diff --git a/dashboard/app.py b/dashboard/app.py
index 705177e..2945147 100644
--- a/dashboard/app.py
+++ b/dashboard/app.py
@@ -13,7 +13,8 @@
sys.path.append(str(project_root))
from components.header import render_header
-from utils.data_loader import InfiniMetricsDataLoader, load_summary_file
+from utils.data_loader import InfiniMetricsDataLoader
+from common import show_data_source_info
# Page configuration
st.set_page_config(
@@ -28,6 +29,8 @@
st.session_state.data_loader = InfiniMetricsDataLoader()
if "selected_accelerators" not in st.session_state:
st.session_state.selected_accelerators = []
+if "use_mongodb" not in st.session_state:
+ st.session_state.use_mongodb = False
def main():
@@ -40,11 +43,34 @@ def main():
with st.sidebar:
st.markdown("## ⚙️ 设置")
+ # Data source selection
+ use_mongodb = st.toggle(
+ "使用 MongoDB",
+ value=st.session_state.use_mongodb,
+ help="切换到 MongoDB 数据源(需要 MongoDB 服务运行中)",
+ )
+
+ if use_mongodb != st.session_state.use_mongodb:
+ st.session_state.use_mongodb = use_mongodb
+ if use_mongodb:
+ st.session_state.data_loader = InfiniMetricsDataLoader(
+ use_mongodb=True, fallback_to_files=True
+ )
+ else:
+ st.session_state.data_loader = InfiniMetricsDataLoader()
+
+ # Show current data source
+ show_data_source_info(style="sidebar")
+
+ st.markdown("---")
+
results_dir = st.text_input(
- "测试结果目录", value="./test_output", help="包含 JSON/CSV 测试结果的目录"
+ "测试结果目录", value="../output", help="包含 JSON/CSV 测试结果的目录"
)
- if results_dir != str(st.session_state.data_loader.results_dir):
+ if not use_mongodb and results_dir != str(
+ st.session_state.data_loader.results_dir
+ ):
st.session_state.data_loader = InfiniMetricsDataLoader(results_dir)
auto_refresh = st.toggle("自动刷新", value=False)
@@ -102,8 +128,9 @@ def render_dashboard(run_id_filter: str):
">
InfiniMetrics Dashboard 用于统一展示
通信(NCCL / 集合通信)、
- 推理(Direct / Service)、
- 算子(核心算子性能)
+ 推理(直接推理 / 服务性能)、
+ 算子(核心算子性能)、
+ 硬件(内存带宽 / 缓存性能)
等 AI 加速卡性能测试结果。
测试框架输出 JSON(环境 / 配置 / 标量指标) +
@@ -230,10 +257,10 @@ def _latest(lst):
st.dataframe(df, use_container_width=True, hide_index=True)
# ========== Dispatcher summary ==========
- summaries = load_summary_file()
+ summaries = st.session_state.data_loader.load_summaries()
if not summaries:
- st.info("No dispatcher_summary file found")
+ st.info("未找到 Dispatcher 汇总记录")
return
st.markdown("### 🧾 Dispatcher 汇总记录")
diff --git a/dashboard/common.py b/dashboard/common.py
index 73b15ca..0a59f07 100644
--- a/dashboard/common.py
+++ b/dashboard/common.py
@@ -21,8 +21,35 @@ def init_page(page_title: str, page_icon: str):
# Page configuration
st.set_page_config(page_title=page_title, page_icon=page_icon, layout="wide")
- # Initialize DataLoader
+ # Initialize use_mongodb setting if not exists
+ if "use_mongodb" not in st.session_state:
+ st.session_state.use_mongodb = False
+
+ # Initialize DataLoader (respect MongoDB setting)
if "data_loader" not in st.session_state:
from utils.data_loader import InfiniMetricsDataLoader
- st.session_state.data_loader = InfiniMetricsDataLoader()
+ st.session_state.data_loader = InfiniMetricsDataLoader(
+ use_mongodb=st.session_state.use_mongodb,
+ fallback_to_files=True,
+ )
+
+
+def show_data_source_info(style: str = "caption"):
+ """
+ Display current data source info (MongoDB or file system).
+
+ Args:
+ style: Display style - "caption" for pages, "sidebar" for main app sidebar
+ """
+ dl = st.session_state.data_loader
+ if dl.source_type == "mongodb":
+ if style == "sidebar":
+ st.success("🟢 数据源: MongoDB")
+ else:
+ st.caption("数据源: MongoDB")
+ else:
+ if style == "sidebar":
+ st.info(f"📁 数据源: 文件系统 ({dl.results_dir})")
+ else:
+ st.caption(f"数据源: 文件系统 ({dl.results_dir})")
diff --git a/dashboard/pages/communication.py b/dashboard/pages/communication.py
index e0a9c9f..27d9508 100644
--- a/dashboard/pages/communication.py
+++ b/dashboard/pages/communication.py
@@ -4,7 +4,7 @@
import streamlit as st
import pandas as pd
-from common import init_page
+from common import init_page, show_data_source_info
from components.header import render_header
from utils.data_loader import get_friendly_size
from utils.metrics import extract_core_metrics
@@ -17,7 +17,7 @@
create_summary_table_infer,
)
-init_page("推理测试分析 | InfiniMetrics", "🔗")
+init_page("通信测试分析 | InfiniMetrics", "🔗")
def main():
@@ -25,6 +25,8 @@ def main():
render_header()
st.markdown("## 🔗 通信性能测试分析")
+ show_data_source_info()
+
try:
# Load communication test results
comm_runs = st.session_state.data_loader.list_test_runs("comm")
@@ -117,7 +119,9 @@ def main():
for name in selected_indices:
idx = run_options[name]
run_info = filtered_runs[idx]
- result = st.session_state.data_loader.load_test_result(run_info["path"])
+ # Use path for file source, run_id for MongoDB
+ identifier = run_info.get("path") or run_info.get("run_id")
+ result = st.session_state.data_loader.load_test_result(identifier)
run_info["data"] = result
selected_runs.append(run_info)
diff --git a/dashboard/pages/inference.py b/dashboard/pages/inference.py
index e9db234..14e0741 100644
--- a/dashboard/pages/inference.py
+++ b/dashboard/pages/inference.py
@@ -4,9 +4,8 @@
import streamlit as st
import pandas as pd
-from common import init_page
+from common import init_page, show_data_source_info
from components.header import render_header
-from utils.data_loader import InfiniMetricsDataLoader, get_friendly_size
from utils.visualizations import (
plot_timeseries_auto,
create_summary_table_infer,
@@ -19,8 +18,9 @@ def main():
render_header()
st.markdown("## 🚀 推理性能测试分析")
- dl = st.session_state.data_loader
- runs = dl.list_test_runs("infer")
+ show_data_source_info()
+
+ runs = st.session_state.data_loader.list_test_runs("infer")
if not runs:
st.info("未找到推理测试结果(testcase 需以 infer.* 开头)。")
@@ -91,7 +91,9 @@ def _mode_of(r):
selected_runs = []
for k in selected:
ri = filtered[options[k]]
- data = dl.load_test_result(ri["path"])
+ # Use path for file source, run_id for MongoDB
+ identifier = ri.get("path") or ri.get("run_id")
+ data = st.session_state.data_loader.load_test_result(identifier)
ri = dict(ri)
ri["data"] = data
selected_runs.append(ri)
diff --git a/dashboard/pages/operator.py b/dashboard/pages/operator.py
index a53fd38..5fdfd78 100644
--- a/dashboard/pages/operator.py
+++ b/dashboard/pages/operator.py
@@ -4,7 +4,7 @@
import streamlit as st
import pandas as pd
-from common import init_page
+from common import init_page, show_data_source_info
from components.header import render_header
from utils.visualizations import (
create_summary_table_ops,
@@ -18,24 +18,18 @@ def main():
render_header()
st.markdown("## ⚡ 算子测试分析")
- dl = st.session_state.data_loader
+ show_data_source_info()
- runs = dl.list_test_runs() # Load all test runs first
- # Identify operator runs by checking "operators" in path or testcase starting with operator/ops
+ runs = st.session_state.data_loader.list_test_runs()
+ # Identify operator runs by testcase starting with operator/ops
ops_runs = []
for r in runs:
- p = str(r.get("path", ""))
tc = (r.get("testcase") or "").lower()
- if (
- ("/operators/" in p.replace("\\", "/"))
- or tc.startswith("operator")
- or tc.startswith("operators")
- or tc.startswith("ops")
- ):
+ if tc.startswith(("operator", "operators", "ops")):
ops_runs.append(r)
if not ops_runs:
- st.info("未找到算子测试结果(请确认 JSON 在 test_output/operators/ 下)。")
+ st.info("未找到算子测试结果(请确认 JSON 在 output/operators 目录下)。")
return
with st.sidebar:
@@ -62,7 +56,9 @@ def main():
selected_runs = []
for k in selected:
ri = filtered[options[k]]
- data = dl.load_test_result(ri["path"])
+ # Use path for file source, run_id for MongoDB
+ identifier = ri.get("path") or ri.get("run_id")
+ data = st.session_state.data_loader.load_test_result(identifier)
ri = dict(ri)
ri["data"] = data
selected_runs.append(ri)
diff --git a/dashboard/utils/data_loader.py b/dashboard/utils/data_loader.py
index f2b9c66..5d6cccf 100644
--- a/dashboard/utils/data_loader.py
+++ b/dashboard/utils/data_loader.py
@@ -1,277 +1,168 @@
#!/usr/bin/env python3
-"""Data loading utilities for InfiniMetrics dashboard."""
+"""Unified data loader for InfiniMetrics dashboard."""
-import json
-import csv
-import pandas as pd
-from pathlib import Path
-from typing import Dict, List, Any, Optional, Tuple
import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional
-logger = logging.getLogger(__name__)
-
-
-class InfiniMetricsDataLoader:
- """Load and parse InfiniMetrics test results."""
-
- def __init__(self, results_dir: str = "./test_output"):
- self.results_dir = Path(results_dir)
-
- def list_test_runs(self, test_type: str = None) -> List[Dict[str, Any]]:
- """List all test runs, filtering out summary files."""
- runs = []
-
- # Search for JSON result files
- for json_file in self.results_dir.rglob("*.json"):
- try:
- # Skip summary files and dispatcher files
- if (
- "summary" in json_file.name.lower()
- or "dispatcher" in json_file.name.lower()
- ):
- continue
-
- with open(json_file, "r", encoding="utf-8") as f:
- data = json.load(f)
-
- # Filter: must be a test result file, not a summary file
- if not self._is_test_result_file(data):
- continue
-
- # Filter by test type if specified
- testcase = data.get("testcase", "")
- if test_type and not testcase.startswith(test_type):
- continue
-
- # Extract basic info
- run_info = self._extract_run_info(data, json_file)
-
- # Extract the accelerator card type
- run_info["accelerator_types"] = extract_accelerator_types(data)
- runs.append(run_info)
-
- except Exception as e:
- logger.debug(f"Skipping file {json_file}: {e}")
-
- # Sort by time (newest first)
- runs.sort(key=lambda x: x["time"], reverse=True)
- return runs
-
- def load_test_result(self, json_path: Path) -> Dict[str, Any]:
- """Load a single test result with all data."""
- with open(json_path, "r", encoding="utf-8") as f:
- data = json.load(f)
-
- # Load associated CSV files
- for metric in data.get("metrics", []):
- csv_url = metric.get("raw_data_url")
- if csv_url and not csv_url.startswith("http"):
- # Get the correct base directory
- base_dir = self._get_csv_base_dir(data, json_path)
- csv_path = self._resolve_csv_path(csv_url, base_dir)
-
- if csv_path and csv_path.exists():
- try:
- df = pd.read_csv(csv_path)
- metric["data"] = df
- metric["data_columns"] = list(df.columns)
- metric["csv_path"] = str(csv_path)
- except Exception as e:
- logger.warning(f"Failed to load CSV {csv_path}: {e}")
- metric["data"] = None
- else:
- logger.debug(f"CSV not found: {csv_url} (base: {base_dir})")
- metric["data"] = None
-
- return data
-
- def load_csv_data(
- self, csv_url: str, json_data: Dict[str, Any], json_path: Path
- ) -> Optional[pd.DataFrame]:
- """Load CSV data file using proper path resolution."""
- try:
- if csv_url.startswith("http"):
- return None
-
- base_dir = self._get_csv_base_dir(json_data, json_path)
- csv_path = self._resolve_csv_path(csv_url, base_dir)
+import pandas as pd
- if csv_path and csv_path.exists():
- return pd.read_csv(csv_path)
- except Exception as e:
- logger.error(f"Failed to load CSV {csv_url}: {e}")
- return None
+from .data_sources import DataSource, FileDataSource, MongoDataSource
+from .data_utils import extract_accelerator_types, extract_run_info, get_friendly_size
- def _is_test_result_file(self, data: Dict[str, Any]) -> bool:
- """Check if JSON file is a test result (not a summary)."""
- # Must have these fields
- required = ["run_id", "testcase", "config"]
- if not all(key in data for key in required):
- return False
+logger = logging.getLogger(__name__)
- # Should have metrics
- if "metrics" not in data:
- return False
- return True
+class InfiniMetricsDataLoader:
+ """
+ Unified data loader supporting multiple sources.
- def _extract_run_info(
- self, data: Dict[str, Any], json_path: Path
- ) -> Dict[str, Any]:
- """Extract run info from test result data."""
- config = data.get("config", {})
- resolved = data.get("resolved", {})
+ Supports:
+ - File-based loading (default)
+ - MongoDB-based loading
+ - Automatic fallback from MongoDB to files
+ """
- # Device used: try resolved first, then config
- device_used = (
- resolved.get("device_used")
- or config.get("device_used")
- or config.get("device_involved", 1)
- )
+ def __init__(
+ self,
+ results_dir: str = "../output",
+ use_mongodb: bool = False,
+ mongo_config=None,
+ fallback_to_files: bool = True,
+ ):
+ """
+ Initialize the data loader.
- # Nodes: try resolved first, then environment
- nodes = resolved.get("nodes") or data.get("environment", {}).get(
- "cluster_scale", 1
+ Args:
+ results_dir: Directory containing test result files
+ use_mongodb: If True, use MongoDB as primary data source
+ mongo_config: Optional MongoDB configuration
+ fallback_to_files: If True, fall back to file loading if MongoDB fails
+ """
+ self.results_dir = Path(results_dir)
+ self._fallback_to_files = fallback_to_files
+ self._use_mongodb = use_mongodb
+ self._mongo_config = mongo_config
+ self._source: Optional[DataSource] = None
+
+ if use_mongodb:
+ self._init_mongodb_source()
+ else:
+ self._source = FileDataSource(results_dir)
+
+ def _init_mongodb_source(self):
+ """Initialize MongoDB data source with optional fallback."""
+ mongo_source = MongoDataSource(self._mongo_config)
+
+ if mongo_source.is_connected():
+ self._source = mongo_source
+ elif self._fallback_to_files:
+ logger.warning("MongoDB unavailable, falling back to file-based loading")
+ self._source = FileDataSource(str(self.results_dir))
+ self._use_mongodb = False
+ else:
+ raise RuntimeError("MongoDB connection failed and fallback is disabled")
+
+ @property
+ def source_type(self) -> str:
+ """Get the current data source type."""
+ return self._source.source_type if self._source else "none"
+
+ @property
+ def is_connected(self) -> bool:
+ """Check if data source is available."""
+ return self._source is not None
+
+ @property
+ def is_using_mongodb(self) -> bool:
+ """Check if currently using MongoDB."""
+ return (
+ self._use_mongodb and self._source and self._source.source_type == "mongodb"
)
- # Success: use result_code if available, fallback to success field
- result_code = data.get("result_code", 1)
- success = result_code == 0
-
- # Extract metrics count and types
- metrics = data.get("metrics", [])
- metric_types = [
- m.get("name", "").split(".")[0] for m in metrics if m.get("name")
- ]
-
- return {
- "path": json_path,
- "testcase": data.get("testcase", "unknown"),
- "run_id": data.get("run_id", "unknown"),
- "time": data.get("time", ""),
- "success": success,
- "result_code": result_code,
- "test_type": self._extract_test_type(data.get("testcase", "")),
- "operation": self._extract_operation(data.get("testcase", "")),
- "config": config,
- "resolved": resolved,
- "device_used": device_used,
- "nodes": nodes,
- "metrics_count": len(metrics),
- "metric_types": list(set(metric_types)),
- }
-
- def _get_csv_base_dir(self, json_data: Dict[str, Any], json_path: Path) -> Path:
- """Get the correct base directory for CSV files."""
- # First try: use output_dir from config
- config = json_data.get("config", {})
- output_dir = config.get("output_dir")
-
- if output_dir:
- output_path = Path(output_dir)
- if output_path.is_absolute():
- return output_path
- # Relative path: resolve relative to JSON file location
- return json_path.parent / output_dir
-
- # Second try: use JSON file's parent directory
- return json_path.parent
-
- def _resolve_csv_path(self, csv_url: str, base_dir: Path) -> Optional[Path]:
+ def switch_to_mongodb(self, mongo_config=None) -> bool:
"""
- Resolve CSV path from raw_data_url and base_dir.
+ Switch to MongoDB data source.
- Handles cases like:
- - base_dir/output/communication + "./comm/xxx.csv" but file is actually base_dir/"xxx.csv"
- - base_dir/output/infer + "./infer/xxx.csv" but file is base_dir/"xxx.csv"
+ Returns:
+ True if switch was successful
"""
- try:
- if not csv_url:
- return None
-
- # strip leading "./"
- rel = csv_url[2:] if csv_url.startswith("./") else csv_url
- rel_path = Path(rel)
-
- candidates = []
+ if mongo_config:
+ self._mongo_config = mongo_config
- # 1) base_dir / rel
- candidates.append(base_dir / rel_path)
+ mongo_source = MongoDataSource(self._mongo_config)
- # 2) base_dir / basename (most common fallback for your current layout)
- candidates.append(base_dir / rel_path.name)
-
- # 3) base_dir.parent / rel (just in case)
- candidates.append(base_dir.parent / rel_path)
-
- # 4) base_dir.parent / basename
- candidates.append(base_dir.parent / rel_path.name)
-
- for p in candidates:
- if p.exists():
- return p
+ if mongo_source.is_connected():
+ self._source = mongo_source
+ self._use_mongodb = True
+ return True
+ elif self._fallback_to_files:
+ logger.warning("Failed to switch to MongoDB, keeping current source")
+ return False
+ else:
+ raise RuntimeError("MongoDB connection failed")
- return None
- except Exception:
- return None
+ def switch_to_files(self, results_dir: str = None):
+ """Switch to file-based data source."""
+ if results_dir:
+ self.results_dir = Path(results_dir)
+ self._source = FileDataSource(str(self.results_dir))
+ self._use_mongodb = False
- def _extract_test_type(self, testcase: str) -> str:
- """Extract test type from testcase string."""
- parts = testcase.split(".")
- if len(parts) > 0:
- return parts[0] # comm, infer, operator, etc.
- return "unknown"
+ def list_test_runs(self, test_type: str = None) -> List[Dict[str, Any]]:
+ """List all test runs."""
+ if self._source is None:
+ return []
+ return self._source.list_test_runs(test_type)
- def _extract_operation(self, testcase: str) -> str:
- """Extract operation from testcase string."""
- parts = testcase.split(".")
- if len(parts) > 2:
- return parts[2] # AllReduce, Direct, Conv, etc.
- return "unknown"
+ def load_test_result(self, identifier) -> Dict[str, Any]:
+ """
+ Load a single test result with all data.
+ Args:
+ identifier: For file source, a Path to JSON file.
+ For MongoDB source, a run_id string.
+ """
+ if self._source is None:
+ return {}
+ return self._source.load_test_result(identifier)
-def load_summary_file(summary_path: str = "./summary_output") -> List[Dict[str, Any]]:
- """Load dispatcher summary files."""
- summaries = []
- summary_dir = Path(summary_path)
+ def load_summaries(self) -> List[Dict[str, Any]]:
+ """Load dispatcher summaries from the current data source."""
+ if self._source is None:
+ return []
+ return self._source.load_summaries()
- if summary_dir.exists():
- for json_file in sorted(
- summary_dir.glob("dispatcher_summary_*.json"), reverse=True
- ):
+ def load_csv_data(
+ self, csv_url: str, json_data: Dict[str, Any], json_path: Path
+ ) -> Optional[pd.DataFrame]:
+ """Load CSV data file using proper path resolution (file source only)."""
+ if isinstance(self._source, FileDataSource):
try:
- with open(json_file, "r", encoding="utf-8") as f:
- data = json.load(f)
- data["file"] = json_file.name
- data["timestamp"] = json_file.stem.replace("dispatcher_summary_", "")
- summaries.append(data)
- except Exception as e:
- logger.warning(f"Failed to load summary {json_file}: {e}")
+ if csv_url.startswith("http"):
+ return None
- return summaries
+ base_dir = self._source._get_csv_base_dir(json_data, json_path)
+ csv_path = self._source._resolve_csv_path(csv_url, base_dir)
-
-def get_friendly_size(size_bytes: int) -> str:
- """Convert bytes to human-readable size."""
- for unit in ["B", "KB", "MB", "GB", "TB"]:
- if size_bytes < 1024.0:
- return f"{size_bytes:.1f} {unit}"
- size_bytes /= 1024.0
- return f"{size_bytes:.1f} PB"
+ if csv_path and csv_path.exists():
+ return pd.read_csv(csv_path)
+ except Exception as e:
+ logger.error(f"Failed to load CSV {csv_url}: {e}")
+ return None
-def extract_accelerator_types(result_json: dict) -> list[str]:
- """
- Extract the accelerator card type from result_json
- """
- types = set()
- try:
- clusters = result_json.get("environment", {}).get("cluster", [])
- for node in clusters:
- accs = node.get("machine", {}).get("accelerators", [])
- for acc in accs:
- if "type" in acc:
- types.add(acc["type"])
- except Exception:
- pass
- return list(types)
+# Re-export from sibling modules
+from .data_sources import DataSource, FileDataSource, MongoDataSource
+from .data_utils import (
+ get_friendly_size,
+ extract_accelerator_types,
+ extract_run_info,
+)
+
+__all__ = [
+ "InfiniMetricsDataLoader",
+ "get_friendly_size",
+ "extract_accelerator_types",
+ "extract_run_info",
+]
diff --git a/dashboard/utils/data_sources.py b/dashboard/utils/data_sources.py
new file mode 100644
index 0000000..03607af
--- /dev/null
+++ b/dashboard/utils/data_sources.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+"""Data source implementations for InfiniMetrics dashboard."""
+
+import json
+import logging
+import sys
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import pandas as pd
+
+from .data_utils import extract_accelerator_types, extract_run_info, load_summary_file
+from db.utils import get_csv_base_dir, resolve_csv_path
+
+logger = logging.getLogger(__name__)
+
+
+class DataSource(ABC):
+ """Abstract data source for test results."""
+
+ @abstractmethod
+ def list_test_runs(self, test_type: str = None) -> List[Dict[str, Any]]:
+ """List all test runs."""
+ pass
+
+ @abstractmethod
+ def load_test_result(self, identifier) -> Dict[str, Any]:
+ """Load a single test result with full data."""
+ pass
+
+ @abstractmethod
+ def load_summaries(self) -> List[Dict[str, Any]]:
+ """Load dispatcher summaries."""
+ pass
+
+ @property
+ @abstractmethod
+ def source_type(self) -> str:
+ """Return the data source type name."""
+ pass
+
+
+class FileDataSource(DataSource):
+ """File-based data source (reads from JSON/CSV files)."""
+
+ def __init__(self, results_dir: str = "../output"):
+ self.results_dir = Path(results_dir)
+
+ @property
+ def source_type(self) -> str:
+ return "file"
+
+ def list_test_runs(self, test_type: str = None) -> List[Dict[str, Any]]:
+ """List all test runs, filtering out summary files."""
+ runs = []
+
+ for json_file in self.results_dir.rglob("*.json"):
+ try:
+ if (
+ "summary" in json_file.name.lower()
+ or "dispatcher" in json_file.name.lower()
+ ):
+ continue
+
+ with open(json_file, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ if not self._is_test_result_file(data):
+ continue
+
+ testcase = data.get("testcase", "")
+ if test_type and not testcase.startswith(test_type):
+ continue
+
+ run_info = extract_run_info(data, json_file)
+ run_info["accelerator_types"] = extract_accelerator_types(data)
+ runs.append(run_info)
+
+ except Exception as e:
+ logger.debug(f"Skipping file {json_file}: {e}")
+
+ runs.sort(key=lambda x: x["time"], reverse=True)
+ return runs
+
+ def load_test_result(self, json_path: Path) -> Dict[str, Any]:
+ """Load a single test result with all data."""
+ with open(json_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ for metric in data.get("metrics", []):
+ csv_url = metric.get("raw_data_url")
+ if csv_url and not csv_url.startswith("http"):
+ base_dir = get_csv_base_dir(data, json_path)
+ csv_path = resolve_csv_path(csv_url, base_dir)
+
+ if csv_path and csv_path.exists():
+ try:
+ df = pd.read_csv(csv_path)
+ metric["data"] = df
+ metric["data_columns"] = list(df.columns)
+ metric["csv_path"] = str(csv_path)
+ except Exception as e:
+ logger.warning(f"Failed to load CSV {csv_path}: {e}")
+ metric["data"] = None
+ else:
+ logger.debug(f"CSV not found: {csv_url} (base: {base_dir})")
+ metric["data"] = None
+
+ return data
+
+ def _is_test_result_file(self, data: Dict[str, Any]) -> bool:
+ """Check if JSON file is a test result (not a summary)."""
+ required = ["run_id", "testcase", "config"]
+ return all(key in data for key in required) and "metrics" in data
+
+ def load_summaries(self) -> List[Dict[str, Any]]:
+ """Load dispatcher summary files from summary_output directory."""
+ summary_dir = self.results_dir.parent / "summary_output"
+ return load_summary_file(str(summary_dir))
+
+
+class MongoDataSource(DataSource):
+ """MongoDB-based data source."""
+
+ def __init__(self, config=None):
+ self._config = config
+ self._client = None
+ self._repository = None
+ self._connected = False
+
+ def _connect(self):
+ """Lazy connection to MongoDB."""
+ if self._connected:
+ return self._connected
+
+ try:
+ project_root = Path(__file__).parent.parent.parent
+ if str(project_root) not in sys.path:
+ sys.path.insert(0, str(project_root))
+
+ from db import MongoDBClient, TestRunRepository
+
+ if self._config:
+ self._client = MongoDBClient(self._config)
+ else:
+ self._client = MongoDBClient()
+
+ if self._client.health_check():
+ from db.config import DatabaseConfig
+
+ config = self._config or DatabaseConfig.from_env()
+ self._repository = TestRunRepository(
+ self._client.get_collection(config.collection_name)
+ )
+ self._connected = True
+ logger.info("Connected to MongoDB data source")
+ else:
+ logger.warning("MongoDB health check failed")
+
+ except Exception as e:
+ logger.warning(f"Failed to connect to MongoDB: {e}")
+ self._connected = False
+
+ return self._connected
+
+ @property
+ def source_type(self) -> str:
+ return "mongodb"
+
+ def is_connected(self) -> bool:
+ """Check if MongoDB is connected."""
+ return self._connected or self._connect()
+
+ def list_test_runs(self, test_type: str = None) -> List[Dict[str, Any]]:
+ """List all test runs from MongoDB."""
+ if not self._connect():
+ logger.warning("MongoDB not connected, returning empty list")
+ return []
+
+ runs = self._repository.list_test_runs(test_type=test_type)
+ result = []
+
+ for run in runs:
+ run_info = extract_run_info(run)
+ run_info["accelerator_types"] = extract_accelerator_types(run)
+ result.append(run_info)
+
+ result.sort(key=lambda x: x["time"], reverse=True)
+ return result
+
+ def load_test_result(self, run_id: str) -> Dict[str, Any]:
+ """Load a single test result with full data from MongoDB."""
+ if not self._connect():
+ logger.warning("MongoDB not connected")
+ return {}
+
+ data = self._repository.find_by_run_id(run_id)
+ if not data:
+ return {}
+
+ for metric in data.get("metrics", []):
+ if "data" in metric and isinstance(metric["data"], list):
+ if metric["data"]:
+ metric["data"] = pd.DataFrame(metric["data"])
+ if "data_columns" not in metric:
+ metric["data_columns"] = list(metric["data"].columns)
+
+ data.pop("_id", None)
+ data.pop("_metadata", None)
+
+ return data
+
+ def load_summaries(self) -> List[Dict[str, Any]]:
+ """Load dispatcher summaries from MongoDB."""
+ if not self._connect():
+ logger.warning("MongoDB not connected, returning empty list")
+ return []
+
+ try:
+ from db import DispatcherSummaryRepository
+ from db.config import DatabaseConfig
+
+ config = self._config or DatabaseConfig.from_env()
+ summary_collection = self._client.get_collection(
+ config.summary_collection_name
+ )
+ summary_repo = DispatcherSummaryRepository(summary_collection)
+ summaries = summary_repo.list_summaries()
+
+ for s in summaries:
+ s.pop("_id", None)
+ s.pop("_metadata", None)
+
+ return summaries
+ except Exception as e:
+ logger.warning(f"Failed to load summaries from MongoDB: {e}")
+ return []
diff --git a/dashboard/utils/data_utils.py b/dashboard/utils/data_utils.py
new file mode 100644
index 0000000..33e6e50
--- /dev/null
+++ b/dashboard/utils/data_utils.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+"""Data utility functions for InfiniMetrics dashboard."""
+import json
+import logging
+
+from pathlib import Path
+from typing import Any, Dict, List
+
+
+def load_summary_file(summary_path: str = "../summary_output") -> List[Dict[str, Any]]:
+ """Load dispatcher summary files."""
+
+ logger = logging.getLogger(__name__)
+ summaries = []
+ summary_dir = Path(summary_path)
+
+ if summary_dir.exists():
+ for json_file in sorted(
+ summary_dir.glob("dispatcher_summary_*.json"), reverse=True
+ ):
+ try:
+ with open(json_file, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ data["file"] = json_file.name
+ data["timestamp"] = json_file.stem.replace("dispatcher_summary_", "")
+ summaries.append(data)
+ except Exception as e:
+ logger.warning(f"Failed to load summary {json_file}: {e}")
+
+ return summaries
+
+
+def get_friendly_size(size_bytes: int) -> str:
+ """Convert bytes to human-readable size."""
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
+ if size_bytes < 1024.0:
+ return f"{size_bytes:.1f} {unit}"
+ size_bytes /= 1024.0
+ return f"{size_bytes:.1f} PB"
+
+
+def extract_accelerator_types(result_json: dict) -> list[str]:
+ """Extract the accelerator card type from result_json."""
+ types = set()
+ try:
+ clusters = result_json.get("environment", {}).get("cluster", [])
+ for node in clusters:
+ accs = node.get("machine", {}).get("accelerators", [])
+ for acc in accs:
+ if "type" in acc:
+ types.add(acc["type"])
+ except Exception:
+ pass
+ return list(types)
+
+
+def extract_test_type(testcase: str) -> str:
+ """Extract test type from testcase string (e.g., 'comm.nccl.allreduce' -> 'comm')."""
+ parts = testcase.split(".")
+ return parts[0] if parts else "unknown"
+
+
+def extract_operation(testcase: str) -> str:
+ """Extract operation from testcase string (e.g., 'comm.nccl.allreduce' -> 'allreduce')."""
+ parts = testcase.split(".")
+ return parts[2] if len(parts) > 2 else "unknown"
+
+
+def extract_run_info(data: Dict[str, Any], path: Path = None) -> Dict[str, Any]:
+ """
+ Extract run info from test result data.
+
+ Args:
+ data: Test result JSON data
+ path: Optional file path (for file-based sources)
+
+ Returns:
+ Dictionary with extracted run information
+ """
+ config = data.get("config", {})
+ resolved = data.get("resolved", {})
+
+ device_used = (
+ resolved.get("device_used")
+ or config.get("device_used")
+ or config.get("device_involved", 1)
+ )
+
+ nodes = resolved.get("nodes") or data.get("environment", {}).get("cluster_scale", 1)
+
+ result_code = data.get("result_code", 1)
+ success = result_code == 0
+
+ metrics = data.get("metrics", [])
+ metric_types = [m.get("name", "").split(".")[0] for m in metrics if m.get("name")]
+
+ testcase = data.get("testcase", "")
+
+ return {
+ "path": path,
+ "testcase": testcase,
+ "run_id": data.get("run_id", "unknown"),
+ "time": data.get("time", ""),
+ "success": success,
+ "result_code": result_code,
+ "test_type": extract_test_type(testcase),
+ "operation": extract_operation(testcase),
+ "config": config,
+ "resolved": resolved,
+ "device_used": device_used,
+ "nodes": nodes,
+ "metrics_count": len(metrics),
+ "metric_types": list(set(metric_types)),
+ }
diff --git a/db/__init__.py b/db/__init__.py
index b5e4703..973a712 100644
--- a/db/__init__.py
+++ b/db/__init__.py
@@ -6,6 +6,7 @@
- MongoDB connection management
- Test result repository
- Dispatcher summary repository
+- Data import from JSON/CSV files
Usage:
from db import MongoDBClient, DatabaseConfig, TestRunRepository
@@ -18,12 +19,15 @@
test_runs = TestRunRepository(client.get_collection("test_runs"))
summaries = DispatcherSummaryRepository(client.get_collection("dispatcher_summaries"))
- # Query test runs
- runs = test_runs.find_all(limit=10)
+ # Import data
+ from db import DataImporter
+ importer = DataImporter(test_runs)
+ importer.import_directory(Path("./output"))
"""
from .client import MongoDBClient, MongoDBConnectionError
from .config import DatabaseConfig
+from .importer import DataImporter
from .repository import DispatcherSummaryRepository, TestRunRepository
__all__ = [
@@ -32,4 +36,5 @@
"MongoDBConnectionError",
"TestRunRepository",
"DispatcherSummaryRepository",
+ "DataImporter",
]
diff --git a/db/importer.py b/db/importer.py
new file mode 100644
index 0000000..12b6186
--- /dev/null
+++ b/db/importer.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+"""Data importer for loading JSON/CSV test results into MongoDB."""
+
+import json
+import logging
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+from .repository import TestRunRepository
+from .utils import (
+ get_csv_base_dir,
+ is_dispatcher_summary,
+ is_dispatcher_summary_file,
+ is_valid_test_result,
+ load_csv_data,
+ resolve_csv_path,
+ resolve_result_file_path,
+ should_skip_file,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DataImporter:
+ """
+ Import JSON/CSV test results to MongoDB.
+
+ Supports hierarchical structure:
+ - Dispatcher summary files (summary_output/dispatcher_summary_*.json)
+ - Individual test result files (output/*_results.json)
+ """
+
+ # Return status constants
+ STATUS_IMPORTED = "imported"
+ STATUS_SKIPPED = "skipped"
+ STATUS_FAILED = "failed"
+
+ def __init__(self, repository: TestRunRepository, base_dir: Optional[Path] = None):
+ self._repository = repository
+ self._base_dir = Path(base_dir) if base_dir else Path.cwd()
+
+ def import_dispatcher_summary(
+ self, summary_path: Path, overwrite: bool = False
+ ) -> Dict[str, Any]:
+ """Import a dispatcher summary file and all referenced test results."""
+ summary: Dict[str, Any] = {
+ "imported": [],
+ "skipped": [],
+ "failed": [],
+ "summary_file": str(summary_path),
+ }
+
+ try:
+ with open(summary_path, "r", encoding="utf-8") as f:
+ summary_data = json.load(f)
+
+ if not is_dispatcher_summary(summary_data):
+ logger.debug(f"Not a dispatcher summary: {summary_path}")
+ return summary
+
+ logger.info(
+ f"Processing dispatcher summary: {summary_path} "
+ f"({summary_data.get('total_tests', 0)} tests)"
+ )
+
+ for result_info in summary_data.get("results", []):
+ result_file = result_info.get("result_file")
+ if not result_file:
+ continue
+
+ result_path = resolve_result_file_path(
+ result_file, summary_path, self._base_dir
+ )
+
+ if not result_path or not result_path.exists():
+ logger.warning(f"Result file not found: {result_file}")
+ summary["failed"].append(result_file)
+ continue
+
+ imported_run_id, status = self.import_test_result(
+ result_path,
+ dispatcher_info={
+ "summary_file": str(summary_path),
+ "summary_timestamp": summary_data.get("timestamp"),
+ "total_tests": summary_data.get("total_tests"),
+ },
+ overwrite=overwrite,
+ )
+
+ if status == self.STATUS_IMPORTED:
+ summary["imported"].append(imported_run_id)
+ elif status == self.STATUS_SKIPPED:
+ summary["skipped"].append(
+ imported_run_id or result_info.get("run_id")
+ )
+ else:
+ summary["failed"].append(str(result_path))
+
+ except json.JSONDecodeError as e:
+ logger.error(f"Invalid JSON in {summary_path}: {e}")
+ except Exception as e:
+ logger.error(f"Failed to process summary {summary_path}: {e}")
+
+ logger.info(
+ f"Summary processed: {len(summary['imported'])} imported, "
+ f"{len(summary['skipped'])} skipped, {len(summary['failed'])} failed"
+ )
+ return summary
+
+ def import_test_result(
+ self,
+ result_path: Path,
+ dispatcher_info: Optional[Dict[str, Any]] = None,
+ overwrite: bool = False,
+ ) -> tuple[str, str]:
+ """
+ Import a single test result file to MongoDB.
+
+ Returns:
+ Tuple of (run_id, status) where status is one of:
+ - "imported": Successfully imported
+ - "skipped": Already exists or not a test result
+ - "failed": Error during import
+ """
+ try:
+ with open(result_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ if not is_valid_test_result(data):
+ logger.debug(f"Skipping non-test file: {result_path}")
+ return (None, self.STATUS_SKIPPED)
+
+ run_id = data.get("run_id")
+ if not run_id:
+ logger.warning(f"No run_id in {result_path}")
+ return (None, self.STATUS_SKIPPED)
+
+ if self._repository.exists(run_id) and not overwrite:
+ logger.debug(f"Skipping existing run_id: {run_id}")
+ return (run_id, self.STATUS_SKIPPED)
+
+ self._embed_csv_data(data, result_path)
+
+ data.setdefault("_metadata", {})
+ data["_metadata"]["source_file"] = str(result_path)
+ data["_metadata"]["version"] = "1.0"
+
+ if dispatcher_info:
+ data["_metadata"]["dispatcher"] = dispatcher_info
+
+ if overwrite and self._repository.exists(run_id):
+ inserted_id = self._repository.upsert(data)
+ else:
+ inserted_id = self._repository.insert(data)
+
+ return (run_id, self.STATUS_IMPORTED if inserted_id else self.STATUS_FAILED)
+
+ except json.JSONDecodeError as e:
+ logger.error(f"Invalid JSON in {result_path}: {e}")
+ return (None, self.STATUS_FAILED)
+ except Exception as e:
+ logger.error(f"Failed to import {result_path}: {e}")
+ return (None, self.STATUS_FAILED)
+
+ def import_directory(
+ self,
+ directory: Path,
+ recursive: bool = True,
+ overwrite: bool = False,
+ include_summaries: bool = True,
+ ) -> Dict[str, Any]:
+ """Import all JSON files from a directory."""
+ summary: Dict[str, Any] = {"imported": [], "skipped": [], "failed": []}
+
+ directory = Path(directory)
+ if not directory.exists():
+ logger.error(f"Directory not found: {directory}")
+ return summary
+
+ pattern = "**/*.json" if recursive else "*.json"
+ json_files = list(directory.glob(pattern))
+
+ logger.info(f"Found {len(json_files)} JSON files in {directory}")
+
+ for json_file in json_files:
+ if is_dispatcher_summary_file(json_file):
+ if include_summaries:
+ result = self.import_dispatcher_summary(json_file, overwrite)
+ summary["imported"].extend(result["imported"])
+ summary["skipped"].extend(result["skipped"])
+ summary["failed"].extend(result["failed"])
+ continue
+
+ if should_skip_file(json_file):
+ continue
+
+ run_id, status = self.import_test_result(json_file, overwrite=overwrite)
+ if status == self.STATUS_IMPORTED:
+ summary["imported"].append(run_id)
+ elif status == self.STATUS_SKIPPED:
+ if run_id:
+ summary["skipped"].append(run_id)
+ else:
+ summary["failed"].append(str(json_file))
+
+ logger.info(
+ f"Import completed: {len(summary['imported'])} imported, "
+ f"{len(summary['skipped'])} skipped, {len(summary['failed'])} failed"
+ )
+ return summary
+
+ def import_all(
+ self,
+ output_dir: Optional[Path] = None,
+ summary_dir: Optional[Path] = None,
+ overwrite: bool = False,
+ ) -> Dict[str, Any]:
+ """Import from both output and summary directories."""
+ combined: Dict[str, Any] = {"imported": [], "skipped": [], "failed": []}
+
+ if summary_dir:
+ summary_dir = Path(summary_dir)
+ if summary_dir.exists():
+ logger.info(f"Importing from summary directory: {summary_dir}")
+ for summary_file in sorted(
+ summary_dir.glob("dispatcher_summary_*.json")
+ ):
+ result = self.import_dispatcher_summary(summary_file, overwrite)
+ combined["imported"].extend(result["imported"])
+ combined["skipped"].extend(result["skipped"])
+ combined["failed"].extend(result["failed"])
+
+ if output_dir:
+ output_dir = Path(output_dir)
+ if output_dir.exists():
+ logger.info(f"Importing from output directory: {output_dir}")
+ result = self.import_directory(
+ output_dir,
+ recursive=True,
+ overwrite=overwrite,
+ include_summaries=False,
+ )
+ for run_id in result["imported"]:
+ if run_id not in combined["imported"]:
+ combined["imported"].append(run_id)
+ combined["skipped"].extend(result["skipped"])
+ combined["failed"].extend(result["failed"])
+
+ logger.info(
+ f"Total import: {len(combined['imported'])} imported, "
+ f"{len(combined['skipped'])} skipped, {len(combined['failed'])} failed"
+ )
+ return combined
+
+ def import_json_file(
+ self, json_path: Path, overwrite: bool = False
+ ) -> Optional[str]:
+ """Import a single JSON result file (legacy method)."""
+ run_id, status = self.import_test_result(json_path, overwrite=overwrite)
+ return run_id if status == self.STATUS_IMPORTED else None
+
+ def _embed_csv_data(self, data: Dict[str, Any], json_path: Path) -> None:
+ """Load CSV files and embed data into metrics."""
+ base_dir = get_csv_base_dir(data, json_path)
+
+ for metric in data.get("metrics", []):
+ csv_url = metric.get("raw_data_url")
+ if csv_url and not csv_url.startswith("http"):
+ csv_path = resolve_csv_path(csv_url, base_dir)
+ if csv_path and csv_path.exists():
+ try:
+ csv_data = load_csv_data(csv_path)
+ metric["data"] = csv_data
+ metric["data_columns"] = (
+ list(csv_data[0].keys()) if csv_data else []
+ )
+ logger.debug(f"Embedded CSV data from {csv_path}")
+ except Exception as e:
+ logger.warning(f"Failed to load CSV {csv_path}: {e}")
diff --git a/db/utils.py b/db/utils.py
new file mode 100644
index 0000000..f950f7f
--- /dev/null
+++ b/db/utils.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env python3
+"""Utility functions for file handling and data processing."""
+
+import csv
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+logger = logging.getLogger(__name__)
+
+
+# ==================== CSV Utilities ====================
+
+
+def load_csv_data(csv_path: Path) -> List[Dict[str, Any]]:
+ """Load CSV file as list of dictionaries."""
+ with open(csv_path, "r", encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ return [convert_csv_row(row) for row in reader]
+
+
+def convert_csv_row(row: Dict[str, str]) -> Dict[str, Any]:
+ """Convert CSV row values to appropriate types."""
+ result = {}
+ for key, value in row.items():
+ try:
+ if "." in value:
+ result[key] = float(value)
+ else:
+ result[key] = int(value)
+ except (ValueError, TypeError):
+ result[key] = value
+ return result
+
+
+def resolve_csv_path(csv_url: str, base_dir: Path) -> Optional[Path]:
+ """Resolve CSV path with fallback strategies."""
+ if not csv_url:
+ return None
+
+ rel = csv_url[2:] if csv_url.startswith("./") else csv_url
+ rel_path = Path(rel)
+
+ candidates = [
+ base_dir / rel_path,
+ base_dir / rel_path.name,
+ base_dir.parent / rel_path,
+ base_dir.parent / rel_path.name,
+ ]
+
+ for p in candidates:
+ if p.exists():
+ return p
+ return None
+
+
+def get_csv_base_dir(data: Dict[str, Any], json_path: Path) -> Path:
+ """Get base directory for CSV resolution."""
+ config = data.get("config", {})
+ output_dir = config.get("output_dir")
+ if output_dir:
+ output_path = Path(output_dir)
+ if output_path.is_absolute():
+ return output_path
+ return json_path.parent / output_dir
+ return json_path.parent
+
+
+# ==================== File Type Detection ====================
+
+
+def is_valid_test_result(data: Dict[str, Any]) -> bool:
+ """Check if data is a valid test result."""
+ required = ["run_id", "testcase", "config"]
+ return all(k in data for k in required) and "metrics" in data
+
+
+def is_dispatcher_summary(data: Dict[str, Any]) -> bool:
+ """Check if data is a dispatcher summary file."""
+ return "results" in data and "total_tests" in data
+
+
+def is_dispatcher_summary_file(path: Path) -> bool:
+ """Check if file is a dispatcher summary based on name."""
+ name_lower = path.name.lower()
+ return "dispatcher_summary" in name_lower or (
+ "summary" in name_lower and "dispatcher" in name_lower
+ )
+
+
+def should_skip_file(path: Path) -> bool:
+ """Check if file should be skipped."""
+ name_lower = path.name.lower()
+ if "summary" in name_lower and not is_dispatcher_summary_file(path):
+ return True
+ return False
+
+
+# ==================== Path Resolution ====================
+
+
+# Directory name aliases for flexible path resolution
+_DIR_ALIASES = {
+ "comm": "communication",
+ "communication": "comm",
+ "infer": "inference",
+ "inference": "infer",
+ "hw": "hardware",
+ "hardware": "hw",
+ "op": "operators",
+ "operators": "op",
+}
+
+
+def _get_path_variants(result_path: Path) -> List[Path]:
+ """Generate path variants with directory name aliases."""
+ variants = [result_path]
+ parts = list(result_path.parts)
+
+ for i, part in enumerate(parts):
+ if part in _DIR_ALIASES:
+ new_parts = parts.copy()
+ new_parts[i] = _DIR_ALIASES[part]
+ variants.append(Path(*new_parts))
+
+ # Also try test_output <-> output mapping
+ result_str = str(result_path)
+ if result_str.startswith("test_output/"):
+ variants.append(Path(result_str.replace("test_output/", "output/", 1)))
+ elif result_str.startswith("output/"):
+ variants.append(Path(result_str.replace("output/", "test_output/", 1)))
+
+ return variants
+
+
+def resolve_result_file_path(
+ result_file: str, summary_path: Path, base_dir: Path
+) -> Optional[Path]:
+ """Resolve result file path from dispatcher summary reference."""
+ result_path = Path(result_file)
+
+ if result_path.is_absolute():
+ return result_path
+
+ # Generate path variants with directory name aliases
+ path_variants = _get_path_variants(result_path)
+
+ for variant in path_variants:
+ candidates = [
+ base_dir / variant,
+ summary_path.parent.parent / variant,
+ summary_path.parent / variant,
+ ]
+ for p in candidates:
+ if p.exists():
+ return p
+
+ # Last resort: search by filename in base_dir subdirectories
+ filename = result_path.name
+ for p in base_dir.rglob(filename):
+ return p
+
+ return None
diff --git a/infinimetrics/hardware/hardware_adapter.py b/infinimetrics/hardware/hardware_adapter.py
index 2de5b4a..426c60a 100644
--- a/infinimetrics/hardware/hardware_adapter.py
+++ b/infinimetrics/hardware/hardware_adapter.py
@@ -72,7 +72,8 @@ def process(self, test_input: Any) -> Dict[str, Any]:
logger.info(f"HardwareTestAdapter: Processing {testcase}")
- self.output_dir = Path(config.get("output_dir", "./output"))
+ # Put CSV files in hardware/ subdirectory to match JSON location
+ self.output_dir = Path(config.get("output_dir", "./output")) / "hardware"
self.output_dir.mkdir(parents=True, exist_ok=True)
device = config.get("device", "cuda").lower()
diff --git a/pyproject.toml b/pyproject.toml
index e4b3ab2..50c6741 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,9 @@ exclude = '''
)/
'''
+[tool.setuptools.packages.find]
+include = ["infinimetrics*", "dashboard*", "db*"]
+
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]