Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snapshotter for torch #2270

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions src/garage/experiment/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import errno
import os
import pathlib
import sys

import cloudpickle
from dowel import logger

# pylint: disable=no-name-in-module

SnapshotConfig = collections.namedtuple(
'SnapshotConfig', ['snapshot_dir', 'snapshot_mode', 'snapshot_gap'])
Expand Down Expand Up @@ -82,6 +86,7 @@ def snapshot_gap(self):
"""
return self._snapshot_gap

# pylint: disable=too-many-branches
def save_itr_params(self, itr, params):
"""Save the parameters if at the right iteration.

Expand All @@ -94,8 +99,13 @@ def save_itr_params(self, itr, params):
"gap_overwrite", "gap_and_last", or "none".

"""
# pylint: disable=import-outside-toplevel
torch = False
if 'torch' in sys.modules:
import torch
from garage.torch import global_device
file_name = None

# pylint: enable=import-outside-toplevel
if self._snapshot_mode == 'all':
file_name = os.path.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
elif self._snapshot_mode == 'gap_overwrite':
Expand All @@ -113,17 +123,32 @@ def save_itr_params(self, itr, params):
file_name = os.path.join(self._snapshot_dir,
'itr_%d.pkl' % itr)
file_name_last = os.path.join(self._snapshot_dir, 'params.pkl')
with open(file_name_last, 'wb') as file:
cloudpickle.dump(params, file)
if torch:
torch.save(params, file_name_last, pickle_module=cloudpickle)
else:
with open(file_name_last, 'wb') as file:
cloudpickle.dump(params, file)
elif self._snapshot_mode == 'none':
pass
else:
raise ValueError('Invalid snapshot mode {}'.format(
self._snapshot_mode))

if file_name:
with open(file_name, 'wb') as file:
cloudpickle.dump(params, file)
if torch:

class _pickle_module:
dump = cloudpickle.dump
Pickler = cloudpickle.CloudPickler

params['global_device'] = global_device()
torch.save(params,
file_name,
pickle_module=_pickle_module,
_use_new_zipfile_serialization=False)
else:
with open(file_name, 'wb') as file:
cloudpickle.dump(params, file)

def load(self, load_dir, itr='last'):
# pylint: disable=no-self-use
Expand All @@ -145,6 +170,12 @@ def load(self, load_dir, itr='last'):
NotAFileError: If the snapshot exists but is not a file.

"""
torch = False
# pylint: disable=import-outside-toplevel
if 'torch' in sys.modules:
import torch
from garage.torch import global_device
# pylint: enable=import-outside-toplevel
if isinstance(itr, int) or itr.isdigit():
load_from_file = os.path.join(load_dir, 'itr_{}.pkl'.format(itr))
else:
Expand All @@ -165,7 +196,13 @@ def load(self, load_dir, itr='last'):

if not os.path.isfile(load_from_file):
raise NotAFileError('File not existing: ', load_from_file)

if torch:
device = global_device()
params = torch.load(load_from_file, map_location=device)
origin_device = params['global_device']
del params['global_device']
logger.log(f'Resuming experiment from {origin_device} on {device}')
return params
with open(load_from_file, 'rb') as file:
return cloudpickle.load(file)

Expand Down
9 changes: 5 additions & 4 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Test fixtures."""
# yapf: disable
from tests.fixtures.fixtures import (snapshot_config,
TfGraphTestCase,
TfTestCase)
from tests.fixtures.fixtures import (reset_gpu_mode, snapshot_config,
TfGraphTestCase, TfTestCase)

# yapf: enable

__all__ = ['snapshot_config', 'TfGraphTestCase', 'TfTestCase']
__all__ = [
'reset_gpu_mode', 'snapshot_config', 'TfGraphTestCase', 'TfTestCase'
]
6 changes: 6 additions & 0 deletions tests/fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from garage.experiment import deterministic
from garage.experiment.snapshotter import SnapshotConfig
from garage.torch import set_gpu_mode

from tests.fixtures.logger import NullOutput

Expand Down Expand Up @@ -64,3 +65,8 @@ def teardown_method(self):
del self.graph
del self.sess
gc.collect()


def reset_gpu_mode():
"""Reset mode to CPU after test."""
set_gpu_mode(False)
4 changes: 3 additions & 1 deletion tests/garage/experiment/test_snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class TestSnapshotter:

def setup_method(self):
# pylint: disable=consider-using-with
self.temp_dir = tempfile.TemporaryDirectory()

def teardown_method(self):
Expand All @@ -44,6 +45,7 @@ def test_snapshotter(self, mode, files):
assert osp.exists(filename)
with open(filename, 'rb') as pkl_file:
data = pickle.load(pkl_file)
snapshot_data[num]['global_device'] = None
assert data == snapshot_data[num]

def test_gap_overwrite(self):
Expand All @@ -60,7 +62,7 @@ def test_gap_overwrite(self):
assert osp.exists(filename)
with open(filename, 'rb') as pkl_file:
data = pickle.load(pkl_file)
assert data == snapshot_data[1]
assert data == {'global_device': None, 'testparam': 4}

def test_invalid_snapshot_mode(self):
with pytest.raises(ValueError):
Expand Down
Loading