Add approximate parameter to GELU activation function#1548
Add approximate parameter to GELU activation function#1548alinpahontu2912 wants to merge 1 commit intodotnet:mainfrom
Conversation
Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality. Changes: - Add GELU.Approximate enum with 'none' and 'tanh' values - Thread approximate parameter through all layers: native C++, PInvoke, Tensor methods, functional API, and module factory - Add new overloads (no breaking changes to existing API) - Add test for tanh approximation mode Fixes dotnet#1368 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Adds support for PyTorch’s approximate mode to GELU (notably "tanh"), threading the option through the native (C++), P/Invoke, Tensor, functional, and module APIs, and adding a regression test.
Changes:
- Introduces
Modules.GELU.Approximate(none/tanh) and plumbs it throughnn.GELUandnn.functional.gelu. - Extends Tensor
gelu/gelu_to accept an approximation mode and updates the corresponding native/PInvoke signatures. - Adds a unit test validating the tanh approximation path and that it differs from the exact mode.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/TorchSharpTest/NN.cs | Adds a test covering GELU tanh approximation behavior. |
| src/TorchSharp/Tensor/Tensor.cs | Adds gelu/gelu_ overloads that pass approximation through to native. |
| src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs | Updates P/Invoke signatures to accept the approximation string. |
| src/TorchSharp/NN/Activation/GELU.cs | Adds approximation enum + overloads in module factory and functional API. |
| src/Native/LibTorchSharp/THSTensor.h | Updates native exports for GELU to accept an approximation parameter. |
| src/Native/LibTorchSharp/THSTensor.cpp | Passes approximation through to torch::gelu / torch::gelu_. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| [DllImport("LibTorchSharp")] | ||
| internal static extern IntPtr THSTensor_gelu(IntPtr tensor); | ||
| internal static extern IntPtr THSTensor_gelu(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); | ||
|
|
||
| [DllImport("LibTorchSharp")] | ||
| internal static extern IntPtr THSTensor_gelu_(IntPtr tensor); | ||
| internal static extern IntPtr THSTensor_gelu_(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string approximate); |
There was a problem hiding this comment.
The new P/Invoke declarations for THSTensor_gelu/THSTensor_gelu_ introduce an LPStr string parameter but don’t specify CharSet/BestFitMapping/ThrowOnUnmappableChar like the other LPStr-based imports in this file (e.g., THSTensor_load/meshgrid/div). This can lead to inconsistent marshaling behavior across platforms and re-enables best-fit character mapping. Consider updating these DllImport attributes to match the existing pattern used for other string parameters in LibTorchSharp.THSTensor.cs.
| public Tensor gelu(TorchSharp.Modules.GELU.Approximate approximate) | ||
| { | ||
| var res = NativeMethods.THSTensor_gelu(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); | ||
| if (res == IntPtr.Zero) | ||
| CheckForErrors(); | ||
| return new Tensor(res); | ||
| } | ||
|
|
||
| public Tensor gelu_() | ||
| { | ||
| var res = NativeMethods.THSTensor_gelu_(Handle); | ||
| var res = NativeMethods.THSTensor_gelu_(Handle, "none"); | ||
| if (res == IntPtr.Zero) | ||
| CheckForErrors(); | ||
| return new Tensor(res); | ||
| } | ||
|
|
||
| public Tensor gelu_(TorchSharp.Modules.GELU.Approximate approximate) | ||
| { | ||
| var res = NativeMethods.THSTensor_gelu_(Handle, approximate == TorchSharp.Modules.GELU.Approximate.tanh ? "tanh" : "none"); |
There was a problem hiding this comment.
The Tensor.gelu overloads take TorchSharp.Modules.GELU.Approximate, which is a nested enum on an nn.Module type. That makes a core Tensor API depend on the Modules layer and forces callers of tensor.gelu(...) / functional.gelu(...) to reference Modules.GELU for what is essentially an ATen algorithm option. Consider moving the approximation enum to a more neutral location (e.g., torch.nn or torch) and having the Tensor/functional overloads use that type (keeping the current overload as a forwarding shim if you want to preserve source compatibility).
Fixes #1368
Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality.
Changes: