Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
dynamic decoding + separate scheduler
  • Loading branch information
mzegla committed Dec 11, 2024
1 parent 74a6c34 commit f8af511
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 25 deletions.
1 change: 0 additions & 1 deletion ci/lib_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def check_dir(start_dir):
'__pycache__',
'add.xml',
'azure_sdk.patch',
'cb.patch',
'bazel-',
'check_coverage.bat',
'genhtml',
Expand Down
1 change: 0 additions & 1 deletion external/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ exports_files([
"listen.patch",
"tf.patch",
"net_http.patch",
"cb.patch",
])
1 change: 0 additions & 1 deletion spelling-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
client/common/resnet_labels.txt
demos/common/python/classes.py
demos/image_classification/go/labels.go
external/cb.patch
extras/nginx-mtls-auth/model_server.conf.template
release_files/thirdparty-licenses/boringssl.LICENSE.txt
src/shape.cpp:436: strIn
Expand Down
38 changes: 35 additions & 3 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart() {
return absl::OkStatus();
}

absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline) {
OVMS_PROFILE_FUNCTION();
// stream: bool; optional
if (!doc.IsObject())
Expand Down Expand Up @@ -350,6 +350,38 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim
request.numReturnSequences = it->value.GetUint();
}

// Speculative decoding specific parameters

auto numAssistantTokensIt = doc.FindMember("num_assistant_tokens");
auto assistantConfidenceThresholdIt = doc.FindMember("assistant_confidence_threshold");

if (isSpeculativePipeline) {
if (numAssistantTokensIt == doc.MemberEnd() && assistantConfidenceThresholdIt == doc.MemberEnd())
return absl::InvalidArgumentError("Speculative decoding requires either num_assistant_tokens or assistant_confidence_threshold to be set.");

if (numAssistantTokensIt != doc.MemberEnd() && assistantConfidenceThresholdIt != doc.MemberEnd())
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are mutually exclusive and cannot both be set.");
} else if (numAssistantTokensIt != doc.MemberEnd() || assistantConfidenceThresholdIt != doc.MemberEnd()) {
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are only supported when speculative decoding is enabled.");
}
// num_assistant_tokens: uint;
if (numAssistantTokensIt != doc.MemberEnd()) {
if (!numAssistantTokensIt->value.IsUint() || numAssistantTokensIt->value.GetUint() == 0) {
return absl::InvalidArgumentError("num_assistant_tokens must be an unsigned integer greater than 0");
}
request.numAssistantTokens = numAssistantTokensIt->value.GetUint();
}
// assistant_confidence_threshold: float;
if (assistantConfidenceThresholdIt != doc.MemberEnd()) {
if (!assistantConfidenceThresholdIt->value.IsDouble() && !assistantConfidenceThresholdIt->value.IsInt()) {
return absl::InvalidArgumentError("assistant_confidence_threshold must be a positive number");
}
request.assistantConfidenceThreshold = assistantConfidenceThresholdIt->value.GetDouble();
if (request.assistantConfidenceThreshold <= 0.0) {
return absl::InvalidArgumentError("assistant_confidence_threshold must be greater than 0");
}
}

// use_beam_search: bool; optional - defaults to false
// Extension from vLLM, unsupported by OpenAI API, not available directly in CB lib
// Use best_of>1 to steer into beams search
Expand Down Expand Up @@ -392,8 +424,8 @@ ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig
return request.createGenerationConfig();
}

absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit);
absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline) {
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit, isSpeculativePipeline);

