pytorch学习-使用torch.nn模块自定义模型_from torch import nn-程序员宅基地

技术标签: pytorch  

使用nn.Module构建神经网络

完成模型的训练包括:参数求导、梯度计算、参数更新以及训练过程的控制

nn.Module是PyTorch提供的神经网络类,并在类中实现了网络各层的定义及前向计算与反向传播机制。
在实际使用时,如果想要实现某个神经网络,只需继承nn.Module,在初始化中定义模型结构与参数,在函数forward()中编写网络前向过程即可。


下面具体以一个由两个全连接层组成的感知机为例,介绍如何使用nn.Module构造模块化的神经网络
可以把代码的两个类分为四个步骤:

1.初始化层参数
2.定义层结构
3.定义所有网络层
4.定义完整的模型

import torch
from torch import nn

class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        # 1.初始化层参数
        super(Linear, self).__init__()
        self.w = nn.Parameter(torch.randn(in_dim, out_dim))
        self.b = nn.Parameter(torch.randn(out_dim))

    def forward(self, x):
        # 2.定义层结构
        x = x.matmul(self.w)
        y = x + self.b.expand_as(x)
        return y

class Perception(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        # 3.定义所有网络层
        super(Perception, self).__init__()
        self.layer1 = Linear(in_dim, hid_dim)
        self.layer2 = Linear(hid_dim, out_dim)

    def forward(self, x):
        # 4.定义模型
        x = self.layer1(x)
        y = torch.sigmoid(x)
        y = self.layer2(y)
        y = torch.sigmoid(y)
        return y

说明:


1 打印网络perception,可以看到上述定义的layer1和layer2

perception = Perception(2,3,2)  # 输入样本维度为2,输出维度也为2	
print(perception)

在这里插入图片描述

2 named_parameters()可以返回学习参数的参数名与参数值

for name, parameter in perception.named_parameters():
    print(name, parameter)  # print(name, ':', parameter.size())

在这里插入图片描述

3 将输入数据传入perception,perception()相当于调用perception中的forward()函数

data = torch.randn(4,2)
print("data:", data)

output = perception(data)
print("output:", output)

在这里插入图片描述

4 nn.Parameter函数

在类的__init__()中需要定义网络学习的参数,在此使用nn.Parameter()函数定义了全连接中的ω和b,这是一种特殊的Tensor的构造方法,默认需要求导,即requires_grad为True。


5 forward()函数与反向传播

forward()函数用来进行网络的前向传播,并需要传入相应的Tensor,例如上例的perception(data)即是直接调用了forward()。在具体底层实现中,perception.call(data)将类的实例perception变成了可调用对象perception(data),而在perception.call(data)中主要调用了forward()函数。nn.Module可以自动利用Autograd机制实现反向传播,不需要自己手动实现。


6 多个module的嵌套

在Module的搭建时,可以嵌套包含子Module,上例的Perception中调用了Linear这个类,这样的代码分布可以使网络更加模块化,提升代码的复用性。在实际的应用中,PyTorch也提供了绝大多数的网络层,如全连接、卷积网络中的卷积、池化等,并自动实现前向与反向传播。

7 nn.Module与nn.functional库

在PyTorch中,还有一个库为nn.functional,同样也提供了很多网络层与函数功能,但与nn.Module不同的是,利用nn.functional定义的网络层不可自动学习参数,还需要使用nn.Parameter封装。nn.functional的设计初衷是对于一些不需要学习参数的层,如激活层、BN(Batch Normalization)层,可以使用nn.functional,这样这些层就不需要在nn.Module中定义了。

总体来看,对于需要学习参数的层,最好使用nn.Module,对于无参数学习的层,可以使用nn.functional,当然这两者间并没有严格的好坏之分。


8 nn.sequential()模块

当模型中只是简单的前馈网络时,即上一层的输出直接作为下一层的输入,这时可以采用nn.Sequential()模块来快速搭建模型,而不必手动在forward()函数中一层一层地前向传播。因此,如果想快速搭建模型而不考虑中间过程的话,推荐使用nn.Sequential()模块。

在上面的例子中,Perception类中的layer1与layer2是直接传递的,因此该Perception类可以使用nn.Sequential()快速搭建。上面代码可改写为:

import torch
from torch import nn

class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        # 1.初始化网络层参数
        super(Linear, self).__init__()
        self.w = nn.Parameter(torch.randn(in_dim, out_dim))
        self.b = nn.Parameter(torch.randn(out_dim))

    def forward(self, x):
        # 2.定义网络层结构
        x = x.matmul(self.w)
        y = x + self.b.expand_as(x)
        return y

class Perception(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        # 3.定义所有网络层
        super(Perception, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.Sigmoid(),
            nn.Linear(hid_dim, out_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        # 4.定义模型
        y = self.layer(x)
        return y

perception = Perception(100,1000,100)
print(perception)

在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Joseph__Lagrange/article/details/109678866

智能推荐

GCC 分别64位和32为体系_gcc查看64 32位-程序员宅基地

文章浏览阅读1k次。#if defined(__X86_64__) || defined(__X86_64) || defined(__amd_64) || defined(__amd_64__)_gcc查看64 32位

剑指offer-js 二进制中1的个数-程序员宅基地

文章浏览阅读268次。二进制中1的个数题目描述:输入一个整数,输出该数二进制表示中1的个数。其中负数用补码表示。问题分析:将数字转换成二进制数字,如果是直接求的话,注意数组反转,反转之后,若不够32位进行补零负数的补码: 求出绝对值的原码,从最后一个数字开始往前数,到第一个1之后把前面的全部取反代码展示:function NumberOf1(n){ //还可以使用二进制转换来减少代码量 num.toString(2) var nums = []; var num = 0; _js 二进制中1的个数

Drawable资源_drawable 资源-程序员宅基地

文章浏览阅读468次。Drawable资源是对图像的一个抽象,你可以通过getDrawable(int)得到并绘制到屏幕上。这里有几种不同类型的Drawable: Bitmap File 一个Bitmap图像文件(.png、.jpg或.gif)。BitmapDrawable。 Nine-Patch File 一个带有伸缩区域的PNG文件,可以基于content伸缩图片(.9.png)_drawable 资源

[ 移植 ] ___ Library : Iconv_iconv 嵌入式-程序员宅基地

文章浏览阅读1.5k次。简介国际文本大多以Unicode编码。然而,由于历史原因,有时仍使用与语言或国家/地区相关的字符编码对其进行编码,随着互联网的出现和国家间文本的频繁交换,在这种情况下,即使从外国查看网页也是一种文本交换,这些编码之间的转换已成为必要。特别是,具有Windows操作系统的计算机仍然在具有传统(有限)字符编码的区域设置中运行。某些程序,如邮件程序和Web浏览器,必须能够在给定的文本编码和用户的编码之间进行转换。其他程序在内部以Unicode存储字符串,以方便内部处理,并且在进行I/O时需要在内部字符串表示_iconv 嵌入式

【ThinkPHP】修改伪静态后缀名_thinkphp 添加伪静态后 请求后缀怎么设为php-程序员宅基地

文章浏览阅读1.5k次。伪静态让url看起来好看.还有ThinkPHP的路由功能,视频说是对seo有优化如图:在项目里面的Conf/config.php里面写$arr=array('URL_HTML_SUFFIX' => '.html'//'URL_HTML_SUFFIX' => '你想要的后续名')return $arr;_thinkphp 添加伪静态后 请求后缀怎么设为php

SQL报错注入的12个函数及sql注入语句-程序员宅基地

文章浏览阅读1.2k次。转来的 侵删1、通过floor报错,注入语句如下: andselect1from(selectcount(*),concat(version(),floor(rand(0)*2))xfrominformation_schema.tablesgroupbyx)a);2、通过ExtractValue报错,注入语句如下:andextractva..._报错注入 十二个函数

随便推点

1.1 仿射集与凸集-1_仿射集为相对开集-程序员宅基地

文章浏览阅读211次。2021 年 2 月 5 日星期五,想喝奶茶的一天。想要 coco 今日推出的奶茶包,可惜要 88 赞。废话少说,进入正题。1 凸集与凸函数1.1 仿射集与凸集-1定义 1.1.1 若集合 W⊆RnW\subseteq\mathbb{R}^nW⊆Rn 中任意两个不同元素对于加法和数乘皆是封闭的,即任取 x1,x2∈Wx_1,x_2\in Wx1​,x2​∈W,λ1,λ2∈R\lambda_1,\lambda_2\in\mathbb{R}λ1​,λ2​∈R 皆有λ1x1+λ2x2∈W,\lambd_仿射集为相对开集

mysql mysqldumpslow_慢日志分析工具—mysqldumpslow 和 mysqlsla-程序员宅基地

文章浏览阅读328次。前提:分析mysql性能的时候会查看数据库的哪些sql语句有问题,效率低。这就用到了数据库的慢查询,作用就是: 它能记录下所有执行超过long_query_time时间的SQL语句,帮你找到执行慢的SQL,方便我们对这些SQL进行优化。1.配置慢查询在mysql客户端执行mysql> show variables like "%query%" ;得到结果:设置慢查询.png涉及参数解释:sl..._mysqldumpslow /var/lib/mysql/slow.log

黑苹果记录-程序员宅基地

文章浏览阅读3.6k次。--2018-4-26 更新一般情况下,小版本的更新直接使用app store 进行更新就行。主要是参照tonymacx86 。这里记一下一些相关问题。如果NVIDIA 官方 驱动没有更新,可以用下面的脚本进行打补丁bash <(curl -s https://raw.githubusercontent.com/Benjamin-Dobell/nvidia-update/master/nvi...

Nginx 又一牛 X 功能:流量拷贝_nginx x-origin-uri-程序员宅基地

文章浏览阅读844次。1. 需求 将生产环境的流量拷贝到预上线环境或测试环境,这样做有很多好处,比如:可以验证功能是否正常,以及服务的性能;用真实有效的流量请求去验证,又不用造数据,不影响线上正常访问;这跟灰度发布还不太一样,镜像流量不会影响真实流量;可以用来排查线上问题;重构,假如服务做了重构,这也是一种测试方式;为了实现流量拷贝,Nginx提供了ngx_http_mirror_modu..._nginx x-origin-uri

写prime函数判断一个数是否是素数(C语言 + 详细注释)-程序员宅基地

文章浏览阅读2.3w次,点赞16次,收藏43次。int prime( int p ){ if(p <= 3) //或者if (p <= 1) return 0; return p > 1; int i; for(i = 2; i * i <= p; i++) // i只需要遍历到根号p以节省时间,且等号不能少,否则像4,9等数就会判断错误 ..._prime函数