lstmcell在转onnx的时候会遇到不支持的情况,如果模型已经训练好,可以通过自己实现lstmcell的方式,加载训练好的权重;以下是实现代码
class MyLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(MyLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias_ih = nn.Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = nn.Parameter(torch.Tensor(4 * hidden_size))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight_hh, a=math.sqrt(5))
nn.init.zeros_(self.bias_ih)
nn.init.zeros_(self.bias_hh)
def forward(self, input, hx):
# input: (batch_size, input_size)
# hx: (batch_size, hidden_size)
hx = hx[0] if isinstance(hx, tuple) else hx
gates = (input @ self.weight_ih.t() + self.bias_ih +
hx @ self.weight_hh.t() + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * hx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy