解决pytorch CrossEntropyLoss报错RuntimeError: 1D target tensor expected, multi-target not supported

2021/7/29 23:09:48

本文主要是介绍解决pytorch CrossEntropyLoss报错RuntimeError: 1D target tensor expected, multi-target not supported,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

解决方法

CrossEntropyLoss(预测值,label)需要的输入维度是:

  1. 有batch时,预测值维度为2,size为[ batch_size, n ]时,label的维度是1,size为[ batch_size ]
  2. 没有batch时,预测值的维度为2,size为[ m, n ],label的维度是1,size为[ m ]

问题解析

一个案例即可说明:

import torch
import torch.nn as nn
import numpy as np

a = torch.tensor(np.random.random((30, 5)))
b = torch.tensor(np.random.randint(0, 4, (30))).long()
loss = nn.CrossEntropyLoss()

print("a的维度:", a.size()) # torch.Size([30, 5])
print("b的维度:", b.size()) # torch.Size([30])
print(loss(a, b)) # tensor(1.6319, dtype=torch.float64)


这篇关于解决pytorch CrossEntropyLoss报错RuntimeError: 1D target tensor expected, multi-target not supported的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程