PyTorch框架学习九——网络模型的构建_6.pytorch中用来搭建网络层的模块是什么?-程序员宅基地

技术标签: PyTorch  深度学习  pytorch  神经网络  

笔记二到八主要介绍与数据有关的内容,这次笔记将开始介绍网络模型有关的内容,首先我们不追求网络内部各层的具体内容,重点关注模型的构建,学会了如何构建模型,然后再开始一些具体网络层的学习。

一、概述

模型有关的内容主要如下图所示:
在这里插入图片描述
主要是模型的搭建权值的初始化两个问题,而模型的搭建里,首先需要构建单独的网络层,然后将这些网络层按顺序拼接起来,就构成了一个模型,然后进行某种权值初始化,就可以用于训练数据。

今天介绍PyTorch中是如何实现模型创建的,具体内部的卷积、池化、激活函数等知识下次笔记介绍。上述的所有内容,在PyTorch中都有一个叫nn.Module的模块来实现。

看一个LeNet模型的例子:
LeNet网络结构
从上图可以看出LeNet模型经过了这样一个网络层的流程:
在这里插入图片描述
那我们要来搭建这个模型的话,就要先单独构建卷积层Conv,池化层pool,全连接层fc,然后按照上面的顺序进行拼接,拼接后的整体才是一个构建好的网络模型。

看一下LeNet的模型构建的代码:

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

可以看出__init__()函数实现了对每一个单独的网络层的构建,forward()函数实现了子网络层的拼接。

二、nn.Module

介绍nn.Module之前先看一下torch.nn里四个重要的模块:

  1. torch.nn.Parameter:张量的子类,表示可学习的参数,如weight、bias。
  2. torch.nn.Module:所有网络层的基类,管理网络属性。
  3. torch.nn.functional:函数具体实现,如卷积、池化、激活函数等。
  4. torch.nn.init:参数初始化的方法。

这里重点介绍nn.Parameter和nn.Module。

nn.Module来构建网络层时会创建8个字典管理它的不同属性,分别如下所示:

  • parameters:存储管理nn.Parameter类。
  • modules:存储管理nn.Module类。
  • buffers:存储管理缓冲属性,如BN层中的running_mean。
  • ×××_hooks(5个):存储管理钩子函数(目前不了解)。

下面的代码是创建一个module时对8个字典的初始化:

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

注意:

  1. 一个module可以包含多个子module,如LeNet是一个module,它包含了conv、fc等子module。
  2. 一个module相当于一个运算,必须实现forward函数。
  3. 每个module都有8个字典管理它的属性。

三、模型容器Container

模型容器有三种,如下图所示:
在这里插入图片描述

1.nn.Sequential

功能:是nn.Module的容器,用于按顺序包装一组网络层。

还是以LeNet为例,我们将LeNet分成features和classifier两部分,每个部分都是一个sequential:
在这里插入图片描述
代码如下:

class LeNetSequential(nn.Module):
    def __init__(self, classes):
        super(LeNetSequential, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes),)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

但是,这种构建网络的方式有一个小问题,每一层网络层都会自动按顺序编一个号作为name,如features这个Sequential里每层网络层在module属性内部是这样的:

在这里插入图片描述
这里只有六个网络层,所以还可以在短时间内找到你需要的那一个,但是当层数非常多的时候,这种数字命名的方式就很不友好,而Sequential也有相应的应对方法,即为每一层网络命名,具体代码如下所示:

class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()

        self.features = nn.Sequential(OrderedDict({
    
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),

            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
    
            'fc1': nn.Linear(16*5*5, 120),
            'relu3': nn.ReLU(),

            'fc2': nn.Linear(120, 84),
            'relu4': nn.ReLU(inplace=True),

            'fc3': nn.Linear(84, classes),
        }))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

与原来不同的地方就是,构建了一个OrderedDict字典来存放键值对,key就是每一层网络的名字,value就是具体的网络层实现,看一下此时的module属性内部:

在这里插入图片描述
这样就很好寻找所需要的某一层网络。

综上,Sequential的特点:

  1. 顺序性:各网络层之间严格按照顺序构建。
  2. 自带forward():通过for循环依次执行前向传播运算。

2.nn.ModuleList

也是nn.module的容器,用于包装一组网络层,以迭代方式调用网络层。

