Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions src/openlayer/lib/data/_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
different storage backends.
"""

import io
import os
from enum import Enum
from typing import Optional
from typing import BinaryIO, Dict, Optional, Union

import requests
from requests.adapters import Response
Expand Down Expand Up @@ -35,6 +36,135 @@ class StorageType(Enum):
VERIFY_REQUESTS = True


# ----- Low-level upload functions (work with bytes or file-like objects) ---- #
def upload_bytes(
storage: StorageType,
url: str,
data: Union[bytes, BinaryIO],
object_name: str,
content_type: str,
fields: Optional[Dict] = None,
) -> Response:
"""Upload data to the appropriate storage backend.

This is a convenience function that routes to the correct upload method
based on the storage type.

Args:
storage: The storage backend type.
url: The presigned URL to upload to.
data: The data to upload (bytes or file-like object).
object_name: The object name (used for multipart uploads).
content_type: The MIME type of the data.
fields: Additional fields for multipart uploads (S3 policy fields).

Returns:
The response from the upload request.
"""
if storage == StorageType.AWS:
return upload_bytes_multipart(
url=url,
data=data,
object_name=object_name,
content_type=content_type,
fields=fields,
)
elif storage == StorageType.GCP:
return upload_bytes_put(
url=url,
data=data,
content_type=content_type,
)
elif storage == StorageType.AZURE:
return upload_bytes_put(
url=url,
data=data,
content_type=content_type,
extra_headers={"x-ms-blob-type": "BlockBlob"},
)
else:
# Local storage uses multipart POST (no extra fields)
return upload_bytes_multipart(
url=url,
data=data,
object_name=object_name,
content_type=content_type,
)


def upload_bytes_multipart(
url: str,
data: Union[bytes, BinaryIO],
object_name: str,
content_type: str,
fields: Optional[Dict] = None,
) -> Response:
"""Upload data using multipart POST (for S3 and local storage).

Args:
url: The presigned URL to upload to.
data: The data to upload (bytes or file-like object).
object_name: The object name for the file field.
content_type: The MIME type of the data.
fields: Additional fields to include in the multipart form (e.g., S3 policy fields).

Returns:
The response from the upload request.
"""
# Convert bytes to file-like object if needed
if isinstance(data, bytes):
data = io.BytesIO(data)

upload_fields = dict(fields) if fields else {}
upload_fields["file"] = (object_name, data, content_type)

encoder = MultipartEncoder(fields=upload_fields)
headers = {"Content-Type": encoder.content_type}

response = requests.post(
url,
data=encoder,
headers=headers,
verify=VERIFY_REQUESTS,
timeout=REQUESTS_TIMEOUT,
)
response.raise_for_status()
return response


def upload_bytes_put(
url: str,
data: Union[bytes, BinaryIO],
content_type: str,
extra_headers: Optional[Dict[str, str]] = None,
) -> Response:
"""Upload data using PUT request (for GCS and Azure).

Args:
url: The presigned URL to upload to.
data: The data to upload (bytes or file-like object).
content_type: The MIME type of the data.
extra_headers: Additional headers (e.g., x-ms-blob-type for Azure).

Returns:
The response from the upload request.
"""
headers = {"Content-Type": content_type}
if extra_headers:
headers.update(extra_headers)

response = requests.put(
url,
data=data,
headers=headers,
verify=VERIFY_REQUESTS,
timeout=REQUESTS_TIMEOUT,
)
response.raise_for_status()
return response


# --- High-level Uploader class (file-based uploads with progress tracking) -- #
class Uploader:
"""Internal class to handle http requests"""

Expand Down Expand Up @@ -105,7 +235,9 @@ def upload_blob_s3(
fields = presigned_url_response.fields
fields["file"] = (object_name, f, "application/x-tar")
e = MultipartEncoder(fields=fields)
m = MultipartEncoderMonitor(e, lambda monitor: t.update(min(t.total, monitor.bytes_read) - t.n))
m = MultipartEncoderMonitor(
e, lambda monitor: t.update(min(t.total, monitor.bytes_read) - t.n)
)
headers = {"Content-Type": m.content_type}
res = requests.post(
presigned_url_response.url,
Expand All @@ -116,7 +248,9 @@ def upload_blob_s3(
)
return res

def upload_blob_gcs(self, file_path: str, presigned_url_response: PresignedURLCreateResponse):
def upload_blob_gcs(
self, file_path: str, presigned_url_response: PresignedURLCreateResponse
):
"""Generic method to upload data to Google Cloud Storage and create the
appropriate resource in the backend.
"""
Expand All @@ -137,7 +271,9 @@ def upload_blob_gcs(self, file_path: str, presigned_url_response: PresignedURLCr
)
return res

def upload_blob_azure(self, file_path: str, presigned_url_response: PresignedURLCreateResponse):
def upload_blob_azure(
self, file_path: str, presigned_url_response: PresignedURLCreateResponse
):
"""Generic method to upload data to Azure Blob Storage and create the
appropriate resource in the backend.
"""
Expand Down Expand Up @@ -180,7 +316,9 @@ def upload_blob_local(
with open(file_path, "rb") as f:
fields = {"file": (object_name, f, "application/x-tar")}
e = MultipartEncoder(fields=fields)
m = MultipartEncoderMonitor(e, lambda monitor: t.update(min(t.total, monitor.bytes_read) - t.n))
m = MultipartEncoderMonitor(
e, lambda monitor: t.update(min(t.total, monitor.bytes_read) - t.n)
)
headers = {"Content-Type": m.content_type}
res = requests.post(
presigned_url_response.url,
Expand Down
Loading