-
Notifications
You must be signed in to change notification settings - Fork 149
/
CheckpointReader.swift
196 lines (178 loc) · 10 KB
/
CheckpointReader.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// Copyright 2021 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Checkpoints
import Foundation
import ModelSupport
import TensorFlow
public struct NetGConfig: Codable {
public let inChannels: Int
public let outChannels: Int
public let ngf: Int // size of feature map
public let useDropout: Bool
public let lastConvFilters: Int
public let learningRate: Float
public let beta: Float
public let padding: Int?
public let kernelSize: Int
enum CodingKeys: String, CodingKey {
case inChannels = "i_channels"
case outChannels = "o_channels"
case ngf = "ngf"
case useDropout = "useDrop"
case lastConvFilters = "n_lastConvFilters"
case learningRate = "lRate"
case beta = "beta"
case padding = "pad"
case kernelSize = "kSize"
}
}
extension CheckpointReader {
func readTensor<Scalar: TensorFlowScalar>(
name: String
) -> Tensor<Scalar> {
return Tensor<Scalar>(loadTensor(named: name))
}
func readIntTensor<Int: TensorFlowInteger>(
name: String
) -> Tensor<Int> {
return Tensor<Int>(loadTensor(named: name))
}
}
// TODO: Come up with better names for these protocols
protocol InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String)
}
protocol InitializableFromPythonCheckpoint3 {
associatedtype Sublayer: Layer where Sublayer.TangentVector.VectorSpaceScalar == Float, Sublayer.Input == Tensor<Float>, Sublayer.Output == Tensor<Float>
init(reader: CheckpointReader, config: NetGConfig, scope: String, submodule: Sublayer)
}
extension ConvLayer: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
conv2d = Conv2D<Float>(reader: reader, config: config, scope: scope + "/conv2d")
let padding = config.padding
let _padding = padding ?? Int(config.kernelSize / 2)
pad = ZeroPadding2D(padding: ((_padding, _padding), (_padding, _padding)))
}
}
extension Conv2D: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let filter: Tensor<Scalar> = reader.readTensor(name: scope + "/fil")
let bias: Tensor<Scalar> = reader.readTensor(name: scope + "/bias")
// TODO: read/write activation, strides, padding, and dialations from checkpoint file
self.init(filter: filter, bias: bias, strides: (2,2), padding: .same)
}
}
extension BatchNorm: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let axis: Tensor<Scalar> = reader.readTensor(name: scope + "/axis")
let momentum: Tensor<Scalar> = reader.readTensor(name: scope + "/mom")
let epsilon: Tensor<Scalar> = reader.readTensor(name: scope + "/eps")
let axisVals = axis.array.scalars
let axisValScalar = axisVals[0]
let axisVal = Int(axisValScalar)
let momentumVals = momentum.array.scalars
let momentumVal = momentumVals[0]
let epsilonVals = epsilon.array.scalars
let epsilonVal = epsilonVals[0]
let offset: Tensor<Scalar> = reader.readTensor(name: scope + "/off")
let scale: Tensor<Scalar> = reader.readTensor(name: scope + "/sc")
let runningMean: Tensor<Scalar> = reader.readTensor(name: scope + "/rmean")
let runningVariance: Tensor<Scalar> = reader.readTensor(name: scope + "/rvar")
self.init(axis: axisVal, momentum: momentumVal, offset: offset, scale: scale, epsilon: epsilonVal, runningMean: runningMean, runningVariance: runningVariance)
}
}
extension TransposedConv2D: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let filter: Tensor<Scalar> = reader.readTensor(name: scope + "/fil")
let bias: Tensor<Scalar> = reader.readTensor(name: scope + "/bias")
// TODO: read/write activation, strides, and padding from checkpoint file
self.init(filter: filter, bias: bias, strides: (2,2), padding: .same)
}
}
extension UNetSkipConnectionInnermost: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let downConv: Conv2D<Float> = Conv2D(reader: reader, config: config, scope: scope + "/dc")
let upNorm: BatchNorm<Float> = BatchNorm(reader: reader, config: config, scope: scope + "/un")
let upConv: TransposedConv2D<Float> = TransposedConv2D(reader: reader, config: config, scope: scope + "/uc")
self.init(downConv: downConv, upConv: upConv, upNorm: upNorm)
}
}
extension UNetSkipConnection: InitializableFromPythonCheckpoint3 {
init(reader: CheckpointReader, config: NetGConfig, scope: String, submodule: Sublayer) {
let downConv: Conv2D<Float> = Conv2D(reader: reader, config: config, scope: scope + "/dc")
let downNorm: BatchNorm<Float> = BatchNorm(reader: reader, config: config, scope: scope + "/dn")
let upConv: TransposedConv2D<Float> = TransposedConv2D(reader: reader, config: config, scope: scope + "/uc")
let upNorm: BatchNorm<Float> = BatchNorm(reader: reader, config: config, scope: scope + "/un")
let dropOut: Dropout<Float> = Dropout(reader: reader, config: config, scope: scope + "/drop")
self.init(downConv: downConv, downNorm: downNorm, upConv: upConv, upNorm: upNorm, dropOut: dropOut, submodule: submodule)
}
}
extension UNetSkipConnectionOutermost: InitializableFromPythonCheckpoint3 {
init(reader: CheckpointReader, config: NetGConfig, scope: String, submodule: Sublayer) {
let downConv: Conv2D<Float> = Conv2D(reader: reader, config: config, scope: scope + "/dc")
let upConv: TransposedConv2D<Float> = TransposedConv2D(reader: reader, config: config, scope: scope + "/uc")
self.init(downConv: downConv, upConv: upConv, submodule: submodule)
}
}
extension Dropout: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let probability:Tensor<Float> = reader.readTensor(name: scope + "/prob")
let val = probability[0].scalar
self.init(probability: Double(val!))
}
}
extension NetG: InitializableFromPythonCheckpoint2 {
public init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let firstBlock = UNetSkipConnectionInnermost(reader: reader, config: config, scope: scope + "/module/submod/submod/submod/submod/submod/submod/submod")
let module1 = UNetSkipConnection<UNetSkipConnectionInnermost>(reader: reader, config: config, scope: scope + "/module/submod/submod/submod/submod/submod/submod", submodule: firstBlock)
let module2 = UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>(reader: reader, config: config, scope: scope + "/module/submod/submod/submod/submod/submod", submodule: module1)
let module3 = UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>>(reader: reader, config: config, scope: scope + "/module/submod/submod/submod/submod", submodule: module2)
let module4 = UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>>>(reader: reader, config: config, scope: scope + "/module/submod/submod/submod", submodule: module3)
let module5 = UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>>>>(reader: reader, config: config, scope: scope + "/module/submod/submod", submodule: module4)
let module6 = UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>>>>>(reader: reader, config: config, scope: scope + "/module/submod", submodule: module5)
self.module = UNetSkipConnectionOutermost<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnection<UNetSkipConnectionInnermost>>>>>>>(reader: reader, config: config, scope: scope + "/module", submodule: module6)
module.submodule = module6
}
}
// TODO: Convert this to suitable protocol with Discriminator configuration
extension NetD: InitializableFromPythonCheckpoint2 {
init(reader: CheckpointReader, config: NetGConfig, scope: String) {
let conv1 = Conv2D<Float>(reader: reader, config: config, scope: scope + "/conv1")
let fn1 = Function<Tensor<Float>, Tensor<Float>> { leakyRelu($0) }
let conv2 = Conv2D<Float>(reader: reader, config: config, scope: scope + "/conv2")
let bn1 = BatchNorm<Float>(featureCount: 2 * config.lastConvFilters)
let fn2 = Function<Tensor<Float>, Tensor<Float>> { leakyRelu($0) }
let conv3 = Conv2D<Float>(reader: reader, config: config, scope: scope + "/conv3")
let bn2 = BatchNorm<Float>(featureCount: 4 * config.lastConvFilters)
let fn3 = Function<Tensor<Float>, Tensor<Float>> { leakyRelu($0) }
let module = Sequential {
conv1
fn1
conv2
bn1
fn2
conv3
bn2
fn3
}
let module2 = Sequential {
module
ConvLayer(reader: reader, config: config, scope: scope + "convLayer1")
BatchNorm<Float>(reader: reader, config: config, scope: scope)
Function<Tensor<Float>, Tensor<Float>> { leakyRelu($0) }
ConvLayer(reader: reader, config: config, scope: scope + "convLayer2")
}
self.module = module2
}
}