tensorflowのtf.dataとmodel.predict()の相性

追記 (2022/01/07)

tensorflow/kerasの推論で tf.data.datasetを活用する方法がわかりました。

model.predict(...)の入力データに「datasetのイテレータ」ではなく、「dataset自体」を渡すことです。

また、dataset.repeat()を使用している場合はイテレータのループ回数が不定になるため、stepsオプションを追加します。

以下に、前回記事のコードを修正したものを貼っておきます。

import os
import tensorflow as tf
from model import MyModel

batch_size=32

def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    return strategy

strategy = auto_select_accelerator()
REPLICAS = strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE

filenames = [ file for file in os.listdir( "/tmp/" )   if  ".jpg" in file ]
files = [ "/kaggle/input/dataset-name/" + filename for filename in filenames ]
dataset = tf.data.dataset.from_tensor_slices( files )
dataset = dataset.map( read_file, AUTO ).batch(batch_size).prefetch(1)
steps = len(files)//(batch_size*REPLICAS)
if len(files)%(batch_size*REPLICAS) != 0:
    steps += 1

with strategy.scope():
    model = MyModel( ... )

# datasetをそのまま入力する.
preds = model.predict( dataset, batch_size=batch_size*REPLICAS, steps=steps)

参考情報: TPUを使用している場合、かつ、数十GBのそこそこ大きめのデータセットに対して推論を適用する場合、model.predict()の処理中にTPUのsocket closedエラーが発生することがありました。→datasetを分割して投入することで解消できました。

元の記事

kaggleのコンペでTPUインスタンスを使って学習を実行することに慣れてきました。
tf.data.datasetで使うと効率的に学習できるので、推論でも使ってみたい!と思って試したところ、OOMエラー(メモリ使用上限に達したこと)が発生してセッションが落ちてしまいました。。
使用したtensorflowは2.6.0です。 kaggleのNotebookやGoogle colaboratoryで発生することを確認しています。 (GPUインスタンスでもTPUインスタンスでも発生するようです)

以下のコードを実行すると、もりもり使用メモリが増加していきます。

import os
import tensorflow as tf
from model import MyModel

def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    return strategy

strategy = auto_select_accelerator()
REPLICAS = strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE

filenames = [ file for file in os.listdir( "/tmp/" )   if  ".jpg" in file ]
files = [ "/kaggle/input/dataset-name/" + filename for filename in filenames ]
dataset = tf.data.dataset.from_tensor_slices( files )
dataset = dataset.map( read_file, AUTO ).batch(32).prefetch(1)

with strategy.scope():
    model = MyModel( ... )

for batch in dataset:
    preds = model.predict( batch )

del modeltf.keras.backend.clear_session()gc.collect()を試しても開放されず。。
なにか知見が得られましたら、ここに情報を追記しようと思います。