-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Rename Recorder and LogReward #2610
Comments
Hello! Can I take this issue? |
Sure! |
And maybe an additional thought: one should be able to define the aggregation function, e.g. "mean", "sum", ... |
maybe something like this? class LogScalar(TrainerHookBase):
def __init__(
self,
key: Union[str, tuple],
logname: str,
log_pbar: bool = False,
reduce_fx: Union[str, Callable] = 'mean',
):
self.logname = logname
self.log_pbar = log_pbar
self.key = key
self.reduce_fx = reduce_fx if callable(reduce_fx) else getattr(torch, reduce_fx)
def __call__(self, batch: TensorDictBase) -> Dict:
if ("collector", "mask") in batch.keys(True):
values = batch.get(self.key)[
batch.get(("collector", "mask"))
]
else:
values = batch.get(self.key)
value = self.reduce_fx(values.float()).item()
return {
self.logname: value,
"log_pbar": self.log_pbar,
}
def register(self, trainer: Trainer, name: str = None):
if name is None:
name = f'log_{self.logname}'
trainer.register_op("pre_steps_log", self)
trainer.register_module(name, self) |
Makes sense, I'd split these things as separate PRs though |
Motivation
When dealing with logging, I found it hard to grasp how to use different loggers and classes. Especially, the Recorder makes it difficult to grasp the idea behind it.
For the LogReward class, I would love to make it more universal since it is actually just a class to log numeric values isn't it?
Solution
Alternatives
Checklist
The text was updated successfully, but these errors were encountered: