diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 18635f603..f8e2a9a1e 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -259,16 +259,17 @@ def __init__( in_features, hidden_features, act_layer = 'gelu', + bias = True, drop = 0.0, ): super().__init__() norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) self.norm = norm_layer(in_features) - self.w0 = nn.Linear(in_features, hidden_features) + self.w0 = nn.Linear(in_features, hidden_features, bias=bias) self.act = create_act_layer(act_layer) - self.w1 = nn.Linear(in_features, hidden_features) - self.w2 = nn.Linear(hidden_features, in_features) + self.w1 = nn.Linear(in_features, hidden_features, bias=bias) + self.w2 = nn.Linear(hidden_features, in_features, bias=bias) def forward(self, x): x = self.norm(x)