Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 161 additions & 16 deletions src/wled/wled.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

import asyncio
import hashlib
import logging
import re
import socket
import time
Expand Down Expand Up @@ -33,6 +35,8 @@

from .const import LiveDataOverride

_LOGGER = logging.getLogger(__name__)


@dataclass
class _PresetsVersion:
Expand Down Expand Up @@ -632,7 +636,7 @@ async def nightlight(
nightlight = {k: v for k, v in nightlight.items() if v is not None}
await self.request("/json/state", method="POST", data={"nl": nightlight})

async def upgrade( # noqa: PLR0912
async def upgrade(
self,
*,
version: str | AwesomeVersion,
Expand Down Expand Up @@ -715,38 +719,179 @@ async def upgrade( # noqa: PLR0912
f"https://github.com/{repo}/releases/download/v{version}/{update_file}"
)

expected_sha256 = await self._fetch_firmware_digest(
repo=repo, version=str(version), asset_name=update_file
)
await self._download_and_flash_firmware(
download_url=download_url,
flash_url=url,
update_file=update_file,
expected_sha256=expected_sha256,
)

async def _download_and_flash_firmware(
self,
*,
download_url: str,
flash_url: URL,
update_file: str,
expected_sha256: str | None,
) -> None:
"""Download firmware, verify its SHA256 digest, and POST it to the device."""
assert self.session is not None # noqa: S101 # guaranteed by upgrade()

try:
async with (
asyncio.timeout(
self.request_timeout * 10,
),
asyncio.timeout(self.request_timeout * 10),
self.session.get(
download_url,
raise_for_status=True,
) as download,
):
form = aiohttp.FormData()
form.add_field("file", await download.read(), filename=update_file)
await self.session.post(url, data=form)
firmware = await download.read()
except TimeoutError as exception:
msg = "Timeout occurred while fetching WLED version information from GitHub"
msg = f"Timeout occurred while downloading firmware from {download_url}"
raise WLEDConnectionTimeoutError(msg) from exception
except aiohttp.ClientResponseError as exception:
if exception.status == 404:
msg = f"Requested firmware file {update_file} does not exist"
raise WLEDUpgradeError(msg) from exception
msg = (
f"Could not download requested WLED version '{version}'"
f" from {download_url}"
)
msg = f"Could not download requested WLED version from {download_url}"
raise WLEDUpgradeError(msg) from exception
except (aiohttp.ClientError, socket.gaierror) as exception:
msg = (
"Error occurred while communicating with GitHub"
" for WLED version information"
)
msg = f"Error occurred while downloading firmware from {download_url}"
raise WLEDConnectionError(msg) from exception

if expected_sha256 is not None:
actual_sha256 = hashlib.sha256(firmware).hexdigest()
if actual_sha256 != expected_sha256:
msg = (
f"Firmware integrity check failed for {update_file}: "
f"expected SHA256 {expected_sha256}, "
f"got {actual_sha256}"
)
raise WLEDUpgradeError(msg)

form = aiohttp.FormData()
form.add_field("file", firmware, filename=update_file)
try:
async with (
asyncio.timeout(self.request_timeout * 10),
self.session.post(
flash_url,
data=form,
raise_for_status=True,
) as flash,
):
await flash.read()
except TimeoutError as exception:
msg = "Timeout occurred while uploading firmware to the device"
raise WLEDConnectionTimeoutError(msg) from exception
except aiohttp.ClientResponseError as exception:
msg = "Device rejected the firmware upload"
raise WLEDUpgradeError(msg) from exception
except (aiohttp.ClientError, socket.gaierror) as exception:
msg = "Error occurred while uploading firmware to the device"
raise WLEDConnectionError(msg) from exception
Comment on lines 759 to 795
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was addressed in the refactoring commit (f31e313). The download and upload are now in two separate try blocks:

  • Download block (lines 743–763): errors report "Error occurred while downloading firmware from {download_url}"
  • Flash block (lines 777–795): errors report "Error occurred while uploading firmware to the device"

No GitHub-specific messages leak into the device upload path.


async def _fetch_firmware_digest(
self, *, repo: str, version: str, asset_name: str
) -> str | None:
"""Fetch the expected SHA256 digest for a firmware asset from the GitHub API.

Returns the hex digest string if available, or None if it cannot be
determined (e.g. old release without digest, tag not found, API error).
Callers should proceed with the upgrade when None is returned.
"""
assert self.session is not None # noqa: S101 # guaranteed by upgrade()
api_url = f"https://api.github.com/repos/{repo}/releases/tags/v{version}"
try:
async with (
asyncio.timeout(self.request_timeout),
self.session.get(
api_url,
headers={"Accept": "application/json"},
) as response,
):
status = response.status
content_type = response.headers.get("Content-Type", "")
contents = await response.read()
except TimeoutError:
_LOGGER.warning(
"Timeout fetching release metadata from GitHub; "
"skipping firmware integrity check"
)
return None
except (aiohttp.ClientError, socket.gaierror):
_LOGGER.warning(
"Connection error fetching release metadata from GitHub; "
"skipping firmware integrity check"
)
return None

if status == 404:
_LOGGER.warning(
"Release tag v%s not found on GitHub API; "
"skipping firmware integrity check",
version,
)
return None

if status // 100 in [4, 5]:
_LOGGER.warning(
"GitHub API returned HTTP %d for release metadata; "
"skipping firmware integrity check",
status,
)
return None

if "application/json" not in content_type:
_LOGGER.warning(
"Unexpected content type '%s' from GitHub API; "
"skipping firmware integrity check",
content_type,
)
return None

release = orjson.loads(contents)
return self._parse_asset_digest(
release=release, asset_name=asset_name, version=version
)

def _parse_asset_digest(
self, *, release: dict[str, Any], asset_name: str, version: str
) -> str | None:
"""Extract the SHA256 hex digest for a named asset from a release payload."""
for asset in release.get("assets", []):
if asset.get("name") == asset_name:
digest: str | None = asset.get("digest")
if digest is None:
_LOGGER.debug(
"No digest for asset %s in release v%s; "
"skipping firmware integrity check",
asset_name,
version,
)
return None
prefix = "sha256:"
if digest.startswith(prefix):
return digest[len(prefix) :]
_LOGGER.warning(
"Unrecognised digest format '%s' for asset %s; "
"skipping firmware integrity check",
digest,
asset_name,
)
return None

_LOGGER.debug(
"Asset %s not found in release v%s assets; "
"skipping firmware integrity check",
asset_name,
version,
)
return None

async def reset(self) -> None:
"""Reboot WLED device."""
await self.request("/reset")
Expand Down
Loading
Loading