-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
☄️ Update Comet integration to include LogCompletionsCallback and Trainer.evaluation_loop() #2501
☄️ Update Comet integration to include LogCompletionsCallback and Trainer.evaluation_loop() #2501
Conversation
…mented related integration test.
…during logging of `game_log` table.
…during logging of `game_log` table.
…during logging of `game_log` table.
…during logging of `game_log` table.
…during logging of `game_log` table.
… during logging of `game_log` table.
|
||
if "comet_ml" in self.args.report_to: | ||
log_table_to_comet_experiment( | ||
name="game_log.csv", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you save it in the output_dir
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comet SDK post submitted files to the server in the background. The intermediate copy of the file is stored in the temporary directory until upload to the Comet server completes.
Could you please elaborate on your idea so I can understand it better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you run experiment.log_table(tabular_data=table, filename=name)
, does it save locally something in a filename name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is saves indeed the table data into temporary file in temporary directory. This file lives until its upload to the Comet server is complete. After that it is automatically cleaned either by Comet SDK or by OS if something goes bad during Python script execution.
Can you screenshot a result? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
And this is a DataFrame encoded as CSV. |
The script I was using to test DPO trainer integration. import os
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
output_dir = "models/minimal/dpo_my"
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
# model_id = "Qwen/Qwen2-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id)
ref_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
training_args = DPOConfig(
output_dir=output_dir,
per_device_train_batch_size=2,
max_steps=1,
remove_unused_columns=False,
gradient_accumulation_steps=8,
precompute_ref_log_probs=False,
learning_rate=5.0e-7,
eval_strategy="steps",
eval_steps=1,
report_to="all",
generate_during_eval=True,
max_length=1024,
)
# dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
dummy_dataset = load_dataset("trl-lib/ultrafeedback_binarized", "default")
dummy_dataset["train"] = dummy_dataset["train"].select(range(20))
dummy_dataset["test"] = dummy_dataset["test"].select(range(40))
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
trainer.train()
trainer.evaluate()
if __name__ == "__main__":
main() Do not forget to set |
@@ -24,6 +24,7 @@ | |||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union | |||
|
|||
import numpy as np | |||
import pandas as pd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pandas
is indeed installed.
trl
-> datasets
->pandas
table = pd.DataFrame( | ||
columns=["Prompt", "Policy"], | ||
data=[ | ||
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the record, this doesn't work, because the pol
can be over-truncated. See
can be reproduced with
from datasets import load_dataset
from trl import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
eval_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:1%]")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10, generate_during_eval=True, eval_steps=2, eval_strategy="steps")
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
But this is out of the scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice, thanks @yaricom!
What does this PR do?
Updated Comet integration to include the following:
LogCompletionsCallback
CPOTrainer.evaluation_loop()
DPOTrainer.evaluation_loop()
BCOTrainer.evaluation_loop()
KTOTrainer.evaluation_loop()
ORPOTrainer.evaluation_loop()
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.