DistilBERTの日本語事前学習済みモデルを使ったりFine-Tuningもしてみる

2020.05.13

こんにちは、Mr.Moです。

先日、バンダイナムコ研究所様から日本語Wikipediaで学習したDistilBERTのモデルが公開されましたね。気になっていたのでさっそく使わせていただきました!(ちなみに拙者、BERT自体あまり使ったことがありません。内容を把握した程度^^;)

BERTに関しては下記の記事をご参照ください!

DistilBERTとは

DistilBERTはHuggingface が NeurIPS 2019 に公開したモデルで、名前は「Distilated-BERT」の略となります。投稿された論文はこちらをご参考ください。 DistilBERTはBERTアーキテクチャをベースにした、小さくて、速くて、軽いTransformerモデルです。DistilBERTは、BERT-baseよりもパラメータが40%少なく、60%高速に動作し、GLUE Benchmarkで測定されたBERTの97%の性能を維持できると言われています。 DistilBERTは、教師と呼ばれる大きなモデルを生徒と呼ばれる小さなモデルに圧縮する技術である知識蒸留を用いて訓練されます。BERTを蒸留することで、元のBERTモデルと多くの類似点を持ちながら、より軽量で実行速度が速いTransformerモデルを得ることができます。

上記は公開していただいている文章をそのまま持ってきたものです。DistilBERTは性能をできるだけ維持した軽量BERTという認識です。

使ってみる

公式ページの説明を参考に作業していきます。
ここから先は自分の作業メモみたいな内容です、ご了承ください。m(_ _)m

前提

前提としてSageMaker上で動かしています。

フォルダ構成は下記です。

.
├── DistilBERT.ipynb
├── DistilBERT-base-jp.zip
├── input.txt
├── output.tsv
├── data
│   ├── dev.tsv
│   └── train.tsv
├── DistilBERT-base-jp
│   ├── docs
│   ├── config.json
│   ├── pytorch_model.bin
│   ├── README.md
│   └── vocab.txt
└── out
    ├── config.json
    ├── eval_results.txt
    ├── pytorch_model.bin
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    ├── training_args.bin
    └── vocab.txt

必要なモジュールのインストール

必要なモジュールをpip installでインストールします。

!pip install --upgrade pip
!pip install -r requirements.txt
  • requirements.txt
torch>=1.3.1
torchvision>=0.4.2
transformers>=2.8.0
tensorboard>=1.14.0
tensorboardX==1.8
scikit-learn>=0.21.0
requests
mecab-python3

DistilBERTの日本語事前学習モデルの読み込み

今回はモデルをダウンロードして解凍したものを使用します。

まずモデルのダウンロードと解凍。

!wget https://github.com/BandaiNamcoResearchInc/DistilBERT-base-jp/releases/download/1.0/DistilBERT-base-jp.zip
!unzip DistilBERT-base-jp.zip

下記のコードでダウンロードしたモデルを読み込みます。

from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-japanese-whole-word-masking")   
model = AutoModel.from_pretrained("./DistilBERT-base-jp")

文章の埋め込み表現ベクトルの取得

入力用のデータは下記です。input.txtに保存しておきます。

  • input.txt
こんにちは
おはよう
こんばんは
お腹すいた
ご飯食べたい
明日晴れてると良いな
明日の天気はどうだろうか
雨降ったら嫌やな

下記の関数を作成して、埋め込み表現ベクトルを取得します。

import torch
def get_embedding(model, tokenizer, text):
  tokenized_text = tokenizer.tokenize(text)
  tokenized_text.insert(0, '[CLS]')
  tokenized_text.append('[SEP]')
  tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
  tokens_tensor = torch.tensor([tokens])
  model.eval()
  with torch.no_grad():
      layers, _ = model(tokens_tensor)
  target_layer = -2
  embedding = layers[0][target_layer].numpy()
  return embedding

さらに下記のコードで取得した埋め込み表現ベクトルの情報をtsvファイルに保存しています。

import numpy as np
embedding_list = []
f = open('./input.txt')
sentens = f.readlines()
f.close()

for s in sentens:
  mbedding = get_embedding(model, tokenizer, s.strip())
  embedding_list.append(mbedding)
np.savetxt('./output.tsv', embedding_list, delimiter='\t')

文章ベクトルの可視化

