-
Notifications
You must be signed in to change notification settings - Fork 724
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable #2938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
86b9199
1eab899
aec7013
2aca498
4d7be63
96d99ba
9a77ebf
72f2a57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,7 @@ namespace jax { | |
| struct ClampedSwigluConfig { | ||
| float limit; | ||
| float alpha; | ||
| float glu_linear_offset; | ||
| }; | ||
|
|
||
| struct ActivationConfig { | ||
|
|
@@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); | |
|
|
||
| XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, | ||
| ::xla::ffi::StructMember<float>("limit"), | ||
| ::xla::ffi::StructMember<float>("alpha")); | ||
| ::xla::ffi::StructMember<float>("alpha"), | ||
| ::xla::ffi::StructMember<float>("glu_linear_offset")); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add a default value for users on HLO from a previous version? Would glu_linear_offset=1 be the same as the current behavior on main?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, glu_linear_offset=1 is consistent with the current behavior. Could you point me on how to add the default value on HLO? Thanks. |
||
|
|
||
| XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( | ||
| transformer_engine::jax::ActivationConfig, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update one of the tests here to use a non-default value of glu_linear_offset?
TransformerEngine/tests/jax/test_custom_call_compute.py
Line 247 in 76c2a9e