如果你的cudatoolkit是9.x版本的,在执行两个很大的batch做matmal的时候,可能会报一个很奇怪的错误:
但是实际上你的显存是够的。为什么会报这样的错误呢?
这个问题困扰了我好几天。从网上查阅了很多资料,才发现是cublas的内部的一个保护机制。当你对两个batch做matmul的时候,如果batch的大小大于172800(大概是这么一个数),就会报错。不太确定cudatoolkit10.x还有没有类似的问题,但是至少cudatoolkit9.x都会遇到这个问题,所以只能想办法把batch改小一点。
注意这里说的batch大小是说矩阵相乘的前面的维度的综合。比如你要做的操作是:
tf.matmul(tf.ones([512, 1024, 4, 2]), tf.ones([512, 1024, 2, 1]))
也会报错的。虽然后面真实相乘的矩阵很小,但是512*1024>172800了,所以会报错。
不信的话,你可以用下面的程序测试一下:
import tensorflow as tf
import numpy as np
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
tf.Session(config=config).close()
def calc():
N = 15 # works for N <= 14