You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, this is my first time working on a transformer model, in this case, a 3D vision transformer model,
I am working on a 3d medical image classification task, and the training set is around 300 3D images; here is what image input looks like (1, 224, 224, 32); here, 1 is the number of channels, and 32 is the z dim size. I trained my data set on 3D efficientnet, and the accuracy was around 80%. I tried a 3D vision transformer, but the model does not converge. Can you please review the code below? Why does the model not learn? Do you know if I am doing something wrong? Do you have any help or suggestions? Thank you in advance.
This is the forward path:
`
def forward(self, img):
print("img,input shape before patch embedding", img.shape)
x = self.to_patch_embedding(img)
print("after patch embedding", x.shape)
b, n, _ = x.shape
#cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
print("cls token shape", cls_tokens.shape)
x = torch.cat((cls_tokens, x), dim=1)
print("after cls_token", x.shape)
x += self.pos_embedding[:, :(n + 1)]
print("after position embedding", x.shape)
x = self.dropout(x)
x = self.transformer(x)
print("after transformer", x.shape)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
print("after latent", x.shape)
return self.mlp_head(x) `
Hi, this is my first time working on a transformer model, in this case, a 3D vision transformer model,
I am working on a 3d medical image classification task, and the training set is around 300 3D images; here is what image input looks like (1, 224, 224, 32); here, 1 is the number of channels, and 32 is the z dim size. I trained my data set on 3D efficientnet, and the accuracy was around 80%. I tried a 3D vision transformer, but the model does not converge. Can you please review the code below? Why does the model not learn? Do you know if I am doing something wrong? Do you have any help or suggestions? Thank you in advance.
This is the forward path:
`
3D vision transformer model configuration:
`
)`
`
optimizer = optim.Adam(ViTmodel.parameters(), lr=0.002)
`
Input shapes from the forward path:
`
model training results:
`
Epoch 1/10 (Training): 100%|██████████| 56/56 [00:59<00:00, 1.07s/it]
Epoch 1/10, Training Loss: 0.5908904586519513, Training Accuracy: 0.7142857142857143
Epoch 1/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 1/10, Validation Loss: 0.5275474616459438, Validation Accuracy: 0.7798165137614679
Best model saved at epoch 1
Epoch 2/10 (Training): 100%|██████████| 56/56 [00:58<00:00, 1.04s/it]
Epoch 2/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 2/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 2/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 3/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 3/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 3/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 3/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679
Epoch 4/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 4/10, Training Loss: 0.5878153358186994, Training Accuracy: 0.7210884353741497
Epoch 4/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 4/10, Validation Loss: 0.5329046036515918, Validation Accuracy: 0.7798165137614679
Epoch 5/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 5/10, Training Loss: 0.6034403315612248, Training Accuracy: 0.7210884353741497
Epoch 5/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 5/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 6/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 6/10, Training Loss: 0.5878153379474368, Training Accuracy: 0.7210884353741497
Epoch 6/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 6/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679`
The text was updated successfully, but these errors were encountered: