python
machine_learning
PyTorchのコードを見ているとよく出現する squeeze と unsqueeze だが、 意味がよくわかりにくいので解説する。
Tensor.squeeze()は「テンソル中の要素数1の次元を削除する」、 Tensor.unsqueeze(d: int)は「テンソルのd階目に要素数1の次元を挿入する」 というものだ。
要素数1の次元の扱いはテンソル操作上面倒になるポイントなので、 使い方を覚えると実はこれらはうれしい関数なのだ。
Unsqueezeの使用例:
In [6]: t = Tensor([[1, 2, 3], [4, 5, 6]]) In [7]: t.shape Out[7]: torch.Size([2, 3]) In [8]: t.unsqueeze(0) Out[8]: tensor([[[1., 2., 3.], [4., 5., 6.]]]) In [9]: t.unsqueeze(0).shape Out[9]: torch.Size([1, 2, 3]) #0に要素数1の次元が挿入された In [10]: t.unsqueeze(1).shape Out[10]: torch.Size([2, 1, 3]) #1に要素数1の次元が挿入された In [11]: t.unsqueeze(2).shape Out[11]: torch.Size([2, 3, 1]) #2に要素数1の次元が挿入された In [12]: t.shape Out[12]: torch.Size([2, 3]) # squeeze/unsqueezeは非破壊的:元テンソルt自体は変更なし
Squeezeの使用例:
In [2]: t = Tensor(1, 2, 3) In [4]: t.squeeze().shape Out[4]: torch.Size([2, 3]) # 要素数1の次元が消える In [5]: t = Tensor(2, 1, 3) In [6]: t.squeeze().shape Out[6]: torch.Size([2, 3]) # 要素数1の次元が消える In [7]: t = Tensor(2, 1, 1, 3) In [8]: t.squeeze().shape Out[8]: torch.Size([2, 3]) # 要素数1の次元が複数あればすべて消える
Squeeze / unsqueezeを使っておいしいポイントの一つは 「入力データにバッチ用の次元を付与する」、「出力ベクトルからバッチ用の次元を削除する」 というものだ。ニューラルネットの実装ではデータが複数あるときにまとめて処理できるように 入出力テンソルの次元はN×Dになっていることが多い。 ここでNはデータ数、Dは各データの次元である。
このときもし処理したいデータがD次元のベクトル1つであったら、データが 一つしかないことを示すためにわざわざ1×Dのテンソルに形状を整えて渡さねばならない。 これをunsqueezeによってうまいことできる、というわけだ。
逆に出力側ではsqueezeでバッチの次元を削除すればよい。
In [1]: import torch ...: from torchvision import models, transforms ...: from PIL import Image In [2]: vgg16 = models.vgg16(pretrained=True) In [3]: img = torch.Tensor(3, 224, 224) # 画像 (C, H, W) In [4]: out = vgg16(img) # バッチ次元がない、これはエラー RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], but got 3-dimensional input of size [3, 224, 224] instead In [9]: out = vgg16(img.unsqueeze(0)) # こっちはOK In [11]: out.shape Out[11]: torch.Size([1, 1000]) #出力もバッチ化されていて邪魔…… In [12]: out.squeeze().shape Out[12]: torch.Size([1000])
豆知識:英単語としての"squeeze"は「圧搾する」という意味だ。"Unsqueeze"は辞書に 載っておらず造語と思われるが、「squeezeの逆」という程度の意味だろう。
ちなみにSqueeze-and-Excitation Network (SENet) という深層学習手法も 存在するが、これとはあまり関係がないと思われる。
おすすめ記事
PIL, NumPy, PyTorchのデータ相互変換早見表NumPyのarray.sizeに相当するのはPytorchのTensor.numel()
PyTorch Tensorを確実にNumpy Arrayに変換する