pytorch-损失函数

torch.dtype

最近进行一些代码实操的过程中,经常遇到一个报错,TypeError: tensor(): argument ‘dtype’ must be torch.dtype, not torch.tensortype,原因则是因为在dtype参数中错误的传递了torch.tensortype。
因为有些类似分类任务中,运行报错可能时张量类型要求是LongTensor,但是直接在定义的时候,我们需要使用对应的torch.intdtype。而不能直接使用张量。

Each torch.Tensor has a torch.dtype, torch.device, and torch.layout.

Data type dtype Legacy Constructors
32-bit floating point torch.float32 or torch.float torch.*.FloatTensor
64-bit floating point torch.float64 or torch.double torch.*.DoubleTensor
64-bit complex torch.complex64 or torch.cfloat
128-bit complex torch.complex128 or torch.cdouble
16-bit floating point 1 torch.float16 or torch.half torch.*.HalfTensor
16-bit floating point 2 torch.bfloat16 torch.*.BFloat16Tensor
8-bit integer (unsigned) torch.uint8 torch.*.ByteTensor
8-bit integer (signed) torch.int8 torch.*.CharTensor
16-bit integer (signed) torch.int16 or torch.short torch.*.ShortTensor
32-bit integer (signed) torch.int32 or torch.int torch.*.IntTensor
64-bit integer (signed) torch.int64 or torch.long torch.*.LongTensor
Boolean torch.bool torch.*.BoolTensor

分类任务中,默认是从0起始的,所以针对分类任务进行编号时,需要从 0 开始进行编号。二分类(0,1),三分类(0,1,2)。

-------------本文结束感谢您的阅读-------------