-
Notifications
You must be signed in to change notification settings - Fork 8
/
attention.py
122 lines (99 loc) · 5.23 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import tensorflow as tf
import os
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras import backend as K
class AttentionLayer(Layer):
"""
This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).
There are three sets of weights introduced W_a, U_a, and V_a
"""
def __init__(self, **kwargs):
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
# Create a trainable weight variable for this layer.
self.W_a = self.add_weight(name='W_a',
shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),
initializer='uniform',
trainable=True)
self.U_a = self.add_weight(name='U_a',
shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),
initializer='uniform',
trainable=True)
self.V_a = self.add_weight(name='V_a',
shape=tf.TensorShape((input_shape[0][2], 1)),
initializer='uniform',
trainable=True)
super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, inputs, verbose=False):
"""
inputs: [encoder_output_sequence, decoder_output_sequence]
"""
assert type(inputs) == list
encoder_out_seq, decoder_out_seq = inputs
if verbose:
print('encoder_out_seq>', encoder_out_seq.shape)
print('decoder_out_seq>', decoder_out_seq.shape)
def energy_step(inputs, states):
""" Step function for computing energy for a single decoder state """
assert_msg = "States must be a list. However states {} is of type {}".format(states, type(states))
assert isinstance(states, list) or isinstance(states, tuple), assert_msg
""" Some parameters required for shaping tensors"""
en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]
de_hidden = inputs.shape[-1]
""" Computing S.Wa where S=[s0, s1, ..., si]"""
# <= batch_size*en_seq_len, latent_dim
reshaped_enc_outputs = K.reshape(encoder_out_seq, (-1, en_hidden))
# <= batch_size*en_seq_len, latent_dim
W_a_dot_s = K.reshape(K.dot(reshaped_enc_outputs, self.W_a), (-1, en_seq_len, en_hidden))
if verbose:
print('wa.s>',W_a_dot_s.shape)
""" Computing hj.Ua """
U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim
if verbose:
print('Ua.h>',U_a_dot_h.shape)
""" tanh(S.Wa + hj.Ua) """
# <= batch_size*en_seq_len, latent_dim
reshaped_Ws_plus_Uh = K.tanh(K.reshape(W_a_dot_s + U_a_dot_h, (-1, en_hidden)))
if verbose:
print('Ws+Uh>', reshaped_Ws_plus_Uh.shape)
""" softmax(va.tanh(S.Wa + hj.Ua)) """
# <= batch_size, en_seq_len
e_i = K.reshape(K.dot(reshaped_Ws_plus_Uh, self.V_a), (-1, en_seq_len))
# <= batch_size, en_seq_len
e_i = K.softmax(e_i)
if verbose:
print('ei>', e_i.shape)
return e_i, [e_i]
def context_step(inputs, states):
""" Step function for computing ci using ei """
# <= batch_size, hidden_size
c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
if verbose:
print('ci>', c_i.shape)
return c_i, [c_i]
def create_inital_state(inputs, hidden_size):
# We are not using initial states, but need to pass something to K.rnn funciton
fake_state = K.zeros_like(inputs) # <= (batch_size, enc_seq_len, latent_dim
fake_state = K.sum(fake_state, axis=[1, 2]) # <= (batch_size)
fake_state = K.expand_dims(fake_state) # <= (batch_size, 1)
fake_state = K.tile(fake_state, [1, hidden_size]) # <= (batch_size, latent_dim
return fake_state
fake_state_c = create_inital_state(encoder_out_seq, encoder_out_seq.shape[-1])
fake_state_e = create_inital_state(encoder_out_seq, encoder_out_seq.shape[1]) # <= (batch_size, enc_seq_len, latent_dim
""" Computing energy outputs """
# e_outputs => (batch_size, de_seq_len, en_seq_len)
last_out, e_outputs, _ = K.rnn(
energy_step, decoder_out_seq, [fake_state_e],
)
""" Computing context vectors """
last_out, c_outputs, _ = K.rnn(
context_step, e_outputs, [fake_state_c],
)
return c_outputs, e_outputs
def compute_output_shape(self, input_shape):
""" Outputs produced by the layer """
return [
tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
]