diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 3febadf112..ec59ceffd0 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -419,7 +419,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa } // check whether group has finished - group.is_done(m_parameters); + group.is_done(m_parameters, this->m_sequence_group->get_prompt_len()); // group cannot continue if there are no valid child beams if (child_beams_per_group[group_id].size() == 0) { @@ -560,14 +560,14 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen std::vector dropped_seq_ids; for (auto& running_sequence : sequence_group->get_running_sequences()) { const auto generated_len = running_sequence->get_generated_len(); - if (sampling_params.max_new_tokens <= generated_len || + if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) <= generated_len || is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { // stop sequence by max_new_tokens or stop token (eos included) running_sequence->set_status(SequenceStatus::FINISHED); if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { running_sequence->set_finish_reason(GenerationFinishReason::STOP); - } else if (sampling_params.max_new_tokens == generated_len) { + } else if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) == generated_len) { running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); } @@ -786,8 +786,8 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // max counter of needed to be sampled tokens OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset); size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset; - OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len); - size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len; + OPENVINO_ASSERT(sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) >= generated_and_verified_len); + size_t max_num_sampled_token = sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) - generated_and_verified_len; if (max_num_sampled_token == 0) { stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request); break; @@ -873,7 +873,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); if (!sequence_group->has_finished() && - running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) { + running_sequences[0]->get_generated_len() == sampling_params.get_max_new_tokens(sequence_group->get_prompt_len())) { // stop sequence by max_new_tokens m_beam_search_info.at(request_id).finalize(sampler_output); } @@ -939,7 +939,7 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge return preeempted_sequence_id; } -void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) { +void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len) { assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups; @@ -960,7 +960,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi return; } case ov::genai::StopCriteria::NEVER: { - size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len; + size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.get_max_new_tokens() : cur_len; float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty); done = worst_score >= highest_attainable_score; return; diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 0f7876cbf9..9979e0ff16 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -105,7 +105,7 @@ class Sampler::GroupBeamSearcher { bool done = false; int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params); - void is_done(const ov::genai::GenerationConfig& sampling_params); + void is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len); }; SequenceGroup::Ptr m_sequence_group;