pytorch 一行代码查看网络参数总量的实现
大家还是直接看代码吧~
netG = Generator()
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
补充:PyTorch查看网络模型的参数量PARAMS和FLOPS等
在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。
示例代码如下:
from torchstat import stat from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d model = resnet50() stat(model, (3, 224, 224))
打印信息如下:


以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。
相关推荐
-
pytorch 实现查看网络中的参数
可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数) 可示例代码如下: params = list(model.named_parameters()) (name, param) = params[28] print(name) print(param.grad) print('-------------------------------------------------') (name
-
pytorch 求网络模型参数实例
用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参
-
pytorch 实现在一个优化器中设置多个网络参数的例子
我就废话不多说了,直接上代码吧! 其实也不难,使用tertools.chain将参数链接起来即可 import itertools ... self.optimizer = optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ... 以上这篇pytorch 实现在一个优化器中设置多个网络参数的
-
关于pytorch中网络loss传播和参数更新的理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇. TensorFlow: 228--->266 Keras: 42--->56 Pytorch: 87--->252 在使用pytorch中,自己有一些思考,如下: 1. loss计算和反向传播 import torch.nn as nn
-
pytorch查看网络参数显存占用量等操作
1.使用torchstat pip install torchstat from torchstat import stat import torchvision.models as models model = models.resnet152() stat(model, (3, 224, 224)) 关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数.我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数. 2.使用torchsummary pip in
-
pytorch 网络参数 weight bias 初始化详解
权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生. 在pytorch的使用过程中有几种权重初始化的方法供大家参考. 注意:第一种方法不推荐.尽量使用后两种方法. # not recommend def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.
-
pytorch 一行代码查看网络参数总量的实现
大家还是直接看代码吧~ netG = Generator() print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) 补充:PyTorch查看网络模型的参数量PARA
-
Pytorch反向求导更新网络参数的方法
方法一:手动计算变量的梯度,然后更新梯度 import torch from torch.autograd import Variable # 定义参数 w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定义输出 d = torch.mean(w1) # 反向求导 d.backward() # 定义学习率等参数 lr = 0.001 # 手动更新参数 w1.data.zero_() # BP求导更新参数之前,需先对导
-
在pytorch中如何查看模型model参数parameters
目录 pytorch查看模型model参数parameters pytorch查看模型参数总结 1:DNN_printer 2:parameters 3:get_model_complexity_info() 4:torchstat pytorch查看模型model参数parameters 示例1:pytorch自带的faster r-cnn模型 import torch import torchvision model = torchvision.models.detection.faster
-
Okhttp、Retrofit进度获取的方法(一行代码搞定)
起因 对于广大Android开发者来说,最近用的最多的网络库,莫过于Okhttp啦(Retrofit依赖Okhttp). Okhttp不像SDK内置的HttpUrlConnection一样,可以明确的获取数据读写的过程,我们需要执行一些操作. 介绍 Retrofit依赖Okhttp.Okhttp依赖于Okio.那么Okio又是什么鬼?别急,看官方介绍: Okio is a library that complements java.io and java.nio to make it much
-
PyTorch 编写代码遇到的问题及解决方案
PyTorch编写代码遇到的问题 错误提示:no module named xxx xxx为自定义文件夹的名字 因为搜索不到,所以将当前路径加入到包的搜索目录 解决方法: import sys sys.path.append('..') #将上层目录加入到搜索路径中 sys.path.append('/home/xxx') # 绝对路径 import os sys.path.append(os.getcwd()) # #将当前工作路径加入到搜索路径中 还可以在当前终端的命令行设置 export
-
pytorch加载自定义网络权重的实现
在将自定义的网络权重加载到网络中时,报错: AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead. 我们一步一步分析. 模型网络权重保存额代码是:torch.sa
随机推荐
- php批量删除数据
- Java输出系统当前的日期(年月日时分秒毫秒)
- jsp fckeditor 上传中文图片乱码问题的解决方法
- 批处理实现过滤重复行
- mongodb主从复制_动力节点Java学院整理
- PHP容易忘记的知识点分享
- Android编程中调用Camera时预览画面有旋转问题的解决方法
- Python使用tablib生成excel文件的简单实现方法
- 自动重启服务的shell脚本代码
- oracle中left join和right join的区别浅谈
- Jquery+Ajax+Json+存储过程实现高效分页
- PHP 5.3新特性命名空间规则解析及高级功能
- win2003服务器删除服务的方法
- 安装服务器常见组件之ISAPI_Rewrite组件图文安装教程
- JS+CSS实现下拉刷新/上拉加载插件
- C#开发微信门户及应用(5) 用户分组信息管理
- android 获取手机GSM/CDMA信号信息,并获得基站信息的方法
- Android自定义ScrollView实现放大回弹效果实例代码
- c语言通过opencv实现轮廓处理与切割
- 移动端android上line-height不居中的问题的解决
