Skip to content

Commit

Permalink
AIP-72: Handle SIGTERM signal on Supervisor
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Dec 9, 2024
1 parent 320bb38 commit 691d183
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
25 changes: 25 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def start(
client=client,
)

log.debug("Started subprocess", pid=pid, ti_id=ti.id, supervisor_pid=os.getpid())

# Set up signal handlers for the Supervisor process
proc._setup_signal_handlers()

# We've forked, but the task won't start until we send it the StartupDetails message. But before we do
# that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any
# reason)
Expand Down Expand Up @@ -366,6 +371,26 @@ def start(
proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs)
return proc

def _setup_signal_handlers(self):
"""
Set up signal handlers for the **supervisor process**.
These handlers catch signals like SIGTERM or SIGSEGV sent to the supervisor,
allowing it to terminate the task process (child process) gracefully.
"""

def signal_handler(signum, frame):
"""Handle termination signals sent to the supervisor."""
log.error(
"Received termination signal in supervisor. Terminating watched subprocess",
signal=signum,
process_pid=self.pid,
supervisor_pid=os.getpid(),
)
self.kill(signal.SIGTERM, force=True)

signal.signal(signal.SIGTERM, signal_handler)

def _register_pipe_readers(
self, logger: FilteringBoundLogger, stdout: socket, stderr: socket, requests: socket, logs: socket
):
Expand Down
27 changes: 26 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,31 @@ def subprocess_main():

assert rc == -9

def test_supervisor_signal_handling(self, mocker):
"""Verify that the supervisor correctly handles signals and terminates the task process."""
mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log")
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")

proc = WatchedSubprocess(
ti_id=TI_ID, pid=12345, stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock()
)

# Send a SIGTERM signal to the supervisor
proc._setup_signal_handlers()
os.kill(os.getpid(), signal.SIGTERM)

# Verify task process termination and log messages
# Asserting that `proc.kill` is called with the correct signal is sufficient to verify the supervisor
# correctly handles the signal and terminates the task process
# The actual signal sent to the task process is tested in `TestWatchedSubprocessKill` class
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.error.assert_called_once_with(
"Received termination signal in supervisor. Terminating watched subprocess",
signal=signal.SIGTERM,
supervisor_pid=os.getpid(),
process_pid=proc.pid,
)

def test_last_chance_exception_handling(self, capfd):
def subprocess_main():
# The real main() in task_runner catches exceptions! This is what would happen if we had a syntax
Expand Down Expand Up @@ -628,7 +653,7 @@ def test_kill_process_custom_signal(self, watched_subprocess, mock_process):
),
],
)
def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch):
def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, monkeypatch):
def subprocess_main():
import signal

Expand Down

0 comments on commit 691d183

Please sign in to comment.