欢迎来到代码驿站!

Python代码

当前位置:首页 > 软件编程 > Python代码

pytorch 常用函数 max ,eq说明

时间:2022-02-22 10:57:45|栏目:Python代码|点击:

max找出tensor 的行或者列最大的值:

找出每行的最大值:

import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,1))

输出:

(tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))

找出每列的最大值:

import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,0))

输出结果:

(tensor([ 3.]), tensor([ 2]))

Tensor比较eq相等:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))

输出结果:

tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)

使用sum() 统计相等的个数:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())

输出结果:

tensor(2)

补充知识:PyTorch - torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le

flyfish

torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le

以上全是简写

参数是input, other, out=None

逐元素比较input和other

返回是torch.BoolTensor

import torch

a=torch.tensor([[1, 2], [3, 4]])
b=torch.tensor([[1, 2], [4, 3]])

print(torch.eq(a,b))#equals
# tensor([[ True, True],
#     [False, False]])

print(torch.ne(a,b))#not equal to
# tensor([[False, False],
#     [ True, True]])

print(torch.gt(a,b))#greater than
# tensor([[False, False],
#     [False, True]])

print(torch.lt(a,b))#less than
# tensor([[False, False],
#     [ True, False]])

print(torch.ge(a,b))#greater than or equal to
# tensor([[ True, True],
#     [False, True]])

print(torch.le(a,b))#less than or equal to
# tensor([[ True, True],
#     [ True, False]])

上一篇:Pytorch测试神经网络时出现 RuntimeError:的解决方案

栏    目:Python代码

下一篇:详解Python语法之模块Module

本文标题:pytorch 常用函数 max ,eq说明

本文地址:http://www.codeinn.net/misctech/194232.html

推荐教程

广告投放 | 联系我们 | 版权申明

重要申明:本站所有的文章、图片、评论等,均由网友发表或上传并维护或收集自网络,属个人行为,与本站立场无关。

如果侵犯了您的权利,请与我们联系,我们将在24小时内进行处理、任何非本站因素导致的法律后果,本站均不负任何责任。

联系QQ:914707363 | 邮箱:codeinn#126.com(#换成@)

Copyright © 2020 代码驿站 版权所有