Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added first version of custom model. #437

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

JoelNiklaus
Copy link
Contributor

Enables the evaluation of any system in the user's control. Fixes Issue 430.

Try with

python -m lighteval custom google-translate /path/to/google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10

google_translate_model.py

import logging
from typing import Optional

from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
    GenerativeResponse,
    LoglikelihoodResponse,
    LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
    GreedyUntilRequest,
    LoglikelihoodRequest,
    LoglikelihoodRollingRequest,
    LoglikelihoodSingleTokenRequest,
)


logger = logging.getLogger(__name__)


class GoogleTranslateClient(LightevalModel):

    def __init__(self, config, env_config) -> None:
        self.model = config.model
        self.model_definition_file_path = config.model_definition_file_path

        self.model_info = ModelInfo(
            model_name=config.model,
            model_sha="",
            model_dtype=None,
            model_size="",
        )
        
        self._tokenizer = AutoTokenizer.from_pretrained("gpt2")  # Use a dummy tokenizer for compatibility

        import httpcore
        # Needed to fix some googletrans bug
        # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618
        setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')  
        from googletrans import Translator
        self.translator = Translator()

    def greedy_until(
        self,
        requests: list[GreedyUntilRequest],
        override_bs: Optional[int] = None,
    ) -> list[GenerativeResponse]:
        """
        Generates responses using a greedy decoding strategy until certain ending conditions are met.

        Args:
            requests (list[Request]): list of requests containing the context and ending conditions.
            disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
            override_bs (int, optional): Override the batch size for generation. Defaults to None.

        Returns:
            list[GenerativeResponse]: list of generated responses.
        """
        for request in requests:
            request.tokenized_context = self.tok_encode(request.context)

        dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
        results = []

        for _ in tqdm(
            dataset.splits_start_end_iterator(),
            total=dataset.num_dataset_splits,
            desc="Splits",
            position=0,
            disable=False,  # self.disable_tqdm,
        ):
            for r in tqdm(dataset, desc="Batch", position=1, disable=False):
                context = r.context.replace("French phrase: ", "")
                # TODO: Get src and dest from request
                translation = self.translator.translate(context, src='fr', dest='de')


                result = translation.text
                cur_response = GenerativeResponse(
                    result=result,
                    logits=None, 
                    generated_tokens=[],
                    input_tokens=[],
                )
                results.append(cur_response)


        return dataset.get_original_order(results)

    @property
    def tokenizer(self):
        return self._tokenizer

    def tok_encode(self, text: str):
        return self.tokenizer.encode(text)

    @property
    def add_special_tokens(self) -> bool:
        return False

    @property
    def max_length(self) -> int:
        """Return the maximum sequence length of the model."""
        return 4096

    def loglikelihood(
        self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
    ) -> list[LoglikelihoodResponse]:
        """Tokenize the context and continuation and compute the log likelihood of those
        tokenized sequences.
        """
        raise NotImplementedError

    def loglikelihood_rolling(
        self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
    ) -> list[LoglikelihoodResponse]:
        """This function is used to compute the log likelihood of the context for perplexity metrics."""
        raise NotImplementedError

    def loglikelihood_single_token(
        self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
    ) -> list[LoglikelihoodSingleTokenResponse]:
        """Tokenize the context and continuation and compute the log likelihood of those
        tokenized sequences.
        """
        raise NotImplementedError

Copy link
Member

@clefourrier clefourrier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice, exactly what I had in mind!

No big comments on the main PR, but

  • you could add your model class in examples.
  • you need to update the doc pages to explain how this works
  • it would be good to add a small test to our suite for this feature

I'll try to run it this afternoon and if all goes well and you update the doc, we'll be good to go!

src/lighteval/main_custom.py Outdated Show resolved Hide resolved
src/lighteval/main_custom.py Outdated Show resolved Hide resolved
@clefourrier
Copy link
Member

clefourrier commented Dec 12, 2024

Hahaha please also provide an explicit requirements files :)

@JoelNiklaus
Copy link
Contributor Author

The explicit requirement file is only needed for the google translate example, right? Where should I add that?

@clefourrier
Copy link
Member

google_translate_model_requirements.txt for now, next to the py file

@JoelNiklaus
Copy link
Contributor Author

Great, fixed the things. @clefourrier ready for review again.

@NathanHB
Copy link
Member

NathanHB commented Dec 16, 2024

Hi @JoelNiklaus ! Great PR, howveer, just tried it and it does not seem to work.

When running:

lighteval custom google-translate google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10
AttributeError: 'NoneType' object has no attribute 'group'
│ ❱  79 │   │   │   │   translation = self.translator.translate(context, src='fr', dest='de')      │
│    80 │   │   │   │                                                                              │
│    81 │   │   │   │                                                                              │
│    82 │   │   │   │   result = translation.text                                                  │
│                                                                                                  │
│ ╭─────────────────────────────────────────── locals ───────────────────────────────────────────╮ │
│ │           _ = (0, 10)                                                                        │ │
│ │     context = '"J\'aimerais faire du parlement européen une instance plus démocratique, plus │ │
│ │               ouv'+296                                                                       │ │

deps:

httpx==0.28.1
googletrans==3.0.0

@JoelNiklaus
Copy link
Contributor Author

Hmm, would you mind trying an environment with the requirements in examples/custom_models/google-translate-requirements-freeze.txt?

@JoelNiklaus
Copy link
Contributor Author

JoelNiklaus commented Dec 17, 2024

@NathanHB I added another custom model example at examples/custom_models/local_mt_model.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FT] Enable the evaluation of any function
3 participants