参考内容都出自于官方API教程tf.Session
一、Session类基本使用方法
这里使用的是1.15版本,TF官方为了能够在2.0+版本中保持兼容,因此调用时使用了tf.compat.v1.Session。
定义:一个Session是对环境的封装,环境中包含执行/executed过的Operation和评估/evaluated过的Tensor。
一个Session会含有很多资源,例如Variable、QueueBase、RenderBase等。当Session运行结束后需要通过**Session().close()**方法释放资源。
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
sess = tf.Session()
print(sess.run(c))
sess.close()
with tf.Session() as sess:
sess.run(...)
Session在创建时,构造函数__init__可以指定三个参数:
- target:一般用于分布式TF中,用于连接执行引擎;
- graph:此Session要launch的图,不指定则为默认;
- config:是一个protocol buffer:ConfigProto
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True)
)
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
with tf.Session(config=tfconfig) as sess:
二、Properties
Session中有很多properties,这些可以直接通过Session().调用查看内部的值。此处介绍Session中关于graph的两个properties:
- graph:返回此Session中已经launch的graph。
- graph_def:将构建的TF图以串行化方式显示出来。
a = tf.constant(1.0)
sess = tf.Session()
assert sess.graph == tf.get_default_graph()
print(sess.graph)
print(sess.graph_def)
二、Methods
并没有全列出来,见一个记录一个:
1.as_default()
返回一个context manager将当前Session设置为默认。一般结合with语句进行使用。as_default()语句并不会结束Session,必须调用close()方法手动结束。
c = tf.constant(1.0)
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
sess = tf.Session(config=tfconfig)
with sess.as_default():
assert tf.get_default_session() is sess
print tf.get_default_session()
print(c.eval())
print(sess.run(c))
print(c)
在会话中进行运算并取值一共有三种:
- tf.Operation.run():针对操作,Run operations
- tf.Tensor.eval():针对张量,Evaluate tensors
- sess.run():在全局空间取值,下部分会讲
2.run()
用于运行Operations和评估Tensors。
run(
fetches,
feed_dict=None,
options=None,
run_metadata=None
)
三、应用实例:多Graph多Session
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
g1 = tf.Graph()
g_ = tf.Graph()
g2 = tf.Graph()
with g1.as_default():
a = tf.constant(2.0)
b = tf.constant(3.0)
c = tf.constant(4.0)
d = tf.multiply(a, b) + c
with g2.as_default():
e = tf.constant(4.0)
f = tf.constant(6.0)
g = tf.constant(6.0)
h = tf.multiply(e, f) + g
with g_.as_default():
i = tf.placeholder(dtype=tf.float32, shape=[], name='G_Input')
j = i + 100.0
sess1 = tf.Session(graph=g1, config=tfconfig)
sess_ = tf.Session(graph=g_, config=tfconfig)
sess2 = tf.Session(graph=g2, config=tfconfig)
print('result of Net1: ', sess1.run(d))
print('result of Net2: ', sess2.run(h))
print(sess_.run(fetches=j, feed_dict={i: sess1.run(d)}))
print(sess_.run(fetches=j, feed_dict={i: sess2.run(h)}))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)