RNN: GRU
PyTorch
blog
1. 计算模式

2. Python代码计算
def gru_test(inputs, dict_):
weight_ih_l0 = dict_['_parameters']['weight_ih_l0'].detach().numpy()
weight_hh_l0 = dict_['_parameters']['weight_hh_l0'].detach().numpy()
bias_ih = dict_['_parameters']['bias_ih_l0'].detach().numpy()
bias_hh = dict_['_parameters']['bias_hh_l0'].detach().numpy()
hidden_size, input_size = weight_ih_l0.shape # 384*128
hidden_size = int( hidden_size / 3 )
w_ir = weight_ih_l0[0:hidden_size, :]
w_iz = weight_ih_l0[hidden_size:hidden_size * 2, :]
w_in = weight_ih_l0[hidden_size * 2:hidden_size * 3, :]
w_hr = weight_hh_l0[0:hidden_size, :]
w_hz = weight_hh_l0[hidden_size:hidden_size * 2, :]
w_hn = weight_hh_l0[hidden_size * 2:hidden_size * 3, :]
bir = bias_ih[0:hidden_size]
biz = bias_ih[hidden_size:hidden_size * 2]
bin = bias_ih[2 * hidden_size:3 * hidden_size]
# print("bii = ", bii.view(1, -1))
bhr = bias_hh[0:hidden_size]
bhz = bias_hh[hidden_size:hidden_size * 2]
bhn = bias_hh[2 * hidden_size:3 * hidden_size]
h0 = np.zeros(hidden_size).astype(np.float32) # 128
input = inputs.value[0,0,:] # 1,31,128
wir_x = np.dot(w_ir,input) + bir
whr_x = np.dot(w_hr,h0) + bhr
wiz_x = np.dot(w_iz, input) + biz
whz_x = np.dot(w_hz,h0) + bhz
win_x = np.dot(w_in, input) + bin
whn_h = np.dot(w_hn,h0) + bhn
r0 = F.sigmoid(torch.from_numpy(wir_x + whr_x))
z0 = F.sigmoid(torch.from_numpy(wiz_x + whz_x))
r_WhnH_bhn = r0 * torch.from_numpy(whn_h)
n0 = F.tanh(torch.from_numpy(win_x) + r_WhnH_bhn)
t1 = torch.Tensor(np.ones(hidden_size))
h1 = (t1 - z0) * n0+z0*torch.from_numpy(h0)
return h1