JP / EN

広告

PyTorchのテンソルを画像にして保存する

タグ:python machine_learning

デバッグなどのためにニューラルネットの出力や中間層を 目視で確認できるようにしたいことがある。 そのようなときは取り出したテンソルを画像にして保存した後 ビューワで開けばよい。

ここでは画像の書き出しにopencv-pythonを使う。 PyTorchテンソルを大きさ (H, W, 3) のNumPy arrayに変換、cv2.imwriteで書き出す、という作戦だ。 もちろんPILでも同様のことができるはず。
    In [1]: import torch
    In [2]: import cv2
    In [3]: a = torch.rand(32, 32, 3)
    In [4]: cv2.imwrite((a * 255).astype(torch.uint8), "tmp.png")  # 正規化と型キャストを入れる
    Out[4]: True
  
注意点としては、実際のニューラルネット中間層は実数値で任意の値を取る。 これをunsigned 8bit (0~255整数値) 相当のjpgやpngの画像にすると 真っ黒や真っ白になったり、丸めのためにパターンが目視できないような 画像になってしまう。
これを適当な正規化を入れて補正しないといけない。
最小値を0, 最大値を255にするような線形変換 でなんとかなる。
なおサンプルではrandで最小値0, 最大値1の乱数を生成しているので、単に255をかければよい。
これがうまくいくと次のようなランダムノイズ画像ができる。



動作確認したバージョン

  • Python 3.7.2
  • torch==1.11.0+cu113
  • opencv-python==3.4.4.19
  • numpy==1.19.5


おすすめ記事

PIL, NumPy, PyTorchのデータ相互変換早見表

NumPyのarray.sizeに相当するのはPytorchのTensor.numel()

Squeeze / unsqueezeの使い方:要素数1の次元を消したり作ったりする



このエントリーをはてなブックマークに追加

https://wonderhorn.net/