Win10 RTX30系列 安装tensorflow1.15
1.遇到的问题:
直接PiP安装,能够安装完成。
pip install tensorflow-gpu==1.15
并且测试TF的版本和显卡是否正确,也都正常。
import tensorflow as tf
print(tf.__version__)
print(tf.test.is_gpu_available())
但是:
训练会卡住,也不报错,就卡着。
遂安装失败。
2.最近发现了一个知乎大神的版本
废话不多说,上链接: https://zhuanlan.zhihu.com/p/356526953
作者已经有打包好的了,但是没有把rtx30的二进制码包进去。
下面是作者编译的老黄魔改版, python3.7 cuda11.2 cudnn8.1.0。
链接:https://pan.baidu.com/s/1hw50gNAzh9A7dWTJQWwCwg
提取码: dg4q
老黄魔改版tf1.15的c++版,和Python配套的,可以给30系的卡使用,不保证稳定性。
链接:https://pan.baidu.com/share/init?surl=NTxi4ftMBHR5MTRfQNNY3g
提取码: jcbh
3.安装过程
直接whl包安装即可
安装cuda和cudnn如下:
conda install cudatoolkit=11.2 -c conda-forge
conda install cudnn=8.1.0 -c conda-forge
4.实际测试:
测试代码
import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None,2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
y = tf.sigmoid(y)
cross_entropy = -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y, 1e-10, 1.0))
+ (1-y_)*tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
Y = [[int(x1+x2 < 1)] for (x1,x2) in X]
with tf.Session() as sess:
init_go = tf.global_variables_initializer()
sess.run(init_go)
print('parameter w1 before train: ', sess.run(w1))
print('parameter w2 before train: ', sess.run(w2))
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % dataset_size
end = min(start+batch_size, dataset_size)
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i%1000 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print('After %d training_steps, cross entropy on all data is %g'%(i,total_cross_entropy))
print('parameter w1 after train: ', sess.run(w1))
print('parameter w2 after train: ', sess.run(w2))
运行结果:
备注: 基本的代码能够运行起来,但似乎对显存占用的更高。没有tflite模块。
其他的待测试。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)