-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample.py
More file actions
100 lines (73 loc) · 2.17 KB
/
example.py
File metadata and controls
100 lines (73 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
"""
Example usage of GPU Performance Diagnosis Tool
with an intentionally unoptimized PyTorch model.
"""
import torch
import torch.nn as nn
from main import profile_model
# ----------------------------
# Bad / Unoptimized Model
# ----------------------------
class BadMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
outputs = []
# ❌ Python loop over batch
for i in range(x.shape[0]):
xi = x[i]
# ❌ Tiny ops, no fusion
h = self.fc1(xi)
h = torch.relu(h)
# ❌ Redundant ops
h = self.fc2(h)
h = torch.relu(h)
h = h + 0.0
# ❌ Tensor-dependent control flow
if h.mean().item() > 0:
h = h * 1.0
else:
h = h - 0.0
out = self.fc3(h)
outputs.append(out)
return torch.stack(outputs, dim=0)
# ----------------------------
# Main
# ----------------------------
def main():
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for profiling")
device = "cuda"
# Create model
model = BadMLP(
input_dim=128,
hidden_dim=256,
output_dim=10
).to(device)
model.eval()
batch_size = 32
# Input generator required by profile_model
def input_gen():
return torch.randn(batch_size, 128, device=device)
# Profile model
results = profile_model(
model=model,
input_generator=input_gen,
steps=50,
device=device,
warmup_steps=5
)
# Access diagnosis results
diagnosis = results["diagnosis"]
bottlenecks = diagnosis.get("bottlenecks", [])
primary = diagnosis.get("primary_bottleneck", None)
print("\nDetected bottlenecks:")
for b in bottlenecks:
print(f" - {b}")
print(f"\nPrimary bottleneck: {primary}")
if __name__ == "__main__":
main()