Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layers in nested list in custom layers can not be saved/loaded correctly. #20598

Open
yubori opened this issue Dec 5, 2024 · 6 comments
Open
Assignees
Labels
keras-team-review-pending Pending review by a Keras team member. type:Bug

Comments

@yubori
Copy link

yubori commented Dec 5, 2024

Dear all,

I faced a bug.
Layers saved in a nested list are not saved/loaded correctly.
The following code can reproduce the bug.

If any other information is necessary, please let me know.

Best Regards,

Environments

  • Keras: 3.7.0
  • Tensorflow: 2.17.1

Code

import keras
import numpy as np

@keras.saving.register_keras_serializable(package="MyLayers")
class NGLayer(keras.layers.Layer):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.l1 = []
        for i in range(2):
            l2 = []
            self.l1.append(l2)
            for j in range(2):
                l2.append(keras.layers.Dense(10, name=f'dense_{i}_{j}'))

    def call(self, x):
        for l in self.l1:
            for d in l:
                x = d(x)
        return x

@keras.saving.register_keras_serializable(package="MyLayers")
class OKLayer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.l1 = []
        for i in range(4):
            self.l1.append(keras.layers.Dense(10, name=f'dense_{i}'))

    def call(self, x):
        for d in self.l1:
            x = d(x)
        return x

# Create the model.
def get_model(is_OK = True):
    inputs = keras.Input(shape=(4,))
    if is_OK:
        mid = OKLayer()(inputs)
    else:
        mid = NGLayer()(inputs)
    outputs = keras.layers.Dense(1, activation='relu')(mid)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="rmsprop", loss="mean_squared_error")
    return model

# Train the model.
def train_model(model):
    input = np.random.random((4, 4))
    target = np.random.random((4, 1))
    model.fit(input, target)
    return model

test_input = np.random.random((4, 4))
test_target = np.random.random((4, 1))

for is_OK in (True, False):
    model = get_model(is_OK)
    model = train_model(model)
    model_path = f"custom_model_{'ok' if is_OK else 'ng'}.keras"
    model.save(model_path)

    reconstructed_model = keras.models.load_model(model_path)
    # same for the above
    #reconstructed_model = get_model(is_OK)
    #reconstructed_model.load_weights(model_path)

    print(f'{is_OK=}')
    np.testing.assert_allclose(
        model.predict(test_input), reconstructed_model.predict(test_input)
    )

Result

2024-12-05 17:47:57.955651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-05 17:47:57.970112: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-05 17:47:57.974527: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-05 17:47:57.984982: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-05 17:47:58.611755: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-12-05 17:48:00.285751: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22651 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 
0000:65:00.0, compute capability: 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1733388481.573550 1218505 service.cc:146] XLA service 0x7725d800a6b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733388481.573574 1218505 service.cc:154]   StreamExecutor device (0): NVIDIA TITAN RTX, Compute Capability 7.5
2024-12-05 17:48:01.588804: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-12-05 17:48:02.757904: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907
I0000 00:00:1733388483.012265 1218505 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 0.3580
is_OK=True
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 119ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 166ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 414ms/step - loss: 0.3595
is_OK=False
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 77ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step
Traceback (most recent call last):
  File "/home/vori/tmp/tmp.py", line 72, in <module>
    np.testing.assert_allclose(
  File "/home/me/.pyenv/versions/miniconda3-4.7.12/envs/keras3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/vori/.pyenv/versions/miniconda3-4.7.12/envs/keras3/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/me/.pyenv/versions/miniconda3-4.7.12/envs/keras3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 4 / 4 (100%)
Max absolute difference: 1.5574965
Max relative difference: 1.
 x: array([[0.],
       [0.],
       [0.],
       [0.]], dtype=float32)
 y: array([[0.193633],
       [0.999855],
       [1.557497],
       [0.802406]], dtype=float32)
@sonali-kumari1 sonali-kumari1 added the keras-team-review-pending Pending review by a Keras team member. label Dec 16, 2024
@mehtamansi29 mehtamansi29 removed the keras-team-review-pending Pending review by a Keras team member. label Dec 16, 2024
@mehtamansi29
Copy link
Collaborator

Hi @yubori -

Thanks for reporting the issue. In keras3, deeply nested inputs in functional models is not supported. Here you can find more details about it.

@yubori
Copy link
Author

yubori commented Dec 17, 2024

Hi @mehtamansi29

Thank you for your response.
However, this issue is not related to the deeply nested inputs issue.
Please recheck the code.

@mehtamansi29
Copy link
Collaborator

Hi @yubori -

Oh! You are error is regarding AssertionError: Not equal to tolerance .
Here while doing assert_allclose, apply tolerance rtol=1e-1,atol=1e-1 and apply kernel_initializer='random_normal',bias_initializer='zeros' while defining dense layers in both class, code will work fine with this.

Attached gist for the reference as well.

@yubori
Copy link
Author

yubori commented Dec 19, 2024

Hi @mehtamansi29,

What's the point of relaxing the error bounds so that no error occurs?
There's no point in sweeping the bug under the rug.
The bug that results when running a model with saved weights loaded, causing the output to be different from the original model, remains unsolved.

@mehtamansi29
Copy link
Collaborator

Hi @yubori-

Here there is an assertion error. AssertionError: Not equal to tolerance rtol=1e-07, atol=0 This is because of using np.testing.assert_allclose.

np.testing.assert_allclose raises an AssertionError if two objects are not equal up to desired tolerance. More details regarding it you can find here.

Because of tolerance variation in both prediction it is giving an error.

Here instead of np.testing.assert_allclose you can np.allclose. It will return True if two arrays are element-wise equal within a tolerance. More details regarding it you can find here.

And Here when we comparing model weights both are same. You can find the same in gist.

@yubori
Copy link
Author

yubori commented Dec 20, 2024

Hi @mehtamansi29,

As you can see in your last gist, there is the message of "Predictions differ." The prediction results are between the base model and the reloaded model.
Do you say it is not a bug?

Please dump the NG model's weights. There are no weights of NGLayer. Keras3.7 fails to handle the layer weights saved in the nested list.

OKLayer's weights (via print(model.weights)):

[<KerasVariable shape=(4, 10), dtype=float32, path=ok_layer/dense_0/kernel>, <KerasVariable shape=(10,), dtype=float32, path=ok_layer/dense_0/bias>, <KerasVariable shape=(10, 10), dtype=float32, path=ok_layer/dense_1/kernel>, <KerasVariable shape=(10,), dtype=float32, path=ok_layer/dense_1/bias>, <KerasVariable shape=(10, 10), dtype=float32, path=ok_layer/dense_2/kernel>, <KerasVariable shape=(10,), dtype=float32, path=ok_layer/dense_2/bias>, <KerasVariable shape=(10, 10), dtype=float32, path=ok_layer/dense_3/kernel>, <KerasVariable shape=(10,), dtype=float32, path=ok_layer/dense_3/bias>, <KerasVariable shape=(10, 1), dtype=float32, path=dense/kernel>, <KerasVariable shape=(1,), dtype=float32, path=dense/bias>]

NGLayer's weights (via print(model.weights)):

[<KerasVariable shape=(10, 1), dtype=float32, path=dense_1/kernel>, <KerasVariable shape=(1,), dtype=float32, path=dense_1/bias>]

@mehtamansi29 mehtamansi29 added the keras-team-review-pending Pending review by a Keras team member. label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member. type:Bug
Projects
None yet
Development

No branches or pull requests

4 participants