Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Bug][Fix][WIP] Fix pre-layernormalization in Transformer #1488

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ horovodrun -np 4 -H localhost:4 python3 train_transformer.py \
--warmup_steps 4000 \
--warmup_init_lr 0.0 \
--seed 123 \
--max_grad_norm 1.0 \
--fp16
```

Expand Down
2 changes: 1 addition & 1 deletion scripts/processing/clean_tok_mono_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_parser():
parser.add_argument('--discard-non-latin1', action='store_true',
help='Whether to discard the sentence pair if both sentences cannot be '
'encoded into latin1.')
parser.add_argument('--num-process', type=int, default=8,
parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(),
help='number of process')
parser.add_argument('--overwrite', action='store_true')

Expand Down
2 changes: 1 addition & 1 deletion scripts/processing/clean_tok_para_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def get_parser():
parser.add_argument('--discard-non-latin1', action='store_true',
help='Whether to discard the sentence pair if both sentences cannot be '
'encoded into latin1.')
parser.add_argument('--num-process', type=int, default=8,
parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(),
help='number of process')
parser.add_argument('--overwrite', action='store_true')

Expand Down
4 changes: 2 additions & 2 deletions src/gluonnlp/data/batchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ class Pad:
val : float or int, default 0
The padding value.
axis : int, default 0
The axis to pad the arrays. The arrays will be padded to the largest dimension at
`axis`. For example, assume the input arrays have shape
The axis to pad the arrays. The arrays will be padded to the largest possible dimension,
and then stack at `axis`. For example, assume the input arrays have shape
(10, 8, 5), (6, 8, 5), (3, 8, 5) and the `axis` is 0. Each input will be padded into
(10, 8, 5) and then stacked to form the final output, which has shape(3, 10, 8, 5).
dtype : str or numpy.dtype, default None
Expand Down
6 changes: 3 additions & 3 deletions src/gluonnlp/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ class PositionwiseFFN(HybridBlock):
"""The Position-wise FFN layer used in Transformer-like architectures

If pre_norm is True:
norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data)
data -> norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data)
Else:
data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(res(+data))
"""
Expand Down Expand Up @@ -566,7 +566,6 @@ def forward(self, data):

Parameters
----------
F
data :
Shape (B, seq_length, C_in)

Expand All @@ -575,13 +574,14 @@ def forward(self, data):
out :
Shape (B, seq_length, C_out)
"""
residual = data
if self._pre_norm:
data = self.layer_norm(data)
out = self.activation(self.ffn_1(data))
out = self.activation_dropout_layer(out)
out = self.ffn_2(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.layer_norm(out)
return out
Expand Down
20 changes: 13 additions & 7 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def transformer_wmt_en_de_big_t2t():
cfg.defrost()
cfg.MODEL.attention_dropout = 0.1
cfg.MODEL.activation_dropout = 0.1
cfg.MODEL.dropout = 0.1
cfg.MODEL.ENCODER.pre_norm = True
cfg.MODEL.DECODER.pre_norm = True
cfg.freeze()
Expand Down Expand Up @@ -255,6 +256,7 @@ def forward(self, data, attn_mask):
attn_weight
Shape (batch_size, seq_length, seq_length)
"""
residual = data
if self._pre_norm:
data = self.layer_norm(data)
query, key, value = np.split(self.attn_qkv(data), 3, axis=-1)
Expand All @@ -264,7 +266,7 @@ def forward(self, data, attn_mask):
out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask)
out = self.attention_proj(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.layer_norm(out)
out = self.ffn(out)
Expand Down Expand Up @@ -565,6 +567,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask):
Shape (seq_length, batch_size, C_out)
"""
# 1. Get the causal self-attention value
residual = data
if self._pre_norm:
data = self.ln_in(data)
self_query, self_key, self_value = np.split(self.attn_in_qkv(data), 3, axis=-1)
Expand All @@ -575,11 +578,12 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask):
self_causal_mask)
out = self.proj_in(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.ln_in(out)
# 2. Attend to the contextual memory
data = out
residual = data
if self._pre_norm:
data = self.ln_inter(data)
out, [_, context_attn_weight] = self.inter_attention(
Expand All @@ -589,7 +593,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask):
mem_attn_mask)
out = self.proj_inter(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.ln_inter(out)
# 3. Encode the output via an FFN layer
Expand Down Expand Up @@ -681,13 +685,14 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask=
Shape (batch_size, prev_seq_length + 1, num_heads, C_value)

"""
if self._pre_norm:
data = self.ln_in(data)
if self.layout == 'NT':
time_axis = 1
else:
time_axis = 0
data = np.expand_dims(data, axis=time_axis)
residual = data
if self._pre_norm:
data = self.ln_in(data)
# Shape (B, prev_L, #Head, C_K), (B, prev_L, #Head, C_V)
# or (prev_L, B, #Head, C_K), (prev_L, B, #Head, C_V)
prev_key, prev_value = states
Expand All @@ -708,11 +713,12 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask=
out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None)
out = self.proj_in(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.ln_in(out)
# 2. Attend to the contextual memory
data = out
residual = data
if self._pre_norm:
data = self.ln_inter(data)
out, _ = self.inter_attention(npx.reshape(self.attn_inter_q(data),
Expand All @@ -724,7 +730,7 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask=
mem_attn_mask)
out = self.proj_inter(out)
out = self.dropout_layer(out)
out = out + data
out = out + residual
if not self._pre_norm:
out = self.ln_inter(out)
# 3. Encode the output via an FFN layer
Expand Down