From 589365babbef7de7d281f68a42d9c493ce7ba730 Mon Sep 17 00:00:00 2001 From: mzegla Date: Mon, 18 Nov 2024 13:15:40 +0100 Subject: [PATCH] refactor --- src/cpp/src/sampler.cpp | 63 +++++++++++++++++--------------------- src/cpp/src/sampler.hpp | 10 ++++-- src/cpp/src/threadpool.hpp | 61 ++++++++++++++++-------------------- 3 files changed, 63 insertions(+), 71 deletions(-) diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 7e53117411..3869b7ab23 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -741,12 +741,11 @@ float get_p_prime(Sequence::Ptr& running_sequence, return p_prime; } -std::tuple Sampler::sample_from_sequence_group(SequenceGroup::Ptr sequence_group, - ov::Tensor sequence_group_logits, - LogitProcessor& logit_processor, - bool is_validation_mode_enabled) { - SamplerOutput sampler_output; - size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits::max(); +SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr sequence_group, + ov::Tensor sequence_group_logits, + LogitProcessor& logit_processor, + bool is_validation_mode_enabled) { + SequenceGroupSamplingInfo sampling_info; auto num_running_sequences = sequence_group->num_running_seqs(); auto sampling_params = sequence_group->get_sampling_parameters(); // get number of token to be validated @@ -769,7 +768,7 @@ std::tuple Sampler::sample_from_sequence_group(Se OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len); size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len; if (max_num_sampled_token == 0) { - stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request); + stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, sampling_info.max_removed_tokens_per_request); break; } @@ -797,13 +796,13 @@ std::tuple Sampler::sample_from_sequence_group(Se // to create n sequence just in case of `sequence_group->num_total_seqs() == 1` and `sampling_params.num_return_sequences > 1` if (is_generate_n_tokens) { const auto forked_seq_ids = create_n_forked_sequences(sequence_group, logit_processor, sampled_token_ids); - sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids}); + sampling_info.sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids}); } sampled_token = sampled_token_ids.front(); // make `_speculative_sampling` in case of previous token was not accepted in speculative decoding if (!is_validation_passed) { float p_prime = get_p_prime(running_sequence, sampled_token, token_offset + 1); - max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_offset); + sampling_info.max_removed_tokens_per_request = std::max(sampling_info.max_removed_tokens_per_request, token_offset); // update prob only in case candidate prob > sampled token prob if (p_prime > 0.f) { auto prob = std::exp(sampled_token.m_log_prob); @@ -816,7 +815,7 @@ std::tuple Sampler::sample_from_sequence_group(Se bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed; if (is_validation_mode_enabled && !is_extend_sequence) { is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token, - is_extend_sequence, max_removed_tokens_per_request, sampling_params.do_sample); + is_extend_sequence, sampling_info.max_removed_tokens_per_request, sampling_params.do_sample); // doing resample in case of non accepted tokens in specualtive sampling if (!is_validation_passed && sampling_params.do_sample) { continue; @@ -833,11 +832,11 @@ std::tuple Sampler::sample_from_sequence_group(Se break; } } - min_generated_len = std::min(min_generated_len, running_sequence->get_generated_len()); + sampling_info.min_generated_len = std::min(sampling_info.min_generated_len, running_sequence->get_generated_len()); } - align_all_sequence_len(sequence_group, min_generated_len, logit_processor); + align_all_sequence_len(sequence_group, sampling_info.min_generated_len, logit_processor); for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) { - sampler_output.m_dropped_sequences.push_back(dropped_seq_id); + sampling_info.sampler_output.m_dropped_sequences.push_back(dropped_seq_id); } } else if (sampling_params.is_beam_search()) { uint64_t request_id = sequence_group->get_request_id(); @@ -853,23 +852,23 @@ std::tuple Sampler::sample_from_sequence_group(Se } // current algorithm already adds new tokens to running sequences and - beam_searcher->select_next_tokens(sequence_group_logits, sampler_output); + beam_searcher->select_next_tokens(sequence_group_logits, sampling_info.sampler_output); // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); if (!sequence_group->has_finished() && running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) { // stop sequence by max_new_tokens - beam_searcher->finalize(sampler_output); + beam_searcher->finalize(sampling_info.sampler_output); } } // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. - OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; + OPENVINO_ASSERT(num_tokens_to_process >= sampling_info.max_removed_tokens_per_request); + size_t num_output_token_to_push = num_tokens_to_process - sampling_info.max_removed_tokens_per_request + 1; sequence_group->notify_handle(num_output_token_to_push); - return std::make_tuple(sampler_output, min_generated_len, max_removed_tokens_per_request); + return sampling_info; } SamplerOutput Sampler::sample(std::vector & sequence_groups, @@ -881,7 +880,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2]; SamplerOutput sampler_output; - std::unordered_map>> future_map; + std::unordered_map> future_map; for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) { SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; if (!sequence_group->is_scheduled()) @@ -898,36 +897,30 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())}); } - //std::cout << "\nSequence group ID: " << sequence_group_id << std::endl; - //std::cout << "Sequence group data valid capacity: " << logits.get_size() << std::endl; - //std::cout << "Sequence group data offset: " << vocab_size * currently_processed_tokens << std::endl; - const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens; ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data); - //std::cout << "Sequence group logits tensor size: " << sequence_group_logits.get_size() << std::endl; // Call sample_from_sequence_group asynchronously - //future_map[sequence_group] = std::async(std::launch::async, &Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits, is_validation_mode_enabled); - future_map[sequence_group] = m_thread_pool.enqueue(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits, - m_logit_processors.at(sequence_group->get_request_id()), is_validation_mode_enabled); + future_map[sequence_group] = m_thread_pool.submit(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits, + m_logit_processors.at(sequence_group->get_request_id()), is_validation_mode_enabled); } // accumulate a number of processed tokens currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences; } - // Iterate over sequence groups and check if future_map contains the key for (auto& sequence_group : sequence_groups) { if (future_map.find(sequence_group) != future_map.end()) { - auto [sequence_group_sampler_output, min_generated_len, max_removed_tokens_per_request] = future_map[sequence_group].get(); + // If there is a future assigned to a sequence group we read it's result (blocking if results not available yet) + auto sequence_group_sampling_info = future_map[sequence_group].get(); - // Merge sequence_group_sampler_output into sampler_output + // Merge sampler output from sequence group to the main one sampler_output.m_dropped_sequences.insert( sampler_output.m_dropped_sequences.end(), - sequence_group_sampler_output.m_dropped_sequences.begin(), - sequence_group_sampler_output.m_dropped_sequences.end() + sequence_group_sampling_info.sampler_output.m_dropped_sequences.begin(), + sequence_group_sampling_info.sampler_output.m_dropped_sequences.end() ); - for (const auto& forked_seq : sequence_group_sampler_output.m_forked_sequences) { + for (const auto& forked_seq : sequence_group_sampling_info.sampler_output.m_forked_sequences) { sampler_output.m_forked_sequences[forked_seq.first].insert( sampler_output.m_forked_sequences[forked_seq.first].end(), forked_seq.second.begin(), @@ -939,8 +932,8 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // update internal state of sequence group to reset scheduler tokens and update currently processed ones sequence_group->finish_iteration(); // decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model - if (max_removed_tokens_per_request) { - auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1; + if (sequence_group_sampling_info.max_removed_tokens_per_request) { + auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1; sequence_group->update_processed_tokens_num(min_processed_tokens); auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id()); logit_processor.update_generated_len(min_processed_tokens); diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 517b27664b..0da2d63af8 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -41,6 +41,12 @@ struct SamplerOutput { std::unordered_map> m_forked_sequences; }; +struct SequenceGroupSamplingInfo { + SamplerOutput sampler_output; + size_t max_removed_tokens_per_request = 0; + size_t min_generated_len = std::numeric_limits::max(); +}; + class Sampler { class GroupBeamSearcher; @@ -68,8 +74,8 @@ class Sampler { Sampler() = default; Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; - std::tuple sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits, - LogitProcessor& logit_processor, bool is_validation_mode_enabled = false); + SequenceGroupSamplingInfo sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits, + LogitProcessor& logit_processor, bool is_validation_mode_enabled = false); SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t seed) { rng_engine.seed(seed); } diff --git a/src/cpp/src/threadpool.hpp b/src/cpp/src/threadpool.hpp index 35159d5b00..9ecf8b585e 100644 --- a/src/cpp/src/threadpool.hpp +++ b/src/cpp/src/threadpool.hpp @@ -6,38 +6,33 @@ #include #include #include -using namespace std; -// Class that represents a simple thread pool class ThreadPool { private: - vector threads_; - queue> tasks_; - mutex queue_mutex_; - condition_variable cv_; - bool stop_ = false; + std::vector threads; + std::queue> tasks; + std::mutex queue_mutex; + std::condition_variable cv; + bool stop = false; public: - // Constructor to create a thread pool with given - // number of threads - ThreadPool(size_t num_threads = thread::hardware_concurrency()) + ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) { - // Creating worker threads for (size_t i = 0; i < num_threads; ++i) { - threads_.emplace_back([this] { + threads.emplace_back([this] { while (true) { - function task; + std::function task; { - unique_lock lock(queue_mutex_); - cv_.wait(lock, [this] { - return !tasks_.empty() || stop_; + std::unique_lock lock(queue_mutex); + cv.wait(lock, [this] { + return !tasks.empty() || stop; }); - if (stop_ && tasks_.empty()) { + if (stop && tasks.empty()) { return; } - task = move(tasks_.front()); - tasks_.pop(); + task = move(tasks.front()); + tasks.pop(); } task(); } @@ -45,33 +40,31 @@ class ThreadPool { } } - // Destructor to stop the thread pool ~ThreadPool() { { - unique_lock lock(queue_mutex_); - stop_ = true; + std::unique_lock lock(queue_mutex); + stop = true; } - cv_.notify_all(); - for (auto& thread : threads_) { + cv.notify_all(); + for (auto& thread : threads) { thread.join(); } } - // Enqueue task for execution by the thread pool template - auto enqueue(F&& f, Args&&... args) -> future> + auto submit(F&& f, Args&&... args) -> std::future> { - using return_type = invoke_result_t; - auto task = make_shared>( - bind(forward(f), forward(args)...) + using return_type = std::invoke_result_t; + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) ); - future res = task->get_future(); + std::future result = task->get_future(); { - unique_lock lock(queue_mutex_); - tasks_.emplace([task]() { (*task)(); }); + std::unique_lock lock(queue_mutex); + tasks.emplace([task]() { (*task)(); }); } - cv_.notify_one(); - return res; + cv.notify_one(); + return result; } };