Pytorch实现网络部分层的固定不进行回传更新问题及思路详解

目录
  • 实际问题
  • 问题解决思路
  • 代码实现
  • LAST 参考文献

实际问题

Pytorch有的时候需要对一些层的参数进行固定,这些层不进行参数的梯度更新

问题解决思路

那么从理论上来说就有两种办法

  • 优化器初始化的时候不包含这些不想被更新的参数,这样他们会进行梯度回传,但是不会被更新
  • 将这些不会被更新的参数梯度归零,或者不计算它们的梯度

思路就是利用tensorrequires_grad,每一个tensor都有自己的requires_grad成员,值只能为TrueFalse。我们对不需要参与训练的参数的requires_grad设置为False

在optim参数模型参数中过滤掉requires_grad为False的参数。
还是以上面搭建的简单网络为例,我们固定第一个卷积层的参数,训练其他层的所有参数。

代码实现

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,32,3)
        self.conv2 = nn.Conv2d(32,24,3)
        self.prelu = nn.PReLU()
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                nn.init.constant_(m.bias.data,0)
            if isinstance(m,nn.Linear):
                m.weight.data.normal_(0.01,0,1)
                m.bias.data.zero_()
    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.prelu(out)
        return out

遍历第一层的参数,然后为其设置requires_grad

model = Net()
for name, p in model.named_parameters():
    if name.startswith('conv1'):
        p.requires_grad = False

optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad is not False ,model.parameters()),lr= 0.2)

为了验证一下我们的设置是否正确,我们分别看看model中的参数的requires_gradoptim中的params_group()

for p in model.parameters():
    print(p.requires_grad)

能看出优化器仅仅对requires_gradTrue的参数进行迭代优化。

LAST 参考文献

Pytorch中,动态调整学习率、不同层设置不同学习率和固定某些层训练的方法_我的博客有点东西-CSDN博客

到此这篇关于Pytorch实现网络部分层的固定不进行回传更新的文章就介绍到这了,更多相关Pytorch网络部分层内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch 一行代码查看网络参数总量的实现

    大家还是直接看代码吧~ netG = Generator() print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) 补充:PyTorch查看网络模型的参数量PARA

  • Pytorch训练网络过程中loss突然变为0的解决方案

    问题 // loss 突然变成0 python train.py -b=8 INFO: Using device cpu INFO: Network: 1 input channels 7 output channels (classes) Bilinear upscaling INFO: Creating dataset with 868 examples INFO: Starting training: Epochs: 5 Batch size: 8 Learning rate: 0.001

  • PyTorch实现更新部分网络,其他不更新

    torch.Tensor.detach()的使用 detach()的官方说明如下: Returns a new Tensor, detached from the current graph. The result will never require gradient. 假设有模型A和模型B,我们需要将A的输出作为B的输入,但训练时我们只训练模型B. 那么可以这样做: input_B = output_A.detach() 它可以使两个计算图的梯度传递断开,从而实现我们所需的功能. 以上这篇P

  • pytorch  网络参数 weight bias 初始化详解

    权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生. 在pytorch的使用过程中有几种权重初始化的方法供大家参考. 注意:第一种方法不推荐.尽量使用后两种方法. # not recommend def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.

  • Pytorch实现网络部分层的固定不进行回传更新问题及思路详解

    目录 实际问题 问题解决思路 代码实现 LAST 参考文献 实际问题 Pytorch有的时候需要对一些层的参数进行固定,这些层不进行参数的梯度更新 问题解决思路 那么从理论上来说就有两种办法 优化器初始化的时候不包含这些不想被更新的参数,这样他们会进行梯度回传,但是不会被更新 将这些不会被更新的参数梯度归零,或者不计算它们的梯度 思路就是利用tensor的requires_grad,每一个tensor都有自己的requires_grad成员,值只能为True和False.我们对不需要参与训练的参

  • 基于Pytorch版yolov5的滑块验证码破解思路详解

    前言 本文将使用pytorch框架的目标识别技术实现滑块验证码的破解.我们这里选择了yolov5算法 例:输入图像 输出图像 可以看到经过检测之后,我们能很准确的定位到缺口的位置,并且能得到缺口的坐标,这样一来我们就能很轻松的实现滑动验证码的破解. 一.前期工作 yolov系列是常用的目标检测算法,yolov5不仅配置简单,而且在速度上也有不小的提升,我们很容易就能训练我们自己的数据集. YOLOV5 Pytorch版本GIthub网址感谢这位作者的代码. 下载之后,是这样的格式 ---data

  • Vue实现textarea固定输入行数与添加下划线样式的思路详解

    先上效果图### textarea下划线 设置一张1*35 //行高 的图片 , 设置背景图即可. background: url('./img/linebg.png') repeat; border: none;outline: none;overflow: hidden; line-height: 35px;//注意行高要和背景图高度一致resize: none; 固定输入行数 需求:用户固定不论多少字节,只能输入2行. 因为是限制行数,所以不能用maxlength设置. 实现思路 首先想到

  • Oracle固定执行计划之SQL PROFILE概要文件详解

    1.  引子 Oracle系统为了合理分配和使用系统的资源提出了概要文件的概念.所谓概要文件,就是一份描述如何使用系统的资源(主要是CPU资源)的配置文件.将概要文件赋予某个数据库用户,在用户连接并访问数据库服务器时,系统就按照概要文件给他分配资源. 包括: 1.管理数据库系统资源. 利用Profile来分配资源限额,必须把初始化参数resource_limit设置为true默认是TRUE的. 2.管理数据库口令及验证方式. 默认给用户分配的是DEFAULT概要文件,将该文件赋予了每个创建的用户

  • java固定大小队列的几种实现方式详解

    目录 前言 基于Hutool中的FixedLinkedHashMap 基于Guava的EvictingQueue 基于Redis的list操作 总结 前言 最近团队有同学在开发中,遇到一个需求,统计最近10次的异常次数,咨询有没有类似的list.针对这个问题,记录一下几种处理方式. 基于Hutool中的FixedLinkedHashMap 引入maven依赖 <dependency> <groupId>cn.hutool</groupId> <artifactId

  • Python LeNet网络详解及pytorch实现

    目录 1.LeNet介绍 2.LetNet网络模型 3.pytorch实现LeNet 1.LeNet介绍 LeNet神经网络由深度学习三巨头之一的Yan LeCun提出,他同时也是卷积神经网络 (CNN,Convolutional Neural Networks)之父.LeNet主要用来进行手写字符的识别与分类,并在美国的银行中投入了使用.LeNet的实现确立了CNN的结构,现在神经网络中的许多内容在LeNet的网络结构中都能看到,例如卷积层,Pooling层,ReLU层.虽然LeNet早在20

  • 关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu18.04 显卡:GTX1080ti python版本:2.7(3.7) 网络架构 具有4层的CNN具有以下架构. 输入层:784个节点(MNIST图像大小) 第一卷积层:5x5x32 第一个最大池层 第二卷积层:5x5x64 第二个最大池层 第三个完全连接层:1024个节点 输出层:10个节点(M

  • Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

    原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的"fake"数据,目的是网络生成的fake数据可以"骗过"判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据.总的来说是:判别器区分真实数据和fake数据的能力越强越好:生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能

  • Pytorch之finetune使用详解

    finetune分为全局finetune和局部finetune.首先介绍一下局部finetune步骤: 1.固定参数 for name, child in model.named_children(): for param in child.parameters(): param.requires_grad = False 后,只传入 需要反传的参数,否则会报错 filter(lambda param: param.requires_grad, model.parameters()) 2.调低学

  • pytorch:model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!! Class Inpaint_Network() ...... Model = Inpaint_Nerwoek() #train: Model.train(mode=True) ..... #test: Model.ev

随机推荐