Skip to content

Commit

Permalink
Upgrade to SynapseAI 1.18 (#227)
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
  • Loading branch information
yuanwu2017 and tthakkal authored Oct 31, 2024
1 parent 7fb4af9 commit 8d84ffa
Show file tree
Hide file tree
Showing 5 changed files with 1,402 additions and 1,267 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ COPY launcher launcher
RUN cargo build --release

# Text Generation Inference base image
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest as base

# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
Expand Down Expand Up @@ -61,7 +61,7 @@ RUN cd server && \
make gen-server && \
pip install -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.17.0 && \
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.18.0 && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir

Expand Down
11 changes: 8 additions & 3 deletions examples/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,18 @@ def get_args():
parser.add_argument(
"--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for datasets"
)

return parser.parse_args()


def read_dataset(
max_input_length: int,
total_sample_count: int,
model_id: str
model_id: str,
seed: int,
) -> List[str]:
"""
Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked
Expand All @@ -51,7 +56,7 @@ def read_dataset(
)
if len(dataset) > total_sample_count:
dataset = dataset.select(range(total_sample_count))
dataset = dataset.shuffle()
dataset = dataset.shuffle(seed=seed)
return [sample["prompt"] for sample in dataset]


Expand All @@ -71,7 +76,7 @@ def is_tgi_available(
def main():
args = get_args()
dataset = read_dataset(
args.max_input_length, args.total_sample_count, args.model_id
args.max_input_length, args.total_sample_count, args.model_id, args.seed
)

if not is_tgi_available(args.server_address):
Expand Down
Loading

0 comments on commit 8d84ffa

Please sign in to comment.