diff --git a/src/accelerate/logging.py b/src/accelerate/logging.py index 1615bc313b7..9132f8cbf4b 100644 --- a/src/accelerate/logging.py +++ b/src/accelerate/logging.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import functools import logging @@ -35,6 +36,16 @@ def _should_log(main_process_only): state = PartialState() return not main_process_only or (main_process_only and state.is_main_process) + def process(self, msg, kwargs): + msg, kwargs = super().process(msg, kwargs) + + # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice + kwargs.setdefault("stacklevel", 2) + + state = PartialState() + msg = f"[RANK {state.process_index}] {msg}" + return msg, kwargs + def log(self, level, msg, *args, **kwargs): """ Delegates logger call after checking if we should log. @@ -46,7 +57,7 @@ def log(self, level, msg, *args, **kwargs): read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not break with the previous behavior. - `in_order` is ignored if `main_process_only` is passed. + `main_process_only` is ignored if `in_order` is passed. """ if PartialState._shared_state == {}: raise RuntimeError( @@ -54,19 +65,16 @@ def log(self, level, msg, *args, **kwargs): ) main_process_only = kwargs.pop("main_process_only", True) in_order = kwargs.pop("in_order", False) - # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice - kwargs.setdefault("stacklevel", 2) if self.isEnabledFor(level): - if self._should_log(main_process_only): - msg, kwargs = self.process(msg, kwargs) + msg, kwargs = self.process(msg, kwargs) + if not in_order and self._should_log(main_process_only): self.logger.log(level, msg, *args, **kwargs) elif in_order: state = PartialState() for i in range(state.num_processes): if i == state.process_index: - msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) state.wait_for_everyone() @@ -82,7 +90,7 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -def get_logger(name: str, log_level: str = None): +def get_logger(name: str, log_level: str | None = None): """ Returns a `logging.Logger` for `name` that can handle multiprocessing. diff --git a/tests/test_logging.py b/tests/test_logging.py index a91c609ddc0..a7be5d71f0b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -53,6 +53,7 @@ def test_log_stack(caplog): ) message = "Test" + expected_message, _ = logger.process(message, {}) lineno = current_lineno() + 1 # the next line is the actual callsite logger.warning(message) @@ -63,7 +64,7 @@ def test_log_stack(caplog): assert rec.name == __name__ assert rec.lineno == lineno assert rec.funcName == test_log_stack.__name__ - assert rec.message == message + assert rec.message == expected_message @pytest.mark.usefixtures("accelerator") @@ -76,6 +77,7 @@ def test_custom_stacklevel(caplog): logger = CustomLogger(wrapped_logger, {}) message = "Test" + expected_message, _ = wrapped_logger.process(message, {}) lineno = current_lineno() + 1 # the next line is the actual callsite logger.warning(message) @@ -88,4 +90,4 @@ def test_custom_stacklevel(caplog): assert rec.name == __name__ assert rec.lineno == lineno assert rec.funcName == test_custom_stacklevel.__name__ - assert rec.message == message + assert rec.message == expected_message