RNN: GRU PyTorch blog 1. 计算模式 2. Python代码计算 def gru_test(inputs, dict_): weight_ih_l0 = dict_['_parameters']['weight_ih_l0'].detach().numpy() weight_hh_l0 = dict_['_parameters']['weight_hh_l0'].detach().numpy() bias_ih = dict_['_parameters']['bias_ih_l0'].detach().numpy() bias_hh = dict_['_parameters']['bias_hh_l0'].detach().numpy() hidden_size, input_size = weight_ih_l0.shape # 384*128 hidden_size = int( hidden_size / 3 ) w_ir = weight_ih_l0[0:hidden_size, :] w_iz = weight_ih_l0[hidden_size:hidden_size * 2, :] w_in = weight_ih_l0[hidden_size * 2:hidden_size * 3, :] w_hr = weight_hh_l0[0:hidden_size, :] w_hz = weight_hh_l0[hidden_size:hidden_size * 2, :] w_hn = weight_hh_l0[hidden_size * 2:hidden_size * 3, :] bir = bias_ih[0:hidden_size] biz = bias_ih[hidden_size:hidden_size * 2] bin = bias_ih[2 * hidden_size:3 * hidden_size] # print("bii = ", bii.view(1, -1)) bhr = bias_hh[0:hidden_size] bhz = bias_hh[hidden_size:hidden_size * 2] bhn = bias_hh[2 * hidden_size:3 * hidden_size] h0 = np.zeros(hidden_size).astype(np.float32) # 128 input = inputs.value[0,0,:] # 1,31,128 wir_x = np.dot(w_ir,input) + bir whr_x = np.dot(w_hr,h0) + bhr wiz_x = np.dot(w_iz, input) + biz whz_x = np.dot(w_hz,h0) + bhz win_x = np.dot(w_in, input) + bin whn_h = np.dot(w_hn,h0) + bhn r0 = F.sigmoid(torch.from_numpy(wir_x + whr_x)) z0 = F.sigmoid(torch.from_numpy(wiz_x + whz_x)) r_WhnH_bhn = r0 * torch.from_numpy(whn_h) n0 = F.tanh(torch.from_numpy(win_x) + r_WhnH_bhn) t1 = torch.Tensor(np.ones(hidden_size)) h1 = (t1 - z0) * n0+z0*torch.from_numpy(h0) return h1