RNN: GRU

PyTorch blog

1. 计算模式

gru

2. Python代码计算

def gru_test(inputs, dict_):
    weight_ih_l0 = dict_['_parameters']['weight_ih_l0'].detach().numpy()
    weight_hh_l0 = dict_['_parameters']['weight_hh_l0'].detach().numpy()
    bias_ih = dict_['_parameters']['bias_ih_l0'].detach().numpy()
    bias_hh = dict_['_parameters']['bias_hh_l0'].detach().numpy()
    hidden_size, input_size = weight_ih_l0.shape  # 384*128
    hidden_size = int( hidden_size / 3 )

    w_ir = weight_ih_l0[0:hidden_size, :]
    w_iz = weight_ih_l0[hidden_size:hidden_size * 2, :]
    w_in = weight_ih_l0[hidden_size * 2:hidden_size * 3, :]

    w_hr = weight_hh_l0[0:hidden_size, :]
    w_hz = weight_hh_l0[hidden_size:hidden_size * 2, :]
    w_hn = weight_hh_l0[hidden_size * 2:hidden_size * 3, :]

    bir = bias_ih[0:hidden_size]
    biz = bias_ih[hidden_size:hidden_size * 2]
    bin = bias_ih[2 * hidden_size:3 * hidden_size]

    # print("bii = ", bii.view(1, -1))
    bhr = bias_hh[0:hidden_size]
    bhz = bias_hh[hidden_size:hidden_size * 2]
    bhn = bias_hh[2 * hidden_size:3 * hidden_size]

    h0 = np.zeros(hidden_size).astype(np.float32)  # 128
    input = inputs.value[0,0,:]  # 1,31,128
    wir_x = np.dot(w_ir,input) + bir
    whr_x = np.dot(w_hr,h0) + bhr
    wiz_x = np.dot(w_iz, input) + biz
    whz_x = np.dot(w_hz,h0) + bhz
    win_x = np.dot(w_in, input) + bin
    whn_h = np.dot(w_hn,h0) + bhn

    r0 = F.sigmoid(torch.from_numpy(wir_x + whr_x))
    z0 = F.sigmoid(torch.from_numpy(wiz_x + whz_x))

    r_WhnH_bhn = r0 * torch.from_numpy(whn_h)
    n0 = F.tanh(torch.from_numpy(win_x) + r_WhnH_bhn)

    t1 = torch.Tensor(np.ones(hidden_size))
    h1 = (t1 - z0) * n0+z0*torch.from_numpy(h0)

    return h1
Loading Disqus comments...
Table of Contents