forked from openvinotoolkit/training_extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
138 lines (109 loc) · 4.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
MIT License
Copyright (c) 2018 Kaiyang Zhou
"""
"""
Copyright (c) 2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import sys
import os
import os.path as osp
import time
import argparse
import torch
import torch.nn as nn
from config.default_config import (
get_default_config, imagedata_kwargs, videodata_kwargs,
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
)
import torchreid
from torchreid.utils import (
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
load_pretrained_weights, compute_model_complexity, collect_env_info
)
from data.datamanager import ImageDataManagerWithTransforms
from engine.builder import build_engine
from engine.schedulers.lr_scheduler import build_lr_scheduler
from models.builder import build_model
def build_datamanager(cfg):
if cfg.data.type == 'image':
return ImageDataManagerWithTransforms(**imagedata_kwargs(cfg))
else:
return torchreid.data.VideoDataManager(**videodata_kwargs(cfg))
def reset_config(cfg, args):
if args.root:
cfg.data.root = args.root
if args.sources:
cfg.data.sources = args.sources
if args.targets:
cfg.data.targets = args.targets
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config-file', type=str, default='', help='path to config file')
parser.add_argument('-s', '--sources', type=str, nargs='+', help='source datasets (delimited by space)')
parser.add_argument('-t', '--targets', type=str, nargs='+', help='target datasets (delimited by space)')
parser.add_argument('--root', type=str, default='', help='path to data root')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
help='Modify config options using the command-line')
args = parser.parse_args()
cfg = get_default_config()
cfg.use_gpu = torch.cuda.is_available()
if args.config_file:
cfg.merge_from_file(args.config_file)
reset_config(cfg, args)
cfg.merge_from_list(args.opts)
cfg.freeze()
set_random_seed(cfg.train.seed)
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
print('Show configuration\n{}\n'.format(cfg))
print('Collecting env info ...')
print('** System info **\n{}\n'.format(collect_env_info()))
if cfg.use_gpu:
torch.backends.cudnn.benchmark = True
datamanager = build_datamanager(cfg)
print('Building model: {}'.format(cfg.model.name))
model = build_model(
name=cfg.model.name,
num_classes=datamanager.num_train_pids,
loss=cfg.loss.name,
pretrained=cfg.model.pretrained,
use_gpu=cfg.use_gpu,
dropout_cfg=cfg.model.dropout,
feature_dim=cfg.model.feature_dim,
fpn_cfg=cfg.model.fpn,
pooling_type=cfg.model.pooling_type,
input_size=(cfg.data.height, cfg.data.width),
IN_first=cfg.model.IN_first,
extra_blocks=cfg.model.extra_blocks
)
num_params, flops = compute_model_complexity(model, (1, 3, cfg.data.height, cfg.data.width))
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
if cfg.model.load_weights and check_isfile(cfg.model.load_weights):
load_pretrained_weights(model, cfg.model.load_weights)
if cfg.use_gpu:
model = nn.DataParallel(model).cuda()
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
scheduler = build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg))
if cfg.model.resume and check_isfile(cfg.model.resume):
args.start_epoch = resume_from_checkpoint(cfg.model.resume, model, optimizer=optimizer)
if len(cfg.model.openvino.name):
from models.openvino_wrapper import OpenVINOModel
openvino_model = OpenVINOModel(cfg.model.openvino.name, cfg.model.openvino.cpu_extension)
else:
openvino_model = None
print('Building {}-engine for {}-reid'.format(cfg.loss.name, cfg.data.type))
engine = build_engine(cfg, datamanager, model, optimizer, scheduler, openvino_model=openvino_model)
engine.run(**engine_run_kwargs(cfg))
if __name__ == '__main__':
main()