if (status != absl::OkStatus())
return status;
Expand Down
42 changes: 27 additions & 15 deletions src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,36 @@ struct CompletionUsageStatistics {

// Class that maps OpenAI request content and provides methods to create GenerationConfig from it.
struct OpenAIChatCompletionsRequest {
// Generic
chat_t messages;
std::optional<std::string> prompt{std::nullopt};
bool stream{false};
StreamOptions streamOptions;
std::string model;
std::optional<int> maxTokens{std::nullopt};
std::optional<float> frequencyPenalty{std::nullopt};
std::optional<float> presencePenalty{std::nullopt};
std::optional<float> diversityPenalty{std::nullopt};
std::optional<float> repetitionPenalty{std::nullopt};
std::optional<float> lengthPenalty{std::nullopt};
std::optional<int> numReturnSequences{std::nullopt};
bool logprobs = 0;
int logprobschat = false;
bool echo{false};
std::optional<bool> ignoreEOS{std::nullopt};
std::optional<std::set<std::string>> stop{std::nullopt};
std::optional<bool> includeStopStrInOutput{std::nullopt};
std::optional<int> numReturnSequences{std::nullopt}; // effective for beam search and multinomial decoding
// Multinomial decoding specific
std::optional<float> temperature{std::nullopt};
std::optional<float> topP{std::nullopt};
std::optional<int> topK{std::nullopt};
std::optional<int> seed{std::nullopt};
std::optional<std::set<std::string>> stop{std::nullopt};
std::optional<bool> includeStopStrInOutput{std::nullopt};
std::optional<float> frequencyPenalty{std::nullopt};
std::optional<float> presencePenalty{std::nullopt};;
std::optional<float> repetitionPenalty{std::nullopt};
// Beam search specific
std::optional<int> bestOf{std::nullopt};
std::optional<bool> ignoreEOS{std::nullopt};
int logprobs = 0;
bool logprobschat = false;
bool echo{false};
std::optional<float> lengthPenalty{std::nullopt};
std::optional<float> diversityPenalty{std::nullopt};

// Speculative decoding specific (only with speculative decoding pipeline, see <docs> for reference)
std::optional<int> numAssistantTokens{std::nullopt};
std::optional<float> assistantConfidenceThreshold{std::nullopt};

OpenAIChatCompletionsRequest() = default;
~OpenAIChatCompletionsRequest() = default;
Expand Down Expand Up @@ -120,7 +127,7 @@ struct OpenAIChatCompletionsRequest {
// TODO: early_finish = ?
// TODO use_beam_search is unused ?

// Multinomial specific
// Multinomial sampling specific
if (temperature.has_value())
config.temperature = temperature.value();
if (topK.has_value())
Expand All @@ -141,6 +148,11 @@ struct OpenAIChatCompletionsRequest {

if (logprobschat || logprobs > 0)
config.logprobs = 1;
// Speculative decoding specific
if (numAssistantTokens.has_value())
config.num_assistant_tokens = numAssistantTokens.value();
if (assistantConfidenceThreshold.has_value())
config.assistant_confidence_threshold = assistantConfidenceThreshold.value();

return config;
}
Expand All @@ -159,7 +171,7 @@ class OpenAIChatCompletionsHandler {

absl::Status parseCompletionsPart();
absl::Status parseChatCompletionsPart();
absl::Status parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit);
absl::Status parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline);

public:
OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point<std::chrono::system_clock> creationTime,
Expand All @@ -182,7 +194,7 @@ class OpenAIChatCompletionsHandler {

ov::genai::GenerationConfig createGenerationConfig() const;

absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit);
absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline);

std::string serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs);
std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason);
Expand Down
4 changes: 3 additions & 1 deletion src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class HttpLLMCalculator : public CalculatorBase {
nodeResources->cbPipe->get_tokenizer());
this->client = payload.client;

auto status = this->apiHandler->parseRequest(nodeResources->maxTokensLimit, nodeResources->bestOfLimit);
auto status = this->apiHandler->parseRequest(nodeResources->maxTokensLimit, nodeResources->bestOfLimit, nodeResources->isSpeculativePipeline);
if (status != absl::OkStatus())
return status;

