Skip to content

Commit

Permalink
[#432] Add Groq Provider - chat completions (#609)
Browse files Browse the repository at this point in the history
# What does this PR do?

Contributes towards issue (#432)

- Groq text chat completions
- Streaming
- All the sampling params that Groq supports

A lot of inspiration taken from @mattf's good work at
#355

**What this PR does not do**

- Tool calls (Future PR)
- Adding llama-guard model
- See if we can add embeddings

### PR Train

- #609 👈 
- #630


## Test Plan

<details>

<summary>Environment</summary>

```bash
export GROQ_API_KEY=<api_key>

wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml

# Build and run environment
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
  --port 5001
```

</details>

<details>

<summary>Manual tests</summary>

Using this jupyter notebook to test manually:
https://github.com/aidando73/llama-stack/blob/2140976d76ee7ef46025c862b26ee87585381d2a/hello.ipynb

Use this code to test passing in the api key from provider_data

```
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5001",
)

response = client.inference.chat_completion(
    model_id="Llama3.2-3B-Instruct",
    messages=[
        {"role": "user", "content": "Hello, world client!"},
    ],
    # Test passing in groq_api_key from the client
    # Need to comment out the groq_api_key in the run.yaml file
    x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}',
    # stream=True,
)
response
```

</details>

<details>
<summary>Integration</summary>

`pytest llama_stack/providers/tests/inference/test_text_inference.py -v
-k groq`

(run in same environment)

```
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED                 [  6%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED  [ 37%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED                 [ 56%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED  [ 87%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%]

======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ========================================
```
</details>

<details>
<summary>Unit tests</summary>

`pytest llama_stack/providers/tests/inference/groq/ -v`

```
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED            [  5%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED       [ 40%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED                [ 45%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED  [ 60%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED        [ 65%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED                            [ 95%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED                   [100%]

==================================================== 20 passed, 11 warnings in 0.08s =====================================================
```

</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation
- [x] Wrote necessary unit or integration tests.
  • Loading branch information
aidando73 authored Jan 3, 2025
1 parent e3f187f commit e1f42eb
Show file tree
Hide file tree
Showing 10 changed files with 692 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
| Groq | Hosted | | :heavy_check_mark: | | | |
| Ollama | Single Node | | :heavy_check_mark: | | | |
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
Expand Down
10 changes: 10 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
Expand Down
26 changes: 26 additions & 0 deletions llama_stack/providers/remote/inference/groq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pydantic import BaseModel

from llama_stack.apis.inference import Inference

from .config import GroqConfig


class GroqProviderDataValidator(BaseModel):
groq_api_key: str


async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter

if not isinstance(config, GroqConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")

adapter = GroqInferenceAdapter(config)
return adapter
19 changes: 19 additions & 0 deletions llama_stack/providers/remote/inference/groq/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
150 changes: 150 additions & 0 deletions llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import warnings
from typing import AsyncIterator, List, Optional, Union

from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId

from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)

_MODEL_ALIASES = [
build_model_alias(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]


class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
_config: GroqConfig

def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._config = config

def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
)

request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
)

response = self._get_client().chat.completions.create(**request)

if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)

async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

def _get_client(self) -> Groq:
if self._config.api_key is not None:
return Groq(api_key=self.config.api_key)
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key:
raise ValueError(
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)
153 changes: 153 additions & 0 deletions llama_stack/providers/remote/inference/groq/groq_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import warnings
from typing import AsyncGenerator, Literal

from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)

from groq.types.chat.completion_create_params import CompletionCreateParams

from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
StopReason,
)


def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
"""
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
Warns client if request contains unsupported features.
"""

if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")

if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")

if request.sampling_params.repetition_penalty != 1.0:
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
# seem to have different semantics
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")

if request.tools:
warnings.warn("tools are not supported yet")

return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
logprobs=None,
frequency_penalty=None,
stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
)


def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == "system":
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.role == "user":
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == "assistant":
return ChatCompletionAssistantMessageParam(
role="assistant", content=message.content
)
else:
raise ValueError(f"Invalid message role: {message.role}")


def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)


def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"]
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls"]
- stop -> model hit a natural stop point or a provided stop sequence
- length -> maximum number of tokens specified in the request was reached
- tool_calls -> model called a tool
"""
if finish_reason == "stop":
return StopReason.end_of_turn
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")


async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:

event_type = ChatCompletionResponseEventType.start
for chunk in stream:
choice = chunk.choices[0]

# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)
)
Loading

0 comments on commit e1f42eb

Please sign in to comment.