博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow 基础学习七:模型的持久化
阅读量:4970 次
发布时间:2019-06-12

本文共 4879 字,大约阅读时间需要 16 分钟。

tf.train.Saver类的使用

保存模型:

import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2')result=v1+v2init_op=tf.global_variables_initializer()saver=tf.train.Saver()with tf.Session() as sess:    sess.run(init_op)    saver.save(sess,'log/model.ckpt')

加载模型:

import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2')result=v1+v2saver=tf.train.Saver()with tf.Session() as sess:    ckpt = tf.train.get_checkpoint_state('log')      if ckpt and ckpt.model_checkpoint_path:          saver.restore(sess, ckpt.model_checkpoint_path)

  在加载模型时,也是先定义tensorflow计算图上的所有运算,但不需要运行变量的初始化,因为变量的值可以通过已经保存的模型加载进来。如果不希望重复定义图上的运算,也可以直接加载已经 持久化的图。

加载计算图:

import tensorflow as tf# 直接加载持久化的图saver=tf.train.import_meta_graph('log/model.ckpt.meta')with tf.Session() as sess:    saver.restore(sess,'log/model.ckpt')    # 通过张量的名称来获取张量    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

tf.train.Saver类还支持在保存和加载模型时给变量重命名。

在加载模型时给变量重命名:

import tensorflow as tf# 这里声明的变量名称和已经保存的模型中变量的名称不同。v1=tf.Variable(tf.constant(1.0,shape=[1]),name='new-v1')v2=tf.Variable(tf.constant(2.0,shape=[1]),name='new-v2')# 直接使用tf.train.Saver()加载模型会提示变量找不到的错误# 需要使用一个字典来重命名变量。这个字典指定# 原来名称为v1的变量现在加载在变量v1中('new-v1'),名称为v2的变量加载到# 变量v2中('new-v2')saver=tf.train.Saver({
'v1':v1,'v2':v2})with tf.Session() as sess: saver.restore(sess, 'log/model.ckpt')

  重命名的好处是可以方便使用变量的滑动平均值。使用变量的滑动平均值可以让神经网络模型更加健壮。在tensorflow中,每一个变量的滑动平均值是通过影子变量维护的,获取变量的滑动平均值实际上就是获取这个影子变量的取值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。这样方便了滑动平均模型的使用。以下代码给出了一个保存滑动平均模型的样例。

import tensorflow as tf v=tf.Variable(0,dtype=tf.float32,name='v')# 没有声明滑动平均模型时,只有一个变量v,所以下面语句只会输出'v:0'for variables in tf.global_variables():    print(variables.name)    ema=tf.train.ExponentialMovingAverage(0.99)maintain_averages_op=ema.apply(tf.global_variables())# 在声明滑动平均模型后,tensorflow会自动生成一个影子变量# 下面语句会输出:'v:0'和'v/ExponentialMovingAverage:0'for variables in tf.global_variables():    print(variables.name)    saver=tf.train.Saver()with tf.Session() as sess:    init_op=tf.global_variables_initializer()    sess.run(init_op)    sess.run(tf.assign(v,10))    sess.run(maintain_averages_op)    # 保存时,tensorflow会将'v:0'和'v/ExponentialMovingAverage:0'两个变量都保存下来    saver.save(sess,'log/model.ckpt')    print(sess.run([v,ema.average(v)])) # 输出:[10.0, 0.099999905]

  基于上面的代码,通过变量重命名直接读取变量的滑动平均值。从程序输出可以看出,读取的变量v的值实际上是上面代码中变量v的滑动平均值。通过该方法,就可以使用完全一样的代码来计算滑动平均模型前向传播的结果。

v=tf.Variable(0,dtype=tf.float32,name='v')# 通过变量重命名将原来变量v的滑动平均值直接赋值给vsaver=tf.train.Saver({
'v/ExponentialMovingAverage':v})with tf.Session() as sess: saver.restore(sess,'log/model.ckpt') print(sess.run(v)) # 输出:0.099999905

  为了方便加载时重命名滑动平均变量,tf.train.ExpoentialMovingAverage类提供了variables_to_restore函数来生成tf.train.Saver类所需要的变量重命名字典。示例代码如下:

