Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions file_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package proton_api_bridge

import (
"bufio"
"bytes"
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"mime"
"os"
Expand All @@ -25,6 +25,16 @@ func (protonDrive *ProtonDrive) handleRevisionConflict(ctx context.Context, link

draftRevision, err := protonDrive.GetRevisions(ctx, link, proton.RevisionStateDraft)
if err != nil {
// If we can't list revisions and the link is already in draft state,
// it's a broken/incomplete upload from a previous failed attempt with
// no recoverable state. Always delete it and retry from scratch.
if link.State == proton.LinkStateDraft {
err = protonDrive.c.DeleteChildren(ctx, protonDrive.MainShare.ShareID, link.ParentLinkID, linkID)
if err != nil {
return "", false, err
}
return "", true, nil
}
return "", false, err
}

Expand Down Expand Up @@ -251,6 +261,18 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n
return nil, 0, nil, "", ErrMissingInputUploadAndCollectBlockData
}

// Fetch the per-revision verification code required by Proton's storage backend.
// Each block's Verifier.Token is produced by XOR-ing this code with the first
// bytes of that block's ciphertext (per the Proton Drive JS SDK spec).
revVerification, err := protonDrive.c.GetRevisionVerification(ctx, protonDrive.MainShare.VolumeID, linkID, revisionID)
if err != nil {
return nil, 0, nil, "", fmt.Errorf("uploadAndCollectBlockData: get revision verification: %w", err)
}
verificationCode, err := base64.StdEncoding.DecodeString(revVerification.VerificationCode)
if err != nil {
return nil, 0, nil, "", fmt.Errorf("uploadAndCollectBlockData: decode verification code: %w", err)
}

totalFileSize := int64(0)

pendingUploadBlocks := make([]PendingUploadBlocks, 0)
Expand All @@ -277,28 +299,41 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n
return err
}

errChan := make(chan error)
uploadBlockWrapper := func(ctx context.Context, errChan chan error, bareURL, token string, block io.Reader) {
// log.Println("Before semaphore")
if err := protonDrive.blockUploadSemaphore.Acquire(ctx, 1); err != nil {
// Use a per-batch cancellable context so that when one block upload
// fails, all sibling goroutines are cancelled promptly and release
// their semaphore slots before the outer retry begins.
batchCtx, batchCancel := context.WithCancel(ctx)
defer batchCancel()

// Buffered so every goroutine can always send without blocking,
// even after the first error has been received and batchCancel called.
errChan := make(chan error, len(blockUploadResp))
uploadBlockWrapper := func(bareURL, token string, block []byte) {
if err := protonDrive.blockUploadSemaphore.Acquire(batchCtx, 1); err != nil {
errChan <- err
return // must not defer-Release a slot we never acquired
}
defer protonDrive.blockUploadSemaphore.Release(1)
// log.Println("After semaphore")
// defer log.Println("Release semaphore")

errChan <- protonDrive.c.UploadBlock(ctx, bareURL, token, block)
errChan <- protonDrive.c.UploadBlock(batchCtx, bareURL, token, block)
}
for i := range blockUploadResp {
go uploadBlockWrapper(ctx, errChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData))
go uploadBlockWrapper(blockUploadResp[i].BareURL, blockUploadResp[i].Token, pendingUploadBlocks[i].encData)
}

// Drain all goroutines. Cancel on first error so the rest stop quickly,
// but still wait for all of them so semaphore slots are fully released
// before we return.
var firstErr error
for i := 0; i < len(blockUploadResp); i++ {
err := <-errChan
if err != nil {
return err
if err := <-errChan; err != nil && firstErr == nil {
firstErr = err
batchCancel()
}
}
if firstErr != nil {
return firstErr
}

pendingUploadBlocks = pendingUploadBlocks[:0]

Expand All @@ -310,7 +345,7 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n
blockSizes := make([]int64, 0)
for i := 1; shouldContinue; i++ {
if (i-1) > 0 && (i-1)%UPLOAD_BATCH_BLOCK_SIZE == 0 {
err := uploadPendingBlocks()
err = uploadPendingBlocks()
if err != nil {
return nil, 0, nil, "", err
}
Expand Down Expand Up @@ -366,17 +401,31 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n
}
manifestSignatureData = append(manifestSignatureData, hash...)

// Compute per-block verifier token: XOR verificationCode with the
// leading bytes of the encrypted block (zero-padded if block is shorter).
verificationToken := make([]byte, len(verificationCode))
for j, v := range verificationCode {
var b byte
if j < len(encData) {
b = encData[j]
}
verificationToken[j] = v ^ b
}

pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{
blockUploadInfo: proton.BlockUploadInfo{
Index: i, // iOS drive: BE starts with 1
Size: int64(len(encData)),
EncSignature: encSignatureStr,
Hash: base64Hash,
Verifier: proton.BlockUploadVerifier{
Token: base64.StdEncoding.EncodeToString(verificationToken),
},
},
encData: encData,
})
}
err := uploadPendingBlocks()
err = uploadPendingBlocks()
if err != nil {
return nil, 0, nil, "", err
}
Expand Down