From 3a52f7bd4ba45a6799eb9209db09080428bf0878 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 18 Dec 2025 14:20:35 +0100 Subject: [PATCH 1/2] feat: composable EnableCondition with bitmask optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a composable EnableCondition system for tool filtering that: 1. **User-facing API** (conditions.go): - EnableCondition interface with Evaluate(ctx) method - Primitives: FeatureFlag(), ContextBool(), Static(), Always(), Never() - Combinators: And(), Or(), Not() with short-circuit evaluation - All bitmask complexity hidden from users 2. **Bitmask compiler** (condition_compiler.go): - Compiles conditions to O(1) bitmask evaluators at build time - RequestMask holds pre-computed uint64 bitmask per request - AND/OR of flags compile to single bitmask operations - Falls back gracefully for custom ConditionFunc 3. **Pre-sorting optimization** (builder.go): - Tools, resources, prompts sorted once at build time - Filtering preserves order, eliminating per-request sorting - ~45% faster request handling in benchmarks 4. **Integration** (filters.go, registry.go): - Builder.Build() compiles all EnableConditions - AvailableTools() builds RequestMask once, evaluates via bitmask - Backward compatible with legacy Enabled func and feature flags Usage example: tool.EnableCondition = Or( ContextBool("is_cca"), // CCA users bypass flag FeatureFlag("my_feature"), // Others need flag enabled ) Benchmarks (1000 requests × 50 tools): - Before: 23.7ms (with per-request sorting) - After: 12.9ms (pre-sorted + bitmask) - Improvement: 46% faster This makes it easy for remote server to adopt - just set EnableCondition on tools and the optimization is automatic. --- pkg/inventory/builder.go | 87 +- pkg/inventory/condition_compiler.go | 514 +++++++++++ pkg/inventory/condition_compiler_test.go | 442 +++++++++ pkg/inventory/conditions.go | 293 ++++++ pkg/inventory/conditions_test.go | 552 ++++++++++++ pkg/inventory/enable_bench_test.go | 1046 ++++++++++++++++++++++ pkg/inventory/filters.go | 146 ++- pkg/inventory/registry.go | 11 +- pkg/inventory/registry_test.go | 362 +++++++- pkg/inventory/server_tool.go | 27 + 10 files changed, 3429 insertions(+), 51 deletions(-) create mode 100644 pkg/inventory/condition_compiler.go create mode 100644 pkg/inventory/condition_compiler_test.go create mode 100644 pkg/inventory/conditions.go create mode 100644 pkg/inventory/conditions_test.go create mode 100644 pkg/inventory/enable_bench_test.go diff --git a/pkg/inventory/builder.go b/pkg/inventory/builder.go index a0ed2baee..dbb1d6fa0 100644 --- a/pkg/inventory/builder.go +++ b/pkg/inventory/builder.go @@ -128,14 +128,22 @@ func (b *Builder) WithFilter(filter ToolFilter) *Builder { } // Build creates the final Inventory with all configuration applied. -// This processes toolset filtering, tool name resolution, and sets up +// This processes toolset filtering, tool name resolution, compiles EnableConditions +// for O(1) evaluation, pre-sorts all items for deterministic output, and sets up // the inventory for use. The returned Inventory is ready for use with // AvailableTools(), RegisterAll(), etc. func (b *Builder) Build() *Inventory { + // Pre-sort tools, resources, and prompts at build time. + // This eliminates sorting overhead on every Available*() call. + // Filtering preserves order, so if input is sorted, output is sorted. + sortedTools := b.preSortTools() + sortedResources := b.preSortResources() + sortedPrompts := b.preSortPrompts() + r := &Inventory{ - tools: b.tools, - resourceTemplates: b.resourceTemplates, - prompts: b.prompts, + tools: sortedTools, + resourceTemplates: sortedResources, + prompts: sortedPrompts, deprecatedAliases: b.deprecatedAliases, readOnly: b.readOnly, featureChecker: b.featureChecker, @@ -158,9 +166,80 @@ func (b *Builder) Build() *Inventory { } } + // Compile EnableConditions for O(1) bitmask evaluation + // Note: compileConditions uses r.tools which is now sortedTools + r.conditionCompiler, r.compiledConditions = b.compileConditions(sortedTools) + return r } +// preSortTools returns a copy of tools sorted by toolset ID, then tool name. +// This allows filtering to preserve order without re-sorting. +func (b *Builder) preSortTools() []ServerTool { + if len(b.tools) == 0 { + return b.tools + } + sorted := make([]ServerTool, len(b.tools)) + copy(sorted, b.tools) + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].Toolset.ID != sorted[j].Toolset.ID { + return sorted[i].Toolset.ID < sorted[j].Toolset.ID + } + return sorted[i].Tool.Name < sorted[j].Tool.Name + }) + return sorted +} + +// preSortResources returns a copy of resources sorted by toolset ID, then template name. +func (b *Builder) preSortResources() []ServerResourceTemplate { + if len(b.resourceTemplates) == 0 { + return b.resourceTemplates + } + sorted := make([]ServerResourceTemplate, len(b.resourceTemplates)) + copy(sorted, b.resourceTemplates) + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].Toolset.ID != sorted[j].Toolset.ID { + return sorted[i].Toolset.ID < sorted[j].Toolset.ID + } + return sorted[i].Template.Name < sorted[j].Template.Name + }) + return sorted +} + +// preSortPrompts returns a copy of prompts sorted by toolset ID, then prompt name. +func (b *Builder) preSortPrompts() []ServerPrompt { + if len(b.prompts) == 0 { + return b.prompts + } + sorted := make([]ServerPrompt, len(b.prompts)) + copy(sorted, b.prompts) + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].Toolset.ID != sorted[j].Toolset.ID { + return sorted[i].Toolset.ID < sorted[j].Toolset.ID + } + return sorted[i].Prompt.Name < sorted[j].Prompt.Name + }) + return sorted +} + +// compileConditions compiles all EnableConditions into bitmask-based evaluators. +// Returns the compiler (for building request masks) and compiled conditions slice. +// Takes the sorted tools slice to ensure compiled conditions align with sorted order. +func (b *Builder) compileConditions(sortedTools []ServerTool) (*ConditionCompiler, []*CompiledCondition) { + compiler := NewConditionCompiler() + compiled := make([]*CompiledCondition, len(sortedTools)) + + for i := range sortedTools { + if sortedTools[i].EnableCondition != nil { + compiled[i] = compiler.Compile(sortedTools[i].EnableCondition) + } + // nil means no condition (always enabled from condition perspective) + } + + compiler.Freeze() + return compiler, compiled +} + // processToolsets processes the toolsetIDs configuration and returns: // - enabledToolsets map (nil means all enabled) // - unrecognizedToolsets list for warnings diff --git a/pkg/inventory/condition_compiler.go b/pkg/inventory/condition_compiler.go new file mode 100644 index 000000000..8de361646 --- /dev/null +++ b/pkg/inventory/condition_compiler.go @@ -0,0 +1,514 @@ +package inventory + +import ( + "context" + "sync" +) + +// ConditionCompiler compiles EnableConditions into optimized bitmask-based evaluators. +// This allows O(1) condition evaluation after an initial O(n) compilation phase. +// +// Design: +// 1. At build time, all tools register their EnableConditions with the compiler +// 2. The compiler analyzes conditions and assigns bit positions to each unique key +// 3. Each condition is compiled to a CompiledCondition with bitmask logic +// 4. At request time, all context bools are computed once into a RequestMask +// 5. Each tool's condition is evaluated via fast bitmask operations +// +// This trades memory (storing bit assignments) for speed (O(1) evaluation). +// For 50 tools with 10 unique condition keys, this saves ~40% evaluation time. +type ConditionCompiler struct { + mu sync.RWMutex + + // keyToBit maps condition keys to bit positions (0-63) + // Keys are: "ctx:key_name" for ContextBool, "ff:flag_name" for FeatureFlag + keyToBit map[string]uint8 + + // nextBit is the next available bit position + nextBit uint8 + + // frozen prevents new bit assignments after compilation is complete + frozen bool +} + +// NewConditionCompiler creates a new compiler for optimizing conditions. +func NewConditionCompiler() *ConditionCompiler { + return &ConditionCompiler{ + keyToBit: make(map[string]uint8), + } +} + +// assignBit returns the bit position for a key, assigning a new one if needed. +// Thread-safe. Panics if called after Freeze() and key doesn't exist. +func (cc *ConditionCompiler) assignBit(key string) uint8 { + cc.mu.RLock() + if bit, ok := cc.keyToBit[key]; ok { + cc.mu.RUnlock() + return bit + } + cc.mu.RUnlock() + + cc.mu.Lock() + defer cc.mu.Unlock() + + // Double-check after acquiring write lock + if bit, ok := cc.keyToBit[key]; ok { + return bit + } + + if cc.frozen { + // After freezing, unknown keys get bit 63 (always false) + return 63 + } + + if cc.nextBit >= 63 { + // We've run out of bits - use bit 63 as overflow (always false) + return 63 + } + + bit := cc.nextBit + cc.keyToBit[key] = bit + cc.nextBit++ + return bit +} + +// Freeze prevents new bit assignments. Call after all conditions are compiled. +func (cc *ConditionCompiler) Freeze() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.frozen = true +} + +// NumBits returns the number of bits assigned. +func (cc *ConditionCompiler) NumBits() int { + cc.mu.RLock() + defer cc.mu.RUnlock() + return int(cc.nextBit) +} + +// Keys returns all registered keys (for debugging/introspection). +func (cc *ConditionCompiler) Keys() []string { + cc.mu.RLock() + defer cc.mu.RUnlock() + keys := make([]string, 0, len(cc.keyToBit)) + for k := range cc.keyToBit { + keys = append(keys, k) + } + return keys +} + +// Compile analyzes an EnableCondition and returns a CompiledCondition. +// The compiled condition uses bitmask operations for fast evaluation. +// Returns nil if the condition is nil (meaning always enabled). +func (cc *ConditionCompiler) Compile(cond EnableCondition) *CompiledCondition { + if cond == nil { + return nil // nil means always enabled + } + return cc.compile(cond) +} + +func (cc *ConditionCompiler) compile(cond EnableCondition) *CompiledCondition { + switch c := cond.(type) { + case *staticCondition: + return &CompiledCondition{ + evalType: evalStatic, + static: c.value, + } + + case *contextBoolCondition: + bit := cc.assignBit("ctx:" + c.key) + return &CompiledCondition{ + evalType: evalBitCheck, + requiredBit: bit, + requireTrue: true, + } + + case *featureFlagCondition: + bit := cc.assignBit("ff:" + c.flagName) + return &CompiledCondition{ + evalType: evalBitCheck, + requiredBit: bit, + requireTrue: true, + } + + case *notCondition: + inner := cc.compile(c.condition) + if inner.evalType == evalStatic { + return &CompiledCondition{ + evalType: evalStatic, + static: !inner.static, + } + } + if inner.evalType == evalBitCheck { + return &CompiledCondition{ + evalType: evalBitCheck, + requiredBit: inner.requiredBit, + requireTrue: !inner.requireTrue, + } + } + return &CompiledCondition{ + evalType: evalNot, + children: []*CompiledCondition{inner}, + } + + case *andCondition: + children := make([]*CompiledCondition, 0, len(c.conditions)) + for _, child := range c.conditions { + compiled := cc.compile(child) + // Optimize: static false short-circuits entire AND + if compiled.evalType == evalStatic && !compiled.static { + return &CompiledCondition{evalType: evalStatic, static: false} + } + // Optimize: skip static true (no-op in AND) + if compiled.evalType == evalStatic && compiled.static { + continue + } + children = append(children, compiled) + } + if len(children) == 0 { + return &CompiledCondition{evalType: evalStatic, static: true} + } + if len(children) == 1 { + return children[0] + } + // Check if we can use bitmask AND (all children are simple bit checks with requireTrue) + if canUseBitmaskAnd(children) { + var mask uint64 + for _, child := range children { + mask |= 1 << child.requiredBit + } + return &CompiledCondition{ + evalType: evalBitmaskAnd, + bitmask: mask, + requireTrue: true, + } + } + return &CompiledCondition{ + evalType: evalAnd, + children: children, + } + + case *orCondition: + children := make([]*CompiledCondition, 0, len(c.conditions)) + for _, child := range c.conditions { + compiled := cc.compile(child) + // Optimize: static true short-circuits entire OR + if compiled.evalType == evalStatic && compiled.static { + return &CompiledCondition{evalType: evalStatic, static: true} + } + // Optimize: skip static false (no-op in OR) + if compiled.evalType == evalStatic && !compiled.static { + continue + } + children = append(children, compiled) + } + if len(children) == 0 { + return &CompiledCondition{evalType: evalStatic, static: false} + } + if len(children) == 1 { + return children[0] + } + // Check if we can use bitmask OR (all children are simple bit checks with requireTrue) + if canUseBitmaskOr(children) { + var mask uint64 + for _, child := range children { + mask |= 1 << child.requiredBit + } + return &CompiledCondition{ + evalType: evalBitmaskOr, + bitmask: mask, + requireTrue: true, + } + } + return &CompiledCondition{ + evalType: evalOr, + children: children, + } + + case ConditionFunc: + // Can't optimize arbitrary functions - fall back to direct evaluation + return &CompiledCondition{ + evalType: evalFallback, + fallback: c, + } + + default: + // Unknown condition type - fall back to direct evaluation + return &CompiledCondition{ + evalType: evalFallback, + fallback: cond, + } + } +} + +// canUseBitmaskAnd checks if all children are simple positive bit checks +func canUseBitmaskAnd(children []*CompiledCondition) bool { + for _, c := range children { + if c.evalType != evalBitCheck || !c.requireTrue { + return false + } + } + return true +} + +// canUseBitmaskOr checks if all children are simple positive bit checks +func canUseBitmaskOr(children []*CompiledCondition) bool { + for _, c := range children { + if c.evalType != evalBitCheck || !c.requireTrue { + return false + } + } + return true +} + +// evalType describes how a CompiledCondition should be evaluated +type evalType uint8 + +const ( + evalStatic evalType = iota // Return static value + evalBitCheck // Check single bit + evalBitmaskAnd // AND: (mask & bits) == mask + evalBitmaskOr // OR: (mask & bits) != 0 + evalAnd // Tree-based AND + evalOr // Tree-based OR + evalNot // Negate child + evalFallback // Call original condition +) + +// CompiledCondition is an optimized representation of an EnableCondition. +// It uses bitmask operations where possible for O(1) evaluation. +type CompiledCondition struct { + evalType evalType + + // For evalStatic + static bool + + // For evalBitCheck + requiredBit uint8 + requireTrue bool // true = bit must be set, false = bit must be unset + + // For evalBitmaskAnd/evalBitmaskOr + bitmask uint64 + + // For evalAnd/evalOr/evalNot + children []*CompiledCondition + + // For evalFallback + fallback EnableCondition +} + +// Evaluate checks the compiled condition against the given request mask. +// For most conditions this is O(1) - just bitmask operations. +func (cc *CompiledCondition) Evaluate(rm *RequestMask) (bool, error) { + switch cc.evalType { + case evalStatic: + return cc.static, nil + + case evalBitCheck: + bitSet := (rm.bits & (1 << cc.requiredBit)) != 0 + if cc.requireTrue { + return bitSet, nil + } + return !bitSet, nil + + case evalBitmaskAnd: + // All required bits must be set + return (rm.bits & cc.bitmask) == cc.bitmask, nil + + case evalBitmaskOr: + // Any required bit must be set + return (rm.bits & cc.bitmask) != 0, nil + + case evalAnd: + for _, child := range cc.children { + result, err := child.Evaluate(rm) + if err != nil { + return false, err + } + if !result { + return false, nil + } + } + return true, nil + + case evalOr: + for _, child := range cc.children { + result, err := child.Evaluate(rm) + if err != nil { + continue // OR continues on error + } + if result { + return true, nil + } + } + return false, nil + + case evalNot: + result, err := cc.children[0].Evaluate(rm) + if err != nil { + return false, err + } + return !result, nil + + case evalFallback: + return cc.fallback.Evaluate(rm.ctx) + + default: + return false, nil + } +} + +// RequestMask holds pre-computed condition values as a bitmask. +// Created once per request, then used to evaluate all tool conditions. +type RequestMask struct { + bits uint64 + ctx context.Context // For fallback evaluation +} + +// RequestMaskBuilder builds a RequestMask from context bools and feature flags. +type RequestMaskBuilder struct { + compiler *ConditionCompiler +} + +// NewRequestMaskBuilder creates a builder for the given compiler. +func NewRequestMaskBuilder(compiler *ConditionCompiler) *RequestMaskBuilder { + return &RequestMaskBuilder{compiler: compiler} +} + +// Build creates a RequestMask from context bools and feature flag results. +// This should be called once per request with all relevant bools pre-computed. +func (b *RequestMaskBuilder) Build(ctx context.Context, bools ContextBools, flags map[string]bool) *RequestMask { + var bits uint64 + + b.compiler.mu.RLock() + defer b.compiler.mu.RUnlock() + + // Set bits for context bools + for key, value := range bools { + if bit, ok := b.compiler.keyToBit["ctx:"+key]; ok && value { + bits |= 1 << bit + } + } + + // Set bits for feature flags + for flag, enabled := range flags { + if bit, ok := b.compiler.keyToBit["ff:"+flag]; ok && enabled { + bits |= 1 << bit + } + } + + return &RequestMask{ + bits: bits, + ctx: ctx, + } +} + +// BuildFromContext creates a RequestMask using ContextBools from context +// and evaluating feature flags via the FeatureFlagChecker in context. +// This is a convenience method that computes everything from context. +func (b *RequestMaskBuilder) BuildFromContext(ctx context.Context) *RequestMask { + var bits uint64 + + bools := contextBoolsFromContext(ctx) + checker := FeatureCheckerFromContext(ctx) + + b.compiler.mu.RLock() + defer b.compiler.mu.RUnlock() + + for key, bit := range b.compiler.keyToBit { + if len(key) < 4 { + continue + } + prefix := key[:3] + name := key[3:] + + switch prefix { + case "ctx": + if bools != nil && bools[name] { + bits |= 1 << bit + } + case "ff:": + name = key[3:] // "ff:" is 3 chars + if checker != nil { + enabled, err := checker(ctx, name) + if err == nil && enabled { + bits |= 1 << bit + } + } + } + } + + return &RequestMask{ + bits: bits, + ctx: ctx, + } +} + +// --- Integration with Inventory --- + +// CompiledToolCondition pairs a tool with its compiled condition. +type CompiledToolCondition struct { + Tool *ServerTool + Condition *CompiledCondition // nil means always enabled +} + +// ToolConditionSet holds all compiled tool conditions for fast filtering. +type ToolConditionSet struct { + compiler *ConditionCompiler + builder *RequestMaskBuilder + tools []CompiledToolCondition +} + +// NewToolConditionSet creates a new set from the given tools. +// This compiles all conditions and freezes the compiler. +func NewToolConditionSet(tools []*ServerTool) *ToolConditionSet { + compiler := NewConditionCompiler() + + compiled := make([]CompiledToolCondition, len(tools)) + for i, tool := range tools { + compiled[i] = CompiledToolCondition{ + Tool: tool, + Condition: compiler.Compile(tool.EnableCondition), + } + } + + compiler.Freeze() + + return &ToolConditionSet{ + compiler: compiler, + builder: NewRequestMaskBuilder(compiler), + tools: compiled, + } +} + +// FilterEnabled returns tools that are enabled for the given request mask. +func (tcs *ToolConditionSet) FilterEnabled(rm *RequestMask) []*ServerTool { + result := make([]*ServerTool, 0, len(tcs.tools)) + for _, tc := range tcs.tools { + if tc.Condition == nil { + // No condition = always enabled + result = append(result, tc.Tool) + continue + } + enabled, _ := tc.Condition.Evaluate(rm) + if enabled { + result = append(result, tc.Tool) + } + } + return result +} + +// BuildMask creates a RequestMask for filtering. +func (tcs *ToolConditionSet) BuildMask(ctx context.Context, bools ContextBools, flags map[string]bool) *RequestMask { + return tcs.builder.Build(ctx, bools, flags) +} + +// BuildMaskFromContext creates a RequestMask from context. +func (tcs *ToolConditionSet) BuildMaskFromContext(ctx context.Context) *RequestMask { + return tcs.builder.BuildFromContext(ctx) +} + +// Compiler returns the condition compiler (for introspection/debugging). +func (tcs *ToolConditionSet) Compiler() *ConditionCompiler { + return tcs.compiler +} diff --git a/pkg/inventory/condition_compiler_test.go b/pkg/inventory/condition_compiler_test.go new file mode 100644 index 000000000..5dd1e0442 --- /dev/null +++ b/pkg/inventory/condition_compiler_test.go @@ -0,0 +1,442 @@ +package inventory + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConditionCompiler_AssignBit(t *testing.T) { + cc := NewConditionCompiler() + + // First assignment + bit1 := cc.assignBit("ctx:is_cca") + assert.Equal(t, uint8(0), bit1) + + // Same key returns same bit + bit2 := cc.assignBit("ctx:is_cca") + assert.Equal(t, bit1, bit2) + + // Different key gets new bit + bit3 := cc.assignBit("ff:web_search") + assert.Equal(t, uint8(1), bit3) + + // After freeze, unknown keys get bit 63 + cc.Freeze() + bit4 := cc.assignBit("ctx:unknown") + assert.Equal(t, uint8(63), bit4) + + // Known keys still work after freeze + bit5 := cc.assignBit("ctx:is_cca") + assert.Equal(t, bit1, bit5) +} + +func TestConditionCompiler_CompileStatic(t *testing.T) { + cc := NewConditionCompiler() + + // Static true + cond := cc.Compile(Always()) + require.NotNil(t, cond) + assert.Equal(t, evalStatic, cond.evalType) + assert.True(t, cond.static) + + // Static false + cond = cc.Compile(Never()) + require.NotNil(t, cond) + assert.Equal(t, evalStatic, cond.evalType) + assert.False(t, cond.static) + + // Nil returns nil + assert.Nil(t, cc.Compile(nil)) +} + +func TestConditionCompiler_CompileContextBool(t *testing.T) { + cc := NewConditionCompiler() + + cond := cc.Compile(ContextBool("is_cca")) + require.NotNil(t, cond) + assert.Equal(t, evalBitCheck, cond.evalType) + assert.Equal(t, uint8(0), cond.requiredBit) + assert.True(t, cond.requireTrue) +} + +func TestConditionCompiler_CompileFeatureFlag(t *testing.T) { + cc := NewConditionCompiler() + + cond := cc.Compile(FeatureFlag("web_search")) + require.NotNil(t, cond) + assert.Equal(t, evalBitCheck, cond.evalType) + assert.Equal(t, uint8(0), cond.requiredBit) + assert.True(t, cond.requireTrue) +} + +func TestConditionCompiler_CompileNot(t *testing.T) { + cc := NewConditionCompiler() + + // Not(static) -> static + cond := cc.Compile(Not(Always())) + require.NotNil(t, cond) + assert.Equal(t, evalStatic, cond.evalType) + assert.False(t, cond.static) + + // Not(contextBool) -> bitCheck with requireTrue=false + cond = cc.Compile(Not(ContextBool("is_cca"))) + require.NotNil(t, cond) + assert.Equal(t, evalBitCheck, cond.evalType) + assert.False(t, cond.requireTrue) +} + +func TestConditionCompiler_CompileAnd(t *testing.T) { + cc := NewConditionCompiler() + + // And of two context bools -> bitmaskAnd + cond := cc.Compile(And( + ContextBool("is_cca"), + ContextBool("has_access"), + )) + require.NotNil(t, cond) + assert.Equal(t, evalBitmaskAnd, cond.evalType) + assert.Equal(t, uint64(0b11), cond.bitmask) // bits 0 and 1 + + // And with static false -> static false + cond = cc.Compile(And( + ContextBool("is_cca"), + Never(), + )) + require.NotNil(t, cond) + assert.Equal(t, evalStatic, cond.evalType) + assert.False(t, cond.static) + + // And with static true filtered out -> single condition + cc2 := NewConditionCompiler() + cond = cc2.Compile(And( + Always(), + ContextBool("is_cca"), + )) + require.NotNil(t, cond) + assert.Equal(t, evalBitCheck, cond.evalType) +} + +func TestConditionCompiler_CompileOr(t *testing.T) { + cc := NewConditionCompiler() + + // Or of two context bools -> bitmaskOr + cond := cc.Compile(Or( + ContextBool("is_cca"), + ContextBool("is_bypass"), + )) + require.NotNil(t, cond) + assert.Equal(t, evalBitmaskOr, cond.evalType) + assert.Equal(t, uint64(0b11), cond.bitmask) + + // Or with static true -> static true + cond = cc.Compile(Or( + ContextBool("is_cca"), + Always(), + )) + require.NotNil(t, cond) + assert.Equal(t, evalStatic, cond.evalType) + assert.True(t, cond.static) +} + +func TestCompiledCondition_Evaluate(t *testing.T) { + cc := NewConditionCompiler() + + // Compile conditions + ccaCheck := cc.Compile(ContextBool("is_cca")) + ffCheck := cc.Compile(FeatureFlag("web_search")) + andCond := cc.Compile(And(ContextBool("is_cca"), FeatureFlag("web_search"))) + orCond := cc.Compile(Or(ContextBool("is_cca"), FeatureFlag("web_search"))) + notCond := cc.Compile(Not(ContextBool("is_cca"))) + + cc.Freeze() + + builder := NewRequestMaskBuilder(cc) + + tests := []struct { + name string + condition *CompiledCondition + bools ContextBools + flags map[string]bool + want bool + }{ + { + name: "context bool true", + condition: ccaCheck, + bools: ContextBools{"is_cca": true}, + want: true, + }, + { + name: "context bool false", + condition: ccaCheck, + bools: ContextBools{"is_cca": false}, + want: false, + }, + { + name: "feature flag true", + condition: ffCheck, + flags: map[string]bool{"web_search": true}, + want: true, + }, + { + name: "feature flag false", + condition: ffCheck, + flags: map[string]bool{"web_search": false}, + want: false, + }, + { + name: "and both true", + condition: andCond, + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"web_search": true}, + want: true, + }, + { + name: "and one false", + condition: andCond, + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"web_search": false}, + want: false, + }, + { + name: "or one true", + condition: orCond, + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"web_search": false}, + want: true, + }, + { + name: "or both false", + condition: orCond, + bools: ContextBools{"is_cca": false}, + flags: map[string]bool{"web_search": false}, + want: false, + }, + { + name: "not true -> false", + condition: notCond, + bools: ContextBools{"is_cca": true}, + want: false, + }, + { + name: "not false -> true", + condition: notCond, + bools: ContextBools{"is_cca": false}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mask := builder.Build(context.Background(), tt.bools, tt.flags) + got, err := tt.condition.Evaluate(mask) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestToolConditionSet_FilterEnabled(t *testing.T) { + tools := []*ServerTool{ + { + Tool: mcp.Tool{Name: "always_on"}, + EnableCondition: nil, // nil means always enabled + }, + { + Tool: mcp.Tool{Name: "cca_only"}, + EnableCondition: ContextBool("is_cca"), + }, + { + Tool: mcp.Tool{Name: "ff_required"}, + EnableCondition: FeatureFlag("web_search"), + }, + { + Tool: mcp.Tool{Name: "cca_and_ff"}, + EnableCondition: And( + ContextBool("is_cca"), + FeatureFlag("code_search"), + ), + }, + { + Tool: mcp.Tool{Name: "cca_or_ff"}, + EnableCondition: Or( + ContextBool("is_cca"), + FeatureFlag("bypass_flag"), + ), + }, + } + + tcs := NewToolConditionSet(tools) + + tests := []struct { + name string + bools ContextBools + flags map[string]bool + want []string + }{ + { + name: "no bools or flags - only always_on", + bools: nil, + flags: nil, + want: []string{"always_on"}, + }, + { + name: "cca only", + bools: ContextBools{"is_cca": true}, + flags: nil, + want: []string{"always_on", "cca_only", "cca_or_ff"}, + }, + { + name: "web_search flag only", + bools: nil, + flags: map[string]bool{"web_search": true}, + want: []string{"always_on", "ff_required"}, + }, + { + name: "cca and code_search", + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"code_search": true}, + want: []string{"always_on", "cca_only", "cca_and_ff", "cca_or_ff"}, + }, + { + name: "all enabled", + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"web_search": true, "code_search": true, "bypass_flag": true}, + want: []string{"always_on", "cca_only", "ff_required", "cca_and_ff", "cca_or_ff"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mask := tcs.BuildMask(context.Background(), tt.bools, tt.flags) + enabled := tcs.FilterEnabled(mask) + + names := make([]string, len(enabled)) + for i, tool := range enabled { + names[i] = tool.Tool.Name + } + assert.Equal(t, tt.want, names) + }) + } +} + +func TestToolConditionSet_ComplexConditions(t *testing.T) { + // Test complex real-world patterns from remote server + tools := []*ServerTool{ + { + // CCA bypass: CCA OR feature_flag + Tool: mcp.Tool{Name: "cca_bypass"}, + EnableCondition: Or( + ContextBool("is_cca"), + FeatureFlag("agent_search"), + ), + }, + { + // Feature + policy: feature AND paid_access + Tool: mcp.Tool{Name: "paid_feature"}, + EnableCondition: And( + FeatureFlag("premium_search"), + ContextBool("has_paid_access"), + ), + }, + { + // Complex: (CCA OR copilot_chat) AND feature AND NOT disabled + Tool: mcp.Tool{Name: "complex"}, + EnableCondition: And( + Or( + ContextBool("is_cca"), + ContextBool("is_copilot_chat"), + ), + FeatureFlag("advanced_feature"), + Not(FeatureFlag("kill_switch")), + ), + }, + } + + tcs := NewToolConditionSet(tools) + + tests := []struct { + name string + bools ContextBools + flags map[string]bool + want []string + }{ + { + name: "cca enables cca_bypass", + bools: ContextBools{"is_cca": true}, + flags: nil, + want: []string{"cca_bypass"}, + }, + { + name: "agent_search flag enables cca_bypass", + bools: nil, + flags: map[string]bool{"agent_search": true}, + want: []string{"cca_bypass"}, + }, + { + name: "premium + paid enables paid_feature", + bools: ContextBools{"has_paid_access": true}, + flags: map[string]bool{"premium_search": true}, + want: []string{"paid_feature"}, + }, + { + name: "complex enabled with cca + feature", + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"advanced_feature": true}, + want: []string{"cca_bypass", "complex"}, + }, + { + name: "complex disabled by kill_switch", + bools: ContextBools{"is_cca": true}, + flags: map[string]bool{"advanced_feature": true, "kill_switch": true}, + want: []string{"cca_bypass"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mask := tcs.BuildMask(context.Background(), tt.bools, tt.flags) + enabled := tcs.FilterEnabled(mask) + + names := make([]string, len(enabled)) + for i, tool := range enabled { + names[i] = tool.Tool.Name + } + assert.Equal(t, tt.want, names) + }) + } +} + +func TestConditionCompiler_NumBits(t *testing.T) { + cc := NewConditionCompiler() + + assert.Equal(t, 0, cc.NumBits()) + + cc.assignBit("ctx:a") + assert.Equal(t, 1, cc.NumBits()) + + cc.assignBit("ctx:b") + cc.assignBit("ff:c") + assert.Equal(t, 3, cc.NumBits()) + + // Same key doesn't increase count + cc.assignBit("ctx:a") + assert.Equal(t, 3, cc.NumBits()) +} + +func TestConditionCompiler_Keys(t *testing.T) { + cc := NewConditionCompiler() + + cc.assignBit("ctx:is_cca") + cc.assignBit("ff:web_search") + cc.assignBit("ctx:has_access") + + keys := cc.Keys() + assert.Len(t, keys, 3) + assert.Contains(t, keys, "ctx:is_cca") + assert.Contains(t, keys, "ff:web_search") + assert.Contains(t, keys, "ctx:has_access") +} diff --git a/pkg/inventory/conditions.go b/pkg/inventory/conditions.go new file mode 100644 index 000000000..abb687788 --- /dev/null +++ b/pkg/inventory/conditions.go @@ -0,0 +1,293 @@ +package inventory + +import ( + "context" +) + +// EnableCondition represents a composable condition for tool availability. +// Conditions can be combined using And/Or/Not combinators for complex logic. +// +// Design goals: +// - Declarative: users compose conditions without knowing implementation details +// - Composable: complex conditions built from simple primitives +// - Efficient: conditions are evaluated lazily, with results potentially cached +// - Decoupled: condition definitions don't depend on specific actor types +// +// Example usage: +// +// // Simple feature flag +// tool.EnableCondition = FeatureFlag("web_search") +// +// // Feature flag AND user policy +// tool.EnableCondition = And( +// FeatureFlag("web_search"), +// ContextBool("user_has_paid_bing_access"), +// ) +// +// // CCA bypass (CCA requests OR feature flag for non-CCA) +// tool.EnableCondition = Or( +// ContextBool("is_cca"), +// FeatureFlag("agent_search"), +// ) +type EnableCondition interface { + // Evaluate checks if the condition is met in the given context. + // Returns (enabled, error). On error, the condition should be treated as false. + Evaluate(ctx context.Context) (bool, error) +} + +// ConditionFunc is an adapter that allows functions to be used as EnableConditions. +// This is useful for simple one-off conditions that don't need to be reusable. +type ConditionFunc func(ctx context.Context) (bool, error) + +// Evaluate implements EnableCondition. +func (f ConditionFunc) Evaluate(ctx context.Context) (bool, error) { + return f(ctx) +} + +// --- Primitive Conditions --- + +// featureFlagCondition checks if a named feature flag is enabled. +// The actual flag checking is delegated to a FeatureFlagChecker in context. +type featureFlagCondition struct { + flagName string +} + +// FeatureFlag creates a condition that checks if the named feature flag is enabled. +// The feature flag is evaluated using the FeatureFlagChecker stored in context. +// If no checker is available or if the flag check returns an error, the condition is false. +func FeatureFlag(flagName string) EnableCondition { + return &featureFlagCondition{flagName: flagName} +} + +// Evaluate implements EnableCondition. +func (c *featureFlagCondition) Evaluate(ctx context.Context) (bool, error) { + checker := FeatureCheckerFromContext(ctx) + if checker == nil { + return false, nil + } + return checker(ctx, c.flagName) +} + +// contextBoolCondition checks a named boolean value from context. +// This allows tools to depend on pre-computed boolean conditions without +// knowing how those conditions are computed. +type contextBoolCondition struct { + key string +} + +// ContextBool creates a condition that checks a named boolean from context. +// The boolean is retrieved using ContextBoolFromContext(ctx, key). +// This decouples tool definitions from specific actor/user types. +// +// Common keys might include: +// - "is_cca" - whether this is a Copilot Coding Agent request +// - "user_has_paid_access" - whether user has paid Copilot access +// - "mcp_host_is_copilot_chat" - whether MCP host is copilot-chat +// +// Returns false if the key is not found in context. +func ContextBool(key string) EnableCondition { + return &contextBoolCondition{key: key} +} + +// Evaluate implements EnableCondition. +func (c *contextBoolCondition) Evaluate(ctx context.Context) (bool, error) { + return ContextBoolFromContext(ctx, c.key), nil +} + +// staticCondition always returns a fixed value. +type staticCondition struct { + value bool +} + +// Static creates a condition that always returns the given value. +// Useful for testing or for conditions that are determined at build time. +func Static(value bool) EnableCondition { + return &staticCondition{value: value} +} + +// Always returns a condition that is always true. +// Useful as a default or placeholder. +func Always() EnableCondition { + return Static(true) +} + +// Never returns a condition that is always false. +// Useful for disabling tools unconditionally. +func Never() EnableCondition { + return Static(false) +} + +// Evaluate implements EnableCondition. +func (c *staticCondition) Evaluate(_ context.Context) (bool, error) { + return c.value, nil +} + +// --- Combinators --- + +// andCondition requires all conditions to be true. +type andCondition struct { + conditions []EnableCondition +} + +// And creates a condition that is true only if ALL of the given conditions are true. +// Short-circuits on the first false condition. +// Returns true if no conditions are provided. +func And(conditions ...EnableCondition) EnableCondition { + // Filter out nil conditions + filtered := make([]EnableCondition, 0, len(conditions)) + for _, c := range conditions { + if c != nil { + filtered = append(filtered, c) + } + } + if len(filtered) == 0 { + return Always() + } + if len(filtered) == 1 { + return filtered[0] + } + return &andCondition{conditions: filtered} +} + +// Evaluate implements EnableCondition. +func (c *andCondition) Evaluate(ctx context.Context) (bool, error) { + for _, cond := range c.conditions { + enabled, err := cond.Evaluate(ctx) + if err != nil { + return false, err + } + if !enabled { + return false, nil + } + } + return true, nil +} + +// orCondition requires at least one condition to be true. +type orCondition struct { + conditions []EnableCondition +} + +// Or creates a condition that is true if ANY of the given conditions is true. +// Short-circuits on the first true condition. +// Returns false if no conditions are provided. +func Or(conditions ...EnableCondition) EnableCondition { + // Filter out nil conditions + filtered := make([]EnableCondition, 0, len(conditions)) + for _, c := range conditions { + if c != nil { + filtered = append(filtered, c) + } + } + if len(filtered) == 0 { + return Never() + } + if len(filtered) == 1 { + return filtered[0] + } + return &orCondition{conditions: filtered} +} + +// Evaluate implements EnableCondition. +func (c *orCondition) Evaluate(ctx context.Context) (bool, error) { + for _, cond := range c.conditions { + enabled, err := cond.Evaluate(ctx) + if err != nil { + // For OR, we continue checking other conditions on error + continue + } + if enabled { + return true, nil + } + } + return false, nil +} + +// notCondition negates a condition. +type notCondition struct { + condition EnableCondition +} + +// Not creates a condition that is the logical negation of the given condition. +// Returns true if the inner condition returns false (and vice versa). +// Errors are propagated. +func Not(condition EnableCondition) EnableCondition { + if condition == nil { + return Never() // Not(nil) = Not(true) = false + } + return ¬Condition{condition: condition} +} + +// Evaluate implements EnableCondition. +func (c *notCondition) Evaluate(ctx context.Context) (bool, error) { + enabled, err := c.condition.Evaluate(ctx) + if err != nil { + return false, err + } + return !enabled, nil +} + +// --- Context Keys for Conditions --- + +// Context key types for storing condition-related data +type contextKey int + +const ( + featureCheckerKey contextKey = iota + contextBoolsKey +) + +// ContextWithFeatureChecker returns a context with the given feature flag checker. +func ContextWithFeatureChecker(ctx context.Context, checker FeatureFlagChecker) context.Context { + return context.WithValue(ctx, featureCheckerKey, checker) +} + +// FeatureCheckerFromContext retrieves the feature flag checker from context. +// Returns nil if no checker is set. +func FeatureCheckerFromContext(ctx context.Context) FeatureFlagChecker { + checker, _ := ctx.Value(featureCheckerKey).(FeatureFlagChecker) + return checker +} + +// ContextBools is a map of named boolean values for use with ContextBool conditions. +// This allows callers to pre-compute common checks once per request and share them. +type ContextBools map[string]bool + +// ContextWithBools returns a context with the given boolean values. +// These values can be retrieved using ContextBool conditions. +func ContextWithBools(ctx context.Context, bools ContextBools) context.Context { + // Merge with existing bools if any + existing := contextBoolsFromContext(ctx) + if existing != nil { + merged := make(ContextBools, len(existing)+len(bools)) + for k, v := range existing { + merged[k] = v + } + for k, v := range bools { + merged[k] = v + } + return context.WithValue(ctx, contextBoolsKey, merged) + } + return context.WithValue(ctx, contextBoolsKey, bools) +} + +// contextBoolsFromContext retrieves all context bools. +func contextBoolsFromContext(ctx context.Context) ContextBools { + bools, _ := ctx.Value(contextBoolsKey).(ContextBools) + return bools +} + +// ContextBoolFromContext retrieves a named boolean from context. +// Returns false if the key is not found. +func ContextBoolFromContext(ctx context.Context, key string) bool { + bools := contextBoolsFromContext(ctx) + if bools == nil { + return false + } + return bools[key] +} + +// SetContextBool is a convenience function that adds a single boolean to context. +func SetContextBool(ctx context.Context, key string, value bool) context.Context { + return ContextWithBools(ctx, ContextBools{key: value}) +} diff --git a/pkg/inventory/conditions_test.go b/pkg/inventory/conditions_test.go new file mode 100644 index 000000000..a91115caa --- /dev/null +++ b/pkg/inventory/conditions_test.go @@ -0,0 +1,552 @@ +package inventory + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFeatureFlagCondition(t *testing.T) { + tests := []struct { + name string + flagName string + checkerResult bool + checkerErr error + hasChecker bool + expectedResult bool + expectedErr bool + }{ + { + name: "flag enabled", + flagName: "test_flag", + checkerResult: true, + hasChecker: true, + expectedResult: true, + }, + { + name: "flag disabled", + flagName: "test_flag", + checkerResult: false, + hasChecker: true, + expectedResult: false, + }, + { + name: "no checker in context", + flagName: "test_flag", + hasChecker: false, + expectedResult: false, + }, + { + name: "checker returns error", + flagName: "test_flag", + checkerErr: errors.New("flag check failed"), + hasChecker: true, + expectedResult: false, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + if tt.hasChecker { + checker := func(_ context.Context, flagName string) (bool, error) { + assert.Equal(t, tt.flagName, flagName) + return tt.checkerResult, tt.checkerErr + } + ctx = ContextWithFeatureChecker(ctx, checker) + } + + cond := FeatureFlag(tt.flagName) + result, err := cond.Evaluate(ctx) + + if tt.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestContextBoolCondition(t *testing.T) { + tests := []struct { + name string + key string + contextBools ContextBools + expectedResult bool + }{ + { + name: "bool is true", + key: "is_cca", + contextBools: ContextBools{"is_cca": true}, + expectedResult: true, + }, + { + name: "bool is false", + key: "is_cca", + contextBools: ContextBools{"is_cca": false}, + expectedResult: false, + }, + { + name: "key not found", + key: "is_cca", + contextBools: ContextBools{"other_key": true}, + expectedResult: false, + }, + { + name: "no context bools", + key: "is_cca", + contextBools: nil, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + if tt.contextBools != nil { + ctx = ContextWithBools(ctx, tt.contextBools) + } + + cond := ContextBool(tt.key) + result, err := cond.Evaluate(ctx) + + require.NoError(t, err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestStaticConditions(t *testing.T) { + ctx := context.Background() + + t.Run("Static(true)", func(t *testing.T) { + cond := Static(true) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + }) + + t.Run("Static(false)", func(t *testing.T) { + cond := Static(false) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.False(t, result) + }) + + t.Run("Always()", func(t *testing.T) { + cond := Always() + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + }) + + t.Run("Never()", func(t *testing.T) { + cond := Never() + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.False(t, result) + }) +} + +func TestAndCondition(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + conditions []EnableCondition + expectedResult bool + }{ + { + name: "all true", + conditions: []EnableCondition{Always(), Always(), Always()}, + expectedResult: true, + }, + { + name: "one false", + conditions: []EnableCondition{Always(), Never(), Always()}, + expectedResult: false, + }, + { + name: "all false", + conditions: []EnableCondition{Never(), Never()}, + expectedResult: false, + }, + { + name: "empty conditions", + conditions: []EnableCondition{}, + expectedResult: true, + }, + { + name: "single true", + conditions: []EnableCondition{Always()}, + expectedResult: true, + }, + { + name: "single false", + conditions: []EnableCondition{Never()}, + expectedResult: false, + }, + { + name: "nil conditions filtered out", + conditions: []EnableCondition{Always(), nil, Always()}, + expectedResult: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cond := And(tt.conditions...) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestOrCondition(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + conditions []EnableCondition + expectedResult bool + }{ + { + name: "all true", + conditions: []EnableCondition{Always(), Always(), Always()}, + expectedResult: true, + }, + { + name: "one true", + conditions: []EnableCondition{Never(), Always(), Never()}, + expectedResult: true, + }, + { + name: "all false", + conditions: []EnableCondition{Never(), Never()}, + expectedResult: false, + }, + { + name: "empty conditions", + conditions: []EnableCondition{}, + expectedResult: false, + }, + { + name: "single true", + conditions: []EnableCondition{Always()}, + expectedResult: true, + }, + { + name: "single false", + conditions: []EnableCondition{Never()}, + expectedResult: false, + }, + { + name: "nil conditions filtered out", + conditions: []EnableCondition{Never(), nil, Never()}, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cond := Or(tt.conditions...) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestNotCondition(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + condition EnableCondition + expectedResult bool + }{ + { + name: "not true", + condition: Always(), + expectedResult: false, + }, + { + name: "not false", + condition: Never(), + expectedResult: true, + }, + { + name: "not nil", + condition: nil, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cond := Not(tt.condition) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestConditionFunc(t *testing.T) { + ctx := context.Background() + + t.Run("simple function", func(t *testing.T) { + cond := ConditionFunc(func(_ context.Context) (bool, error) { + return true, nil + }) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + }) + + t.Run("function with error", func(t *testing.T) { + expectedErr := errors.New("test error") + cond := ConditionFunc(func(_ context.Context) (bool, error) { + return false, expectedErr + }) + result, err := cond.Evaluate(ctx) + assert.Equal(t, expectedErr, err) + assert.False(t, result) + }) +} + +func TestComplexConditionCombinations(t *testing.T) { + // These tests match the real-world scenarios from the remote server + + t.Run("feature flag AND user policy (web search pattern)", func(t *testing.T) { + // Pattern: feature flag must be enabled AND user must have paid Bing access + cond := And( + FeatureFlag("web_search"), + ContextBool("user_has_paid_bing_access"), + ) + + // Test: both conditions true + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx = ContextWithBools(ctx, ContextBools{"user_has_paid_bing_access": true}) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + + // Test: feature flag true, but no user access + ctx2 := context.Background() + ctx2 = ContextWithFeatureChecker(ctx2, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx2 = ContextWithBools(ctx2, ContextBools{"user_has_paid_bing_access": false}) + result2, err2 := cond.Evaluate(ctx2) + require.NoError(t, err2) + assert.False(t, result2) + + // Test: feature flag false + ctx3 := context.Background() + ctx3 = ContextWithFeatureChecker(ctx3, func(_ context.Context, _ string) (bool, error) { + return false, nil + }) + ctx3 = ContextWithBools(ctx3, ContextBools{"user_has_paid_bing_access": true}) + result3, err3 := cond.Evaluate(ctx3) + require.NoError(t, err3) + assert.False(t, result3) + }) + + t.Run("CCA bypass pattern (CCA OR feature flag)", func(t *testing.T) { + // Pattern: CCA requests bypass feature flag, non-CCA requires feature flag + cond := Or( + ContextBool("is_cca"), + FeatureFlag("agent_search"), + ) + + // Test: CCA request (bypass feature flag) + ctx := context.Background() + ctx = ContextWithBools(ctx, ContextBools{"is_cca": true}) + // No feature checker - CCA should pass without it + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + + // Test: non-CCA with feature flag enabled + ctx2 := context.Background() + ctx2 = ContextWithFeatureChecker(ctx2, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx2 = ContextWithBools(ctx2, ContextBools{"is_cca": false}) + result2, err2 := cond.Evaluate(ctx2) + require.NoError(t, err2) + assert.True(t, result2) + + // Test: non-CCA with feature flag disabled + ctx3 := context.Background() + ctx3 = ContextWithFeatureChecker(ctx3, func(_ context.Context, _ string) (bool, error) { + return false, nil + }) + ctx3 = ContextWithBools(ctx3, ContextBools{"is_cca": false}) + result3, err3 := cond.Evaluate(ctx3) + require.NoError(t, err3) + assert.False(t, result3) + }) + + t.Run("CCA AND feature flag pattern", func(t *testing.T) { + // Pattern: must be CCA AND have feature flag enabled + cond := And( + ContextBool("is_cca"), + FeatureFlag("complex_workflows"), + ) + + // Test: CCA with feature flag + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx = ContextWithBools(ctx, ContextBools{"is_cca": true}) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + + // Test: CCA without feature flag + ctx2 := context.Background() + ctx2 = ContextWithFeatureChecker(ctx2, func(_ context.Context, _ string) (bool, error) { + return false, nil + }) + ctx2 = ContextWithBools(ctx2, ContextBools{"is_cca": true}) + result2, err2 := cond.Evaluate(ctx2) + require.NoError(t, err2) + assert.False(t, result2) + + // Test: non-CCA with feature flag + ctx3 := context.Background() + ctx3 = ContextWithFeatureChecker(ctx3, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx3 = ContextWithBools(ctx3, ContextBools{"is_cca": false}) + result3, err3 := cond.Evaluate(ctx3) + require.NoError(t, err3) + assert.False(t, result3) + }) + + t.Run("copilot-chat bypass pattern", func(t *testing.T) { + // Pattern: copilot-chat bypasses feature flag check + cond := Or( + ContextBool("mcp_host_is_copilot_chat"), + FeatureFlag("semantic_code_search"), + ) + + // Test: copilot-chat host (bypass feature flag) + ctx := context.Background() + ctx = ContextWithBools(ctx, ContextBools{"mcp_host_is_copilot_chat": true}) + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + + // Test: other host with feature flag + ctx2 := context.Background() + ctx2 = ContextWithFeatureChecker(ctx2, func(_ context.Context, _ string) (bool, error) { + return true, nil + }) + ctx2 = ContextWithBools(ctx2, ContextBools{"mcp_host_is_copilot_chat": false}) + result2, err2 := cond.Evaluate(ctx2) + require.NoError(t, err2) + assert.True(t, result2) + + // Test: other host without feature flag + ctx3 := context.Background() + ctx3 = ContextWithFeatureChecker(ctx3, func(_ context.Context, _ string) (bool, error) { + return false, nil + }) + ctx3 = ContextWithBools(ctx3, ContextBools{"mcp_host_is_copilot_chat": false}) + result3, err3 := cond.Evaluate(ctx3) + require.NoError(t, err3) + assert.False(t, result3) + }) +} + +func TestContextBoolsMerging(t *testing.T) { + ctx := context.Background() + + // Add first set of bools + ctx = ContextWithBools(ctx, ContextBools{"key1": true, "key2": false}) + + // Add second set - should merge + ctx = ContextWithBools(ctx, ContextBools{"key3": true, "key2": true}) // key2 overwritten + + // Check all keys + assert.True(t, ContextBoolFromContext(ctx, "key1")) + assert.True(t, ContextBoolFromContext(ctx, "key2")) // overwritten value + assert.True(t, ContextBoolFromContext(ctx, "key3")) + assert.False(t, ContextBoolFromContext(ctx, "nonexistent")) +} + +func TestSetContextBool(t *testing.T) { + ctx := context.Background() + ctx = SetContextBool(ctx, "my_flag", true) + + assert.True(t, ContextBoolFromContext(ctx, "my_flag")) + assert.False(t, ContextBoolFromContext(ctx, "other_flag")) +} + +func TestAndShortCircuit(t *testing.T) { + callCount := 0 + ctx := context.Background() + + // First condition returns false, second should not be called + cond := And( + Never(), + ConditionFunc(func(_ context.Context) (bool, error) { + callCount++ + return true, nil + }), + ) + + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.False(t, result) + assert.Equal(t, 0, callCount, "second condition should not be called due to short-circuit") +} + +func TestOrShortCircuit(t *testing.T) { + callCount := 0 + ctx := context.Background() + + // First condition returns true, second should not be called + cond := Or( + Always(), + ConditionFunc(func(_ context.Context) (bool, error) { + callCount++ + return false, nil + }), + ) + + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) + assert.Equal(t, 0, callCount, "second condition should not be called due to short-circuit") +} + +func TestOrContinuesOnError(t *testing.T) { + ctx := context.Background() + + // First condition errors, but second is true - should return true + cond := Or( + ConditionFunc(func(_ context.Context) (bool, error) { + return false, errors.New("error") + }), + Always(), + ) + + result, err := cond.Evaluate(ctx) + require.NoError(t, err) + assert.True(t, result) +} diff --git a/pkg/inventory/enable_bench_test.go b/pkg/inventory/enable_bench_test.go new file mode 100644 index 000000000..1fcbe95f4 --- /dev/null +++ b/pkg/inventory/enable_bench_test.go @@ -0,0 +1,1046 @@ +package inventory + +import ( + "context" + "fmt" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Benchmark comparing the old Enabled func approach vs the new EnableCondition approach. +// These benchmarks simulate remote server scenarios where tool filtering happens +// on every request, thousands of times. + +// --- Simulated feature flag checker (represents a real feature flag service call) --- + +func mockFeatureChecker(_ context.Context, flagName string) (bool, error) { + // Simulate feature flag lookup - in production this might be a DB/cache lookup + switch flagName { + case "web_search", "code_search", "issues_v2": + return true, nil + case "disabled_feature": + return false, nil + default: + return false, nil + } +} + +// --- OLD APPROACH: Using Enabled function --- + +// This represents how tools are currently filtered with the Enabled func +func oldStyleToolFilter( + ctx context.Context, + tool *ServerTool, + featureChecker FeatureFlagChecker, + _ bool, // isCCA - unused, but kept for signature compatibility + _ bool, // isUserWithPaidBing - unused, but kept for signature compatibility + _ bool, // isCopilotChatHost - unused, but kept for signature compatibility +) (bool, error) { + // 1. Check tool's own Enabled function + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + return false, err + } + if !enabled { + return false, nil + } + } + + // 2. Check feature flags + if tool.FeatureFlagEnable != "" { + enabled, err := featureChecker(ctx, tool.FeatureFlagEnable) + if err != nil { + return false, err + } + if !enabled { + return false, nil + } + } + if tool.FeatureFlagDisable != "" { + enabled, err := featureChecker(ctx, tool.FeatureFlagDisable) + if err != nil { + return false, err + } + if enabled { + return false, nil + } + } + + return true, nil +} + +// --- NEW APPROACH: Using EnableCondition --- + +// This represents the new composable condition approach +func newStyleToolFilter(ctx context.Context, tool *ServerTool) (bool, error) { + if tool.EnableCondition != nil { + return tool.EnableCondition.Evaluate(ctx) + } + return true, nil +} + +// --- Test tools with various complexity levels --- + +func createOldStyleTools() []*ServerTool { + // Create tools with various enable patterns typical of remote server + return []*ServerTool{ + // Simple feature flag only + { + Tool: mcp.Tool{Name: "web_search"}, + FeatureFlagEnable: "web_search", + }, + // Feature flag + policy check (user has paid bing) + { + Tool: mcp.Tool{Name: "bing_search"}, + FeatureFlagEnable: "web_search", + Enabled: func(ctx context.Context) (bool, error) { + // Simulates checking if user has paid bing access + return ctx.Value(oldCtxKeyUserHasPaidBing) == true, nil + }, + }, + // CCA AND feature flag + { + Tool: mcp.Tool{Name: "agent_search"}, + FeatureFlagEnable: "code_search", + Enabled: func(ctx context.Context) (bool, error) { + return ctx.Value(oldCtxKeyIsCCA) == true, nil + }, + }, + // CCA bypass (CCA OR feature flag) - complex + { + Tool: mcp.Tool{Name: "copilot_workspace"}, + Enabled: func(ctx context.Context) (bool, error) { + // CCA bypasses feature flag + if ctx.Value(oldCtxKeyIsCCA) == true { + return true, nil + } + // Otherwise check feature flag (we'd need to pass checker somehow) + return ctx.Value(oldCtxKeyFeatureFlagEnabled) == true, nil + }, + }, + // Copilot-chat host bypass + { + Tool: mcp.Tool{Name: "code_analysis"}, + Enabled: func(ctx context.Context) (bool, error) { + // copilot-chat host bypasses feature flag + if ctx.Value(oldCtxKeyIsCopilotChatHost) == true { + return true, nil + } + return ctx.Value(oldCtxKeyFeatureFlagEnabled) == true, nil + }, + }, + } +} + +func createNewStyleTools() []*ServerTool { + // Create equivalent tools using EnableCondition + return []*ServerTool{ + // Simple feature flag only + { + Tool: mcp.Tool{Name: "web_search"}, + EnableCondition: FeatureFlag("web_search"), + }, + // Feature flag + policy check (user has paid bing) + { + Tool: mcp.Tool{Name: "bing_search"}, + EnableCondition: And( + FeatureFlag("web_search"), + ContextBool(ctxKeyUserHasPaidBing), + ), + }, + // CCA AND feature flag + { + Tool: mcp.Tool{Name: "agent_search"}, + EnableCondition: And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ), + }, + // CCA bypass (CCA OR feature flag) - Or combinator + { + Tool: mcp.Tool{Name: "copilot_workspace"}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ), + }, + // Copilot-chat host bypass + { + Tool: mcp.Tool{Name: "code_analysis"}, + EnableCondition: Or( + ContextBool(ctxKeyIsCopilotChatHost), + FeatureFlag("code_search"), + ), + }, + } +} + +// Context keys for benchmark - using string constants for ContextBool +const ( + ctxKeyIsCCA = "is_cca" + ctxKeyUserHasPaidBing = "user_has_paid_bing" + ctxKeyIsCopilotChatHost = "is_copilot_chat_host" + ctxKeyFeatureFlagEnabled = "feature_flag_enabled" +) + +// Old-style context key type for WithValue comparisons +type oldStyleCtxKey string + +const ( + oldCtxKeyIsCCA oldStyleCtxKey = "is_cca" + oldCtxKeyUserHasPaidBing oldStyleCtxKey = "user_has_paid_bing" + oldCtxKeyIsCopilotChatHost oldStyleCtxKey = "is_copilot_chat_host" + oldCtxKeyFeatureFlagEnabled oldStyleCtxKey = "feature_flag_enabled" +) + +// --- BENCHMARKS --- + +// BenchmarkOldStyleFiltering simulates the old approach with Enabled funcs +func BenchmarkOldStyleFiltering(b *testing.B) { + tools := createOldStyleTools() + + // Setup context with actor info (simulates what remote server does) + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, true) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, true) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, false) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = oldStyleToolFilter(ctx, tool, mockFeatureChecker, true, true, false) + } + } +} + +// BenchmarkNewStyleFiltering simulates the new EnableCondition approach +func BenchmarkNewStyleFiltering(b *testing.B) { + tools := createNewStyleTools() + + // Setup context with feature checker and pre-computed bools + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = newStyleToolFilter(ctx, tool) + } + } +} + +// BenchmarkManyToolsOldStyle - simulate filtering 50 tools (realistic toolset) +func BenchmarkManyToolsOldStyle(b *testing.B) { + // Create 50 tools with mixed enable patterns + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + // Feature flag only + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagEnable: "web_search", + } + case 1: + // Feature flag + Enabled check + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagEnable: "code_search", + Enabled: func(ctx context.Context) (bool, error) { + return ctx.Value(oldCtxKeyIsCCA) == true, nil + }, + } + case 2: + // Enabled check only + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + Enabled: func(ctx context.Context) (bool, error) { + if ctx.Value(oldCtxKeyIsCCA) == true { + return true, nil + } + return ctx.Value(oldCtxKeyFeatureFlagEnabled) == true, nil + }, + } + case 3: + // No checks + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + } + case 4: + // Disable flag + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagDisable: "disabled_feature", + } + } + } + + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, true) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = oldStyleToolFilter(ctx, tool, mockFeatureChecker, true, false, false) + } + } +} + +// BenchmarkManyToolsNewStyle - simulate filtering 50 tools with EnableCondition +func BenchmarkManyToolsNewStyle(b *testing.B) { + // Create 50 tools with mixed enable conditions + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + // Feature flag only + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: FeatureFlag("web_search"), + } + case 1: + // Feature flag + context bool (AND) + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ), + } + case 2: + // OR condition + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ), + } + case 3: + // No checks (Always enabled) + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Always(), + } + case 4: + // NOT condition + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Not(FeatureFlag("disabled_feature")), + } + } + } + + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = newStyleToolFilter(ctx, tool) + } + } +} + +// BenchmarkRemoteServerSimulation - simulates 1000 requests filtering all tools +func BenchmarkRemoteServerSimulation_OldStyle(b *testing.B) { + tools := createOldStyleTools() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Simulate 1000 requests + for req := 0; req < 1000; req++ { + ctx := context.Background() + // Each request has slightly different actor context + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, req%3 == 0) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, req%2 == 0) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, req%7 == 0) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, req%4 != 0) + + var enabledCount int + for _, tool := range tools { + enabled, _ := oldStyleToolFilter(ctx, tool, mockFeatureChecker, + req%3 == 0, req%2 == 0, req%7 == 0) + if enabled { + enabledCount++ + } + } + _ = enabledCount + } + } +} + +// BenchmarkRemoteServerSimulation_NewStyle - simulates 1000 requests with EnableCondition +func BenchmarkRemoteServerSimulation_NewStyle(b *testing.B) { + tools := createNewStyleTools() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Simulate 1000 requests + for req := 0; req < 1000; req++ { + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + // Each request has slightly different actor context + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }) + + var enabledCount int + for _, tool := range tools { + enabled, _ := newStyleToolFilter(ctx, tool) + if enabled { + enabledCount++ + } + } + _ = enabledCount + } + } +} + +// BenchmarkShortCircuitEvaluation_OldStyle - tests OR pattern with short-circuit +func BenchmarkShortCircuitEvaluation_OldStyle(b *testing.B) { + // Tool with expensive check that should be short-circuited + tool := &ServerTool{ + Tool: mcp.Tool{Name: "expensive_tool"}, + Enabled: func(ctx context.Context) (bool, error) { + // CCA check (fast, should short-circuit) + if ctx.Value(oldCtxKeyIsCCA) == true { + return true, nil + } + // Expensive check that shouldn't run if CCA is true + for i := 0; i < 100; i++ { + _ = i * i // Simulate work + } + return false, nil + }, + } + + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, true) // Should short-circuit + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = oldStyleToolFilter(ctx, tool, mockFeatureChecker, true, false, false) + } +} + +// BenchmarkShortCircuitEvaluation_NewStyle - tests OR with short-circuit +func BenchmarkShortCircuitEvaluation_NewStyle(b *testing.B) { + // Expensive condition that should be short-circuited + expensiveCondition := &customCondition{ + eval: func(_ context.Context) (bool, error) { + for i := 0; i < 100; i++ { + _ = i * i // Simulate work + } + return false, nil + }, + } + + tool := &ServerTool{ + Tool: mcp.Tool{Name: "expensive_tool"}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), // Should short-circuit before expensive + expensiveCondition, + ), + } + + ctx := context.Background() + ctx = ContextWithBools(ctx, ContextBools{ctxKeyIsCCA: true}) // Should short-circuit + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = newStyleToolFilter(ctx, tool) + } +} + +// customCondition for testing custom expensive conditions +type customCondition struct { + eval func(ctx context.Context) (bool, error) +} + +func (c *customCondition) Evaluate(ctx context.Context) (bool, error) { + return c.eval(ctx) +} + +// BenchmarkComplexConditionTree - tests deep condition tree evaluation +func BenchmarkComplexConditionTree_NewStyle(b *testing.B) { + // Deep condition tree: + // (CCA OR (FeatureFlag AND UserPaidBing)) AND NOT DisabledFeature + tool := &ServerTool{ + Tool: mcp.Tool{Name: "complex_tool"}, + EnableCondition: And( + Or( + ContextBool(ctxKeyIsCCA), + And( + FeatureFlag("web_search"), + ContextBool(ctxKeyUserHasPaidBing), + ), + ), + Not(FeatureFlag("disabled_feature")), + ), + } + + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: false, + ctxKeyUserHasPaidBing: true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = newStyleToolFilter(ctx, tool) + } +} + +// BenchmarkRemoteServerSimulation_NewStyle_Optimized - reuses context (realistic scenario) +// In production, you'd compute context bools once at request start, not per-tool +func BenchmarkRemoteServerSimulation_NewStyle_Optimized(b *testing.B) { + tools := createNewStyleTools() + + // Pre-create context templates for different request types + // This is more realistic - you'd compute bools once at start of request + baseCtx := ContextWithFeatureChecker(context.Background(), mockFeatureChecker) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Simulate 1000 requests + for req := 0; req < 1000; req++ { + // Create context once per request (realistic) + ctx := ContextWithBools(baseCtx, ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }) + + var enabledCount int + for _, tool := range tools { + enabled, _ := newStyleToolFilter(ctx, tool) + if enabled { + enabledCount++ + } + } + _ = enabledCount + } + } +} + +// BenchmarkContextSetup compares context setup costs +func BenchmarkContextSetup_OldStyle(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, true) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, true) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, false) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, true) + _ = ctx + } +} + +func BenchmarkContextSetup_NewStyle(b *testing.B) { + baseCtx := ContextWithFeatureChecker(context.Background(), mockFeatureChecker) + for i := 0; i < b.N; i++ { + ctx := ContextWithBools(baseCtx, ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }) + _ = ctx + } +} + +// BenchmarkPureEvaluation - tests ONLY condition evaluation, no context setup +func BenchmarkPureEvaluation_OldStyle(b *testing.B) { + tools := createOldStyleTools() + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, true) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, true) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, false) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = oldStyleToolFilter(ctx, tool, mockFeatureChecker, true, true, false) + } + } +} + +func BenchmarkPureEvaluation_NewStyle(b *testing.B) { + tools := createNewStyleTools() + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tool := range tools { + _, _ = newStyleToolFilter(ctx, tool) + } + } +} + +// BenchmarkDirectContextValue vs MapLookup - isolate the lookup cost +func BenchmarkDirectContextValue(b *testing.B) { + ctx := context.WithValue(context.Background(), oldCtxKeyIsCCA, true) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ctx.Value(oldCtxKeyIsCCA) == true + } +} + +func BenchmarkMapLookup(b *testing.B) { + ctx := ContextWithBools(context.Background(), ContextBools{ctxKeyIsCCA: true}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ContextBoolFromContext(ctx, ctxKeyIsCCA) + } +} + +// --- Compiled Bitmask Benchmarks --- + +// BenchmarkCompiledFiltering - tests the bitmask-optimized condition evaluation +func BenchmarkCompiledFiltering(b *testing.B) { + tools := createNewStyleTools() + + // Compile all conditions + tcs := NewToolConditionSet(tools) + + // Build mask once (simulates start of request) + mask := tcs.BuildMask(context.Background(), ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tcs.FilterEnabled(mask) + } +} + +// BenchmarkCompiledManyTools - 50 tools with compiled conditions +func BenchmarkCompiledManyTools(b *testing.B) { + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: FeatureFlag("web_search"), + } + case 1: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ), + } + case 2: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ), + } + case 3: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Always(), + } + case 4: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Not(FeatureFlag("disabled_feature")), + } + } + } + + tcs := NewToolConditionSet(tools) + + mask := tcs.BuildMask(context.Background(), ContextBools{ + ctxKeyIsCCA: true, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + "disabled_feature": false, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tcs.FilterEnabled(mask) + } +} + +// BenchmarkCompiledRemoteServer - 1000 requests × 5 tools with compiled conditions +func BenchmarkCompiledRemoteServer(b *testing.B) { + tools := createNewStyleTools() + tcs := NewToolConditionSet(tools) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + // Build mask once per request + mask := tcs.BuildMask(context.Background(), ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + }) + + _ = tcs.FilterEnabled(mask) + } + } +} + +// BenchmarkCompiledRealisticScale - 1000 requests × 50 tools +func BenchmarkCompiledRealisticScale(b *testing.B) { + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: FeatureFlag("web_search"), + } + case 1: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ), + } + case 2: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ), + } + case 3: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Always(), + } + case 4: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Not(FeatureFlag("disabled_feature")), + } + } + } + + tcs := NewToolConditionSet(tools) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + mask := tcs.BuildMask(context.Background(), ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + "disabled_feature": false, + }) + + _ = tcs.FilterEnabled(mask) + } + } +} + +// BenchmarkPureEvaluation_Compiled - just the evaluation, no mask building +func BenchmarkPureEvaluation_Compiled(b *testing.B) { + tools := createNewStyleTools() + tcs := NewToolConditionSet(tools) + + mask := tcs.BuildMask(context.Background(), ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tcs.FilterEnabled(mask) + } +} + +// BenchmarkMaskBuilding - just the mask building overhead +func BenchmarkMaskBuilding(b *testing.B) { + tools := createNewStyleTools() + tcs := NewToolConditionSet(tools) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tcs.BuildMask(ctx, ContextBools{ + ctxKeyIsCCA: true, + ctxKeyUserHasPaidBing: true, + ctxKeyIsCopilotChatHost: false, + }, map[string]bool{ + "web_search": true, + "code_search": true, + "issues_v2": true, + }) + } +} + +// --- Realistic Scale Benchmarks --- + +// BenchmarkRealisticScale_OldStyle - 1000 requests × 50 tools each +func BenchmarkRealisticScale_OldStyle(b *testing.B) { + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagEnable: "web_search", + } + case 1: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagEnable: "code_search", + Enabled: func(ctx context.Context) (bool, error) { + return ctx.Value(oldCtxKeyIsCCA) == true, nil + }, + } + case 2: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + Enabled: func(ctx context.Context) (bool, error) { + if ctx.Value(oldCtxKeyIsCCA) == true { + return true, nil + } + return ctx.Value(oldCtxKeyFeatureFlagEnabled) == true, nil + }, + } + case 3: + tools[i] = &ServerTool{Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}} + case 4: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + FeatureFlagDisable: "disabled_feature", + } + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, req%3 == 0) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, req%2 == 0) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, req%7 == 0) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, req%4 != 0) + + var enabledCount int + for _, tool := range tools { + enabled, _ := oldStyleToolFilter(ctx, tool, mockFeatureChecker, req%3 == 0, req%2 == 0, req%7 == 0) + if enabled { + enabledCount++ + } + } + _ = enabledCount + } + } +} + +// BenchmarkRealisticScale_NewStyle - 1000 requests × 50 tools each +func BenchmarkRealisticScale_NewStyle(b *testing.B) { + tools := make([]*ServerTool, 50) + for i := 0; i < 50; i++ { + switch i % 5 { + case 0: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: FeatureFlag("web_search"), + } + case 1: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ), + } + case 2: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ), + } + case 3: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Always(), + } + case 4: + tools[i] = &ServerTool{ + Tool: mcp.Tool{Name: fmt.Sprintf("tool_%d", i)}, + EnableCondition: Not(FeatureFlag("disabled_feature")), + } + } + } + + baseCtx := ContextWithFeatureChecker(context.Background(), mockFeatureChecker) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + ctx := ContextWithBools(baseCtx, ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }) + + var enabledCount int + for _, tool := range tools { + enabled, _ := newStyleToolFilter(ctx, tool) + if enabled { + enabledCount++ + } + } + _ = enabledCount + } + } +} + +// BenchmarkIntegratedInventory_Compiled - tests the full Inventory.AvailableTools() path +// This is the integrated path using compiled bitmask conditions +func BenchmarkIntegratedInventory_Compiled(b *testing.B) { + // Create tools with EnableConditions + tools := make([]ServerTool, 50) + toolset := ToolsetMetadata{ID: "test", Default: true} + for i := 0; i < 50; i++ { + tools[i].Tool = mcp.Tool{Name: fmt.Sprintf("tool_%d", i)} + tools[i].Toolset = toolset + switch i % 5 { + case 0: + tools[i].EnableCondition = FeatureFlag("web_search") + case 1: + tools[i].EnableCondition = And( + ContextBool(ctxKeyIsCCA), + FeatureFlag("code_search"), + ) + case 2: + tools[i].EnableCondition = Or( + ContextBool(ctxKeyIsCCA), + FeatureFlag("issues_v2"), + ) + case 3: + tools[i].EnableCondition = Always() + case 4: + tools[i].EnableCondition = Not(FeatureFlag("disabled_feature")) + } + } + + // Build inventory (compiles conditions at build time) + inv := NewBuilder(). + SetTools(tools). + Build() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, mockFeatureChecker) + ctx = ContextWithBools(ctx, ContextBools{ + ctxKeyIsCCA: req%3 == 0, + ctxKeyUserHasPaidBing: req%2 == 0, + ctxKeyIsCopilotChatHost: req%7 == 0, + }) + + _ = inv.AvailableTools(ctx) + } + } +} + +// BenchmarkIntegratedInventory_OldStyle - tests equivalent Inventory path with old Enabled func +func BenchmarkIntegratedInventory_OldStyle(b *testing.B) { + // Create tools with old-style Enabled functions + tools := make([]ServerTool, 50) + toolset := ToolsetMetadata{ID: "test", Default: true} + for i := 0; i < 50; i++ { + tools[i].Tool = mcp.Tool{Name: fmt.Sprintf("tool_%d", i)} + tools[i].Toolset = toolset + switch i % 5 { + case 0: + tools[i].FeatureFlagEnable = "web_search" + case 1: + tools[i].FeatureFlagEnable = "code_search" + tools[i].Enabled = func(ctx context.Context) (bool, error) { + return ctx.Value(oldCtxKeyIsCCA) == true, nil + } + case 2: + tools[i].Enabled = func(ctx context.Context) (bool, error) { + if ctx.Value(oldCtxKeyIsCCA) == true { + return true, nil + } + return ctx.Value(oldCtxKeyFeatureFlagEnabled) == true, nil + } + case 3: + // Always enabled - no condition + case 4: + tools[i].FeatureFlagDisable = "disabled_feature" + } + } + + // Build inventory + inv := NewBuilder(). + SetTools(tools). + WithFeatureChecker(mockFeatureChecker). + Build() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for req := 0; req < 1000; req++ { + ctx := context.Background() + ctx = context.WithValue(ctx, oldCtxKeyIsCCA, req%3 == 0) + ctx = context.WithValue(ctx, oldCtxKeyUserHasPaidBing, req%2 == 0) + ctx = context.WithValue(ctx, oldCtxKeyIsCopilotChatHost, req%7 == 0) + ctx = context.WithValue(ctx, oldCtxKeyFeatureFlagEnabled, req%4 != 0) + + _ = inv.AvailableTools(ctx) + } + } +} diff --git a/pkg/inventory/filters.go b/pkg/inventory/filters.go index 991001a64..35b7e6005 100644 --- a/pkg/inventory/filters.go +++ b/pkg/inventory/filters.go @@ -50,15 +50,61 @@ func (r *Inventory) isFeatureFlagAllowed(ctx context.Context, enableFlag, disabl return true } +// buildRequestMask creates a RequestMask for the current request context. +// This computes all condition values once for O(1) evaluation of each tool. +func (r *Inventory) buildRequestMask(ctx context.Context) *RequestMask { + if r.conditionCompiler == nil { + return nil + } + + var bits uint64 + bools := contextBoolsFromContext(ctx) + checker := FeatureCheckerFromContext(ctx) + + r.conditionCompiler.mu.RLock() + defer r.conditionCompiler.mu.RUnlock() + + for key, bit := range r.conditionCompiler.keyToBit { + // Keys are formatted as "ctx:key_name" or "ff:flag_name" + if len(key) < 4 { // Minimum: "ff:x" or "ctx:" prefix + 1 char + continue + } + + switch { + case len(key) > 4 && key[:4] == "ctx:": + // Context bool: "ctx:key_name" + name := key[4:] + if bools != nil && bools[name] { + bits |= 1 << bit + } + case len(key) > 3 && key[:3] == "ff:": + // Feature flag: "ff:flag_name" + name := key[3:] + if checker != nil { + enabled, err := checker(ctx, name) + if err == nil && enabled { + bits |= 1 << bit + } + } + } + } + + return &RequestMask{ + bits: bits, + ctx: ctx, + } +} + // isToolEnabled checks if a specific tool is enabled based on current filters. // Filter evaluation order: -// 1. Tool.Enabled (tool self-filtering) -// 2. FeatureFlagEnable/FeatureFlagDisable -// 3. Read-only filter -// 4. Builder filters (via WithFilter) -// 5. Toolset/additional tools -func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // 1. Check tool's own Enabled function first +// 1. Tool.Enabled (legacy tool self-filtering - deprecated) +// 2. Tool.EnableCondition via compiled bitmask (O(1) evaluation) +// 3. FeatureFlagEnable/FeatureFlagDisable (legacy - deprecated) +// 4. Read-only filter +// 5. Builder filters (via WithFilter) +// 6. Toolset/additional tools +func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool, toolIndex int, rm *RequestMask) bool { + // 1. Check tool's legacy Enabled function first (for backward compatibility) if tool.Enabled != nil { enabled, err := tool.Enabled(ctx) if err != nil { @@ -69,15 +115,48 @@ func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool) bool { return false } } - // 2. Check feature flags + // 2. Check tool's EnableCondition via compiled bitmask (O(1) evaluation) + if toolIndex >= 0 && toolIndex < len(r.compiledConditions) && r.compiledConditions[toolIndex] != nil { + if rm != nil { + enabled, err := r.compiledConditions[toolIndex].Evaluate(rm) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.EnableCondition check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } else if tool.EnableCondition != nil { + // Fallback to tree-based evaluation if no request mask + enabled, err := tool.EnableCondition.Evaluate(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.EnableCondition check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } + } else if tool.EnableCondition != nil { + // Fallback to tree-based evaluation if no compiled condition + enabled, err := tool.EnableCondition.Evaluate(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.EnableCondition check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } + // 3. Check legacy feature flags (for backward compatibility) if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { return false } - // 3. Check read-only filter (applies to all tools) + // 4. Check read-only filter (applies to all tools) if r.readOnly && !tool.IsReadOnly() { return false } - // 4. Apply builder filters + // 5. Apply builder filters for _, filter := range r.filters { allowed, err := filter(ctx, tool) if err != nil { @@ -88,11 +167,11 @@ func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool) bool { return false } } - // 5. Check if tool is in additionalTools (bypasses toolset filter) + // 6. Check if tool is in additionalTools (bypasses toolset filter) if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { return true } - // 5. Check toolset filter + // 6. Check toolset filter if !r.isToolsetEnabled(tool.Toolset.ID) { return false } @@ -102,30 +181,30 @@ func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool) bool { // AvailableTools returns the tools that pass all current filters, // sorted deterministically by toolset ID, then tool name. // The context is used for feature flag evaluation. +// Uses O(1) bitmask evaluation for EnableConditions when possible. +// Note: Tools are pre-sorted at build time, so filtering preserves order. func (r *Inventory) AvailableTools(ctx context.Context) []ServerTool { + // Build request mask once for O(1) condition evaluation + rm := r.buildRequestMask(ctx) + + // Tools are pre-sorted at build time; filtering preserves order var result []ServerTool for i := range r.tools { tool := &r.tools[i] - if r.isToolEnabled(ctx, tool) { + if r.isToolEnabled(ctx, tool, i, rm) { result = append(result, *tool) } } - // Sort deterministically: by toolset ID, then by tool name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Tool.Name < result[j].Tool.Name - }) - return result } // AvailableResourceTemplates returns resource templates that pass all current filters, // sorted deterministically by toolset ID, then template name. // The context is used for feature flag evaluation. +// Note: Resources are pre-sorted at build time, so filtering preserves order. func (r *Inventory) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + // Resources are pre-sorted at build time; filtering preserves order var result []ServerResourceTemplate for i := range r.resourceTemplates { res := &r.resourceTemplates[i] @@ -138,21 +217,15 @@ func (r *Inventory) AvailableResourceTemplates(ctx context.Context) []ServerReso } } - // Sort deterministically: by toolset ID, then by template name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Template.Name < result[j].Template.Name - }) - return result } // AvailablePrompts returns prompts that pass all current filters, // sorted deterministically by toolset ID, then prompt name. // The context is used for feature flag evaluation. +// Note: Prompts are pre-sorted at build time, so filtering preserves order. func (r *Inventory) AvailablePrompts(ctx context.Context) []ServerPrompt { + // Prompts are pre-sorted at build time; filtering preserves order var result []ServerPrompt for i := range r.prompts { prompt := &r.prompts[i] @@ -165,14 +238,6 @@ func (r *Inventory) AvailablePrompts(ctx context.Context) []ServerPrompt { } } - // Sort deterministically: by toolset ID, then by prompt name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Prompt.Name < result[j].Prompt.Name - }) - return result } @@ -221,7 +286,9 @@ func (r *Inventory) filterPromptsByName(name string) []ServerPrompt { // ToolsForToolset returns all tools belonging to a specific toolset. // This method bypasses the toolset enabled filter (for dynamic toolset registration), // but still respects the read-only filter. +// Note: Tools are pre-sorted at build time, so filtering preserves order. func (r *Inventory) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + // Tools are pre-sorted at build time; filtering preserves order var result []ServerTool for i := range r.tools { tool := &r.tools[i] @@ -234,11 +301,6 @@ func (r *Inventory) ToolsForToolset(toolsetID ToolsetID) []ServerTool { } } - // Sort by tool name for deterministic order - sort.Slice(result, func(i, j int) bool { - return result[i].Tool.Name < result[j].Tool.Name - }) - return result } diff --git a/pkg/inventory/registry.go b/pkg/inventory/registry.go index f3691e38a..4cded22e9 100644 --- a/pkg/inventory/registry.go +++ b/pkg/inventory/registry.go @@ -24,6 +24,7 @@ import ( // - Deterministic ordering for documentation generation // - Lazy dependency injection during registration via RegisterAll() // - Runtime toolset enabling for dynamic toolsets mode +// - O(1) EnableCondition evaluation via pre-compiled bitmasks type Inventory struct { // tools holds all tools in this group (ordered for iteration) tools []ServerTool @@ -40,6 +41,12 @@ type Inventory struct { defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs toolsetDescriptions map[ToolsetID]string // toolset ID -> description + // Compiled conditions for O(1) EnableCondition evaluation (set during Build) + // Maps tool index → compiled condition (nil means always enabled) + compiledConditions []*CompiledCondition + // conditionCompiler used to compile conditions (shared across inventory) + conditionCompiler *ConditionCompiler + // Filters - these control what's returned by Available* methods // readOnly when true filters out write tools readOnly bool @@ -110,7 +117,9 @@ func (r *Inventory) ForMCPRequest(method string, itemName string) *Inventory { enabledToolsets: r.enabledToolsets, // shared, not modified additionalTools: r.additionalTools, // shared, not modified featureChecker: r.featureChecker, - filters: r.filters, // shared, not modified + filters: r.filters, // shared, not modified + compiledConditions: r.compiledConditions, // shared, not modified + conditionCompiler: r.conditionCompiler, // shared, not modified unrecognizedToolsets: r.unrecognizedToolsets, } diff --git a/pkg/inventory/registry_test.go b/pkg/inventory/registry_test.go index 41e94b8d9..9a15977aa 100644 --- a/pkg/inventory/registry_test.go +++ b/pkg/inventory/registry_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" ) // testToolsetMetadata returns a ToolsetMetadata for testing @@ -1598,10 +1599,11 @@ func TestFilteredToolsMatchesAvailableTools(t *testing.T) { func TestFilteringOrder(t *testing.T) { // Test that filters are applied in the correct order: // 1. Tool.Enabled - // 2. Feature flags - // 3. Read-only - // 4. Builder filters - // 5. Toolset/additional tools + // 2. EnableCondition + // 3. Feature flags + // 4. Read-only + // 5. Builder filters + // 6. Toolset/additional tools callOrder := []string{} @@ -1643,3 +1645,355 @@ func TestFilteringOrder(t *testing.T) { } } } + +// Tests for EnableCondition integration +func TestEnableConditionSimple(t *testing.T) { + // Tool with EnableCondition that returns true + tool := mockTool("test_tool", "toolset1", true) + tool.EnableCondition = Always() + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when EnableCondition returns true") + } + + // Tool with EnableCondition that returns false + tool2 := mockTool("test_tool2", "toolset1", true) + tool2.EnableCondition = Never() + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool2}). + WithToolsets([]string{"all"}). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when EnableCondition returns false") + } +} + +func TestEnableConditionWithFeatureFlag(t *testing.T) { + // Tool with EnableCondition using FeatureFlag condition + tool := mockTool("test_tool", "toolset1", true) + tool.EnableCondition = FeatureFlag("my_feature") + + // Without feature checker - should be excluded + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when no feature checker is available") + } + + // With feature checker that returns true + ctx := ContextWithFeatureChecker(context.Background(), func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + }) + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + available2 := reg2.AvailableTools(ctx) + if len(available2) != 1 { + t.Error("Tool should be included when feature flag is enabled via context") + } +} + +func TestEnableConditionWithContextBool(t *testing.T) { + // Tool with EnableCondition using ContextBool + tool := mockTool("cca_tool", "toolset1", true) + tool.EnableCondition = ContextBool("is_cca") + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + // Without context bool - should be excluded + available1 := reg.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when context bool is not set") + } + + // With context bool = true + ctx := ContextWithBools(context.Background(), ContextBools{"is_cca": true}) + available2 := reg.AvailableTools(ctx) + if len(available2) != 1 { + t.Error("Tool should be included when context bool is true") + } + + // With context bool = false + ctx3 := ContextWithBools(context.Background(), ContextBools{"is_cca": false}) + available3 := reg.AvailableTools(ctx3) + if len(available3) != 0 { + t.Error("Tool should be excluded when context bool is false") + } +} + +func TestEnableConditionComplexPattern(t *testing.T) { + // CCA bypass pattern: tool is available if CCA OR feature flag is enabled + tool := mockTool("agent_tool", "toolset1", true) + tool.EnableCondition = Or( + ContextBool("is_cca"), + FeatureFlag("agent_search"), + ) + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + // CCA request - should be enabled without feature flag + ctxCCA := ContextWithBools(context.Background(), ContextBools{"is_cca": true}) + availableCCA := reg.AvailableTools(ctxCCA) + if len(availableCCA) != 1 { + t.Error("Tool should be enabled for CCA requests") + } + + // Non-CCA with feature flag - should be enabled + ctxFF := ContextWithBools(context.Background(), ContextBools{"is_cca": false}) + ctxFF = ContextWithFeatureChecker(ctxFF, func(_ context.Context, flag string) (bool, error) { + return flag == "agent_search", nil + }) + availableFF := reg.AvailableTools(ctxFF) + if len(availableFF) != 1 { + t.Error("Tool should be enabled with feature flag for non-CCA") + } + + // Non-CCA without feature flag - should be excluded + ctxNone := ContextWithBools(context.Background(), ContextBools{"is_cca": false}) + availableNone := reg.AvailableTools(ctxNone) + if len(availableNone) != 0 { + t.Error("Tool should be excluded for non-CCA without feature flag") + } +} + +func TestEnableConditionAndLegacyEnabledInteraction(t *testing.T) { + // When both Enabled and EnableCondition are set, both must pass + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil // Legacy Enabled passes + } + tool.EnableCondition = Never() // But EnableCondition fails + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when EnableCondition returns false, even if Enabled returns true") + } + + // Both pass + tool2 := mockTool("test_tool2", "toolset1", true) + tool2.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + tool2.EnableCondition = Always() + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool2}). + WithToolsets([]string{"all"}). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and EnableCondition pass") + } +} + +func TestEnableConditionFilteringOrder(t *testing.T) { + // Test that EnableCondition is checked after legacy Enabled but before feature flags + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "LegacyEnabled") + return true, nil + } + tool.EnableCondition = ConditionFunc(func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "EnableCondition") + return false, nil // Return false to stop early + }) + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Should stop at EnableCondition since it returns false + expectedOrder := []string{"LegacyEnabled", "EnableCondition"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } + } +} + +func TestCompiledConditionsIntegration(t *testing.T) { + // Test that conditions are compiled at build time and evaluated with bitmask + // at request time (via AvailableTools) + + checker := func(_ context.Context, flagName string) (bool, error) { + switch flagName { + case "enabled_flag": + return true, nil + case "disabled_flag": + return false, nil + default: + return false, nil + } + } + + toolAlwaysEnabled := mockTool("always_on", "test", true) + toolAlwaysEnabled.EnableCondition = Always() + + toolNeverEnabled := mockTool("never_on", "test", true) + toolNeverEnabled.EnableCondition = Never() + + toolFlagEnabled := mockTool("flag_on", "test", true) + toolFlagEnabled.EnableCondition = FeatureFlag("enabled_flag") + + toolFlagDisabled := mockTool("flag_off", "test", true) + toolFlagDisabled.EnableCondition = FeatureFlag("disabled_flag") + + toolContextBool := mockTool("ctx_bool", "test", true) + toolContextBool.EnableCondition = ContextBool("my_bool") + + toolComplex := mockTool("complex", "test", true) + toolComplex.EnableCondition = Or( + ContextBool("my_bool"), + FeatureFlag("enabled_flag"), + ) + + reg := NewBuilder(). + SetTools([]ServerTool{ + toolAlwaysEnabled, + toolNeverEnabled, + toolFlagEnabled, + toolFlagDisabled, + toolContextBool, + toolComplex, + }). + WithToolsets([]string{"all"}). + Build() + + // Verify compiler was created + assert.NotNil(t, reg.conditionCompiler, "Condition compiler should be created") + assert.NotNil(t, reg.compiledConditions, "Compiled conditions should be created") + + // Test without context bools - only flag-based tools should pass + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, checker) + + tools := reg.AvailableTools(ctx) + toolNames := make([]string, len(tools)) + for i, tool := range tools { + toolNames[i] = tool.Tool.Name + } + + assert.Contains(t, toolNames, "always_on", "Always enabled tool should be available") + assert.Contains(t, toolNames, "flag_on", "Flag enabled tool should be available") + assert.Contains(t, toolNames, "complex", "Complex (OR flag) tool should be available") + assert.NotContains(t, toolNames, "never_on", "Never enabled tool should not be available") + assert.NotContains(t, toolNames, "flag_off", "Disabled flag tool should not be available") + assert.NotContains(t, toolNames, "ctx_bool", "Context bool tool should not be available without bool set") + + // Test with context bool set + ctx = ContextWithBools(ctx, ContextBools{"my_bool": true}) + + tools = reg.AvailableTools(ctx) + toolNames = make([]string, len(tools)) + for i, tool := range tools { + toolNames[i] = tool.Tool.Name + } + + assert.Contains(t, toolNames, "ctx_bool", "Context bool tool should be available when bool is set") + assert.Contains(t, toolNames, "complex", "Complex tool should still be available") +} + +func TestCompiledConditionsANDBitmask(t *testing.T) { + // Test that AND conditions are compiled to bitmask AND operations + checker := func(_ context.Context, flagName string) (bool, error) { + return flagName == "flag_a" || flagName == "flag_b", nil + } + + // Tool requires both flags + tool := mockTool("both_flags", "test", true) + tool.EnableCondition = And( + FeatureFlag("flag_a"), + FeatureFlag("flag_b"), + ) + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + // Verify it was compiled to bitmask AND + assert.NotNil(t, reg.compiledConditions[0]) + compiled := reg.compiledConditions[0] + assert.Equal(t, evalBitmaskAnd, compiled.evalType, "AND of flags should compile to bitmaskAnd") + + // Test evaluation + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, checker) + + tools := reg.AvailableTools(ctx) + assert.Len(t, tools, 1, "Tool with both flags enabled should be available") +} + +func TestCompiledConditionsORBitmask(t *testing.T) { + // Test that OR conditions are compiled to bitmask OR operations + checker := func(_ context.Context, flagName string) (bool, error) { + return flagName == "flag_a", nil // Only flag_a is enabled + } + + // Tool requires either flag + tool := mockTool("either_flag", "test", true) + tool.EnableCondition = Or( + FeatureFlag("flag_a"), + FeatureFlag("flag_b"), + ) + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + + // Verify it was compiled to bitmask OR + assert.NotNil(t, reg.compiledConditions[0]) + compiled := reg.compiledConditions[0] + assert.Equal(t, evalBitmaskOr, compiled.evalType, "OR of flags should compile to bitmaskOr") + + // Test evaluation - should pass because flag_a is enabled + ctx := context.Background() + ctx = ContextWithFeatureChecker(ctx, checker) + + tools := reg.AvailableTools(ctx) + assert.Len(t, tools, 1, "Tool with one flag enabled should be available") +} diff --git a/pkg/inventory/server_tool.go b/pkg/inventory/server_tool.go index 362ee2643..ae0eea440 100644 --- a/pkg/inventory/server_tool.go +++ b/pkg/inventory/server_tool.go @@ -58,17 +58,44 @@ type ServerTool struct { // FeatureFlagEnable specifies a feature flag that must be enabled for this tool // to be available. If set and the flag is not enabled, the tool is omitted. + // + // Deprecated: Use EnableCondition with FeatureFlag() instead for composable conditions. + // This field is checked before EnableCondition for backward compatibility. FeatureFlagEnable string // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool // to be omitted. Used to disable tools when a feature flag is on. + // + // Deprecated: Use EnableCondition with Not(FeatureFlag()) instead for composable conditions. + // This field is checked before EnableCondition for backward compatibility. FeatureFlagDisable string + // EnableCondition is the composable condition for tool availability. + // Use the condition combinators (FeatureFlag, ContextBool, And, Or, Not) + // to build complex enable logic declaratively. + // + // Examples: + // // Feature flag only + // EnableCondition: FeatureFlag("web_search") + // + // // Feature flag AND user policy + // EnableCondition: And(FeatureFlag("web_search"), ContextBool("user_has_paid_access")) + // + // // CCA bypass (CCA OR feature flag) + // EnableCondition: Or(ContextBool("is_cca"), FeatureFlag("agent_search")) + // + // If nil, the tool is enabled (subject to other filters like toolset, read-only). + EnableCondition EnableCondition + // Enabled is an optional function called at build/filter time to determine // if this tool should be available. If nil, the tool is considered enabled // (subject to FeatureFlagEnable/FeatureFlagDisable checks). // The context carries request-scoped information for the consumer to use. // Returns (enabled, error). On error, the tool should be treated as disabled. + // + // Deprecated: Use EnableCondition instead for composable, declarative conditions. + // If both Enabled and EnableCondition are set, Enabled takes precedence for + // backward compatibility. Migrate to EnableCondition for new tools. Enabled func(ctx context.Context) (bool, error) } From 880c19455b9a660b352d6d6a6aa48163206214dc Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 18 Dec 2025 14:36:44 +0100 Subject: [PATCH 2/2] refactor: embed compiledCondition in ServerTool Instead of maintaining a separate compiledConditions slice indexed by tool position, embed the *CompiledCondition directly in each ServerTool. This eliminates the index alignment issue where filtering to a single tool (ForMCPRequest with ToolsCall) would break the mapping between tools and their compiled conditions. Benefits: - Compiled condition travels with the tool during any filtering - No index bookkeeping needed - Simpler code in filters.go (no toolIndex parameter) - ForMCPRequest works correctly for single-tool lookups - Same O(1) bitmask evaluation performance --- pkg/inventory/builder.go | 17 ++++++++--------- pkg/inventory/filters.go | 11 ++++++----- pkg/inventory/registry.go | 15 +++++++-------- pkg/inventory/registry_test.go | 13 ++++++------- pkg/inventory/server_tool.go | 6 ++++++ 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pkg/inventory/builder.go b/pkg/inventory/builder.go index dbb1d6fa0..6cb00ba78 100644 --- a/pkg/inventory/builder.go +++ b/pkg/inventory/builder.go @@ -166,9 +166,9 @@ func (b *Builder) Build() *Inventory { } } - // Compile EnableConditions for O(1) bitmask evaluation - // Note: compileConditions uses r.tools which is now sortedTools - r.conditionCompiler, r.compiledConditions = b.compileConditions(sortedTools) + // Compile EnableConditions for O(1) bitmask evaluation. + // This modifies sortedTools in place, embedding compiled conditions in each tool. + r.conditionCompiler = b.compileConditions(sortedTools) return r } @@ -223,21 +223,20 @@ func (b *Builder) preSortPrompts() []ServerPrompt { } // compileConditions compiles all EnableConditions into bitmask-based evaluators. -// Returns the compiler (for building request masks) and compiled conditions slice. -// Takes the sorted tools slice to ensure compiled conditions align with sorted order. -func (b *Builder) compileConditions(sortedTools []ServerTool) (*ConditionCompiler, []*CompiledCondition) { +// Modifies sortedTools in place, embedding compiled conditions in each tool. +// Returns the compiler (for building request masks at runtime). +func (b *Builder) compileConditions(sortedTools []ServerTool) *ConditionCompiler { compiler := NewConditionCompiler() - compiled := make([]*CompiledCondition, len(sortedTools)) for i := range sortedTools { if sortedTools[i].EnableCondition != nil { - compiled[i] = compiler.Compile(sortedTools[i].EnableCondition) + sortedTools[i].compiledCondition = compiler.Compile(sortedTools[i].EnableCondition) } // nil means no condition (always enabled from condition perspective) } compiler.Freeze() - return compiler, compiled + return compiler } // processToolsets processes the toolsetIDs configuration and returns: diff --git a/pkg/inventory/filters.go b/pkg/inventory/filters.go index 35b7e6005..d81dc3755 100644 --- a/pkg/inventory/filters.go +++ b/pkg/inventory/filters.go @@ -103,7 +103,7 @@ func (r *Inventory) buildRequestMask(ctx context.Context) *RequestMask { // 4. Read-only filter // 5. Builder filters (via WithFilter) // 6. Toolset/additional tools -func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool, toolIndex int, rm *RequestMask) bool { +func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool, rm *RequestMask) bool { // 1. Check tool's legacy Enabled function first (for backward compatibility) if tool.Enabled != nil { enabled, err := tool.Enabled(ctx) @@ -116,9 +116,10 @@ func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool, toolInd } } // 2. Check tool's EnableCondition via compiled bitmask (O(1) evaluation) - if toolIndex >= 0 && toolIndex < len(r.compiledConditions) && r.compiledConditions[toolIndex] != nil { + // The compiled condition is embedded in the tool, so it travels with filtering. + if tool.compiledCondition != nil { if rm != nil { - enabled, err := r.compiledConditions[toolIndex].Evaluate(rm) + enabled, err := tool.compiledCondition.Evaluate(rm) if err != nil { fmt.Fprintf(os.Stderr, "Tool.EnableCondition check error for %q: %v\n", tool.Tool.Name, err) return false @@ -126,7 +127,7 @@ func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool, toolInd if !enabled { return false } - } else if tool.EnableCondition != nil { + } else { // Fallback to tree-based evaluation if no request mask enabled, err := tool.EnableCondition.Evaluate(ctx) if err != nil { @@ -191,7 +192,7 @@ func (r *Inventory) AvailableTools(ctx context.Context) []ServerTool { var result []ServerTool for i := range r.tools { tool := &r.tools[i] - if r.isToolEnabled(ctx, tool, i, rm) { + if r.isToolEnabled(ctx, tool, rm) { result = append(result, *tool) } } diff --git a/pkg/inventory/registry.go b/pkg/inventory/registry.go index 4cded22e9..c78df1e8c 100644 --- a/pkg/inventory/registry.go +++ b/pkg/inventory/registry.go @@ -41,10 +41,8 @@ type Inventory struct { defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs toolsetDescriptions map[ToolsetID]string // toolset ID -> description - // Compiled conditions for O(1) EnableCondition evaluation (set during Build) - // Maps tool index → compiled condition (nil means always enabled) - compiledConditions []*CompiledCondition - // conditionCompiler used to compile conditions (shared across inventory) + // conditionCompiler used to compile conditions and build request masks. + // The compiled conditions themselves are embedded in each ServerTool. conditionCompiler *ConditionCompiler // Filters - these control what's returned by Available* methods @@ -107,7 +105,9 @@ const ( func (r *Inventory) ForMCPRequest(method string, itemName string) *Inventory { // Create a shallow copy with shared filter settings // Note: lazy-init maps (toolsByName, etc.) are NOT copied - the new Registry - // will initialize its own maps on first use if needed + // will initialize its own maps on first use if needed. + // Compiled conditions are embedded in each ServerTool, so they travel with the tool + // during filtering - no index alignment issues. result := &Inventory{ tools: r.tools, resourceTemplates: r.resourceTemplates, @@ -117,9 +117,8 @@ func (r *Inventory) ForMCPRequest(method string, itemName string) *Inventory { enabledToolsets: r.enabledToolsets, // shared, not modified additionalTools: r.additionalTools, // shared, not modified featureChecker: r.featureChecker, - filters: r.filters, // shared, not modified - compiledConditions: r.compiledConditions, // shared, not modified - conditionCompiler: r.conditionCompiler, // shared, not modified + filters: r.filters, // shared, not modified + conditionCompiler: r.conditionCompiler, unrecognizedToolsets: r.unrecognizedToolsets, } diff --git a/pkg/inventory/registry_test.go b/pkg/inventory/registry_test.go index 9a15977aa..ed0355e0a 100644 --- a/pkg/inventory/registry_test.go +++ b/pkg/inventory/registry_test.go @@ -1904,7 +1904,6 @@ func TestCompiledConditionsIntegration(t *testing.T) { // Verify compiler was created assert.NotNil(t, reg.conditionCompiler, "Condition compiler should be created") - assert.NotNil(t, reg.compiledConditions, "Compiled conditions should be created") // Test without context bools - only flag-based tools should pass ctx := context.Background() @@ -1954,9 +1953,9 @@ func TestCompiledConditionsANDBitmask(t *testing.T) { WithToolsets([]string{"all"}). Build() - // Verify it was compiled to bitmask AND - assert.NotNil(t, reg.compiledConditions[0]) - compiled := reg.compiledConditions[0] + // Verify it was compiled to bitmask AND (condition embedded in tool) + assert.NotNil(t, reg.tools[0].compiledCondition) + compiled := reg.tools[0].compiledCondition assert.Equal(t, evalBitmaskAnd, compiled.evalType, "AND of flags should compile to bitmaskAnd") // Test evaluation @@ -1985,9 +1984,9 @@ func TestCompiledConditionsORBitmask(t *testing.T) { WithToolsets([]string{"all"}). Build() - // Verify it was compiled to bitmask OR - assert.NotNil(t, reg.compiledConditions[0]) - compiled := reg.compiledConditions[0] + // Verify it was compiled to bitmask OR (condition embedded in tool) + assert.NotNil(t, reg.tools[0].compiledCondition) + compiled := reg.tools[0].compiledCondition assert.Equal(t, evalBitmaskOr, compiled.evalType, "OR of flags should compile to bitmaskOr") // Test evaluation - should pass because flag_a is enabled diff --git a/pkg/inventory/server_tool.go b/pkg/inventory/server_tool.go index ae0eea440..f96a2c2a3 100644 --- a/pkg/inventory/server_tool.go +++ b/pkg/inventory/server_tool.go @@ -97,6 +97,12 @@ type ServerTool struct { // If both Enabled and EnableCondition are set, Enabled takes precedence for // backward compatibility. Migrate to EnableCondition for new tools. Enabled func(ctx context.Context) (bool, error) + + // compiledCondition is the pre-compiled bitmask evaluator for EnableCondition. + // Set at build time by compileConditions(). nil means always enabled. + // This is embedded in the tool so it travels with the tool during filtering, + // eliminating index alignment issues when filtering to single tools. + compiledCondition *CompiledCondition } // IsReadOnly returns true if this tool is marked as read-only via annotations.