import tensorflow as tf v=tf.Variable(0,dtype=tf.float32,name='v')ema=tf.train.ExponentialMovingAverage(0.99)# 通过使用variables_to_restore函数来直接生成上面代码中提供的字典# {'v/ExponentialMovingAverage':v}# 以下代码会输出:# {'v/ExponentialMovingAverage': 
}print(ema.variables_to_restore())saver=tf.train.Saver(ema.variables_to_restore())with tf.Session() as sess: saver.restore(sess,'log/model.ckpt') print(sess.run(v)) # 输出:0.099999905

   tf.train.Saver的缺点就是每次会保存程序的全部信息,但有时并不需要全部信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似于变量初始化、模型保存等辅助结点的信息。而且,将变量取值和计算图结构分成不同文件存储有时候也不方便,tensorflow中提供了convert_variables_to_constants函数,可以将计算图中的变量及其取值通过常量的方式保存,这样可以将整个计算图统一存放在一个文件中。示例代码如下:

 

import tensorflow as tffrom tensorflow.python.framework import graph_utilv1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2')result=v1+v2init_op=tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init_op)    # 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算    # 过程    graph_def=tf.get_default_graph().as_graph_def()        # 将图中的变量及其取值转化为常量,同时将图中不必要的结点去掉。在下面一行代码中,    # 最后一个参数['add']给出了需要保存的节点名称。add节点是上面定义的两个变量相加    # 的操作。注意,'add:0'表示某个计算节点的第一个输出,是一个张量名。    output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])    # 将导出的模型存入文件    with tf.gfile.GFile('log/combined_model.pb','wb') as f:        f.write(output_graph_def.SerializeToString())

 

  通过下面的代码可以直接计算定义的加法运算的结果。这种方法可以使用训练的模型完成迁移学习

import tensorflow as tfwith tf.Session() as sess:    model_filename='log/combined_model.pb'    # 读取保存的模型文件,将文件解析成对应的GraphDef Protocol Buffer    with tf.gfile.FastGFile(model_filename,'rb') as f:        graph_def=tf.GraphDef()        graph_def.ParseFromString(f.read())        # 将graph_def中保存的图加载到当前图中。return_elements=['add:0']给出了返回的    # 张量的名称。在保存时给出的是计算节点的名称,所以为'add'。在加载时给出的是    # 张量的名称,所以是add:0    result=tf.import_graph_def(graph_def,return_elements=['add:0'])    print(sess.run(result))

 

转载于:https://www.cnblogs.com/hypnus-ly/p/8319571.html

你可能感兴趣的文章
使用 ref 和 out 传递数组注意事项
查看>>
三个线程ABC,交替打印ABC
查看>>
flex4.6 图表 在module中 x轴旋转正确的做法
查看>>
LeetCode Binary Tree Preorder Traversal
查看>>
spark_to_kakfa
查看>>
The superclass "javax.servlet.http.HttpServlet" was not found on the Java Build Path 解决方法
查看>>
combobox和textbox中输入数据为非数字leave时的公用事件,只需要在控件的leave事件中选择本事件即可...
查看>>
[原创]如何确保JavaScript的执行顺序 –之jQuery1.5.1与jQuery1.4.4的差异
查看>>
Java生鲜电商平台-源码地址公布与思考和建议
查看>>
一个数据库小题目
查看>>
小学生运算题目生成器说明书
查看>>
20145104张家明 《Java程序设计》第三次实验设计
查看>>
fetch 添加请求头headers
查看>>
【javascript】javascript中function(){},function(){}(),new function(){},new Function()
查看>>
Linux命令—tar
查看>>
BLE GATT 介绍
查看>>
唤起支付宝的链接地址
查看>>
cu命令
查看>>
form标签之form:checkboxes
查看>>
纵越6省1市-重新启动
查看>>