上記で取得した文章の埋め込み表現ベクトルをEmbedding Projectorを使って可視化します。output.tsvinput.txtをLoadさせT-SNEで次元圧縮して3次元にマッピングさせた結果が下記ですが、似た文章が近くに集まってるっぽいですかね。

image.png

自前のデータセットでFine-Tuning

ここからは独自で用意したデータセットでFine-Tuningして、分類タスクを実施します。

まずtransformersのコードをベースに修正するので、GitHubからcloneしてきます。

!git clone https://github.com/huggingface/transformers.git -b v2.8.0

競合を避けるために pip uninstall transformersしてインストールしていたtransformersをアンインストールしておきます。

!pip uninstall transformers -y

下記のファイルに今回用のコードを追加します。Customizeと名付けて関連の処理を追加しています。

  • transformers/src/transformers/data/processors/glue.py
class CustomizeProcessor(DataProcessor):
    """Processor for the original data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        print(lines)
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


glue_tasks_num_labels = {
    "cola": 2,
    "mnli": 3,
    "mrpc": 2,
    "sst-2": 2,
    "sts-b": 1,
    "qqp": 2,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
    "customize": 2, #add
}

glue_processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor,
    "wnli": WnliProcessor,
    "customize": CustomizeProcessor, #add
}

glue_output_modes = {
    "cola": "classification",
    "mnli": "classification",
    "mnli-mm": "classification",
    "mrpc": "classification",
    "sst-2": "classification",
    "sts-b": "regression",
    "qqp": "classification",
    "qnli": "classification",
    "rte": "classification",
    "wnli": "classification",
    "customize": "classification", #add
}

さらに下記のコードも追加します。今回修正した上記のコードをimport(transformers)で使うようにするためですね。

import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))

次に下記のファイルに評価用のメトリクスを追加します。

  • transformers/src/transformers/data/metrics/__init__.py
def glue_compute_metrics(task_name, preds, labels):
        assert len(preds) == len(labels)
        if task_name == "cola":
            return {"mcc": matthews_corrcoef(labels, preds)}
        elif task_name == "sst-2":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mrpc":
            return acc_and_f1(preds, labels)
        elif task_name == "sts-b":
            return pearson_and_spearman(preds, labels)
        elif task_name == "qqp":
            return acc_and_f1(preds, labels)
        elif task_name == "mnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mnli-mm":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "qnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "rte":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "wnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "hans":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "customize":                   #add
            return {"acc": acc_and_f1(preds, labels)}   #add
        else:
            raise KeyError(task_name)

    def acc_and_f1(preds, labels):
        acc = simple_accuracy(preds, labels)
        # f1 = f1_score(y_true=labels, y_pred=preds)
        f1 = f1_score(y_true=labels, y_pred=preds, average='macro') #add
        return {
            "acc": acc,
            "f1": f1,
            "acc_and_f1": (acc + f1) / 2,
        }

データセットはtrain.tsv(学習データ)、dev.tsv(検証データ)を用意する必要があります。今回は簡単ですが下記のデータセットとしました。([テキスト\tラベル]の形式です。ヘッダー行は不要です)

  • data/train.tsv
断固として非難する   1
非難する    1
極めて遺憾   1
深く憂慮する  1
憂慮する    1
強く懸念する  1
懸念する    1
評価する    0
期待している  0
  • data/dev.tsv
遺憾の意を表する    0
一定評価できる 1

ここまでの準備が終わったら下記のコマンドを実行してFine-Tuning実行です。--output_dirに指定したフォルダにモデルが出力されているはずです。

!python transformers/examples/run_glue.py \
    --task_name customize \
    --do_train \
    --do_eval \
    --data_dir ./data \
    --output_dir ./out \
    --model_type distilbert \
    --model_name_or_path ./DistilBERT-base-jp \
    --max_seq_length 128 \
    --learning_rate 2e-5 \
    --num_train_epochs 100 \
    --per_gpu_train_batch_size 32 \
    --per_gpu_eval_batch_size 32 \

下記はFine-Tuningの実行結果です。accの値が1.0となっており、dev.tsv(検証データ)のデータを正しく分類できたようですね。

image.png

まとめ

まずは動かすことで直感的な理解を得たいと思った次第でした。まだまだ使いこなせていませんしDistilated-BERTの論文も読んでいないので、そちらも読み込んで理解を深めたいです。 今回は日本語Wikipediaで学習済みのモデルを公開していただいていたので簡単に試すことができました!