diff --git a/src/embeddings/embeddings_api.cpp b/src/embeddings/embeddings_api.cpp index 8c2ff60139..91742bed41 100644 --- a/src/embeddings/embeddings_api.cpp +++ b/src/embeddings/embeddings_api.cpp @@ -112,9 +112,12 @@ std::variant EmbeddingsRequest::fromJson(rapidjs if (input_strings.size() > 0) { request.input = input_strings; } - if (input_tokens.size() > 0) { + else if (input_tokens.size() > 0) { request.input = input_tokens; } + else { + return "no input provided in request"; + } return request; } @@ -128,7 +131,7 @@ absl::Status EmbeddingsHandler::parseRequest() { return absl::OkStatus(); } -std::variant, std::vector>>& EmbeddingsHandler::getInput() { +std::variant, std::vector>>& EmbeddingsHandler::getInput() { return request.input; } EmbeddingsRequest::EncodingFormat EmbeddingsHandler::getEncodingFormat() const { diff --git a/src/embeddings/embeddings_api.hpp b/src/embeddings/embeddings_api.hpp index 043cdcd618..c49f1e2ebb 100644 --- a/src/embeddings/embeddings_api.hpp +++ b/src/embeddings/embeddings_api.hpp @@ -36,7 +36,7 @@ struct EmbeddingsRequest { FLOAT, BASE64 }; - std::variant, std::vector>> input; + std::variant, std::vector>> input; EncodingFormat encoding_format; static std::variant fromJson(rapidjson::Document* request); @@ -51,7 +51,7 @@ class EmbeddingsHandler { EmbeddingsHandler(rapidjson::Document& document) : doc(document) {} - std::variant, std::vector>>& getInput(); + std::variant, std::vector>>& getInput(); EmbeddingsRequest::EncodingFormat getEncodingFormat() const; absl::Status parseRequest(); diff --git a/src/test/embeddingsnode_test.cpp b/src/test/embeddingsnode_test.cpp index 61a616a3f3..aa3c6d5388 100644 --- a/src/test/embeddingsnode_test.cpp +++ b/src/test/embeddingsnode_test.cpp @@ -283,6 +283,19 @@ TEST_F(EmbeddingsHttpTest, simplePositiveMultipleStrings) { ASSERT_EQ(d["data"][1]["embedding"].Size(), EMBEDDING_OUTPUT_SIZE); } +TEST_F(EmbeddingsHttpTest, emptyInput) { + std::string requestBody = R"( + { + "model": "embeddings", + "input": [] + } + )"; + Status status = handler->dispatchToProcessor(endpointEmbeddings, requestBody, &response, comp, responseComponents, &writer); + ASSERT_EQ(status, + ovms::StatusCode::MEDIAPIPE_EXECUTION_ERROR) + << status.string(); +} + class EmbeddingsExtensionTest : public ::testing::Test { protected: static std::unique_ptr t;