-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdiff.py.rej
More file actions
38 lines (34 loc) · 1.39 KB
/
diff.py.rej
File metadata and controls
38 lines (34 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
diff a/bitdelta/diff.py b/bitdelta/diff.py (rejected hunks)
@@ -73,24 +86,31 @@ def save_diff(finetuned_compressed_model, save_dir):
diff_dict[name + ".coeff"] = module.coeff.cpu()
for name, param in finetuned_compressed_model.named_parameters():
+ if "mlp" in name or "self_attn" in name:
+ if Pass(layers,name) == True:
+ continue
+
if param.requires_grad:
diff_dict[name] = param.cpu()
-
+
+ # import pdb; pdb.set_trace()
torch.save(diff_dict, save_dir)
@torch.no_grad()
def load_diff(model, diff_dir):
device = model.device
diff_dict = torch.load(diff_dir)
-
+
for name, module in model.named_modules():
if name + ".mask" in diff_dict:
coeff = diff_dict[name + ".coeff"].to(device)
mask = diff_dict[name + ".mask"].to(device)
- setattr(module, "mask", mask)
- setattr(module, "coeff", coeff)
- # module.weight.add_((mask * coeff).to(module.weight.dtype))
+ # setattr(module, "mask", mask)
+ # setattr(module, "coeff", coeff)
+ weight = (unpack(mask)*2-1) * coeff
+
+ module.weight.add_(weight.T.to(module.weight.dtype))
elif name + ".weight" in diff_dict:
module.weight = nn.Parameter(diff_dict[name + ".weight"].to(device).to(module.weight.dtype))