Tensorflow笔记----自定义网络、模型保存与加载_刘星星儿的博客-程序员宝宝_tensorflow 自定义模型保存

技术标签: Tensorflow  tensorflow  

一.自定义网络(CustomNetwork)

通过自定义的网络我们可以将一些现有的网络和我们自己的网络串联起来,从而实现各种高效的网络。

  • Keras.Sequential:可以将现有的层跟我们自己的层串联在一起,也可以很方便的组织层的参数;不过我们要使用Sequential需要准守一些协议:
    1.我们自定义的层必须继承自Keras.layers.Layer;
    2.我们自己的模型需要继承自Keras.Model;
model = Sequential([    #五层的网络的一个容器
    layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]    #降维
    layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
    layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
    layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
    layers.Dense(10) # [b, 32] => [b, 10], 参数量330 = 32*10 + 10
])

以上代码使用了Sequential容器,只通过五行的代码就完成了5层神经网络的搭建,十分方便快捷;还可以使用model.trainable_variables来集中管理各种参数;

  • Keras.layers.Layer / Keras.Model:都是自定义层的母类,提供了很多属性,我们只需要实现init初始化函数以及call函数;代码如下:
class MyDense(layers.Layer):    #自己实现一个线性层Dense,继承自layers.Layer
    def __init__(self,inp_dim,outp_dim):    #实现初始化方法,必须要有
        super(MyDense,self).__init__()  #调用母类的初始化函数,必须要有

        #创建遍历
        self.kernel = self.add_variable('w',[inp_dim,outp_dim])
        self.bias = self.add_variable('b',[outp_dim])

    def call(self, inputs, training=None):  #实现call方法,必须要有
        out = inputs @ self.kernel + self.bias  #矩阵相乘再相加
        return out

