diff --git a/src/flagsmith_sql_flag_engine/translator.py b/src/flagsmith_sql_flag_engine/translator.py index bc9f0ec..c836120 100644 --- a/src/flagsmith_sql_flag_engine/translator.py +++ b/src/flagsmith_sql_flag_engine/translator.py @@ -517,26 +517,33 @@ def translate_condition(cond: SegmentCondition, ctx: TranslateContext) -> str | def translate_rule(rule: SegmentRule, ctx: TranslateContext) -> str | None: - children: list[str] = [] + cond_children: list[str] = [] for cond in rule.get("conditions") or []: sql = translate_condition(cond, ctx) if sql is None: return None - children.append(f"({sql})") + cond_children.append(f"({sql})") + rule_children: list[str] = [] for nested in rule.get("rules") or []: sql = translate_rule(nested, ctx) if sql is None: return None - children.append(f"({sql})") - - assert children, "segment rule must have at least one condition or nested rule" - match rule["type"]: - case "ALL": - return " AND ".join(children) - case "ANY": - return " OR ".join(children) - case "NONE": - return f"NOT ({' OR '.join(children)})" + rule_children.append(f"({sql})") + + # Mirror the engine's `context_matches_rule`: conditions and nested rules + # are two independent groups AND-ed together, each vacuously true when + # empty. + op = {"ALL": " AND ", "ANY": " OR ", "NONE": " OR "}[rule["type"]] + groups = [ + f"NOT ({op.join(c)})" if rule["type"] == "NONE" else op.join(c) + for c in (cond_children, rule_children) + if c + ] + if not groups: + return "TRUE" + if len(groups) == 1: + return groups[0] + return " AND ".join(f"({g})" for g in groups) def translate_segment(segment: SegmentContext, ctx: TranslateContext) -> str | None: diff --git a/tests/test_translator_unit.py b/tests/test_translator_unit.py index 079f1fa..946b5c5 100644 --- a/tests/test_translator_unit.py +++ b/tests/test_translator_unit.py @@ -282,6 +282,96 @@ def test_translate_segment__empty_rules__returns_false() -> None: assert translate_segment(seg, _ctx()) == "FALSE" +def test_translate_segment__rule_with_no_conditions_or_nested_rules__matches_all() -> None: + # Given + seg: SegmentContext = { + "key": "10", + "name": "s", + "rules": [{"type": "ALL", "conditions": [], "rules": []}], + } + + # When / Then + assert translate_segment(seg, _ctx()) == "(TRUE)" + + +def test_translate_segment__rule_with_conditions_and_nested_rule__ands_the_two_groups() -> None: + # Given a rule carrying both a condition and a (translatable) nested rule + seg: SegmentContext = { + "key": "11", + "name": "s", + "rules": [ + { + "type": "ALL", + "conditions": [{"operator": "EQUAL", "property": "plan", "value": "growth"}], + "rules": [ + { + "type": "ALL", + "conditions": [{"operator": "EQUAL", "property": "country", "value": "GB"}], + } + ], + } + ], + } + + # When translated + sql = translate_segment(seg, _ctx()) + + # Then the condition group and the nested-rule group are AND-ed together + # (mirroring the engine: matches_conditions AND matches_rules) + assert sql == ( + "((((" + "if(i.traits.`plan` IS NULL, NULL, toString(i.traits.`plan`)) IS NOT NULL" + " AND ((toString(i.traits.`plan`) = 'growth') OR (i.traits.`plan`.:Bool = true))" + "))) AND ((((" + "if(i.traits.`country` IS NULL, NULL, toString(i.traits.`country`)) IS NOT NULL" + " AND ((toString(i.traits.`country`) = 'GB') OR (i.traits.`country`.:Bool = true))" + ")))))" + ) + + +def test_translate_segment__any_rule_with_conditions_and_nested_rule__ors_within_ands_across() -> ( + None +): + # Given an ANY rule with two conditions and a nested rule + seg: SegmentContext = { + "key": "12", + "name": "s", + "rules": [ + { + "type": "ANY", + "conditions": [ + {"operator": "EQUAL", "property": "plan", "value": "growth"}, + {"operator": "EQUAL", "property": "plan", "value": "scale"}, + ], + "rules": [ + { + "type": "ALL", + "conditions": [{"operator": "EQUAL", "property": "country", "value": "GB"}], + } + ], + } + ], + } + + # When + sql = translate_segment(seg, _ctx()) + + # Then + # interpreted as `any(conditions) AND any(rules)` + assert sql == ( + "((((" + "if(i.traits.`plan` IS NULL, NULL, toString(i.traits.`plan`)) IS NOT NULL" + " AND ((toString(i.traits.`plan`) = 'growth') OR (i.traits.`plan`.:Bool = true))" + ")) OR ((" + "if(i.traits.`plan` IS NULL, NULL, toString(i.traits.`plan`)) IS NOT NULL" + " AND ((toString(i.traits.`plan`) = 'scale') OR (i.traits.`plan`.:Bool = true))" + "))) AND ((((" + "if(i.traits.`country` IS NULL, NULL, toString(i.traits.`country`)) IS NOT NULL" + " AND ((toString(i.traits.`country`) = 'GB') OR (i.traits.`country`.:Bool = true))" + ")))))" + ) + + def test_translate_segment__trait_key_with_hyphens__quotes_subcolumn_path() -> None: # Given a trait key with a hyphen (illegal as an unquoted SQL identifier) seg: SegmentContext = {