Skip to content

Commit

Permalink
Merge pull request #192 from lean-dojo/dev
Browse files Browse the repository at this point in the history
Incorporate Recent Commits
  • Loading branch information
yangky11 authored Aug 6, 2024
2 parents 8134232 + 297b96e commit b9d2115
Show file tree
Hide file tree
Showing 18 changed files with 649 additions and 192 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/type_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- name: Install dependencies
run: pip install ".[all]"
- name: Type Check (mypy)
run: mypy src/lean_dojo/interaction
run: mypy src/lean_dojo
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ dmypy.json

# Pyre type checker
.pyre/

# vscode debug config
.vscode/
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "LeanDojo"
copyright = "2023, LeanDojo Team"
author = "Kaiyu Yang"
release = "2.0.3"
release = "2.1.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,13 @@ disallow_untyped_calls = False
follow_imports = skip

[mypy-pexpect.*]
ignore_missing_imports = True

[mypy-lxml.*]
ignore_missing_imports = True

[mypy-tqdm.*]
ignore_missing_imports = True

[mypy-networkx.*]
ignore_missing_imports = True
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude = [

[project]
name = "lean-dojo"
version = "2.0.3"
version = "2.1.0"
authors = [
{ name="Kaiyu Yang", email="[email protected]" },
]
Expand All @@ -31,6 +31,7 @@ dependencies = [
"python-dotenv",
"loguru",
"filelock",
"gitpython",
"psutil",
"pexpect",
"types-psutil",
Expand Down
13 changes: 7 additions & 6 deletions src/lean_dojo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

load_dotenv()

__version__ = "2.0.3"
__version__ = "2.1.0"

logger.remove()
if "VERBOSE" in os.environ or "DEBUG" in os.environ:
Expand Down Expand Up @@ -71,15 +71,16 @@
assert re.fullmatch(r"\d+g", TACTIC_MEMORY_LIMIT)


def check_git_version(min_version: Tuple[int, int, int]) -> Tuple[int, int, int]:
def check_git_version(min_version: Tuple[int, int, int]) -> None:
"""Check the version of Git installed on the system."""
res = subprocess.run("git --version", shell=True, capture_output=True, check=True)
output = res.stdout.decode()
output = res.stdout.decode().strip()
error = res.stderr.decode()
assert error == "", error
m = re.match(r"git version (?P<version>[0-9.]+)", output)
version = tuple(int(_) for _ in m["version"].split("."))

m = re.search(r"git version (\d+\.\d+\.\d+)", output)
assert m, f"Could not parse Git version from: {output}"
# Convert version number string to tuple of integers
version = tuple(int(_) for _ in m.group(1).split("."))
version_str = ".".join(str(_) for _ in version)
min_version_str = ".".join(str(_) for _ in min_version)
assert (
Expand Down
10 changes: 6 additions & 4 deletions src/lean_dojo/data_extraction/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node":
return subcls.from_data(node_data, lean_file)

@classmethod
def _kind_to_node_type(cls, kind: str) -> type:
def _kind_to_node_type(cls, kind: str) -> type["Node"]:
prefix = "Lean.Parser."
if kind.startswith(prefix):
kind = kind[len(prefix) :]
Expand Down Expand Up @@ -83,7 +83,7 @@ def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node":
start = Pos.from_str(tree.attrib["start"]) if "start" in tree.attrib else None
end = Pos.from_str(tree.attrib["end"]) if "end" in tree.attrib else None
children = [Node.from_xml(subtree, lean_file) for subtree in tree]
kwargs = {}
kwargs: Dict[str, Any] = {}

for field in subcls.__dataclass_fields__.values():
if field.name in ("lean_file", "start", "end", "children"):
Expand Down Expand Up @@ -113,11 +113,13 @@ def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node":

return subcls(lean_file, start, end, children, **kwargs) # type: ignore

def get_closure(self) -> Tuple[Pos, Pos]:
def get_closure(self) -> Tuple[Optional[Pos], Optional[Pos]]:
return self.start, self.end


def _parse_pos(info: Dict[str, Any], lean_file: LeanFile) -> Pos:
def _parse_pos(
info: Dict[str, Any], lean_file: LeanFile
) -> Optional[Tuple[Optional[Pos], Optional[Pos]]]:
if "synthetic" in info and not info["synthetic"]["canonical"]:
return None

Expand Down
57 changes: 24 additions & 33 deletions src/lean_dojo/data_extraction/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from pathlib import Path
from loguru import logger
from filelock import FileLock
from typing import Optional, Generator
from dataclasses import dataclass, field
from typing import Optional, Tuple, Generator

from ..utils import (
execute,
url_exists,
get_repo_info,
report_critical_failure,
)
from ..constants import (
Expand All @@ -23,22 +22,6 @@
)


def _split_git_url(url: str) -> Tuple[str, str]:
"""Split a Git URL into user name and repo name."""
if url.endswith("/"):
url = url[:-1]
assert not url.endswith("/"), f"Unexpected URL: {url}"
fields = url.split("/")
user_name = fields[-2]
repo_name = fields[-1]
return user_name, repo_name


def _format_dirname(url: str, commit: str) -> str:
user_name, repo_name = _split_git_url(url)
return f"{user_name}-{repo_name}-{commit}"


_CACHE_CORRPUTION_MSG = "The cache may have been corrputed!"


Expand All @@ -59,16 +42,20 @@ def __post_init__(self):
lock_path = self.cache_dir.with_suffix(".lock")
object.__setattr__(self, "lock", FileLock(lock_path))

def get(self, url: str, commit: str) -> Optional[Path]:
"""Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found."""
_, repo_name = _split_git_url(url)
dirname = _format_dirname(url, commit)
def get(self, rel_cache_dir: Path) -> Optional[Path]:
"""Get the cache repo at ``CACHE_DIR / rel_cache_dir`` from the cache.
Args:
rel_cache_dir (Path): The relative path of the stored repo in the cache.
"""
dirname = rel_cache_dir.parent
dirpath = self.cache_dir / dirname
cache_path = self.cache_dir / rel_cache_dir

with self.lock:
if dirpath.exists():
assert (dirpath / repo_name).exists()
return dirpath / repo_name
assert cache_path.exists()
return cache_path

elif not DISABLE_REMOTE_CACHE:
url = os.path.join(REMOTE_CACHE_URL, f"{dirname}.tar.gz")
Expand All @@ -83,23 +70,27 @@ def get(self, url: str, commit: str) -> Optional[Path]:
with tarfile.open(f"{dirpath}.tar.gz") as tar:
tar.extractall(self.cache_dir)
os.remove(f"{dirpath}.tar.gz")
assert (dirpath / repo_name).exists()
assert (cache_path).exists()

return dirpath / repo_name
return cache_path

else:
return None

def store(self, src: Path) -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache."""
url, commit = get_repo_info(src)
dirpath = self.cache_dir / _format_dirname(url, commit)
_, repo_name = _split_git_url(url)
def store(self, src: Path, rel_cache_dir: Path) -> Path:
"""Store a repo at path ``src``. Return its path in the cache.
Args:
src (Path): Path to the repo.
rel_cache_dir (Path): The relative path of the stored repo in the cache.
"""
dirpath = self.cache_dir / rel_cache_dir.parent
cache_path = self.cache_dir / rel_cache_dir
if not dirpath.exists():
with self.lock:
with report_critical_failure(_CACHE_CORRPUTION_MSG):
shutil.copytree(src, dirpath / repo_name)
return dirpath / repo_name
shutil.copytree(src, cache_path)
return cache_path


cache = Cache(CACHE_DIR)
Expand Down
Loading

0 comments on commit b9d2115

Please sign in to comment.