迁移学习

实际训练中很少有网络能够拥有足够大的数据集进行训练,所以迁移学习是实际卷积网络训练过程中非常重要的步骤

简介

首先将模型在大数据集(比如ImageNet,包含120万张共1000类的图像)上进行预训练,然后将训练后的模型作为指定数据集的初始化或者固定特征提取器,这一操作称为迁移学习(Transfer Learning

适用场景

迁移学习主要有2个适用场景:

  1. 将卷积网络作为固定特征提取器。除了最后的全连接层外,将会冻结所有网络的权重。最后的全连接层将会被一个新的随机初始化的全连接层替代,并且仅训练该层
  2. 微调卷积网络。不使用随机初始化而是用一个预训练网络来初始化网络,就像在ImageNet 1000数据集上训练网络一样,剩下的训练和往常一样。可以微调卷积网络的所有层,或者可以保持一些早期的层固定不变(由于过度拟合的问题),并且只微调网络的一些较高层部分。这是因为观察到卷积网络的早期特征包含更多通用特征(例如,边缘检测器或颜色斑点检测器),这些特征对许多任务都很有用,但是卷积网络的顶层对于原始数据集中包含的类的细节变得越来越具体。例如,在包含许多犬种的ImageNet的情况下,卷积网络的很大一部分表示能力可以用于区分犬种的特定特性

预训练模型

  • Caffe训练的卷积网络可以在Model Zoo上进行分享
  • PyTorchtorchvision.models中提供了多个网络及其预训练模型

示例

使用网络ResNet18对小数据集(只有蚂蚁和蜜蜂两类,每类有约120张训练图片和75张测试图片)进行识别,比较随机初始化参数、使用预训练模型作为固定特征提取器以及微调网络的差异

加载数据

下载数据集,保存并解压到data文件夹

1
2
3
├── data
│   ├── hymenoptera_data
│   └── hymenoptera_data.zip
  • 训练阶段。随机裁剪图片并缩放至224x224大小,同时进行随机水平翻转,最后进行数据标准化操作
  • 测试阶段。缩放图片至256x256大小,从中心裁剪224x224大小,最后进行数据标准化操作
  • 批量大小:4

训练参数

  • 网络:ResNet18
  • 反向传播:随机梯度下降(SGD
  • 学习率:1e-3
  • 动量:0.9
  • 学习率调度器:随步长衰减,每7轮迭代衰减一次,gamma=0.1

实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# -*- coding: utf-8 -*-

"""
@author: zj
@file: finetune.py
@time: 2020-02-26
"""

import time
import copy
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models


def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated


def load_data():
# Data augmentation and normalization for training
# Just normalization for validation
# 进行数据扩充
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

return dataloaders, dataset_sizes, class_names


def show_data(dataloaders):
# 可视化数据集
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
# 制作图像网格
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])


def visualize_model(model, dataloaders, class_names, num_images=6):
"""
可视化模型训练结果
"""
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()

with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)

outputs = model(inputs)
_, preds = torch.max(outputs, 1)

for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images // 2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])

if images_so_far == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)


def visualize_train():
"""
可视化训练损失和精度
"""


def create_model(mode='ri'):
if mode == 'fixed':
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False
return model_conv
elif mode == 'ft':
return models.resnet18(pretrained=True)
else:
return models.resnet18()


def train_model(model, criterion, optimizer, scheduler, dataset_sizes, dataloaders, num_epochs=25):
since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)

# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode

running_loss = 0.0
running_corrects = 0

# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()

# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()

epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]

print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))

# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())

print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_wts)
return model


if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataloaders, dataset_sizes, class_names = load_data()
show_data(dataloaders)

for mode_name in {'ri', 'ft', 'fixed'}:
print('begin mode: %s' % mode_name)
print('#' * 20)
# 创建网络模型,指定参数初始化方式
model = create_model(mode=mode_name)
# model = create_model(mode='ri')
# model = create_model(mode='ft')
# model = create_model(mode='fixed')

num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model.fc = nn.Linear(num_ftrs, 2)

model_conv = model.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
# 随机梯度下降
optimizer_conv = optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
# 随步长衰减
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

model_conv = train_model(model_conv, criterion,
optimizer_conv,
exp_lr_scheduler,
dataset_sizes,
dataloaders,
num_epochs=25)

# visualize_model(model_conv, dataloaders, class_names, num_images=6)

比较

分别用3种不同的权重处理方式(随机初始化、微调、固定特征提取器)进行训练,共迭代25次,结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
begin mode: ri
####################
...
Training complete in 0m 39s
Best val Acc: 0.725490
begin mode: ft
####################
...
Training complete in 0m 40s
Best val Acc: 0.941176
begin mode: fixed
####################
...
Training complete in 0m 26s
Best val Acc: 0.960784

从训练结果中发现,使用迁移学习后的网络模型能够得到更好的训练结果

使用场景

在何种情况下进行迁移学习的使用,最大的因素有两个:

  • 新数据集的规模
  • 新数据集与原先数据集的相似程度

根据以上两个因素共分为4个使用场景:

  1. 新数据集很小,与原始数据集相似。由于数据集很小,存在过拟合的问题,所以微调卷积网络不是一个好主意;由于数据与原始数据相似,卷积网络中的高级特征也与此数据集相关。最好的办法就是使用固定特征提取器的方式,再训练一个线性分类器
  2. 新数据集很大,与原始数据集相似。因为有更多的数据,所以对整个网络进行微调也不会产生过拟合
  3. 新数据集很小,与原始数据集非常不同。因为数据很小,所以使用固定特征提取器的方式,再训练一个线性分类器。由于数据集有很大的不同,所以不能从包含更多数据集特定特征的网络顶部来训练分类器,而是固定网络早期权重,微调网络顶部权重的方式来训练线性分类器
  4. 新数据集很大,与原始数据集非常不同。由于数据集非常大,能够从头开始训练一个卷积网络。但是在实践中,用来自预训练模型的权重初始化仍然非常有效。在这种情况下,将有足够的数据和信心通过整个网络进行微调

相关阅读