広告
2022/09/15

TensorFlowのテンソルとPIL, NumPy, PyTorchの相互変換

タグ:python machine_learning

テンソルは機械学習や数値計算ライブラリで使われる多次元配列の ラッパークラスであるが、ライブラリごとに色々ありややこしい。 個人的にはPyTorchをよく使うのだが、たまにTensorFlowが 必要になると忘れてしまっている。
ということで思い出し用にTensorFlowテンソルから様々なほかの型への変換をまとめた。 なおバージョンはTensorflow v2系を前提としている。
PyTorch版はこちら
またKerasも(標準バックエンドがTensorFlowなので) TensorFlowのテンソルを返すのでこれと同じようにできる。

TensorFlowテンソルからNumPy array

そのものずばりnumpy()メソッドがある。
In [1]: import tensorflow as tf
In [2]: a = tf.zeros((2, 3, 4))
In [3]: b = a.numpy()
In [4]: type(b)
Out[4]: numpy.ndarray
  


TensorFlowテンソルからPIL Image

Kerasからarray_to_img関数を使う。インポートパスが長い・バージョンにとって代わるので注意
In [1]: import tensorflow as tf
In [2]: a = tf.zeros((5, 4, 3))  # 末尾次元=チャンネル数は3か1のみサポート
In [3]: b = a.numpy()
In [4]: b = tf.keras.preprocessing.image.array_to_img(a)
In [4]: b = tf.keras.utils.array_to_img(a)  # バージョンによってはこっちらしい
In [5]: type(b)
Out[5]: PIL.Image.Image
  


TensorFlowテンソルからPyTorch Tensor

直接変換はできないようなのでNumpy arrayを経由する。
In [1]: import tensorflow as tf
In [2]: import torch
In [3]: a = tf.zeros((2, 3, 4))
In [4]: b = torch.Tensor(a.numpy())
In [23]: type(b)
Out[23]: torch.Tensor
  


NumPy ArrayからTensorFlowテンソル

tf.constantにarrayを渡す。
In [1]: import tensorflow as tf
In [2]: from PIL import Image
In [3]: a = np.array([[1, 2, 3], [4, 5, 6]])
In [4]: b = tf.constant(a)
In [5]: type(b)
Out[5]: tensorflow.python.framework.ops.EagerTensor
  

PIL ImageからTensorFlowテンソル

どうも直接変化する関数はなさそうである。NumPyを経由すればよい。
In [1]: import tensorflow as tf
In [2]: from PIL import Image
In [3]: a = Image.new("RGB", (16, 16))
In [4]: b = tf.keras.preprocessing.image.img_to_array(a)
In [5]: type(b)
Out[5]: numpy.ndarray
In [6]: c = tf.convert_to_tensor(b)
In [7]: type(c)
Out[7]: tensorflow.python.framework.ops.EagerTensor
  

PyTorchテンソルからTensorFlowテンソル

In [1]: import tensorflow as tf
In [2]: import torch
In [3]: a = torch.Tensor([[1, 2], [3, 4]])  # 注意:PyTorchテンソルはcpu側にある必要がある。
In [4]: b = tf.constant(a)
In [23]: type(b)
Out[23]: tensorflow.python.framework.ops.EagerTensor
  


動作確認したバージョン

  • torch==1.11.0+cu113
  • tensorflow==2.5.0
  • numpy==1.19.5