Skip to content

Commit

Permalink
Fix pad to multiple of (#25732)
Browse files Browse the repository at this point in the history
* nits

* update the test

* nits

* update

* fix bark

* fix bark tests and allow padding to multiple of without new tokens
  • Loading branch information
ArthurZucker authored and LysandreJik committed Sep 15, 2023
1 parent 2ba46c1 commit b033d1a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
12 changes: 7 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,8 @@ def resize_token_embeddings(
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
Expand All @@ -1431,12 +1432,12 @@ def resize_token_embeddings(
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None:
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Update base model and current model config
self.config.vocab_size = new_num_tokens
self.vocab_size = new_num_tokens
self.config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = model_embeds.weight.shape[0]

# Tie weights again if needed
self.tie_weights()
Expand Down Expand Up @@ -1492,7 +1493,8 @@ def _get_resized_embeddings(
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
Expand Down
46 changes: 41 additions & 5 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,21 +1086,57 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
]
)
self.set_input_embeddings(new_embeddings_list)
new_num_tokens = [embed.weight.shape[0] for embed in new_embeddings_list]
new_num_tokens = new_embeddings_list[0].weight.shape[0]

# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head_list = self.get_output_embeddings()
new_lm_head_list = nn.ModuleList(
[
self._get_resized_lm_head(old_lm_head, new_num_token)
for old_lm_head, new_num_token in zip(old_lm_head_list, new_num_tokens)
]
[self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
)
self.set_output_embeddings(new_lm_head_list)

return self.get_input_embeddings()

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Update base model and current model config
self.config.output_vocab_size = model_embeds[0].weight.shape[0]
self.config.vocab_size = model_embeds[0].weight.shape[0]
self.output_vocab_size = model_embeds[0].weight.shape[0]
self.vocab_size = model_embeds[0].weight.shape[0]

# Tie weights again if needed
self.tie_weights()

return model_embeds

def tie_weights(self):
"""
Tie the weights between the input embeddings list and the output embeddings list.
Expand Down
3 changes: 3 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,9 @@ def test_resize_tokens_embeddings(self):
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)

self.assertTrue(model_embed.weight.shape[0], model.config.vocab_size)
self.assertTrue(model.config.vocab_size, model.vocab_size)

model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)

Expand Down

0 comments on commit b033d1a

Please sign in to comment.