diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b34f4a943..87f41fbc9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -318,6 +318,9 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + if out is not None: + out.copy_(output) + output = out # 3. Save state ctx.state = quant_state