[小ネタ]pytestのmark.parametrizeでサブテストに簡単に名前をつける方法

pytestの[mark.parametrize](https://docs.pytest.org/en/2.8.7/parametrize.html)でサブテストに簡単に名前をつける方法をご紹介します。
2019.10.17

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

はじめに

おはようございます、加藤です。pytestのmark.parametrizeでサブテストに簡単に名前をつける方法をご紹介します。

説明

説明様に、引数同士を足すだけの簡単な関数を用意します。

src/main.py

def sum(x, y):
    return x + y

テストコードは下記の様に書きます。

tests/test.py

import pytest
from src.main import sum

params = {
    "normal 1": (1, 2, 3),
    "normal 2": (3, 4, 7)
}


@pytest.mark.parametrize("x, y, want", list(params.values()), ids=list(params.keys()))
def test(x, y, want):
    assert want == sum(x, y)

テストパラメータは、辞書型で定義します。Keyにテスト名を、Valueにタプル型で入力と期待する出力を定義します。

"normal 1": (1, 2, 3) の場合は、 "normal 1" がテスト名で、 1, 2 が入力、 3 が期待する出力です。

params = {
    "normal 1": (1, 2, 3),
    "normal 2": (3, 4, 7)
}

@pytest.mark.parametrize の第一引数で、入力及び出力に名前を付けます。先程テストパラメータで定義したタプル型の順序どおりにします。
list(<dict>.keys()), list(<dict>.values())で、辞書のKeyのみ、Valueのみをリスト型で取得できます。第二引数に list(<dict>.values()) を、引数idslist(<dict>.keys()) を定義する事で、テストパラメータの名前と値を適合させます。

@pytest.mark.parametrize("x, y, want", list(params.values()), ids=list(params.keys()))

参考例を1つ紹介します。
API Gateway → Lambda 構成のAPIを想定しています。テストパラメータの入力は pathParameters で、出力は statusCodeBody です。
紹介した内容を使うと、こんな感じでテストが書けます。
英語がだいぶ変だと思いますが、絶賛勉強中なので、許してください。

import pytest
import json

// 色々インポート

@pytest.mark.usefixtures('start_moto_mock', 'create_gg_group', 'create_lambda_function',
                         'create_rdb_records_for_test', 'create_enabled_failure_notification_email_records')
class TestHogeHoge:
    params_for_res_ok = {
        'have child': (
            {
                'pathParameters': {
                    'id': '1',
                },
            },
            200,
            json.dumps(
                {
                    'id': 1,
                    'name': 'TEST NAME',
                    'chlid': {
                            'id': 2,
                            'name': 'CHILD NAME',
                    }
                }
            )
        ),
        'don\'t have child': (
            {
                'pathParameters': {
                    'id': '1',
                },
            },
            200,
            json.dumps(
                {
                    'id': 1,
                    'name': 'TEST NAME',
                    'chlid': {}
                }
            )
        )
    }

    params_for_res_not_found = {
        'not found': (
            {
                'pathParameters': {
                    'id': '9999',
                },
            },
            404,
            json.dumps(
                {
                    'errors': [
                        {
                            'error_code': 'E404001',
                            'message': 'the specified object was not found'
                        }
                    ]
                }
            )
        ),
        // 権限不一致によるエラーだが存在を非権限者に伝えないようにNotFoundを返す
        'permission error': (
            {
                'pathParameters': {
                    'id': '8888',
                },
            },
            404,
            json.dumps(
                {
                    'errors': [
                        {
                            'error_code': 'E404001',
                            'message': 'the specified object was not found'
                        }
                    ]
                }
            )
        )
    }

    @pytest.mark.parametrize('event, status_code, body', list(params_for_res_ok.values()),
                             ids=list(params_for_res_ok.keys()))
    def test_res_ok(self, event, status_code, body):
        """
        正常系: 200 OK
        :param event:
        :param status_code:
        :param body:
        :return:
        """
        res = handler(event, {})

        assert res['statusCode'] == status_code
        assert res['body'] == body

    @pytest.mark.parametrize('event, status_code, body', list(params_for_res_not_found.values()),
                             ids=list(params_for_res_not_found.keys()))
    def test_res_not_found(self, event, status_code, body):
        """
        異常系: 404 Not Found
        :param event:
        :param status_code:
        :param body:
        :return:
        """
        res = handler(event, {})

        assert res['statusCode'] == status_code
        assert res['body'] == body

あとがき

今回の例が特にわかりやすいのですが、正常系と異常系のコードが完全に一致しているので、まとめようと思えばまとめることができます。 だからといって、まとめ過ぎると扱いにくくなるので気をつけましょう!!(レビューで指摘を受けたので自戒を込めて)

参考元