当前位置:主页 > 软件编程 > Python代码 >

在Pytorch中计算自己模型的FLOPs方式

时间:2020-10-20 13:22:03 | 栏目:Python代码 | 点击:

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

您可能感兴趣的文章:

相关文章