広告
2022/05/25

TensorFlowのテンソルをPyTorchに変換したりPyTorchから TensorFlowにしたりする

タグ:python machine_learning

事情によりTensorFlowとPyTorchが混ざったシステムができてしまった。
このようなときはTensorFlowの処理結果をPyTorchに渡したり、 その逆をしたくなる。
基本的にはNumPy arrayを介して変換すると思えばよく、 PyTorch -> Tensorflowは暗黙的な変換がうまくいくのでこれを省略できる。
Tensorflow -> PyTorch は明示的にTensor.numpy()で変換してから 渡せばよい。
import tensorflow as tf
import torch

a = tf.constant((1, 32, 32, 3))
b = torch.Tensor(a.numpy())
print(b)

c = tf.constant(b)
print(c)
  
tensor([ 1., 32., 32.,  3.])
tf.Tensor([ 1. 32. 32.  3.], shape=(4,), dtype=float32)
  

注意点

一般的に使われる次元順がTensorFlowとPyTorchで異なるため、 単に型を変換しただけでなく中身の変換も必要になるケースがある。
例えば画像はTensorFlowでは (画像数, 縦, 横, チャンネル) のような4次元テンソル にすることが普通だが、PyTorchでは (画像数, チャンネル, 縦, 横) が普通である。
適宜permuteなどを使うこと。

動作確認したバージョン

  • Python 3.7.2
  • torch==1.11.0+cu113
  • tensorflow==2.5.0

おすすめ記事

PIL, NumPy, PyTorchのデータ相互変換早見表

NumPyのarray.sizeに相当するのはPytorchのTensor.numel()

PyTorch Tensorを確実にNumpy Arrayに変換する