From 45b966f2e809323c8567a24d21e6464139fa875d Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 23 Aug 2023 10:26:40 -0700 Subject: [PATCH] Rewire TensorFlow to rely on tf_keras target. PiperOrigin-RevId: 559470121 --- .../python/core/sparsity/keras/prune_registry.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index 36e203bc8..17c884497 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -19,12 +19,16 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer try: - from keras.engine import base_layer # pylint: disable=g-import-not-at-top + # OSS case. + import keras # pylint: disable=g-import-not-at-top + if hasattr(keras, 'src'): + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member + else: + from keras.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member except ImportError: - # Path as seen in pip packages as of TF/Keras 2.13. - from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top - -# TODO(b/139939526): move to public API. + # Internal case. + base_layer = tf._keras_internal.engine.base_layer # pylint: disable=protected-access layers = tf.keras.layers layers_compat_v1 = tf.compat.v1.keras.layers