这个很简单,跟使用pytorch的GPU计算是一样的,就不解释了,直接上代码
代码:
import dgl
import torch
device = ''
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
u, v = torch.tensor([0, 1, 2]).to(device), torch.tensor([2, 3, 4]).to(device)
graph = dgl.graph((u, v))
nodes_feaure = torch.randn((graph.num_nodes(), 5)).to(device)
edges_feature = torch.randn((graph.num_edges(), 3)).to(device)
graph.ndata['x'] = nodes_feaure
graph.edata['e'] = edges_feature
print(graph)