python
machine_learning
しょっちゅう忘れるので自分用にPyTorchと Pythonライブラリのデータ受け渡しのための データ変換をまとめておきます。
ポイントは
-
Pytorch系ははtorchvision.transforms.functionalに入っている
必要に応じてデータ型の変換を行う・ふつうは画像はuint8型
PyTorch -> NumPyではdetach()で計算グラフの情報を取り除く。
取り除いておかないとエラーになる。
From / To | PIL | NumPy | PyTorch |
PIL | -- | np.array(img) | torchvision.transforms.functional.to_tensor(img) |
NumPy | Image.fromarray(np.uint8(arr)) | -- | torch.from_numpy(arr) |
PyTorch | torchvision.transforms.functional.to_pil_image(tensor) | tensor.to('cpu').detach().numpy() | -- |
PIL → PyTorch
from PIL import Image import torchvision img = Image.open("img.jpg") torchvision.transforms.functional.to_tensor(img)
PIL → NumPy
import numpy as np from PIL import Image img = Image.open("img.jpg") np.array(img)
NumPy → PIL
import numpy as np from PIL import Image arr = np.zeros((64, 128, 3)) # H * W * C Image.fromarray(np.uint8(arr)) # データ型が違うとエラーになる
NumPy → PyTorch
import numpy as np import torch arr = np.zeros((64, 128, 3)) torch.from_numpy(arr)
PyTorch → PIL
import torch import torchvision tensor = torch.zeros((64, 128, 3)) # H * W * C torchvision.transforms.functional.to_pil_image(tensor)
PyTorch → NumPy
import torch tensor = torch.zeros((64, 128, 3)) tensor.to('cpu').detach().numpy()
おすすめ記事
PyenvでPythonのバージョンが切り替わらないと思ったらインストール先が変わっただけだったSqueeze / unsqueezeの使い方:要素数1の次元を消したり作ったりする
PyTorch Tensorを確実にNumpy Arrayに変換する