From 9c15ce06f8667ad3c13fafa46be3568d02501a86 Mon Sep 17 00:00:00 2001 From: Ricardo Torres Date: Thu, 14 May 2026 17:51:45 -0700 Subject: [PATCH] This change updates the random number generation utility to use tf.uint64 for bitwise operations and constants. The seed generation function (next_seed_fn) has been rewritten to implement a proper 7-bit LFSR (PRBS7) using bitwise operations, replacing the previous power-based approximation. PiperOrigin-RevId: 915708863 --- .../tensor_encoding/utils/tf_utils.py | 61 +++++++++++++++---- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py index bf58eb85b..4e0b61d2f 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py @@ -153,12 +153,12 @@ def _cmwc_random_sequence(num_elements, seed): # Create constants needed for the algorithm. The constants and notation # follows from the above reference. - a = tf.tile(tf.constant([3636507990], tf.int64), [parallelism]) - b = tf.tile(tf.constant([2**32], tf.int64), [parallelism]) - logb_scalar = tf.constant(32, tf.int64) + a = tf.tile(tf.constant([3636507990], tf.uint64), [parallelism]) + b = tf.tile(tf.constant([2**32], tf.uint64), [parallelism]) + logb_scalar = tf.constant(32, tf.uint64) logb = tf.tile([logb_scalar], [parallelism]) - f = tf.tile(tf.constant([0], dtype=tf.int64), [parallelism]) - bits = tf.constant(0, dtype=tf.int64, name='bits') + f = tf.tile(tf.constant([0], dtype=tf.uint64), [parallelism]) + bits = tf.constant(0, dtype=tf.uint64, name='bits') # TensorArray used in tf.while_loop for efficiency. values = tf.TensorArray( @@ -166,20 +166,58 @@ def _cmwc_random_sequence(num_elements, seed): # Iteration counter. num = tf.constant(0, dtype=tf.int32, name='num') # TensorFlow constant to be used at multiple places. - val_53 = tf.constant(53, tf.int64, name='val_53') + val_53 = tf.constant(53, tf.uint64, name='val_53') # Construct initial sequence of seeds. # From a single input seed, we construct multiple starting seeds for the # sequences to be computed in parallel. def next_seed_fn(i, val, q): - val = val**7 + val**6 + 1 # PRBS7. + """Generates the next seed using a 7-bit LFSR. + + This function implements a proper 7-bit Fibonacci LFSR with the polynomial + x^7 + x^6 + 1. It takes the lower 7 bits of `val` as the current state, + computes the next state, and writes it to the TensorArray `q`. + + Args: + i: The current index in the while loop. + val: The current seed value (tf.uint64). The lower 7 bits are used as + the LFSR state. + q: The tf.TensorArray to write the generated seed into. + + Returns: + A tuple of (i + 1, new_val, q), where `new_val` is the next state of the + LFSR. + """ + state = tf.bitwise.bitwise_and(val, tf.constant(0x7F, tf.uint64)) + # Avoid zero state, which is a trapping state for this LFSR polynomial. + state = tf.bitwise.bitwise_or( + state, + tf.cast(tf.equal(state, tf.constant(0, tf.uint64)), tf.uint64) + ) + # Feedback bit = bit 7 (index 6) ^ bit 6 (index 5) + feedback = tf.bitwise.bitwise_and( + tf.bitwise.bitwise_xor( + tf.bitwise.right_shift(state, tf.constant(6, tf.uint64)), + tf.bitwise.right_shift(state, tf.constant(5, tf.uint64)) + ), + tf.constant(1, tf.uint64) + ) + # Shift left and insert feedback + val = tf.bitwise.bitwise_and( + tf.bitwise.bitwise_or( + tf.bitwise.left_shift(state, tf.constant(1, tf.uint64)), + feedback + ), + tf.constant(0x7F, tf.uint64) + ) q = q.write(i, val) return i + 1, val, q - q = tf.TensorArray(dtype=tf.int64, size=parallelism, element_shape=()) + q = tf.TensorArray(dtype=tf.uint64, size=parallelism, element_shape=()) + seed_u64 = tf.cast(seed, tf.uint64) _, _, q = tf.while_loop(lambda i, _, __: i < parallelism, next_seed_fn, - [tf.constant(0), seed, q]) + [tf.constant(0), seed_u64, q]) c = q = q.stack() # The random sequence generation code. @@ -193,9 +231,10 @@ def cmwc_step(f, bits, q, c, num, values): f.set_shape((1,)) # Correct for failed shape inference. bits += logb_scalar def add_val(bits, f, values, num): + mask_53 = tf.constant(2**53 - 1, tf.uint64) new_val = tf.cast( - tf.bitwise.bitwise_and(f, (2**val_53 - 1)), - dtype=tf.float64) * (1 / 2**val_53) + tf.bitwise.bitwise_and(f, mask_53), + dtype=tf.float64) * (1.0 / 2.0**53) values = values.write(num, new_val) f += tf.bitwise.right_shift(f, val_53) bits -= val_53