python
machine_learning
あるときPyTorchで書かれたネットワークをTensorflowに移植することがあった。 レイヤの名前や引数などが微妙に違うのでややこしかったので、メモをかねて 主要なレイヤの名前とドキュメントへのリンクをまとめた。
またレイヤのクラス名のほかに引数も違う。たいていの場合PyTorchでは必須な入力次元が TensorFlowでは省力されている(ネットワークを作ったときに推論する)ようになっているのだが、 他にも細かい違いがあることがあるのでドキュメントを参照されたい。
なおTensorFlowのレイヤはkerasのものとした。
全結合層
PyTorch: Linear
TensorFlow: Dense
畳み込み(二次元)
PyTorch: Conv2d
TensorFlow: Conv2D
最大値プーリング(二次元)
PyTorch: ReLU
TensorFlow: MaxPooling2D
ReLU
PyTorch: MaxPool2d
TensorFlow: ReLU
バッチ正規化
PyTorch: BatchNorm2d (次元数ごとに別に実装されているので注意)
TensorFlow: BatchNormalization
ドロップアウト
PyTorch: Dropout
TensorFlow: Dropout
おすすめ記事
PIL, NumPy, PyTorchのデータ相互変換早見表NumPyのarray.sizeに相当するのはPytorchのTensor.numel()
PyTorch Tensorを確実にNumpy Arrayに変換する