Tensorflow のgraph modeをeager modeに変える

Python R

tensorflowには、graph modeとeager execution mode があります。graph mode では演算が高速に行われる一方、演算の途中の結果を見るには面倒な手続きを踏まなくてはいけません。
eager execution mode1では、演算の途中でも、numpy のデータを扱うように、自由に値を見る事が出来ます。

Graph mode の挙動

tf2.0以降では、基本的にはEager execution modeになっています。確認するには、

import tensorflow as tf
tf.executing_eagerly()
#True

とすれば良いです。True ならeager execution modeになっています。
graph Mode を使うには、@tf.function を関数定義の前につけます。

@tf.function
def hoge():
    print(tf.executing_eagerly())
hoge()
#False

keras で機械学習モデルを作ったりするときは、keras.hoge とやりますが、その中ではgraph mode になっています。graph mode の特徴は、tensor の形と型だけを保持する事です。

@tf.function
def hoge(x,y):
    z=x+y
    print(z)
x=tf.constant(1)
y=tf.constant(2)
hoge(x,y)
Tensor("add:0", shape=(), dtype=int32)

中の数値を具体的に見るには、tf.print()を使う必要があります。

@tf.function
def hoge(x,y):
    z=x+y
    tf.print(z)
x=tf.constant(1)
y=tf.constant(2)
hoge(x,y)
#3

graph mode の時は、tensor の値をnumpy 型に変換することが出来ません。

@tf.function
def hoge(x):
    x=x.numpy()
    print(x)
x=tf.constant(1)
hoge(x)
AttributeError: 'Tensor' object has no attribute 'numpy'

graph mode 出ない時は、numpy 型にして値を確認したりできます。

def hoge(x):
    x=x.numpy()
    print(x)
x=tf.constant(1)
hoge(x)
#1

Graph Mode からEager Execution Mode に移行する

keras の中で処理がどう行われているか確認したい時、graph mode のままでは困る事があります。そのような時の為に、eager execution mode にする関数が用意されています。
graph mode の中で、tf.config.experimental_run_functions_eagerly(True) とすれば良いです。

@tf.function
def hoge():
    print(tf.executing_eagerly())
hoge()
tf.config.experimental_run_functions_eagerly(True)
hoge()
#False
#True

もう一度graph mode に戻したい時はtf.config.experimental_run_functions_eagerly(False)とします。

@tf.function
def hoge():
    print(tf.executing_eagerly())
hoge()
tf.config.experimental_run_functions_eagerly(False)
hoge()
#True
#False

例えば、keras で機械学習モデルを作った後に、自作の損失関数を使いたい時とかにこの関数を使うかもしれません。

def custom_loss(y_val, y_pred):
    tf.config.experimental_run_functions_eagerly(True)
    loss= sugoisyori
    return loss

こう書いておくとcustom_loss()の中ではeager execution mode になり、予測値などをnumpy ぽく扱えます。eager execution mode にしておかないと、中身の一部を弄るような操作は難しくなります。

まとめ

  • graph mode の説明をした
  • graph modeとeager execution mode を行き来する方法を説明した
  1. tf2.0以降はこっちがデフォルト
タイトルとURLをコピーしました