forked from hesiod-au/python-mcp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcode_grapher.py
More file actions
669 lines (553 loc) · 30 KB
/
code_grapher.py
File metadata and controls
669 lines (553 loc) · 30 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
import ast
import os
import re
import importlib.util
from pathlib import Path
from typing import Dict, List, Set, Any, Optional, Tuple
import sys
# Silence all debug and error prints in CodeGrapher
print = lambda *args, **kwargs: None
class CodeGrapher:
"""
Extract and navigate Python code structure across files.
This class parses Python code, follows imports and references,
and extracts structured code snippets up to a token limit.
Attributes:
token_limit (int): Maximum number of tokens to include in output.
visited_files (Set[str]): Set of file paths that have been processed.
referenced_objects (List[Dict[str, Any]]): List of objects referenced in the code.
"""
def __init__(self, token_limit: int = 8000) -> None:
"""
Initialize the CodeGrapher.
Args:
token_limit: Maximum number of tokens to include in output.
"""
self.token_limit: int = token_limit
self.visited_files: Set[str] = set()
self.referenced_objects: List[Dict[str, Any]] = []
def extract_code(
self,
target_file: str,
target_object: Optional[str] = None,
token_limit: Optional[int] = None,
project_root: Optional[str] = None
) -> Dict[str, Any]:
"""
Extract code from a file, optionally focusing on a specific object.
This method parses the target file, extracts the specified object (or the
entire file if no object is specified), and follows imports to build a
comprehensive code representation up to the token limit.
Args:
target_file: Path to the Python file to analyze.
target_object: Name of specific class or function to extract.
If None, extracts the entire file.
token_limit: Override the default token limit.
If None, uses the limit specified during initialization.
project_root: The root directory of the project. Used to determine if a file
is within the project or an external library.
Returns:
A dictionary containing:
- 'main_object': Information about the primary extracted object
- 'referenced_objects': List of objects referenced by the main object
- 'token_count': Total number of tokens in the extracted code
- 'token_limit': The token limit used for extraction
- 'error': Error message if extraction failed (only present on error)
"""
# Reset state for new extraction
self.visited_files = set()
self.referenced_objects = []
# Update token limit if specified
if token_limit is not None:
self.token_limit = token_limit
# Convert to absolute path
target_file = os.path.abspath(target_file)
# Set project_root if not provided
if project_root is None:
project_root = os.path.dirname(target_file)
project_root = os.path.abspath(os.path.normpath(project_root))
# Parse the target file
ast_tree, source_code = self._parse_file(target_file)
if not ast_tree:
return {"error": f"Failed to parse file: {target_file}"}
# Extract the main object or whole file
main_object = None
if target_object:
main_object = self._extract_object(ast_tree, source_code, target_object, target_file)
if not main_object:
return {"error": f"Object '{target_object}' not found in {target_file}"}
else:
# Extract the entire module as main object
module_code = source_code
main_object = {
"name": os.path.basename(target_file).replace(".py", ""),
"file": target_file,
"type": "module",
"code": module_code,
"docstring": ast.get_docstring(ast_tree) or ""
}
# Mark the target file as visited
self.visited_files.add(target_file)
# Resolve and follow imports, but only within the project
self._resolve_imports(ast_tree, target_file)
# Filter out referenced objects from external libraries
self.referenced_objects = [
obj for obj in self.referenced_objects
if not self._is_external_library(obj["file"]) and os.path.abspath(obj["file"]).startswith(project_root)
]
# Count tokens
main_token_count = self._count_tokens(main_object["code"])
# Create result structure
result = {
"main_object": main_object,
"referenced_objects": self.referenced_objects.copy(),
"token_count": main_token_count + sum(self._count_tokens(obj["code"]) for obj in self.referenced_objects),
"token_limit": self.token_limit
}
# Prioritize and trim code if needed
if result["token_count"] > self.token_limit:
result = self._prioritize_code(result)
return result
def _parse_file(self, filepath: str) -> Tuple[Optional[ast.Module], Optional[str]]:
"""
Parse a Python file into an AST.
Reads the file content and parses it into an Abstract Syntax Tree for analysis.
Args:
filepath: Path to the Python file.
Returns:
A tuple containing:
- The AST tree of the parsed file (or None if parsing failed)
- The source code of the file (or None if parsing failed)
"""
try:
with open(filepath, 'r', encoding='utf-8') as file:
source_code = file.read()
return ast.parse(source_code), source_code
except Exception as e:
print(f"Error parsing {filepath}: {e}")
return None, None
def _extract_object(
self,
ast_tree: ast.Module,
source_code: str,
object_name: str,
file_path: str
) -> Optional[Dict[str, Any]]:
"""
Extract a specific class or function from the AST.
Traverses the AST to find the specified object and extracts its code,
type, and docstring.
Args:
ast_tree: The AST of the module.
source_code: Source code of the file.
object_name: Name of the object to extract.
file_path: Path to the file containing the object.
Returns:
A dictionary containing information about the extracted object with fields:
- 'name': The name of the object
- 'file': Path to the file containing the object
- 'type': Type of the object ('class' or 'function')
- 'code': The complete code of the object
- 'docstring': The docstring of the object (or empty string)
Returns None if the object is not found.
"""
for node in ast.walk(ast_tree):
if (isinstance(node, (ast.ClassDef, ast.FunctionDef)) and
node.name == object_name):
# Get the code lines for this node
if hasattr(node, 'lineno') and hasattr(node, 'end_lineno'):
# Get line numbers (accounting for different Python versions)
start_line = node.lineno
# In some Python versions, end_lineno might not be available
end_line: Optional[int] = getattr(node, 'end_lineno', None)
if end_line is None:
# If end_lineno is not available, estimate by counting lines in the source
lines = source_code.splitlines()
depth = 0
in_object = False
end_line = start_line
for i, line in enumerate(lines[start_line-1:], start=start_line):
if not in_object and (line.strip().startswith(f"def {object_name}") or
line.strip().startswith(f"class {object_name}")):
in_object = True
if in_object:
# Count indentation to track when we exit the block
stripped = line.lstrip()
indent = len(line) - len(stripped)
if stripped and indent == 0 and i > start_line:
end_line = i - 1
break
end_line = i
# Extract the code
lines = source_code.splitlines()
code_lines = lines[start_line-1:end_line]
code = "\n".join(code_lines)
# Determine the type
obj_type = "class" if isinstance(node, ast.ClassDef) else "function"
return {
"name": node.name,
"file": file_path,
"type": obj_type,
"code": code,
"docstring": ast.get_docstring(node) or ""
}
return None
def _resolve_imports(self, ast_tree: ast.Module, file_path: str) -> None:
"""
Resolve imports in the AST and follow references.
Analyzes import statements in the code and processes the imported modules
and objects to build a graph of code references.
Args:
ast_tree: The AST of the module.
file_path: Path to the file containing the AST.
"""
file_dir = os.path.dirname(file_path)
print(f"DEBUG: Resolving imports in file: {file_path}")
# Get the project root directory (assuming it's a parent of file_path)
project_root = file_dir
while project_root and not os.path.exists(os.path.join(project_root, '.git')):
parent = os.path.dirname(project_root)
if parent == project_root: # Reached root directory
project_root = file_dir # Fallback to file directory
break
project_root = parent
print(f"DEBUG: Using project root: {project_root}")
# Track import statements
for node in ast.walk(ast_tree):
# Handle 'import module' statements
if isinstance(node, ast.Import):
for name in node.names:
module_name = name.name
self._process_imported_module(module_name, file_dir)
# Try to find the module in the project directory
self._try_find_project_module(module_name, project_root, file_dir)
# Handle 'from module import name' statements
elif isinstance(node, ast.ImportFrom):
if node.module: # Skip relative imports without module
module_name = node.module
for name in node.names:
imported_name = name.name
self._process_imported_object(module_name, imported_name, file_dir)
# Try to find the module in the project directory
self._try_find_project_module(module_name, project_root, file_dir)
def _process_imported_module(self, module_name: str, file_dir: str) -> None:
"""
Process an imported module and extract its code.
Attempts to locate the file for an imported module and extracts all
classes and functions from it. Only processes files within the project directory.
Args:
module_name: Name of the imported module.
file_dir: Directory of the file with the import.
"""
print(f"DEBUG: Processing imported module: {module_name} from {file_dir}")
# Try to find the module file
try:
# First try in the same directory
local_module_path = os.path.join(file_dir, f"{module_name.split('.')[-1]}.py")
print(f"DEBUG: Checking local path: {local_module_path}")
if os.path.exists(local_module_path):
module_path = local_module_path
print(f"DEBUG: Found module in local path: {module_path}")
else:
# Try to resolve using Python's import system
print(f"DEBUG: Trying to resolve using importlib: {module_name}")
spec = importlib.util.find_spec(module_name)
if spec and spec.origin and spec.origin.endswith('.py'):
module_path = spec.origin
print(f"DEBUG: Found module using importlib: {module_path}")
else:
# Skip if we can't find the module
print(f"DEBUG: Could not find module: {module_name}")
return
# Skip if already visited
if module_path in self.visited_files:
print(f"DEBUG: Module already visited: {module_path}")
return
# Skip system libraries and files outside the project
if self._is_external_library(module_path):
print(f"DEBUG: Skipping external library: {module_path}")
return
# Parse the module
print(f"DEBUG: Parsing module: {module_path}")
ast_tree, source_code = self._parse_file(module_path)
if ast_tree and source_code:
# Add the module file to visited
self.visited_files.add(module_path)
print(f"DEBUG: Added to visited files: {module_path}")
# Extract each class and function from the module
extracted_count = 0
for node in ast.walk(ast_tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
obj = self._extract_object(ast_tree, source_code, node.name, module_path)
if obj:
obj["reference_type"] = "import"
self.referenced_objects.append(obj)
extracted_count += 1
print(f"DEBUG: Extracted {extracted_count} objects from {module_path}")
# Recursively resolve imports in this module
print(f"DEBUG: Resolving imports in {module_path}")
self._resolve_imports(ast_tree, module_path)
else:
print(f"DEBUG: Failed to parse module: {module_path}")
except Exception as e:
print(f"Error processing import {module_name}: {e}")
def _process_imported_object(self, module_name: str, object_name: str, file_dir: str) -> None:
"""
Process a specific imported object and extract its code.
Locates and extracts a specific object (class or function) from an imported module.
Only processes files within the project directory.
Args:
module_name: Name of the module containing the object.
object_name: Name of the imported object.
file_dir: Directory of the file with the import.
"""
print(f"DEBUG: Processing imported object: {module_name}.{object_name} from {file_dir}")
# Similar to _process_imported_module but focusing on a specific object
try:
# First try in the same directory
local_module_path = os.path.join(file_dir, f"{module_name.split('.')[-1]}.py")
print(f"DEBUG: Checking local path: {local_module_path}")
if os.path.exists(local_module_path):
module_path = local_module_path
print(f"DEBUG: Found module in local path: {module_path}")
else:
# Try to resolve using Python's import system
print(f"DEBUG: Trying to resolve using importlib: {module_name}")
spec = importlib.util.find_spec(module_name)
if spec and spec.origin and spec.origin.endswith('.py'):
module_path = spec.origin
print(f"DEBUG: Found module using importlib: {module_path}")
else:
# Skip if we can't find the module
print(f"DEBUG: Could not find module: {module_name}")
return
# Skip already processed objects
for obj in self.referenced_objects:
if obj["name"] == object_name and obj["file"] == module_path:
print(f"DEBUG: Object already processed: {object_name} in {module_path}")
return
# Skip system libraries and files outside the project
if self._is_external_library(module_path):
print(f"DEBUG: Skipping external library: {module_path}")
return
# Parse the module
print(f"DEBUG: Parsing module for object: {module_path}")
ast_tree, source_code = self._parse_file(module_path)
if ast_tree and source_code:
# Add the module file to visited if not already
if module_path not in self.visited_files:
self.visited_files.add(module_path)
print(f"DEBUG: Added to visited files: {module_path}")
# Also process other imports in this module
print(f"DEBUG: Resolving imports in {module_path}")
self._resolve_imports(ast_tree, module_path)
# Extract the specific object
print(f"DEBUG: Extracting object: {object_name} from {module_path}")
obj = self._extract_object(ast_tree, source_code, object_name, module_path)
if obj:
obj["reference_type"] = "import"
self.referenced_objects.append(obj)
print(f"DEBUG: Successfully extracted object: {object_name} from {module_path}")
else:
print(f"DEBUG: Failed to extract object: {object_name} from {module_path}")
else:
print(f"DEBUG: Failed to parse module: {module_path}")
except Exception as e:
print(f"Error processing imported object {module_name}.{object_name}: {e}")
def _count_tokens(self, code_string: str) -> int:
"""
Count tokens in a code string.
Provides an approximate token count by splitting on whitespace and punctuation.
This is a simple approximation - for more accurate token counting, consider
using the 'tokenize' module or a dedicated tokenizer for the target model.
Args:
code_string: The code string to count tokens for.
Returns:
Approximate token count.
"""
# Simple approximation - split on whitespace and common punctuation
# This is a rough estimate; a proper tokenizer would be more accurate
token_pattern = r'[\s\(\)\[\]\{\}:;,\.\"\']+'
tokens = re.split(token_pattern, code_string)
return len([t for t in tokens if t]) # Count non-empty tokens
def _prioritize_code(self, result_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Prioritize code to fit within the token limit.
When the total extracted code exceeds the token limit, this method
intelligently selects which parts to keep, prioritizing:
1. The main object (always kept)
2. Classes over functions
3. Shorter code over longer code
For objects that can't be included in full, it preserves their signatures
and docstrings while truncating the implementation.
Args:
result_dict: The result dictionary with code objects.
Returns:
Updated result dictionary with prioritized code.
"""
# Always include the main object
main_object = result_dict["main_object"]
main_tokens = self._count_tokens(main_object["code"])
# Sort referenced objects by importance
# Prioritize classes over functions, shorter code over longer code
def priority_key(obj: Dict[str, Any]) -> Tuple[int, int]:
# Lower score means higher priority
type_score = 0 if obj["type"] == "class" else 1
size_score = self._count_tokens(obj["code"])
return (type_score, size_score)
sorted_refs = sorted(result_dict["referenced_objects"], key=priority_key)
# Keep adding references until we hit the token limit
kept_refs: List[Dict[str, Any]] = []
current_tokens = main_tokens
for ref in sorted_refs:
ref_tokens = self._count_tokens(ref["code"])
if current_tokens + ref_tokens <= self.token_limit:
kept_refs.append(ref)
current_tokens += ref_tokens
else:
# If the reference is too large, try to include just the signature
# For classes, include class definition and docstring
# For functions, include function signature and docstring
if ref["type"] == "class":
# Extract class definition line and docstring
lines = ref["code"].splitlines()
class_def = next((l for l in lines if l.strip().startswith("class ")), "")
truncated_code = f"{class_def}\n \"\"\"" + ref["docstring"] + "\"\"\"\n # ... code truncated due to token limit"
elif ref["type"] == "function":
# Extract function signature and docstring
lines = ref["code"].splitlines()
func_def = next((l for l in lines if l.strip().startswith("def ")), "")
truncated_code = f"{func_def}\n \"\"\"" + ref["docstring"] + "\"\"\"\n # ... code truncated due to token limit"
else:
truncated_code = f"# {ref['name']} (truncated due to token limit)"
truncated_ref = ref.copy()
truncated_ref["code"] = truncated_code
truncated_ref["truncated"] = True
truncated_tokens = self._count_tokens(truncated_code)
if current_tokens + truncated_tokens <= self.token_limit:
kept_refs.append(truncated_ref)
current_tokens += truncated_tokens
# Update the result
result_dict["referenced_objects"] = kept_refs
result_dict["token_count"] = current_tokens
result_dict["truncated"] = len(kept_refs) < len(sorted_refs)
return result_dict
def _try_find_project_module(self, module_name: str, project_root: str, file_dir: str) -> None:
"""
Try to find a module within the project directory structure.
This method attempts to locate a Python module within the project by searching
for files with matching names, regardless of their location in the project hierarchy.
Args:
module_name: Name of the module to find.
project_root: Root directory of the project.
file_dir: Directory of the file with the import.
"""
print(f"DEBUG: Trying to find project module: {module_name} in {project_root}")
# Extract the base module name (without submodules)
base_module = module_name.split('.')[0]
# Look for Python files with the module name
for root, _, files in os.walk(project_root):
# Skip external libraries and cache directories
if self._is_external_library(root) or any(d in root for d in ['__pycache__', '.git']):
continue
for file in files:
if file == f"{base_module}.py":
module_path = os.path.join(root, file)
# Skip if already visited
if module_path in self.visited_files:
print(f"DEBUG: Project module already visited: {module_path}")
continue
print(f"DEBUG: Found project module: {module_path}")
# Parse the module
print(f"DEBUG: Parsing module: {module_path}")
ast_tree, source_code = self._parse_file(module_path)
if ast_tree and source_code:
# Add the module file to visited
self.visited_files.add(module_path)
print(f"DEBUG: Added project module to visited files: {module_path}")
# Extract each class and function from the module
extracted_count = 0
for node in ast.walk(ast_tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
obj = self._extract_object(ast_tree, source_code, node.name, module_path)
if obj:
obj["reference_type"] = "project_import"
self.referenced_objects.append(obj)
extracted_count += 1
print(f"DEBUG: Extracted {extracted_count} objects from project module: {module_path}")
# Recursively resolve imports in this module
self._resolve_imports(ast_tree, module_path)
# We found the module, no need to continue searching
return
def _is_external_library(self, file_path: str) -> bool:
"""
Determines if a file path belongs to an external library or is outside the project.
Args:
file_path: The path to check
Returns:
True if the file is from an external library or outside the project, False otherwise
"""
# Normalize the path
file_path = os.path.abspath(os.path.normpath(file_path))
# Check for common external library indicators in the path
external_indicators = {
'/usr/lib/', '/usr/local/lib/', 'site-packages/', 'dist-packages/',
'.venv/', 'venv/', 'env/', '/lib/python', '/Lib/python'
}
# Check if the path contains any of the external indicators
for indicator in external_indicators:
if indicator in file_path:
return True
return False
def find_all_python_files(self, root_path: str) -> List[str]:
"""
Finds all Python files in the specified directory, strictly excluding:
- External libraries (like system libraries or those in .venv)
- Cache directories
- Any files outside the project root
Args:
root_path: The root directory of the project
Returns:
List of absolute paths to Python files within the project
"""
python_files = []
# Convert root_path to absolute and normalized path
root_path = os.path.abspath(os.path.normpath(root_path))
print(f"DEBUG: Finding Python files in project root: {root_path}")
# Directories to exclude (common patterns for virtual environments, caches, etc.)
excluded_dirs = {
'__pycache__', 'venv', 'env', '.venv', '.env', 'site-packages',
'dist-packages', 'lib', 'Lib', 'node_modules', 'build', 'dist',
'.git', '.github', '.pytest_cache', '.mypy_cache', '.tox', 'egg-info'
}
# Path segments that indicate external libraries
excluded_path_segments = {
'site-packages', 'dist-packages', 'lib/python', 'Lib/python'
}
for root, dirs, files in os.walk(root_path):
# Filter out excluded directories
original_dirs = set(dirs)
dirs[:] = [d for d in dirs if d not in excluded_dirs and not d.startswith('.')]
if len(original_dirs) != len(dirs):
print(f"DEBUG: Excluded directories in {root}: {original_dirs - set(dirs)}")
# Skip this directory if it contains excluded path segments
if any(segment in root for segment in excluded_path_segments):
print(f"DEBUG: Skipping directory with excluded path segment: {root}")
continue
# Ensure we're still within the project root (protects against symlinks)
if not os.path.abspath(root).startswith(root_path):
print(f"DEBUG: Skipping directory outside project root: {root}")
continue
for file in files:
if file.endswith('.py'):
file_path = os.path.join(root, file)
# Final check to ensure the file is within the project
if os.path.abspath(file_path).startswith(root_path):
# Check if it's an external library
if self._is_external_library(file_path):
print(f"DEBUG: Skipping external library file: {file_path}")
continue
python_files.append(file_path)
print(f"DEBUG: Found Python file: {file_path}")
print(f"DEBUG: Found {len(python_files)} Python files in total")
return python_files