diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 7d992d1ca..e4a5f8a61 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -110,6 +110,8 @@ Models PirateNet EquivariantGraphNeuralOperator SINDy + Vectorized Spline + Kolmogorov-Arnold Network Blocks ------------- @@ -128,6 +130,7 @@ Blocks Continuous Convolution Block Orthogonal Block PirateNet Block + KAN Block Message Passing ------------------- diff --git a/docs/source/_rst/model/block/kan_block.rst b/docs/source/_rst/model/block/kan_block.rst new file mode 100644 index 000000000..95ca239eb --- /dev/null +++ b/docs/source/_rst/model/block/kan_block.rst @@ -0,0 +1,7 @@ +KANBlock +======================= +.. currentmodule:: pina.model.block.kan_block + +.. autoclass:: pina._src.model.block.kan_block.KANBlock + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/model/kolmogorov_arnold_network.rst b/docs/source/_rst/model/kolmogorov_arnold_network.rst new file mode 100644 index 000000000..0211611f4 --- /dev/null +++ b/docs/source/_rst/model/kolmogorov_arnold_network.rst @@ -0,0 +1,7 @@ +KolmogorovArnoldNetwork +=========================== +.. currentmodule:: pina.model.kolmogorov_arnold_network + +.. autoclass:: pina._src.model.kolmogorov_arnold_network.KolmogorovArnoldNetwork + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/model/vectorized_spline.rst b/docs/source/_rst/model/vectorized_spline.rst new file mode 100644 index 000000000..08522bc54 --- /dev/null +++ b/docs/source/_rst/model/vectorized_spline.rst @@ -0,0 +1,7 @@ +VectorizedSpline +======================= +.. currentmodule:: pina.model.vectorized_spline + +.. autoclass:: pina._src.model.vectorized_spline.VectorizedSpline + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/model/block/kan_block.py b/pina/_src/model/block/kan_block.py new file mode 100644 index 000000000..77597d310 --- /dev/null +++ b/pina/_src/model/block/kan_block.py @@ -0,0 +1,158 @@ +"""Module for the Kolmogorov-Arnold Network block.""" + +import torch +from pina._src.model.vectorized_spline import VectorizedSpline +from pina._src.core.utils import check_consistency, check_positive_integer + + +class KANBlock(torch.nn.Module): + """ + The inner block of the Kolmogorov-Arnold Network (KAN). + + The block applies a spline transformation to the input, optionally combined + with a linear transformation of a base activation function. The output is + aggregated across input dimensions to produce the final output. + + .. seealso:: + + **Original reference**: + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., + Hou T., Tegmark M. (2025). + *KAN: Kolmogorov-Arnold Networks*. + DOI: `arXiv preprint arXiv:2404.19756. + `_ + """ + + def __init__( + self, + input_dimensions, + output_dimensions, + spline_order=3, + n_knots=10, + grid_range=[0, 1], + base_function=torch.nn.SiLU, + use_base_linear=True, + use_bias=True, + init_scale_spline=1e-2, + init_scale_base=1.0, + ): + """ + Initialization of the :class:`KANBlock` class. + + :param int input_dimensions: The number of input features. + :param int output_dimensions: The number of output features. + :param int spline_order: The order of each spline basis function. + Default is 3 (cubic splines). + :param int n_knots: The number of knots for each spline basis function. + Default is 10. + :param grid_range: The range for the spline knots. It must be either a + list or a tuple of the form [min, max]. Default is [0, 1]. + :type grid_range: list | tuple. + :param torch.nn.Module base_function: The base activation function to be + applied to the input before the linear transformation. Default is + :class:`torch.nn.SiLU`. + :param bool use_base_linear: Whether to include a linear transformation + of the base function output. Default is True. + :param bool use_bias: Whether to include a bias term in the output. + Default is True. + :param init_scale_spline: The scale for initializing each spline + control points. Default is 1e-2. + :type init_scale_spline: float | int. + :param init_scale_base: The scale for initializing the base linear + weights. Default is 1.0. + :type init_scale_base: float | int. + :raises ValueError: If ``grid_range`` is not of length 2. + """ + super().__init__() + + # Check consistency + check_consistency(base_function, torch.nn.Module, subclass=True) + check_positive_integer(input_dimensions, strict=True) + check_positive_integer(output_dimensions, strict=True) + check_positive_integer(spline_order, strict=True) + check_positive_integer(n_knots, strict=True) + check_consistency(use_base_linear, bool) + check_consistency(use_bias, bool) + check_consistency(init_scale_spline, (int, float)) + check_consistency(init_scale_base, (int, float)) + check_consistency(grid_range, (int, float)) + + # Raise error if grid_range is not valid + if len(grid_range) != 2: + raise ValueError("Grid must be a list or tuple with two elements.") + + # Knots for the spline basis functions + initial_knots = torch.ones(spline_order) * grid_range[0] + final_knots = torch.ones(spline_order) * grid_range[1] + + # Number of internal knots + n_internal = max(0, n_knots - 2 * spline_order) + + # Internal knots are uniformly spaced in the grid range + internal_knots = torch.linspace( + grid_range[0], grid_range[1], n_internal + 2 + )[1:-1] + + # Define the knots + knots = torch.cat((initial_knots, internal_knots, final_knots)) + knots = knots.unsqueeze(0).repeat(input_dimensions, 1) + + # Define the control points for the spline basis functions + control_points = ( + torch.randn( + input_dimensions, + output_dimensions, + knots.shape[-1] - spline_order, + ) + * init_scale_spline + ) + + # Define the vectorized spline module + self.spline = VectorizedSpline( + order=spline_order, knots=knots, control_points=control_points + ) + + # Initialize the base function + self.base_function = base_function() + + # Initialize the base linear weights if needed + if use_base_linear: + self.base_weight = torch.nn.Parameter( + torch.randn(output_dimensions, input_dimensions) + * (init_scale_base / (input_dimensions**0.5)) + ) + else: + self.register_parameter("base_weight", None) + + # Initialize the bias term if needed + if use_bias: + self.bias = torch.nn.Parameter(torch.zeros(output_dimensions)) + else: + self.register_parameter("bias", None) + + def forward(self, x): + """ + Forward pass of the Kolmogorov-Arnold block. The input is passed through + the spline transformation, optionally combined with a linear + transformation of the base function output, and then aggregated across + input dimensions to produce the final output. + + :param x: The input tensor for the model. + :type x: torch.Tensor | LabelTensor + :return: The output tensor of the model. + :rtype: torch.Tensor | LabelTensor + """ + y = self.spline(x) + + if self.base_weight is not None: + base_x = self.base_function(x) + base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight) + y = y + base_out + + # aggregate contributions from all input dimensions + y = y.sum(dim=1) + + if self.bias is not None: + y = y + self.bias + + return y diff --git a/pina/_src/model/kolmogorov_arnold_network.py b/pina/_src/model/kolmogorov_arnold_network.py new file mode 100644 index 000000000..1782aab4b --- /dev/null +++ b/pina/_src/model/kolmogorov_arnold_network.py @@ -0,0 +1,105 @@ +import torch +from pina._src.model.block.kan_block import KANBlock +from pina._src.core.utils import check_consistency + + +class KolmogorovArnoldNetwork(torch.nn.Module): + """ + Implementation of Kolmogorov-Arnold Network (KAN). + + The model consists of a sequence of KAN blocks, where each block applies a + spline transformation to the input, optionally combined with a linear + transformation of a base activation function. + + .. seealso:: + + **Original reference**: + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., + Hou T., Tegmark M. (2025). + *KAN: Kolmogorov-Arnold Networks*. + DOI: `arXiv preprint arXiv:2404.19756. + `_ + """ + + def __init__( + self, + layers, + spline_order=3, + n_knots=10, + grid_range=[-1, 1], + base_function=torch.nn.SiLU, + use_base_linear=True, + use_bias=True, + init_scale_spline=1e-2, + init_scale_base=1.0, + ): + """ + Initialization of the :class:`KolmogorovArnoldNetwork` class. + + :param layers: A list of integers specifying the sizes of each layer, + including input and output dimensions. + :type layers: list | tuple. + :param int spline_order: The order of each spline basis function. + Default is 3 (cubic splines). + :param int n_knots: The number of knots for each spline basis function. + Default is 3. + :param grid_range: The range for the spline knots. It must be either a + list or a tuple of the form [min, max]. Default is [0, 1]. + :type grid_range: list | tuple. + :param torch.nn.Module base_function: The base activation function to be + applied to the input before the linear transformation. Default is + :class:`torch.nn.SiLU`. + :param bool use_base_linear: Whether to include a linear transformation + of the base function output. Default is True. + :param bool use_bias: Whether to include a bias term in the output. + Default is True. + :param init_scale_spline: The scale for initializing each spline + control points. Default is 1e-2. + :type init_scale_spline: float | int. + :param init_scale_base: The scale for initializing the base linear + weights. Default is 1.0. + :type init_scale_base: float | int. + :raises ValueError: If ``grid_range`` is not of length 2. + """ + super().__init__() + + # Check consistency -- all other checks are performed in KANBlock + check_consistency(layers, int) + if len(layers) < 2: + raise ValueError( + "`Provide at least two elements for layers (input and output)." + ) + + # Initialize KAN blocks + self.kan_layers = torch.nn.ModuleList( + [ + KANBlock( + input_dimensions=layers[i], + output_dimensions=layers[i + 1], + spline_order=spline_order, + n_knots=n_knots, + grid_range=grid_range, + base_function=base_function, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + for i in range(len(layers) - 1) + ] + ) + + def forward(self, x): + """ + Forward pass of the KolmogorovArnoldNetwork model. It passes the input + through each KAN block in the network and returns the final output. + + :param x: The input tensor for the model. + :type x: torch.Tensor | LabelTensor + :return: The output tensor of the model. + :rtype: torch.Tensor | LabelTensor + """ + for layer in self.kan_layers: + x = layer(x) + + return x diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py new file mode 100644 index 000000000..0fd7c2535 --- /dev/null +++ b/pina/_src/model/vectorized_spline.py @@ -0,0 +1,652 @@ +"""Vectorized univariate B-spline model with per-spline knots.""" + +import warnings +import torch +from pina._src.core.utils import check_consistency, check_positive_integer + + +class VectorizedSpline(torch.nn.Module): + r""" + The vectorized B-spline model class. + + A :class:`VectorizedSpline` represents a vector spline, i.e., a collection + of independent univariate B-splines evaluated in parallel. Each univariate + spline has its own knot vector and its own control points, and acts on one + input feature. + + Given ``s`` univariate splines, the vector spline maps an input + :math:`x = (x^{(1)}, \dots, x^{(s)}) \in \mathbb{R}^s` to an output obtained + by evaluating each univariate spline on its corresponding scalar input + :math:`x^{(j)}`. + + For the :math:`j`-th univariate spline of order :math:`k`, the output is + defined as + + .. math:: + + S^{(j)}(x^{(j)}) = \sum_{i=1}^{n_j} B_{i,k}^{(j)}(x^{(j)}) C_i^{(j)}, + + where: + + - :math:`C^{(j)}` are the control points of the :math:`j`-th univariate + spline. In the scalar-output case, :math:`C^{(j)} \in \mathbb{R}^{n_j}`. + More generally, each univariate spline may have output dimension + :math:`o`, so :math:`C^{(j)} \in \mathbb{R}^{o \times n_j}`. + - :math:`B_{i,k}^{(j)}(x)` are the B-spline basis functions of order + :math:`k`, i.e., piecewise polynomials of degree :math:`k-1`, associated + with the knot vector of the :math:`j`-th univariate spline. + - :math:`X^{(j)} = \{x_1^{(j)}, x_2^{(j)}, \dots, x_{m_j}^{(j)}\}` is the + non-decreasing knot vector of the :math:`j`-th univariate spline. + + If the first and last knots of a given univariate spline are repeated + :math:`k` times, then that univariate spline interpolates its first and last + control points. + + The full vector spline evaluates all univariate splines in parallel. If each + univariate spline has output dimension :math:`o`, then before optional + aggregation the output has shape ``[batch, s, o]``. + + .. note:: + + Each univariate spline is forced to be zero outside the interval defined + by the first and last knots of its own knot vector. + + .. note:: + + This class does not represent a single multivariate spline + :math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate + basis. Instead, it represents a vector of splines built from ``s`` + independent univariate splines, one for each input feature. + + .. note:: + + When using the :meth:`derivative` method of this class, derivatives are + computed directly in vectorized form and returned with the correct + shape. In contrast, when relying on ``autograd``, derivatives must be + computed separately for each output dimension of each univariate spline + and then combined, since autograd does not natively handle this + vectorized structure. + + :Example: + + >>> from pina.model import VectorizedSpline + >>> import torch + + >>> knt1 = torch.tensor([ + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0], + ... ]) + >>> spline1 = VectorizedSpline(order=3, knots=knt1, control_points=None) + + >>> knt2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto", "n_splines": 2} + >>> spline2 = VectorizedSpline(order=3, knots=knt2, control_points=None) + + >>> knt3 = torch.tensor([ + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... ]) + >>> ctrl3 = torch.tensor([ + ... [0.0, 1.0, 3.0, 2.0], + ... [1.0, 0.0, 2.0, 1.0], + ... ]) + >>> spline3 = VectorizedSpline(order=3, knots=knt3, control_points=ctrl3) + """ + + def __init__( + self, + order=4, + knots=None, + control_points=None, + aggregate_output=None, + ): + """ + Initialization of the :class:`VectorizedSpline` class. + + :param int order: The order of each univariate spline. The corresponding + basis functions are polynomials of degree ``order - 1``. + Default is 4. + :param knots: The knots of the spline. If a tensor is provided, it must + have shape ``[s, n]``, where ``s`` is the number of univariate + splines and ``n`` is the number of knots per univariate spline. If a + dictionary is provided, it must contain the keys ``"n"``, ``"min"``, + ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies + the number of knots for each univariate spline, ``"min"`` and + ``"max"`` define the interval, ``"mode"`` selects the sampling + strategy, and ``"n_splines"`` specifies the number of univariate + splines. The supported modes are ``"uniform"``, where the knots are + evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots + are constructed to ensure that each univariate spline interpolates + the first and last control points. In this case, the number of knots + is adjusted if :math:`n < 2 * order`. If None is given, knots are + initialized automatically over :math:`[0, 1]` ensuring interpolation + of the first and last control points. Default is None. + :type knots: torch.Tensor | dict + :param torch.Tensor control_points: The control points tensor. The + tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where + each univariate spline has ``c`` control points and output dimension + ``o``. In the latter case, the control points are expanded to shape + ``[s, 1, c]``. If None, control points are initialized to learnable + parameters with zero initial value. Default is None. + :param str aggregate_output: If None, the output of each univariate + spline is returned separately, resulting in an output of shape + ``[batch, s, o]``, where ``s`` is the number of univariate splines + and ``o`` is the output dimension of each univariate spline. If set + to ``"mean"`` or ``"sum"``, the output is aggregated accordingly + across the last dimension, resulting in an output of shape + ``[batch, s]``. Default is None. + :raises AssertionError: If ``order`` is not a positive integer. + :raises ValueError: If ``knots`` is neither a torch.Tensor nor a + dictionary, when provided. + :raises ValueError: If ``aggregate_output`` is not None, "mean", or + "sum". + :raises ValueError: If ``control_points`` is not a torch.Tensor, + when provided. + :raises ValueError: If both ``knots`` and ``control_points`` are None. + :raises ValueError: If ``knots`` is not two-dimensional, after + processing. + :raises ValueError: If ``control_points``, after expansion when + two-dimensional, is not three-dimensional. + :raises ValueError: If, for each univariate spline, the number of + ``knots`` is not equal to the sum of ``order`` and the number of + ``control_points.`` + :raises UserWarning: If, for each univariate spline, the number of + ``control_points`` is lower than the ``order``, resulting in a + degenerate spline. + :raises ValueError: If the number of univariate splines in ``knots`` and + ``control_points`` do not match. + """ + + super().__init__() + + # Check consistency + check_positive_integer(value=order, strict=True) + check_consistency(knots, (type(None), torch.Tensor, dict)) + check_consistency(control_points, (type(None), torch.Tensor)) + + # Raise error if neither knots nor control points are provided + if knots is None and control_points is None: + raise ValueError("knots and control_points cannot both be None.") + + # Raise error if aggregate_output is not None, "mean", or "sum" + if aggregate_output not in (None, "mean", "sum"): + raise ValueError( + f"aggregate_output must be None, 'mean', or 'sum'." + f" Got {aggregate_output}." + ) + + # Initialize knots if not provided + if knots is None and control_points is not None: + knots = { + "n": control_points.shape[-1] + order, + "min": 0, + "max": 1, + "n_splines": control_points.shape[0], + "mode": "auto", + } + + # Initialization - knots and control points managed by their setters + self.order = order + self.knots = knots + self.control_points = control_points + self.aggregate_output = aggregate_output + + # Check dimensionality of control points + if self.control_points.ndim != 3: + raise ValueError("control_points must be three-dimensional.") + + # Raise error if #knots != order + #control_points + if self.knots.shape[-1] != self.order + self.control_points.shape[-1]: + raise ValueError( + f" The number of knots per spline must be equal to order + the" + f" number of control points. Got {self.knots.shape[-1]} knots" + f" per spline, {self.control_points.shape[-1]} control points," + f" and {self.order} order." + ) + + # Raise warning if spline is degenerate + if self.control_points.shape[-1] < self.order: + warnings.warn( + "The number of control points per spline is smaller than the" + " spline order. This creates a degenerate spline with limited" + " flexibility.", + UserWarning, + ) + + # Raise error if knots and control points have different # of splines + if self.knots.shape[0] != self.control_points.shape[0]: + raise ValueError( + f"The number of splines must be the same for knots and" + f" control points. Got {self.knots.shape[0]} splines for knots" + f" and {self.control_points.shape[0]} splines for control" + f" points." + ) + + # Precompute boundary interval index + self.register_buffer( + "_boundary_interval_idx", self._compute_boundary_interval() + ) + + # Precompute denominators used in derivative formulas + self._compute_derivative_denominators() + + def _compute_boundary_interval(self): + """ + Precompute the index of the rightmost non-degenerate interval to improve + performance, eliminating the need to perform a search loop in the basis + function on each call. + + :return: The index of the rightmost non-degenerate interval for each + univariate spline. + :rtype: torch.Tensor + """ + # Compute the differences between consecutive knots for each spline + diffs = self._knots[:, 1:] - self._knots[:, :-1] + valid = diffs > 0 + + # Initialize idx tensor to store the last valid interval for each spline + idx = torch.zeros( + self._knots.shape[0], dtype=torch.long, device=self._knots.device + ) + + # For each spline, find the last idx where interval is non-degenerate + for s in range(self._knots.shape[0]): + valid_s = torch.nonzero(valid[s], as_tuple=False) + idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0 + + return idx + + def _compute_derivative_denominators(self): + """ + Precompute the denominators used in the derivatives for all orders up to + the spline order to avoid redundant calculations. + """ + # Precompute for order 2 to k + for i in range(2, self.order + 1): + + # Denominators for the derivative recurrence relations + left_den = self.knots[:, i - 1 : -1] - self.knots[:, :-i] + right_den = self.knots[:, i:] - self.knots[:, 1 : -i + 1] + + # If consecutive knots are equal, set left and right factors to zero + left_fac = torch.where( + torch.abs(left_den) > 1e-10, + (i - 1) / left_den, + torch.zeros_like(left_den), + ) + right_fac = torch.where( + torch.abs(right_den) > 1e-10, + (i - 1) / right_den, + torch.zeros_like(right_den), + ) + + # Register buffers + self.register_buffer(f"_left_factor_order_{i}", left_fac) + self.register_buffer(f"_right_factor_order_{i}", right_fac) + + def basis(self, x, collection=False): + """ + Evaluate the B-spline basis functions for each univariate spline. + + This method applies the Cox-de Boor recursion in vectorized form across + all univariate splines of the vector spline. + + :param torch.Tensor x: The points to be evaluated. + :param bool collection: If True, returns a list of basis functions for + all orders up to the spline order. Default is False. + :raise ValueError: If ``collection`` is not a boolean. + :raises ValueError: If ``x`` is not two-dimensional. + :raises ValueError: If the number of input features does not match + the number of univariate splines. + :return: The basis functions evaluated at x. + :rtype: torch.Tensor + """ + # Check consistency + check_consistency(collection, bool) + + # Ensure x is a tensor of the same dtype as knots + x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype) + + # Raise error if x does not have shape (batch, s) + if x.ndim != 2: + raise ValueError( + f"The input must have shape (batch, s). Got {x.shape}." + ) + + # Raise error if x has different number of splines than knots + if x.shape[1] != self.knots.shape[0]: + raise ValueError( + f"The number of input features must be the same as the number" + f" of univariate splines. Got {x.shape[1]} input features," + f" but {self.knots.shape[0]} univariate splines." + ) + + # Add a final dimension to x for broadcasting + x = x.unsqueeze(-1) + + # Add an initial dimension to knots for broadcasting + knots = self.knots.unsqueeze(0) + + # Base case of recursion: indicator functions for the intervals + basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) + basis = basis.to(x.dtype) + + # Extract left and right knots of the boundary interval for each spline + range_tensor = torch.arange(self.knots.shape[0], device=x.device) + knot_left = self.knots[range_tensor, self._boundary_interval_idx] + knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1] + + # Identify points at the rightmost boundary + at_rightmost_boundary = ( + x.squeeze(-1) >= knot_left.unsqueeze(0) + ) & torch.isclose( + x.squeeze(-1), knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10 + ) + + # Ensure the correct value is set at the rightmost boundary + if torch.any(at_rightmost_boundary): + b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True) + basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0 + + # If returning the whole collection, initialize list + if collection: + basis_collection = [None, basis] + + # Cox-de Boor recursion -- iterative case + for i in range(1, self.order): + + # Compute the denominators for both terms of the recursion + denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + + # Ensure no division by zero + denom1 = torch.where( + denom1.abs() < 1e-8, torch.ones_like(denom1), denom1 + ) + denom2 = torch.where( + denom2.abs() < 1e-8, torch.ones_like(denom2), denom2 + ) + + # Compute the two terms of the recursion + term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:] + + # Combine terms to get the new basis + basis = term1 + term2 + + if collection: + basis_collection.append(basis) + + return basis_collection if collection else basis + + def forward(self, x): + """ + Forward pass for the :class:`VectorizedSpline` model. Each univariate + spline is evaluated independently on its corresponding input feature. + + The input is expected to have shape ``[batch, s]``, where ``s`` is the + number of univariate splines. The output has shape ``[batch, s, o]``, + where ``o`` is the output dimension of each univariate spline, unless an + aggregation method is specified. If both ``s`` and ``o`` are 1, the + output is aggregated across the last dimension, resulting in an output + of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or + ``"sum"``, the output is aggregated across the last dimension, resulting + in an output of shape ``[batch, s]``. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output tensor. + :rtype: torch.Tensor + """ + # Compute the basis functions at x + basis = self.basis(x) + + # Compute the output for each spline + out = torch.einsum("bsc,soc->bso", basis, self.control_points) + + # Aggregate output if needed + if self.aggregate_output == "mean": + out = out.mean(dim=-1) + elif self.aggregate_output == "sum": + out = out.sum(dim=-1) + elif out.shape[1] == 1 and out.shape[2] == 1: + out = out.squeeze(-1) + + return out + + def derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of each univariate spline at the + given input points. + + The output has shape ``[batch, s, o]``, where ``o`` is the output + dimension of each univariate spline, unless an aggregation method is + specified. If both ``s`` and ``o`` are 1, the output is aggregated + across the last dimension, resulting in an output of shape + ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or + ``"sum"``, the output is aggregated across the last dimension, resulting + in an output of shape ``[batch, s]``. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param int degree: The derivative degree to compute. + :return: The derivative tensor. + :rtype: torch.Tensor + """ + # Check consistency + check_positive_integer(degree, strict=False) + + # Compute basis derivative + der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree) + + # Compute the output for each spline + out = torch.einsum("bsc,soc->bso", der, self.control_points) + + # Aggregate output if needed + if self.aggregate_output == "mean": + out = out.mean(dim=-1) + elif self.aggregate_output == "sum": + out = out.sum(dim=-1) + elif out.shape[1] == 1 and out.shape[2] == 1: + out = out.squeeze(-1) + + return out + + def _basis_derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of the vectorized spline basis + functions at the given input points using an iterative approach. + + :param torch.Tensor x: The points to be evaluated. + :param int degree: The derivative degree to compute. + :return: The derivative of the basis functions of order ``self.order``. + :rtype: torch.Tensor + """ + # Compute the whole basis collection + basis = self.basis(x, collection=True) + + # Derivatives initialization (dummy at index 0 for convenience) + derivatives = [None] + [basis[o] for o in range(1, self.order + 1)] + + # Iterate over derivative degrees + for _ in range(1, degree + 1): + + # Current degree derivatives (with dummy at index 0 for convenience) + current_der = [None] * (self.order + 1) + current_der[1] = torch.zeros_like(derivatives[1]) + + # Iterate over basis orders + for o in range(2, self.order + 1): + + # Retrieve precomputed factors + left_fac = getattr(self, f"_left_factor_order_{o}") + right_fac = getattr(self, f"_right_factor_order_{o}") + + # derivatives[o - 1] has shape [b, s, m] + # Slice previous derivatives to align + left_part = derivatives[o - 1][..., :-1] + right_part = derivatives[o - 1][..., 1:] + + # Broadcast factors over batch dims + left_fac = left_fac.unsqueeze(0) + right_fac = right_fac.unsqueeze(0) + + # Compute current derivatives + current_der[o] = left_fac * left_part - right_fac * right_part + + # Update derivatives for next degree + derivatives = current_der + + return derivatives[self.order] + + @property + def control_points(self): + """ + The control points of the spline. + + :return: The control points. + :rtype: torch.Tensor + """ + return self._control_points + + @control_points.setter + def control_points(self, control_points): + """ + Set the control points of the spline. + + :param torch.Tensor control_points: The control points tensor. The + tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where + each univariate spline has ``c`` control points and output dimension + ``o``. In the latter case, the control points are expanded to shape + ``[s, 1, c]``. + :raises ValueError: If there are not enough knots to define the control + points, due to the relation: #knots = order + #control_points. + """ + # If control points are not provided, initialize them + if control_points is None: + + # Check that there are enough knots to define control points + if self.knots.shape[-1] < self.order + 1: + raise ValueError( + f"Not enough knots to define control points. Got" + f" {self.knots.shape[-1]} knots for each univariate spline," + f" but need at least {self.order + 1}." + ) + + # Initialize control points to zero + control_points = torch.zeros( + self.knots.shape[0], 1, self.knots.shape[-1] - self.order + ) + + # If a the control points are 2D, add an output dimension of size 1 + if control_points.ndim == 2: + control_points = control_points.unsqueeze(1) + + # Set control points + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) + + @property + def knots(self): + """ + The knots of the spline. + + :return: The knots. + :rtype: torch.Tensor + """ + return self._knots + + @knots.setter + def knots(self, value): + """ + Set the knots of the spline. + :param value: The knots of the spline. If a tensor is provided, it must + have shape ``[s, n]``, where ``s`` is the number of univariate + splines and ``n`` is the number of knots per univariate spline. If a + dictionary is provided, it must contain the keys ``"n"``, ``"min"``, + ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies + the number of knots for each univariate spline, ``"min"`` and + ``"max"`` define the interval, ``"mode"`` selects the sampling + strategy, and ``"n_splines"`` specifies the number of univariate + splines. The supported modes are ``"uniform"``, where the knots are + evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots + are constructed to ensure that each univariate spline interpolates + the first and last control points. In this case, the number of knots + is adjusted if :math:`n < 2 * order`. If None is given, knots are + initialized automatically over :math:`[0, 1]` ensuring interpolation + of the first and last control points. + :type value: torch.Tensor | dict + :raises ValueError: If a dictionary is provided but does not contain + the required keys. + :raises ValueError: If the mode specified in the dictionary is invalid. + :raises ValueError: If knots is not two-dimensional after processing. + """ + # If a dictionary is provided, initialize knots accordingly + if isinstance(value, dict): + + # Check that required keys are present + required_keys = {"n", "min", "max", "mode", "n_splines"} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"When providing knots as a dictionary, the following " + f"keys must be present: {required_keys}. Got " + f"{value.keys()}." + ) + + # Save number of splines for later use + n_splines = value["n_splines"] + + # Uniform sampling of knots + if value["mode"] == "uniform": + value = torch.linspace(value["min"], value["max"], value["n"]) + + # Automatic sampling of interpolating knots + elif value["mode"] == "auto": + + # Repeat the first and last knots 'order' times + initial_knots = torch.ones(self.order) * value["min"] + final_knots = torch.ones(self.order) * value["max"] + + # Number of internal knots + n_internal = value["n"] - 2 * self.order + + # If no internal knots are needed, just concatenate boundaries + if n_internal <= 0: + value = torch.cat((initial_knots, final_knots)) + + # Else, sample internal knots uniformly and exclude boundaries + # Recover the correct number of internal knots when slicing by + # adding 2 to n_internal + else: + internal_knots = torch.linspace( + value["min"], value["max"], n_internal + 2 + )[1:-1] + value = torch.cat( + (initial_knots, internal_knots, final_knots) + ) + + # Raise error if mode is invalid + else: + raise ValueError( + f"Invalid mode for knots initialization. Got " + f"{value['mode']}, but expected 'uniform' or 'auto'." + ) + + # Repeat the knot vector for each spline + value = value.unsqueeze(0).repeat(n_splines, 1) + + # Set knots + self.register_buffer("_knots", value.sort(dim=-1).values) + + # Check dimensionality of knots + if self.knots.ndim != 2: + raise ValueError("knots must be two-dimensional.") + + # Recompute boundary interval when knots change + if hasattr(self, "_boundary_interval_idx"): + self.register_buffer( + "_boundary_interval_idx", self._compute_boundary_interval() + ) + + # Recompute derivative denominators when knots change + self._compute_derivative_denominators() diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index 5dbba18c2..28bccf089 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -289,15 +289,10 @@ def move_discretisation_into_conditions(self): if not self.are_all_domains_discretised: warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { + warning_message = "\n".join([f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.discretised_domains else - "not sampled"}""" - for key in self.domains - ] - ) + "not sampled"}""" for key in self.domains]) warnings.warn( "Some of the domains are still not sampled. Consider calling " "problem.discretise_domain function for all domains before " diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 0310eef5c..ee221c17e 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -17,6 +17,8 @@ "EquivariantGraphNeuralOperator", "SINDy", "SplineSurface", + "VectorizedSpline", + "KolmogorovArnoldNetwork", ] from pina._src.model.feed_forward import FeedForward, ResidualFeedForward @@ -34,3 +36,5 @@ EquivariantGraphNeuralOperator, ) from pina._src.model.sindy import SINDy +from pina._src.model.vectorized_spline import VectorizedSpline +from pina._src.model.kolmogorov_arnold_network import KolmogorovArnoldNetwork diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index 88bfd9e43..e9e8e793d 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -25,6 +25,7 @@ "RBFBlock", "GNOBlock", "PirateNetBlock", + "KANBlock", ] from pina._src.model.block.convolution_2d import ContinuousConvBlock @@ -50,3 +51,4 @@ from pina._src.model.block.rbf_block import RBFBlock from pina._src.model.block.gno_block import GNOBlock from pina._src.model.block.pirate_network_block import PirateNetBlock +from pina._src.model.block.kan_block import KANBlock diff --git a/tests/test_block/test_kan_block.py b/tests/test_block/test_kan_block.py new file mode 100644 index 000000000..08f549f40 --- /dev/null +++ b/tests/test_block/test_kan_block.py @@ -0,0 +1,146 @@ +import torch +import pytest +from pina.model.block import KANBlock + +# Data +input_dim = 3 +data = torch.rand((10, input_dim)) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_constructor( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + # Should fail if input_dimensions is not a positive integer + with pytest.raises(AssertionError): + KANBlock(input_dimensions=-1, output_dimensions=output_dimensions) + + # Should fail if output_dimensions is not a positive integer + with pytest.raises(AssertionError): + KANBlock(input_dimensions=data.shape[1], output_dimensions=-1) + + # Should fail if spline_order is not a positive integer + with pytest.raises(AssertionError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=-1, + ) + + # Should fail if n_knots is not a positive integer + with pytest.raises(AssertionError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + n_knots=-1, + ) + + # Should fail if grid_range is not of length 2 + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + grid_range=[-1, 0, 1], + ) + + # Should fail if base_function is not a torch.nn.Module subclass + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + base_function="not_a_module", + ) + + # Should fail if use_base_linear is not a boolean + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + use_base_linear="not_a_bool", + ) + + # Should fail if use_bias is not a boolean + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + use_bias="not_a_bool", + ) + + # Should fail if init_scale_spline is not a float or int + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + init_scale_spline="not_a_number", + ) + + # Should fail if init_scale_base is not a float or int + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + init_scale_base="not_a_number", + ) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_forward( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + model = KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + output_ = model(data) + assert output_.shape == (data.shape[0], output_dimensions) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_backward( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + model = KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + data.requires_grad_() + output_ = model(data) + + loss = torch.mean(output_) + loss.backward() + assert data.grad.shape == data.shape diff --git a/tests/test_model/test_kolmogorov_arnold_network.py b/tests/test_model/test_kolmogorov_arnold_network.py new file mode 100644 index 000000000..ec26f81d2 --- /dev/null +++ b/tests/test_model/test_kolmogorov_arnold_network.py @@ -0,0 +1,83 @@ +import torch +import pytest +from pina.model import KolmogorovArnoldNetwork + +# Data +input_dim = 3 +data = torch.rand((10, input_dim)) + + +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_constructor(use_base_linear, use_bias, grid_range, layers): + + # Constructor + KolmogorovArnoldNetwork( + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, + ) + + # Should fail if grid_range is not of length 2 + with pytest.raises(ValueError): + KolmogorovArnoldNetwork(layers=layers, grid_range=[-1, 0, 1]) + + # Should fail if layers has less than 2 elements + with pytest.raises(ValueError): + KolmogorovArnoldNetwork(layers=[input_dim]) + + +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_forward(use_base_linear, use_bias, grid_range, layers): + + model = KolmogorovArnoldNetwork( + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, + ) + + output_ = model(data) + assert output_.shape == (data.shape[0], layers[-1]) + + +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_backward(use_base_linear, use_bias, grid_range, layers): + + model = KolmogorovArnoldNetwork( + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, + ) + + data.requires_grad_() + output_ = model(data) + + loss = torch.mean(output_) + loss.backward() + assert data.grad.shape == data.shape diff --git a/tests/test_model/test_vectorized_spline.py b/tests/test_model/test_vectorized_spline.py new file mode 100644 index 000000000..cc0107073 --- /dev/null +++ b/tests/test_model/test_vectorized_spline.py @@ -0,0 +1,287 @@ +import torch +import pytest +from pina.model import VectorizedSpline, Spline +from pina.operator import grad +from pina import LabelTensor + + +# Utility quantities for testing +order = torch.randint(3, 6, (1,)).item() +n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() +n_knots = order + n_ctrl_pts +n_splines = torch.randint(2, 5, (1,)).item() +output_dim = torch.randint(1, 4, (1,)).item() + +# Input points +labels = [f"x{i}" for i in range(n_splines)] +pts = torch.rand(10, n_splines).requires_grad_(True) +pts = LabelTensor(pts, labels) + + +# Define all possible combinations of valid arguments for VectorizedSpline class +valid_args = [ + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": torch.linspace(0, 1, n_knots) + .unsqueeze(0) + .repeat(n_splines, 1), + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "auto", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "uniform", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": None, + "knots": torch.linspace(0, 1, n_knots) + .unsqueeze(0) + .repeat(n_splines, 1), + }, + { + "order": order, + "control_points": None, + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "auto", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": None, + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "uniform", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": None, + }, +] + + +@pytest.mark.parametrize("args", valid_args) +@pytest.mark.parametrize("aggregate_output", ["mean", "sum", None]) +def test_constructor(args, aggregate_output): + VectorizedSpline(**args, aggregate_output=aggregate_output) + + # Should fail if order is not a positive integer + with pytest.raises(AssertionError): + VectorizedSpline( + order=-1, + control_points=args["control_points"], + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if control_points is not None or a torch.Tensor + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=[1, 2, 3], + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if knots is not None, a torch.Tensor, or a dict + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=5, + aggregate_output=aggregate_output, + ) + + # Should fail if aggregate_output is not None, "mean", or "sum" + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=args["knots"], + aggregate_output="invalid", + ) + + # Should fail if both knots and control_points are None + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=None, + knots=None, + aggregate_output=aggregate_output, + ) + + # Should fail if knots is not two-dimensional + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=torch.rand(n_knots), + aggregate_output=aggregate_output, + ) + + # Should fail if control_points is not three-dimensional + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=torch.rand(n_ctrl_pts), + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if the number of knots != order + number of control points + # If control points are None, they are initialized to fulfill this condition + if args["control_points"] is not None: + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=torch.linspace(0, 1, n_knots + 1) + .unsqueeze(0) + .repeat(n_splines, 1), + aggregate_output=aggregate_output, + ) + + # Should fail if the knot dict is missing required keys + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1}, + aggregate_output=aggregate_output, + ) + + # Should fail if the knot dict has invalid 'mode' key + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, + aggregate_output=aggregate_output, + ) + + # Should fail if knots and control points have different number of splines + with pytest.raises(ValueError): + VectorizedSpline( + order=3, + control_points=torch.rand(5, 4, 5), + knots=torch.linspace(0, 1, 8).unsqueeze(0).repeat(3, 1), + aggregate_output=aggregate_output, + ) + + +@pytest.mark.parametrize("args", valid_args) +def test_forward(args): + + # Define the model + model = VectorizedSpline(**args) + + # Evaluate the model + output_ = model(pts) + + # Check output shape + if model.aggregate_output is None: + assert output_.shape == ( + pts.shape[0], + pts.shape[1], + model.control_points.shape[1], + ) + else: + assert output_.shape == pts.shape + + +@pytest.mark.parametrize("args", valid_args) +def test_backward(args): + + # Define the model + model = VectorizedSpline(**args) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape + + +@pytest.mark.parametrize("args", valid_args) +def test_derivative(args): + + # Define and evaluate the model + model = VectorizedSpline(**args) + pts.requires_grad_(True) + output_ = model(pts) + + # Compute analytical derivatives + first_der = model.derivative(x=pts, degree=1) + + # Compute autograd derivatives -- we need to loop over output dimensions + # since autograd doesn't support vectorized outputs + gradients = [] + for j in range(output_.shape[2]): + out = output_[:, :, j].squeeze(-1) + out = LabelTensor(out, [f"u{j}" for j in range(out.shape[1])]) + gradients.append( + grad(out, pts)[[f"du{j}dx{j}" for j in range(pts.shape[1])]] + ) + first_der_auto = torch.stack(gradients, dim=-1) + + # Check shape and value + assert first_der.shape == first_der_auto.shape + assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) + + +def test_1d_vs_vectorized(): + + control_points = torch.rand(1, 1, n_ctrl_pts) + knots = torch.linspace(0, 1, n_knots).unsqueeze(0) + + # Classical 1D spline + + spline = Spline( + order=order, + control_points=control_points.squeeze(), + knots=knots.squeeze(), + ) + + # Create a VectorizedSpline instance with the same control pts and knots + vectorized_spline = VectorizedSpline( + order=order, + knots=knots, + control_points=control_points, + aggregate_output=None, + ) + + # Input points + x = LabelTensor(torch.rand(10, 1), labels=["x"]) + + # Evaluate both models on the same input + out_spline = spline(x) + out_vectorized = vectorized_spline(x) + + assert out_vectorized.shape == out_spline.shape + assert torch.allclose(out_vectorized, out_spline, atol=1e-5, rtol=1e-5)