Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom loss function only receives one/first output from model declaring multiple outputs #20654

Open
r25hbgh opened this issue Dec 17, 2024 · 2 comments
Assignees
Labels

Comments

@r25hbgh
Copy link

r25hbgh commented Dec 17, 2024

Current behavior

Declare a model with 2 named outputs, one for the bounding box and the second for the label classification.
Implicit declaration of loss functions in the compile configuration works ok.
Explicit declaration of a custom_loss function in the compile configuration calls the custom_loss function, but it only receives the first output (bbox), not the expected two (bbox + labels)
Standalone code to reproduce the issue

Model as (num_classes = 6 for ex):
bbox= layers.Dense(4, name="bbox")(features)
classification_output = layers.Dense(num_classes, name="classification", activation="softmax")(features)

model = keras.Model(inputs=inputs, outputs=[bbox, classification_output], name='vit_object_detector_with_class')

Dictionary for bounding box loss

bbox_loss_dict = {
"mse_loss": tf.keras.losses.MeanSquaredError(), # Mean Squared Error (for bounding box regression)
"mae_loss": tf.keras.losses.MeanAbsoluteError() # Mean Absolute Error (alternative for bounding boxes)
}

Dictionary for classification loss

class_loss_dict = {
"sparse_categorical_crossentropy": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # Cross-entropy loss for multi-class classification
"categorical_crossentropy": tf.keras.losses.CategoricalCrossentropy(from_logits=False) # Another option for multi-class classification if using one-hot encoded labels
}

Implicit declaration of configuration, inclusive the losses that works ok:

model.compile(
    optimizer='adam',  # or any optimizer
    loss={
        "bbox": bbox_loss_dict["mse_loss"], 
        "classification": class_loss_dict["sparse_categorical_crossentropy"] 
    },
    loss_weights={
    "bbox": 1.0,
    "classification": 1.5
    },
    metrics={
        "bbox": ["mse", "mae"],
        "classification": ["accuracy"]
    }
)

Training:

targets = {
"bbox": bbox_target, # shape, for ex: (640,4)
"classification": class_target # shape for ex: (640,)
}
model.fit(x_train, targets, epochs=10, batch_size=32)

#That training works correctly.

Declaring explicitly a custom_loss function:

def custom_loss(y_true, y_pred):

bbox_true = y_true[0]  # Bounding boxes ground truth
class_true = y_true[1]  # Class labels ground truth

bbox_pred = y_pred[0]  # Bounding box predictions
class_pred = y_pred[1]  # Class predictions

...... etc

and declaring the configuration as:
model.compile(optimizer='adam', loss=custom_loss, metrics=["accuracy"])

the custom_loss function only receives the bbox data (y_pred.shape = (32,4)) but not the label classification.
It should be something like: y_pred.shape = [(32,4),(32,)]

@sonali-kumari1
Copy link

Hi @r25hbgh,

Thanks for reporting this issue. You can declare custom_loss function by subclassing the keras.losses.Loss base class. I have checked the shapes for both implicit loss declaration and custom_loss(subclassing) and they are same. Attaching gist for your reference.

@fagonzalezo
Copy link

Thanks for reporting this issue. You can declare custom_loss function by subclassing the keras.losses.Loss base class. I have checked the shapes for both implicit loss declaration and custom_loss(subclassing) and they are same. Attaching gist for your reference.

Subclassing has the same problem, both y_true and y_pred receive only the bbox data. Trying to unpack with bbox, label = y_pred will produce an error. The contents of y_true and y_pred are respectively:

y_true: Tensor("data_1:0", shape=(32, 4), dtype=float32)
y_pred: Tensor("functional_1_1/bbox_1/BiasAdd:0", shape=(32, 4), dtype=float32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants