diff --git a/ipcutil/ipcutil.go b/ipcutil/ipcutil.go index 5939d86..fd4cdb3 100644 --- a/ipcutil/ipcutil.go +++ b/ipcutil/ipcutil.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "io" + "sync" ) // MaxMessageSize is the maximum IPC message size (1MB). @@ -29,7 +30,13 @@ func Read(r io.Reader) ([]byte, error) { } // Write writes a length-prefixed IPC message to w. +// It is safe for concurrent use: the length prefix and payload are +// written as an atomic unit, preventing interleaving when callers +// share the same writer across goroutines. func Write(w io.Writer, data []byte) error { + writeMu.Lock() + defer writeMu.Unlock() + var lenBuf [4]byte binary.BigEndian.PutUint32(lenBuf[:], uint32(len(data))) if _, err := w.Write(lenBuf[:]); err != nil { @@ -38,3 +45,5 @@ func Write(w io.Writer, data []byte) error { _, err := w.Write(data) return err } + +var writeMu sync.Mutex diff --git a/ipcutil/zz_test.go b/ipcutil/zz_test.go index 6df60be..5526aac 100644 --- a/ipcutil/zz_test.go +++ b/ipcutil/zz_test.go @@ -5,8 +5,10 @@ package ipcutil import ( "bytes" "encoding/binary" + "fmt" "io" "strings" + "sync" "testing" ) @@ -127,6 +129,57 @@ func TestWriteErrorOnLengthPrefix(t *testing.T) { } } +// TestWriteConcurrent verifies Write is safe across many goroutines sharing +// the same writer: length headers never interleave with unrelated payloads, +// and every written message round-trips intact. +func TestWriteConcurrent(t *testing.T) { + t.Parallel() + + // rawBuf collects all bytes without any internal locking — it + // exercises the Write-level mutex, not the io.Writer level. + var mu sync.Mutex + var rawBuf bytes.Buffer + + const n = 200 + var wg sync.WaitGroup + wg.Add(n) + + for i := 0; i < n; i++ { + go func(id int) { + defer wg.Done() + payload := []byte(fmt.Sprintf("msg-%03d-%s", id, strings.Repeat("X", id))) + // Serialize writes to rawBuf so the test doesn't + // trip over bytes.Buffer being non-concurrent. + mu.Lock() + if err := Write(&rawBuf, payload); err != nil { + t.Errorf("Write(%d): %v", id, err) + } + mu.Unlock() + }(i) + } + wg.Wait() + + reader := bytes.NewReader(rawBuf.Bytes()) + seen := make(map[int]bool) + for i := 0; i < n; i++ { + msg, err := Read(reader) + if err != nil { + t.Fatalf("Read %d: %v — %d bytes remain in buffer", i, err, reader.Len()) + } + var id int + if _, scanErr := fmt.Sscanf(string(msg), "msg-%03d-", &id); scanErr != nil { + t.Fatalf("Read %d: corrupt message %q: %v", i, string(msg[:min(len(msg), 20)]), scanErr) + } + if seen[id] { + t.Fatalf("duplicate message id %d", id) + } + seen[id] = true + } + if len(seen) != n { + t.Fatalf("expected %d unique messages, got %d", n, len(seen)) + } +} + func TestWriteErrorOnPayload(t *testing.T) { t.Parallel() w := &errWriter{failAfter: 1} // first write (length) succeeds, second (payload) fails