-
Notifications
You must be signed in to change notification settings - Fork 89
Pyrtl floating point library #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode | ||
| from .floatoperations import ( | ||
| BFloat16Operations, | ||
| Float16Operations, | ||
| Float32Operations, | ||
| Float64Operations, | ||
| FloatOperations, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "FloatingPointType", | ||
| "FPTypeProperties", | ||
| "PyrtlFloatConfig", | ||
| "RoundingMode", | ||
| "FloatOperations", | ||
| "BFloat16Operations", | ||
| "Float16Operations", | ||
| "Float32Operations", | ||
| "Float64Operations", | ||
| ] |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import pyrtl | ||
|
|
||
| from ._types import FPTypeProperties | ||
|
|
||
|
|
||
| def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the sign bit of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the sign bit. | ||
| """ | ||
| return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] | ||
|
|
||
|
|
||
| def get_exponent(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the exponent bits of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the exponent bits. | ||
| """ | ||
| return wire[ | ||
| fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits | ||
| + fp_prop.num_exponent_bits | ||
| ] | ||
|
|
||
|
|
||
| def get_mantissa(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the mantissa bits of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the mantissa bits. | ||
| """ | ||
| return wire[: fp_prop.num_mantissa_bits] | ||
|
|
||
|
|
||
| def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like we always run all these smaller_kinds = check_kinds(smaller_operand)
...
with smaller_kinds.is_nan:
...I think the kinds are mutually exclusive? If that's right, it could make sense to encode them into 2 bits instead of using 4 one-hot bits. |
||
| """ | ||
| Returns whether the floating point number is zero. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is zero. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) == 0) & (get_exponent(fp_prop, wire) == 0) | ||
|
|
||
|
|
||
| def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is infinity. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is infinity. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) == 0) & ( | ||
| get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 | ||
| ) | ||
|
|
||
|
|
||
| def is_denormalized( | ||
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is denormalized. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A reference that defines denormalized numbers would be very helpful. https://en.wikipedia.org/wiki/Subnormal_number seems good, but maybe you've found something better? |
||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is denormalized. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) != 0) & (get_exponent(fp_prop, wire) == 0) | ||
|
|
||
|
|
||
| def is_nan(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is NaN. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is NaN. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) != 0) & ( | ||
| get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 | ||
| ) | ||
|
|
||
|
|
||
| def make_denormals_zero( | ||
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Returns zero if denormalized, else original number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the resulting floating point number. | ||
| """ | ||
| out = pyrtl.WireVector( | ||
| bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 | ||
| ) | ||
| with pyrtl.conditional_assignment: | ||
| with get_exponent(fp_prop, wire) == 0: | ||
| out |= pyrtl.concat( | ||
| get_sign(fp_prop, wire), | ||
| get_exponent(fp_prop, wire), | ||
| pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), | ||
| ) | ||
| with pyrtl.otherwise: | ||
| out |= wire | ||
| return out | ||
|
|
||
|
|
||
| def make_inf( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a little confusing that I'd recommend just returning |
||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent infinity. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 1 | ||
| mantissa |= 0 | ||
|
|
||
|
|
||
| def make_nan( | ||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent NaN. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 1 | ||
| mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) | ||
|
|
||
|
|
||
| def make_zero(exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent zero. | ||
|
|
||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= 0 | ||
| mantissa |= 0 | ||
|
|
||
|
|
||
| def make_largest_finite_number( | ||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent the largest finite number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 2 | ||
| mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import pyrtl | ||
|
|
||
| from ._float_utills import ( | ||
| get_exponent, | ||
| get_mantissa, | ||
| get_sign, | ||
| is_denormalized, | ||
| is_inf, | ||
| is_nan, | ||
| is_zero, | ||
| make_denormals_zero, | ||
| make_inf, | ||
| make_largest_finite_number, | ||
| make_nan, | ||
| make_zero, | ||
| ) | ||
| from ._types import PyrtlFloatConfig, RoundingMode | ||
|
|
||
|
|
||
| def mul( | ||
| config: PyrtlFloatConfig, | ||
| operand_a: pyrtl.WireVector, | ||
| operand_b: pyrtl.WireVector, | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Performs floating point multiplication of two WireVectors. | ||
|
|
||
| :param config: Configuration for the floating point type and rounding mode. | ||
| :param operand_a: The first floating point operand as a WireVector. | ||
| :param operand_b: The second floating point operand as a WireVector. | ||
| :return: The result of the multiplication as a WireVector. | ||
| """ | ||
| fp_type_props = config.fp_type_properties | ||
| rounding_mode = config.rounding_mode | ||
| num_exp_bits = fp_type_props.num_exponent_bits | ||
| num_mant_bits = fp_type_props.num_mantissa_bits | ||
|
|
||
| # Denormalized numbers are not supported, so we flush them to zero. | ||
| operands = (operand_a, operand_b) | ||
| operands_daz = tuple(make_denormals_zero(fp_type_props, op) for op in operands) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left a similar comment on Similarly, this is a big function, so think about how someone new to this code will read it, and look for ways to help them navigate and understand what's important. Don't forget that you have the Curse of Knowledge: you wrote all this code, so it already makes perfect sense to you :) To really test your code's readability, you have to show it to other people. |
||
|
|
||
| # Extract the sign and exponent of both operands. | ||
| signs = tuple(get_sign(fp_type_props, op) for op in operands_daz) | ||
| exponents = tuple(get_exponent(fp_type_props, op) for op in operands_daz) | ||
|
|
||
| result_sign = signs[0] ^ signs[1] | ||
|
|
||
| # IEEE-754 floating point numbers have a bias: | ||
| # https://en.wikipedia.org/wiki/Exponent_bias | ||
| # real_exponent = stored_exponent - bias, so stored_exponent = real + bias | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe an easier way to think about this (for me, at least :)
Which is a double-biased result. We can correct that to a single-bias result by subtracting off |
||
| # Therefore, stored_exponent_product = real_exponent_product + bias | ||
| # = (real_exponent_a + real_exponent_b) + bias | ||
| # = (stored_exponent_a - bias + stored_exponent_b - bias) + bias | ||
| # = stored_exponent_a + stored_exponent_b - bias | ||
| operand_exponent_sums = exponents[0] + exponents[1] | ||
| exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 | ||
| product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) | ||
|
|
||
| # Extract the mantissa of both operands and add the implicit leading 1. | ||
| mantissas = tuple( | ||
| pyrtl.concat(pyrtl.Const(1), get_mantissa(fp_type_props, op)) | ||
| for op in operands_daz | ||
| ) | ||
| product_mantissa = mantissas[0] * mantissas[1] | ||
|
|
||
| normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) | ||
| normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| # We need to normalize (shift right) if the leading bit is 1. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's worth explaining why normalization requires at most one right shift. It took me a few minutes to figure out, even with the reference below which doesn't really spell it out. Or maybe we can find a better reference? IIUC, we're multiplying two equal-bitwidth numbers that both have their most significant bit set to 1, with binary points just to the right of the MSB. So the inputs both look like The product's binary point will be just after its second-most significant bit, so the output will look like A properly normalized output must have the form We'll never need to shift right more than one position because the biggest output we can produce is We'll never need to shift left because you can't multiply two |
||
| # https://numeral-systems.com/ieee-754-multiply/ | ||
| need_to_normalize = product_mantissa[-1] | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| guard = pyrtl.WireVector(bitwidth=1) | ||
| sticky = pyrtl.WireVector(bitwidth=1) | ||
| last = pyrtl.WireVector(bitwidth=1) # Last bit of the mantissa before rounding. | ||
|
|
||
| # Assign the normalized mantissa, exponent, guard, sticky, and last bits | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems clearer to always mention |
||
| # based on whether normalization is needed. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a reference to https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ for |
||
| with pyrtl.conditional_assignment: | ||
| with need_to_normalize: | ||
| normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this logic can be simplified by only aligning the # Align the mantissa's binary point into the standard 1.<something>
# format.
with pyrtl.conditional_assignment:
with need_to_normalize:
# product_mantissa's MSB is 1 so we just reinterpret the 1b.cdef...
# format as 1.bcdef...
aligned_mantissa |= product_mantissa
normalized_product_exponent |= product_exponent + 1
with pyrtl.otherwise:
# Output is 01.<something> so shift it left to drop the leading
# zero and put it in the standard 1.<something> format.
aligned_mantissa |= pyrtl.Concat(product_mantissa[:-1], pyrtl.Const(0, bitwidth=1))
normalized_product_exponent |= product_exponent
# Extract the parts we need from the aligned_mantissa.
normalized_product_mantissa |= aligned_mantissa[-num_mant_bits - 1 :]
last = aligned_mantissa[-num_mant_bits - 1]
guard = aligned_mantissa[-num_mant_bits - 2]
sticky = aligned_mantissa[: -num_mant_bits - 2] != 0 |
||
| normalized_product_exponent |= product_exponent + 1 | ||
| if rounding_mode == RoundingMode.RNE: | ||
| guard |= product_mantissa[-num_mant_bits - 2] | ||
| sticky |= product_mantissa[: -num_mant_bits - 2] != 0 | ||
| last |= product_mantissa[-num_mant_bits - 1] | ||
| with pyrtl.otherwise: | ||
| normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] | ||
| normalized_product_exponent |= product_exponent | ||
| if rounding_mode == RoundingMode.RNE: | ||
| guard |= product_mantissa[-num_mant_bits - 3] | ||
| sticky |= product_mantissa[: -num_mant_bits - 3] != 0 | ||
| last |= product_mantissa[-num_mant_bits - 2] | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
| rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) | ||
| # Whether exponent was incremented due to rounding (for overflow check). | ||
| exponent_incremented = pyrtl.WireVector(bitwidth=1) | ||
| # If guard bit is not set, number is closer to smaller value: no round. | ||
| # If guard and sticky are set, round up. | ||
| # If guard is set but sticky is not, value is exactly halfway. | ||
| # Following round-to-nearest ties-to-even, round up if last bit is 1. | ||
| round_up = guard & (last | sticky) | ||
| with pyrtl.conditional_assignment: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks identical to the |
||
| with round_up: | ||
| with normalized_product_mantissa == (1 << num_mant_bits) - 1: | ||
| rounded_product_mantissa |= 0 | ||
| rounded_product_exponent |= normalized_product_exponent + 1 | ||
| exponent_incremented |= 1 | ||
| with pyrtl.otherwise: | ||
| rounded_product_mantissa |= normalized_product_mantissa + 1 | ||
| rounded_product_exponent |= normalized_product_exponent | ||
| exponent_incremented |= 0 | ||
| with pyrtl.otherwise: | ||
| rounded_product_mantissa |= normalized_product_mantissa | ||
| rounded_product_exponent |= normalized_product_exponent | ||
| exponent_incremented |= 0 | ||
|
|
||
| result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) | ||
| result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| # Check whether operands are special: NaN, infinity, zero, or denormalized. | ||
| operand_nans = tuple(is_nan(fp_type_props, op) for op in operands_daz) | ||
| operand_infs = tuple(is_inf(fp_type_props, op) for op in operands_daz) | ||
| operand_zeros = tuple(is_zero(fp_type_props, op) for op in operands_daz) | ||
| operand_denorms = tuple(is_denormalized(fp_type_props, op) for op in operands_daz) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Try to be consistent with your names, this breaks the pattern established on lines 125-127 ( |
||
|
|
||
| # We check for overflow and underflow by computing max and min exponent | ||
| # values of the sum of operands before rounding and normalization. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should "sum of operands" say "sum of operands' exponents" (we're adding the exponents to multiply the operands) ? |
||
| # These values depend on the operands. If the result requires | ||
| # normalization, the exponent is incremented by 1. Additionally, rounding | ||
| # may further increase the exponent. Therefore, we subtract these | ||
| # potential increments from the absolute maximum exponent, which is one | ||
| # less than the all-1s exponent (reserved for inf/NaN) plus bias. | ||
| # Similarly, we subtract these increments from the absolute minimum | ||
| # exponent, which is 1 plus the exponent bias. | ||
| sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) | ||
| sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) | ||
| if rounding_mode == RoundingMode.RNE: | ||
| exponent_max_value = ( | ||
| sum_exponent_max_value - need_to_normalize - exponent_incremented | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works, but think about the generated logic: each So we're really just selecting between:
which shouldn't require instantiating any adders at all. If you refactor this, definitely add a comment explaining why it's not written as two subtractions, so someone doesn't try to "fix" it later :) Similar comments below, also for the similar code in |
||
| ) | ||
| exponent_min_value = ( | ||
| sum_exponent_min_value - need_to_normalize - exponent_incremented | ||
| ) | ||
| else: | ||
| exponent_max_value = sum_exponent_max_value - need_to_normalize | ||
| exponent_min_value = sum_exponent_min_value - need_to_normalize | ||
|
|
||
| # Assign the raw result's exponent and mantissa depending on whether RNE rounding | ||
| # is used. The calculated exponent WireVector has an extra bit due to the carry-out | ||
| # from addition, so we take only the lower num_exp_bits to remove this extra bit. | ||
| if rounding_mode == RoundingMode.RNE: | ||
| raw_result_exponent = rounded_product_exponent[:num_exp_bits] | ||
| raw_result_mantissa = rounded_product_mantissa | ||
| else: | ||
| raw_result_exponent = normalized_product_exponent[:num_exp_bits] | ||
| raw_result_mantissa = normalized_product_mantissa | ||
|
|
||
| with pyrtl.conditional_assignment: | ||
| # If either operand is NaN, or if one operand is infinity and the other is | ||
| # zero, the result is NaN. | ||
| with ( | ||
| operand_nans[0] | ||
| | operand_nans[1] | ||
| | (operand_infs[0] & operand_zeros[1]) | ||
| | (operand_zeros[0] & operand_infs[1]) | ||
| ): | ||
| make_nan(fp_type_props, result_exponent, result_mantissa) | ||
| # If either operand is infinity, the result is infinity. | ||
| with operand_infs[0] | operand_infs[1]: | ||
| make_inf(fp_type_props, result_exponent, result_mantissa) | ||
| # Detect overflow. | ||
| with operand_exponent_sums > exponent_max_value: | ||
| if rounding_mode == RoundingMode.RNE: | ||
| make_inf(fp_type_props, result_exponent, result_mantissa) | ||
| else: | ||
| make_largest_finite_number( | ||
| fp_type_props, result_exponent, result_mantissa | ||
| ) | ||
| # If either operand is zero, if underflow occurred, or if either operand is | ||
| # denormalized, the result is zero. | ||
| with ( | ||
| operand_zeros[0] | ||
| | operand_zeros[1] | ||
| | (operand_exponent_sums < exponent_min_value) | ||
| | operand_denorms[0] | ||
| | operand_denorms[1] | ||
| ): | ||
| make_zero(result_exponent, result_mantissa) | ||
| with pyrtl.otherwise: | ||
| result_exponent |= raw_result_exponent | ||
| result_mantissa |= raw_result_mantissa | ||
|
|
||
| return pyrtl.concat(result_sign, result_exponent, result_mantissa) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be called
_float_utils.py? (only one L in utilities :)