diff --git a/file_upload.go b/file_upload.go index d13e073..21097d6 100644 --- a/file_upload.go +++ b/file_upload.go @@ -2,12 +2,12 @@ package proton_api_bridge import ( "bufio" - "bytes" "context" "crypto/sha1" "crypto/sha256" "encoding/base64" "encoding/hex" + "fmt" "io" "mime" "os" @@ -25,6 +25,16 @@ func (protonDrive *ProtonDrive) handleRevisionConflict(ctx context.Context, link draftRevision, err := protonDrive.GetRevisions(ctx, link, proton.RevisionStateDraft) if err != nil { + // If we can't list revisions and the link is already in draft state, + // it's a broken/incomplete upload from a previous failed attempt with + // no recoverable state. Always delete it and retry from scratch. + if link.State == proton.LinkStateDraft { + err = protonDrive.c.DeleteChildren(ctx, protonDrive.MainShare.ShareID, link.ParentLinkID, linkID) + if err != nil { + return "", false, err + } + return "", true, nil + } return "", false, err } @@ -251,6 +261,18 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n return nil, 0, nil, "", ErrMissingInputUploadAndCollectBlockData } + // Fetch the per-revision verification code required by Proton's storage backend. + // Each block's Verifier.Token is produced by XOR-ing this code with the first + // bytes of that block's ciphertext (per the Proton Drive JS SDK spec). + revVerification, err := protonDrive.c.GetRevisionVerification(ctx, protonDrive.MainShare.VolumeID, linkID, revisionID) + if err != nil { + return nil, 0, nil, "", fmt.Errorf("uploadAndCollectBlockData: get revision verification: %w", err) + } + verificationCode, err := base64.StdEncoding.DecodeString(revVerification.VerificationCode) + if err != nil { + return nil, 0, nil, "", fmt.Errorf("uploadAndCollectBlockData: decode verification code: %w", err) + } + totalFileSize := int64(0) pendingUploadBlocks := make([]PendingUploadBlocks, 0) @@ -277,28 +299,41 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n return err } - errChan := make(chan error) - uploadBlockWrapper := func(ctx context.Context, errChan chan error, bareURL, token string, block io.Reader) { - // log.Println("Before semaphore") - if err := protonDrive.blockUploadSemaphore.Acquire(ctx, 1); err != nil { + // Use a per-batch cancellable context so that when one block upload + // fails, all sibling goroutines are cancelled promptly and release + // their semaphore slots before the outer retry begins. + batchCtx, batchCancel := context.WithCancel(ctx) + defer batchCancel() + + // Buffered so every goroutine can always send without blocking, + // even after the first error has been received and batchCancel called. + errChan := make(chan error, len(blockUploadResp)) + uploadBlockWrapper := func(bareURL, token string, block []byte) { + if err := protonDrive.blockUploadSemaphore.Acquire(batchCtx, 1); err != nil { errChan <- err + return // must not defer-Release a slot we never acquired } defer protonDrive.blockUploadSemaphore.Release(1) - // log.Println("After semaphore") - // defer log.Println("Release semaphore") - errChan <- protonDrive.c.UploadBlock(ctx, bareURL, token, block) + errChan <- protonDrive.c.UploadBlock(batchCtx, bareURL, token, block) } for i := range blockUploadResp { - go uploadBlockWrapper(ctx, errChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData)) + go uploadBlockWrapper(blockUploadResp[i].BareURL, blockUploadResp[i].Token, pendingUploadBlocks[i].encData) } + // Drain all goroutines. Cancel on first error so the rest stop quickly, + // but still wait for all of them so semaphore slots are fully released + // before we return. + var firstErr error for i := 0; i < len(blockUploadResp); i++ { - err := <-errChan - if err != nil { - return err + if err := <-errChan; err != nil && firstErr == nil { + firstErr = err + batchCancel() } } + if firstErr != nil { + return firstErr + } pendingUploadBlocks = pendingUploadBlocks[:0] @@ -310,7 +345,7 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n blockSizes := make([]int64, 0) for i := 1; shouldContinue; i++ { if (i-1) > 0 && (i-1)%UPLOAD_BATCH_BLOCK_SIZE == 0 { - err := uploadPendingBlocks() + err = uploadPendingBlocks() if err != nil { return nil, 0, nil, "", err } @@ -366,17 +401,31 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n } manifestSignatureData = append(manifestSignatureData, hash...) + // Compute per-block verifier token: XOR verificationCode with the + // leading bytes of the encrypted block (zero-padded if block is shorter). + verificationToken := make([]byte, len(verificationCode)) + for j, v := range verificationCode { + var b byte + if j < len(encData) { + b = encData[j] + } + verificationToken[j] = v ^ b + } + pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{ blockUploadInfo: proton.BlockUploadInfo{ Index: i, // iOS drive: BE starts with 1 Size: int64(len(encData)), EncSignature: encSignatureStr, Hash: base64Hash, + Verifier: proton.BlockUploadVerifier{ + Token: base64.StdEncoding.EncodeToString(verificationToken), + }, }, encData: encData, }) } - err := uploadPendingBlocks() + err = uploadPendingBlocks() if err != nil { return nil, 0, nil, "", err }