破局的忘却に立ち向かうEWCを試してみた

破局的忘却とは何か、それを回避するにはどうしたら良いかをご紹介します。
2022.08.09

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

皆さん、こんにちは。hotoke_nekoです。

今回はニューラルネットワークが抱える破局的忘却に関して簡単な説明をします。その後、回避手段の一つElastic Weight Consolidation(以下、EWC)を実際に使って回避できるのか確認してみました。

破局的忘却とは?

一言で言うと、今までの学習した結果を忘れてしまう事です。 例えば、犬や猫を分類できるように学習したモデルで、鳥を分類できるように学習させます。 そうすると、鳥を分類できるようにはなったものの、元の精度で犬や猫を分類できなくなる、という現象です。

解決策としては、新しいデータを含めた今までの全てのデータを使って再度学習させます。 この解決方法の場合、「新しく何かを学習させたい」とデータを用意する度に学習に時間がかかります。

そこで、破局的忘却に対処するEWCを使ってみましょう。

EWCの発想

こちらの論文で紹介されています。

一言で説明すると、以前学習時のモデルのパラメータを忘れないように新しいデータを学習しよう、というものです。

次の図を見てください。(上述したリンクにある論文から引用。)

EWC Image

EWCは先に学習していたtaskAと次に学習するtaskBの両方にとって良いモデルのパラメータを探索するようにしています。

つまり、EWCはtaskAとtaskBで共通するlow errorを目指しています。 low errorであるという事はモデルのパラメータの値が適切であるため、予測の精度が高いという事です。

やってみた

OSSでEWCを実装しているコードを発見したので、こちらのコード例や実装を参考にしてEWCを使うと破局的忘却に対処できるのか確かめてみました。なお、Pytorchが使用されています。

1.モデル準備

同じ構成のモデルを2つ用意します。(ElasticWeightConsolidationクラスやBaseModelクラスの定義については上記リンク先と同様です。)

変数ewcはEWCを使って学習後にパラメータの保持をするように設定します。(次の学習時でもパラメータをできるだけ維持しつつ学習します。)

変数non_ewcは学習後にパラメータ保持の設定をしません。(学習の度にパラメータの調整をデータセットに合わせて行っていきます。)

ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4)
non_ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4)

2.データ準備

MNISTのデータを準備します。MNISTは手書きの数字を集めたデータセットです。

mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

3.学習

MNISTのデータで学習を行います。学習回数は解説の為のミニマムな実装として4回です。

for _ in range(4):
    for input, target in tqdm(train_loader):
        ewc.forward_backward_update(input, target)
        non_ewc.forward_backward_update(input, target)

4.パラメータを維持するよう設定

次のコードで学習後のパラメータを維持するように設定します。

ewc.register_ewc_params(mnist_train, 100, 300)

繰り返し

上述の「2.データ準備」から「4.パラメータを維持するよう設定」までを、次はFashionMNISTで実行しました。 FashionMNISTは服や靴などの画像が集まったデータセットです。

コードを示すと以下のようになります。

データの準備をします。

f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

学習させます。MNISTと同じく学習回数は4回です。

for _ in range(4):
    for input, target in tqdm(f_train_loader):
        ewc.forward_backward_update(input, target)
        non_ewc.forward_backward_update(input, target)

学習後にモデルにパラメータを維持するよう設定しておきます。

ewc.register_ewc_params(f_mnist_train, 100, 300)

学習結果

それぞれのモデルにおけるデータセットでのaccuracy(正解率)を確認しました。

モデル MNIST FashionMNIST
ewc 0.7145 0.8198
non_ewc 0.3658 0.8635

特徴的なのは、最初に学習したMNISTでは変数non_ewcのaccuracyが四捨五入しても4割ほどに対して、変数ewcでは7割ほどと維持できている点です。

FashionMNISTでは、non_ewcの方が高い正解率です。これは新しく学習したデータセットに沿って学習を行えたからだと考えられます。 ただし、ewcの方も同じく8割の正解率なので、正解率に大きい差はありません。

今回試してみて、新しいデータセットで学習すると破局的忘却が起こり実際にaccuracyなど目に見える結果として差が出てきました。そのような場面においてEWCを活用すると、副作用少なく学習結果を維持できました。

終わりに

今回は簡単に破局的忘却についての説明を行い、それを回避するEWCについてお伝えしました。

pre-trainedなモデルを使う場合にも起こりうる可能性があるため、うまくいかない場合はこの記事を思い出していただければ幸いです。

今回はここまで。

それでは、また!