-
Notifications
You must be signed in to change notification settings - Fork 559
/
nets.py
144 lines (127 loc) · 8.18 KB
/
nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import division
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.layers.python.layers import utils
import numpy as np
# Range of disparity/inverse depth values
DISP_SCALING = 10
MIN_DISP = 0.01
def resize_like(inputs, ref):
iH, iW = inputs.get_shape()[1], inputs.get_shape()[2]
rH, rW = ref.get_shape()[1], ref.get_shape()[2]
if iH == rH and iW == rW:
return inputs
return tf.image.resize_nearest_neighbor(inputs, [rH.value, rW.value])
def pose_exp_net(tgt_image, src_image_stack, do_exp=True, is_training=True):
inputs = tf.concat([tgt_image, src_image_stack], axis=3)
H = inputs.get_shape()[1].value
W = inputs.get_shape()[2].value
num_source = int(src_image_stack.get_shape()[3].value//3)
with tf.variable_scope('pose_exp_net') as sc:
end_points_collection = sc.original_name_scope + '_end_points'
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
normalizer_fn=None,
weights_regularizer=slim.l2_regularizer(0.05),
activation_fn=tf.nn.relu,
outputs_collections=end_points_collection):
# cnv1 to cnv5b are shared between pose and explainability prediction
cnv1 = slim.conv2d(inputs,16, [7, 7], stride=2, scope='cnv1')
cnv2 = slim.conv2d(cnv1, 32, [5, 5], stride=2, scope='cnv2')
cnv3 = slim.conv2d(cnv2, 64, [3, 3], stride=2, scope='cnv3')
cnv4 = slim.conv2d(cnv3, 128, [3, 3], stride=2, scope='cnv4')
cnv5 = slim.conv2d(cnv4, 256, [3, 3], stride=2, scope='cnv5')
# Pose specific layers
with tf.variable_scope('pose'):
cnv6 = slim.conv2d(cnv5, 256, [3, 3], stride=2, scope='cnv6')
cnv7 = slim.conv2d(cnv6, 256, [3, 3], stride=2, scope='cnv7')
pose_pred = slim.conv2d(cnv7, 6*num_source, [1, 1], scope='pred',
stride=1, normalizer_fn=None, activation_fn=None)
pose_avg = tf.reduce_mean(pose_pred, [1, 2])
# Empirically we found that scaling by a small constant
# facilitates training.
pose_final = 0.01 * tf.reshape(pose_avg, [-1, num_source, 6])
# Exp mask specific layers
if do_exp:
with tf.variable_scope('exp'):
upcnv5 = slim.conv2d_transpose(cnv5, 256, [3, 3], stride=2, scope='upcnv5')
upcnv4 = slim.conv2d_transpose(upcnv5, 128, [3, 3], stride=2, scope='upcnv4')
mask4 = slim.conv2d(upcnv4, num_source * 2, [3, 3], stride=1, scope='mask4',
normalizer_fn=None, activation_fn=None)
upcnv3 = slim.conv2d_transpose(upcnv4, 64, [3, 3], stride=2, scope='upcnv3')
mask3 = slim.conv2d(upcnv3, num_source * 2, [3, 3], stride=1, scope='mask3',
normalizer_fn=None, activation_fn=None)
upcnv2 = slim.conv2d_transpose(upcnv3, 32, [5, 5], stride=2, scope='upcnv2')
mask2 = slim.conv2d(upcnv2, num_source * 2, [5, 5], stride=1, scope='mask2',
normalizer_fn=None, activation_fn=None)
upcnv1 = slim.conv2d_transpose(upcnv2, 16, [7, 7], stride=2, scope='upcnv1')
mask1 = slim.conv2d(upcnv1, num_source * 2, [7, 7], stride=1, scope='mask1',
normalizer_fn=None, activation_fn=None)
else:
mask1 = None
mask2 = None
mask3 = None
mask4 = None
end_points = utils.convert_collection_to_dict(end_points_collection)
return pose_final, [mask1, mask2, mask3, mask4], end_points
def disp_net(tgt_image, is_training=True):
H = tgt_image.get_shape()[1].value
W = tgt_image.get_shape()[2].value
with tf.variable_scope('depth_net') as sc:
end_points_collection = sc.original_name_scope + '_end_points'
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
normalizer_fn=None,
weights_regularizer=slim.l2_regularizer(0.05),
activation_fn=tf.nn.relu,
outputs_collections=end_points_collection):
cnv1 = slim.conv2d(tgt_image, 32, [7, 7], stride=2, scope='cnv1')
cnv1b = slim.conv2d(cnv1, 32, [7, 7], stride=1, scope='cnv1b')
cnv2 = slim.conv2d(cnv1b, 64, [5, 5], stride=2, scope='cnv2')
cnv2b = slim.conv2d(cnv2, 64, [5, 5], stride=1, scope='cnv2b')
cnv3 = slim.conv2d(cnv2b, 128, [3, 3], stride=2, scope='cnv3')
cnv3b = slim.conv2d(cnv3, 128, [3, 3], stride=1, scope='cnv3b')
cnv4 = slim.conv2d(cnv3b, 256, [3, 3], stride=2, scope='cnv4')
cnv4b = slim.conv2d(cnv4, 256, [3, 3], stride=1, scope='cnv4b')
cnv5 = slim.conv2d(cnv4b, 512, [3, 3], stride=2, scope='cnv5')
cnv5b = slim.conv2d(cnv5, 512, [3, 3], stride=1, scope='cnv5b')
cnv6 = slim.conv2d(cnv5b, 512, [3, 3], stride=2, scope='cnv6')
cnv6b = slim.conv2d(cnv6, 512, [3, 3], stride=1, scope='cnv6b')
cnv7 = slim.conv2d(cnv6b, 512, [3, 3], stride=2, scope='cnv7')
cnv7b = slim.conv2d(cnv7, 512, [3, 3], stride=1, scope='cnv7b')
upcnv7 = slim.conv2d_transpose(cnv7b, 512, [3, 3], stride=2, scope='upcnv7')
# There might be dimension mismatch due to uneven down/up-sampling
upcnv7 = resize_like(upcnv7, cnv6b)
i7_in = tf.concat([upcnv7, cnv6b], axis=3)
icnv7 = slim.conv2d(i7_in, 512, [3, 3], stride=1, scope='icnv7')
upcnv6 = slim.conv2d_transpose(icnv7, 512, [3, 3], stride=2, scope='upcnv6')
upcnv6 = resize_like(upcnv6, cnv5b)
i6_in = tf.concat([upcnv6, cnv5b], axis=3)
icnv6 = slim.conv2d(i6_in, 512, [3, 3], stride=1, scope='icnv6')
upcnv5 = slim.conv2d_transpose(icnv6, 256, [3, 3], stride=2, scope='upcnv5')
upcnv5 = resize_like(upcnv5, cnv4b)
i5_in = tf.concat([upcnv5, cnv4b], axis=3)
icnv5 = slim.conv2d(i5_in, 256, [3, 3], stride=1, scope='icnv5')
upcnv4 = slim.conv2d_transpose(icnv5, 128, [3, 3], stride=2, scope='upcnv4')
i4_in = tf.concat([upcnv4, cnv3b], axis=3)
icnv4 = slim.conv2d(i4_in, 128, [3, 3], stride=1, scope='icnv4')
disp4 = DISP_SCALING * slim.conv2d(icnv4, 1, [3, 3], stride=1,
activation_fn=tf.sigmoid, normalizer_fn=None, scope='disp4') + MIN_DISP
disp4_up = tf.image.resize_bilinear(disp4, [np.int(H/4), np.int(W/4)])
upcnv3 = slim.conv2d_transpose(icnv4, 64, [3, 3], stride=2, scope='upcnv3')
i3_in = tf.concat([upcnv3, cnv2b, disp4_up], axis=3)
icnv3 = slim.conv2d(i3_in, 64, [3, 3], stride=1, scope='icnv3')
disp3 = DISP_SCALING * slim.conv2d(icnv3, 1, [3, 3], stride=1,
activation_fn=tf.sigmoid, normalizer_fn=None, scope='disp3') + MIN_DISP
disp3_up = tf.image.resize_bilinear(disp3, [np.int(H/2), np.int(W/2)])
upcnv2 = slim.conv2d_transpose(icnv3, 32, [3, 3], stride=2, scope='upcnv2')
i2_in = tf.concat([upcnv2, cnv1b, disp3_up], axis=3)
icnv2 = slim.conv2d(i2_in, 32, [3, 3], stride=1, scope='icnv2')
disp2 = DISP_SCALING * slim.conv2d(icnv2, 1, [3, 3], stride=1,
activation_fn=tf.sigmoid, normalizer_fn=None, scope='disp2') + MIN_DISP
disp2_up = tf.image.resize_bilinear(disp2, [H, W])
upcnv1 = slim.conv2d_transpose(icnv2, 16, [3, 3], stride=2, scope='upcnv1')
i1_in = tf.concat([upcnv1, disp2_up], axis=3)
icnv1 = slim.conv2d(i1_in, 16, [3, 3], stride=1, scope='icnv1')
disp1 = DISP_SCALING * slim.conv2d(icnv1, 1, [3, 3], stride=1,
activation_fn=tf.sigmoid, normalizer_fn=None, scope='disp1') + MIN_DISP
end_points = utils.convert_collection_to_dict(end_points_collection)
return [disp1, disp2, disp3, disp4], end_points