広告

Squeeze / unsqueezeの使い方:要素数1の次元を消したり作ったりする


タグ:python machine_learning

PyTorchのコードを見ているとよく出現する squeeze と unsqueeze だが、 意味がよくわかりにくいので解説する。
Tensor.squeeze()は「テンソル中の要素数1の次元を削除する」、 Tensor.unsqueeze(d: int)は「テンソルのd階目に要素数1の次元を挿入する」 というものだ。
要素数1の次元の扱いはテンソル操作上面倒になるポイントなので、 使い方を覚えると実はこれらはうれしい関数なのだ。

Unsqueezeの使用例:
In [6]: t =  Tensor([[1, 2, 3], [4, 5, 6]])

In [7]: t.shape
Out[7]: torch.Size([2, 3])

In [8]: t.unsqueeze(0)
Out[8]:
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])

In [9]: t.unsqueeze(0).shape
Out[9]: torch.Size([1, 2, 3])  #0に要素数1の次元が挿入された

In [10]: t.unsqueeze(1).shape
Out[10]: torch.Size([2, 1, 3])  #1に要素数1の次元が挿入された

In [11]: t.unsqueeze(2).shape
Out[11]: torch.Size([2, 3, 1])  #2に要素数1の次元が挿入された

In [12]: t.shape
Out[12]: torch.Size([2, 3])  # squeeze/unsqueezeは非破壊的:元テンソルt自体は変更なし


Squeezeの使用例:
In [2]: t =  Tensor(1, 2, 3)
In [4]: t.squeeze().shape
Out[4]: torch.Size([2, 3])  # 要素数1の次元が消える

In [5]: t =  Tensor(2, 1, 3)
In [6]: t.squeeze().shape
Out[6]: torch.Size([2, 3])  # 要素数1の次元が消える

In [7]: t =  Tensor(2, 1, 1, 3)
In [8]: t.squeeze().shape
Out[8]: torch.Size([2, 3])  # 要素数1の次元が複数あればすべて消える


Squeeze / unsqueezeを使っておいしいポイントの一つは 「入力データにバッチ用の次元を付与する」、「出力ベクトルからバッチ用の次元を削除する」 というものだ。ニューラルネットの実装ではデータが複数あるときにまとめて処理できるように 入出力テンソルの次元はN×Dになっていることが多い。 ここでNはデータ数、Dは各データの次元である。
このときもし処理したいデータがD次元のベクトル1つであったら、データが 一つしかないことを示すためにわざわざ1×Dのテンソルに形状を整えて渡さねばならない。 これをunsqueezeによってうまいことできる、というわけだ。
逆に出力側ではsqueezeでバッチの次元を削除すればよい。
  In [1]: import torch
   ...: from torchvision import models, transforms
   ...: from PIL import Image
   In [2]: vgg16 = models.vgg16(pretrained=True)

   In [3]: img = torch.Tensor(3, 224, 224)  # 画像 (C, H, W)

   In [4]: out = vgg16(img)  # バッチ次元がない、これはエラー
   RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], but got 3-dimensional input of size [3, 224, 224] instead

   In [9]: out = vgg16(img.unsqueeze(0))  # こっちはOK
   In [11]: out.shape
   Out[11]: torch.Size([1, 1000])  #出力もバッチ化されていて邪魔……

   In [12]: out.squeeze().shape
  Out[12]: torch.Size([1000])


豆知識:英単語としての"squeeze"は「圧搾する」という意味だ。"Unsqueeze"は辞書に 載っておらず造語と思われるが、「squeezeの逆」という程度の意味だろう。
ちなみにSqueeze-and-Excitation Network (SENet) という深層学習手法も 存在するが、これとはあまり関係がないと思われる。

おすすめ記事

PIL, NumPy, PyTorchのデータ相互変換早見表

NumPyのarray.sizeに相当するのはPytorchのTensor.numel()

PyTorch Tensorを確実にNumpy Arrayに変換する