Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't PromptTuning in Multi-GPU with DeepSpeed and Qwen2.5-14B-Instruct #2266

Open
2 of 4 tasks
dongshou opened this issue Dec 9, 2024 · 2 comments
Open
2 of 4 tasks

Comments

@dongshou
Copy link

dongshou commented Dec 9, 2024

System Info

Name: peft
Version: 0.12.0

Name: transformers
Version: 4.47.0

Name: accelerate
accelerate 0.34.2

Python 3.11.9

cuda
Build cuda_11.8.r11.8/compiler.31833905_0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

1 prompt tuning

model_name_or_path = "/workspace/labels/Qwen2-fintune/qwen/Qwen2.5-14B-Instruct"
tokenizer_name_or_path = "/workspace/labels/Qwen2-fintune/qwen/Qwen2.5-14B-Instruct"

peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=16,
    prompt_tuning_init_text=" prompt text which text length more than 16",
    tokenizer_name_or_path=tokenizer_name_or_path,
)

2. dataset

 dataset have landed from json
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def preprocess_fn(
    examples,
):
    """Preprocesses the data for supervised fine-tuning."""
    # tokenize input
    goal = "xxxxx"

    texts = []
    for query,response in zip(examples['query'],examples['response']):
        msg = [
            {"role": "user", "content": goal+query},
            {"role":"assistant","content":response}
        ]
        texts.append(tokenizer.apply_chat_template(
                    msg,
                    chat_template=qwen_chat_template,
                    tokenize=True,
                    add_generation_prompt=False,
                    padding="max_length",
                    max_length=max_length,
                    truncation=True,
                ))
    input_ids = torch.tensor(texts, dtype=torch.int)
    target_ids = input_ids.clone()
    target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
    # target_ids[target_ids <= tokenizer.assistant] = IGNORE_TOKEN_ID
    attention_mask = input_ids.ne(tokenizer.pad_token_id)
    return dict(
        input_ids=input_ids, labels=target_ids, attention_mask=attention_mask
    )


processed_datasets = dataset.map(
    preprocess_fn,
    batched=True,
    num_proc=16,
    remove_columns=dataset.column_names, #remove unprocessed column for training
    load_from_cache_file=False,
    desc="Running tokenizer on datasset"
)

3. model

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    config= transformers.AutoConfig.from_pretrained(model_name_or_path),
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map = 'balanced'
    )

model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())

4. trainer

from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=default_data_collator,
    args=TrainingArguments(
      output_dir=output_dir,
      per_device_train_batch_size=batch_size,
      num_train_epochs=num_epochs,
      learning_rate=learning_rate,
      lr_scheduler_type='cosine',
      per_device_eval_batch_size=batch_size,
      deepspeed='deepspeed/ds_z2_config.json',
      load_best_model_at_end=False,
      logging_strategy='steps',
      logging_steps=10,
      evaluation_strategy='steps',
      eval_steps=1000,
      save_strategy='steps',
      save_steps=10,
    )
  )

trainer.train()

5. deepspeed config json

{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true,
    "round_robin_gradients": true
  }
}

6 debug info

when mv label to another coda, the label value have been changed!
loss code from transformers/loss/loss_utils.py

def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    print("label before move",shift_labels.min(),shift_labels.max(),shift_labels.shape)
    shift_labels = shift_labels.to(shift_logits.device)
    print("label after move",shift_labels.min(),shift_labels.max(),shift_labels.shape)

    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss

6.1log and error

peft label dtail tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
peft label dtail 2 tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
### label before move tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([8252])
###label after move tensor(0, device='cuda:3') tensor(0, device='cuda:3') torch.Size([8252])
  0%|                                                                                                                          | 1/47500 [00:16<218:02:43, 16.53s/it]peft label dtail tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
peft label dtail 2 tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
### label before move tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([8252])
### label after move tensor(-9223372034707292160, device='cuda:3') tensor(0, device='cuda:3') torch.Size([8252])
Traceback (most recent call last):
  File "/workspace/llm-tuning/prompt_tuning_qa.py", line 173, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2522, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3688, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 2196, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 27, in backward
    grad_input = torch.zeros(
                 ^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.

Expected behavior

expect prompt-tuning with multi-gpu

@hiyamgh
Copy link

hiyamgh commented Dec 10, 2024

Same error here

Code

tokenizer = AutoTokenizer.from_pretrained(f"{model_checkpoint}", add_prefix_space=True)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.pad_token = tokenizer.eos_token

    encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=[col for col in dataset["train"].column_names if col not in labels])
    # example = encoded_dataset['train'][0]
    # print(example.keys())
    # tokenizer.decode(example['input_ids'])

    encoded_dataset.set_format("torch")

    dist.init_process_group(backend="nccl", init_method="env://")
    local_rank = int(os.environ['LOCAL_RANK'])  # Automatically set by SLURM or torchrun
    device = torch.device(f'cuda:{local_rank}')
    torch.cuda.set_device(local_rank)  # Set the GPU for this process

    # Step 1: Initialize the model with empty weights
    llm_model = AutoModelForSequenceClassification.from_pretrained(
        f"{model_checkpoint}",
        problem_type="multi_label_classification",
        num_labels=len(labels),
        id2label=id2label,
        label2id=label2id,
        trust_remote_code=True,
        device_map="auto",  # Distributes the model across GPUs
    )

    # Step 3: Apply LoRA configuration
    llm_peft_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        target_modules="all-linear",
    )

    # Ensure model weights are fully loaded before applying LoRA
    llm_model = get_peft_model(llm_model, llm_peft_config)
    llm_model.print_trainable_parameters()

    batch_size = 8
    metric_name = "f1"

    llm_model = llm_model.to(device)

    model_checkpoint_saving_name = model_checkpoint.replace("/", "-")
    args = TrainingArguments(
        f"{model_checkpoint_saving_name}",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=5,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model=metric_name,
        fp16=True,
        # push_to_hub=True,
        local_rank = int(os.environ['LOCAL_RANK']),  # Set local rank for multi-GPU
    )

    t1 = time.time()
    trainer = Trainer(
        llm_model,
        args,
        train_dataset=encoded_dataset["train"],
        eval_dataset=encoded_dataset["validation"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    trainer.train()
    t2 = time.time()
    print(f'Training time took: {(t2-t1)/60:.3f} mins')

    trainer.evaluate()

@BenjaminBossan
Copy link
Member

@dongshou Thanks for reporting the error.

peft label dtail tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
peft label dtail 2 tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
### label before move tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([8252])
###label after move tensor(0, device='cuda:3') tensor(0, device='cuda:3') torch.Size([8252])
  0%|                                                                                                                          | 1/47500 [00:16<218:02:43, 16.53s/it]peft label dtail tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
peft label dtail 2 tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([4, 2048])
### label before move tensor(-100, device='cuda:0') tensor(151645, device='cuda:0') torch.Size([8252])
### label after move tensor(-9223372034707292160, device='cuda:3') tensor(0, device='cuda:3') torch.Size([8252])

It is strange that the labels are changed, the last one looks like an overflow error. PEFT does not touch the data at all, so I'm fairly certain that this is not a PEFT issue. Could you please try full fine-tuning with the same setup and check if the same error occurs? If you don't have enough memory for full fine-tuning, please test a smaller Qwen model and/or reduce the batch size, but leave the data the same. Please report back if you still get the same type of error with full fine-tuning.

@hiyamgh Do you get the exact same error? Please paste the full error message. Also, are you using FSDP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants