diff --git a/src/wled/wled.py b/src/wled/wled.py index 65b88ef6..3ba757f2 100644 --- a/src/wled/wled.py +++ b/src/wled/wled.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio +import hashlib +import logging import re import socket import time @@ -33,6 +35,8 @@ from .const import LiveDataOverride +_LOGGER = logging.getLogger(__name__) + @dataclass class _PresetsVersion: @@ -632,7 +636,7 @@ async def nightlight( nightlight = {k: v for k, v in nightlight.items() if v is not None} await self.request("/json/state", method="POST", data={"nl": nightlight}) - async def upgrade( # noqa: PLR0912 + async def upgrade( self, *, version: str | AwesomeVersion, @@ -715,38 +719,179 @@ async def upgrade( # noqa: PLR0912 f"https://github.com/{repo}/releases/download/v{version}/{update_file}" ) + expected_sha256 = await self._fetch_firmware_digest( + repo=repo, version=str(version), asset_name=update_file + ) + await self._download_and_flash_firmware( + download_url=download_url, + flash_url=url, + update_file=update_file, + expected_sha256=expected_sha256, + ) + + async def _download_and_flash_firmware( + self, + *, + download_url: str, + flash_url: URL, + update_file: str, + expected_sha256: str | None, + ) -> None: + """Download firmware, verify its SHA256 digest, and POST it to the device.""" + assert self.session is not None # noqa: S101 # guaranteed by upgrade() + try: async with ( - asyncio.timeout( - self.request_timeout * 10, - ), + asyncio.timeout(self.request_timeout * 10), self.session.get( download_url, raise_for_status=True, ) as download, ): - form = aiohttp.FormData() - form.add_field("file", await download.read(), filename=update_file) - await self.session.post(url, data=form) + firmware = await download.read() except TimeoutError as exception: - msg = "Timeout occurred while fetching WLED version information from GitHub" + msg = f"Timeout occurred while downloading firmware from {download_url}" raise WLEDConnectionTimeoutError(msg) from exception except aiohttp.ClientResponseError as exception: if exception.status == 404: msg = f"Requested firmware file {update_file} does not exist" raise WLEDUpgradeError(msg) from exception - msg = ( - f"Could not download requested WLED version '{version}'" - f" from {download_url}" - ) + msg = f"Could not download requested WLED version from {download_url}" raise WLEDUpgradeError(msg) from exception except (aiohttp.ClientError, socket.gaierror) as exception: - msg = ( - "Error occurred while communicating with GitHub" - " for WLED version information" - ) + msg = f"Error occurred while downloading firmware from {download_url}" + raise WLEDConnectionError(msg) from exception + + if expected_sha256 is not None: + actual_sha256 = hashlib.sha256(firmware).hexdigest() + if actual_sha256 != expected_sha256: + msg = ( + f"Firmware integrity check failed for {update_file}: " + f"expected SHA256 {expected_sha256}, " + f"got {actual_sha256}" + ) + raise WLEDUpgradeError(msg) + + form = aiohttp.FormData() + form.add_field("file", firmware, filename=update_file) + try: + async with ( + asyncio.timeout(self.request_timeout * 10), + self.session.post( + flash_url, + data=form, + raise_for_status=True, + ) as flash, + ): + await flash.read() + except TimeoutError as exception: + msg = "Timeout occurred while uploading firmware to the device" + raise WLEDConnectionTimeoutError(msg) from exception + except aiohttp.ClientResponseError as exception: + msg = "Device rejected the firmware upload" + raise WLEDUpgradeError(msg) from exception + except (aiohttp.ClientError, socket.gaierror) as exception: + msg = "Error occurred while uploading firmware to the device" raise WLEDConnectionError(msg) from exception + async def _fetch_firmware_digest( + self, *, repo: str, version: str, asset_name: str + ) -> str | None: + """Fetch the expected SHA256 digest for a firmware asset from the GitHub API. + + Returns the hex digest string if available, or None if it cannot be + determined (e.g. old release without digest, tag not found, API error). + Callers should proceed with the upgrade when None is returned. + """ + assert self.session is not None # noqa: S101 # guaranteed by upgrade() + api_url = f"https://api.github.com/repos/{repo}/releases/tags/v{version}" + try: + async with ( + asyncio.timeout(self.request_timeout), + self.session.get( + api_url, + headers={"Accept": "application/json"}, + ) as response, + ): + status = response.status + content_type = response.headers.get("Content-Type", "") + contents = await response.read() + except TimeoutError: + _LOGGER.warning( + "Timeout fetching release metadata from GitHub; " + "skipping firmware integrity check" + ) + return None + except (aiohttp.ClientError, socket.gaierror): + _LOGGER.warning( + "Connection error fetching release metadata from GitHub; " + "skipping firmware integrity check" + ) + return None + + if status == 404: + _LOGGER.warning( + "Release tag v%s not found on GitHub API; " + "skipping firmware integrity check", + version, + ) + return None + + if status // 100 in [4, 5]: + _LOGGER.warning( + "GitHub API returned HTTP %d for release metadata; " + "skipping firmware integrity check", + status, + ) + return None + + if "application/json" not in content_type: + _LOGGER.warning( + "Unexpected content type '%s' from GitHub API; " + "skipping firmware integrity check", + content_type, + ) + return None + + release = orjson.loads(contents) + return self._parse_asset_digest( + release=release, asset_name=asset_name, version=version + ) + + def _parse_asset_digest( + self, *, release: dict[str, Any], asset_name: str, version: str + ) -> str | None: + """Extract the SHA256 hex digest for a named asset from a release payload.""" + for asset in release.get("assets", []): + if asset.get("name") == asset_name: + digest: str | None = asset.get("digest") + if digest is None: + _LOGGER.debug( + "No digest for asset %s in release v%s; " + "skipping firmware integrity check", + asset_name, + version, + ) + return None + prefix = "sha256:" + if digest.startswith(prefix): + return digest[len(prefix) :] + _LOGGER.warning( + "Unrecognised digest format '%s' for asset %s; " + "skipping firmware integrity check", + digest, + asset_name, + ) + return None + + _LOGGER.debug( + "Asset %s not found in release v%s assets; " + "skipping firmware integrity check", + asset_name, + version, + ) + return None + async def reset(self) -> None: """Reboot WLED device.""" await self.request("/reset") diff --git a/tests/test_wled.py b/tests/test_wled.py index 1a4649fc..4c528e25 100644 --- a/tests/test_wled.py +++ b/tests/test_wled.py @@ -1162,6 +1162,31 @@ async def test_client_error_raises_connection_error( # Section 17: WLED client - upgrade() method # ========================================================================= +FAKE_FIRMWARE = b"fake firmware" +FAKE_FIRMWARE_SHA256 = ( + "0eb580d67f17f6586407d4b9e0ae216a91b228e49cac8858b5283cd6da8ad0c1" +) + + +def mock_github_release_api( # noqa: PLR0913 # pylint: disable=too-many-arguments,too-many-positional-arguments + responses: aioresponses, + repo: str = "wled/WLED", + version: str = "0.15.0", + asset_name: str = "WLED_0.15.0_ESP32.bin", + digest: str | None = f"sha256:{FAKE_FIRMWARE_SHA256}", + status: int = 200, +) -> None: + """Mock the GitHub releases/tags API for a given repo/version.""" + asset = {"name": asset_name} + if digest is not None: + asset["digest"] = digest + responses.get( + f"https://api.github.com/repos/{repo}/releases/tags/v{version}", + status=status, + body=json.dumps({"tag_name": f"v{version}", "assets": [asset]}), + content_type="application/json", + ) + async def prepare_wled_for_upgrade( # pylint: disable=too-many-arguments, too-many-positional-arguments responses: aioresponses, @@ -1217,11 +1242,11 @@ async def test_upgrade_calls_update_when_no_device( ) -> None: """Test upgrade() calls update() if no device loaded.""" mock_json_and_presets(responses) - # Mock the download and upload + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", status=200, - body=b"fake firmware", + body=FAKE_FIRMWARE, ) responses.post( "http://example.com/update", @@ -1246,10 +1271,11 @@ async def test_upgrade_no_session_raises() -> None: async def test_upgrade_success(responses: aioresponses, wled: WLED) -> None: """Test successful upgrade.""" await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", status=200, - body=b"fake firmware", + body=FAKE_FIRMWARE, ) responses.post( "http://example.com/update", @@ -1263,10 +1289,11 @@ async def test_upgrade_success(responses: aioresponses, wled: WLED) -> None: async def test_upgrade_ethernet_board(responses: aioresponses, wled: WLED) -> None: """Test upgrade with Ethernet board (empty bssid).""" await prepare_wled_for_upgrade(responses, wled, wifi_bssid="") + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32_Ethernet.bin") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32_Ethernet.bin", status=200, - body=b"fake firmware", + body=FAKE_FIRMWARE, ) responses.post( "http://example.com/update", @@ -1286,10 +1313,11 @@ async def test_upgrade_esp02_gzip(responses: aioresponses, wled: WLED) -> None: wled_data["info"]["fs"]["t"] = 512 mock_json_and_presets(responses, wled_data) await wled.update() + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP02.bin.gz") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP02.bin.gz", status=200, - body=b"fake firmware", + body=FAKE_FIRMWARE, ) responses.post( "http://example.com/update", @@ -1303,6 +1331,7 @@ async def test_upgrade_esp02_gzip(responses: aioresponses, wled: WLED) -> None: async def test_upgrade_404(responses: aioresponses, wled: WLED) -> None: """Test upgrade with 404 download raises WLEDUpgradeError.""" await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, version="0.99.0", status=404) responses.get( "https://github.com/wled/WLED/releases/download/v0.99.0/WLED_0.99.0_ESP32.bin", status=404, @@ -1314,6 +1343,7 @@ async def test_upgrade_404(responses: aioresponses, wled: WLED) -> None: async def test_upgrade_other_http_error(responses: aioresponses, wled: WLED) -> None: """Test upgrade with non-404 HTTP error raises WLEDUpgradeError.""" await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", status=500, @@ -1325,6 +1355,7 @@ async def test_upgrade_other_http_error(responses: aioresponses, wled: WLED) -> async def test_upgrade_connection_error(responses: aioresponses, wled: WLED) -> None: """Test upgrade with connection error raises WLEDConnectionError.""" await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", exception=aiohttp.ClientError("fail"), @@ -1336,6 +1367,7 @@ async def test_upgrade_connection_error(responses: aioresponses, wled: WLED) -> async def test_upgrade_timeout(responses: aioresponses, wled: WLED) -> None: """Test upgrade with timeout raises WLEDConnectionTimeoutError.""" await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") wled.request_timeout = 0.001 responses.get( "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", @@ -1345,6 +1377,174 @@ async def test_upgrade_timeout(responses: aioresponses, wled: WLED) -> None: await wled.upgrade(version="0.15.0") +async def test_upgrade_digest_mismatch_raises( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade raises WLEDUpgradeError when firmware digest does not match.""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api( + responses, + asset_name="WLED_0.15.0_ESP32.bin", + digest="sha256:0000000000000000000000000000000000000000000000000000000000000000", + ) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + with pytest.raises(WLEDUpgradeError, match="integrity check failed"): + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_no_digest_in_asset_proceeds( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade proceeds without error when asset has no digest (null).""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api( + responses, + asset_name="WLED_0.15.0_ESP32.bin", + digest=None, + ) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=200, + body="OK", + content_type="text/plain", + ) + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_asset_not_in_release_proceeds( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade proceeds when the firmware asset is not listed in the release.""" + await prepare_wled_for_upgrade(responses, wled) + # Return a release with a different asset name (not matching our firmware file) + responses.get( + "https://api.github.com/repos/wled/WLED/releases/tags/v0.15.0", + status=200, + body=json.dumps({"tag_name": "v0.15.0", "assets": [{"name": "other.bin"}]}), + content_type="application/json", + ) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=200, + body="OK", + content_type="text/plain", + ) + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_github_api_404_proceeds( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade proceeds when the GitHub API returns 404 for the release tag.""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, version="0.15.0", status=404) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=200, + body="OK", + content_type="text/plain", + ) + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_github_api_server_error_proceeds( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade proceeds (gracefully) when GitHub API returns a 5xx error.""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, version="0.15.0", status=500) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=200, + body="OK", + content_type="text/plain", + ) + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_github_api_connection_error_proceeds( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade proceeds when GitHub API call raises a connection error.""" + await prepare_wled_for_upgrade(responses, wled) + responses.get( + "https://api.github.com/repos/wled/WLED/releases/tags/v0.15.0", + exception=aiohttp.ClientError("fail"), + ) + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=200, + body="OK", + content_type="text/plain", + ) + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_device_http_error(responses: aioresponses, wled: WLED) -> None: + """Test upgrade raises WLEDUpgradeError when the device rejects the POST.""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + status=500, + ) + with pytest.raises(WLEDUpgradeError, match="Device rejected"): + await wled.upgrade(version="0.15.0") + + +async def test_upgrade_device_connection_error( + responses: aioresponses, wled: WLED +) -> None: + """Test upgrade raises WLEDConnectionError when device POST fails.""" + await prepare_wled_for_upgrade(responses, wled) + mock_github_release_api(responses, asset_name="WLED_0.15.0_ESP32.bin") + responses.get( + "https://github.com/wled/WLED/releases/download/v0.15.0/WLED_0.15.0_ESP32.bin", + status=200, + body=FAKE_FIRMWARE, + ) + responses.post( + "http://example.com/update", + exception=aiohttp.ClientError("device gone"), + ) + with pytest.raises(WLEDConnectionError): + await wled.upgrade(version="0.15.0") + + # ========================================================================= # Section 18: WLEDReleases class # =========================================================================