要得到每一类的预测概率,首先通过torch.eq判断每个图片预测的准不准确,循环每个预测结果,得到没个结果对应的标签,如果准确,在该标签类的正确数量加一,在该类的总的数量加一。最后输出该类正确的数量除以该类总的数量就得到了该类的预测概率了。
# 查看单类准确率
classes = ('0', '1', '2', '3','4')
N_CLASSES = 5
class_correct = list(0. for i in range(N_CLASSES))
class_total = list(0. for i in range(N_CLASSES))
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
# print(val_labels.shape)
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
c = torch.eq(predict_y, val_labels.to(device)).squeeze()
size = int(val_labels.shape[0])
for i in range(size):
label = val_labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
for i in range(N_CLASSES):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
若该分类任务存在类间分类,每一类差距很小,想要使预测结果处于相邻类就算分类正确时,则需要先将val_loader的batch_size设置为1,再通过一系列if语句实现该效果。
# 查看单类准确率
classes = ('0', '1', '2', '3','4')
N_CLASSES = 5
class_correct = list(0. for i in range(N_CLASSES))
class_total = list(0. for i in range(N_CLASSES))
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
labels = val_labels.numpy()
predict = predict_y.cpu().numpy()
if labels == 0:
if predict==0 or predict==1:
c = True
else:
c = False
elif labels == 4:
if predict==3 or predict==4:
c = True
else:
c = False
else:
if predict==labels-1 or predict==labels or predict==labels+1:
c = True
else:
c = False
size = int(val_labels.shape[0])
for i in range(size):
label = val_labels[i]
class_correct[label] += c
class_total[label] += 1
acc += c
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
for i in range(N_CLASSES):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)