博客
关于我
知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型
阅读量:458 次
发布时间:2019-03-06

本文共 12942 字,大约阅读时间需要 43 分钟。

文章目录

摘要

本文介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。文中主要探讨了两种蒸馏方式,并详细讲解了如何实现第二种方法,即通过卷积神经网络蒸馏DEiT模型。


模型和损失

model.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal
# 注册模型
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model

losses.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
"""Implements the knowledge distillation loss"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision."""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop through the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum', log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss

训练Teacher模型

teacher_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from timm.models import regnetx_160
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 训练函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'TeacherModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

学生网络

student_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)[0]
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'StudentModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

蒸馏学生网络

train_kd.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.loss import LabelSmoothingCrossEntropy
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
from losses import DistillationLoss
# 设置随机因子
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(s_net, t_net, device, criterionKD, train_loader, optimizer, epoch):
s_net.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out_s = s_net(data)
loss = criterionKD(data, out_s, target)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, criterionCls, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
out_s = model(data)
loss = criterionCls(out_s, target)
_, pred = torch.max(out_s.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'KDModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)
distillation_type = 'hard'
distillation_alpha = 0.5
distillation_tau = 1.0

结果比对

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
# 定义文件路径
teacher_file = 'result.json'
student_file = 'result_student.json'
student_kd_file = 'result_kd.json'
# 定义读取函数
def read_json(file):
with open(file, 'r', encoding='utf8') as fp:
json_data = json.load(fp)
print(json_data)
return json_data
# 读取数据
teacher_data = read_json.teacher_file()
student_data = read_json.student_file()
student_kd_data = read_json.student_kd_file()
# 绘制图表
x = list(range(len(teacher_data)))
plt.plot(x, list(teacher_data.values()), label='teacher')
plt.plot(x, list(student_data.values()), label='student without IRG')
plt.plot(x, list(student_kd_data.values()), label='student with IRG')
plt.title('Test accuracy')
plt.legend()
plt.show()

总结

本文详细介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。通过实验结果可以看出,采用硬蒸馏方式能够有效提升学生网络的性能。代码和数据集详见链接:链接

转载地址:http://sjdbz.baihongyu.com/

你可能感兴趣的文章
MySQL 加锁处理分析
查看>>
mysql 协议的退出命令包及解析
查看>>
mysql 参数 innodb_flush_log_at_trx_commit
查看>>
mysql 取表中分组之后最新一条数据 分组最新数据 分组取最新数据 分组数据 获取每个分类的最新数据
查看>>
MySQL 命令和内置函数
查看>>
mysql 四种存储引擎
查看>>
MySQL 在并发场景下的问题及解决思路
查看>>
MySQL 基础架构
查看>>
MySQL 基础模块的面试题总结
查看>>
MySQL 备份 Xtrabackup
查看>>
mYSQL 外键约束
查看>>
mysql 多个表关联查询查询时间长的问题
查看>>
mySQL 多个表求多个count
查看>>
mysql 多字段删除重复数据,保留最小id数据
查看>>
MySQL 多表联合查询:UNION 和 JOIN 分析
查看>>
MySQL 大数据量快速插入方法和语句优化
查看>>
mysql 如何给SQL添加索引
查看>>
mysql 字段区分大小写
查看>>
mysql 字段合并问题(group_concat)
查看>>
mysql 字段类型类型
查看>>