Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug in gradient accumulation #495

Open
yxchng opened this issue Dec 26, 2024 · 0 comments
Open

Potential bug in gradient accumulation #495

yxchng opened this issue Dec 26, 2024 · 0 comments

Comments

@yxchng
Copy link

yxchng commented Dec 26, 2024

for epoch_idx in range(args.num_epochs):
b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt)
minibatch_idx = 0
for mini_batch_start in range(
0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size
):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
# print("micro batch start", micro_batch_start, self.rank)
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds]
vpred_temp = get_reward(
self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length
)
vpred_temp = vpred_temp[0]
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1)
self.value_model.backward(vf_loss * args.vf_coef)
self.value_model.step()
new_logprobs = self.forward(
mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature
)
# if self.rank==0:
# print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}")
new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss
self.model.backward(loss)
# print("backward loss", self.rank, "micro batch start", micro_batch_start)
# print("trying to step", self.rank, "micro batch start", micro_batch_start)
self.model.step()
# print("step", self.rank, "micro batch start", micro_batch_start)
with torch.no_grad():
# print("waiting for value model step", self.rank, "micro batch start", micro_batch_start)
# vf_loss, vf_clipfrac = ray.get(value_model_step_future)
vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1)
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
# print("value model stepped", self.rank, "micro batch start", micro_batch_start)
# prob_dist = torch.nn.functional.softmax(logits, dim=-1)
# entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
# entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# fmt: off
del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1
del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss
# del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max
# del vf_loss, vf_clipfrac, pg_clipfrac, approxkl
# fmt: on
# del everything and empty cache
torch.cuda.empty_cache()
del b_inds, mini_batch_inds

Currently, this code steps both value_model and model in each local mini batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant