pytorch中forwod函数在父类中的调用方式解读

目录
  • pytorch forwod函数在父类中的调用
    • 问题背景
  • pytorch forward方法调用原理
  • 总结

pytorch forwod函数在父类中的调用

问题背景

最近在研究Detetron2的代码结构时,发现有些网络代码里面没有forward函数,却照样可以推理,深入挖掘之后,发现其将forword函数都写在了同一个父类里面。

这就牵涉到了下面这个问题,子类中没有forward函数,只有父类中有forward函数,这样能不能正常调用网络。

import torch.nn as nn

class Network1(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x

class Network2(Network1):
    def __init__(self):
        super().__init__()

data = [1,2,3]
model = Network2().eval()
output = model(data)
print(output)

输出结果如下:

[1,2,3]

pytorch forward方法调用原理

在使用Pytorch自定义网络模型的时候,我们需要继承nn.Module这个类,然后定义forward方法来实现前向转播。

如下图的一个自定义的网络模型

首先该网络模型的初始化方法__init__需要继承父类nn.Module的初始化方法,用语句super().init()实现。

并在初始化方法里面,定义了卷积、BN、激活函数等。接下来定义forward方法,将整个网络连接起来。

有了上面的定义,我们可以实例化一个对象,例如:

fire2 = Fire(96, 128,16,64,64)

实现前向传播,使用 y= fire2(x) 其中x是该网络的输入,y是输出,实现了forward方法的额功能。

这里就会有人感到奇怪,forward作为Fire这个类的方法,使用的时候不应该是 y= fire2.forward(x)吗。

这里为什么一个类的实例可以当做方法直接使用?这是因为这个Fire类继承的父类nn.Module里面定义了__call__方法。

一个类如果定义了__call__方法,则该类的实例就可以作为一个方法那样直接使用。

例如下列代码[1]

class A():
    def __call__(self):
        print('i can be called like a function')

a = A()
a()

就会执行print函数,打印其中搞的文字。这里需要区别的是,实例化的时候,类的名称后面括号可以传递参数,例如前面实例化Fire的时候,传递in_channel,out_channel等参数。

但是要利用__call__的特性,是在实例名后面的括号中传递参数,例如上面的例子a(),这里虽然没有参数,但是也可以改变__call__的定义使之可以传递参数。

回到网络模型的内容上来。翻看nn.Module的部分源码[2],可以发现,nn.Module里面果然定义了__call__,并且传递了参数*input。在__call__的定义中国,调用了self.forward。

这里其实还有一个点值得注意。其实nn.Module里面并没有定义forward,但他却调用self.forward,严格来说,他是“想要”调用self.forward。

如果我们没有定义一个类,例如Fire,来继承nn.Module,并且在这个类里面定义forward,那么nn.Module中__call__下面的self.forward就是无效的。

这意味着,父类中__call__下面调用的函数,可以在继承他的子类中定义

下面给出一个简单的例子。

class father():
    def __call__(self):
        self.forward()
        print('I''m the father!')

class child(father):
    def forward(self):
        print('Forward!')
F=father()
C=child()

这里定义了父类father,并定义了继承他的一个子类child。此外还进行了他们的实例化。

显然,在father的__call__方法下面,调用了self.forward,但是没有定义。child在继承了father之后,定义了forward。

首先,这段代码不会报错,即使father的__call__下面的self.forward并没有定义,这也是前面我说的,虽然没有定义forward,但是可以理解为他“想要”调用self.forward。

那么在child记成了father之后,进行了forward的定义,这使得child本身可以调用forward。

在上面这段代码的基础上,如果我们执行F(),汇报下面这一段错误,这解释了forward没有定义,只是“想要”调用self.forward。

如果我们执行C(),则如下图输出。

显然,在child中补充了forward的定义,就可以成功调用。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch __init__、forward与__call__的用法小结

    1.介绍 当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__.build 和call小结)类似的情况,即经常会遇到__init__.forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢? 1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的 2)forward是表示一个前向传播,构建网络层的先后运算步骤 3)__call__的功能其实和fo

  • pytorch 中forward 的用法与解释说明

    前言 最近在使用pytorch的时候,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数 即: forward 的使用 class Module(nn.Module): def __init__(self): super(Module, self).__init__() # ...... def forward(self, x): # ...... return x data = ..... #输入数据 # 实例化一个对象 module

  • pytorch forward两个参数实例

    以channel Attention Block为例子 class CAB(nn.Module): def __init__(self, in_channels, out_channels): super(CAB, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, st

  • pytorch中forwod函数在父类中的调用方式解读

    目录 pytorch forwod函数在父类中的调用 问题背景 pytorch forward方法调用原理 总结 pytorch forwod函数在父类中的调用 问题背景 最近在研究Detetron2的代码结构时,发现有些网络代码里面没有forward函数,却照样可以推理,深入挖掘之后,发现其将forword函数都写在了同一个父类里面. 这就牵涉到了下面这个问题,子类中没有forward函数,只有父类中有forward函数,这样能不能正常调用网络. import torch.nn as nn c

  • 浅谈js中test()函数在正则中的使用

    test() 方法用于检测一个字符串是否匹配某个模式. 返回一个 Boolean 值,它指出在被查找的字符串中是否匹配给出的正则表达式. regexp.test(str) 参数 regexp 必选项.包含正则表达式模式或可用标志的正则表达式对象. str    必选项.要在其上测试查找的字符串. 说明 test 方法检查字符串是否与给出的正则表达式模式相匹配,如果是则返回 true,否则就返回 false. 每个正则表达式都有一个 lastIndex 属性,用于记录上一次匹配结束的位置. var

  • 解析mysql中UNIX_TIMESTAMP()函数与php中time()函数的区别

    mysql 中:UNIX_TIMESTAMP(), UNIX_TIMESTAMP(date)若无参数调用,则返回一个Unix timestamp ('1970-01-01 00:00:00' GMT 之后的秒数) 作为无符号整数.若用date 来调用UNIX_TIMESTAMP(),它会将参数值以'1970-01-01 00:00:00' GMT后的秒数的形式返回.date 可以是一个DATE 字符串.一个 DATETIME字符串.一个 TIMESTAMP或一个当地时间的YYMMDD 或YYYM

  • Linux中mkdir函数与Windows中_mkdir函数的区别

    下面先来给大家介绍windows下_mkdir函数 复制代码 代码如下: #include<direct.h> int _mkdir( const char *dirname ); 参数: dirname是目录的路径名指针 返回值: 如果新目录的创建时间,这些功能中的每一个返回值 0. 在错误,则函数返回 – 1 linux下mkdir函数mode_t参数详解 复制代码 代码如下: #include <sys/stat.h> int mkdir(const char *path,

  • python中super()函数的理解与基本使用

    目录 前言 super的用法 super的原理 Python super()使用注意事项 混用super与显式类调用 不同种类的参数 总结 前言 Python是一门面向对象的语言,定义类时经常要用到继承,在类的继承中,子类继承父类中已经封装好的方法,不需要再次编写,如果子类如果重新定义了父类的某一方法,那么该方法就会覆盖父类的同名方法,但是有时我们希望子类保持父类方法的基础上进行扩展,而不是直接覆盖,就需要先调用父类的方法,然后再进行功能的扩展,这时就可以通过super来实现对父类方法的调用.

  • AngularJS中控制器函数的定义与使用方法示例

    本文实例讲述了AngularJS中控制器函数的定义与使用方法.分享给大家供大家参考,具体如下: HTML正文: <body ng-app="myApp" ng-controller="myCtrl"> <h2>AngularJS函数绑定</h2> <textarea ng-model="message" cols="40" rows="10"></tex

  • Lua中的函数知识总结

    前言 Lua中的函数和C++中的函数的含义是一致的,Lua中的函数格式如下: 复制代码 代码如下: function MyFunc(param)      -- Do something end 在调用函数时,也需要将对应的参数放在一对圆括号中,即使调用函数时没有参数,也必须写出一对空括号.对于这个规则只有一种特殊的例外情况:一个函数若只有一个参数,并且此参数是一个字符串或table构造式,那么圆括号便可以省略掉.看以下代码: 复制代码 代码如下: print "Hello World"

  • Python中函数参数调用方式分析

    本文实例讲述了Python中函数参数调用方式.分享给大家供大家参考,具体如下: Python中函数的参数是很灵活的,下面分四种情况进行说明. (1) fun(arg1, arg2, ...) 这是最常见的方式,也是和其它语言类似的方式 下面是一个实例: >>> def fun(x, y): return x - y >>> fun(12, -2) 14 (2) fun(arg1, arg2=value2, ...) 这种就是所谓的带默认参数的函数,调用的时候我们可以指定

  • 谈谈JavaScript中的函数

    JS中的函数简介 JS中的函数是一种通过调用来完成具体业务的一段代码块.最核心的目的是将可重复执行的操作进行封装,然后供调用方无限制的调用. JS中的函数的定义 JS中函数定义,有如下两种形式: 方式1 function f1(){} //函数声明,f1为函数名,可以将其理解为变量f1指向一个函数 function f2(){return 100;}//函数允许有返回值 function f3(a,b){}//函数中可以定义多个参数,无需指定变量类型 方式2 var f4=function(){

  • TypeScript中的函数

    目录 1.函数定义 1.1JavaScript中的函数 1.2TypeScript中的函数 3.可选参数和默认参数 4.剩余参数 1.函数定义 1.1JavaScript中的函数 在学习TypeScript中的函数前我们先来回顾一下JavaScript中的函数定义常用的包含以下几种: 第一种:使用function关键字声明函数 function add1 (x, y) { return x + y } 第二种:使用字面量方式声明函数 const add2 = function (x, y) {

随机推荐