diff --git a/packages/durabletask-js/src/worker/task-hub-grpc-worker.ts b/packages/durabletask-js/src/worker/task-hub-grpc-worker.ts index 03a2e12..1d70675 100644 --- a/packages/durabletask-js/src/worker/task-hub-grpc-worker.ts +++ b/packages/durabletask-js/src/worker/task-hub-grpc-worker.ts @@ -823,7 +823,23 @@ export class TaskHubGrpcWorker { } /** - * Executes an entity batch request. + * Executes an entity batch request and tracks it as a pending work item. + */ + private _executeEntity( + req: pb.EntityBatchRequest, + completionToken: string, + stub: stubs.TaskHubSidecarServiceClient, + operationInfos?: pb.OperationInfo[], + ): void { + const workPromise = this._executeEntityInternal(req, completionToken, stub, operationInfos); + this._pendingWorkItems.add(workPromise); + workPromise.finally(() => { + this._pendingWorkItems.delete(workPromise); + }); + } + + /** + * Internal implementation of entity batch execution. * * @param req - The entity batch request from the sidecar. * @param completionToken - The completion token for the work item. @@ -834,7 +850,7 @@ export class TaskHubGrpcWorker { * This method looks up the entity by name, creates a TaskEntityShim, executes the batch, * and sends the result back to the sidecar. */ - private async _executeEntity( + private async _executeEntityInternal( req: pb.EntityBatchRequest, completionToken: string, stub: stubs.TaskHubSidecarServiceClient, @@ -907,7 +923,22 @@ export class TaskHubGrpcWorker { } /** - * Executes an entity request (V2 format). + * Executes an entity request (V2 format) and tracks it as a pending work item. + */ + private _executeEntityV2( + req: pb.EntityRequest, + completionToken: string, + stub: stubs.TaskHubSidecarServiceClient, + ): void { + const workPromise = this._executeEntityV2Internal(req, completionToken, stub); + this._pendingWorkItems.add(workPromise); + workPromise.finally(() => { + this._pendingWorkItems.delete(workPromise); + }); + } + + /** + * Internal implementation of V2 entity execution. * * @param req - The entity request (V2) from the sidecar. * @param completionToken - The completion token for the work item. @@ -918,7 +949,7 @@ export class TaskHubGrpcWorker { * instead of OperationRequest. It converts the V2 format to V1 format * (EntityBatchRequest) and delegates to the existing execution logic. */ - private async _executeEntityV2( + private async _executeEntityV2Internal( req: pb.EntityRequest, completionToken: string, stub: stubs.TaskHubSidecarServiceClient, @@ -1002,7 +1033,7 @@ export class TaskHubGrpcWorker { batchRequest.setOperationsList(operations); // Delegate to the V1 execution logic with V2 operationInfos - await this._executeEntity(batchRequest, completionToken, stub, operationInfos); + await this._executeEntityInternal(batchRequest, completionToken, stub, operationInfos); } /** diff --git a/packages/durabletask-js/test/worker-entity.spec.ts b/packages/durabletask-js/test/worker-entity.spec.ts index 94d3225..7fc73ad 100644 --- a/packages/durabletask-js/test/worker-entity.spec.ts +++ b/packages/durabletask-js/test/worker-entity.spec.ts @@ -5,6 +5,9 @@ import { TaskHubGrpcWorker } from "../src/worker/task-hub-grpc-worker"; import { TaskEntity } from "../src/entities/task-entity"; import { ITaskEntity, EntityFactory } from "../src/entities/task-entity"; import { TaskEntityOperation } from "../src/entities/task-entity-operation"; +import * as pb from "../src/proto/orchestrator_service_pb"; +import * as stubs from "../src/proto/orchestrator_service_grpc_pb"; +import { NoOpLogger } from "../src/types/logger.type"; /** * Test entity for worker tests. @@ -20,6 +23,70 @@ class CounterEntity extends TaskEntity { } } +const COMPLETION_TOKEN = "test-completion-token"; + +/** + * Creates a mock gRPC stub that captures the EntityBatchResult passed to + * completeEntityTask. + */ +function createMockStub(): { + stub: stubs.TaskHubSidecarServiceClient; + capturedResult: pb.EntityBatchResult | null; +} { + let capturedResult: pb.EntityBatchResult | null = null; + + const stub = { + completeEntityTask: ( + result: pb.EntityBatchResult, + metadata: any, + callback: (err: any, res: any) => void, + ) => { + capturedResult = result; + callback(null, {}); + }, + } as unknown as stubs.TaskHubSidecarServiceClient; + + return { + stub, + get capturedResult() { + return capturedResult; + }, + }; +} + +/** + * Creates a minimal EntityBatchRequest for testing. + */ +function createEntityBatchRequest(entityName: string, entityKey: string): pb.EntityBatchRequest { + const req = new pb.EntityBatchRequest(); + req.setInstanceid(`@${entityName}@${entityKey}`); + + const opRequest = new pb.OperationRequest(); + opRequest.setOperation("increment"); + opRequest.setRequestid("req-1"); + req.setOperationsList([opRequest]); + + return req; +} + +/** + * Creates a minimal EntityRequest (V2) for testing. + */ +function createEntityRequestV2(entityName: string, entityKey: string): pb.EntityRequest { + const req = new pb.EntityRequest(); + req.setInstanceid(`@${entityName}@${entityKey}`); + + const historyEvent = new pb.HistoryEvent(); + const signaled = new pb.EntityOperationSignaledEvent(); + signaled.setOperation("increment"); + signaled.setRequestid("req-1"); + historyEvent.setEntityoperationsignaled(signaled); + req.setOperationrequestsList([historyEvent]); + + return req; +} + + describe("TaskHubGrpcWorker", () => { describe("Entity Registration", () => { describe("addEntity", () => { @@ -144,4 +211,123 @@ describe("TaskHubGrpcWorker", () => { expect(true).toBe(true); }); }); + + describe("Entity Execution Tracking", () => { + it("should track V1 entity execution in _pendingWorkItems", async () => { + // Arrange + const worker = new TaskHubGrpcWorker({ logger: new NoOpLogger() }); + const factory: EntityFactory = () => new CounterEntity(); + worker.addNamedEntity("counter", factory); + + const mockStub = createMockStub(); + const req = createEntityBatchRequest("counter", "key1"); + + // Act - call _executeEntity via the wrapper (which tracks the work item) + (worker as any)._executeEntity(req, COMPLETION_TOKEN, mockStub.stub); + + // Assert - the promise should be tracked while executing + const pendingWorkItems: Set> = (worker as any)._pendingWorkItems; + expect(pendingWorkItems.size).toBe(1); + + // Wait for completion + await Promise.all(pendingWorkItems); + + // After completion, it should be removed + expect(pendingWorkItems.size).toBe(0); + }); + + it("should remove V1 entity execution from _pendingWorkItems after completion", async () => { + // Arrange + const worker = new TaskHubGrpcWorker({ logger: new NoOpLogger() }); + const factory: EntityFactory = () => new CounterEntity(); + worker.addNamedEntity("counter", factory); + + const mockStub = createMockStub(); + const req = createEntityBatchRequest("counter", "key1"); + + // Act + (worker as any)._executeEntity(req, COMPLETION_TOKEN, mockStub.stub); + + const pendingWorkItems: Set> = (worker as any)._pendingWorkItems; + + // Wait for completion + await Promise.all(pendingWorkItems); + + // Assert - should have been cleaned up + expect(pendingWorkItems.size).toBe(0); + expect(mockStub.capturedResult).not.toBeNull(); + expect(mockStub.capturedResult!.getCompletiontoken()).toBe(COMPLETION_TOKEN); + }); + + it("should track V2 entity execution in _pendingWorkItems", async () => { + // Arrange + const worker = new TaskHubGrpcWorker({ logger: new NoOpLogger() }); + const factory: EntityFactory = () => new CounterEntity(); + worker.addNamedEntity("counter", factory); + + const mockStub = createMockStub(); + const req = createEntityRequestV2("counter", "key1"); + + // Act - call _executeEntityV2 via the wrapper (which tracks the work item) + (worker as any)._executeEntityV2(req, COMPLETION_TOKEN, mockStub.stub); + + // Assert - the promise should be tracked while executing + const pendingWorkItems: Set> = (worker as any)._pendingWorkItems; + expect(pendingWorkItems.size).toBe(1); + + // Wait for completion + await Promise.all(pendingWorkItems); + + // After completion, it should be removed + expect(pendingWorkItems.size).toBe(0); + }); + + it("should remove V1 entity execution from _pendingWorkItems even when entity is not found", async () => { + // Arrange + const worker = new TaskHubGrpcWorker({ logger: new NoOpLogger() }); + // Do NOT register any entity — the entity lookup will fail + + const mockStub = createMockStub(); + const req = createEntityBatchRequest("nonexistent", "key1"); + + // Act + (worker as any)._executeEntity(req, COMPLETION_TOKEN, mockStub.stub); + + const pendingWorkItems: Set> = (worker as any)._pendingWorkItems; + expect(pendingWorkItems.size).toBe(1); + + // Wait for completion + await Promise.all(pendingWorkItems); + + // Assert - should be cleaned up even on error path + expect(pendingWorkItems.size).toBe(0); + expect(mockStub.capturedResult).not.toBeNull(); + }); + + it("should track multiple concurrent entity executions in _pendingWorkItems", async () => { + // Arrange + const worker = new TaskHubGrpcWorker({ logger: new NoOpLogger() }); + const factory: EntityFactory = () => new CounterEntity(); + worker.addNamedEntity("counter", factory); + + const mockStub1 = createMockStub(); + const mockStub2 = createMockStub(); + const req1 = createEntityBatchRequest("counter", "key1"); + const req2 = createEntityBatchRequest("counter", "key2"); + + // Act - fire two concurrent entity executions + (worker as any)._executeEntity(req1, "token-1", mockStub1.stub); + (worker as any)._executeEntity(req2, "token-2", mockStub2.stub); + + // Assert - both should be tracked + const pendingWorkItems: Set> = (worker as any)._pendingWorkItems; + expect(pendingWorkItems.size).toBe(2); + + // Wait for all to complete + await Promise.all(pendingWorkItems); + + // Both should be cleaned up + expect(pendingWorkItems.size).toBe(0); + }); + }); });