Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logging logic when in_order is set to True #3280

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/accelerate/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import logging
Expand All @@ -35,6 +36,16 @@ def _should_log(main_process_only):
state = PartialState()
return not main_process_only or (main_process_only and state.is_main_process)

def process(self, msg, kwargs):
msg, kwargs = super().process(msg, kwargs)

# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
kwargs.setdefault("stacklevel", 2)

state = PartialState()
msg = f"[RANK {state.process_index}] {msg}"
return msg, kwargs

def log(self, level, msg, *args, **kwargs):
"""
Delegates logger call after checking if we should log.
Expand All @@ -46,27 +57,24 @@ def log(self, level, msg, *args, **kwargs):
read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
break with the previous behavior.

`in_order` is ignored if `main_process_only` is passed.
`main_process_only` is ignored if `in_order` is passed.
"""
if PartialState._shared_state == {}:
raise RuntimeError(
"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
)
main_process_only = kwargs.pop("main_process_only", True)
in_order = kwargs.pop("in_order", False)
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
kwargs.setdefault("stacklevel", 2)

if self.isEnabledFor(level):
if self._should_log(main_process_only):
msg, kwargs = self.process(msg, kwargs)
msg, kwargs = self.process(msg, kwargs)
if not in_order and self._should_log(main_process_only):
self.logger.log(level, msg, *args, **kwargs)

elif in_order:
state = PartialState()
for i in range(state.num_processes):
if i == state.process_index:
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
state.wait_for_everyone()

Expand All @@ -82,7 +90,7 @@ def warning_once(self, *args, **kwargs):
self.warning(*args, **kwargs)


def get_logger(name: str, log_level: str = None):
def get_logger(name: str, log_level: str | None = None):
"""
Returns a `logging.Logger` for `name` that can handle multiprocessing.

Expand Down
Loading