Skip to content

Commit

Permalink
fixup! AIP-72: Handle SIGTERM signal on Supervisor
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Dec 9, 2024
1 parent 8d5e6c3 commit 3f1a533
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 29 deletions.
28 changes: 18 additions & 10 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,6 @@ def start(

log.debug("Started subprocess", pid=pid, ti_id=ti.id, supervisor_pid=os.getpid())

# Set up signal handlers for the Supervisor process
proc._setup_signal_handlers()

# We've forked, but the task won't start until we send it the StartupDetails message. But before we do
# that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any
# reason)
Expand Down Expand Up @@ -371,23 +368,25 @@ def start(
proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs)
return proc

def _setup_signal_handlers(self):
@classmethod
def setup_signal_handlers(cls):
"""
Set up signal handlers for the **supervisor process**.
These handlers catch signals like SIGTERM or SIGSEGV sent to the supervisor,
allowing it to terminate the task process (child process) gracefully.
These handlers catch signals like SIGTERM sent to the supervisor, allowing it to
terminate the task processes (child processes) gracefully.
"""

def signal_handler(signum, frame):
"""Handle termination signals sent to the supervisor."""
log.error(
"Received termination signal in supervisor. Terminating watched subprocess",
"Received termination signal in supervisor. Terminating all watched subprocesses",
signal=signum,
process_pid=self.pid,
process_pids=list(cls.procs.keys()),
supervisor_pid=os.getpid(),
)
self.kill(signal.SIGTERM, force=True)
for proc in list(cls.procs.values()):
proc.kill(signal.SIGTERM, force=True)

signal.signal(signal.SIGTERM, signal_handler)

Expand Down Expand Up @@ -498,6 +497,7 @@ def kill(

def wait(self) -> int:
if self._exit_code is not None:
self._update_final_ti_state()
return self._exit_code

try:
Expand All @@ -509,6 +509,12 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

self._update_final_ti_state()

return self._exit_code

def _update_final_ti_state(self):
"""Update the TaskInstance state."""
# If the process has finished in a terminal state, update the state of the TaskInstance
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
Expand All @@ -517,7 +523,6 @@ def wait(self) -> int:
self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
return self._exit_code

def _monitor_subprocess(self):
"""
Expand Down Expand Up @@ -870,6 +875,9 @@ def supervise(
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()

# Set up signal handlers for the supervisor process
WatchedSubprocess.setup_signal_handlers()

process = WatchedSubprocess.start(dag_path, ti, client=client, logger=logger)

exit_code = process.wait()
Expand Down
29 changes: 29 additions & 0 deletions task_sdk/tests/dags/sleeper_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.providers.standard.operators.bash import BashOperator
from airflow.sdk.definitions.dag import dag


@dag()
def sleeper_dag():
BashOperator(task_id="sleep", bash_command="sleep 10")


sleeper_dag()
66 changes: 47 additions & 19 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import selectors
import signal
import sys
import threading
from io import BytesIO
from operator import attrgetter
from time import sleep
Expand Down Expand Up @@ -179,31 +180,58 @@ def subprocess_main():

assert rc == -9

def test_supervisor_signal_handling(self, mocker):
"""Verify that the supervisor correctly handles signals and terminates the task process."""
mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log")
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")
def test_supervisor_signal_handling_in_supervise(self, test_dags_dir, captured_logs, mocker):
"""Test that `supervise` handles SIGTERM by terminating all managed subprocesses."""

proc = WatchedSubprocess(
ti_id=TI_ID, pid=12345, stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock()
# TODO: Optimize this test!!
ti = TaskInstance(
id=TI_ID,
task_id="sleep",
dag_id="sleeper_dag",
run_id="test_run",
try_number=1,
)

# Send a SIGTERM signal to the supervisor
proc._setup_signal_handlers()
os.kill(os.getpid(), signal.SIGTERM)
# Spy on `WatchedSubprocess.kill` to track calls
kill_spy = mocker.spy(WatchedSubprocess, "kill")

# Verify task process termination and log messages
# Asserting that `proc.kill` is called with the correct signal is sufficient to verify the supervisor
# correctly handles the signal and terminates the task process
# The actual signal sent to the task process is tested in `TestWatchedSubprocessKill` class
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.error.assert_called_once_with(
"Received termination signal in supervisor. Terminating watched subprocess",
signal=signal.SIGTERM,
supervisor_pid=os.getpid(),
process_pid=proc.pid,
# Simulate sending SIGTERM to the supervisor process
supervisor_pid = os.getpid()

def send_sigterm():
sleep(1) # Allow `supervise` to start
os.kill(supervisor_pid, signal.SIGTERM)

sigterm_thread = threading.Thread(target=send_sigterm)
sigterm_thread.start()

# Run `supervise` (expect it to handle SIGTERM gracefully)
exit_code = supervise(
ti=ti,
dag_path=test_dags_dir / "sleeper_dag.py",
token="fake_token",
dry_run=True,
)

sigterm_thread.join() # Ensure SIGTERM is processed

# Assert the process exited due to SIGTERM
assert exit_code in [-signal.SIGTERM, -signal.SIGKILL]

# Verify `kill` was called for all processes
kill_spy.assert_called_once_with(mocker.ANY, signal.SIGTERM, force=True)

# Validate logs
assert {
"signal": signal.SIGTERM,
"process_pids": list(WatchedSubprocess.procs.keys()),
"supervisor_pid": supervisor_pid,
"event": "Received termination signal in supervisor. Terminating all watched subprocesses",
"timestamp": mocker.ANY,
"level": "error",
"logger": "supervisor",
} in captured_logs

def test_last_chance_exception_handling(self, capfd):
def subprocess_main():
# The real main() in task_runner catches exceptions! This is what would happen if we had a syntax
Expand Down

0 comments on commit 3f1a533

Please sign in to comment.