Skip to content

Release v4.44.0

Compare
Choose a tag to compare
@ArthurZucker ArthurZucker released this 06 Aug 18:39
· 1133 commits to main since this release

Release v4.44.0: End to end compile generation!!! Gemma2 (with assisted decoding), Codestral (Mistral for code), Nemotron, Efficient SFT training, CPU Offloaded KVCache, torch export for static cache

This release comes a bit early in our cycle because we wanted to ship important and requested models along with improved performances for everyone!

All of these are included with examples in the awesome https://github.com/huggingface/local-gemma repository! 🎈 We tried to share examples of what is now possible with all the shipped features! Kudos to @gante, @sanchit-gandhi and @xenova

💥 End-to-end generation compile

Generate: end-to-end compilation #30788 by @gante: model.generate now supports compiling! There are a few limitations, but here is a small snippet:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import copy

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")

# compile generate
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")

# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
generation_config = copy.deepcopy(model.generation_config)
generation_config.pad_token_id = model.config.eos_token_id

model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
model_inputs = model_inputs.to(model.device)
output_compiled = compiled_generate(**model_inputs, generation_config=generation_config)
print(output_compiled)

⚡ 3 to 5x compile speedup (compilation time 👀 not runtime)

  • 3-5x faster torch.compile forward compilation for autoregressive decoder models #32227* by @fxmarty .
    As documented on the PR, this makes the whole generation a lot faster when you re-use the cache!
    You can see this when you run model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

🪶 Offloaded KV cache: offload the cache to CPU when you are GPU poooooor 🚀

  • Offloaded KV Cache #31325* by @n17s : you just have to set cache_implementation="offloaded" when calling from_pretrained or using this:
from transformers import GenerationConfig
gen_config = GenerationConfig(cache_implementation="offloaded", # other generation options such as num_beams=4,num_beam_groups=2,num_return_sequences=4,diversity_penalty=1.0,max_new_tokens=50,early_stopping=True)
outputs = model.generate(inputs["input_ids"],generation_config=gen_config)

📦 Torch export for static cache

pytorch team gave us a great gift: you can now use torch.export directly compatible with Executorch! Find examples here.

This also unlocks support for prompt reuse:

import os, torch, copy
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
device = "cuda"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"

INITIAL_PROMPT = "From now on, you are going to answer all my questions with historical details. Make sure to always add a bit of french here and there, for style."

model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt_cache = DynamicCache()
inputs = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
prompt_cache = model(**inputs, past_key_values = prompt_cache).past_key_values

prompt = "Why are french people obsessed with french?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]
print(response)

prompt = "What is the best city to swim in?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache),max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]

Gemma2: assisted decoding

Gemma 2: support assisted generation #32357 by @gante

We now have a 2B Gemma 2 model -- a perfect sidekick for the 27B with assisted generation. We've enabled assisted generation in gemma 2, with a caveat: assisted generation currently requires the use of a windowless cache (as opposed to the default cache for gemma 2), so you might observe some output mismatch on long sequences. Read more about it here.

# transformers assisted generation reference: 
# https://huggingface.co/docs/transformers/main/en/llm_optims#speculative-decoding 
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# we DON’T recommend using the 9b model with the 2b model as its assistant
assistant_model_name = 'google/gemma-2-2b-it'
reference_model_name = 'google/gemma-2-27b-it'

tokenizer = AutoTokenizer.from_pretrained(reference_model_name)
model = AutoModelForCausalLM.from_pretrained(
   reference_model_name, device_map='auto', torch_dtype=torch.bfloat16
)
assistant_model = AutoModelForCausalLM.from_pretrained(
   assistant_model_name, device_map='auto', torch_dtype=torch.bfloat16
)

model_inputs = tokenizer("Einstein's theory of relativity states", return_tensors="pt").to(model.device)
generation_options = {
   "assistant_model": assistant_model,
   "do_sample": True,
   "temperature": 0.7,
   "max_new_tokens": 64,
}

outputs = model.generate(**model_inputs, **generation_options)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

Nemotron support

image

Nemotron-4-340B-Instruct is a large language model (LLM) that can be used as part of a synthetic data generation pipeline to create training data that helps researchers and developers build their own LLMs. It is a fine-tuned version of the Nemotron-4-340B-Base model, optimized for English-based single and multi-turn chat use-cases. It supports a context length of 4,096 tokens.

The conversion script should be able to cover Minitron and Nemotron, thanks and kudos to @suiyoubi. See:

  • Add Nemotron HF Support #31699

Codestral support

image

Codestral is trained on a diverse dataset of 80+ programming languages, including the most popular ones, such as Python, Java, C, C++, JavaScript, and Bash. It also performs well on more specific ones like Swift and Fortran. This broad language base ensures Codestral can assist developers in various coding environments and projects.

Codestral saves developers time and effort: it can complete coding functions, write tests, and complete any partial code using a fill-in-the-middle mechanism. Interacting with Codestral will help level up the developer’s coding game and reduce the risk of errors and bugs.

It's mamba2 architecture, was a bit of a pain to remove all einops but hope we made it better for everyone!

Breaking changes:

We removed the chat template in the code, they should all be on the hub!

Long-form decoding for whisper, even faster:

Our great @sanchit-gandhi worked on porting the recent compile upgrades to long form decoding in

  • [whisper] compile compatibility with long-form decoding #31772

What's Changed

New Contributors

Full Changelog: v4.43.4...v4.44.0