在代码最前面加入已下代码
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
在DataLoader中使用worker_init_fn来确定种子,worker_init_fn定义如下:
def worker_init_fn(worker_id):
np.random.seed(int(seed))