python
machine_learning
Tensor.detach() は計算グラフからテンソルを切り離す関数である。 しかし「計算グラフ?なにそれ」となりがちであるように、PyTorch内部の動作をわかっていないと 何をやっているのかよくわからないものでもある。 なんとなく「PyTorch TensorをNumPy Arrayに変換するときに呼ぶやつ」と思っている人も多いのではないだろうか。
PyTorchにおける計算グラフとは「テンソルの値がどんな計算順序で算出されたものか」 を記録しておくためのグラフ構造である。 計算グラフを作る処理はプログラマは普段意識しないと思うが、実はPyTorchは基本的な 演算子をオーバーライドして独自実装を入れるなどせておりこれが知らないうちに 記録されているのである。
PyTorchでニューラルネットを学習するときによく使われるバックプロパゲーション などはこの計算グラフにアクセスして各テンソル(=ニューラルネットのパラメータ) などをどうやって微調整するか決めている。
とはいえこの計算グラフ、用途によっては邪魔になるので切り捨てるためのメソッドdetach が用意されている。具体的には
- 計算グラフをサポートしない他の形式(NumPy arrayなど)に変換するとき
- パックプロパゲーションで使わない値(例えば、テスト時に参考として記録する分類精度の値など)をテンソルに貯める
- パックプロパゲーションによらないカスタマイズされた学習手法を試したいとき
なお"新規テンソルを生成する"といってもデータ自体のコピーが発生するわけではない。 同じデータを参照する、別のメタデータとしてのテンソルを作るというだけである。

図:detachで計算グラフから切り離されたテンソル (z1) を作ったところ
利用例:勾配を流したくない結合を明示的に指定
機械学習手法の中で「ネットのこの箇所は勾配を流してほしくないな……」 となるケースがある。例えば、ネットワークの一部を全体の学習から切り離して別の方法で パラメータを手動設定する、というようなアルゴリズムが世の中に 存在する。
例 SimSiamの"stop-grad": https://github.com/facebookresearch/simsiam/blob/main/simsiam/builder.py#L61
おすすめ記事
PIL, NumPy, PyTorchのデータ相互変換早見表NumPyのarray.sizeに相当するのはPytorchのTensor.numel()
PyTorch Tensorを確実にNumpy Arrayに変換する