
蒸馏过程:
for epoch in range(epochs):
student_model.train()
for batch, (data, target) in enumerate(train_loader):
student_logits = student_model(data)
// 教师不更新
with torch.no_grad():
teacher_logits = teacher_model(data)
# student与label的loss
loss_cri = F.cross_entropy(y_s, target)
# student与teacher的loss
loss_kd = soft_cross_entropy(student_logits/temperature, teacher_logits/temperature)
## kd loss
#p_s = F.log_softmax(student_logits/kd_T, dim=1)
#p_t = F.softmax(teacher_logits/kd_T, dim=1)
#loss_kd = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / student_logits.shape[0]
# total loss
loss = alpha * loss_cri + beta * loss_kd
loss.backward()
optimizer.zero_grad()