Flaskでデコレータを使うときの落とし穴

2024.06.03

はじめに

Pythonにはデコレータという機能があります。デコレータを使うことで、以下のようなシンタックスシュガーを使うことができ、簡潔に共通処理なんかを書くことができます。

def sample_decorator(f):
    def _wrapper(*args, **keywords):
        # 前処理
        print("pre process")
        _f = f(*args, **keywords)
        # 後処理
        return _f

    return _wrapper

@sample_decorator # hello = sample_decorator(hello)と同じ
def hello():
    print("Hello World")

さて、このデコレータですが、Flaskで使おうとすると落とし穴にハマることがあります。この記事では、落とし穴の回避方法を紹介します。

Flaskでデコレータを使うユースケース

Flaskは強力なWebAppフレームワークなので、大抵のことはデフォルトの機能で実現できます。僕は昔、なんちゃって権限管理をするためにデコレータを使いました。ほかにもLoggerを拡張して使いたいとか、リクエストを整形したりしたいときなんかに使えそうです。

落とし穴

以下のようなFlask appを作成しました。それぞれのエンドポイントにアクセスがあった際に、アクセスされたよ!と書き出すだけのデコレータをつけています。

from flask import Flask, jsonify

app = Flask(__name__)


def sample_decorator(f):
    def _wrapper(*args, **keywords):
        # エンドポイントにアクセスされたらアクセスされたよ!と出力する
        print("アクセスされたよ!")
        _f = f(*args, **keywords)
        return _f
    return _wrapper


@app.route("/")
@sample_decorator
def hello_world():
    return jsonify({"message": "hello world"}), 200


@app.route("/not-found")
@sample_decorator
def not_found():
    return jsonify({"message": "Not Found"}), 404

一見すると動いてくれそうですが、このコードではサーバが立ち上がりません。

AssertionError: View function mapping is overwriting an existing endpoint function: _wrapper

_wrapperというview関数が重複しているっぽいですね。次の節で詳しく調べてみます。お急ぎの方は読み飛ばしてください。

エラーの原因を探索

前提として、デコレータにはとある重要な仕様があります。

def sample():
    pass


# sample関数の関数名を確認
print(sample.__name__)


def sample_decorator(f):
    def _wrapper(*args, **keywords):
        return f(*args, **keywords)

    return _wrapper


@sample_decorator
def sample2():
    pass


# sample2関数の関数名を確認...しているはず
print(sample2.__name__)

このコードの実行結果は以下のようになります。

sample
_wrapper

おや。。。デコレータをつけている方は_wrapperが出力されていますね。これはデコレータで糖衣しなかったコードを見ると理由がわかります。

def sample_decorator(f):
    def _wrapper(*args, **keywords):
        return f(*args, **keywords)

    return _wrapper

def sample2():
    pass

sample2 = sample_decorator(sample2)
print(sample2.__name__)

デコレータとして使っていると意識しづらいですが、元の関数オブジェクトをデコレータ内のラッパーで置き換えています。これで、さっきのエラーメッセージの意味がわかりました。

解決方法

assertがかかっているのは関数名に対してなので、デコレータの内部で_wrapperの関数名を被デコレート関数の関数名で置き換えてやれば良さそうです。

from flask import Flask, jsonify

app = Flask(__name__)


def sample_decorator(f):
    def _wrapper(*args, **keywords):
        # エンドポイントにアクセスされたらアクセスされたよ!と出力する
        print("アクセスされたよ!")
        _f = f(*args, **keywords)
        return _f

    _wrapper.__name__ = f.__name__ # ここ!
    return _wrapper


@app.route("/")
@sample_decorator
def hello_world():
    return jsonify({"message": "hello world"}), 200


@app.route("/not-found")
@sample_decorator
def not_found():
    return jsonify({"message": "Not Found"}), 404

こうして関数名を偽装することで、意図した通りの動作をするようになりました。

おまけ

せっかくなのでFlaskで提供されているデコレータがこの部分をどうやって回避しているか確認してみます。コードを辿ると以下の関数を見つけました。

WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
                       '__annotations__', '__type_params__')
WRAPPER_UPDATES = ('__dict__',)
def update_wrapper(wrapper,
                   wrapped,
                   assigned = WRAPPER_ASSIGNMENTS,
                   updated = WRAPPER_UPDATES):
    for attr in assigned:
        try:
            value = getattr(wrapped, attr)
        except AttributeError:
            pass
        else:
            setattr(wrapper, attr, value)
    for attr in updated:
        getattr(wrapper, attr).update(getattr(wrapped, attr, {}))

    wrapper.__wrapped__ = wrapped

    return wrapper

※FlaskはBSD-3ライセンスです。

これを読む限り、nameを含めたいくつかの属性をデコレータのラッパーに移植しています。概ね同じやり方で回避しているので問題なさそうです。もし解決法に記載したやり方でうまくいかないときはここに列挙されている属性を全部コピーしてやるといいはずです。

終わりに

引数つきのデコレータでも要領は変わりません。

ちなみに僕はFlaskをフラスコと読む派です。