Skip to content
Open

[WIP] #4170

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions logservice/logpuller/region_req_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ func newRequestCache(maxPendingCount int) *requestCache {
return res
}

// markDropped marks a region request as finished without being tracked in sentRequests.
// It is used when a request is popped from pendingQueue but won't be sent (for example, for a stop task),
// or when sending fails before markSent is called.
func (c *requestCache) markDropped() {
c.decPendingCount()
// Notify waiting add operations that there's space available.
select {
case c.spaceAvailable <- struct{}{}:
default: // If channel is full, skip notification
}
}

// add adds a new region request to the cache
// It blocks if pendingCount >= maxPendingCount until there's space or ctx is cancelled
func (c *requestCache) add(ctx context.Context, region regionInfo, force bool) (bool, error) {
Expand Down Expand Up @@ -255,9 +267,15 @@ func (c *requestCache) clearStaleRequest() {
}
}

if reqCount == 0 && c.pendingCount.Load() != 0 {
log.Info("region worker pending request count is not equal to actual region request count, correct it", zap.Int("pendingCount", int(c.pendingCount.Load())), zap.Int("actualReqCount", reqCount))
c.pendingCount.Store(0)
actualReqCount := int64(reqCount) + int64(len(c.pendingQueue))
pendingCount := c.pendingCount.Load()
// One request can be "in flight" (popped from pendingQueue but not yet marked sent),
// so we tolerate a small mismatch to avoid false corrections.
if pendingCount < actualReqCount || pendingCount-actualReqCount > 1 {
log.Info("region worker pending request count is not equal to actual region request count, correct it",
zap.Int("pendingCount", int(pendingCount)),
zap.Int64("actualReqCount", actualReqCount))
c.pendingCount.Store(actualReqCount)
Comment on lines +270 to +278
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Normalization logic is reasonable but the correction itself is racy.

clearStaleRequest reads pendingCount (line 271) and len(c.pendingQueue) (line 270) without holding any synchronization that covers both the atomic counter and the channel length together. A concurrent markDropped/add call between these two reads can cause an incorrect correction. Since sentRequests.Lock is held, only the queue side is unprotected.

Additionally, the tolerance comment says "up to 1" for an in-flight request, but if pendingCount < actualReqCount (line 274, first branch) the correction blindly stores actualReqCount, which may be stale by then. Given this runs on a 10-second timer and is a safety net rather than primary accounting, the impact is low, but consider logging a warning rather than auto-correcting when the mismatch is small.

🤖 Prompt for AI Agents
In `@logservice/logpuller/region_req_cache.go` around lines 270 - 278, The
correction in clearStaleRequest is racy because it reads len(c.pendingQueue) and
c.pendingCount separately without a consistent lock; fix it by making the
comparison and potential fix under a single synchronization point (e.g., acquire
sentRequests.Lock around reading len(c.pendingQueue) and pendingCount or
otherwise atomically snapshot both values), and change the behavior when the
delta is 1 to only log a warning instead of blindly calling
c.pendingCount.Store(actualReqCount); ensure references to clearStaleRequest,
c.pendingQueue, c.pendingCount, sentRequests.Lock, and the markDropped/add code
paths are considered so concurrent updates won't cause incorrect corrections.

}

c.lastCheckStaleRequestTime.Store(time.Now())
Expand Down
6 changes: 6 additions & 0 deletions logservice/logpuller/region_request_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func newRegionRequestWorker(
*worker.preFetchForConnecting = region.regionInfo
return nil
} else {
worker.requestCache.markDropped()
continue
}
}
Expand Down Expand Up @@ -375,8 +376,10 @@ func (s *regionRequestWorker) processRegionSendTask(
FilterLoop: region.filterLoop,
}
if err := doSend(req); err != nil {
s.requestCache.markDropped()
return err
}
s.requestCache.markDropped()
for _, state := range s.takeRegionStates(subID) {
state.markStopped(&requestCancelledErr{})
regionEvent := regionEvent{
Expand All @@ -390,11 +393,13 @@ func (s *regionRequestWorker) processRegionSendTask(
// the stopped subscribedTable, or the special singleRegionInfo for stopping
// the table will be handled later.
s.client.onRegionFail(newRegionErrorInfo(region, &sendRequestToStoreErr{}))
s.requestCache.markDropped()
} else {
state := newRegionFeedState(region, uint64(subID), s)
state.start()
s.addRegionState(subID, region.verID.GetID(), state)
if err := doSend(s.createRegionRequest(region)); err != nil {
s.requestCache.markDropped()
return err
}
s.requestCache.markSent(regionReq)
Expand Down Expand Up @@ -485,6 +490,7 @@ func (s *regionRequestWorker) clearPendingRegions() []regionInfo {
region := *s.preFetchForConnecting
s.preFetchForConnecting = nil
regions = append(regions, region)
s.requestCache.markDropped()
}

// Clear all regions from cache
Expand Down
123 changes: 123 additions & 0 deletions logservice/logpuller/region_request_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,19 @@
package logpuller

import (
"context"
"errors"
"io"
"testing"
"time"

"github.com/pingcap/kvproto/pkg/cdcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/ticdc/heartbeatpb"
"github.com/pingcap/ticdc/logservice/logpuller/regionlock"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/tikv"
"google.golang.org/grpc/metadata"
)

func TestRegionStatesOperation(t *testing.T) {
Expand All @@ -38,3 +48,116 @@ func TestRegionStatesOperation(t *testing.T) {
require.Nil(t, worker.getRegionState(1, 2))
require.Equal(t, 0, len(worker.requestedRegions.subscriptions))
}

type fakeEventFeedV2Client struct {
sendErr error
sendHook func(*cdcpb.ChangeDataRequest)
ctx context.Context
}

func (c *fakeEventFeedV2Client) Send(req *cdcpb.ChangeDataRequest) error {
if c.sendHook != nil {
c.sendHook(req)
}
return c.sendErr
}

func (c *fakeEventFeedV2Client) Recv() (*cdcpb.ChangeDataEvent, error) { return nil, io.EOF }

func (c *fakeEventFeedV2Client) Header() (metadata.MD, error) { return nil, nil }
func (c *fakeEventFeedV2Client) Trailer() metadata.MD { return nil }
func (c *fakeEventFeedV2Client) CloseSend() error { return nil }
func (c *fakeEventFeedV2Client) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
func (c *fakeEventFeedV2Client) SendMsg(m any) error { return nil }
func (c *fakeEventFeedV2Client) RecvMsg(m any) error { return nil }

func TestRegionRequestWorkerSendErrorDoesNotLeakPendingCount(t *testing.T) {
t.Parallel()

subSpan := &subscribedSpan{subID: 1}
region := newRegionInfo(
tikv.NewRegionVerID(100, 1, 1),
heartbeatpb.TableSpan{StartKey: []byte("a"), EndKey: []byte("b")},
&tikv.RPCContext{
Addr: "store-1",
Meta: &metapb.Region{RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}},
},
subSpan,
false,
)
region.lockedRangeState = &regionlock.LockedRangeState{}

worker := &regionRequestWorker{
workerID: 1,
client: &subscriptionClient{clusterID: 1},
store: &requestedStore{storeAddr: "store-1"},
preFetchForConnecting: &region,
requestCache: newRequestCache(16),
}
worker.requestedRegions.subscriptions = make(map[SubscriptionID]regionFeedStates)
worker.requestCache.pendingCount.Store(1)

conn := &ConnAndClient{
Client: &fakeEventFeedV2Client{sendErr: errors.New("send failed")},
}

err := worker.processRegionSendTask(context.Background(), conn)
require.Error(t, err)
require.Equal(t, int64(0), worker.requestCache.pendingCount.Load())
}

func TestRegionRequestWorkerStopTaskDoesNotLeakPendingCount(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sendCalled := make(chan struct{})
conn := &ConnAndClient{
Client: &fakeEventFeedV2Client{
sendHook: func(*cdcpb.ChangeDataRequest) {
select {
case <-sendCalled:
default:
close(sendCalled)
}
cancel()
},
},
}

stopRegion := regionInfo{
subscribedSpan: &subscribedSpan{subID: 1},
lockedRangeState: nil,
}

worker := &regionRequestWorker{
workerID: 1,
client: &subscriptionClient{clusterID: 1},
store: &requestedStore{storeAddr: "store-1"},
preFetchForConnecting: &stopRegion,
requestCache: newRequestCache(16),
}
worker.requestedRegions.subscriptions = make(map[SubscriptionID]regionFeedStates)
worker.requestCache.pendingCount.Store(1)

errCh := make(chan error, 1)
go func() {
errCh <- worker.processRegionSendTask(ctx, conn)
}()

select {
case <-sendCalled:
case <-time.After(2 * time.Second):
t.Fatal("send is not called in time")
}

err := <-errCh
require.Error(t, err)
require.Equal(t, int64(0), worker.requestCache.pendingCount.Load())
}