Better support for target models on the ensemble attack#135
Better support for target models on the ensemble attack#135
Conversation
…celo/ensamble-ctgan
…n optional parameter to the config
📝 WalkthroughWalkthroughThe PR introduces a new abstraction model for ensemble attack training by creating an Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.Change the |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tests/unit/attacks/ensemble/test_process_data_split.py (1)
41-44:⚠️ Potential issue | 🟡 MinorDuplicate assertion detected.
Line 44 duplicates the assertion from line 43 (both check
real_test.csv). This appears to be a copy-paste error and doesn't add test coverage.🔧 Suggested fix
# Assert that the split real data files are saved in the provided path assert (output_dir / "real_train.csv").exists() assert (output_dir / "real_val.csv").exists() assert (output_dir / "real_test.csv").exists() - assert (output_dir / "real_test.csv").exists()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/attacks/ensemble/test_process_data_split.py` around lines 41 - 44, Remove the duplicate assertion that repeats checking (output_dir / "real_test.csv").exists(); locate the repeated line asserting (output_dir / "real_test.csv").exists() and delete it so the test only asserts real_train.csv, real_val.csv, and real_test.csv once each (or, if the intent was to verify a different file, replace the duplicate with the correct filename instead of duplicating the real_test.csv assertion).examples/ensemble_attack/test_attack_model.py (1)
345-353:⚠️ Potential issue | 🟡 MinorTypo in variable names: "mataclassifier" should be "metaclassifier".
Lines 345-351 have consistent typos:
mataclassifier_pathandtrained_mataclassifier_model. While this doesn't affect functionality, it reduces code readability.📝 Suggested fix
- mataclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" - assert mataclassifier_path.exists(), ( - f"No metaclassifier model found at {mataclassifier_path}. Make sure to run the training script first." + metaclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" + assert metaclassifier_path.exists(), ( + f"No metaclassifier model found at {metaclassifier_path}. Make sure to run the training script first." ) - with open(mataclassifier_path, "rb") as f: - trained_mataclassifier_model = pickle.load(f) + with open(metaclassifier_path, "rb") as f: + trained_metaclassifier_model = pickle.load(f) - log(INFO, f"Metaclassifier model loaded from {mataclassifier_path}, starting the test...") + log(INFO, f"Metaclassifier model loaded from {metaclassifier_path}, starting the test...")Also update line 422:
- blending_attacker.trained_model = trained_mataclassifier_model + blending_attacker.trained_model = trained_metaclassifier_model🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/test_attack_model.py` around lines 345 - 353, Rename the misspelled variables in the test to use "metaclassifier" consistently: change mataclassifier_path to metaclassifier_path (the Path construction and the existence assertion) and change trained_mataclassifier_model to trained_metaclassifier_model (the pickle.load assignment and any subsequent uses, including the referenced occurrence around line 422); update all references so variable names match (e.g., in the open(...) block and the log message) to improve readability and avoid typos.src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
53-80:⚠️ Potential issue | 🟡 MinorDocstring references removed parameters.
Lines 67 and 73-74 document parameters
fine_tuning_configandnumber_of_points_to_synthesizethat no longer exist in the function signature after the refactor.📝 Suggested fix
training_json_config_paths: Configuration dictionary containing paths to the data JSON config files. An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - table_domain_file_path (str): Path to the table domain json file. - dataset_meta_file_path (str): Path to dataset meta json file. - training_config_path (str): Path to table's training config json file. - fine_tuning_config: Configuration dictionary containing shadow model fine-tuning specific information. init_model_id: An ID to assign to the pre-trained initial models. This can be used to save multiple pre-trained models with different IDs. table_name: Name of the main table to be used for training the TabDDPM model. id_column_name: Name of the ID column in the data. pre_training_data_size: Size of the initial training set, defaults to 60,000. - number_of_points_to_synthesize: Size of the synthetic data to be generated by each shadow model, - defaults to 20,000. init_data_seed: Random seed for the initial training set. random_seed: Random seed used for reproducibility, defaults to None.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py` around lines 53 - 80, The docstring for the shadow model training function in src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`; update the function's docstring to match the current signature by removing any references to those parameters (or replacing them with the correct current parameter names), ensure the Args section lists only existing parameters such as `model_runner`, `n_models`, `n_reps`, `population_data`, `master_challenge_data`, `shadow_models_output_path`, `training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`, `pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the description of Returns if needed so the docstring accurately reflects the function's current behavior.
🧹 Nitpick comments (3)
examples/ensemble_attack/run_attack.py (1)
87-96: Consider adding error handling for JSON config loading.The JSON is loaded directly into
EnsembleAttackTabDDPMTrainingConfigwithout validation. If the JSON structure doesn't match the expected schema, the error message may be unclear.💡 Optional: Add a try-except for clearer error messages
with open(config.shadow_training.training_json_config_paths.training_config_path, "r") as file: - training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file)) + try: + training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file)) + except (TypeError, ValueError) as e: + raise ValueError( + f"Failed to parse training config from " + f"{config.shadow_training.training_json_config_paths.training_config_path}: {e}" + ) from e🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/run_attack.py` around lines 87 - 96, The JSON loading for EnsembleAttackTabDDPMTrainingConfig is unprotected, so malformed or schema-mismatched JSON will raise unclear errors; wrap the open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a try-except that catches json.JSONDecodeError and TypeError/ValueError (or a generic Exception) and re-raise or log a clear message including the config path and the original exception, then only proceed to set fine_tuning_* fields and instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.src/midst_toolkit/attacks/ensemble/models.py (1)
115-117: Hardcoded table name "trans" should be parameterized.The
load_tablescall uses a hardcoded"trans"key in thetrain_datadictionary. This assumes the table is always named "trans", but other datasets may use different table names.Consider either:
- Adding a
table_namefield toEnsembleAttackTabDDPMTrainingConfig(similar to howEnsembleAttackCTGANTrainingConfighas it)- Deriving the table name from the dataset metadata
♻️ Suggested approach
+class EnsembleAttackTabDDPMTrainingConfig(ClavaDDPMTrainingConfig, EnsembleAttackTrainingConfig): + fine_tuning_diffusion_iterations: int = 100 + fine_tuning_classifier_iterations: int = 10 + table_name: str = "trans" # Default for backward compatibilityThen in
train_or_fine_tune_and_synthesize:- tables, relation_order, _ = load_tables(self.training_config.general.data_dir, train_data={"trans": dataset}) + tables, relation_order, _ = load_tables( + self.training_config.general.data_dir, + train_data={self.training_config.table_name: dataset} + )And update synthesis line 182:
- result.synthetic_data = cleaned_tables["trans"] + result.synthetic_data = cleaned_tables[self.training_config.table_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/models.py` around lines 115 - 117, The load_tables call in train_or_fine_tune_and_synthesize uses a hardcoded "trans" key which breaks datasets with different table names; add a table_name field to EnsembleAttackTabDDPMTrainingConfig (mirroring EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided dataset metadata, then replace the hardcoded "trans" with that value (e.g., use self.training_config.table_name or dataset.metadata.table_name) in the load_tables call and anywhere later (including the synthesis usage around the previous line 182) so the code uses the configured/derived table name instead of "trans".examples/ensemble_attack/run_shadow_model_training.py (1)
61-72: Sharedmodel_runnerstate is mutated by multiple callers.Based on the context snippets, the same
model_runnerinstance flows throughrun_shadow_model_training→train_three_sets_of_shadow_models→ multiple shadow training functions, each of which overwritesmodel_runner.training_config.general.*fields. Then the same instance is passed torun_target_model_training, which resets these fields here (lines 68-70).While this appears to work correctly because the reset happens before the target model uses the config, this pattern is fragile. If the call order changes or a caller forgets to reset, stale paths could cause models to be saved to unexpected locations.
Consider either:
- Documenting this behavior explicitly in the function docstring
- Creating a fresh config copy for each training phase instead of mutating the shared instance
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/run_shadow_model_training.py` around lines 61 - 72, The code mutates the shared model_runner.training_config.general for different training phases (see model_runner, training_config.general, save_additional_training_config, train_or_fine_tune_and_synthesize) which is fragile; fix by creating and using a fresh config copy for each phase instead of mutating the shared instance—e.g., deep-copy model_runner.training_config (or construct a new TrainingConfig from save_additional_training_config's return) and assign that copy to a separate ModelRunner or pass it into train_or_fine_tune_and_synthesize, leaving the original model_runner unchanged; update run_shadow_model_training, train_three_sets_of_shadow_models, and run_target_model_training to operate on their own config copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gan/ensemble_attack/utils.py`:
- Around line 32-43: The docstring for function make_training_config contains a
typo "attacktraining"; update the docstring text to read "attack training" (in
the description line and any other occurrences within the function's docstring)
so it clearly states "Make the ensemble attack training config for the CTGAN
model..." and preserves existing punctuation and formatting.
In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py`:
- Around line 19-34: Update the docstring for the function that modifies and
loads training configurations (the function taking config_type, data_dir,
training_config_json_path, final_config_json_path, experiment_name,
workspace_name) to remove the specific "TabDDPM" mention and describe it as a
generic modifier/loader for any EnsembleAttackTrainingConfig subclass;
explicitly state that config_type accepts an EnsembleAttackTrainingConfig
subclass, and ensure the Args and Returns sections reflect the generic behavior
and returned configs/save_dir values instead of TabDDPM-specific language.
---
Outside diff comments:
In `@examples/ensemble_attack/test_attack_model.py`:
- Around line 345-353: Rename the misspelled variables in the test to use
"metaclassifier" consistently: change mataclassifier_path to metaclassifier_path
(the Path construction and the existence assertion) and change
trained_mataclassifier_model to trained_metaclassifier_model (the pickle.load
assignment and any subsequent uses, including the referenced occurrence around
line 422); update all references so variable names match (e.g., in the open(...)
block and the log message) to improve readability and avoid typos.
In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py`:
- Around line 53-80: The docstring for the shadow model training function in
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents
removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`;
update the function's docstring to match the current signature by removing any
references to those parameters (or replacing them with the correct current
parameter names), ensure the Args section lists only existing parameters such as
`model_runner`, `n_models`, `n_reps`, `population_data`,
`master_challenge_data`, `shadow_models_output_path`,
`training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`,
`pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the
description of Returns if needed so the docstring accurately reflects the
function's current behavior.
In `@tests/unit/attacks/ensemble/test_process_data_split.py`:
- Around line 41-44: Remove the duplicate assertion that repeats checking
(output_dir / "real_test.csv").exists(); locate the repeated line asserting
(output_dir / "real_test.csv").exists() and delete it so the test only asserts
real_train.csv, real_val.csv, and real_test.csv once each (or, if the intent was
to verify a different file, replace the duplicate with the correct filename
instead of duplicating the real_test.csv assertion).
---
Nitpick comments:
In `@examples/ensemble_attack/run_attack.py`:
- Around line 87-96: The JSON loading for EnsembleAttackTabDDPMTrainingConfig is
unprotected, so malformed or schema-mismatched JSON will raise unclear errors;
wrap the open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a
try-except that catches json.JSONDecodeError and TypeError/ValueError (or a
generic Exception) and re-raise or log a clear message including the config path
and the original exception, then only proceed to set fine_tuning_* fields and
instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.
In `@examples/ensemble_attack/run_shadow_model_training.py`:
- Around line 61-72: The code mutates the shared
model_runner.training_config.general for different training phases (see
model_runner, training_config.general, save_additional_training_config,
train_or_fine_tune_and_synthesize) which is fragile; fix by creating and using a
fresh config copy for each phase instead of mutating the shared instance—e.g.,
deep-copy model_runner.training_config (or construct a new TrainingConfig from
save_additional_training_config's return) and assign that copy to a separate
ModelRunner or pass it into train_or_fine_tune_and_synthesize, leaving the
original model_runner unchanged; update run_shadow_model_training,
train_three_sets_of_shadow_models, and run_target_model_training to operate on
their own config copies.
In `@src/midst_toolkit/attacks/ensemble/models.py`:
- Around line 115-117: The load_tables call in train_or_fine_tune_and_synthesize
uses a hardcoded "trans" key which breaks datasets with different table names;
add a table_name field to EnsembleAttackTabDDPMTrainingConfig (mirroring
EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided
dataset metadata, then replace the hardcoded "trans" with that value (e.g., use
self.training_config.table_name or dataset.metadata.table_name) in the
load_tables call and anywhere later (including the synthesis usage around the
previous line 182) so the code uses the configured/derived table name instead of
"trans".
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8fdb8e7f-6866-45c8-bf8d-cb1f088f2238
📒 Files selected for processing (15)
.gitignoreexamples/ensemble_attack/run_attack.pyexamples/ensemble_attack/run_metaclassifier_training.pyexamples/ensemble_attack/run_shadow_model_training.pyexamples/ensemble_attack/test_attack_model.pyexamples/gan/ensemble_attack/test_attack_model.pyexamples/gan/ensemble_attack/train_attack_model.pyexamples/gan/ensemble_attack/utils.pysrc/midst_toolkit/attacks/ensemble/models.pysrc/midst_toolkit/attacks/ensemble/process_split_data.pysrc/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.pysrc/midst_toolkit/attacks/ensemble/shadow_model_utils.pytests/integration/attacks/ensemble/assets/data_configs/trans.jsontests/integration/attacks/ensemble/test_shadow_model_training.pytests/unit/attacks/ensemble/test_process_data_split.py
emersodb
left a comment
There was a problem hiding this comment.
A really nice step in the right direction! I like the design you chose.
| config.shadow_training.fine_tuning_config.fine_tune_classifier_iterations | ||
| ) | ||
|
|
||
| model_runner = EnsembleAttackTabDDPMModelRunner(training_config=training_config) |
There was a problem hiding this comment.
Perhaps you've already thought of this, but should the code above be part of the base for the ModelRunner? That is, should lines 87-94 actually happen inside that class rather than in the attack script here?
There was a problem hiding this comment.
This would also slightly simplify the process of subbing out the model, since you would just need to sub the runner class instead of both the running and the config class? I might be missing a complexity though.
There was a problem hiding this comment.
Not sure if I understood your idea, but I thought maybe if I pass the config dictionary to the init of the model runner class we would be able to skip making the config. Is that it?
There was a problem hiding this comment.
Sort of. My thought was that you could simply have the EnsembleAttackTabDDPMModelRunner init take a path to the configuration file. Then you could load the file and do all of the steps to properly construct EnsembleAttackTabDDPMTrainingConfig object within the runner class? That way a user doesn't have to do that themselves.
It's possible I'm missing something where that would be a bad idea though 🙂
There was a problem hiding this comment.
Let me know if my explanation of what I was trying to suggest isn't clear. We can talk about it together.
There was a problem hiding this comment.
The only problem I see with that is the attack changes the config a few times so any runner must have a config object in order for this to run properly (not great practice, though, I want to refactor that out at some point). I think I can use OOP to our advantage here and add the requirement to have a config for any classes implementing a model runner. TBD.
There was a problem hiding this comment.
Gotcha. Okay, I'll defer to your judgement/plan here!
There was a problem hiding this comment.
It ended up being a bigger change than I wanted but it looks simpler now. The only downside I see is the code inside the main library being more dependent on the config file.
I think we should move away from stuff being set in config files and move into parameters with default values or config classes. It will make it easier for people outside our team to use the library. It's a bigger effort, though, but I plan to do it little by little as I work on experiments. We can talk more in our 1 on 1 today.
emersodb
left a comment
There was a problem hiding this comment.
The changes you made and most of the responses you gave make sense to me! Just a few final pieces I think.
sarakodeiri
left a comment
There was a problem hiding this comment.
Minor suggestions and questions for my own learning. Thanks Marcelo!
emersodb
left a comment
There was a problem hiding this comment.
Really like the changes!
PR Type
Feature
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868h6nmfc
Refactoring the ensemble attack to allow for more flexibility when setting the target model.
The way I went about this was to use inheritance, with the introduction of the following abstract classes on the
midst_toolkit/attacks/ensemble/models.py:EnsembleAttackModelRunner: is the main class responsible for running the model (training, fine tuning and synthesizing). It will also store the configs.train_or_fine_tune_and_synthesizemethod is the main method any models that we run the ensemble attack against need to implement.EnsembleAttackTrainingConfig: Will store the configs for training, synthesizing and fine-tuning. Inherits frommidst_toolkit.common.config.TrainingConfig.EnsembleAttackTrainingResult: Stores the result of the ensemble attack model training.Each one of those classes have their respective TabDDPM and ClavaDDPM counterparts, with their implementations moved from the utils file into the
train_or_fine_tune_and_synthesizefunctions.Most of the other changes are just general low hanging fruit refactorings.
Tests Added
Just fixing the currently existing tests