模型蒸馏

发布时间 2023-04-10 10:21:48作者: 开发者的灵感

  蒸馏过程:

for epoch in range(epochs):
    student_model.train()
    for batch, (data, target) in enumerate(train_loader):
        student_logits = student_model(data)
        // 教师不更新
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        # student与label的loss
        loss_cri = F.cross_entropy(y_s, target)

        # student与teacher的loss
        loss_kd = soft_cross_entropy(student_logits/temperature, teacher_logits/temperature)
        ## kd loss
        #p_s = F.log_softmax(student_logits/kd_T, dim=1)
        #p_t = F.softmax(teacher_logits/kd_T, dim=1)
        #loss_kd = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / student_logits.shape[0]
        
        # total loss
        loss = alpha * loss_cri + beta * loss_kd
        loss.backward()
        optimizer.zero_grad()

参考文献