diff --git a/canopen/objectdictionary/__init__.py b/canopen/objectdictionary/__init__.py index 40ab8fbf..dbc6c855 100644 --- a/canopen/objectdictionary/__init__.py +++ b/canopen/objectdictionary/__init__.py @@ -6,7 +6,7 @@ import logging import struct -from collections.abc import Iterator, Mapping, MutableMapping +from collections.abc import Collection, Iterator, Mapping, MutableMapping from typing import Optional, TextIO, Union from canopen.objectdictionary.datatypes import * @@ -516,27 +516,44 @@ def encode_desc(self, desc: str) -> int: raise ValueError( f"No value corresponds to '{desc}'. Valid values are: {valid_values}") - def decode_bits(self, value: int, bits: list[int]) -> int: - try: + def decode_bits(self, value: int, bits: Union[str, Collection[int]]) -> int: + """Isolate and right-shift the specified bits from a given integer. + + :param value: Variable value holding the bits + :param bits: Registered lookup name or concrete list of bit offsets + :return: Extracted bits, right-shifted to cut off to lowest specified offset + :raises KeyError: For unknown lookup names + """ + if isinstance(bits, str): bits = self.bit_definitions[bits] - except (TypeError, KeyError): - pass mask = 0 for bit in bits: mask |= 1 << bit return (value & mask) >> min(bits) - def encode_bits(self, original_value: int, bits: list[int], bit_value: int): - try: + def encode_bits( + self, original_value: int, bits: Union[str, Collection[int]], bit_value: int + ) -> int: + """Replace the specified bits with the given (unshifted) pattern. + + The bit offsets sequence may be non-contiguous, but the replacement pattern + must specify all bits including the "holes". It is only shifted once, so the + LSB lands at the lowest specified bit offset. + + :param original_value: Variable value holding the bits + :param bits: Registered lookup name or concrete list of bit offsets + :param bit_value: Source pattern to overwrite with + :return: Merged value with the bits replaced + :raises KeyError: For unknown lookup names + """ + if isinstance(bits, str): bits = self.bit_definitions[bits] - except (TypeError, KeyError): - pass temp = original_value mask = 0 for bit in bits: mask |= 1 << bit temp &= ~mask - temp |= bit_value << min(bits) + temp |= (bit_value << min(bits)) & mask return temp diff --git a/canopen/variable.py b/canopen/variable.py index 639a1839..ede20607 100644 --- a/canopen/variable.py +++ b/canopen/variable.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from collections.abc import Mapping +from collections.abc import Collection, Mapping from typing import Union from canopen import objectdictionary @@ -118,7 +120,7 @@ def desc(self, desc: str): self.raw = self.od.encode_desc(desc) @property - def bits(self) -> "Bits": + def bits(self) -> Bits: """Access bits using integers, slices, or bit descriptions.""" return Bits(self) @@ -169,23 +171,23 @@ def write( class Bits(Mapping): def __init__(self, variable: Variable): + assert variable.od.data_type in objectdictionary.datatypes.INTEGER_TYPES self.variable = variable self.read() + self.raw: int @staticmethod - def _get_bits(key): + def _get_bits(key: Union[slice, int, str, Collection[int]]) -> Union[str, Collection[int]]: if isinstance(key, slice): - bits = range(key.start, key.stop, key.step) - elif isinstance(key, int): - bits = [key] - else: - bits = key - return bits - - def __getitem__(self, key) -> int: + return range(key.start, key.stop, key.step) + if isinstance(key, int): + return [key] + return key + + def __getitem__(self, key: Union[slice, int, str, Collection[int]]) -> int: return self.variable.od.decode_bits(self.raw, self._get_bits(key)) - def __setitem__(self, key, value: int): + def __setitem__(self, key: Union[slice, int, str, Collection[int]], value: int): self.raw = self.variable.od.encode_bits( self.raw, self._get_bits(key), value) self.write() @@ -197,7 +199,8 @@ def __len__(self): return len(self.variable.od.bit_definitions) def read(self): - self.raw = self.variable.raw + assert isinstance(raw_int := self.variable.raw, int) + self.raw = raw_int def write(self): self.variable.raw = self.raw diff --git a/test/test_od.py b/test/test_od.py index d6e3e984..8aa41d01 100644 --- a/test/test_od.py +++ b/test/test_od.py @@ -238,10 +238,26 @@ def test_bits(self): self.assertEqual(var.decode_bits(1, "BIT 0"), 1) self.assertEqual(var.decode_bits(1, [1]), 0) self.assertEqual(var.decode_bits(0xf, [0, 1, 2, 3]), 15) + self.assertEqual(var.decode_bits(0xf, range(4)), 15) self.assertEqual(var.decode_bits(8, "BIT 2 and 3"), 2) self.assertEqual(var.encode_bits(0xf, [1], 0), 0xd) self.assertEqual(var.encode_bits(0, "BIT 0", 1), 1) + with self.assertRaises(KeyError): + var.decode_bits(0, "DOES NOT EXIST") + with self.assertRaises(KeyError): + var.encode_bits(0, "DOES NOT EXIST", 0) + + def test_bits_sparse(self): + var = od.ODVariable("Test UNSIGNED8", 0x1000) + var.data_type = od.UNSIGNED8 + + self.assertEqual(var.decode_bits(0b11111111, [2, 5]), 0b1001) + self.assertEqual(var.decode_bits(0b11011011, [2, 5]), 0) + self.assertEqual(var.encode_bits(0b11111111, [2, 5], 0), 0b11011011) + self.assertEqual(var.encode_bits(0b00000000, [2, 5], 0b1001), 0b00100100) + self.assertEqual(var.encode_bits(0b00000000, [2, 5], 0b1111), 0b00100100) + class TestObjectDictionary(unittest.TestCase):