- PyTorchにはいろいろなモデルの書き方があるけど,このメモの内容を考えると学習パラメタが無いレイヤ含めてinitにもたせておくほうが良いのかもしれない。
(と思ったけど,一番最後のレイヤをとりあえずコピーしてforwardを再定義するやり方ならどっちでも良い,と思った)
重みにアクセスしたい場合
- 重みを特定の値で初期化した場合などが利用としては考えられるか。
- model.state_dict()が簡単。
- 辞書で返してくれるのでmodelのinitに付けた名前でアクセスが可能。
model = Net()
model.state_dict()
中間層の出力にアクセスしたい場合
- forwardのhookしたり,色々やり方はあるようだけど,モデルの構築の段階からSequencialを使って分割して作成しておくのがわかりやすい。
- 隠層の出力をそれぞれ得たい,という場合には後述のようにforwardを再定義するような方法があるようだ。
- サンプルとしては,torchvisionのモデルが参考になる。例えば,AlexNetは特徴量抽出までのfeaturesと,それを使って分類するclassifeirに分けて実装されている。例えば,AlexNetは下記のようになっている。
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
- これを用いて,学習済みのAlexNetの特徴量抽出までを使用したければ,下記のように使える。
model = models.alexnet(pretrained=True)
y = model.features(x)
- VGGの各層からの出力を得るには下記のような例がある。ここでは,各層ではなく,特定のReLU層,3,8,15,22を出力している。やっていることは,VGG16のfeaturesのレイヤをそれぞれ通して行きながら,その過程で所望のレイヤの出力を保存している。
- このやり方は結構汎用的だ。学習済みのモデルをとりあえずもってきて,自分でfowardを書き直しているイメージ。必要なら,適時dropoutを新たに挟む,とかも出来る。
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple
class Vgg16(torch.nn.Module):
def __init__(self):
super(Vgg16, self).__init__()
features = list(vgg16(pretrained = True).features)[:23]
self.features = nn.ModuleList(features).eval()
def forward(self, x):
results = []
for ii,model in enumerate(self.features):
x = model(x)
if ii in {3,8,15,22}:
results.append(x)
return results