Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed zero3 how to save custom model? #3309

Open
NLPJCL opened this issue Dec 21, 2024 · 0 comments
Open

deepspeed zero3 how to save custom model? #3309

NLPJCL opened this issue Dec 21, 2024 · 0 comments

Comments

@NLPJCL
Copy link

NLPJCL commented Dec 21, 2024

DeepSpeedEngine(
(module): LLMDecoder(
(model): Qwen2ForSequenceClassification(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1536)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1536, out_features=1536, bias=True)
(k_proj): Linear(in_features=1536, out_features=256, bias=True)
(v_proj): Linear(in_features=1536, out_features=256, bias=True)
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
(up_proj): Linear(in_features=1536, out_features=8960, bias=False)
(down_proj): Linear(in_features=8960, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((0,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((0,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((0,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(score): Linear(in_features=1536, out_features=1, bias=False)
)
)
)
Hello, the above is my model structure. In short, I use a custom LLMDecoder, which has a variable named model which is a Qwen2ForSequenceClassification object.
In this case, how should I save the model in deepspeed zero3?

The following code is not suitable for my model structure, how should I modify it?

unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant