好玩的人工智能
快乐的深度学习

tensorflow学习之: session run

session.run([fetch1, fetch2])
关于 session.run([fetch1, fetch2]),请看http://stackoverflow.com/questions/42407611/how-tensorflow-handle-the-computional-graph-when-executing-sess-run/42408368?noredirect=1#comment71994086_42408368

执行sess.run()时,tensorflow是否计算了整个图
我们在编写代码的时候,总是要先定义好整个图,然后才调用sess.run()。那么调用sess.run()的时候,程序是否执行了整个图

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
state = tf.Variable(0.0,dtype=tf.float32)
one = tf.constant(1.0,dtype=tf.float32)
new_val = tf.add(state, one)
update = tf.assign(state, new_val) #返回tensor, 值为new_val
update2 = tf.assign(state, 10000) #没有fetch,便没有执行
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for _ in range(3):
print sess.run(update)

和上个程序差不多,但我们这次仅仅是fetch “update”,输出是1.0 , 2.0, 3.0,可以看出,tensorflow并没有计算整个图,只是计算了与想要fetch 的值相关的部分

sess.run() 中的feed_dict
我们都知道feed_dict的作用是给使用placeholder创建出来的tensor赋值。其实,他的作用更加广泛:feed 使用一个 值临时替换一个 op 的输出结果. 你可以提供 feed 数据作为 run() 调用的参数. feed 只在调用它的方法内有效, 方法结束, feed 就会消失.

import tensorflow as tf
y = tf.Variable(1)
b = tf.identity(y)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(b,feed_dict={y:3})) #使用3 替换掉
#tf.Variable(1)的输出结果,所以打印出来3
#feed_dict{y.name:3} 和上面写法等价

1
print(sess.run(b))

#由于feed只在调用他的方法范围内有效,所以这个打印的结果是 1

Session(会话)
**
在tensorflow中数据流图中的Op在得到执行之前,必须先创建Session对象,
Session对象负责着图中所以Op的执行.
Session 对象创建时有三个可选参数:
1.target,在不是分布式中使用Session对象时,该参数默认为空
2.graph,指定了Session对象中加载的Graph对象,如果不指定的话默认加载当前默认的数据流图,但是如果有多个图,就需要传入加载的图对象
3.config,Session对象的一些配置信息,CPU/GPU使用上的一些限制,或者一些优化设置

1
2
3
4
5
6
7
8
9
import tensorflow as tf
# 在默认数据流图中创建Op,tensor
# tf.add()是一个Operation简称Op,a是一个tensor
a = tf.add(1, 2)
b = tf.add(a, 3)
# 创建一个Session对象,没有指定graph参数默认加载,默认数据流图
sess = tf.Session()
sess = tf.Session(graph=tf.get_default_graph)
# 以上两种方式是等价的,都是加载默认数据流图
1
2
3
4
Session对象的run()方法
run()计算张量对象的输出,
sess.run(a)    # 输出3
sess.run(b)    # 输出6

run()方法接受一个参数和三个可选参数:
1.fecthes,接受数据流图中的(所有的Op和tensor),也就是希望执行的对象,tensor对象一般会返回数或数组,Op没有返回值None.

# 在sess.run(b)中,fecthes的参数是tensor:b,(对应的Op是tf.add()),
# sess会找到与b有数据依赖的节点,然后顺序执行
sess.run(b)

1
2
3
4
#fecthes也会接受Op,比如变量的初始化:
init = tf.global_variables_initializer()
sess.run(init)
# 返回值为None

2.feed_dict参数,用于覆盖数据流图中的tensor对象,接受的参数类型为Python的字典对象
字典的’键’为被覆盖的tensor对象的句柄,’值’可以是各种数据类型,但是必须和被覆盖tensor
的类型相同或者能够转换为相同类型

1
2
3
4
5
import tensorflow as tf
a = tf.add(1, 2)
b = tf.add(a, 10)
sess = tf.Session()
sess.run(b)
1
2
3
# 定义一个字典,覆盖tensor a
replace_dict = {a : 20}
sess.run(b, feed_dict=replace_dict)

关于回话对象使用结束后需要关闭调用close方法,释放资源sess.close()
# 将Session对象作为上下文管理器来使用,离开作用域,Session对象会自动关闭

1
2
3
4
5
6
7
8
9
with tf.Session() as sess:
#.......Op,tensor
-----------------------------------
# 也可以像图对象一样被隐式使用
sess = tf.Session()
with sess.as_default():
a.eval()       # 类似与run(a)
# 但是必须手动关闭
sess.close()

tensorflow中的占位节点
占位节点是用来接收输入值的,它们的作用和tensor对象类似,在创建时不用指定具体的数值,它们的作用是为将要用到的tensor对象预留位置,相当于输入节点

1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
import numpy as np
# 创建一个占位节点tf.placeholder()
# 第一参数是dtype是必须指定的,
# 第二个参数shape,默认为None,接收任意形状的tensor,长度为2的一阶张量
# 第三个参数name,用来标识这个Op
a = tf.placeholder(tf.int32, shape=[3], name='input')
b = tf.reduce_sum(a, name='sum')
# 在Session对象的run()方法中,通过字典给占位节点传递值
# 在计算b时,在run()方法中,只用给b有依赖的占位节点,覆盖数值
# 没有依赖的节点不需要包含在feed_dict中
sess = tf.Session()
sess.run(b, feed_dict={a: np.array([1, 2, 3], dtype=np.int32)})
未经允许不得转载:零点智能 » tensorflow学习之: session run
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

零点智能 人工智能社区,加Q群:469331966

投稿&建议&加Q群