diff --git a/train.py b/train.py index 951bda9914..5ff482f023 100644 --- a/train.py +++ b/train.py @@ -176,6 +176,7 @@ def get_batch(split): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) + state_dict = None # free up memory iter_num = checkpoint['iter_num'] best_val_loss = checkpoint['best_val_loss'] elif init_from.startswith('gpt2'):