CV・NLPハマりどころメモ

画像認識と自然言語処理を研究する中でうまくいかなかったことと、その対策をまとめる自分用メモが中心。

Pytorch v0.4のコードをv0.3で動かす際には.dataに注意!![Pytorch]

Pytorchのコードを見ているとミニバッチごとのlossやaccuracyを計算する際、.dataを用いて値を取り出されることが頻繁にある。

よくある例:

for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
      inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
      dis_opt.zero_grad()
      out = discriminator.batchClassify(inp)
      loss_fn = nn.BCELoss()
      loss = loss_fn(out, target)
      loss.backward()
      dis_opt.step()

      total_loss += loss.data # .dataを使ってlossから値を取り出しtotal_lossに蓄積
      total_acc += torch.sum((out>0.5)==(target>0.5)).data # .dataを使ってaccuracyを取り出しtotal_accに蓄積

.dataを用いて値を取り出す上のコードは、pytorch v0.4で動かすならエラーがでない。

しかし、上のコードをpytorch v0.3で動かすと下のエラーが発生する。

# pytorch v0.3で動かした際のエラー

TypeError: div_ received an invalid combination of arguments - got (float), but expected one of:
 * (int value)
      didn't match because some of the arguments have invalid types: (!float!)
 * (torch.cuda.ByteTensor other)
      didn't match because some of the arguments have invalid types: (!float!)

なぜv0.4で動いていたコードがv0.3で動かなくなってしまうのか?

その原因は、pytorch v0.3でVariableの型の行列から値を取り出す際に.dataを使用するとデータの型がVariableからTensorに変わってしまうからである。

簡単な例で問題をみていこう。

# Pytorch v0.4で動かすと問題ないのだが、v0.3で動かすとエラーが発生するコード
# 本コードではv0.3で動かしたことを想定

import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換

y = Variable(torch.Tensor([4,5,6])).data # .dataを使ってVariableから値を取り出す。

# 返り値を表示。xにはVariableが格納されていることがわかる。
x
>>
Variable containing:
 1
 2
 3
[torch.FloatTensor of size 3]

# 返り値を表示。yはVariableではなくTensorが格納されてしまった。
y
>>

 4
 5
 6
[torch.FloatTensor of size 3]

# TensorとVariableを加算するとエラーが発生
x+y
>>
Traceback (most recent call last):

  File "<ipython-input-188-259706549f3d>", line 1, in <module>
    x+y

RuntimeError: add() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
 * (float other, float alpha)
 * (Variable other, float alpha)

上の例のようにpytorch v0.3では、.dataで値を取り出すと型違いによるエラーが発生してしまう為、v0.4で動いていたコードが動かなくなる。

v0.4ではVariableとTensorが統合された為、型違いによるエラーが発生しないようだ。

v0.4公式がまとめた変更点(英語)

pytorch.org

v0.4の変更点が日本語でまとめられた記事

qiita.com

ちなみに、v0.4の公式ドキュメントでは.dataを使うことはunsafeだと言っており、代わりに.detachを使うことが推奨されている。

pytorch.org

What about .data ? のセクションに .dataを使うことの危険性が述べられている。

.dataで値を取り出した場合、xに対する変更がautogradで追跡できない為、危険視されているようだ。

最後に.detachを使用して書いた、v0.3とv0.4ともにエラーが発生しないコードを載せる。

# v0.3とv0.4ともにエラーが発生しないコード
import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換

z = Variable(torch.Tensor([7,8,9])).detach()

z
>>
Variable containing:
 7
 8
 9
[torch.FloatTensor of size 3]

# v0.3とv0.4ともにエラー無しで加算が行える
x+z
>>
Variable containing:
  8
 10
 12
[torch.FloatTensor of size 3]

追記

あまり良い方法ではないのだが、lossやaccuracyなど1x1の値を取り出す場合には、.data[0]を使う手もある。

.data[0]は、下のIrfan_Buluさんの回答を見て知った。

discuss.pytorch.org