Skip to content

Commit

Permalink
Implement TGI model config from path (#448)
Browse files Browse the repository at this point in the history
Implement TGI model config from path:
```python
TGIModelConfig.from_path(model_config_path)
```

Follow-up to:
- #434 

Related to:
- #439
  • Loading branch information
albertvillanova authored Dec 17, 2024
1 parent 0ebc7ec commit 1b9e2c3
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 8 deletions.
9 changes: 1 addition & 8 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def tgi(
"""
Evaluate models using TGI as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.tgi_model import TGIModelConfig
Expand All @@ -332,14 +331,8 @@ def tgi(

# TODO (nathan): better handling of model_args
parallelism_manager = ParallelismManager.TGI
with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

model_config = TGIModelConfig(
inference_server_address=config["instance"]["inference_server_address"],
inference_server_auth=config["instance"]["inference_server_auth"],
model_id=config["instance"]["model_id"],
)
model_config = TGIModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
8 changes: 8 additions & 0 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def __post_init__(self):

@classmethod
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
"""Load configuration for inference endpoint model from YAML file path.
Args:
path (`str`): Path of the model configuration YAML file.
Returns:
[`InferenceEndpointModelConfig`]: Configuration for inference endpoint model.
"""
import yaml

with open(path, "r") as f:
Expand Down
16 changes: 16 additions & 0 deletions src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ class TGIModelConfig:
inference_server_auth: str
model_id: str

@classmethod
def from_path(cls, path: str) -> "TGIModelConfig":
"""Load configuration for TGI endpoint model from YAML file path.
Args:
path (`str`): Path of the model configuration YAML file.
Returns:
[`TGIModelConfig`]: Configuration for TGI endpoint model.
"""
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["instance"])


# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
# the client functions, since they use a different client.
Expand Down
File renamed without changes.
42 changes: 42 additions & 0 deletions tests/models/endpoints/test_tgi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import asdict

import pytest

from lighteval.models.endpoints.tgi_model import TGIModelConfig


class TestTGIModelConfig:
@pytest.mark.parametrize(
"config_path, expected_config",
[
(
"examples/model_configs/tgi_model.yaml",
{"inference_server_address": "", "inference_server_auth": None, "model_id": None},
),
],
)
def test_from_path(self, config_path, expected_config):
config = TGIModelConfig.from_path(config_path)
assert asdict(config) == expected_config

0 comments on commit 1b9e2c3

Please sign in to comment.