-
Notifications
You must be signed in to change notification settings - Fork 5
Add tests/fixes for inplace rules #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I'd like to do a bit of tidying up with this one but it can be merged if tests pass and people feel like it, I can just make a separate PR. |
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
| MatrixAlgebraKit.zero!(darg2) | ||
| $pb(dA, A, (arg1, arg2), (darg1, darg2)) | ||
| zero!(darg1) | ||
| zero!(darg2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems like a lot more sensible order. I assume this was working before because arg were not modified in between the forward and the backward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thus the value of testing this a little more directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although wait, I am now confused. The copy of arg is made before the primal call. So this restores the state of arg to that before it got the actual output values assigned into it, and uses these values in the pullback call. How can this then yield the correct result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to have the actual computed values of arg available at the time of calling the pullback (assuming that someone might have modified the return values after the primal call), while also being able to restore the state of arg to value before the primal call, it seems like we need to independent copies of arg. One before calling the primal (to restore after the pullback call), and one copy after calling the primal (as a cache to be used in the pullback, in case someone would be destroying the values) in between primal and pullback call. Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is what it is supposed to do, you have to restore the state to what it was before the primal call executes, such that primal calls that would have preceded it now have the correct values again.
Think of something like:
Q, R = qr_compact!(A, (Q, R), ...)
x = f(Q, R) # some computation that depends on Q and R
Q2, R2 = qr_compact!(A, (Q, R), ...)
y = g(Q2, R2) # some computation that depends on Q2 and R2
return x + yIn the reverse pass, the second qr_compact! rrule is executed, followed by the f rrule, and in order for the f rrule to "see" the correct values of Q and R, the qr_compact! rrule has to restore them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I would assume that we do not need to restore the values that are computed after the primal, since we should assume that no one is destroying the values without restoring them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is in line with my original interpretation, which is why I thought the original order was correct and I don't understand the new order. However, the new order got my thinking confused.
Added more tests to
test_pullbacks_matchto make sure the state of the arguments is restored, and the final argument derivatives match between inplace and non in place methods.Unfortunately, the Mooncake FD tester doesn't work well for our functions, because
Abecomes a scratch space, and the inputs are also the outputs (so get incremented twice under the FD scheme).