Skip to content

Commit

Permalink
Merge pull request #180 from lean-dojo/kaiyu
Browse files Browse the repository at this point in the history
Use `pexpect` instead of `signal` in dojo.py
  • Loading branch information
yangky11 authored Jul 10, 2024
2 parents 3742322 + 6ce9e9c commit 501a573
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 97 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "LeanDojo"
copyright = "2023, LeanDojo Team"
author = "Kaiyu Yang"
release = "2.0.0"
release = "2.0.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ pretty = True
implicit_reexport = True
disallow_untyped_calls = False
follow_imports = skip

[mypy-pexpect.*]
ignore_missing_imports = True
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude = [

[project]
name = "lean-dojo"
version = "2.0.0"
version = "2.0.1"
authors = [
{ name="Kaiyu Yang", email="[email protected]" },
]
Expand All @@ -32,6 +32,7 @@ dependencies = [
"loguru",
"filelock",
"psutil",
"pexpect",
"types-psutil",
"tqdm",
"toml",
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TimeoutError,
TacticResult,
DojoCrashError,
DojoHardTimeoutError,
DojoTacticTimeoutError,
DojoInitError,
Dojo,
ProofFinished,
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

load_dotenv()

__version__ = "2.0.0"
__version__ = "2.0.1"

logger.remove()
if "VERBOSE" in os.environ or "DEBUG" in os.environ:
Expand Down
129 changes: 38 additions & 91 deletions src/lean_dojo/interaction/dojo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import os
import sys
import json
import time
import shlex
import signal
import psutil
import pexpect
import tempfile
import subprocess
from pathlib import Path
from loguru import logger
from dataclasses import dataclass, field
Expand All @@ -21,9 +18,6 @@
from ..data_extraction.traced_data import TracedFile, get_code_without_comments


_REPL_PROMPT = "REPL>"


@dataclass(frozen=True)
class CommandState:
id: int = field(compare=False)
Expand Down Expand Up @@ -87,14 +81,21 @@ def is_out_of_memory(self) -> bool:
return str(self) == "OOM"


class DojoHardTimeoutError(Exception):
class DojoTacticTimeoutError(Exception):
pass


class DojoInitError(Exception):
pass


def kill_descendants(pid: int) -> None:
try:
_kill_descendants(psutil.Process(pid))
except psutil.NoSuchProcess:
pass


def _kill_descendants(proc: psutil.Process) -> None:
for child in proc.children():
_kill_descendants(child)
Expand All @@ -108,7 +109,6 @@ class Dojo:
"""Gym-like environment for programmatic interaction with Lean through tactics or commands."""

entry: Union[Theorem, Tuple[LeanGitRepo, Path, int]]
hard_timeout: Optional[float]
additional_imports: List[str]
repo: LeanGitRepo
file_path: Path
Expand All @@ -120,7 +120,7 @@ class Dojo:
def __init__(
self,
entry: Union[Theorem, Tuple[LeanGitRepo, Path, int]],
hard_timeout: Optional[float] = None,
timeout: int = 600,
additional_imports: List[str] = [],
):
"""Initialize Dojo.
Expand All @@ -130,10 +130,10 @@ def __init__(
the :class:`Dojo` object enables interaction with the theorem through tactics.
When a tuple of (repo, file_path, line_nb) is given (only supported in Lean 4),
the :class:`Dojo` object enables interaction with Lean through commands (similar to a REPL).
hard_timeout (Optional[float], optional): Hard timeout in seconds. Defaults to None.
timeout (int): The maximum number of seconds for a single interaction (e.g., tactic).
"""
self.entry = entry
self.hard_timeout = hard_timeout
self.timeout = timeout
self.additional_imports = additional_imports

if self.uses_tactics:
Expand All @@ -146,11 +146,6 @@ def __init__(
self.repo, self.file_path, _ = entry
self.file_path = Path(self.file_path)

if self.hard_timeout is None:
logger.warning(
"Running tactics without a hard timeout may hang indefinitely."
)

@property
def uses_tactics(self) -> bool:
return isinstance(self.entry, Theorem)
Expand All @@ -162,7 +157,6 @@ def uses_commands(self) -> bool:
def __enter__(self) -> Tuple["Dojo", State]:
"""Initialize Dojo."""
logger.debug(f"Initializing Dojo for {self.entry}")
self._install_handlers()

# Replace the human-written proof with a `repl` tactic.
traced_repo_path = get_traced_repo_path(self.repo)
Expand All @@ -180,14 +174,8 @@ def __enter__(self) -> Tuple["Dojo", State]:
memory_limit = 1024 * int(TACTIC_MEMORY_LIMIT[:-1])
modified_path = Path(self.modified_file.name).relative_to(traced_repo_path)
cmd = f"lake env lean --threads={TACTIC_CPU_LIMIT} --memory={memory_limit} {modified_path}"
self.proc = subprocess.Popen(
shlex.split(cmd),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
encoding="utf-8",
bufsize=1,
self.proc = pexpect.spawn(
cmd, timeout=self.timeout, maxread=1, encoding="utf-8", echo=False
)

# Get the initial tactic state.
Expand Down Expand Up @@ -217,41 +205,12 @@ def __enter__(self) -> Tuple["Dojo", State]:
init_state = CommandState(int(res["sid"]))

self.start_time = time.monotonic()
self._set_timer()

return self, init_state

def _locate_traced_file(self, traced_repo_path: Path) -> TracedFile:
json_path = to_json_path(traced_repo_path, self.file_path, self.repo)
return TracedFile.from_traced_file(traced_repo_path, json_path, self.repo)

def _set_timer(self) -> None:
if self.hard_timeout is not None:
signal.signal(signal.SIGALRM, self._handle_hard_timeout)
signal.alarm(int(self.hard_timeout))

def _cancel_timer(self) -> None:
if self.hard_timeout is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, signal.SIG_DFL)

def _handle_hard_timeout(self, signum: Any, frame: Any) -> None:
logger.debug(f"Hard timeout in {self}")
self.has_timedout = True
raise DojoHardTimeoutError()

def _install_handlers(self) -> None:
self.old_sigint = signal.signal(signal.SIGINT, self._exit_gracefully)
self.old_sigterm = signal.signal(signal.SIGTERM, self._exit_gracefully)

def _uninstall_handlers(self) -> None:
signal.signal(signal.SIGINT, self.old_sigint)
signal.signal(signal.SIGTERM, self.old_sigterm)

def _exit_gracefully(self, signum: Any, frame: Any) -> None:
logger.debug("Exiting gracefully.")
sys.exit(-1)

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
"""Exit Dojo.
Expand All @@ -261,12 +220,8 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
exc_tb (None): _description_
"""
logger.debug("Cleaning up.")
self._cancel_timer()
try:
_kill_descendants(psutil.Process(self.proc.pid))
self.modified_file.__exit__(exc_type, exc_val, exc_tb)
finally:
self._uninstall_handlers()
kill_descendants(self.proc.pid)
self.modified_file.__exit__(exc_type, exc_val, exc_tb)

def _post_process(self, tactic_state: str) -> str:
"""Post-process the pretty-printed tactic state.
Expand Down Expand Up @@ -372,8 +327,6 @@ def run_tac(self, state: TacticState, tactic: str) -> TacticResult:
if res["error"] is not None:
if "proof contains `sorry`" in res["error"]:
return ProofGivenUp()
elif "try_for_time tactic failed, timeout" in res["error"]:
return TimeoutError(res["error"].strip())
else:
return LeanError(res["error"].strip())
elif res["tacticState"] == "no goals":
Expand Down Expand Up @@ -415,11 +368,9 @@ def _submit_request(self, req: str) -> Dict[str, Any]:
Returns:
Dict[str, Any]: _description_
"""
if self.proc.stdin is None:
raise RuntimeError("self.proc.stdin is not initialized")
self._check_alive()
logger.debug(req)
self.proc.stdin.write(req + "\n")
self.proc.sendline(req)
try:
res, msg = self._read_next_line()
except EOFError:
Expand All @@ -433,10 +384,11 @@ def _submit_request(self, req: str) -> Dict[str, Any]:
return result

def _check_alive(self) -> None:
exit_code = self.proc.poll()
if exit_code is None:
if self.proc.isalive():
return
elif exit_code == 137:
exit_code = self.proc.exitstatus
assert exit_code is not None
if exit_code == 137:
raise DojoCrashError("OOM")
else:
raise DojoCrashError(f"Unknown exit code: {exit_code}")
Expand All @@ -452,28 +404,23 @@ def _read_next_line(self) -> Tuple[str, str]:
Returns:
str: _description_
"""
if self.proc.stdout is None:
raise RuntimeError("self.proc.stout is not initialized")
_REPL_PROMPT = "REPL>"
msg: List[str] = []
while True:
line = self.proc.stdout.readline().strip()
logger.debug(line)
if line == "":
raise EOFError
if line.startswith(_REPL_PROMPT):
try:
index = self.proc.expect(["\n", f"{_REPL_PROMPT}.*?\n"])
if index == 0:
if self.proc.before == "":
raise EOFError
else:
msg.append(self.proc.before.strip())
continue
self._check_alive()
return line[len(_REPL_PROMPT) :].strip(), "\n".join(msg)
elif "error: " in line:
if (
"error: deep recursion was detected" in line
or "error: [fatal] not_a_theorem" in line
):
self.is_crashed = True
raise DojoCrashError(line)
elif "error: unknown package" in line:
self.is_crashed = True
raise DojoInitError(line)
else:
pass
else:
msg.append(line)
res = self.proc.match.string[len(_REPL_PROMPT) :].strip()
return res, "\n".join(msg) + self.proc.before
except pexpect.EOF:
raise EOFError
except pexpect.TIMEOUT:
logger.debug(f"Tactic timed out")
self.has_timedout = True
raise DojoTacticTimeoutError()
4 changes: 2 additions & 2 deletions tests/interaction/test_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ def test_timeout_1(lean4_example_repo: LeanGitRepo) -> None:
"Lean4Example.lean",
"hello_world",
)
with Dojo(thm, hard_timeout=10) as (dojo, init_state):
with pytest.raises(DojoHardTimeoutError):
with Dojo(thm) as (dojo, init_state):
with pytest.raises(DojoTacticTimeoutError):
dojo.run_tac(init_state, "sleep 99999999999999")

0 comments on commit 501a573

Please sign in to comment.