class MyModel(tf.keras.Model):  #自己实现一个模型,必须继承自tf.keras.Model
    def __init__(self): #实现初始化方法,必须要有
        super(MyModel,self).__init__()  #调用母类的初始化函数,必须要有
        self.fc1 = MyDense(28*28,256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        x = self.fc1(inputs)
        x = tf.nn.relu(x)   #激活函数,一遍一遍的通过,从而形成新的x
        x = self.fc2(inputs)
        x = tf.nn.relu(x)
        x = self.fc3(inputs)
        x = tf.nn.relu(x)
        x = self.fc4(inputs)
        x = tf.nn.relu(x)
        x = self.fc5(inputs)    #最后一遍 一般不通过激活函数
        return x

二.模型保存与加载(ModelSavingAndLoading)

在模型训练过程中,我们可能遇到各种问题,比如断电、终止等,如果不对之前运行的数据进行保存,那我们可能就必须重新训练;

所以我们必须定时进行模型的保存与加载,即使遇到上述的情况,也可以从上次最后一次保存的状态开始。有三种模式:

Save / load weight:最干净最轻量级的方式,只保存网络的参数,其他状态通过不管,适合自己对代码有精细控制且自己拥有源代码的情况。

model.save_weight('保存路径')   #保存
model = create_model()  #加载
model.load_weight('保存路径')

Save / load entire model:最简单粗暴的方式,把所有的状态都保存起来,可以随意的进行恢复,但效率可能有点低。

network.save('model.h5')    #保存
network = tf.keras.models.load_model('model.h5')    #加载

Saved_model:最通用的方式:

tf.saved_model.save(m,'路径') #保存
imported = tf.saved_model.load(path) #加载
f = imported.signatures["serving_default"]
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43580130/article/details/108133790

智能推荐

作为程序员,你最常上的网站是什么_weixin_30840253的博客-程序员宝宝

Wikipedia当然是最应该先查阅的网站,作为了解概念、背景和相关知识的绝佳索引,能够引导走向包含详细技术细节文档,以及一些相关的技术和概念。http://www.wikipedia.org/Google还用多说么?Google it,什么都有!如果需要论文,那怎么少得了Google Scholar,还有Springer、IEEE Xplore和ACM Digital Library...

HTML基础-跟着李南江学编程_子.星语的博客-程序员宝宝_李南江博客

HTML5定义了一组新的语义化的结构标记来描述网页内容。虽然语义化结构标记可以被HTML4标记进行替换,但是他可以简化HTML页面设计,明确的语义化更适合搜索引擎检索和抓取。header:表示页面中一个内容区块或整个页面的标题footer:表示整个页面或页面中一个内容区块的脚注。section:表示页面中的一个内容区块,如章节、页眉、页脚或页面中的其他部分。article:表示页面中的一块与上下文不相关的独立内容,如博客中的一篇文章aside:表示article元素的内容之外的、与article

IDEA2018.2.8的安装_苍山如海z的博客-程序员宝宝_intellij idea2018.2.8

1.官网下载2018.2月版本。(other version->选中2018.2)2.下载JetbrainsCrack_jb51.rarhttp://wangshuo.jb51.net:81/201904/tools/JetbrainsCrack_jb51.rar3.解压缩,找到JetbrainsCrack-3.4-release-enc.jar。 把jar包复制到 idea b...

【技术综述】一文道尽R-CNN系列目标检测_言有三的博客-程序员宝宝

文章首发于微信公众号《有三学AI》【技术综述】一文道尽R-CNN系列目标检测目标检测任务关注的是图片中特定目标物体的位置。一个检测任务包含两个子任务,其一是输出这一目标的类别信息,属于分类任务。其二是输出目标的具体位置信息,属于定位任务。分类的结果是一个类别标签,对于单分类任务而言,它就是一个数,对于多分类任务,就是一个向量。定位任务的输出是一个位置,用矩形框表示,包含矩形框左上角或中...

css的vm,vm · CSS 3 中文手册 · 看云_Jakcwin的博客-程序员宝宝

[CSS参考手册](http://css.doyoe.com/)»[单位列表](#)»[长度单位](#)»相关内容:[**其它长度单位参考**选择其它项](#)- [em](#)- [ex](#)- [ch](#)- [rem](#)- [vw](#)- [vh](#)- [vm](#)- [cm](#)- [mm](#)- [in](#)- [pt](#)- [pc](#)- [px](#)# ...

虚拟机栈中的局部变量表、操作数栈、动态链接、方法返回地址、栈顶缓存技术_执傲i的博客-程序员宝宝

局部变量表局部变量表也被称之为局部变量数组或本地变量表定义为一个数字数组,主要用于存储方法参数和定义在方法体内的局部变量,这些数据类型包括各种基本数据类型,对象引用,以及returnAddress类型由于局部变量表是建立在线程的栈上,是线程的私有数据,因为不存在数据安全问题。局部变量表所需的容量大小是在编译期确定下来的,并保存在方法的Code属性的maximum local variables数据项中,在方法运行期间是不会改变局部变量表的大小的。方法嵌套调用的次数由栈的大小决定

随便推点

CAD图纸中插入或删除编号的图文教程_weixin_44684531的博客-程序员宝宝

方法一:1、 首先,同时按CTRL+A键进行全选图形后,接着输入命令ss,然后回车2、然后,在提示框中输入122,回车。这时就可以发现,大于等122的编号都自动增加了1,接着复制任何一个编号粘贴到缺少编号的阀门附件,并修改为122,就完成了3、需要删除时:以编号97为例子,删除后,大于97的编号要自动减14、按CTRL+A键进行全选图形后,输入dd命令,然后回车...

字节跳动2019年春季实习招聘机器学习算法岗第二批笔试题及解答_skj1995的博客-程序员宝宝

以下解答是我自己的解法,有待优化,仅供参考!第一题:.题目描述公司的程序员不够用了,决定把产品经理都转变为程序员以解决开发时间长的问题。在给定的矩形网格中,每个单元格可以有以下三个值之一:.值0代表空单元格;.值1代表产品经理;.值2代表程序员;每分钟,任何与程序员(在4个正方向上)相邻的产品经理都会变成程序员。返回直到单元格中没有产品经理为止所必须经过的最小分钟数。如...

Mybatis(19)注解实现多表查询_小布米的博客-程序员宝宝

两个类,User,Account查询每个账户以及其对应的用户信息,一对一查询,采用及时加载Account类package com.itheima.domain;import java.io.Serializable;public class Account implements Serializable { private Integer id; private Integer uid; private Double money; //多对一(mybat

kissy ajax,KISSY - A Powerful JavaScript Framework_闲白客的博客-程序员宝宝

快速开始1,复制 & 粘贴种子文件是一个非常小的 JS 文件,通过他可以动态加载 KISSY 的模块文件,因为体积很小,推荐将种子文件至于标签内。data-config="{combine:true}" 表示启用服务器 combo 机制,可用于减少网络请求数目。2,开始使用 KISSY// 创建一个 KISSY 沙箱KISSY.use('node',function(S,Node){// ...

如何在MySQL中设置主从复制_dayou7738的博客-程序员宝宝

mysql主从同步定义 主从同步机制配置主从同步 配置主服务器 配置从服务器使用主从同步来备份 使用mysqldump来备份 备份原始文件主从同步的小技巧排错 Slave_IO_Running: NO Slave_SQL_Running: Nomysql主从同步定义主从同步使得数据...

推荐文章

热门文章

相关标签