diff --git a/infrastructure/modules/lambda/lambda.tf b/infrastructure/modules/lambda/lambda.tf index 889cd6d2a..6485be675 100644 --- a/infrastructure/modules/lambda/lambda.tf +++ b/infrastructure/modules/lambda/lambda.tf @@ -17,14 +17,15 @@ resource "aws_lambda_function" "eligibility_signposting_lambda" { environment { variables = { - PERSON_TABLE_NAME = var.eligibility_status_table_name, - RULES_BUCKET_NAME = var.eligibility_rules_bucket_name, - KINESIS_AUDIT_STREAM_TO_S3 = var.kinesis_audit_stream_to_s3_name - ENV = var.environment - LOG_LEVEL = var.log_level - ENABLE_XRAY_PATCHING = var.enable_xray_patching - API_DOMAIN_NAME = var.api_domain_name - HASHING_SECRET_NAME = var.hashing_secret_name + PERSON_TABLE_NAME = var.eligibility_status_table_name, + RULES_BUCKET_NAME = var.eligibility_rules_bucket_name, + CONSUMER_MAPPING_BUCKET_NAME = var.eligibility_consumer_mappings_bucket_name, + KINESIS_AUDIT_STREAM_TO_S3 = var.kinesis_audit_stream_to_s3_name + ENV = var.environment + LOG_LEVEL = var.log_level + ENABLE_XRAY_PATCHING = var.enable_xray_patching + API_DOMAIN_NAME = var.api_domain_name + HASHING_SECRET_NAME = var.hashing_secret_name } } diff --git a/infrastructure/modules/lambda/variables.tf b/infrastructure/modules/lambda/variables.tf index 85b639862..6f238e149 100644 --- a/infrastructure/modules/lambda/variables.tf +++ b/infrastructure/modules/lambda/variables.tf @@ -44,6 +44,11 @@ variable "eligibility_rules_bucket_name" { type = string } +variable "eligibility_consumer_mappings_bucket_name" { + description = "consumer mappings bucket name" + type = string +} + variable "eligibility_status_table_name" { description = "eligibility datastore table name" type = string diff --git a/infrastructure/stacks/api-layer/iam_policies.tf b/infrastructure/stacks/api-layer/iam_policies.tf index 0fd67e453..b798fa48f 100644 --- a/infrastructure/stacks/api-layer/iam_policies.tf +++ b/infrastructure/stacks/api-layer/iam_policies.tf @@ -104,6 +104,60 @@ data "aws_iam_policy_document" "rules_s3_bucket_policy" { } } +# Policy doc for S3 Consumer Mappings bucket +data "aws_iam_policy_document" "s3_consumer_mapping_bucket_policy" { + statement { + sid = "AllowSSLRequestsOnly" + actions = [ + "s3:GetObject", + "s3:ListBucket", + ] + resources = [ + module.s3_consumer_mappings_bucket.storage_bucket_arn, + "${module.s3_consumer_mappings_bucket.storage_bucket_arn}/*", + ] + condition { + test = "Bool" + values = ["true"] + variable = "aws:SecureTransport" + } + } +} + +# ensure only secure transport is allowed + +resource "aws_s3_bucket_policy" "consumer_mapping_s3_bucket" { + bucket = module.s3_consumer_mappings_bucket.storage_bucket_id + policy = data.aws_iam_policy_document.consumer_mapping_s3_bucket_policy.json +} + +data "aws_iam_policy_document" "consumer_mapping_s3_bucket_policy" { + statement { + sid = "AllowSslRequestsOnly" + actions = [ + "s3:*", + ] + effect = "Deny" + resources = [ + module.s3_consumer_mappings_bucket.storage_bucket_arn, + "${module.s3_consumer_mappings_bucket.storage_bucket_arn}/*", + ] + principals { + type = "*" + identifiers = ["*"] + } + condition { + test = "Bool" + values = [ + "false", + ] + + variable = "aws:SecureTransport" + } + } +} + +# audit bucket resource "aws_s3_bucket_policy" "audit_s3_bucket" { bucket = module.s3_audit_bucket.storage_bucket_id policy = data.aws_iam_policy_document.audit_s3_bucket_policy.json @@ -136,12 +190,18 @@ data "aws_iam_policy_document" "audit_s3_bucket_policy" { } # Attach s3 read policy to Lambda role -resource "aws_iam_role_policy" "lambda_s3_read_policy" { +resource "aws_iam_role_policy" "lambda_s3_rules_read_policy" { name = "S3ReadAccess" role = aws_iam_role.eligibility_lambda_role.id policy = data.aws_iam_policy_document.s3_rules_bucket_policy.json } +resource "aws_iam_role_policy" "lambda_s3_mapping_read_policy" { + name = "S3ConsumerMappingReadAccess" + role = aws_iam_role.eligibility_lambda_role.id + policy = data.aws_iam_policy_document.s3_consumer_mapping_bucket_policy.json +} + # Attach s3 write policy to kinesis firehose role resource "aws_iam_role_policy" "kinesis_firehose_s3_write_policy" { name = "S3WriteAccess" @@ -290,6 +350,41 @@ resource "aws_kms_key_policy" "s3_rules_kms_key" { policy = data.aws_iam_policy_document.s3_rules_kms_key_policy.json } +data "aws_iam_policy_document" "s3_consumer_mapping_kms_key_policy" { + #checkov:skip=CKV_AWS_111: Root user needs full KMS key management + #checkov:skip=CKV_AWS_356: Root user needs full KMS key management + #checkov:skip=CKV_AWS_109: Root user needs full KMS key management + statement { + sid = "EnableIamUserPermissions" + effect = "Allow" + principals { + type = "AWS" + identifiers = ["arn:aws:iam::${data.aws_caller_identity.current.account_id}:root"] + } + actions = ["kms:*"] + resources = ["*"] + } + + #checkov:skip=CKV_AWS_111: Permission boundary enforces restrictions for this policy + #checkov:skip=CKV_AWS_356: Permission boundary enforces resource-level controls + #checkov:skip=CKV_AWS_109: Permission boundary governs write-access constraints + statement { + sid = "AllowLambdaDecrypt" + effect = "Allow" + principals { + type = "AWS" + identifiers = [aws_iam_role.eligibility_lambda_role.arn] + } + actions = ["kms:Decrypt"] + resources = ["*"] + } +} + +resource "aws_kms_key_policy" "s3_consumer_mapping_kms_key" { + key_id = module.s3_consumer_mappings_bucket.storage_bucket_kms_key_id + policy = data.aws_iam_policy_document.s3_consumer_mapping_kms_key_policy.json +} + resource "aws_iam_role_policy" "splunk_firehose_policy" { #checkov:skip=CKV_AWS_290: Firehose requires write access to dynamic log streams without static constraints #checkov:skip=CKV_AWS_355: Firehose logging requires wildcard resource for CloudWatch log groups/streams diff --git a/infrastructure/stacks/api-layer/lambda.tf b/infrastructure/stacks/api-layer/lambda.tf index 9b31fee49..f87c36588 100644 --- a/infrastructure/stacks/api-layer/lambda.tf +++ b/infrastructure/stacks/api-layer/lambda.tf @@ -11,27 +11,28 @@ data "aws_subnet" "private_subnets" { } module "eligibility_signposting_lambda_function" { - source = "../../modules/lambda" - eligibility_lambda_role_arn = aws_iam_role.eligibility_lambda_role.arn - eligibility_lambda_role_name = aws_iam_role.eligibility_lambda_role.name - workspace = local.workspace - environment = var.environment - runtime = "python3.13" - lambda_func_name = "${terraform.workspace == "default" ? "" : "${terraform.workspace}-"}eligibility_signposting_api" + source = "../../modules/lambda" + eligibility_lambda_role_arn = aws_iam_role.eligibility_lambda_role.arn + eligibility_lambda_role_name = aws_iam_role.eligibility_lambda_role.name + workspace = local.workspace + environment = var.environment + runtime = "python3.13" + lambda_func_name = "${terraform.workspace == "default" ? "" : "${terraform.workspace}-"}eligibility_signposting_api" security_group_ids = [data.aws_security_group.main_sg.id] - vpc_intra_subnets = [for v in data.aws_subnet.private_subnets : v.id] - file_name = "../../../dist/lambda.zip" - handler = "eligibility_signposting_api.app.lambda_handler" - eligibility_rules_bucket_name = module.s3_rules_bucket.storage_bucket_name - eligibility_status_table_name = module.eligibility_status_table.table_name - kinesis_audit_stream_to_s3_name = module.eligibility_audit_firehose_delivery_stream.firehose_stream_name - hashing_secret_name = module.secrets_manager.aws_hashing_secret_name - lambda_insights_extension_version = 38 - log_level = "INFO" - enable_xray_patching = "true" - stack_name = local.stack_name - provisioned_concurrency_count = 5 - api_domain_name = local.api_domain_name + vpc_intra_subnets = [for v in data.aws_subnet.private_subnets : v.id] + file_name = "../../../dist/lambda.zip" + handler = "eligibility_signposting_api.app.lambda_handler" + eligibility_rules_bucket_name = module.s3_rules_bucket.storage_bucket_name + eligibility_consumer_mappings_bucket_name = module.s3_consumer_mappings_bucket.storage_bucket_name + eligibility_status_table_name = module.eligibility_status_table.table_name + kinesis_audit_stream_to_s3_name = module.eligibility_audit_firehose_delivery_stream.firehose_stream_name + hashing_secret_name = module.secrets_manager.aws_hashing_secret_name + lambda_insights_extension_version = 38 + log_level = "INFO" + enable_xray_patching = "true" + stack_name = local.stack_name + provisioned_concurrency_count = 5 + api_domain_name = local.api_domain_name } # ----------------------------------------------------------------------------- diff --git a/infrastructure/stacks/api-layer/s3_buckets.tf b/infrastructure/stacks/api-layer/s3_buckets.tf index 1a94f7284..276e71354 100644 --- a/infrastructure/stacks/api-layer/s3_buckets.tf +++ b/infrastructure/stacks/api-layer/s3_buckets.tf @@ -7,6 +7,15 @@ module "s3_rules_bucket" { workspace = terraform.workspace } +module "s3_consumer_mappings_bucket" { + source = "../../modules/s3" + bucket_name = "eli-consumer-map" + environment = var.environment + project_name = var.project_name + stack_name = local.stack_name + workspace = terraform.workspace +} + module "s3_audit_bucket" { source = "../../modules/s3" bucket_name = "eli-audit" diff --git a/src/eligibility_signposting_api/common/api_error_response.py b/src/eligibility_signposting_api/common/api_error_response.py index 90d8d1909..2fe98bbe3 100644 --- a/src/eligibility_signposting_api/common/api_error_response.py +++ b/src/eligibility_signposting_api/common/api_error_response.py @@ -135,3 +135,11 @@ def log_and_generate_response( fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED, fhir_display_message="Access has been denied to process this request.", ) + +CONSUMER_ID_NOT_PROVIDED_ERROR = APIErrorResponse( + status_code=HTTPStatus.FORBIDDEN, + fhir_issue_code=FHIRIssueCode.FORBIDDEN, + fhir_issue_severity=FHIRIssueSeverity.ERROR, + fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED, + fhir_display_message="Access has been denied to process this request.", +) diff --git a/src/eligibility_signposting_api/common/request_validator.py b/src/eligibility_signposting_api/common/request_validator.py index 79e967c3c..75de288fd 100644 --- a/src/eligibility_signposting_api/common/request_validator.py +++ b/src/eligibility_signposting_api/common/request_validator.py @@ -7,12 +7,13 @@ from flask.typing import ResponseReturnValue from eligibility_signposting_api.common.api_error_response import ( + CONSUMER_ID_NOT_PROVIDED_ERROR, INVALID_CATEGORY_ERROR, INVALID_CONDITION_FORMAT_ERROR, INVALID_INCLUDE_ACTIONS_ERROR, NHS_NUMBER_ERROR, ) -from eligibility_signposting_api.config.constants import NHS_NUMBER_HEADER +from eligibility_signposting_api.config.constants import CONSUMER_ID, NHS_NUMBER_HEADER logger = logging.getLogger(__name__) @@ -50,6 +51,13 @@ def validate_request_params() -> Callable: def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> ResponseReturnValue: # noqa:ANN002,ANN003 + consumer_id = request.headers.get(CONSUMER_ID) + if not consumer_id: + message = "You are not authorised to request" + return CONSUMER_ID_NOT_PROVIDED_ERROR.log_and_generate_response( + log_message=message, diagnostics=message + ) + path_nhs_number = str(kwargs.get("nhs_number")) if kwargs.get("nhs_number") else None if not path_nhs_number: diff --git a/src/eligibility_signposting_api/config/config.py b/src/eligibility_signposting_api/config/config.py index 6be1840aa..52f3111cc 100644 --- a/src/eligibility_signposting_api/config/config.py +++ b/src/eligibility_signposting_api/config/config.py @@ -22,6 +22,7 @@ def config() -> dict[str, Any]: person_table_name = TableName(os.getenv("PERSON_TABLE_NAME", "test_eligibility_datastore")) rules_bucket_name = BucketName(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket")) + consumer_mapping_bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket")) audit_bucket_name = BucketName(os.getenv("AUDIT_BUCKET_NAME", "test-audit-bucket")) hashing_secret_name = HashSecretName(os.getenv("HASHING_SECRET_NAME", "test_secret")) aws_default_region = AwsRegion(os.getenv("AWS_DEFAULT_REGION", "eu-west-1")) @@ -41,6 +42,7 @@ def config() -> dict[str, Any]: "s3_endpoint": None, "rules_bucket_name": rules_bucket_name, "audit_bucket_name": audit_bucket_name, + "consumer_mapping_bucket_name": consumer_mapping_bucket_name, "firehose_endpoint": None, "kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3, "enable_xray_patching": enable_xray_patching, @@ -59,6 +61,7 @@ def config() -> dict[str, Any]: "s3_endpoint": URL(os.getenv("S3_ENDPOINT", local_stack_endpoint)), "rules_bucket_name": rules_bucket_name, "audit_bucket_name": audit_bucket_name, + "consumer_mapping_bucket_name": consumer_mapping_bucket_name, "firehose_endpoint": URL(os.getenv("FIREHOSE_ENDPOINT", local_stack_endpoint)), "kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3, "enable_xray_patching": enable_xray_patching, diff --git a/src/eligibility_signposting_api/config/constants.py b/src/eligibility_signposting_api/config/constants.py index 3aa45fd35..6deb0b4ec 100644 --- a/src/eligibility_signposting_api/config/constants.py +++ b/src/eligibility_signposting_api/config/constants.py @@ -3,4 +3,5 @@ URL_PREFIX = "patient-check" RULE_STOP_DEFAULT = False NHS_NUMBER_HEADER = "nhs-login-nhs-number" +CONSUMER_ID = "nhsd-application-id" # "Nhsd-Application-Id" ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"] diff --git a/src/eligibility_signposting_api/model/consumer_mapping.py b/src/eligibility_signposting_api/model/consumer_mapping.py new file mode 100644 index 000000000..a0ee8d415 --- /dev/null +++ b/src/eligibility_signposting_api/model/consumer_mapping.py @@ -0,0 +1,17 @@ +from typing import NewType + +from pydantic import BaseModel, Field, RootModel + +from eligibility_signposting_api.model.campaign_config import CampaignID + +ConsumerId = NewType("ConsumerId", str) + + +class ConsumerCampaign(BaseModel): + campaign_config_id: CampaignID = Field(alias="CampaignConfigID") + description: str | None = Field(default=None, alias="Description") + + +class ConsumerMapping(RootModel[dict[ConsumerId, list[ConsumerCampaign]]]): + def get(self, key: ConsumerId, default: list[ConsumerCampaign] | None = None) -> list[ConsumerCampaign] | None: + return self.root.get(key, default) diff --git a/src/eligibility_signposting_api/repos/consumer_mapping_repo.py b/src/eligibility_signposting_api/repos/consumer_mapping_repo.py new file mode 100644 index 000000000..122291aac --- /dev/null +++ b/src/eligibility_signposting_api/repos/consumer_mapping_repo.py @@ -0,0 +1,41 @@ +import json +from typing import Annotated, NewType + +from botocore.client import BaseClient +from wireup import Inject, service + +from eligibility_signposting_api.model.campaign_config import CampaignID +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping + +BucketName = NewType("BucketName", str) + + +@service +class ConsumerMappingRepo: + """Repository class for Consumer Mapping""" + + def __init__( + self, + s3_client: Annotated[BaseClient, Inject(qualifier="s3")], + bucket_name: Annotated[BucketName, Inject(param="consumer_mapping_bucket_name")], + ) -> None: + super().__init__() + self.s3_client = s3_client + self.bucket_name = bucket_name + + def get_permitted_campaign_ids(self, consumer_id: ConsumerId) -> list[CampaignID] | None: + objects = self.s3_client.list_objects(Bucket=self.bucket_name).get("Contents") + + if not objects: + return None + + consumer_mappings_obj = objects[0] + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=consumer_mappings_obj["Key"]) + body = response["Body"].read() + + mapping_result = ConsumerMapping.model_validate(json.loads(body)).get(consumer_id) + + if mapping_result is None: + return None + + return [item.campaign_config_id for item in mapping_result] diff --git a/src/eligibility_signposting_api/services/eligibility_services.py b/src/eligibility_signposting_api/services/eligibility_services.py index 79934e174..13b701d61 100644 --- a/src/eligibility_signposting_api/services/eligibility_services.py +++ b/src/eligibility_signposting_api/services/eligibility_services.py @@ -3,7 +3,10 @@ from wireup import service from eligibility_signposting_api.model import eligibility_status +from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId from eligibility_signposting_api.repos import CampaignRepo, NotFoundError, PersonRepo +from eligibility_signposting_api.repos.consumer_mapping_repo import ConsumerMappingRepo from eligibility_signposting_api.services.calculators import eligibility_calculator as calculator logger = logging.getLogger(__name__) @@ -23,12 +26,14 @@ def __init__( self, person_repo: PersonRepo, campaign_repo: CampaignRepo, + consumer_mapping_repo: ConsumerMappingRepo, calculator_factory: calculator.EligibilityCalculatorFactory, ) -> None: super().__init__() self.person_repo = person_repo self.campaign_repo = campaign_repo self.calculator_factory = calculator_factory + self.consumer_mapping = consumer_mapping_repo def get_eligibility_status( self, @@ -36,16 +41,33 @@ def get_eligibility_status( include_actions: str, conditions: list[str], category: str, + consumer_id: str, ) -> eligibility_status.EligibilityStatus: """Calculate a person's eligibility for vaccination given an NHS number.""" if nhs_number: try: person_data = self.person_repo.get_eligibility_data(nhs_number) - campaign_configs = list(self.campaign_repo.get_campaign_configs()) except NotFoundError as e: raise UnknownPersonError from e else: - calc: calculator.EligibilityCalculator = self.calculator_factory.get(person_data, campaign_configs) + campaign_configs: list[CampaignConfig] = list(self.campaign_repo.get_campaign_configs()) + permitted_campaign_configs = self.__collect_permitted_campaign_configs( + campaign_configs, ConsumerId(consumer_id) + ) + calc: calculator.EligibilityCalculator = self.calculator_factory.get( + person_data, permitted_campaign_configs + ) return calc.get_eligibility_status(include_actions, conditions, category) raise UnknownPersonError # pragma: no cover + + def __collect_permitted_campaign_configs( + self, campaign_configs: list[CampaignConfig], consumer_id: ConsumerId + ) -> list[CampaignConfig]: + permitted_campaign_ids = self.consumer_mapping.get_permitted_campaign_ids(ConsumerId(consumer_id)) + if permitted_campaign_ids: + permitted_campaign_configs: list[CampaignConfig] = [ + campaign for campaign in campaign_configs if campaign.id in permitted_campaign_ids + ] + return permitted_campaign_configs + return [] diff --git a/src/eligibility_signposting_api/views/eligibility.py b/src/eligibility_signposting_api/views/eligibility.py index eb2b706ea..b935678f6 100644 --- a/src/eligibility_signposting_api/views/eligibility.py +++ b/src/eligibility_signposting_api/views/eligibility.py @@ -11,9 +11,12 @@ from eligibility_signposting_api.audit.audit_context import AuditContext from eligibility_signposting_api.audit.audit_service import AuditService -from eligibility_signposting_api.common.api_error_response import NHS_NUMBER_NOT_FOUND_ERROR +from eligibility_signposting_api.common.api_error_response import ( + NHS_NUMBER_NOT_FOUND_ERROR, +) from eligibility_signposting_api.common.request_validator import validate_request_params -from eligibility_signposting_api.config.constants import URL_PREFIX +from eligibility_signposting_api.config.constants import CONSUMER_ID, URL_PREFIX +from eligibility_signposting_api.model.consumer_mapping import ConsumerId from eligibility_signposting_api.model.eligibility_status import Condition, EligibilityStatus, NHSNumber, Status from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.views.response_model import eligibility_response @@ -47,13 +50,17 @@ def check_eligibility( nhs_number: NHSNumber, eligibility_service: Injected[EligibilityService], audit_service: Injected[AuditService] ) -> ResponseReturnValue: logger.info("checking nhs_number %r in %r", nhs_number, eligibility_service, extra={"nhs_number": nhs_number}) + + query_params = _get_or_default_query_params() + consumer_id = _get_consumer_id_from_headers() + try: - query_params = get_or_default_query_params() eligibility_status = eligibility_service.get_eligibility_status( nhs_number, query_params["includeActions"], query_params["conditions"], query_params["category"], + consumer_id, ) except UnknownPersonError: return handle_unknown_person_error(nhs_number) @@ -63,7 +70,14 @@ def check_eligibility( return make_response(response.model_dump(by_alias=True, mode="json", exclude_none=True), HTTPStatus.OK) -def get_or_default_query_params() -> dict[str, Any]: +def _get_consumer_id_from_headers() -> ConsumerId: + """ + @validate_request_params() ensures the consumer ID is never null at this stage. + """ + return ConsumerId(request.headers.get(CONSUMER_ID, "")) + + +def _get_or_default_query_params() -> dict[str, Any]: default_query_params = {"category": "ALL", "conditions": ["ALL"], "includeActions": "Y"} if not request.args: diff --git a/tests/fixtures/builders/model/rule.py b/tests/fixtures/builders/model/rule.py index bf62de900..2793ea032 100644 --- a/tests/fixtures/builders/model/rule.py +++ b/tests/fixtures/builders/model/rule.py @@ -93,7 +93,7 @@ class IterationFactory(ModelFactory[Iteration]): class RawCampaignConfigFactory(ModelFactory[CampaignConfig]): iterations = Use(IterationFactory.batch, size=2) - + id = "42-hi5tch-hi5kers-gu5ide-t2o-t3he-gal6axy" start_date = Use(past_date) end_date = Use(future_date) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e76dc9c6f..9ba0968d3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -30,6 +30,7 @@ StartDate, StatusText, ) +from eligibility_signposting_api.model.consumer_mapping import ConsumerCampaign, ConsumerId, ConsumerMapping from eligibility_signposting_api.processors.hashing_service import HashingService, HashSecretName from eligibility_signposting_api.repos import SecretRepo from eligibility_signposting_api.repos.campaign_repo import BucketName @@ -49,6 +50,8 @@ AWS_CURRENT_SECRET = "test_value" # noqa: S105 AWS_PREVIOUS_SECRET = "test_value_old" # noqa: S105 +UNIQUE_CONSUMER_HEADER = "nhsd-application-id" + @pytest.fixture(scope="session") def localstack(request: pytest.FixtureRequest) -> URL: @@ -690,6 +693,22 @@ def rules_bucket(s3_client: BaseClient) -> Generator[BucketName]: bucket_name = BucketName(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket")) s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) yield bucket_name + response = s3_client.list_objects_v2(Bucket=bucket_name) + if "Contents" in response: + objects_to_delete = [{"Key": obj["Key"]} for obj in response["Contents"]] + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": objects_to_delete}) + s3_client.delete_bucket(Bucket=bucket_name) + + +@pytest.fixture(scope="session") +def consumer_mapping_bucket(s3_client: BaseClient) -> Generator[BucketName]: + bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket")) + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + yield bucket_name + response = s3_client.list_objects_v2(Bucket=bucket_name) + if "Contents" in response: + objects_to_delete = [{"Key": obj["Key"]} for obj in response["Contents"]] + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": objects_to_delete}) s3_client.delete_bucket(Bucket=bucket_name) @@ -722,7 +741,7 @@ def firehose_delivery_stream(firehose_client: BaseClient, audit_bucket: BucketNa @pytest.fixture(scope="class") -def campaign_config(s3_client: BaseClient, rules_bucket: BucketName) -> Generator[CampaignConfig]: +def rsv_campaign_config(s3_client: BaseClient, rules_bucket: BucketName) -> Generator[CampaignConfig]: campaign: CampaignConfig = rule.CampaignConfigFactory.build( target="RSV", iterations=[ @@ -1263,6 +1282,307 @@ def campaign_config_with_missing_descriptions_missing_rule_text( s3_client.delete_object(Bucket=rules_bucket, Key=f"{campaign.name}.json") +@pytest.fixture +def campaign_configs(request, s3_client: BaseClient, rules_bucket: BucketName) -> Generator[list[CampaignConfig]]: + """Create and upload multiple campaign configs to S3, then clean up after tests.""" + campaigns, campaign_data_keys = [], [] # noqa: F841 + + raw = getattr( + request, "param", [("RSV", "RSV_campaign_id"), ("COVID", "COVID_campaign_id"), ("FLU", "FLU_campaign_id")] + ) + + targets = [] + campaign_id = [] + status = [] + + for t, _id, *rest in raw: + targets.append(t) + campaign_id.append(_id) + status.append(rest[0] if rest else None) + + for i in range(len(targets)): + campaign: CampaignConfig = rule.CampaignConfigFactory.build( + name=f"campaign_{i}", + id=campaign_id[i], + target=targets[i], + type="V", + iterations=[ + rule.IterationFactory.build( + iteration_rules=[ + rule.PostcodeSuppressionRuleFactory.build(type=RuleType.filter), + rule.PersonAgeSuppressionRuleFactory.build(), + rule.PersonAgeSuppressionRuleFactory.build(name="Exclude 76 rolling", description=""), + ], + iteration_cohorts=[ + rule.IterationCohortFactory.build( + cohort_label="cohort1", + cohort_group="cohort_group1", + positive_description="", + negative_description="", + ) + ], + status_text=None, + ) + ], + ) + + if status[i] == "inactive": + campaign.iterations[0].iteration_date = datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(days=7) + + campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)} + key = f"{campaign.name}.json" + s3_client.put_object( + Bucket=rules_bucket, Key=key, Body=json.dumps(campaign_data), ContentType="application/json" + ) + campaign_id.append(campaign) + campaign_data_keys.append(key) + + yield campaign_id + + for key in campaign_data_keys: + s3_client.delete_object(Bucket=rules_bucket, Key=key) + + +@pytest.fixture(scope="class") +def consumer_id() -> ConsumerId: + return ConsumerId("23-mic7heal-jor6don") + + +def create_and_put_consumer_mapping_in_s3( + campaign_config: CampaignConfig, consumer_id: str, consumer_mapping_bucket, s3_client +) -> ConsumerMapping: + consumer_mapping = ConsumerMapping.model_validate({}) + campaign_entry = ConsumerCampaign( + CampaignConfigID=campaign_config.id, Description="Test description for campaign mapping" + ) + + consumer_mapping.root[ConsumerId(consumer_id)] = [campaign_entry] + consumer_mapping_data = consumer_mapping.model_dump(by_alias=True) + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(consumer_mapping_data), + ContentType="application/json", + ) + return consumer_mapping + + +@pytest.fixture(scope="class") +def consumer_to_active_campaign_having_invalid_tokens_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_invalid_tokens: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_invalid_tokens, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_to_active_campaign_having_tokens_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_tokens: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_tokens, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_to_active_rsv_campaign_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + rsv_campaign_config: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + rsv_campaign_config, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_to_active_campaign_having_and_rule_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_and_rule: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_and_rule, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_missing_descriptions_missing_rule_text, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_having_rules_with_rule_code_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_rules_having_rule_code: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_rules_having_rule_code, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_having_rules_with_rule_mapper_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_rules_having_rule_mapper: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_rules_having_rule_mapper, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_having_only_virtual_cohort_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_virtual_cohort: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_virtual_cohort, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_config_with_derived_values_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_derived_values: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_derived_values, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_config_with_derived_values_formatted_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_derived_values_formatted: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_derived_values_formatted, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_active_campaign_config_with_multiple_add_days_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_multiple_add_days: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_multiple_add_days, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_to_campaign_having_inactive_iteration_mapping( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + inactive_iteration_config: list[CampaignConfig], + consumer_id: ConsumerId, +): + mapping = ConsumerMapping.model_validate({}) + mapping.root[consumer_id] = [ + ConsumerCampaign(CampaignConfigID=cc.id, Description=f"Description for {cc.id}") + for cc in inactive_iteration_config + ] + + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(mapping.model_dump(by_alias=True)), + ContentType="application/json", + ) + yield mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_to_multiple_campaign_configs_mapping( + multiple_campaign_configs: list[CampaignConfig], + consumer_id: ConsumerId, + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, +) -> Generator[ConsumerMapping]: + mapping = ConsumerMapping.model_validate({}) + mapping.root[consumer_id] = [ + ConsumerCampaign(CampaignConfigID=cc.id, Description=f"Description for {cc.id}") + for cc in multiple_campaign_configs + ] + + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(mapping.model_dump(by_alias=True)), + ContentType="application/json", + ) + yield mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mappings( + request, s3_client: BaseClient, consumer_mapping_bucket: BucketName +) -> Generator[ConsumerMapping]: + consumer_mapping = ConsumerMapping.model_validate(getattr(request, "param", {})) + consumer_mapping_data = consumer_mapping.model_dump(by_alias=True) + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(consumer_mapping_data), + ContentType="application/json", + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + # If you put StubSecretRepo in a separate module, import it instead class StubSecretRepo(SecretRepo): # def __init__(self, current: str = AWS_CURRENT_SECRET, previous: str = AWS_PREVIOUS_SECRET): diff --git a/tests/integration/in_process/test_derived_values.py b/tests/integration/in_process/test_derived_values.py index 29574e921..2996b0daa 100644 --- a/tests/integration/in_process/test_derived_values.py +++ b/tests/integration/in_process/test_derived_values.py @@ -43,8 +43,9 @@ none, ) -from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping from eligibility_signposting_api.model.eligibility_status import NHSNumber +from tests.integration.conftest import UNIQUE_CONSUMER_HEADER class TestDerivedValues: @@ -54,7 +55,8 @@ def test_add_days_calculates_next_dose_due_from_last_successful_date( self, client: FlaskClient, person_with_covid_vaccination: NHSNumber, - campaign_config_with_derived_values: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_campaign_config_with_derived_values_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): """ @@ -77,7 +79,7 @@ def test_add_days_calculates_next_dose_due_from_last_successful_date( ] """ # Given - headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination)} + headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination), UNIQUE_CONSUMER_HEADER: str(consumer_id)} # When response = client.get( @@ -126,7 +128,8 @@ def test_add_days_with_formatted_date_output( self, client: FlaskClient, person_with_covid_vaccination: NHSNumber, - campaign_config_with_derived_values_formatted: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_campaign_config_with_derived_values_formatted_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): """ @@ -140,7 +143,7 @@ def test_add_days_with_formatted_date_output( - DateOfNextEarliestVaccination shows "29 April 2026" (formatted output) """ # Given - headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination)} + headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination), UNIQUE_CONSUMER_HEADER: str(consumer_id)} # When response = client.get( @@ -179,7 +182,8 @@ class TestMultipleActionsWithAddDays: def test_multiple_actions_with_different_add_days_parameters( self, client: FlaskClient, - campaign_config_with_multiple_add_days: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_campaign_config_with_multiple_add_days_mapping: ConsumerMapping, # noqa: ARG002 person_with_covid_vaccination: NHSNumber, ): """ @@ -207,7 +211,7 @@ def test_multiple_actions_with_different_add_days_parameters( function with different parameters. """ # Given - headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination)} + headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination), UNIQUE_CONSUMER_HEADER: str(consumer_id)} # When response = client.get( diff --git a/tests/integration/in_process/test_eligibility_endpoint.py b/tests/integration/in_process/test_eligibility_endpoint.py index 6f01e3147..7e25c2813 100644 --- a/tests/integration/in_process/test_eligibility_endpoint.py +++ b/tests/integration/in_process/test_eligibility_endpoint.py @@ -1,3 +1,4 @@ +import json from http import HTTPStatus import pytest @@ -8,16 +9,21 @@ from hamcrest import ( assert_that, contains_exactly, + contains_inanyorder, equal_to, has_entries, has_entry, has_key, ) -from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.campaign_config import CampaignConfig, RuleComparator +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping from eligibility_signposting_api.model.eligibility_status import ( NHSNumber, ) +from eligibility_signposting_api.repos.campaign_repo import BucketName +from tests.fixtures.builders.model import rule +from tests.integration.conftest import UNIQUE_CONSUMER_HEADER class TestBaseLine: @@ -25,11 +31,12 @@ def test_nhs_number_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}", headers=headers) @@ -50,13 +57,15 @@ def test_nhs_number_given_in_path_but_no_nhs_number_header_present( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 headers: dict, ): # Given # When - response = client.get(f"/patient-check/{persisted_person}", headers=headers) + response = client.get( + f"/patient-check/{persisted_person}", headers={UNIQUE_CONSUMER_HEADER: "some-id"} | headers + ) # Then assert_that( @@ -71,16 +80,19 @@ def test_nhs_number_given_in_path_but_no_nhs_number_header_present( {"nhs-login-nhs-number": ""}, # header present but blank, invalid ], ) - def test_nhs_number_in_path_and_header_present_but_empty_or_none( + def test_nhs_number_in_path_and_header_present_but_empty_or_none( # noqa: PLR0913 self, headers: dict, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): # When - response = client.get(f"/patient-check/{persisted_person}", headers=headers) + response = client.get( + f"/patient-check/{persisted_person}", headers={UNIQUE_CONSUMER_HEADER: consumer_id} | headers + ) # Then assert_that( @@ -94,14 +106,17 @@ def test_nhs_number_given_but_header_nhs_number_doesnt_match( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given headers = {"nhs-login-nhs-number": f"123{persisted_person!s}"} # When - response = client.get(f"/patient-check/{persisted_person}", headers=headers) + response = client.get( + f"/patient-check/{persisted_person}", headers={UNIQUE_CONSUMER_HEADER: consumer_id} | headers + ) # Then assert_that( @@ -129,7 +144,6 @@ def test_no_nhs_number_given_but_header_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 ): # Given headers = {"nhs-login-nhs-number": str(persisted_person)} @@ -151,10 +165,12 @@ def test_not_base_eligible( self, client: FlaskClient, persisted_person_no_cohorts: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts)} + headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_no_cohorts}?includeActions=Y", headers=headers) @@ -195,10 +211,12 @@ def test_not_eligible_by_rule( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -239,10 +257,12 @@ def test_not_actionable_and_check_response_when_no_rule_code_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -289,9 +309,11 @@ def test_actionable( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -340,10 +362,12 @@ def test_actionable_with_and_rule( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_and_rule: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_to_active_campaign_having_and_rule_mapping: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -394,10 +418,12 @@ def test_not_eligible_by_rule_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_having_only_virtual_cohort_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -438,10 +464,12 @@ def test_not_actionable_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_having_only_virtual_cohort_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -488,10 +516,12 @@ def test_actionable_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_having_only_virtual_cohort_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -542,10 +572,12 @@ def test_not_base_eligible( self, client: FlaskClient, persisted_person_no_cohorts: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts)} + headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_no_cohorts}?includeActions=Y", headers=headers) @@ -580,10 +612,12 @@ def test_not_eligible_by_rule( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -618,10 +652,12 @@ def test_not_actionable( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -662,10 +698,12 @@ def test_actionable( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -708,10 +746,12 @@ def test_actionable_no_actions( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_missing_descriptions_and_rule_text_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=N", headers=headers) @@ -782,10 +822,12 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_absent_but_rule_c self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_rules_having_rule_code: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_having_rules_with_rule_code_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -832,10 +874,12 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_rules_having_rule_mapper: CampaignConfig, # noqa: ARG002 + consumer_to_active_campaign_having_rules_with_rule_mapper_mapping: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -877,3 +921,461 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_given( ) ), ) + + @pytest.mark.parametrize( + ( + "campaign_configs", + "consumer_mappings", + "consumer_id", + "requested_conditions", + "requested_category", + "expected_targets", + ), + [ + # ============================================================ + # Group 1: Consumer is mapped, campaign exists in S3, requesting + # ============================================================ + # 1.1 Consumer is mapped; multiple active campaigns exist; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "ALL", + "VACCINATIONS", + ["RSV", "COVID"], + ), + # 1.2 Consumer is mapped; requested single campaign exists and is mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV", + "VACCINATIONS", + ["RSV"], + ), + # 1.3 Consumer is mapped; requested multiple campaigns exist and are mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV,COVID", + "VACCINATIONS", + ["RSV", "COVID"], + ), + # ============================================================ + # Group 2: Consumer is mapped, campaign does NOT exist in S3 + # ============================================================ + # 2.1 Consumer is mapped; requested campaign exists in S3 but not mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "FLU", + "VACCINATIONS", + [], + ), + # 2.2 Consumer is mapped, but none of the mapped campaigns exist in S3 + ( + [ + ("MMR", "MMR_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 2.3 Consumer is mapped; requested mapped campaign does not exist in S3 + ( + [ + ("MMR", "MMR_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + # ============================================================ + # Group 3: Consumer is NOT mapped, campaign exists in S3 + # ============================================================ + # 3.1 Consumer not mapped; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "another-consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 3.2 Consumer not mapped; requesting specific campaign + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "another-consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + # ============================================================ + # Group 4: Consumer NOT mapped, campaign does NOT exist in S3 + # ============================================================ + # 4.1 Consumer mapped; requested campaign does not exist + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"CampaignConfigID": "RSV_campaign_id"}, + {"CampaignConfigID": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "HPV", + "VACCINATIONS", + [], + ), + # 4.2 No consumer mappings; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + {}, + "consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 4.3 No consumer mappings; requesting specific campaign + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + {}, + "consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + ], + indirect=["campaign_configs", "consumer_mappings"], + ) + def test_valid_response_when_consumer_has_a_valid_campaign_config_mapping( # noqa: PLR0913 + self, + client: FlaskClient, + persisted_person: NHSNumber, + secretsmanager_client: BaseClient, # noqa: ARG002 + campaign_configs: CampaignConfig, # noqa: ARG002 + consumer_mappings: ConsumerMapping, # noqa: ARG002 + consumer_id: str, + requested_conditions: str, + requested_category: str, + expected_targets: list[str], + ): + # Given + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} + + # When + response = client.get( + f"/patient-check/{persisted_person}?includeActions=Y&category={requested_category}&conditions={requested_conditions}", + headers=headers, + ) + + assert_that( + response, + is_response() + .with_status_code(HTTPStatus.OK) + .and_text( + is_json_that( + has_entry( + "processedSuggestions", + # This ensures ONLY these items exist, no extras like FLU + contains_inanyorder(*[has_entry("condition", i) for i in expected_targets]), + ) + ) + ), + ) + + @pytest.mark.parametrize( + ("consumer_id", "expected_campaign_id"), + [ + # Consumer requesting for ALL + # Consumer is mapped only to RSV_campaign_id_1 + ("consumer-id-1", "RSV_campaign_id_1"), + # Consumer is mapped only to RSV_campaign_id_2 + ("consumer-id-2", "RSV_campaign_id_2"), + # Edge-case : Consumer-id-3a is mapped to multiple active campaigns, so only one taken. + ("consumer-id-3a", "RSV_campaign_id_3"), + # Edge-case : Consumer-id-3b is mapped to multiple active campaigns, so only one taken. + ("consumer-id-3b", "RSV_campaign_id_3"), + # Edge-case : Consumer is mapped to inactive inactive_RSV_campaign_id_5 and active RSV_campaign_id_6 + ("consumer-id-4", "RSV_campaign_id_6"), + # Edge-case : Consumer is mapped only to inactive RSV_campaign_id_5 + ("consumer-id-5", None), + ], + ) + @pytest.mark.parametrize( + ("campaign_configs", "consumer_mappings", "requested_conditions", "requested_category"), + # consumer mappings and campaign configs are static here + [ + ( + [ + # Campaign configs in S3 + ("RSV", "RSV_campaign_id_1"), + ("RSV", "RSV_campaign_id_2"), + ("RSV", "RSV_campaign_id_3"), + ("RSV", "RSV_campaign_id_4"), + ("RSV", "inactive_RSV_campaign_id_5", "inactive"), # inactive iteration + ("RSV", "RSV_campaign_id_6"), + ], + { + # Consumer mappings in S3 + "consumer-id-1": [{"CampaignConfigID": "RSV_campaign_id_1"}], + "consumer-id-2": [{"CampaignConfigID": "RSV_campaign_id_2"}], + "consumer-id-3a": [ + {"CampaignConfigID": "RSV_campaign_id_3"}, + {"CampaignConfigID": "RSV_campaign_id_4"}, + ], + "consumer-id-3b": [ + {"CampaignConfigID": "RSV_campaign_id_4"}, + {"CampaignConfigID": "RSV_campaign_id_3"}, + ], + "consumer-id-4": [ + {"CampaignConfigID": "inactive_RSV_campaign_id_5"}, + {"CampaignConfigID": "RSV_campaign_id_6"}, + ], + "consumer-id-5": [{"CampaignConfigID": "inactive_RSV_campaign_id_5"}], + }, + "RSV", + "VACCINATIONS", + ) + ], + indirect=["campaign_configs", "consumer_mappings"], + ) + def test_if_correct_campaign_is_chosen_for_the_consumer_when_multiple_campaign_exists_per_target_giving_same_status( # noqa : PLR0913 + self, + client: FlaskClient, + persisted_person: NHSNumber, + secretsmanager_client: BaseClient, # noqa: ARG002 + audit_bucket: BucketName, + s3_client: BaseClient, + campaign_configs: CampaignConfig, # noqa: ARG002 + consumer_mappings: ConsumerMapping, # noqa: ARG002 + consumer_id: str, + requested_conditions: str, + requested_category: str, + expected_campaign_id: list[str], + ): + # Given + headers = {"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id} + + # When + client.get( + f"/patient-check/{persisted_person}?includeActions=Y&category={requested_category}&conditions={requested_conditions}", + headers=headers, + ) + + objects = s3_client.list_objects_v2(Bucket=audit_bucket).get("Contents", []) + object_keys = [obj["Key"] for obj in objects] + latest_key = sorted(object_keys)[-1] + audit_data = json.loads(s3_client.get_object(Bucket=audit_bucket, Key=latest_key)["Body"].read()) + + # Then + if expected_campaign_id is not None: + assert_that(len(audit_data["response"]["condition"]), equal_to(1)) + assert_that(audit_data["response"]["condition"][0].get("campaignId"), equal_to(expected_campaign_id)) + else: + assert_that(len(audit_data["response"]["condition"]), equal_to(0)) + + def test_if_campaign_having_best_status_is_chosen_if_there_exists_multiple_campaign_per_target( # noqa : PLR0913 + self, + client: FlaskClient, + persisted_person_pc_sw19: NHSNumber, + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + rules_bucket: BucketName, + secretsmanager_client: BaseClient, # noqa: ARG002 + ): + # Given + consumer_id = "consumer-n3bs-jo4hn-ce4na" + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), UNIQUE_CONSUMER_HEADER: consumer_id} + + # Consumer Mapping Data + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps( + { + consumer_id: [ + {"CampaignConfigID": "RSV_campaign_id_not_actionable"}, + {"CampaignConfigID": "RSV_campaign_id_actionable"}, + ], + } + ), + ContentType="application/json", + ) + + # Campaign configs + campaign_1 = rule.CampaignConfigFactory.build( + id="RSV_campaign_id_not_actionable", + target="RSV", + type="V", + iterations=[ + rule.IterationFactory.build( + iteration_rules=[ + rule.PostcodeSuppressionRuleFactory.build(name="Exclude SW19", description=""), + ], + iteration_cohorts=[ + rule.IterationCohortFactory.build( + cohort_label="cohort1", + cohort_group="cohort_group1", + positive_description="positive_description", + ) + ], + status_text=None, + ) + ], + ) + + campaign_2 = rule.CampaignConfigFactory.build( + id="RSV_campaign_id_actionable", + target="RSV", + type="V", + iterations=[ + rule.IterationFactory.build( + iteration_rules=[ + rule.PostcodeSuppressionRuleFactory.build(name="Exclude M4", comparator=RuleComparator("M4")), + ], + iteration_cohorts=[ + rule.IterationCohortFactory.build( + cohort_label="cohort1", + cohort_group="cohort_group1", + positive_description="positive_description", + ) + ], + status_text=None, + ) + ], + ) + + for campaign in [campaign_1, campaign_2]: + s3_client.put_object( + Bucket=rules_bucket, + Key=f"{campaign.id}.json", + Body=json.dumps({"CampaignConfig": campaign.model_dump(by_alias=True)}), + ContentType="application/json", + ) + + # When + response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) + + # Then + assert_that( + response, + is_response() + .with_status_code(HTTPStatus.OK) + .and_text( + is_json_that( + has_entry( + "processedSuggestions", + equal_to( + [ + { + "condition": "RSV", + "status": "Actionable", + "eligibilityCohorts": [ + { + "cohortCode": "cohort_group1", + "cohortStatus": "Actionable", + "cohortText": "positive_description", + } + ], + "actions": [ + { + "actionCode": "action_code", + "actionType": "defaultcomms", + "description": "", + "urlLabel": "", + "urlLink": "", + } + ], + "suitabilityRules": [], + "statusText": "You should have the RSV vaccine", + } + ] + ), + ) + ) + ), + ) diff --git a/tests/integration/lambda/test_app_running_as_lambda.py b/tests/integration/lambda/test_app_running_as_lambda.py index 46572689a..09953f4ca 100644 --- a/tests/integration/lambda/test_app_running_as_lambda.py +++ b/tests/integration/lambda/test_app_running_as_lambda.py @@ -24,8 +24,10 @@ from yarl import URL from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping from eligibility_signposting_api.model.eligibility_status import NHSNumber from eligibility_signposting_api.repos.campaign_repo import BucketName +from tests.integration.conftest import UNIQUE_CONSUMER_HEADER logger = logging.getLogger(__name__) @@ -34,7 +36,8 @@ def test_install_and_call_lambda_flask( lambda_client: BaseClient, flask_function: str, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, ): """Given lambda installed into localstack, run it via boto3 lambda client""" # Given @@ -49,6 +52,7 @@ def test_install_and_call_lambda_flask( "accept": "application/json", "content-type": "application/json", "nhs-login-nhs-number": str(persisted_person), + UNIQUE_CONSUMER_HEADER: consumer_id, }, "pathParameters": {"id": str(persisted_person)}, "requestContext": { @@ -85,7 +89,8 @@ def test_install_and_call_lambda_flask( def test_install_and_call_flask_lambda_over_http( persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): """Given api-gateway and lambda installed into localstack, run it via http""" @@ -94,7 +99,7 @@ def test_install_and_call_flask_lambda_over_http( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(persisted_person)}, + headers={"nhs-login-nhs-number": str(persisted_person), UNIQUE_CONSUMER_HEADER: consumer_id}, timeout=10, ) @@ -105,12 +110,14 @@ def test_install_and_call_flask_lambda_over_http( ) -def test_install_and_call_flask_lambda_with_unknown_nhs_number( +def test_install_and_call_flask_lambda_with_unknown_nhs_number( # noqa: PLR0913 flask_function: str, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, logs_client: BaseClient, api_gateway_endpoint: URL, + secretsmanager_client: BaseClient, # noqa: ARG001 ): """Given lambda installed into localstack, run it via http, with a nonexistent NHS number specified""" # Given @@ -120,7 +127,7 @@ def test_install_and_call_flask_lambda_with_unknown_nhs_number( invoke_url = f"{api_gateway_endpoint}/patient-check/{nhs_number}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(nhs_number)}, + headers={"nhs-login-nhs-number": str(nhs_number), UNIQUE_CONSUMER_HEADER: consumer_id}, timeout=10, ) @@ -181,7 +188,9 @@ def get_log_messages(flask_function: str, logs_client: BaseClient) -> list[str]: def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_if_audited( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, + rsv_campaign_config: CampaignConfig, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -198,7 +207,7 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", - "nhsd_application_id": "nhsd_application_id", + "nhsd-application-id": consumer_id, }, params={"includeActions": "Y"}, timeout=10, @@ -220,19 +229,19 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i "xRequestId": "x_request_id", "xCorrelationId": "x_correlation_id", "nhsdEndUserOrganisationOds": "nhsd_end_user_organisation_ods", - "nhsdApplicationId": "nhsd_application_id", + "nhsdApplicationId": consumer_id, } expected_query_params = {"category": None, "conditions": None, "includeActions": "Y"} expected_conditions = [ { - "campaignId": campaign_config.id, - "campaignVersion": campaign_config.version, - "iterationId": campaign_config.iterations[0].id, - "iterationVersion": campaign_config.iterations[0].version, - "conditionName": campaign_config.target, + "campaignId": rsv_campaign_config.id, + "campaignVersion": rsv_campaign_config.version, + "iterationId": rsv_campaign_config.iterations[0].id, + "iterationVersion": rsv_campaign_config.iterations[0].version, + "conditionName": rsv_campaign_config.target, "status": "not_actionable", - "statusText": f"You should have the {campaign_config.target} vaccine", + "statusText": f"You should have the {rsv_campaign_config.target} vaccine", "eligibilityCohorts": [{"cohortCode": "cohort1", "cohortStatus": "not_actionable"}], "eligibilityCohortGroups": [ { @@ -277,7 +286,6 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_results_in_error_response( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -285,7 +293,7 @@ def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_resu invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": f"123{persisted_person!s}"}, + headers={"nhs-login-nhs-number": f"123{persisted_person!s}", UNIQUE_CONSUMER_HEADER: "test_consumer_id"}, timeout=10, ) @@ -324,7 +332,6 @@ def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_resu def test_given_nhs_number_not_present_in_headers_results_in_valid_for_application_restricted_users( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -332,6 +339,7 @@ def test_given_nhs_number_not_present_in_headers_results_in_valid_for_applicatio invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, + headers={UNIQUE_CONSUMER_HEADER: "test_consumer_id"}, timeout=10, ) @@ -344,7 +352,8 @@ def test_given_nhs_number_not_present_in_headers_results_in_valid_for_applicatio def test_given_nhs_number_key_present_in_headers_have_no_value_results_in_error_response( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 + consumer_id: ConsumerId, + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -352,7 +361,7 @@ def test_given_nhs_number_key_present_in_headers_have_no_value_results_in_error_ invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": ""}, + headers={"nhs-login-nhs-number": "", UNIQUE_CONSUMER_HEADER: consumer_id}, timeout=10, ) @@ -390,7 +399,8 @@ def test_given_nhs_number_key_present_in_headers_have_no_value_results_in_error_ def test_validation_of_query_params_when_all_are_valid( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 + consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): # Given @@ -398,7 +408,7 @@ def test_validation_of_query_params_when_all_are_valid( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": persisted_person}, + headers={"nhs-login-nhs-number": persisted_person, UNIQUE_CONSUMER_HEADER: consumer_id}, params={"category": "VACCINATIONS", "conditions": "COVID19", "includeActions": "N"}, timeout=10, ) @@ -410,7 +420,6 @@ def test_validation_of_query_params_when_all_are_valid( def test_validation_of_query_params_when_invalid_conditions_is_specified( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -418,7 +427,7 @@ def test_validation_of_query_params_when_invalid_conditions_is_specified( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": persisted_person}, + headers={"nhs-login-nhs-number": persisted_person, UNIQUE_CONSUMER_HEADER: "test_consumer_id"}, params={"category": "ALL", "conditions": "23-097"}, timeout=10, ) @@ -431,9 +440,12 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # lambda_client: BaseClient, # noqa:ARG001 persisted_person_all_cohorts: NHSNumber, multiple_campaign_configs: list[CampaignConfig], + consumer_to_multiple_campaign_configs_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, + secretsmanager_client: BaseClient, # noqa: ARG001 ): invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person_all_cohorts}" response = httpx.get( @@ -443,7 +455,7 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", - "nhsd_application_id": "nhsd_application_id", + "nhsd_application_id": consumer_id, }, params={"includeActions": "Y", "category": "VACCINATIONS", "conditions": "COVID,FLU,RSV"}, timeout=10, @@ -463,7 +475,7 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # "xRequestId": "x_request_id", "xCorrelationId": "x_correlation_id", "nhsdEndUserOrganisationOds": "nhsd_end_user_organisation_ods", - "nhsdApplicationId": "nhsd_application_id", + "nhsdApplicationId": consumer_id, } expected_query_params = {"category": "VACCINATIONS", "conditions": "COVID,FLU,RSV", "includeActions": "Y"} @@ -573,7 +585,8 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # def test_no_active_iteration_returns_empty_processed_suggestions( lambda_client: BaseClient, # noqa:ARG001 persisted_person_all_cohorts: NHSNumber, - inactive_iteration_config: list[CampaignConfig], # noqa:ARG001 + consumer_to_campaign_having_inactive_iteration_mapping: ConsumerMapping, # noqa:ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person_all_cohorts}" @@ -584,7 +597,7 @@ def test_no_active_iteration_returns_empty_processed_suggestions( "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", - "nhsd_application_id": "nhsd_application_id", + "nhsd_application_id": consumer_id, }, params={"includeActions": "Y", "category": "VACCINATIONS", "conditions": "COVID,FLU,RSV"}, timeout=10, @@ -609,7 +622,8 @@ def test_no_active_iteration_returns_empty_processed_suggestions( def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 person_with_all_data: NHSNumber, - campaign_config_with_tokens: CampaignConfig, # noqa:ARG001 + consumer_to_active_campaign_having_tokens_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -619,7 +633,7 @@ def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 invoke_url = f"{api_gateway_endpoint}/patient-check/{person_with_all_data}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(person_with_all_data)}, + headers={"nhs-login-nhs-number": str(person_with_all_data), UNIQUE_CONSUMER_HEADER: consumer_id}, timeout=10, ) @@ -659,7 +673,8 @@ def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 def test_incorrect_token_causes_internal_server_error( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 person_with_all_data: NHSNumber, - campaign_config_with_invalid_tokens: CampaignConfig, # noqa:ARG001 + consumer_to_active_campaign_having_invalid_tokens_mapping: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -669,7 +684,7 @@ def test_incorrect_token_causes_internal_server_error( # noqa: PLR0913 invoke_url = f"{api_gateway_endpoint}/patient-check/{person_with_all_data}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(person_with_all_data)}, + headers={"nhs-login-nhs-number": str(person_with_all_data), UNIQUE_CONSUMER_HEADER: consumer_id}, timeout=10, ) diff --git a/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json b/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json new file mode 100644 index 000000000..ed4879d56 --- /dev/null +++ b/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json @@ -0,0 +1,23 @@ + +{ + "consumer-id-123": [ + { + "CampaignConfigID": "8fcb742b-45fa-4e0d-8f2f-9c2efb1f46d0", + "Description": "RSV Ongoing for My Vaccines" + }, + { + "CampaignConfigID": "COVID_campaign_id", + "Description": "COVID Ongoing for My Vaccines" + } + ], + "consumer-id-456": [ + { + "CampaignConfigID": "RSV_campaign_id_NBS", + "Description": "RSV Ongoing for NBS" + }, + { + "CampaignConfigID": "COVID_campaign_id_NBS", + "Description": "RSV Ongoing for NBS" + } + ] +} diff --git a/tests/unit/common/test_request_validator.py b/tests/unit/common/test_request_validator.py index 4c600ad57..d3025f7a7 100644 --- a/tests/unit/common/test_request_validator.py +++ b/tests/unit/common/test_request_validator.py @@ -7,6 +7,7 @@ from eligibility_signposting_api.common import request_validator from eligibility_signposting_api.common.request_validator import logger +from tests.integration.conftest import UNIQUE_CONSUMER_HEADER @pytest.fixture(autouse=True) @@ -42,7 +43,7 @@ class TestValidateRequestParams: @pytest.mark.parametrize( "headers", [ - {}, # header missing entirely - request from application restricted consumer + {}, # nhs header missing entirely - request from application restricted consumer {"nhs-login-nhs-number": "1234567890"}, # valid request from consumer ], ) @@ -54,7 +55,7 @@ def test_validate_request_params_success(self, headers, app, caplog): with app.test_request_context( "/dummy?id=1234567890", - headers=headers, + headers={UNIQUE_CONSUMER_HEADER: "some-consumer"} | headers, method="GET", ): with caplog.at_level(logging.INFO): @@ -80,7 +81,7 @@ def test_validate_request_params_nhs_mismatch(self, headers, app, caplog): with app.test_request_context( "/dummy?id=1234567890", - headers=headers, + headers={UNIQUE_CONSUMER_HEADER: "some-id"} | headers, method="GET", ): with caplog.at_level(logging.INFO): @@ -98,6 +99,58 @@ def test_validate_request_params_nhs_mismatch(self, headers, app, caplog): assert issue["diagnostics"] == "You are not authorised to request information for the supplied NHS Number" assert response.headers["Content-Type"] == "application/fhir+json" + def test_validate_request_params_consumer_id_present(self, app, caplog): + mock_api = MagicMock(return_value="ok") + + decorator = request_validator.validate_request_params() + dummy_route = decorator(mock_api) + + with ( + app.test_request_context( + "/dummy?id=1234567890", + headers={ + UNIQUE_CONSUMER_HEADER: "some-consumer", + "nhs-login-nhs-number": "1234567890", + }, + method="GET", + ), + caplog.at_level(logging.INFO), + ): + response = dummy_route(nhs_number=request.args.get("id")) + + mock_api.assert_called_once() + assert response == "ok" + assert not any(record.levelname == "ERROR" for record in caplog.records) + + def test_validate_request_params_missing_consumer_id(self, app, caplog): + mock_api = MagicMock() + + decorator = request_validator.validate_request_params() + dummy_route = decorator(mock_api) + + with ( + app.test_request_context( + "/dummy?id=1234567890", + headers={"nhs-login-nhs-number": "1234567890"}, # no consumer ID + method="GET", + ), + caplog.at_level(logging.ERROR), + ): + response = dummy_route(nhs_number=request.args.get("id")) + + mock_api.assert_not_called() + + assert response is not None + assert response.status_code == HTTPStatus.FORBIDDEN + response_json = response.json + + issue = response_json["issue"][0] + assert issue["code"] == "forbidden" + assert issue["details"]["coding"][0]["code"] == "ACCESS_DENIED" + assert issue["details"]["coding"][0]["display"] == "Access has been denied to process this request." + assert issue["diagnostics"] == "You are not authorised to request" + assert response.headers["Content-Type"] == "application/fhir+json" + class TestValidateQueryParameters: @pytest.mark.parametrize( diff --git a/tests/unit/repos/test_consumer_mapping_repo.py b/tests/unit/repos/test_consumer_mapping_repo.py new file mode 100644 index 000000000..1c6d60e8d --- /dev/null +++ b/tests/unit/repos/test_consumer_mapping_repo.py @@ -0,0 +1,62 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from eligibility_signposting_api.model.consumer_mapping import ConsumerId +from eligibility_signposting_api.repos.consumer_mapping_repo import BucketName, ConsumerMappingRepo + + +class TestConsumerMappingRepo: + @pytest.fixture + def mock_s3_client(self): + return MagicMock() + + @pytest.fixture + def repo(self, mock_s3_client): + return ConsumerMappingRepo(s3_client=mock_s3_client, bucket_name=BucketName("test-bucket")) + + def test_get_permitted_campaign_ids_success(self, repo, mock_s3_client): + # Given + consumer_id = "user-123" + + # The expected output is just the IDs + expected_campaign_ids = ["flu-2024", "covid-2024"] + + # The mocked S3 data must match the new schema (objects with description) + mapping_data = { + consumer_id: [ + {"CampaignConfigID": "flu-2024", "Description": "Flu Shot Description"}, + {"CampaignConfigID": "covid-2024", "Description": "Covid Shot Description"}, + ] + } + + mock_s3_client.list_objects.return_value = {"Contents": [{"Key": "mappings.json"}]} + + body_json = json.dumps(mapping_data).encode("utf-8") + mock_s3_client.get_object.return_value = {"Body": MagicMock(read=lambda: body_json)} + + # When + result = repo.get_permitted_campaign_ids(ConsumerId(consumer_id)) + + # Then + assert result == expected_campaign_ids + mock_s3_client.list_objects.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.get_object.assert_called_once_with(Bucket="test-bucket", Key="mappings.json") + + def test_get_permitted_campaign_ids_returns_none_when_missing(self, repo, mock_s3_client): + """ + Setup data where the consumer_id doesn't exist + We must still use the valid schema (dicts inside the list) to pass Pydantic validation + """ + valid_schema_data = {"other-user": [{"CampaignConfigID": "camp-1", "Description": "Some description"}]} + + mock_s3_client.list_objects.return_value = {"Contents": [{"Key": "mappings.json"}]} + body_json = json.dumps(valid_schema_data).encode("utf-8") + mock_s3_client.get_object.return_value = {"Body": MagicMock(read=lambda: body_json)} + + # When + result = repo.get_permitted_campaign_ids(ConsumerId("missing-user")) + + # Then + assert result is None diff --git a/tests/unit/services/test_eligibility_services.py b/tests/unit/services/test_eligibility_services.py index 504888f12..3d3b787cd 100644 --- a/tests/unit/services/test_eligibility_services.py +++ b/tests/unit/services/test_eligibility_services.py @@ -3,23 +3,43 @@ import pytest from hamcrest import assert_that, empty +from eligibility_signposting_api.model.campaign_config import CampaignConfig, CampaignID from eligibility_signposting_api.model.eligibility_status import NHSNumber from eligibility_signposting_api.repos import CampaignRepo, NotFoundError, PersonRepo +from eligibility_signposting_api.repos.consumer_mapping_repo import ConsumerMappingRepo from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.services.calculators.eligibility_calculator import EligibilityCalculatorFactory from tests.fixtures.matchers.eligibility import is_eligibility_status +@pytest.fixture +def mock_repos(): + return { + "person": MagicMock(spec=PersonRepo), + "campaign": MagicMock(spec=CampaignRepo), + "consumer": MagicMock(spec=ConsumerMappingRepo), + "factory": MagicMock(spec=EligibilityCalculatorFactory), + } + + +@pytest.fixture +def service(mock_repos): + return EligibilityService( + mock_repos["person"], mock_repos["campaign"], mock_repos["consumer"], mock_repos["factory"] + ) + + def test_eligibility_service_returns_from_repo(): # Given person_repo = MagicMock(spec=PersonRepo) campaign_repo = MagicMock(spec=CampaignRepo) + consumer_mapping_repo = MagicMock(spec=ConsumerMappingRepo) person_repo.get_eligibility = MagicMock(return_value=[]) - service = EligibilityService(person_repo, campaign_repo, EligibilityCalculatorFactory()) + service = EligibilityService(person_repo, campaign_repo, consumer_mapping_repo, EligibilityCalculatorFactory()) # When actual = service.get_eligibility_status( - NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL" + NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL", consumer_id="test_consumer_id" ) # Then @@ -30,9 +50,53 @@ def test_eligibility_service_for_nonexistent_nhs_number(): # Given person_repo = MagicMock(spec=PersonRepo) campaign_repo = MagicMock(spec=CampaignRepo) + consumer_mapping_repo = MagicMock(spec=ConsumerMappingRepo) person_repo.get_eligibility_data = MagicMock(side_effect=NotFoundError) - service = EligibilityService(person_repo, campaign_repo, EligibilityCalculatorFactory()) + service = EligibilityService(person_repo, campaign_repo, consumer_mapping_repo, EligibilityCalculatorFactory()) # When with pytest.raises(UnknownPersonError): - service.get_eligibility_status(NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL") + service.get_eligibility_status( + NHSNumber("1234567890"), + include_actions="Y", + conditions=["ALL"], + category="ALL", + consumer_id="test_consumer_id", + ) + + +def test_get_eligibility_status_filters_permitted_campaigns(service, mock_repos): + """Tests that ONLY permitted campaigns reach the calculator factory.""" + # Given + nhs_number = NHSNumber("1234567890") + person_data = {"age": 65, "vulnerable": True} + mock_repos["person"].get_eligibility_data.return_value = person_data + + # Available campaigns in system + camp_a = MagicMock(spec=CampaignConfig, id=CampaignID("CAMP_A")) + camp_b = MagicMock(spec=CampaignConfig, id=CampaignID("CAMP_B")) + mock_repos["campaign"].get_campaign_configs.return_value = [camp_a, camp_b] + + # Consumer is only permitted to see CAMP_B + mock_repos["consumer"].get_permitted_campaign_ids.return_value = [CampaignID("CAMP_B")] + + # Mock calculator behavior + mock_calc = MagicMock() + mock_repos["factory"].get.return_value = mock_calc + mock_calc.get_eligibility_status.return_value = "eligible_result" + + # When + result = service.get_eligibility_status(nhs_number, "Y", ["FLU"], "G1", "consumer_xyz") + + # Then + # Verify the factory was called ONLY with camp_b + mock_repos["factory"].get.assert_called_once_with(person_data, [camp_b]) + assert result == "eligible_result" + + +def test_raises_unknown_person_error_on_repo_not_found(service, mock_repos): + """Tests that NotFoundError from repo is translated to UnknownPersonError.""" + mock_repos["person"].get_eligibility_data.side_effect = NotFoundError + + with pytest.raises(UnknownPersonError): + service.get_eligibility_status(NHSNumber("999"), "Y", [], "", "any") diff --git a/tests/unit/views/test_eligibility.py b/tests/unit/views/test_eligibility.py index 5c323a7b2..3a0c2304c 100644 --- a/tests/unit/views/test_eligibility.py +++ b/tests/unit/views/test_eligibility.py @@ -29,10 +29,10 @@ ) from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.views.eligibility import ( + _get_or_default_query_params, build_actions, build_eligibility_cohorts, build_suitability_results, - get_or_default_query_params, ) from eligibility_signposting_api.views.response_model import eligibility_response from tests.fixtures.builders.model.eligibility import ( @@ -41,6 +41,7 @@ EligibilityStatusFactory, ) from tests.fixtures.matchers.eligibility import is_eligibility_cohort +from tests.integration.conftest import UNIQUE_CONSUMER_HEADER logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ def get_eligibility_status( _include_actions: str, _conditions: list[str], _category: str, + _consumer_id: str, ) -> EligibilityStatus: return EligibilityStatusFactory.build() @@ -74,6 +76,7 @@ def get_eligibility_status( _include_actions: str, _conditions: list[str], _category: str, + _consumer_id: str, ) -> EligibilityStatus: raise UnknownPersonError @@ -100,7 +103,7 @@ def test_security_headers_present_on_successful_response(app: Flask, client: Fla get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # When - headers = {"nhs-login-nhs-number": "9876543210"} + headers = {"nhs-login-nhs-number": "9876543210", UNIQUE_CONSUMER_HEADER: "test_consumer_id"} response = client.get("/patient-check/9876543210", headers=headers) # Then @@ -128,7 +131,7 @@ def test_security_headers_present_on_error_response(app: Flask, client: FlaskCli get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # When - headers = {"nhs-login-nhs-number": "9876543210"} + headers = {"nhs-login-nhs-number": "9876543210", UNIQUE_CONSUMER_HEADER: "test_customer_id"} response = client.get("/patient-check/9876543210", headers=headers) # Then @@ -177,7 +180,7 @@ def test_nhs_number_given(app: Flask, client: FlaskClient): get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), UNIQUE_CONSUMER_HEADER: "test_customer_id"} # When response = client.get("/patient-check/12345", headers=headers) @@ -190,7 +193,7 @@ def test_no_nhs_number_given(app: Flask, client: FlaskClient): # Given with get_app_container(app).override.service(EligibilityService, new=FakeUnknownPersonEligibilityService()): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), UNIQUE_CONSUMER_HEADER: "test_customer_id"} # When response = client.get("/patient-check/", headers=headers) @@ -229,7 +232,7 @@ def test_no_nhs_number_given(app: Flask, client: FlaskClient): def test_unexpected_error(app: Flask, client: FlaskClient): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), UNIQUE_CONSUMER_HEADER: "test_customer_id"} with get_app_container(app).override.service(EligibilityService, new=FakeUnexpectedErrorEligibilityService()): response = client.get("/patient-check/12345", headers=headers) @@ -439,7 +442,10 @@ def test_excludes_nulls_via_build_response(client: FlaskClient): return_value=mocked_response, ), ): - response = client.get("/patient-check/12345", headers={"nhs-login-nhs-number": str(12345)}) + response = client.get( + "/patient-check/12345", + headers={"nhs-login-nhs-number": str(12345), UNIQUE_CONSUMER_HEADER: "test_customer_id"}, + ) assert response.status_code == HTTPStatus.OK payload = json.loads(response.data) @@ -491,7 +497,10 @@ def test_build_response_include_values_that_are_not_null(client: FlaskClient): return_value=mocked_response, ), ): - response = client.get("/patient-check/12345", headers={"nhs-login-nhs-number": str(12345)}) + response = client.get( + "/patient-check/12345", + headers={"nhs-login-nhs-number": str(12345), UNIQUE_CONSUMER_HEADER: "test_customer_id"}, + ) assert response.status_code == HTTPStatus.OK payload = json.loads(response.data) @@ -507,7 +516,7 @@ def test_build_response_include_values_that_are_not_null(client: FlaskClient): def test_get_or_default_query_params_with_no_args(app: Flask): with app.test_request_context("/patient-check"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"category": "ALL", "conditions": ["ALL"], "includeActions": "Y"} @@ -516,7 +525,7 @@ def test_get_or_default_query_params_with_no_args(app: Flask): def test_get_or_default_query_params_with_all_args(app: Flask): with app.test_request_context("/patient-check?includeActions=Y&category=VACCINATIONS&conditions=FLU"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "Y", "category": "VACCINATIONS", "conditions": ["FLU"]} @@ -525,7 +534,7 @@ def test_get_or_default_query_params_with_all_args(app: Flask): def test_get_or_default_query_params_with_partial_args(app: Flask): with app.test_request_context("/patient-check?includeActions=N"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "N", "category": "ALL", "conditions": ["ALL"]} @@ -534,13 +543,13 @@ def test_get_or_default_query_params_with_partial_args(app: Flask): def test_get_or_default_query_params_with_lowercase_y(app: Flask): with app.test_request_context("/patient-check?includeActions=y"): - result = get_or_default_query_params() + result = _get_or_default_query_params() assert_that(result["includeActions"], is_("Y")) def test_get_or_default_query_params_missing_include_actions(app: Flask): with app.test_request_context("/patient-check?category=SCREENING&conditions=COVID19,FLU"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "Y", "category": "SCREENING", "conditions": ["COVID19", "FLU"]} @@ -581,3 +590,30 @@ def test_status_endpoint(app: Flask, client: FlaskClient): ) ), ) + + +def test_consumer_id_is_passed_to_service(app: Flask, client: FlaskClient): + """ + Verifies that the consumer ID from the header is actually passed + to the eligibility service call. + """ + # Given + mock_service = MagicMock(spec=EligibilityService) + mock_service.get_eligibility_status.return_value = EligibilityStatusFactory.build() + + with ( + get_app_container(app).override.service(EligibilityService, new=mock_service), + get_app_container(app).override.service(AuditService, new=FakeAuditService()), + ): + headers = {"nhs-login-nhs-number": "1234567890", UNIQUE_CONSUMER_HEADER: "specific_consumer_123"} + + # When + client.get("/patient-check/1234567890", headers=headers) + + # Then + # Verify the 5th positional argument or the keyword argument 'consumer_id' + mock_service.get_eligibility_status.assert_called_once() + args, _kwargs = mock_service.get_eligibility_status.call_args + + # Check that 'specific_consumer_123' was the consumer_id passed + assert args[4] == "specific_consumer_123"