Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ To set up multi-node training:
- Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
6. Edit `custom_models/__init__.py` to import your custom model (e.g., `from . import sba`). Set `__all__` accordingly so only valid packages are exposed.



Expand Down
3 changes: 3 additions & 0 deletions custom_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import sba

__all__ = ["sba"]
Comment on lines +1 to +3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Import fails hard if ‘sba’ is absent – make registration resilient / configurable

__init__.py eagerly imports sba, raising ModuleNotFoundError the moment a user removes or renames that directory.
Given the README now tells users to edit this file for their own models, a safer pattern is to:

  1. Gate the import in a try/except and emit a clear error message, or
  2. Provide a minimal lazy-import helper that looks up modules listed in __all__ only when accessed.

This prevents a broken install/run‐time experience when someone forgets to adjust the import while experimenting.

-from . import sba
-
-__all__ = ["sba"]
+import importlib
+import warnings
+
+__all__ = ["sba"]  # update this list for your own modules
+
+def __getattr__(name):
+    if name in __all__:
+        try:
+            module = importlib.import_module(f".{name}", __name__)
+            globals()[name] = module
+            return module
+        except ModuleNotFoundError:
+            raise ImportError(
+                f"custom_models: requested sub-module '{name}' not found – "
+                "did you forget to create/rename it?"
+            )
+    raise AttributeError(name)

Keeps side-effects minimal and offers a clearer failure mode.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from . import sba
__all__ = ["sba"]
import importlib
import warnings
__all__ = ["sba"] # update this list for your own modules
def __getattr__(name):
if name in __all__:
try:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
return module
except ModuleNotFoundError:
raise ImportError(
f"custom_models: requested sub-module '{name}' not found – "
"did you forget to create/rename it?"
)
raise AttributeError(name)
🤖 Prompt for AI Agents
In custom_models/__init__.py at lines 1 to 3, the current eager import of 'sba'
causes a ModuleNotFoundError if 'sba' is missing. To fix this, wrap the import
of 'sba' in a try/except block that catches ImportError and logs a clear,
user-friendly error message explaining the missing module. Alternatively,
implement a lazy-import mechanism that only imports modules listed in __all__
when they are accessed, preventing immediate import failures and improving
resilience when users modify the models.

2 changes: 1 addition & 1 deletion train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params")
config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params")
tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params")
model=$(
python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
python -c "import fla, sys, custom_models; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type'
)

mkdir -p $path
Expand Down