-
Notifications
You must be signed in to change notification settings - Fork 25
/
shrink.py
38 lines (30 loc) · 1.03 KB
/
shrink.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding: utf-8 -*-
"""Shrink model, and save and run.
- Author: Junghoon Kim
- Email: [email protected]
"""
import argparse
import os
import shutil
from src.runners import initialize
from src.runners.shrinker import Shrinker
# arguments
parser = argparse.ArgumentParser(description="Model shrinker.")
parser.add_argument("--gpu", default=0, type=int, help="GPU id to use")
parser.add_argument("--checkpoint", type=str, help="input checkpoint path to quantize")
parser.add_argument("--config", type=str, help="Pruning configuration path")
args = parser.parse_args()
# get config and directory path prefix for logging
config, dir_prefix, device = initialize(
mode="shrink", config_path=args.config, gpu_id=args.gpu
)
assert args.checkpoint and os.path.exists(args.checkpoint), "--checkpoint required"
shutil.copyfile(args.checkpoint, os.path.join(dir_prefix, "orig_model.pth.tar"))
# run quantization
shrinker = Shrinker(
config=config,
checkpoint_path=args.checkpoint,
dir_prefix=dir_prefix,
device=device,
)
shrinker.run()