Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW] Llama3.2 weight converters 🦙 #255

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TJ-Solergibert
Copy link
Contributor

@TJ-Solergibert TJ-Solergibert commented Nov 28, 2024

Hi!

In this branch I reintroduce & update to the current main branch the Llama model & conversion scripts to support Llama3.1 and Llama3.2 1B&3B models.
The main changes are the following:

  1. Deleted Flash Attention RoPEs as they don't support rope scaling of Llama3
  2. Copied + Cleaning transformers LlamaRotaryEmbedding layer. Now this will be the only class in llama.py. I think it shouldn't break generations for the inference case WITHOUT LlamaConfig.rope_interleaved = True in CausalSelfAttention.forward, are there any tests?
  3. Added @eliebak fixes (resuming checkpoint without lr schedule or optimizer state #253) to just load weights from a checkpoint. I would suggest adding a config.optimizer.finetuning flag in order to (True) just load the weights or (False) Load weights, optimizer & LR Scheduler instead of config.checkpoints.load_optimizer & config.checkpoints.load_lr_scheduler
  4. Switched from flash_attn_varlen_func to flash_attn_func as the later achieves greater performance. Keep in mind that we aren't using any feature of the varlen funct so it's recommended to stick with flash_attn_func
  5. Should we depreciate LlamaConfig.rope_interleaved ? It was useful for training when using FlashAttention RoPEs and now seems to be used also in the inference code. IMO we should unify all 3 cases (Training, inference with rope_interleaved & inference without rope interleaved) within a single RoPE

Results

You can run the conversions & generations tests using the scripts in tools/converters. As I already mentioned in the previous PR (#174), despite we need at least 1 GPU (To init the ParallelContext) we are running the conversion with the CPU.

  1. Generate HF predictions with the HF Hub model
torchrun --nproc-per-node 1 tools/converters/delete/generate_hf_predictions.py --pretrained-model-name-or-path meta-llama/Llama-3.2-3B
  1. Convert the HF model to Nanotron
torchrun --nproc-per-node 1 tools/converters/convert_hf_to_nanotron.py --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B --pretrained-model-name-or-path meta-llama/Llama-3.2-3B
  1. Generate predictions with the nanotron converted model
torchrun --nproc-per-node 1 tools/converters/delete/generate_nanotron_predictions.py --tp 1 --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B
  1. Convert the model back to HF
torchrun --nproc-per-node 1 tools/converters/convert_nanotron_to_hf.py --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B --hugging-face-checkpoint-path checkpoints/huggingface_converted/Converted-Nanotron-Llama-3.2-3B
  1. Generate predictions using the converted back HF model
torchrun --nproc-per-node 1 tools/converters/delete/generate_hf_predictions.py --pretrained-model-name-or-path checkpoints/huggingface_converted/Converted-Nanotron-Llama-3.2-3B

As can be seen from the following table, we observe slightly differences between the 2 backends. Those differences are produced by the QKV projections in the CausalSelfAttention layer (Nanotron computes them in a single GEMM vs 3 different GEMMs in HF) and the LayerNorm layer is different (Nanotron is using a optimized one from FlashAttention vs Basic PyTorch LayerNorm in HF). Also note that the differences increase if we use TP which is totally expected as the sizes of the GEMMs are different, triggering different GEMM algorithms.

Experiment Backend Size TP Accuracy
OG HF HF 3 1 0.73046875
OG HF --> Nanotron Nanotron 3 1 0.7265625
OG HF --> Nanotron --> HF HF 3 1 0.73046875
OG HF --> Nanotron Nanotron 3 2 0.703125
OG HF --> Nanotron Nanotron 3 4 0.65234375

To run the Nanotron generations with different TP sizes:

torchrun --nproc-per-node 2 tools/converters/delete/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B
torchrun --nproc-per-node 4 tools/converters/delete/generate_nanotron_predictions.py --tp 4 --nanotron-checkpoint-path checkpoints/nanotron_pretrained_checkpoints/Nanotron-Llama-3.2-3B

TODO (Preferably in other PRs):

  • Add docs
  • Edit examples/config_llama3.2-3B.yaml data_stages.data.dataset.dataset_folder
  • Remove nanotron/tools/converters/delete/generate_hf_predictions.py & nanotron/tools/converters/delete/generate_nanotron_predictions.py scripts
  • Integrate Liger kernels for apply_rotary_pos_emb
  • Switch to SDPA instead of FA2
  • Reduce number of transposes in CausalSelfAttention.forward

@Lauler
Copy link

Lauler commented Dec 15, 2024

Have you managed to train with tp=4 after converting llama 3.2 from HF to Nanotron? From your earlier Llama3 PR you wrote the conversion could be done with tp=dp=pp=1, and that it was TP agnostic.

When using your conversion script above for Llama 3.2 3B model it works fine for tp=2, but runs into tensor size mismatch trying to train with tp=4:

[3236:3]:RuntimeError: The expanded size of the tensor (32128) must match the existing size (31872) at non-singleton dimension 0.  Target sizes: [32128, 3072].  Tensor sizes: [31872, 3072]

traceback.txt

(using your llama 3.2 yaml script in examples/ as a template for starting continued pretraining)

Copy link
Member

@NouamaneTazi NouamaneTazi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR @TJ-Solergibert! Thanks
Added some small qsts before merging

@@ -0,0 +1,73 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to be pushed? 👀


# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attn_output = flash_attn_varlen_func(
attn_output = flash_attn_func(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is faster but only for causal masks. How do you deal with kv cache in inference? Are generations the same with and without use_kv_cache?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants