Deep Learningを簡単・ローコードで使えるライブラリktrainを使ってみた

Deep Learningを簡単・ローコードで使えるライブラリktrainを使ってみた

Clock Icon2020.06.10

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、Mr.Moです。

機械学習関連の情報は今の時代オープンにされていることが多く、勉強するにあたり情報不足で困ることがあまり無いと感じています。一方で気になる論文や参考文献を見て理論を理解しそこからさらに実装(適用)に移行して動かしていきますが、そこまで集中力が持たなかったという経験は無いでしょうか?

そんな時に非常に短いコードでディープラーニングを実行できるktrainというライブラリがありましたので使ってみたいと思います。

ktrainとは

image.png

ktrainはTensorFlow Keras (他にもtransformers, scikit-learn, stellargraphなど) の軽量ラッパーで、ニューラルネットワークやその他の機械学習モデルの構築、トレーニング、デプロイを支援します。初心者から経験豊富な方まで利用できシンプルで統一されたインターフェースにより様々なタスクを素早く解決します。

機械学習のワークフローは正直、複雑で特に初心者の方とかはハードルが高いので、難易度の高い機械学習モデルの作成からデプロイまでのステップを数行のコードで実施して試すことができることを目的にしているようですね。

また、データ毎で実施できるタスクは現在下記のようです。

  • text data:
    • Text Classification
    • Text Regression
    • Sequence Labeling (NER)
    • Ready-to-Use NER models for English, Chinese, and Russian
    • Sentence Pair Classification
    • Unsupervised Topic Modeling
    • Document Similarity with One-Class Learning
    • Document Recommendation Engine
    • Text Summarization
    • Open-Domain Question-Answering
    • Zero-Shot Learning
  • vision data:
    • image classification
    • image regression
  • graph data:
    • node classification
    • link prediction

使ってみる

今回はBertでテキスト分類を実施します。データセットはみんな大好き「livedoor ニュースコーパス」です。
(前提としてSageMakerのml.p2.xlargeインスタンス、tensorflow_p36のNotebookを使っています)

  • Fine-Tuningを実施

たった数行で実行できています。

import ktrain
from ktrain import text
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
t = text.Transformer(MODEL_NAME, maxlen=128, classes=topics)
trn = t.preprocess_train(X_train, y_train)
val = t.preprocess_test(X_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=32) # lower bs if OOM occurs
#learner.fit_onecycle(5e-5, 3)
learner.autofit(5e-5)

image.png

  • 機械学習のモデル評価

一行で各メトリクスの確認ができました。

learner.validate(class_names=t.get_classes())

image.png

  • モデルの保存・ロード

こちらもローコードで実現できています。

## ktrainのバージョンが0.16.1以上でないと動かないので注意してください。
predictor.save('model/livedoor')
p = ktrain.load_predictor('model/livedoor')

上記のコード全文はこちらに置いております。

まとめ

まずは動かした方が機械学習に対する理解が進むこともあると思いますので、そういった時にもktrainは便利そうですね。 こちらのライブラリは更新頻度が高いように思いますので今後のUPDATEも楽しみです。実際にktrainを使ってみてフィードバックを送るのも良いかもしれません。

参考

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.