Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions dreadnode/integrations/axolotl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Axolotl integration for Dreadnode.

This plugin enables logging training metrics and parameters to Dreadnode Strikes
when using Axolotl for fine-tuning. It injects the existing DreadnodeCallback
into Axolotl's trainer via the plugin system.

Usage in axolotl config:
plugins:
- dreadnode.integrations.axolotl.DreadnodePlugin

dreadnode_project: "my-project"
dreadnode_run_name: "training-run-v1" # optional
dreadnode_tags: # optional
- "experiment-1"
"""

from __future__ import annotations

import logging
import typing as t

# ruff: noqa: ARG002, FBT001, FBT002

LOG = logging.getLogger(__name__)


class DreadnodePlugin:
"""
Axolotl plugin that integrates DreadnodeCallback for training metrics logging.

This plugin follows the same pattern as Axolotl's built-in integrations
(SwanLab, etc.) and reuses the existing DreadnodeCallback from the
transformers integration.
"""

def __init__(self) -> None:
self._callback: t.Any = None
self._initialized = False

def register(self, cfg: dict[str, t.Any]) -> None:
"""Called during plugin registration with unparsed config dict."""
if cfg.get("dreadnode_project"):
LOG.info("Dreadnode plugin registered for project: %s", cfg.get("dreadnode_project"))

def get_input_args(self) -> str:
"""Return the dotted path to args class for config validation."""
return "dreadnode.integrations.axolotl.args.DreadnodeAxolotlArgs"

def get_training_args_mixin(self) -> None:
"""Returns training args mixin class - not used by Dreadnode plugin."""

def load_datasets(self, cfg: t.Any, preprocess: bool = False) -> None:
"""Not used by Dreadnode plugin - returns None to use default loading."""

def pre_model_load(self, cfg: t.Any) -> None:
"""Called before model loading - early validation and logging."""
if not getattr(cfg, "dreadnode_project", None):
LOG.debug("Dreadnode integration disabled (no dreadnode_project set)")
return

LOG.info("Dreadnode integration enabled for project: %s", cfg.dreadnode_project)

def post_model_build(self, cfg: t.Any, model: t.Any) -> None:
"""Called after model is built - not used by Dreadnode plugin."""

def pre_lora_load(self, cfg: t.Any, model: t.Any) -> None:
"""Called before LoRA loading - not used by Dreadnode plugin."""

def post_lora_load(self, cfg: t.Any, model: t.Any) -> None:
"""Called after LoRA loading - not used by Dreadnode plugin."""

def post_model_load(self, cfg: t.Any, model: t.Any) -> None:
"""Called after model loading - not used by Dreadnode plugin."""

def get_trainer_cls(self, cfg: t.Any) -> None:
"""Returns custom trainer class - not used by Dreadnode plugin."""

def post_trainer_create(self, cfg: t.Any, trainer: t.Any) -> None:
"""Called after trainer creation - not used by Dreadnode plugin."""

def get_training_args(self, cfg: t.Any) -> None:
"""Returns custom training args - not used by Dreadnode plugin."""

def get_collator_cls_and_kwargs(self, cfg: t.Any, is_eval: bool = False) -> None:
"""Returns custom collator - not used by Dreadnode plugin."""

def create_optimizer(self, cfg: t.Any, trainer: t.Any) -> None:
"""Returns custom optimizer - not used by Dreadnode plugin."""

def create_lr_scheduler(
self, cfg: t.Any, trainer: t.Any, optimizer: t.Any, num_training_steps: int
) -> None:
"""Returns custom LR scheduler - not used by Dreadnode plugin."""

def add_callbacks_pre_trainer(self, cfg: t.Any, model: t.Any) -> list[t.Any]:
"""Returns callbacks before trainer creation - not used by Dreadnode plugin."""
return []

def add_callbacks_post_trainer(self, cfg: t.Any, trainer: t.Any) -> list[t.Any]:
"""
Hook called after trainer is created - inject our callback here.

Returns a list of TrainerCallbacks that Axolotl will add to the trainer.
"""
if not getattr(cfg, "dreadnode_project", None):
return []

# Only initialize on rank 0 in distributed training
try:
from axolotl.utils.distributed import ( # type: ignore[import-untyped,import-not-found]
is_main_process,
)

if not is_main_process():
LOG.debug("Skipping Dreadnode callback on non-main process")
return []
except ImportError:
# Fallback for older axolotl versions or non-distributed
pass

import dreadnode as dn
from dreadnode.integrations.transformers import DreadnodeCallback

# Configure workspace if provided
workspace = getattr(cfg, "dreadnode_workspace", None)
if workspace:
dn.configure(workspace=workspace)

self._callback = DreadnodeCallback(
project=cfg.dreadnode_project,
run_name=getattr(cfg, "dreadnode_run_name", None),
tags=getattr(cfg, "dreadnode_tags", None),
)
self._initialized = True

LOG.info(
"Registered DreadnodeCallback: project=%s, run_name=%s",
cfg.dreadnode_project,
getattr(cfg, "dreadnode_run_name", "auto"),
)

return [self._callback]

def post_train(self, cfg: t.Any, model: t.Any) -> None:
"""Called after training completes."""
if self._initialized:
LOG.info("Training complete, Dreadnode run finalized")

def post_train_unload(self, cfg: t.Any) -> None:
"""Called after training unload - not used by Dreadnode plugin."""


__all__ = ["DreadnodePlugin"]
42 changes: 42 additions & 0 deletions dreadnode/integrations/axolotl/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Pydantic args for Axolotl config validation.

These fields are merged into the main Axolotl config schema when the
DreadnodePlugin is enabled.
"""

from pydantic import BaseModel, Field


class DreadnodeAxolotlArgs(BaseModel):
"""
Configuration options for Dreadnode integration with Axolotl.

These fields are merged into the main Axolotl config schema.
"""

dreadnode_project: str | None = Field(
default=None,
json_schema_extra={
"description": "Dreadnode project name. If not set, Dreadnode logging is disabled."
},
)

dreadnode_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "Dreadnode workspace name. If not set, uses the default workspace or DREADNODE_WORKSPACE env var."
},
)

dreadnode_run_name: str | None = Field(
default=None,
json_schema_extra={
"description": "Name for this training run. Defaults to Axolotl's run_name if not specified."
},
)

dreadnode_tags: list[str] | None = Field(
default=None,
json_schema_extra={"description": "Tags to associate with this run in Strikes."},
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ module = [
"dreadnode.data_types.*",
"dreadnode.scorers.*",
"dreadnode.transforms.*",
"dreadnode.integrations.axolotl.*",
]
disable_error_code = ["unused-ignore", "import-untyped", "import-not-found"]

Expand Down
Loading