Skip to content

ZarrsCodecPipeline deadlocks in forked child processes (torch DataLoader with num_workers>0) #171

@srivarra

Description

@srivarra

Reading or writing a Zarr array through ZarrsCodecPipeline deadlocks in a forked child process if the parent process has already decoded a chunk. This makes torch.utils.data.DataLoader(num_workers>0) which forks workers by default deadlock on Linux.

Small example(s)

Forking without PyTorch

  import os, tempfile
  import numpy as np, zarr

  zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})

  path = tempfile.mkdtemp() + "/a.zarr"
  a = zarr.create_array(store=path, shape=(256, 256), chunks=(32, 32),
                          dtype="int16", zarr_format=3)
  a[:] = np.arange(256 * 256, dtype="int16").reshape(256, 256)  # parent uses the codec

  pid = os.fork()
  if pid == 0:
      zarr.open_array(path, mode="r")[:]   # hangs forever
      os._exit(0)

    os.waitpid(pid, 0)   # parent never returns

Forking with PyTorch DataLoader

import sys, tempfile
import numpy as np, torch, zarr
from torch.utils.data import DataLoader, Dataset

zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})


class ImageDataset(Dataset):
    def __init__(self, root: str, n: int):
        self.root, self.n = root, n   # instance attrs travel to spawned workers

    def __len__(self) -> int:
        return self.n

    def __getitem__(self, idx: int) -> torch.Tensor:
        arr = zarr.open_array(f"{self.root}/img_{idx}.zarr", mode="r")
        return torch.from_numpy(np.asarray(arr[:]))   # full read -> parallel codec


def main(mode: str) -> None:
    root = tempfile.mkdtemp()
    n = 8
    for i in range(n):
        a = zarr.create_array(store=f"{root}/img_{i}.zarr", shape=(256, 256),
                              chunks=(32, 32), dtype="int16", zarr_format=3)
        a[:] = np.full((256, 256), i, dtype="int16")
    zarr.open_array(f"{root}/img_0.zarr", mode="r")[:]   # parent arms the codec pool

    kwargs = {"batch_size": 2}
    if mode == "fork":
        kwargs |= {"num_workers": 4, "multiprocessing_context": "fork"}
    elif mode == "spawn":
        kwargs |= {"num_workers": 4, "multiprocessing_context": "spawn"}
    # "serial": num_workers=0 (default)

    loader = DataLoader(ImageDataset(root, n), **kwargs)
    total = sum(int(batch.shape[0]) for batch in loader)
    print(f"[{mode}] read {total} images")


if __name__ == "__main__":
    main(sys.argv[1] if len(sys.argv) > 1 else "serial")

So far from what I've debugged with claude a bit and found that the first codec call spins up a global thread pool, and when forked, the child inherits the pools' state but none of its threads. Then the next decode blocks on a thread which never runs.

The annbatch docs mention setting the multiprocessign_context to "spawn" which fixes it along with setting the num_workers to 0. But it would be nice if fork could play nicely with the codec.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions