Skip to content

Commit

Permalink
Set input_embeds before it gets used (#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
tthakkal authored and regisss committed Aug 16, 2024
1 parent 427d325 commit 09551d9
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ def forward(
else:
past_key_values_length = past_key_values[0][0].shape[2]

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
Expand All @@ -615,9 +618,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
Expand Down

0 comments on commit 09551d9

Please sign in to comment.