主要方法:

  1. append():在ModuleList后面添加网络层
  2. extend():拼接两个ModuleList
  3. insert():指定在ModuleList中位置插入网络层

这种容器比较适合构建大量重复的网络层,因为利用了迭代的方法,下面就是构建20个线性层的例子

class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x

在这里插入图片描述

3.nn.ModuleDict()

也是nn.module的容器,用于包装一组网络层,以索引方式调用网络层。

主要方法:

  1. clear():清空ModuleDict
  2. items():返回可迭代的键值对
  3. keys():返回字典的键
  4. values():返回字典的值
  5. pop():返回一对键值,并从字典中删除

这种容器的特点是,因为键值对可以索引的特性,可用于选择网络层:

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
    
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        self.activations = nn.ModuleDict({
    
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x


net = ModuleDict()

fake_img = torch.randn((4, 10, 32, 32))

output = net(fake_img, 'conv', 'relu')

print(output)

我们构建了conv、pool以及relu、prelu,然后我们选择使用conv和relu。

4.总结

对于上述提及的三种容器,它们各自的特点以及适用范围如下所示:

  1. nn.Sequential:顺序性,各层之间按顺序执行,常用于block的构建。
  2. nn.ModuleList:迭代性,常用于大量重复网络层的构建,通过for循环实现重复构建。
  3. nn.ModuleDict:索引性,常用于可选择的网络层的构建,通过字典的键值对实现选择。
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_40467656/article/details/108007685

智能推荐

安装Hadoop2.10.1_hadoop2.10.1下载-程序员宅基地

文章浏览阅读4k次。文章目录前言通过在Hadoop1安装Hadoop,然后配置相应的配置文件,最后将Hadoop 所有文件同步到其他Hadoop节点。一、集群规划#主机名‘master/hadoop1’ ‘slave01/hadoop2’ ‘slave02/hadoop3’#启动节点Namenode Nodemanager Nodemanager Resourcemanager _hadoop2.10.1下载

NUMA的取舍和优化设置_numa optimized-程序员宅基地

文章浏览阅读4.2k次。转载:https://www.cnblogs.com/wjoyxt/p/4804081.htmlNUMA的取舍与优化设置 在os层numa关闭时,打开bios层的numa会影响性能,QPS会下降15-30%; 在bios层面numa关闭时,无论os层面的numa是否打开,都不会影响性能。 安装numactl: #yum install numactl -y ..._numa optimized

DID去中心化身份认证技术调研_did过程-程序员宅基地

文章浏览阅读1.3w次。数字身份国际电子技术委员会将“身份”定义为“一组与实体关联的属性”。这里的实体不仅仅是人,对于机器或者物体都可以是实体,甚至网络中虚拟的东西也可以是实体并拥有身份。随着互联网的出现和普及,传统的身份有了另外一种表现形式,即数字身份。一般认为,数字身份的演进经历了以下四个阶段:(1)中心化身份:中心化身份是由单一的权威机构进行管理和控制的,现在互联网上的大多数身份还是中心化身份,比如ICANN管理的域名与IP地址分配,以及PKI(Public Key Infrastructure)系统中的CA(Cert_did过程

Ubuntu18.04安装fsl5.0步骤_怎么下载fsl v5.0版本-程序员宅基地

文章浏览阅读3.2k次,点赞3次,收藏18次。要做医学图像处理相关研究,fsl软件不可或缺。在我安装fsl的过程中遇到很多问题,根本原因是过于依赖网上的安装教程,从而忽视了官网教程。其实官网的教程虽然是英文,但是写的十分简单明了,按照官网教程逐步就可完成安装。此处为官网链接其实对比中文教程和官网教程,差别不大,但是最关键的一处区别为以下两句指令wget -O- http://neuro.debian.net/lists/trusty.c..._怎么下载fsl v5.0版本

编程语言的战争-程序员宅基地

文章浏览阅读971次。GO ,JAVA, C# ,JAVAScript(Html,Css),node.js, c++ /c,Rust,VB.net ,Python,R,Lua ,Ryby晕 ~~~~~~~~计算机语言的战争早就开始了。每一种语言都声称自己是最好的,自己的理念,是解决世界不二法门。可是身为一个不算年轻程序员狗,在面对那么多语言的选择的时候常常觉得自己很渺小,太渺小了,可是人类的好...

Delphi中URL的编码与解码,即urlencode的使用_delphi urlencode-程序员宅基地

文章浏览阅读8.4k次。一、URL简介 URL是网页的地址,比如 http://www.shanhaiMy.com。Web 浏览器通过 URL 从 web 服务器请求页面。 由于URL字符串常常会包含非ASCII字符,URL在传输过程中,往往出现错误。因此,可以将非字符串字符,让一些特殊ASCII字符组合,代替非ASCII字符。这就是编码转换,当字符串传输后,可以返回原RUL字符串(解码)。 UR..._delphi urlencode

随便推点

Ajax的常用技巧(3)---实现自动刷新页面._ajax页面自动刷新-程序员宅基地

文章浏览阅读1.5w次,点赞4次,收藏20次。网页自动刷新功能在web网站上已经屡见不鲜了,如即时新闻信息,股票信息等,都需要不断获取最新信息。在传统的web实现方式中,想要实现类似的效果,必须进行整个页面的刷新,在网络速度受到一定限制的情况下,这种因为一个局部变动而牵动整个页面的处理方式显得有些得不偿失。Ajax技术的出现很好的解决了这个问题,利用Ajax技术可以实现网页的局部刷新,只更新指定的数据,并不更新其他的数据。 现在创建一_ajax页面自动刷新

Chrome 设置使用已安装JRE的方式_chrome jre-程序员宅基地

文章浏览阅读5.5w次。情况描述此状况发生在Chrome 的较旧的版本上:机器已经安装过JRE 或是JDK, 但是每次打开Chrome 浏览器使用Applet时, 会报需要下载JRE的提示信息(最新的JRE1.7)。看上去, Chrome 并没有找到系统已经安装的JRE。相比而言, 对于IE和Firefox 浏览器。我们可以在Java 控制台设置使用的JRE版本和支持的浏览器, 而且我们可以更改浏览器_chrome jre

实验室linux服务器 配置anaconda+pytorch环境_linux 服务器 anaconda pytorch-程序员宅基地

文章浏览阅读539次。登录服务器后,0 安装anaconda1 进入 base 环境 : conda activate base2 #创建虚拟环境:建议不在base 环境中装pytorchconda create -n pytorch python=3.8 anaconda#pytorch 是虚拟环境的名字,你随意起一个也行#anaconda 让虚拟环境中有anaconda的各种第三方库,否则虚拟环境中就只有python的官方库,不方便.3 进入虚拟环境: conda activate pytorch4 _linux 服务器 anaconda pytorch

【STM32+cubemx】0026 HAL库开发:NRF24L01无线2.4G通信模块的应用_stm32cubemx nrf-程序员宅基地

文章浏览阅读1.6w次,点赞34次,收藏202次。NRF24L01是NORDIC公司生产的一款无线通信通信芯片,可以工作在免费开放的2.4GHz频段;通信速率可以达到最高2Mbps;MUC可以使用SPI接口与它交互。本节我们就来使用stm32驱动NRF24L01实现无线通信,先实现简单的一对一通信,然后讲解一对多通信,最后实现在ack中返回数据的应用。1)NRF24L01模块硬件介绍直接使用nrf24L01芯片搭建电路需要比较高的射频功底,一般情况下推荐使用现成的电路模块,本文使用的是下图这种:类似的nrf24L01模块的对外引_stm32cubemx nrf

中秋节及其典故来源英文怎么说?_博客英语中秋节-程序员宅基地

文章浏览阅读2.2k次。中秋节的英文是"Moon Festival"或者"Mid-Autumn Festival"The Mid-Autumn Festival , also known as the Lantern Festival, Moon Festival or Mooncake Festival, is a festival celebrated by the Chinese and East Asian people. It is the second most important Chinese festival_博客英语中秋节

C++学习笔记5——封装篇(下)-程序员宅基地

文章浏览阅读215次。深拷贝浅拷贝定义一个类叫Array,定义一个数据成员 m_iCount ,在构造函数中给m_iCount赋初值为5,在拷贝构造函数中,传入的参数是arr,其数据类型也是Array,所以该拷贝构造函数中,也有数据成员m_iCount。拷贝构造函数中的语句是把arr中的m_iCount赋值给本身Array这个中的数据成员m_iCount。使用时,若采用图中的方式进行实例化,那么会调用arr1的构...