Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix signal handling in SubprocessManager #2039

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 151 additions & 27 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,62 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple


def kill_process_and_descendants(pid, termination_timeout):
def send_signals(pid, signal):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
subprocess.check_call(["kill", "-TERM", str(pid)])
except subprocess.CalledProcessError:
pass
retcode = subprocess.call(["pkill", signal, "-P", str(pid)])
# 2: Invalid options
# 3: No processes matched
if retcode == 2 or retcode == 3:
print(f"'pkill {signal} -P {pid}' failed with return code: {retcode}.")

retcode = subprocess.call(["kill", signal, str(pid)])
if retcode != 0:
print(f"'kill {signal} {pid}' failed with return code: {retcode}.")


def kill_process_and_descendants(pid, termination_timeout):
send_signals(pid, "-TERM")

time.sleep(termination_timeout)

send_signals(pid, "-KILL")


def kill_processes_and_descendants(pids, termination_timeout):
for pid in pids:
send_signals(pid, "-TERM")

time.sleep(termination_timeout)

try:
subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
subprocess.check_call(["kill", "-KILL", str(pid)])
except subprocess.CalledProcessError:
pass
for pid in pids:
send_signals(pid, "-KILL")


async def async_send_signals(pids, signal):
pkill_processes = [
await asyncio.create_subprocess_exec("pkill", signal, "-P", str(pid))
for pid in pids
]

for proc in pkill_processes:
await proc.wait()

kill_processes = [
await asyncio.create_subprocess_exec("kill", signal, str(pid)) for pid in pids
]

for proc in kill_processes:
await proc.wait()


async def async_kill_processes_and_descendants(pids, termination_timeout):
await async_send_signals(pids, "-TERM")

await asyncio.sleep(termination_timeout)

await async_send_signals(pids, "-KILL")


class LogReadTimeoutError(Exception):
Expand All @@ -42,6 +81,18 @@ class SubprocessManager(object):
def __init__(self):
self.commands: Dict[int, CommandManager] = {}

try:

async def handle_sigint():
await self._async_handle_sigint()

asyncio.get_running_loop().add_signal_handler(
signal.SIGINT, lambda: asyncio.create_task(handle_sigint())
)
Comment on lines +86 to +91
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closure required in order to capture self, also asyncio signal handlers can't be async.


except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)

async def __aenter__(self) -> "SubprocessManager":
return self

Expand Down Expand Up @@ -81,8 +132,12 @@ def run_command(
"""

command_obj = CommandManager(command, env, cwd)
pid = command_obj.run(show_output=show_output)
pid = command_obj.run(show_output=show_output, wait=False)

self.commands[pid] = command_obj

command_obj.sync_wait()

return pid

async def async_run_command(
Expand Down Expand Up @@ -138,6 +193,42 @@ def cleanup(self) -> None:
for v in self.commands.values():
v.cleanup()

async def kill(self, termination_timeout: float = 5):
"""
Kill all managed subprocesses and their descendants.

Parameters
----------
termination_timeout : float, default 5
The time to wait after sending a SIGTERM to a subprocess and its descendants
before sending a SIGKILL.
"""

pids = [v.process.pid for v in self.commands.values() if v.process is not None]
await async_kill_processes_and_descendants(pids, 5)

def sync_kill(self, termination_timeout: float = 5):
"""
Kill all managed subprocesses and their descendants synchronously.

Parameters
----------
termination_timeout : float, default 5
The time to wait after sending a SIGTERM to a subprocess and its descendants
before sending a SIGKILL.
"""
pids = [v.process.pid for v in self.commands.values() if v.process is not None]
kill_processes_and_descendants(
pids,
termination_timeout,
)

def _handle_sigint(self, signum, frame):
self.sync_kill()

async def _async_handle_sigint(self):
await self.kill()


class CommandManager(object):
"""A manager for an individual subprocess."""
Expand Down Expand Up @@ -169,11 +260,11 @@ def __init__(
self.cwd = cwd if cwd is not None else os.getcwd()

self.process = None
self.stdout_thread = None
self.stderr_thread = None
self.run_called: bool = False
self.log_files: Dict[str, str] = {}

signal.signal(signal.SIGINT, self._handle_sigint)

async def __aenter__(self) -> "CommandManager":
return self

Expand Down Expand Up @@ -221,11 +312,23 @@ async def wait(
"within %s seconds." % (self.process.pid, command_string, timeout)
)

def run(self, show_output: bool = False):
def sync_wait(self):
"""
Run the subprocess synchronously. This can only be called once.
Wait for the subprocess to finish synchronously.

This also waits on the process implicitly.
You can only call `sync_wait` if `run` has already been called.
"""

if not self.run_called:
raise RuntimeError("No command run yet to wait for...")

self.process.wait()
self.stdout_thread.join()
self.stderr_thread.join()

def run(self, show_output: bool = False, wait: bool = True) -> int:
"""
Run the subprocess synchronously. This can only be called once.

Parameters
----------
Expand All @@ -234,6 +337,10 @@ def run(self, show_output: bool = False):
They can be accessed later by reading the files present in:
- self.log_files["stdout"]
- self.log_files["stderr"]
wait : bool, default True
Wait for the process to finish before returning.
If false, the process will run in the background. You can then wait on
the process (using `sync_wait`) or kill it (using `sync_kill`).
"""

if not self.run_called:
Expand Down Expand Up @@ -265,22 +372,22 @@ def stream_to_stdout_and_file(pipe, log_file):

self.run_called = True

stdout_thread = threading.Thread(
self.stdout_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stdout, stdout_logfile),
)
stderr_thread = threading.Thread(
self.stderr_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stderr, stderr_logfile),
)

stdout_thread.start()
stderr_thread.start()

self.process.wait()
self.stdout_thread.start()
self.stderr_thread.start()

stdout_thread.join()
stderr_thread.join()
if wait:
self.process.wait()
self.stdout_thread.join()
self.stderr_thread.join()

return self.process.pid
except Exception as e:
Expand Down Expand Up @@ -457,8 +564,25 @@ async def kill(self, termination_timeout: float = 5):
else:
print("No process to kill.")

def _handle_sigint(self, signum, frame):
asyncio.create_task(self.kill())
def sync_kill(self, termination_timeout: float = 5):
"""
Kill the subprocess and its descendants synchronously.

Parameters
----------
termination_timeout : float, default 5
The time to wait after sending a SIGTERM to the process and its descendants
before sending a SIGKILL.
"""

if self.process is not None:
send_signals(self.process.pid, "-TERM")

time.sleep(termination_timeout)

send_signals(self.process.pid, "-KILL")
else:
print("No process to kill.")


async def main():
Expand Down
Loading