Skip to content

Commit

Permalink
Add Claude 2 support (#231)
Browse files Browse the repository at this point in the history
* Add Claude 2 support.

* Add Claude-2 support.
  • Loading branch information
rmitsch authored Jul 25, 2023
1 parent a3ee82e commit 40248a1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
8 changes: 3 additions & 5 deletions spacy_llm/models/rest/anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import srsly # type: ignore[import]
from requests import HTTPError

from ....compat import Literal
from ..base import REST


Expand All @@ -25,10 +24,6 @@ class SystemPrompt(str, Enum):


class Anthropic(REST):
MODEL_NAMES = {
"claude-1": Literal["claude-1", "claude-1-100k"],
}

@property
def credentials(self) -> Dict[str, str]:
# Fetch and check the key, set up headers
Expand Down Expand Up @@ -115,6 +110,9 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return (
# claude-1
"claude-2",
"claude-2-100k",
# claude-1
"claude-1",
"claude-1-100k",
Expand Down
34 changes: 34 additions & 0 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,40 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Claude-2.v1")
def anthropic_claude_2(
config: Dict[Any, Any] = SimpleFrozenDict(),
name: Literal["claude-2", "claude-2-100k"] = "claude-2", # noqa: F722
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Anthropic instance for 'claude-2' model using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (Literal["claude-2", "claude-2-100k"]): Model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to
prompt API.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
)


@registry.llm_models("spacy.Claude-1.v1")
def anthropic_claude_1(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down

0 comments on commit 40248a1

Please sign in to comment.