広告

PyTorchとTensorFlowのレイヤ・モジュールの対応関係

タグ:python machine_learning

あるときPyTorchで書かれたネットワークをTensorflowに移植することがあった。 レイヤの名前や引数などが微妙に違うのでややこしかったので、メモをかねて 主要なレイヤの名前とドキュメントへのリンクをまとめた。
またレイヤのクラス名のほかに引数も違う。たいていの場合PyTorchでは必須な入力次元が TensorFlowでは省力されている(ネットワークを作ったときに推論する)ようになっているのだが、 他にも細かい違いがあることがあるのでドキュメントを参照されたい。
なおTensorFlowのレイヤはkerasのものとした。

全結合層
PyTorch: Linear
TensorFlow: Dense


畳み込み(二次元)
PyTorch: Conv2d
TensorFlow: Conv2D


最大値プーリング(二次元)
PyTorch: ReLU
TensorFlow: MaxPooling2D


ReLU
PyTorch: MaxPool2d
TensorFlow: ReLU


バッチ正規化
PyTorch: BatchNorm2d (次元数ごとに別に実装されているので注意)
TensorFlow: BatchNormalization


ドロップアウト
PyTorch: Dropout
TensorFlow: Dropout



おすすめ記事

PIL, NumPy, PyTorchのデータ相互変換早見表

NumPyのarray.sizeに相当するのはPytorchのTensor.numel()

PyTorch Tensorを確実にNumpy Arrayに変換する