TensorFlowのmatrix_band_partで三角行列を作る

TensorFlowのmatrix_band_partで三角行列を作る

こんにちは、大阪DI部の大澤です。

今回はTensorFlowのmatrix_band_partという関数について紹介します。

ソースコードを読んでいてふと現れたこの関数が何をするものなのかが全然わからず、ドキュメントを読んでも理解するのに少し時間がかかったので、メモがわりに残しておきます。

概要

入力されたテンソルの下位2階のテンソル(行列)に対して、対角要素を基準に上下の要素を0にするかそのまま残すかという処理を行います。 引数は次の通りです。

  • input: 2階以上のテンソル
    • 操作対象
  • num_lower: 0階のテンソル(int32かint64)→要は整数
    • 対角要素から下側に何個離れた要素まで残すかを決める
    • -1の場合は対角要素から下側の要素は全て残す、0の場合は全て0にする
  • num_upper: 0階のテンソル(int32かint64)→要は整数
    • 対角要素から上側に何個離れた要素まで残すかを決める
    • -1の場合は対角要素から上側の要素は全て残し、0の場合は全て0にする
  • name: (optional)文字列
    • 操作名
tf.matrix_band_part(
    input,
    num_lower,
    num_upper,
    name=None
)

特にnum_lower=0、num_upper=-1の時は上三角行列、num_lower=-1、num_upper=0の時は下三角行列にする処理となります。

また、numpyで似た操作が可能な関数としてはnumpy.trilやnumpy.triuがありますが、引数の意味が少々異なるので注意が必要です。 例えば、numpy.tril(x, m)であれば、xの対角要素を基準に上側にm個の要素を残し、それより上の要素は0にするという処理になります。(mは負値も可能) numpy.triuはその逆の操作になります。詳細については以下のドキュメントをご覧ください。

試してみる

まずは処理対象として扱う3階のテンソル(行列)を作成します。

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()

x = np.reshape(range(75), (3,5,5))
x

まずは何も変化を及ぼさない処理から試します。

tf.matrix_band_part(x, -1, -1).eval()

次は対角要素のみ残します。

tf.matrix_band_part(x, 0, 0).eval()

次は三角行列です。 まずは下三角行列。

tf.matrix_band_part(x, -1, 0).eval()

そして上三角行列。

tf.matrix_band_part(x, 0, -1).eval()

次は下側の要素は0で、上側は対角要素から2つ目まで残して、それより上は0にします。

tf.matrix_band_part(x, 0, 2).eval()

次は上側の要素は全てそのままで、下側は対角要素から2つ目まで残して、それより下は0にします。

tf.matrix_band_part(x, 2, -1).eval()

TensorFlowのセッションを閉じておきます。

sess.close()

さいごに

TensorFlowのmatrix_band_partという関数について紹介しました。分かってしまえば簡単なシンプルな操作でした。 matrix_band_partは頻繁に使うものでは無いかもしれませんが、三角行列はマスクなどする時に便利なので、覚えておきたいところです。