c测试

发布时间 2023-12-10 22:19:58作者: 太好了还有脑子可以用
    def eval(self, phase='val', openset=False):
        print("Enter eval")
        print_str = ['Phase: %s' % (phase)]
        print_write(print_str, self.log_file)  # Phase: test
        time.sleep(0.25)  # 暂停程序执行0.25秒,以便给打印输出留出时间。

        if openset:  # 检查是否处于openset测试模式
            print('Under openset test mode. Open threshold is %.1f'
                  % self.training_opt['open_threshold'])

        torch.cuda.empty_cache()  # 清空GPU缓存,以释放内存。

        # In validation or testing mode, set model to eval() and initialize running loss/correct
        for model in self.networks.values():
            model.eval()

        # 创建一个空的张量,用于存储所有样本的预测logits,labels,paths。
        self.total_logits = torch.empty((0, self.training_opt['num_classes'])).to(self.device)
        self.total_labels = torch.empty(0, dtype=torch.long).to(self.device)
        self.total_paths = np.empty(0)

        # Iterate over dataset
        # 对数据集中的每个样本执行以下操作,并使用tqdm库显示进度条:
        for inputs, labels, paths in tqdm(self.data[phase]):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            # print(labels)
            # If on training phase, enable gradients
            with torch.set_grad_enabled(False):  # 不计算梯度

                # In validation or testing
                self.batch_forward(inputs,
                                   labels,
                                   centroids=self.memory['centroids'],
                                   phase=phase)

                # 将当前样本的x连接到张量x中
                self.total_logits = torch.cat((self.total_logits, self.logits))
                self.total_labels = torch.cat((self.total_labels, labels))
                self.total_paths = np.concatenate((self.total_paths, paths))

        # 对self.total_logits张量进行softmax操作,并获取每个样本的概率和预测结果。
        probs, preds = F.softmax(self.total_logits.detach(), dim=1).max(dim=1)

        if openset:  # 处于openset测试模式
            # 将概率低于开放阈值的样本的预测结果设置为 - 1。
            preds[probs < self.training_opt['open_threshold']] = -1
            self.openset_acc = mic_acc_cal(preds[self.total_labels == -1],
                                           self.total_labels[self.total_labels == -1])
            print('\n\nOpenset Accuracy: %.3f' % self.openset_acc)

        # Calculate the overall accuracy and F measurement
        # 计算总体准确率。
        self.eval_acc_mic_top1 = mic_acc_cal(preds[self.total_labels != -1],
                                             self.total_labels[self.total_labels != -1])
        # 计算F度量值。
        self.eval_f_measure = F_measure(preds, self.total_labels, openset=openset,
                                        theta=self.training_opt['open_threshold'])

        # 计算多样本准确率、中等样本准确率和少样本准确率。
        self.many_acc_top1, \
        self.median_acc_top1, \
        self.low_acc_top1 = shot_acc(preds[self.total_labels != -1],
                                     self.total_labels[self.total_labels != -1],
                                     self.data['train'])
        # Top-1 accuracy and additional string
        # format:
        #  Phase: val
        #  总体准确率Evaluation_accuracy_micro_top1: 0.011
        #  计算F度量值Averaged F-measure: 0.002
        #  多中少样本准确率Many_shot_accuracy_top1: 0.029 Median_shot_accuracy_top1: 0.000 Low_shot_accuracy_top1: 0.000
        print_str = ['\n\n',
                     'Phase: %s'
                     % (phase),
                     '\n\n',
                     '总体准确率(Evaluation_accuracy_micro_top1): %.3f'
                     % (self.eval_acc_mic_top1),
                     '\n',
                     '计算F度量值(Averaged F-measure): %.3f'
                     % (self.eval_f_measure),
                     '\n',
                     '多样本准确率(Many_shot_accuracy_top1): %.3f'
                     % (self.many_acc_top1),
                     '中等样本准确率(Median_shot_accuracy_top1): %.3f'
                     % (self.median_acc_top1),
                     '少样本准确率(Low_shot_accuracy_top1): %.3f'
                     % (self.low_acc_top1),
                     '\n']

        if phase == 'val':
            print_write(print_str, self.log_file)
        else:
            print(*print_str)