Skip to content

Commit

Permalink
fix(qwen2): adapt to latest TnX
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Dec 23, 2024
1 parent 211f6f8 commit 1d60cef
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optimum/neuron/models/qwen2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,13 @@ def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwar
return padded_inputs, input_embeddings, *rst

def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs):
original_input_ids = input_ids
if last_token_id is not None: # preprocess_and_embed() has already been invoked
rst = cache_ids, start_ids, last_token_id
else: # invoke preprocess_and_embed()
input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs)
# either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding)
inputs = input_embeddings if input_embeddings is not None else input_ids
logits = self._forward(inputs, *rst)
logits = self._postprocess(logits, start_ids=start_ids, **kwargs)
logits = self._postprocess(original_input_ids, logits, start_ids=start_ids, **kwargs)
return logits

0 comments on commit 1d60cef

Please sign in to comment.