-
Notifications
You must be signed in to change notification settings - Fork 53
/
config.py
105 lines (89 loc) · 4.26 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, Optional
from ...system_utils import is_rocm_system
from ..config import ScenarioConfig
LOGGER = getLogger("inference")
INPUT_SHAPES = {
"batch_size": 2,
}
@dataclass
class InferenceConfig(ScenarioConfig):
name: str = "inference"
_target_: str = "optimum_benchmark.scenarios.inference.scenario.InferenceScenario"
# benchmark options
iterations: int = field(
default=10,
metadata={
"help": "Minimum number of iterations to run the benchmark. "
"The number of tracked inferences will be at least this value."
"Set to 0 to disable this constraint (benchmark will run for `duration` seconds)."
},
)
duration: int = field(
default=10,
metadata={
"help": "Minimum duration of the benchmark in seconds. "
"The sum of tracked inferences will be at least this value."
"Set to 0 to disable this constraint (benchmark will run for `iterations` iterations)."
},
)
warmup_runs: int = field(
default=10,
metadata={"help": "Number of warmup runs to perform before benchmarking."},
)
# input/output config
input_shapes: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Input shapes for the model. Missing keys will be filled with default values."},
)
new_tokens: Optional[int] = field(
default=None,
metadata={"help": "If set, `max_new_tokens` and `min_new_tokens` will be set to this value."},
)
# tracking options
memory: bool = field(default=False, metadata={"help": "Measure max memory usage"})
latency: bool = field(default=True, metadata={"help": "Measure latencies and throughputs"})
energy: bool = field(default=False, metadata={"help": "Measure energy usage and efficiency"})
# methods kwargs
forward_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the forward method of the backend."}
)
generate_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the generate method of the backend."}
)
call_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the call method of the backend."}
)
def __post_init__(self):
super().__post_init__()
self.input_shapes = {**INPUT_SHAPES, **self.input_shapes}
if self.new_tokens is not None:
LOGGER.warning(
"`new_tokens` is deprecated. Use `max_new_tokens` and `min_new_tokens` instead. "
"Setting `max_new_tokens` and `min_new_tokens` to `new_tokens`."
)
self.generate_kwargs["max_new_tokens"] = self.new_tokens
self.generate_kwargs["min_new_tokens"] = self.new_tokens
if (
"max_new_tokens" in self.generate_kwargs
and "min_new_tokens" in self.generate_kwargs
and self.generate_kwargs["max_new_tokens"] != self.generate_kwargs["min_new_tokens"]
):
raise ValueError(
"Setting `min_new_tokens` and `max_new_tokens` to different values results in non-deterministic behavior."
)
elif "max_new_tokens" in self.generate_kwargs and "min_new_tokens" not in self.generate_kwargs:
LOGGER.warning(
"Setting `max_new_tokens` without `min_new_tokens` results in non-deterministic behavior. "
"Setting `min_new_tokens` to `max_new_tokens`."
)
self.generate_kwargs["min_new_tokens"] = self.generate_kwargs["max_new_tokens"]
elif "min_new_tokens" in self.generate_kwargs and "max_new_tokens" not in self.generate_kwargs:
LOGGER.warning(
"Setting `min_new_tokens` without `max_new_tokens` results in non-deterministic behavior. "
"Setting `max_new_tokens` to `min_new_tokens`."
)
self.generate_kwargs["max_new_tokens"] = self.generate_kwargs["min_new_tokens"]
if self.energy and is_rocm_system():
raise ValueError("Energy measurement through codecarbon is not yet available on ROCm-powered devices.")