Expand Down Expand Up @@ -204,7 +204,9 @@ class HttpLLMCalculator : public CalculatorBase {
if (this->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING || this->generationHandle->can_read()) {
// Subsequent iteration
OVMS_PROFILE_SCOPE("Generation of subsequent streaming response");
//SPDLOG_LOGGER_INFO(llm_calculator_logger, "Start read() ...");
ov::genai::GenerationOutputs generationOutputs = this->generationHandle->read();
//SPDLOG_LOGGER_INFO(llm_calculator_logger, "End read() ...");
RET_CHECK(generationOutputs.size() == 1); // TODO: Support multiple generations
this->apiHandler->incrementProcessedTokens(generationOutputs.begin()->second.generated_ids.size());

Expand Down
17 changes: 17 additions & 0 deletions src/llm/llm_calculator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,21 @@ message LLMCalculatorOptions {
optional uint32 max_tokens_limit = 9 [default = 4096];

optional bool enable_prefix_caching = 10 [default = false];

// speculative decoding - draft model config (ignore below fields if you don't want to use speculative decoding)
// when draft_models_path is set, the pipeline will use speculative decoding
// other values are by default inherited from the main model when speculative decoding is enabled, but can be overridden
optional string draft_models_path = 11;

optional string draft_device = 12;

optional uint64 draft_max_num_batched_tokens = 13;

optional uint64 draft_cache_size = 14;

optional uint64 draft_block_size = 15;

optional uint64 draft_max_num_seqs = 16;

optional bool draft_dynamic_split_fuse = 17;
}
20 changes: 19 additions & 1 deletion src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#pragma GCC diagnostic pop

#include "../mediapipe_internal/mediapipe_utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "src/llm/llm_executor.hpp"
#include "src/llm/text_processor.hpp"

Expand Down Expand Up @@ -155,6 +154,14 @@ Status LLMNodeResources::initializeLLMNodeResources(LLMNodeResources& nodeResour

nodeResources.device = nodeOptions.device();

if (!nodeOptions.draft_models_path().empty()) {
auto draftSchedulerConfig = prepareDraftModelSchedulerConfig(nodeOptions);
auto draftModelConfig = ov::genai::draft_model(nodeOptions.draft_models_path(), nodeOptions.draft_device(),
ov::genai::scheduler_config(draftSchedulerConfig));
nodeResources->pluginConfig.insert(draftModelConfig);
nodeResources->isSpeculativePipeline = true;
}

auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), nodeResources.pluginConfig);
if (!status.ok()) {
SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config());
Expand Down Expand Up @@ -208,4 +215,15 @@ std::unordered_map<std::string, std::string> LLMNodeResources::prepareLLMNodeIni
return LLMArguments;
}

ov::genai::SchedulerConfig LLMNodeResources::prepareDraftModelSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions) {
return {
.max_num_batched_tokens = nodeOptions.has_draft_max_num_batched_tokens() ? nodeOptions.draft_max_num_batched_tokens() : nodeOptions.max_num_batched_tokens(),
.cache_size = nodeOptions.has_draft_cache_size() ? nodeOptions.draft_cache_size() : nodeOptions.cache_size(),
.block_size = nodeOptions.has_draft_block_size() ? nodeOptions.draft_block_size() : nodeOptions.block_size(),
.dynamic_split_fuse = nodeOptions.has_draft_dynamic_split_fuse() ? nodeOptions.draft_dynamic_split_fuse() : nodeOptions.dynamic_split_fuse(),
.max_num_seqs = nodeOptions.has_draft_max_num_seqs() ? nodeOptions.draft_max_num_seqs() : nodeOptions.max_num_seqs(),
.enable_prefix_caching = nodeOptions.enable_prefix_caching(),
};
}

} // namespace ovms
3 changes: 3 additions & 0 deletions src/llm/llmnoderesources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../logging.hpp"
#include "../stringutils.hpp"
#include "src/python/utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "text_processor.hpp"

namespace ovms {
Expand Down Expand Up @@ -105,6 +106,7 @@ using plugin_config_t = std::map<std::string, ov::Any>;
struct LLMNodeResources {
public:
std::shared_ptr<ov::genai::ContinuousBatchingPipeline> cbPipe = nullptr;
bool isSpeculativePipeline{false};
std::string modelsPath;
std::string device;
plugin_config_t pluginConfig;
Expand All @@ -128,6 +130,7 @@ struct LLMNodeResources {
private:
std::unique_ptr<LLMExecutorWrapper> llmExecutorWrapper;
static std::unordered_map<std::string, std::string> prepareLLMNodeInitializeArguments(const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string basePath);
static ov::genai::SchedulerConfig prepareDraftModelSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions);

public:
virtual void initializeContinuousBatchingPipeline(
Expand Down
2 changes: 0 additions & 2 deletions third_party/llm_engine/llm_engine.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def llm_engine():
build_file = "@_llm_engine//:BUILD",
init_submodules = True,
recursive_init_submodules = True,
patch_args = ["-p1"],
patches = ["cb.patch"],
)
# when using local repository manually run: git submodule update --recursive
#native.new_local_repository(
Expand Down

0 comments on commit f8af511

Please sign in to comment.