-
Notifications
You must be signed in to change notification settings - Fork 313
/
common.hpp
523 lines (437 loc) · 21.4 KB
/
common.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
#ifndef __COMMON_HPP__
#define __COMMON_HPP__
#include "ggml_extend.hpp"
class DownSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
bool vae_downsample;
public:
DownSampleBlock(int channels,
int out_channels,
bool vae_downsample = false)
: channels(channels),
out_channels(out_channels),
vae_downsample(vae_downsample) {
if (vae_downsample) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
} else {
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
x = conv->forward(ctx, x);
}
return x; // [N, out_channels, h/2, w/2]
}
};
class UpSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
public:
UpSampleBlock(int channels,
int out_channels)
: channels(channels),
out_channels(out_channels) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
return x;
}
};
class ResBlock : public GGMLBlock {
protected:
// network hparams
int64_t channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4)
int64_t emb_channels; // time_embed_dim
int64_t out_channels; // mult * model_channels
std::pair<int, int> kernel_size;
int dims;
bool skip_t_emb;
bool exchange_temb_dims;
std::shared_ptr<GGMLBlock> conv_nd(int dims,
int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> padding) {
GGML_ASSERT(dims == 2 || dims == 3);
if (dims == 3) {
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
}
}
public:
ResBlock(int64_t channels,
int64_t emb_channels,
int64_t out_channels,
std::pair<int, int> kernel_size = {3, 3},
int dims = 2,
bool exchange_temb_dims = false,
bool skip_t_emb = false)
: channels(channels),
emb_channels(emb_channels),
out_channels(out_channels),
kernel_size(kernel_size),
dims(dims),
skip_t_emb(skip_t_emb),
exchange_temb_dims(exchange_temb_dims) {
std::pair<int, int> padding = {kernel_size.first / 2, kernel_size.second / 2};
blocks["in_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(channels));
// in_layer_1 is nn.SILU()
blocks["in_layers.2"] = conv_nd(dims, channels, out_channels, kernel_size, padding);
if (!skip_t_emb) {
// emb_layer_0 is nn.SILU()
blocks["emb_layers.1"] = std::shared_ptr<GGMLBlock>(new Linear(emb_channels, out_channels));
}
blocks["out_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
// out_layer_1 is nn.SILU()
// out_layer_2 is nn.Dropout(), skip for inference
blocks["out_layers.3"] = conv_nd(dims, out_channels, out_channels, kernel_size, padding);
if (out_channels != channels) {
blocks["skip_connection"] = conv_nd(dims, channels, out_channels, {1, 1}, {0, 0});
}
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = NULL) {
// For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml
// [N, c, t, h, w] => [N, c, t, h * w]
// x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w]
// emb: [N, emb_channels] if dims == 2 else [N, t, emb_channels]
auto in_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["in_layers.0"]);
auto in_layers_2 = std::dynamic_pointer_cast<UnaryBlock>(blocks["in_layers.2"]);
auto out_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out_layers.0"]);
auto out_layers_3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["out_layers.3"]);
if (emb == NULL) {
GGML_ASSERT(skip_t_emb);
}
// in_layers
auto h = in_layers_0->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
// emb_layers
if (!skip_t_emb) {
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
auto emb_out = ggml_silu(ctx, emb);
emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels]
if (dims == 2) {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
} else {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
if (exchange_temb_dims) {
// emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
}
}
h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
// out_layers
h = out_layers_0->forward(ctx, h);
h = ggml_silu_inplace(ctx, h);
// dropout, skip for inference
h = out_layers_3->forward(ctx, h);
// skip connection
if (out_channels != channels) {
auto skip_connection = std::dynamic_pointer_cast<UnaryBlock>(blocks["skip_connection"]);
x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
h = ggml_add(ctx, h, x);
return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
};
class GEGLU : public GGMLBlock {
protected:
int64_t dim_in;
int64_t dim_out;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
}
public:
GEGLU(int64_t dim_in, int64_t dim_out)
: dim_in(dim_in), dim_out(dim_out) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
struct ggml_tensor* w = params["proj.weight"];
struct ggml_tensor* b = params["proj.bias"];
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
gate = ggml_gelu_inplace(ctx, gate);
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out]
return x;
}
};
class FeedForward : public GGMLBlock {
public:
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4) {
int64_t inner_dim = dim * mult;
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
// net_1 is nn.Dropout(), skip for inference
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim]
// return: [ne3, ne2, ne1, dim_out]
auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]);
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
return x;
}
};
class CrossAttention : public GGMLBlock {
protected:
int64_t query_dim;
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;
public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head,
bool flash_attn = false)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim),
flash_attn(flash_attn) {
int64_t inner_dim = d_head * n_head;
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, query_dim));
// to_out_1 is nn.Dropout(), skip for inference
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
int64_t n = x->ne[2];
int64_t n_token = x->ne[1];
int64_t n_context = context->ne[1];
int64_t inner_dim = d_head * n_head;
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
}
};
class BasicTransformerBlock : public GGMLBlock {
protected:
int64_t n_head;
int64_t d_head;
bool ff_in;
public:
BasicTransformerBlock(int64_t dim,
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false,
bool flash_attn = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
if (ff_in) {
blocks["norm_in"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["ff_in"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
auto attn1 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn1"]);
auto attn2 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn2"]);
auto ff = std::dynamic_pointer_cast<FeedForward>(blocks["ff"]);
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto norm3 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm3"]);
if (ff_in) {
auto norm_in = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_in"]);
auto ff_in = std::dynamic_pointer_cast<FeedForward>(blocks["ff_in"]);
auto x_skip = x;
x = norm_in->forward(ctx, x);
x = ff_in->forward(ctx, x);
// self.is_res is always True
x = ggml_add(ctx, x, x_skip);
}
auto r = x;
x = norm1->forward(ctx, x);
x = attn1->forward(ctx, x, x); // self-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm2->forward(ctx, x);
x = attn2->forward(ctx, x, context); // cross-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm3->forward(ctx, x);
x = ff->forward(ctx, x);
x = ggml_add(ctx, x, r);
return x;
}
};
class SpatialTransformer : public GGMLBlock {
protected:
int64_t in_channels; // mult * model_channels
int64_t n_head;
int64_t d_head;
int64_t depth = 1; // 1
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim,
bool flash_attn = false)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
depth(depth),
context_dim(context_dim) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
// disable_self_attn is always False
int64_t inner_dim = n_head * d_head; // in_channels
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, in_channels, h, w]
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto x_in = x;
int64_t n = x->ne[3];
int64_t h = x->ne[1];
int64_t w = x->ne[0];
int64_t inner_dim = n_head * d_head;
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
x = transformer_block->forward(ctx, x, context);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_add(ctx, x, x_in);
return x;
}
};
class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
float get_alpha() {
// image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,]
// so learned_with_images is same as learned
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}
public:
AlphaBlender() {
// merge_strategy is always learned_with_images
// for inference, we don't need to set alpha
// since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x_spatial,
struct ggml_tensor* x_temporal) {
// image_only_indicator is always tensor([0.])
float alpha = get_alpha();
auto x = ggml_add(ctx,
ggml_scale(ctx, x_spatial, alpha),
ggml_scale(ctx, x_temporal, 1.0f - alpha));
return x;
}
};
class VideoResBlock : public ResBlock {
public:
VideoResBlock(int channels,
int emb_channels,
int out_channels,
std::pair<int, int> kernel_size = {3, 3},
int64_t video_kernel_size = 3,
int dims = 2) // always 2
: ResBlock(channels, emb_channels, out_channels, kernel_size, dims) {
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, emb_channels, out_channels, kernel_size, 3, true));
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
// x: [N, channels, h, w] aka [b*t, channels, h, w]
// emb: [N, emb_channels] aka [b*t, emb_channels]
// image_only_indicator is always tensor([0.])
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
auto time_mixer = std::dynamic_pointer_cast<AlphaBlender>(blocks["time_mixer"]);
x = ResBlock::forward(ctx, x, emb);
int64_t T = num_video_frames;
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
x = time_stack->forward(ctx, x, emb); // b t c (h w)
x = time_mixer->forward(ctx, x_mix, x); // b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}
};
#endif // __COMMON_HPP__