Skip to content

Commit

Permalink
Use filter arg in tarfile.extractall to prevent unsafe unarchival ope…
Browse files Browse the repository at this point in the history
…rations (#2722)

* use filter in tarfile.extractall

* update release notes

* update release notes action

* update docstring
  • Loading branch information
thehomebrewnerd authored May 10, 2024
1 parent 8001f77 commit 5b37bd8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/release_notes_updated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ jobs:
- name: Check for development branch
id: branch
shell: python
env:
REF: ${{ github.event.pull_request.head.ref }}
run: |
from re import compile
main = '^main$'
Expand All @@ -19,7 +21,7 @@ jobs:
min_dep_update = '^min-dep-update-[a-f0-9]{7}$'
regex = main, release, backport, dep_update, min_dep_update
patterns = list(map(compile, regex))
ref = "${{ github.event.pull_request.head.ref }}"
ref = "$REF"
is_dev = not any(pattern.match(ref) for pattern in patterns)
print('::set-output name=is_dev::' + str(is_dev))
- if: ${{ steps.branch.outputs.is_dev == 'true' }}
Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Future Release
* Temporarily restrict Dask version (:pr:`2694`)
* Remove support for creating ``EntitySets`` from Dask or Pyspark dataframes (:pr:`2705`)
* Bump minimum versions of ``tqdm`` and ``pip`` in requirements files (:pr:`2716`)
* Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize EntitySets (:pr:`2722`)
* Documentation Changes
* Testing Changes
* Fix serialization test to work with pytest 8.1.1 (:pr:`2694`)
Expand Down
10 changes: 9 additions & 1 deletion featuretools/entityset/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tarfile
import tempfile
from inspect import getfullargspec

import pandas as pd
import woodwork.type_sys.type_system as ww_type_system
Expand Down Expand Up @@ -140,6 +141,8 @@ def read_data_description(path):
def read_entityset(path, profile_name=None, **kwargs):
"""Read entityset from disk, S3 path, or URL.
NOTE: Never attempt to read an archived EntitySet from an untrusted source.
Args:
path (str): Directory on disk, S3 path, or URL to read `data_description.json`.
profile_name (str, bool): The AWS profile specified to write to S3. Will default to None and search for AWS credentials.
Expand All @@ -159,7 +162,12 @@ def read_entityset(path, profile_name=None, **kwargs):
use_smartopen_es(local_path, path, transport_params)

with tarfile.open(str(local_path)) as tar:
tar.extractall(path=tmpdir)
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
raise RuntimeError(
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
)

data_description = read_data_description(tmpdir)
return description_to_entityset(data_description, **kwargs)
Expand Down
14 changes: 13 additions & 1 deletion featuretools/tests/entityset_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import tempfile
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from urllib.request import urlretrieve

import boto3
Expand Down Expand Up @@ -292,6 +292,18 @@ def test_deserialize_local_tar(es):
assert es.__eq__(new_es, deep=True)


@patch("featuretools.entityset.deserialize.getfullargspec")
def test_deserialize_errors_if_python_version_unsafe(mock_inspect, es):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
with tempfile.TemporaryDirectory() as tmp_path:
temp_tar_filepath = os.path.join(tmp_path, TEST_FILE)
urlretrieve(URL, filename=temp_tar_filepath)
with pytest.raises(RuntimeError, match=""):
deserialize.read_entityset(temp_tar_filepath)


def test_deserialize_url_csv(es):
new_es = deserialize.read_entityset(URL)
assert es.__eq__(new_es, deep=True)
Expand Down

0 comments on commit 5b37bd8

Please sign in to comment.