From 8307d12ac6ec6ff6fe39e02aca94f36e732de14d Mon Sep 17 00:00:00 2001 From: ajacoby9 Date: Thu, 15 Jan 2026 04:51:37 -0500 Subject: [PATCH 01/12] KAN implementation (#611) * Improve spline * Add KAN --------- Co-authored-by: Filippo Olivo --- pina/_src/model/spline.py | 2 +- .../kolmogorov_arnold_network/kan_layer.py | 223 ++++++++++++++++++ .../kolmogorov_arnold_network/kan_network.py | 194 +++++++++++++++ 3 files changed, 418 insertions(+), 1 deletion(-) create mode 100644 pina/model/kolmogorov_arnold_network/kan_layer.py create mode 100644 pina/model/kolmogorov_arnold_network/kan_network.py diff --git a/pina/_src/model/spline.py b/pina/_src/model/spline.py index 5e5b133c3..0cbf8df45 100644 --- a/pina/_src/model/spline.py +++ b/pina/_src/model/spline.py @@ -475,4 +475,4 @@ def knots(self, value): self._boundary_interval_idx = self._compute_boundary_interval() # Recompute derivative denominators when knots change - self._compute_derivative_denominators() + self._compute_derivative_denominators() \ No newline at end of file diff --git a/pina/model/kolmogorov_arnold_network/kan_layer.py b/pina/model/kolmogorov_arnold_network/kan_layer.py new file mode 100644 index 000000000..ddd360587 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_layer.py @@ -0,0 +1,223 @@ +"""Create the infrastructure for a KAN layer""" +import torch +import numpy as np + +from pina.model.spline import Spline + + +class KAN_layer(torch.nn.Module): + """define a KAN layer using splines""" + def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None: + """ + Initialize the KAN layer. + """ + super().__init__() + self.k = k + self.input_dimensions = input_dimensions + self.output_dimensions = output_dimensions + self.inner_nodes = inner_nodes + self.num = num + self.grid_eps = grid_eps + self.grid_range = grid_range + self.grid_extension = grid_extension + + if sparse_init: + self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) + else: + self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) + + grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + + if grid_extension: + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + for i in range(self.k): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + n_coef = grid.shape[1] - (self.k + 1) + + control_points = torch.nn.Parameter( + torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale + ) + + self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension) + + self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ + scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) + self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) + self.base_function = base_function + + @staticmethod + def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor: + ''' + get sparse mask + ''' + in_coord = torch.arange(in_dimensions) * 1/in_dimensions + 1/(2*in_dimensions) + out_coord = torch.arange(out_dimensions) * 1/out_dimensions + 1/(2*out_dimensions) + + dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dimensions), in_nearest]).permute(1,0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dimensions)]).permute(1,0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dimensions, out_dimensions) + mask[all_connection[:,0], all_connection[:,1]] = 1. + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN layer. + Each input goes through: w_base*base(x) + w_spline*spline(x) + Then sum across input dimensions for each output node. + """ + if hasattr(x, 'tensor'): + x_tensor = x.tensor + else: + x_tensor = x + + base = self.base_function(x_tensor) # (batch, input_dimensions) + + basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots) + spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + base_term = self.scale_base[None, :, :] * base[:, :, None] + spline_term = self.scale_spline[None, :, :] * spline_out_per_input + combined = base_term + spline_term + combined = self.mask[None,:,:] * combined + + output = torch.sum(combined, dim=1) # (batch, output_dimensions) + + return output + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid from input samples to better fit data distribution. + Based on PyKAN implementation but with boundary preservation. + """ + # Convert LabelTensor to regular tensor for spline operations + if hasattr(x, 'tensor'): + # This is a LabelTensor, extract the tensor part + x_tensor = x.tensor + else: + x_tensor = x + + with torch.no_grad(): + batch_size = x_tensor.shape[0] + x_sorted = torch.sort(x_tensor, dim=0)[0] # (batch_size, input_dimensions) + + # Get current number of intervals (excluding extensions) + if self.grid_extension: + num_interval = self.spline.knots.shape[1] - 1 - 2*self.k + else: + num_interval = self.spline.knots.shape[1] - 1 + + def get_grid(num_intervals: int): + """PyKAN-style grid creation with boundary preservation""" + ids = [int(batch_size * i / num_intervals) for i in range(num_intervals)] + [-1] + grid_adaptive = x_sorted[ids, :].transpose(0, 1) # (input_dimensions, num_intervals+1) + + original_min = self.grid_range[0] + original_max = self.grid_range[1] + + # Clamp adaptive grid to not shrink beyond original domain + grid_adaptive[:, 0] = torch.min(grid_adaptive[:, 0], + torch.full_like(grid_adaptive[:, 0], original_min)) + grid_adaptive[:, -1] = torch.max(grid_adaptive[:, -1], + torch.full_like(grid_adaptive[:, -1], original_max)) + + margin = 0.0 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_intervals + grid_uniform = (grid_adaptive[:, [0]] - margin + + h * torch.arange(num_intervals + 1, device=x_tensor.device, dtype=x_tensor.dtype)[None, :]) + + grid_blended = (self.grid_eps * grid_uniform + + (1 - self.grid_eps) * grid_adaptive) + + return grid_blended + + # Create augmented evaluation points: samples + boundary points + # This ensures we preserve boundary behavior while adapting to sample density + boundary_points = torch.tensor([[self.grid_range[0]], [self.grid_range[1]]], + device=x_tensor.device, dtype=x_tensor.dtype).expand(-1, self.input_dimensions) + + # Combine samples with boundary points for evaluation + x_augmented = torch.cat([x_sorted, boundary_points], dim=0) + x_augmented = torch.sort(x_augmented, dim=0)[0] # Re-sort with boundaries included + + # Evaluate current spline at augmented points (samples + boundaries) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Create new grid + new_grid = get_grid(num_interval) + + if mode == 'grid': + # For 'grid' mode, use denser sampling + sample_grid = get_grid(2 * num_interval) + x_augmented = sample_grid.transpose(0, 1) # (batch_size, input_dimensions) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Add grid extensions if needed + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update grid and refit coefficients + self.spline.knots = new_grid + + try: + # Refit coefficients using augmented points (preserves boundaries) + self.spline.compute_control_points(x_augmented, y_eval) + except Exception as e: + print(f"Warning: Failed to update coefficients during grid refinement: {e}") + + def update_grid_resolution(self, new_num: int): + """ + Update grid resolution to a new number of intervals. + """ + with torch.no_grad(): + # Sample the current spline function on a dense grid + x_eval = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=2 * new_num, + device=self.spline.knots.device + ) + x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions) + + basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Update num and create a new grid + self.num = new_num + new_grid = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=self.num + 1, + device=self.spline.knots.device + ) + new_grid = new_grid[None, :].expand(self.input_dimensions, self.num + 1) + + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update spline with the new grid and re-compute control points + self.spline.knots = new_grid + self.spline.compute_control_points(x_eval, y_eval) + + def get_grid_statistics(self): + """Get statistics about the current grid for debugging/analysis""" + return { + 'grid_shape': self.spline.knots.shape, + 'grid_min': self.spline.knots.min().item(), + 'grid_max': self.spline.knots.max().item(), + 'grid_range': (self.spline.knots.max() - self.spline.knots.min()).mean().item(), + 'num_intervals': self.spline.knots.shape[1] - 1 - (2*self.k if self.spline.grid_extension else 0) + } \ No newline at end of file diff --git a/pina/model/kolmogorov_arnold_network/kan_network.py b/pina/model/kolmogorov_arnold_network/kan_network.py new file mode 100644 index 000000000..cd94a5894 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_network.py @@ -0,0 +1,194 @@ +"""Kolmogorov Arnold Network implementation""" +import torch +import torch.nn as nn +from typing import List + +try: + from .kan_layer import KAN_layer +except ImportError: + from kan_layer import KAN_layer + +class KAN_Network(torch.nn.Module): + """ + Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers. + Each layer uses learnable univariate functions (B-splines + base functions) on edges. + """ + + def __init__( + self, + layer_sizes: List[int], + k: int = 3, + num: int = 3, + grid_eps: float = 0.1, + grid_range: List[float] = [-1, 1], + grid_extension: bool = True, + noise_scale: float = 0.1, + base_function = torch.nn.SiLU(), + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + inner_nodes: int = 5, + sparse_init: bool = False, + sp_trainable: bool = True, + sb_trainable: bool = True, + save_act: bool = True + ): + """ + Initialize the KAN network. + + Args: + layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim] + k: Order of the B-spline + num: Number of grid points for B-splines + grid_eps: Epsilon for grid spacing + grid_range: Range for the grid [min, max] + grid_extension: Whether to extend the grid + noise_scale: Scale for initialization noise + base_function: Base activation function (e.g., SiLU) + scale_base_mu: Mean for base function scaling + scale_base_sigma: Std for base function scaling + scale_sp: Scale for spline functions + """ + super().__init__() + + if len(layer_sizes) < 2: + raise ValueError("Need at least input and output dimensions") + + self.layer_sizes = layer_sizes + self.num_layers = len(layer_sizes) - 1 + self.save_act = save_act + + # Create KAN layers + self.kan_layers = nn.ModuleList() + + for i in range(self.num_layers): + layer = KAN_layer( + k=k, + input_dimensions=layer_sizes[i], + output_dimensions=layer_sizes[i+1], + num=num, + grid_eps=grid_eps, + grid_range=grid_range, + grid_extension=grid_extension, + noise_scale=noise_scale, + base_function=base_function, + scale_base_mu=scale_base_mu, + scale_base_sigma=scale_base_sigma, + scale_sp=scale_sp, + inner_nodes=inner_nodes, + sparse_init=sparse_init, + sp_trainable=sp_trainable, + sb_trainable=sb_trainable + ) + self.kan_layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN network. + + Args: + x: Input tensor of shape (batch_size, input_dimensions) + + Returns: + Output tensor of shape (batch_size, output_dimensions) + """ + current = x + self.acts = [current] + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + + if self.save_act: + self.acts.append(current.detach()) + + return current + + def get_num_parameters(self) -> int: + """Get total number of trainable parameters""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid for all layers based on input samples. + This adapts the grid points to better fit the data distribution. + + Args: + x: Input samples, shape (batch_size, input_dimensions) + mode: 'sample' or 'grid' - determines sampling strategy + """ + current = x + + for i, layer in enumerate(self.kan_layers): + layer.update_grid_from_samples(current, mode=mode) + + if i < len(self.kan_layers) - 1: + with torch.no_grad(): + current = layer(current) + + def update_grid_resolution(self, new_num: int): + """ + Update the grid resolution for all layers. + This can be used for adaptive training where grid resolution increases over time. + + Args: + new_num: New number of grid points + """ + for layer in self.kan_layers: + layer.update_grid_resolution(new_num) + + def enable_sparsification(self, threshold: float = 1e-4): + """ + Enable sparsification by setting small weights to zero. + + Args: + threshold: Threshold below which weights are set to zero + """ + with torch.no_grad(): + for layer in self.kan_layers: + # Sparsify scale parameters + layer.scale_base.data[torch.abs(layer.scale_base.data) < threshold] = 0 + layer.scale_spline.data[torch.abs(layer.scale_spline.data) < threshold] = 0 + + # Update mask + layer.mask.data = ((torch.abs(layer.scale_base) >= threshold) | + (torch.abs(layer.scale_spline) >= threshold)).float() + + def get_activation_statistics(self, x: torch.Tensor): + """ + Get statistics about activations for analysis purposes. + + Args: + x: Input tensor + + Returns: + Dictionary with activation statistics + """ + stats = {} + current = x + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + stats[f'layer_{i}'] = { + 'mean': current.mean().item(), + 'std': current.std().item(), + 'min': current.min().item(), + 'max': current.max().item() + } + + return stats + + + def get_network_grid_statistics(self): + """ + Get grid statistics for all layers in the network. + + Returns: + Dictionary with grid statistics for each layer + """ + stats = {} + for i, layer in enumerate(self.kan_layers): + stats[f'layer_{i}'] = layer.get_grid_statistics() + return stats + + \ No newline at end of file From dc31498ef6b1b668b925fd951c9089e07f54f5e3 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Wed, 21 Jan 2026 14:27:14 +0100 Subject: [PATCH 02/12] KAN with non-vectorized spline --- .../model/block/kan_block.py} | 98 ++++++++--- .../model/kolmogorov_arnold_network.py} | 54 +++--- pina/_src/model/spline.py | 9 +- pina/_src/model/vectorized_spline.py | 164 ++++++++++++++++++ pina/model/__init__.py | 4 + pina/model/block/__init__.py | 2 + .../test_kolmogorov_arnold_network.py | 153 ++++++++++++++++ tests/test_model/test_spline.py | 39 +++++ 8 files changed, 473 insertions(+), 50 deletions(-) rename pina/{model/kolmogorov_arnold_network/kan_layer.py => _src/model/block/kan_block.py} (70%) rename pina/{model/kolmogorov_arnold_network/kan_network.py => _src/model/kolmogorov_arnold_network.py} (76%) create mode 100644 pina/_src/model/vectorized_spline.py create mode 100644 tests/test_model/test_kolmogorov_arnold_network.py diff --git a/pina/model/kolmogorov_arnold_network/kan_layer.py b/pina/_src/model/block/kan_block.py similarity index 70% rename from pina/model/kolmogorov_arnold_network/kan_layer.py rename to pina/_src/model/block/kan_block.py index ddd360587..ec5b5cca3 100644 --- a/pina/model/kolmogorov_arnold_network/kan_layer.py +++ b/pina/_src/model/block/kan_block.py @@ -2,14 +2,21 @@ import torch import numpy as np -from pina.model.spline import Spline +from pina._src.model.spline import Spline +from pina._src.model.vectorized_spline import VectorizedSpline -class KAN_layer(torch.nn.Module): +class KANBlock(torch.nn.Module): """define a KAN layer using splines""" - def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None: + def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, + num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, + noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, + scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, + sb_trainable=True): """ Initialize the KAN layer. + + num è il numero di intervalli nella griglia iniziale (esclusi gli eventuali nodi di estensione) """ super().__init__() self.k = k @@ -20,6 +27,8 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ self.grid_eps = grid_eps self.grid_range = grid_range self.grid_extension = grid_extension + self.vec = True + # self.vec = False if sparse_init: self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) @@ -27,6 +36,7 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + knots = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1) if grid_extension: h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) @@ -34,17 +44,53 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ grid = torch.cat([grid[:, [0]] - h, grid], dim=1) grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - n_coef = grid.shape[1] - (self.k + 1) + n_control_points = len(knots) - (self.k ) - control_points = torch.nn.Parameter( - torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale - ) + # control_points = torch.nn.Parameter( + # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale + # ) + # print(control_points.shape) + if self.vec: + control_points = torch.randn(self.input_dimensions * self.output_dimensions, n_control_points) + print('control points', control_points.shape) + control_points = torch.stack([ + torch.randn(n_control_points) + for _ in range(self.input_dimensions * self.output_dimensions) + ]) + print('control points', control_points.shape) + self.spline_q = VectorizedSpline( + order=self.k, + knots=knots, + control_points=control_points + ) + + else: + spline_q = [] + for q in range(self.output_dimensions): + spline_p = [] + for p in range(self.input_dimensions): + spline_ = Spline( + order=self.k, + knots=knots, + control_points=torch.randn(n_control_points) + ) + spline_p.append(spline_) + spline_p = torch.nn.ModuleList(spline_p) + spline_q.append(spline_p) + self.spline_q = torch.nn.ModuleList(spline_q) + + + # control_points = torch.nn.Parameter( + # torch.randn(n_control_points, self.output_dimensions) * noise_scale) + # print(control_points) + # print('uuu') - self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension) + # self.spline = Spline( + # order=self.k, knots=knots, control_points=control_points) - self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ - scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) - self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) + # self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ + # scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) + # self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) self.base_function = base_function @staticmethod @@ -75,20 +121,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_tensor = x.tensor else: x_tensor = x - - base = self.base_function(x_tensor) # (batch, input_dimensions) - - basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots) - spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points) - base_term = self.scale_base[None, :, :] * base[:, :, None] - spline_term = self.scale_spline[None, :, :] * spline_out_per_input - combined = base_term + spline_term - combined = self.mask[None,:,:] * combined - - output = torch.sum(combined, dim=1) # (batch, output_dimensions) - return output + if self.vec: + y = self.spline_q.forward(x_tensor) # (batch, output_dimensions, input_dimensions) + y = y.reshape(y.shape[0], y.shape[1], self.output_dimensions, self.input_dimensions) + base_out = self.base_function(x_tensor) # (batch, input_dimensions) + y = y + base_out[:, :, None, None] + y = y.sum(dim=3).sum(dim=1) # sum over input dimensions + else: + y = [] + for q in range(self.output_dimensions): + y_q = [] + for p in range(self.input_dimensions): + spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions) + base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions) + y_q.append(spline_out + base_out) + y.append(torch.stack(y_q, dim=1).sum(dim=1)) + y = torch.stack(y, dim=1) + + return y def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): """ diff --git a/pina/model/kolmogorov_arnold_network/kan_network.py b/pina/_src/model/kolmogorov_arnold_network.py similarity index 76% rename from pina/model/kolmogorov_arnold_network/kan_network.py rename to pina/_src/model/kolmogorov_arnold_network.py index cd94a5894..81f0754b0 100644 --- a/pina/model/kolmogorov_arnold_network/kan_network.py +++ b/pina/_src/model/kolmogorov_arnold_network.py @@ -3,15 +3,20 @@ import torch.nn as nn from typing import List -try: - from .kan_layer import KAN_layer -except ImportError: - from kan_layer import KAN_layer +from pina._src.model.block.kan_block import KANBlock -class KAN_Network(torch.nn.Module): +class KolmogorovArnoldNetwork(torch.nn.Module): """ - Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers. - Each layer uses learnable univariate functions (B-splines + base functions) on edges. + Kolmogorov Arnold Network, a neural network using KAN layers instead of + traditional MLP layers. Each layer uses learnable univariate functions + (B-splines + base functions) on edges. + + .. references:: + + Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., + ... & Tegmark, M. (2024). Kan: Kolmogorov-arnold networks. arXiv + preprint arXiv:2404.19756. + """ def __init__( @@ -35,19 +40,25 @@ def __init__( ): """ Initialize the KAN network. - - Args: - layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim] - k: Order of the B-spline - num: Number of grid points for B-splines - grid_eps: Epsilon for grid spacing - grid_range: Range for the grid [min, max] - grid_extension: Whether to extend the grid - noise_scale: Scale for initialization noise - base_function: Base activation function (e.g., SiLU) - scale_base_mu: Mean for base function scaling - scale_base_sigma: Std for base function scaling - scale_sp: Scale for spline functions + + :param iterable layer_sizes: List of layer sizes including input and + output dimensions. + :param int k: Order of the B-spline. + :param int num: Number of grid points for B-splines. + :param float grid_eps: Epsilon for grid spacing. + :param list grid_range: Range for the grid [min, max]. + :param bool grid_extension: Whether to extend the grid. + :param float noise_scale: Scale for initialization noise. + :param base_function: Base activation function (e.g., SiLU). + :param float scale_base_mu: Mean for base function scaling. + :param float scale_base_sigma: Std for base function scaling. + :param float scale_sp: Scale for spline functions. + :param int inner_nodes: Number of inner nodes for KAN layers. + :param bool sparse_init: Whether to use sparse initialization. + :param bool sp_trainable: Whether spline parameters are trainable. + :param bool sb_trainable: Whether base function parameters are + trainable. + :param bool save_act: Whether to save activations after each layer. """ super().__init__() @@ -62,7 +73,7 @@ def __init__( self.kan_layers = nn.ModuleList() for i in range(self.num_layers): - layer = KAN_layer( + layer = KANBlock( k=k, input_dimensions=layer_sizes[i], output_dimensions=layer_sizes[i+1], @@ -97,6 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.kan_layers): current = layer(current) + # current = torch.nn.functional.sigmoid(current) if self.save_act: self.acts.append(current.detach()) diff --git a/pina/_src/model/spline.py b/pina/_src/model/spline.py index 0cbf8df45..1b00300de 100644 --- a/pina/_src/model/spline.py +++ b/pina/_src/model/spline.py @@ -277,11 +277,8 @@ def forward(self, x): :return: The output tensor. :rtype: torch.Tensor """ - return torch.einsum( - "...bi, i -> ...b", - self.basis(x.as_subclass(torch.Tensor)).squeeze(-1), - self.control_points, - ) + basis = self.basis(x.as_subclass(torch.Tensor)) + return basis @ self.control_points def derivative(self, x, degree): """ @@ -475,4 +472,4 @@ def knots(self, value): self._boundary_interval_idx = self._compute_boundary_interval() # Recompute derivative denominators when knots change - self._compute_derivative_denominators() \ No newline at end of file + self._compute_derivative_denominators() diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py new file mode 100644 index 000000000..89d2a0e72 --- /dev/null +++ b/pina/_src/model/vectorized_spline.py @@ -0,0 +1,164 @@ +"""Vectorized univariate B-spline model.""" + +import torch +import torch.nn as nn + +class VectorizedSpline(nn.Module): + """ + Vectorized univariate B-spline model (shared knots, many splines). + + Notation: + - knots: shape (m,) + - order: k (degree = k-1) + - n_ctrl = m - k + - control_points: + * (S, n_ctrl) -> S splines, scalar output each + * (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels) + Input: + - x: shape (...,) or (..., B) + Output: + - if control_points is (S, n_ctrl): shape (..., S) + - if control_points is (S, O, n_ctrl): shape (..., S, O) + """ + + def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None): + super().__init__() + if not isinstance(order, int) or order <= 0: + raise ValueError("order must be a positive integer.") + if not isinstance(knots, torch.Tensor): + raise ValueError("knots must be a torch.Tensor.") + if knots.ndim != 1: + raise ValueError("knots must be 1D.") + + self.order = order + + # store sorted knots as buffer + knots_sorted = knots.sort().values + self.register_buffer("knots", knots_sorted) + + n_ctrl = knots_sorted.numel() - order + if n_ctrl <= 0: + raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.") + + # boundary interval idx for rightmost inclusion + self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted) + + # # control points init + # if control_points is None: + # # default: one spline + # cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device) + # self.control_points = nn.Parameter(cp, requires_grad=True) + # else: + # if not isinstance(control_points, torch.Tensor): + # raise ValueError("control_points must be a torch.Tensor or None.") + # if control_points.ndim not in (2, 3): + # raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).") + # if control_points.shape[-1] != n_ctrl: + # raise ValueError( + # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}." + # ) + self.control_points = nn.Parameter(control_points, requires_grad=True) + + @staticmethod + def _compute_boundary_interval_idx(knots: torch.Tensor) -> int: + if knots.numel() < 2: + return 0 + diffs = knots[1:] - knots[:-1] + valid = torch.nonzero(diffs > 0, as_tuple=False) + if valid.numel() == 0: + return 0 + return int(valid[-1]) + + def basis(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute B-spline basis functions of order self.order at x. + + Returns: + basis: shape (..., n_ctrl) + """ + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + + # ensure float dtype consistent + x = x.to(dtype=self.knots.dtype, device=self.knots.device) + + # make x shape (..., 1) for broadcasting + x_exp = x.unsqueeze(-1) # (..., 1) + + # knots as (1, ..., 1, m) via unsqueeze to broadcast + # (m,) -> (1,)*x.ndim + (m,) + knots = self.knots.view(*([1] * x.ndim), -1) + + # order-1 base: indicator on intervals [t_i, t_{i+1}) + basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1) + + # include rightmost boundary in the last non-degenerate interval + j = self._boundary_interval_idx + knot_left = knots[..., j] + knot_right = knots[..., j + 1] + at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10) + if torch.any(at_right): + basis_j = basis[..., j].bool() | at_right + basis[..., j] = basis_j.to(basis.dtype) + + # Cox-de Boor recursion up to order k + # after i-th iteration, basis has length (m-1 - i) + for i in range(1, self.order): + denom1 = knots[..., i:-1] - knots[..., :-(i + 1)] + denom2 = knots[..., i + 1:] - knots[..., 1:-i] + + denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1) + denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2) + + term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:] + basis = term1 + term2 + + # final basis length is n_ctrl = m - order + return basis # (..., n_ctrl) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Evaluate spline(s) at x. + + If control_points is (S, n_ctrl): output (..., S) + If control_points is (S, O, n_ctrl): output (..., S, O) + """ + B = self.basis(x) # (..., n_ctrl) + + cp = self.control_points + if cp.ndim == 2: + # (S, n_ctrl) + # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) + out = B @ cp.transpose(0, 1) + return out + else: + # (S, O, n_ctrl) + # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S + # vectorized using einsum (yes, this one is actually appropriate) + # (..., n) * (S, O, n) -> (..., S, O) + # out = torch.einsum("...n, son -> ...so", B, cp) + out = torch.einsum("bsc,sco->bso", B, cp) + + return out + + def forward_basis(self, basis): + """ + Evaluate spline(s) given precomputed basis. + + """ + cp = self.control_points + if cp.ndim == 2: + # (S, n_ctrl) + # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) + out = basis @ cp.transpose(0, 1) + return out + else: + # (S, O, n_ctrl) + # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S + # vectorized using einsum (yes, this one is actually appropriate) + # (..., n) * (S, O, n) -> (..., S, O) + # out = torch.einsum("...n, son -> ...so", B, cp) + out = torch.einsum("bsc,sco->bso", basis, cp) + + return out \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 0310eef5c..ee221c17e 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -17,6 +17,8 @@ "EquivariantGraphNeuralOperator", "SINDy", "SplineSurface", + "VectorizedSpline", + "KolmogorovArnoldNetwork", ] from pina._src.model.feed_forward import FeedForward, ResidualFeedForward @@ -34,3 +36,5 @@ EquivariantGraphNeuralOperator, ) from pina._src.model.sindy import SINDy +from pina._src.model.vectorized_spline import VectorizedSpline +from pina._src.model.kolmogorov_arnold_network import KolmogorovArnoldNetwork diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index 88bfd9e43..e9e8e793d 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -25,6 +25,7 @@ "RBFBlock", "GNOBlock", "PirateNetBlock", + "KANBlock", ] from pina._src.model.block.convolution_2d import ContinuousConvBlock @@ -50,3 +51,4 @@ from pina._src.model.block.rbf_block import RBFBlock from pina._src.model.block.gno_block import GNOBlock from pina._src.model.block.pirate_network_block import PirateNetBlock +from pina._src.model.block.kan_block import KANBlock diff --git a/tests/test_model/test_kolmogorov_arnold_network.py b/tests/test_model/test_kolmogorov_arnold_network.py new file mode 100644 index 000000000..42f994f71 --- /dev/null +++ b/tests/test_model/test_kolmogorov_arnold_network.py @@ -0,0 +1,153 @@ +import torch +import pytest + +from pina.model import KolmogorovArnoldNetwork + +data = torch.rand((20, 3)) +input_vars = 3 +output_vars = 1 + + +def test_constructor(): + KolmogorovArnoldNetwork([input_vars, output_vars]) + KolmogorovArnoldNetwork([input_vars, 10, 20, output_vars]) + KolmogorovArnoldNetwork( + [input_vars, 10, 20, output_vars], + k=3, + num=5 + ) + KolmogorovArnoldNetwork( + [input_vars, 10, 20, output_vars], + k=3, + num=5, + grid_eps=0.05, + grid_range=[-2, 2] + ) + KolmogorovArnoldNetwork( + [input_vars, 10, output_vars], + base_function=torch.nn.Tanh(), + scale_sp=0.5, + sparse_init=True + ) + + +def test_constructor_wrong(): + with pytest.raises(ValueError): + KolmogorovArnoldNetwork([input_vars]) + with pytest.raises(ValueError): + KolmogorovArnoldNetwork([]) + + +def test_forward(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, dim_out]) + output_ = kan(data) + assert output_.shape == (data.shape[0], dim_out) + + +def test_forward_multilayer(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, 10, 5, dim_out]) + output_ = kan(data) + assert output_.shape == (data.shape[0], dim_out) + + +def test_backward(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, dim_out]) + data.requires_grad = True + output_ = kan(data) + loss = torch.mean(output_) + loss.backward() + assert data._grad.shape == torch.Size([20, 3]) + + +def test_get_num_parameters(): + kan = KolmogorovArnoldNetwork([3, 5, 2]) + num_params = kan.get_num_parameters() + assert num_params > 0 + assert isinstance(num_params, int) + +from pina.problem.zoo import Poisson2DSquareProblem +from pina.solver import PINN +from pina.trainer import Trainer + +def test_train_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + + model = KolmogorovArnoldNetwork([2, 3, 1], k=3, num=5) + solver = PINN(model=model, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=10, + accelerator="cpu", + batch_size=100, + train_size=1.0, + val_size=0.0, + test_size=0.0, + ) + trainer.train() + + + +# def test_update_grid_from_samples(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# samples = torch.randn(50, 3) +# kan.update_grid_from_samples(samples, mode='sample') +# # Check that the network still works after grid update +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_update_grid_resolution(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], num=3) +# kan.update_grid_resolution(5) +# # Check that the network still works after resolution update +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_enable_sparsification(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# kan.enable_sparsification(threshold=1e-4) +# # Check that the network still works after sparsification +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_get_activation_statistics(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# stats = kan.get_activation_statistics(data) +# assert isinstance(stats, dict) +# assert 'layer_0' in stats +# assert 'layer_1' in stats +# assert 'mean' in stats['layer_0'] +# assert 'std' in stats['layer_0'] +# assert 'min' in stats['layer_0'] +# assert 'max' in stats['layer_0'] + + +# def test_get_network_grid_statistics(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# stats = kan.get_network_grid_statistics() +# assert isinstance(stats, dict) +# assert 'layer_0' in stats +# assert 'layer_1' in stats + + +# def test_save_act(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=True) +# output = kan(data) +# assert hasattr(kan, 'acts') +# assert len(kan.acts) == 3 # input + 2 layers +# assert kan.acts[0].shape == data.shape +# assert kan.acts[-1].shape == output.shape + + +# def test_save_act_disabled(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=False) +# _ = kan(data) +# assert hasattr(kan, 'acts') +# # Only the first activation (input) is saved +# assert len(kan.acts) == 1 diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index baff81940..144f71b66 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -191,3 +191,42 @@ def test_derivative(args, pts): # Check shape and value assert first_der.shape == pts.shape assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) + + +#@pytest.mark.parametrize("args", valid_args) # TODO +def test_vectorized(): + + N = 7 + cps = [] + splines = [] + for i in range(N): + cp = torch.rand(n_ctrl_pts) + cps.append(cp) + spline = Spline( + order=order, + control_points=cp + ) + splines.append(spline) + + from pina.model import VectorizedSpline + unique_cps = torch.stack(cps, dim=0) + print(unique_cps.shape) + print(cps[0].shape) + # Vectorized control points + vectorized_spline = VectorizedSpline( + order=order, + knots=splines[0].knots, + control_points=torch.stack(cps, dim=0) + ) + + x = torch.rand(100, 1) + + result_single = torch.stack([ + splines[i](x) for i in range(N) + ]) + print(result_single.shape) + result_single = result_single.permute(1, 2, 0) + out_vectorized = vectorized_spline(x) + print(out_vectorized.shape) + print(result_single.shape) + assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5) \ No newline at end of file From a5481614dc1dd324e841f38f1cfe5d883e2b85f3 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Fri, 20 Mar 2026 10:06:14 +0100 Subject: [PATCH 03/12] Fix minor problem, black formatter Add future todos on kan_block --- pina/_src/model/block/kan_block.py | 329 ++++++++++++------- pina/_src/model/kolmogorov_arnold_network.py | 96 +++--- pina/_src/model/vectorized_spline.py | 52 ++- pina/_src/problem/abstract_problem.py | 9 +- tests/test_model/test_spline.py | 28 +- 5 files changed, 315 insertions(+), 199 deletions(-) diff --git a/pina/_src/model/block/kan_block.py b/pina/_src/model/block/kan_block.py index ec5b5cca3..cbb7509ab 100644 --- a/pina/_src/model/block/kan_block.py +++ b/pina/_src/model/block/kan_block.py @@ -1,18 +1,39 @@ """Create the infrastructure for a KAN layer""" + import torch import numpy as np from pina._src.model.spline import Spline from pina._src.model.vectorized_spline import VectorizedSpline - +# TODO +# - Improve documentation and comments throughout the code for better clarity. +# - Remove any unused parameters or code related to the base function if it's +# not being utilized in the current implementation. +# - Clean unused code class KANBlock(torch.nn.Module): """define a KAN layer using splines""" - def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, - num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, - noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, - scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, - sb_trainable=True): + + def __init__( + self, + k, + input_dimensions, + output_dimensions, + inner_nodes, + num=3, + grid_eps=0.1, + grid_range=[-1, 1], + grid_extension=True, + noise_scale=0.1, + base_function=torch.nn.SiLU(), + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + sparse_init=True, + sp_trainable=True, + sb_trainable=True, + vectorized=True, + ): """ Initialize the KAN layer. @@ -27,41 +48,50 @@ def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, self.grid_eps = grid_eps self.grid_range = grid_range self.grid_extension = grid_extension - self.vec = True - # self.vec = False - + self.vectorized = vectorized + if sparse_init: - self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) + self.mask = torch.nn.Parameter( + self.sparse_mask(input_dimensions, output_dimensions) + ).requires_grad_(False) else: - self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) - - grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + self.mask = torch.nn.Parameter( + torch.ones(input_dimensions, output_dimensions) + ).requires_grad_(False) + + grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[ + None, : + ].expand(self.input_dimensions, self.num + 1) knots = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1) - + if grid_extension: h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) for i in range(self.k): grid = torch.cat([grid[:, [0]] - h, grid], dim=1) grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - - n_control_points = len(knots) - (self.k ) - + + n_control_points = len(knots) - (self.k) + # control_points = torch.nn.Parameter( # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale # ) # print(control_points.shape) - if self.vec: - control_points = torch.randn(self.input_dimensions * self.output_dimensions, n_control_points) - print('control points', control_points.shape) - control_points = torch.stack([ - torch.randn(n_control_points) - for _ in range(self.input_dimensions * self.output_dimensions) - ]) - print('control points', control_points.shape) + if self.vectorized: + control_points = torch.randn( + self.input_dimensions * self.output_dimensions, n_control_points + ) + print("control points", control_points.shape) + control_points = torch.stack( + [ + torch.randn(n_control_points) + for _ in range( + self.input_dimensions * self.output_dimensions + ) + ] + ) + print("control points", control_points.shape) self.spline_q = VectorizedSpline( - order=self.k, - knots=knots, - control_points=control_points + order=self.k, knots=knots, control_points=control_points ) else: @@ -72,14 +102,13 @@ def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, spline_ = Spline( order=self.k, knots=knots, - control_points=torch.randn(n_control_points) + control_points=torch.randn(n_control_points), ) spline_p.append(spline_) spline_p = torch.nn.ModuleList(spline_p) spline_q.append(spline_p) self.spline_q = torch.nn.ModuleList(spline_q) - # control_points = torch.nn.Parameter( # torch.randn(n_control_points, self.output_dimensions) * noise_scale) # print(control_points) @@ -95,20 +124,28 @@ def __init__(self, k, input_dimensions, output_dimensions, inner_nodes, @staticmethod def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor: - ''' + """ get sparse mask - ''' - in_coord = torch.arange(in_dimensions) * 1/in_dimensions + 1/(2*in_dimensions) - out_coord = torch.arange(out_dimensions) * 1/out_dimensions + 1/(2*out_dimensions) + """ + in_coord = torch.arange(in_dimensions) * 1 / in_dimensions + 1 / ( + 2 * in_dimensions + ) + out_coord = torch.arange(out_dimensions) * 1 / out_dimensions + 1 / ( + 2 * out_dimensions + ) - dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) in_nearest = torch.argmin(dist_mat, dim=0) - in_connection = torch.stack([torch.arange(in_dimensions), in_nearest]).permute(1,0) + in_connection = torch.stack( + [torch.arange(in_dimensions), in_nearest] + ).permute(1, 0) out_nearest = torch.argmin(dist_mat, dim=1) - out_connection = torch.stack([out_nearest, torch.arange(out_dimensions)]).permute(1,0) + out_connection = torch.stack( + [out_nearest, torch.arange(out_dimensions)] + ).permute(1, 0) all_connection = torch.cat([in_connection, out_connection], dim=0) mask = torch.zeros(in_dimensions, out_dimensions) - mask[all_connection[:,0], all_connection[:,1]] = 1. + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 return mask def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -117,15 +154,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Each input goes through: w_base*base(x) + w_spline*spline(x) Then sum across input dimensions for each output node. """ - if hasattr(x, 'tensor'): + if hasattr(x, "tensor"): x_tensor = x.tensor else: x_tensor = x - - if self.vec: - y = self.spline_q.forward(x_tensor) # (batch, output_dimensions, input_dimensions) - y = y.reshape(y.shape[0], y.shape[1], self.output_dimensions, self.input_dimensions) + if self.vectorized: + y = self.spline_q.forward( + x_tensor + ) # (batch, output_dimensions, input_dimensions) + y = y.reshape( + y.shape[0], + y.shape[1], + self.output_dimensions, + self.input_dimensions, + ) base_out = self.base_function(x_tensor) # (batch, input_dimensions) y = y + base_out[:, :, None, None] y = y.sum(dim=3).sum(dim=1) # sum over input dimensions @@ -134,98 +177,148 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for q in range(self.output_dimensions): y_q = [] for p in range(self.input_dimensions): - spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions) - base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions) + spline_out = self.spline_q[q][p].forward( + x_tensor[:, p] + ) # (batch, input_dimensions, output_dimensions) + base_out = self.base_function( + x_tensor[:, p] + ) # (batch, input_dimensions) y_q.append(spline_out + base_out) y.append(torch.stack(y_q, dim=1).sum(dim=1)) y = torch.stack(y, dim=1) - + return y - def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + def update_grid_from_samples(self, x: torch.Tensor, mode: str = "sample"): """ Update grid from input samples to better fit data distribution. Based on PyKAN implementation but with boundary preservation. """ # Convert LabelTensor to regular tensor for spline operations - if hasattr(x, 'tensor'): + if hasattr(x, "tensor"): # This is a LabelTensor, extract the tensor part x_tensor = x.tensor else: x_tensor = x - + with torch.no_grad(): batch_size = x_tensor.shape[0] - x_sorted = torch.sort(x_tensor, dim=0)[0] # (batch_size, input_dimensions) - + x_sorted = torch.sort(x_tensor, dim=0)[ + 0 + ] # (batch_size, input_dimensions) + # Get current number of intervals (excluding extensions) if self.grid_extension: - num_interval = self.spline.knots.shape[1] - 1 - 2*self.k + num_interval = self.spline.knots.shape[1] - 1 - 2 * self.k else: num_interval = self.spline.knots.shape[1] - 1 - + def get_grid(num_intervals: int): """PyKAN-style grid creation with boundary preservation""" - ids = [int(batch_size * i / num_intervals) for i in range(num_intervals)] + [-1] - grid_adaptive = x_sorted[ids, :].transpose(0, 1) # (input_dimensions, num_intervals+1) - + ids = [ + int(batch_size * i / num_intervals) + for i in range(num_intervals) + ] + [-1] + grid_adaptive = x_sorted[ids, :].transpose( + 0, 1 + ) # (input_dimensions, num_intervals+1) + original_min = self.grid_range[0] original_max = self.grid_range[1] - + # Clamp adaptive grid to not shrink beyond original domain - grid_adaptive[:, 0] = torch.min(grid_adaptive[:, 0], - torch.full_like(grid_adaptive[:, 0], original_min)) - grid_adaptive[:, -1] = torch.max(grid_adaptive[:, -1], - torch.full_like(grid_adaptive[:, -1], original_max)) - - margin = 0.0 - h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_intervals - grid_uniform = (grid_adaptive[:, [0]] - margin + - h * torch.arange(num_intervals + 1, device=x_tensor.device, dtype=x_tensor.dtype)[None, :]) - - grid_blended = (self.grid_eps * grid_uniform + - (1 - self.grid_eps) * grid_adaptive) - + grid_adaptive[:, 0] = torch.min( + grid_adaptive[:, 0], + torch.full_like(grid_adaptive[:, 0], original_min), + ) + grid_adaptive[:, -1] = torch.max( + grid_adaptive[:, -1], + torch.full_like(grid_adaptive[:, -1], original_max), + ) + + margin = 0.0 + h = ( + grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin + ) / num_intervals + grid_uniform = ( + grid_adaptive[:, [0]] + - margin + + h + * torch.arange( + num_intervals + 1, + device=x_tensor.device, + dtype=x_tensor.dtype, + )[None, :] + ) + + grid_blended = ( + self.grid_eps * grid_uniform + + (1 - self.grid_eps) * grid_adaptive + ) + return grid_blended - + # Create augmented evaluation points: samples + boundary points # This ensures we preserve boundary behavior while adapting to sample density - boundary_points = torch.tensor([[self.grid_range[0]], [self.grid_range[1]]], - device=x_tensor.device, dtype=x_tensor.dtype).expand(-1, self.input_dimensions) - + boundary_points = torch.tensor( + [[self.grid_range[0]], [self.grid_range[1]]], + device=x_tensor.device, + dtype=x_tensor.dtype, + ).expand(-1, self.input_dimensions) + # Combine samples with boundary points for evaluation x_augmented = torch.cat([x_sorted, boundary_points], dim=0) - x_augmented = torch.sort(x_augmented, dim=0)[0] # Re-sort with boundaries included - + x_augmented = torch.sort(x_augmented, dim=0)[ + 0 + ] # Re-sort with boundaries included + # Evaluate current spline at augmented points (samples + boundaries) - basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) - y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) - + basis = self.spline.basis( + x_augmented, self.spline.k, self.spline.knots + ) + y_eval = torch.einsum( + "bil,iol->bio", basis, self.spline.control_points + ) + # Create new grid new_grid = get_grid(num_interval) - - if mode == 'grid': + + if mode == "grid": # For 'grid' mode, use denser sampling sample_grid = get_grid(2 * num_interval) - x_augmented = sample_grid.transpose(0, 1) # (batch_size, input_dimensions) - basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) - y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) - + x_augmented = sample_grid.transpose( + 0, 1 + ) # (batch_size, input_dimensions) + basis = self.spline.basis( + x_augmented, self.spline.k, self.spline.knots + ) + y_eval = torch.einsum( + "bil,iol->bio", basis, self.spline.control_points + ) + # Add grid extensions if needed if self.grid_extension: - h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / ( + new_grid.shape[1] - 1 + ) for i in range(self.k): - new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) - new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) - + new_grid = torch.cat( + [new_grid[:, [0]] - h, new_grid], dim=1 + ) + new_grid = torch.cat( + [new_grid, new_grid[:, [-1]] + h], dim=1 + ) + # Update grid and refit coefficients self.spline.knots = new_grid - + try: # Refit coefficients using augmented points (preserves boundaries) self.spline.compute_control_points(x_augmented, y_eval) except Exception as e: - print(f"Warning: Failed to update coefficients during grid refinement: {e}") + print( + f"Warning: Failed to update coefficients during grid refinement: {e}" + ) def update_grid_resolution(self, new_num: int): """ @@ -234,32 +327,42 @@ def update_grid_resolution(self, new_num: int): with torch.no_grad(): # Sample the current spline function on a dense grid x_eval = torch.linspace( - self.grid_range[0], - self.grid_range[1], - steps=2 * new_num, - device=self.spline.knots.device + self.grid_range[0], + self.grid_range[1], + steps=2 * new_num, + device=self.spline.knots.device, ) x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions) basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots) - y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + y_eval = torch.einsum( + "bil,iol->bio", basis, self.spline.control_points + ) # Update num and create a new grid self.num = new_num new_grid = torch.linspace( - self.grid_range[0], - self.grid_range[1], - steps=self.num + 1, - device=self.spline.knots.device + self.grid_range[0], + self.grid_range[1], + steps=self.num + 1, + device=self.spline.knots.device, + ) + new_grid = new_grid[None, :].expand( + self.input_dimensions, self.num + 1 ) - new_grid = new_grid[None, :].expand(self.input_dimensions, self.num + 1) if self.grid_extension: - h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / ( + new_grid.shape[1] - 1 + ) for i in range(self.k): - new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) - new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) - + new_grid = torch.cat( + [new_grid[:, [0]] - h, new_grid], dim=1 + ) + new_grid = torch.cat( + [new_grid, new_grid[:, [-1]] + h], dim=1 + ) + # Update spline with the new grid and re-compute control points self.spline.knots = new_grid self.spline.compute_control_points(x_eval, y_eval) @@ -267,9 +370,13 @@ def update_grid_resolution(self, new_num: int): def get_grid_statistics(self): """Get statistics about the current grid for debugging/analysis""" return { - 'grid_shape': self.spline.knots.shape, - 'grid_min': self.spline.knots.min().item(), - 'grid_max': self.spline.knots.max().item(), - 'grid_range': (self.spline.knots.max() - self.spline.knots.min()).mean().item(), - 'num_intervals': self.spline.knots.shape[1] - 1 - (2*self.k if self.spline.grid_extension else 0) - } \ No newline at end of file + "grid_shape": self.spline.knots.shape, + "grid_min": self.spline.knots.min().item(), + "grid_max": self.spline.knots.max().item(), + "grid_range": (self.spline.knots.max() - self.spline.knots.min()) + .mean() + .item(), + "num_intervals": self.spline.knots.shape[1] + - 1 + - (2 * self.k if self.spline.grid_extension else 0), + } diff --git a/pina/_src/model/kolmogorov_arnold_network.py b/pina/_src/model/kolmogorov_arnold_network.py index 81f0754b0..1c8c38789 100644 --- a/pina/_src/model/kolmogorov_arnold_network.py +++ b/pina/_src/model/kolmogorov_arnold_network.py @@ -1,10 +1,12 @@ """Kolmogorov Arnold Network implementation""" + import torch import torch.nn as nn from typing import List from pina._src.model.block.kan_block import KANBlock + class KolmogorovArnoldNetwork(torch.nn.Module): """ Kolmogorov Arnold Network, a neural network using KAN layers instead of @@ -18,9 +20,9 @@ class KolmogorovArnoldNetwork(torch.nn.Module): preprint arXiv:2404.19756. """ - + def __init__( - self, + self, layer_sizes: List[int], k: int = 3, num: int = 3, @@ -28,7 +30,7 @@ def __init__( grid_range: List[float] = [-1, 1], grid_extension: bool = True, noise_scale: float = 0.1, - base_function = torch.nn.SiLU(), + base_function=torch.nn.SiLU(), scale_base_mu: float = 0.0, scale_base_sigma: float = 1.0, scale_sp: float = 1.0, @@ -36,7 +38,7 @@ def __init__( sparse_init: bool = False, sp_trainable: bool = True, sb_trainable: bool = True, - save_act: bool = True + save_act: bool = True, ): """ Initialize the KAN network. @@ -61,22 +63,22 @@ def __init__( :param bool save_act: Whether to save activations after each layer. """ super().__init__() - + if len(layer_sizes) < 2: raise ValueError("Need at least input and output dimensions") - + self.layer_sizes = layer_sizes self.num_layers = len(layer_sizes) - 1 self.save_act = save_act - + # Create KAN layers self.kan_layers = nn.ModuleList() - + for i in range(self.num_layers): layer = KANBlock( k=k, input_dimensions=layer_sizes[i], - output_dimensions=layer_sizes[i+1], + output_dimensions=layer_sizes[i + 1], num=num, grid_eps=grid_eps, grid_range=grid_range, @@ -89,17 +91,17 @@ def __init__( inner_nodes=inner_nodes, sparse_init=sparse_init, sp_trainable=sp_trainable, - sb_trainable=sb_trainable + sb_trainable=sb_trainable, ) self.kan_layers.append(layer) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the KAN network. - + Args: x: Input tensor of shape (batch_size, input_dimensions) - + Returns: Output tensor of shape (batch_size, output_dimensions) """ @@ -109,98 +111,100 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.kan_layers): current = layer(current) # current = torch.nn.functional.sigmoid(current) - + if self.save_act: self.acts.append(current.detach()) - + return current - + def get_num_parameters(self) -> int: """Get total number of trainable parameters""" return sum(p.numel() for p in self.parameters() if p.requires_grad) - - - def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = "sample"): """ Update grid for all layers based on input samples. This adapts the grid points to better fit the data distribution. - + Args: x: Input samples, shape (batch_size, input_dimensions) mode: 'sample' or 'grid' - determines sampling strategy """ current = x - + for i, layer in enumerate(self.kan_layers): layer.update_grid_from_samples(current, mode=mode) - + if i < len(self.kan_layers) - 1: with torch.no_grad(): current = layer(current) - + def update_grid_resolution(self, new_num: int): """ Update the grid resolution for all layers. This can be used for adaptive training where grid resolution increases over time. - + Args: new_num: New number of grid points """ for layer in self.kan_layers: layer.update_grid_resolution(new_num) - + def enable_sparsification(self, threshold: float = 1e-4): """ Enable sparsification by setting small weights to zero. - + Args: threshold: Threshold below which weights are set to zero """ with torch.no_grad(): for layer in self.kan_layers: # Sparsify scale parameters - layer.scale_base.data[torch.abs(layer.scale_base.data) < threshold] = 0 - layer.scale_spline.data[torch.abs(layer.scale_spline.data) < threshold] = 0 - + layer.scale_base.data[ + torch.abs(layer.scale_base.data) < threshold + ] = 0 + layer.scale_spline.data[ + torch.abs(layer.scale_spline.data) < threshold + ] = 0 + # Update mask - layer.mask.data = ((torch.abs(layer.scale_base) >= threshold) | - (torch.abs(layer.scale_spline) >= threshold)).float() + layer.mask.data = ( + (torch.abs(layer.scale_base) >= threshold) + | (torch.abs(layer.scale_spline) >= threshold) + ).float() def get_activation_statistics(self, x: torch.Tensor): """ Get statistics about activations for analysis purposes. - + Args: x: Input tensor - + Returns: Dictionary with activation statistics """ stats = {} current = x - + for i, layer in enumerate(self.kan_layers): current = layer(current) - stats[f'layer_{i}'] = { - 'mean': current.mean().item(), - 'std': current.std().item(), - 'min': current.min().item(), - 'max': current.max().item() + stats[f"layer_{i}"] = { + "mean": current.mean().item(), + "std": current.std().item(), + "min": current.min().item(), + "max": current.max().item(), } - + return stats - - + def get_network_grid_statistics(self): """ Get grid statistics for all layers in the network. - + Returns: Dictionary with grid statistics for each layer """ stats = {} for i, layer in enumerate(self.kan_layers): - stats[f'layer_{i}'] = layer.get_grid_statistics() + stats[f"layer_{i}"] = layer.get_grid_statistics() return stats - - \ No newline at end of file diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index 89d2a0e72..7bf48256a 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn + class VectorizedSpline(nn.Module): """ Vectorized univariate B-spline model (shared knots, many splines). @@ -21,7 +22,12 @@ class VectorizedSpline(nn.Module): - if control_points is (S, O, n_ctrl): shape (..., S, O) """ - def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None): + def __init__( + self, + order: int, + knots: torch.Tensor, + control_points: torch.Tensor | None = None, + ): super().__init__() if not isinstance(order, int) or order <= 0: raise ValueError("order must be a positive integer.") @@ -38,10 +44,14 @@ def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor n_ctrl = knots_sorted.numel() - order if n_ctrl <= 0: - raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.") + raise ValueError( + f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}." + ) # boundary interval idx for rightmost inclusion - self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted) + self._boundary_interval_idx = self._compute_boundary_interval_idx( + knots_sorted + ) # # control points init # if control_points is None: @@ -90,13 +100,17 @@ def basis(self, x: torch.Tensor) -> torch.Tensor: knots = self.knots.view(*([1] * x.ndim), -1) # order-1 base: indicator on intervals [t_i, t_{i+1}) - basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1) + basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to( + x_exp.dtype + ) # (..., m-1) # include rightmost boundary in the last non-degenerate interval j = self._boundary_interval_idx knot_left = knots[..., j] knot_right = knots[..., j + 1] - at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10) + at_right = (x >= knot_left.squeeze(-1)) & torch.isclose( + x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10 + ) if torch.any(at_right): basis_j = basis[..., j].bool() | at_right basis[..., j] = basis_j.to(basis.dtype) @@ -104,14 +118,20 @@ def basis(self, x: torch.Tensor) -> torch.Tensor: # Cox-de Boor recursion up to order k # after i-th iteration, basis has length (m-1 - i) for i in range(1, self.order): - denom1 = knots[..., i:-1] - knots[..., :-(i + 1)] - denom2 = knots[..., i + 1:] - knots[..., 1:-i] - - denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1) - denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2) - - term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1] - term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:] + denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + + denom1 = torch.where( + denom1.abs() < 1e-8, torch.ones_like(denom1), denom1 + ) + denom2 = torch.where( + denom2.abs() < 1e-8, torch.ones_like(denom2), denom2 + ) + + term1 = ((x_exp - knots[..., : -(i + 1)]) / denom1) * basis[ + ..., :-1 + ] + term2 = ((knots[..., i + 1 :] - x_exp) / denom2) * basis[..., 1:] basis = term1 + term2 # final basis length is n_ctrl = m - order @@ -143,9 +163,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out def forward_basis(self, basis): - """ + """ Evaluate spline(s) given precomputed basis. - + """ cp = self.control_points if cp.ndim == 2: @@ -161,4 +181,4 @@ def forward_basis(self, basis): # out = torch.einsum("...n, son -> ...so", B, cp) out = torch.einsum("bsc,sco->bso", basis, cp) - return out \ No newline at end of file + return out diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index 5dbba18c2..28bccf089 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -289,15 +289,10 @@ def move_discretisation_into_conditions(self): if not self.are_all_domains_discretised: warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { + warning_message = "\n".join([f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.discretised_domains else - "not sampled"}""" - for key in self.domains - ] - ) + "not sampled"}""" for key in self.domains]) warnings.warn( "Some of the domains are still not sampled. Consider calling " "problem.discretise_domain function for all domains before " diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index 144f71b66..d375f92ef 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -2,7 +2,7 @@ import pytest from scipy.interpolate import BSpline from pina.operator import grad -from pina.model import Spline +from pina.model import Spline, VectorizedSpline from pina import LabelTensor # Utility quantities for testing @@ -193,30 +193,23 @@ def test_derivative(args, pts): assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) -#@pytest.mark.parametrize("args", valid_args) # TODO -def test_vectorized(): +@pytest.mark.parametrize("args", valid_args) +@pytest.mark.parametrize("N", [1, 4, 7]) +def test_vectorized(args, N): - N = 7 cps = [] splines = [] + for i in range(N): - cp = torch.rand(n_ctrl_pts) - cps.append(cp) - spline = Spline( - order=order, - control_points=cp - ) + spline = Spline(**args) splines.append(spline) + cps.append(spline.control_points) - from pina.model import VectorizedSpline unique_cps = torch.stack(cps, dim=0) - print(unique_cps.shape) - print(cps[0].shape) - # Vectorized control points vectorized_spline = VectorizedSpline( - order=order, + order=args["order"], knots=splines[0].knots, - control_points=torch.stack(cps, dim=0) + control_points=unique_cps ) x = torch.rand(100, 1) @@ -224,9 +217,6 @@ def test_vectorized(): result_single = torch.stack([ splines[i](x) for i in range(N) ]) - print(result_single.shape) result_single = result_single.permute(1, 2, 0) out_vectorized = vectorized_spline(x) - print(out_vectorized.shape) - print(result_single.shape) assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5) \ No newline at end of file From c1d7e2613c49fc1e1bc190012276ef7652f1c665 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Thu, 26 Mar 2026 09:38:28 +0100 Subject: [PATCH 04/12] fix output dimension for vectorized spline --- pina/_src/model/spline.py | 1 + pina/_src/model/vectorized_spline.py | 25 +++++++++++++++++++------ tests/test_model/test_spline.py | 5 ++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/pina/_src/model/spline.py b/pina/_src/model/spline.py index 1b00300de..5df52c106 100644 --- a/pina/_src/model/spline.py +++ b/pina/_src/model/spline.py @@ -278,6 +278,7 @@ def forward(self, x): :rtype: torch.Tensor """ basis = self.basis(x.as_subclass(torch.Tensor)) + # print("normal forward, cp:", self.control_points) return basis @ self.control_points def derivative(self, x, degree): diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index 7bf48256a..f60ac8a2b 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -24,9 +24,10 @@ class VectorizedSpline(nn.Module): def __init__( self, - order: int, - knots: torch.Tensor, - control_points: torch.Tensor | None = None, + order, + knots, + control_points=None, + aggregate_output=None, ): super().__init__() if not isinstance(order, int) or order <= 0: @@ -68,6 +69,7 @@ def __init__( # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}." # ) self.control_points = nn.Parameter(control_points, requires_grad=True) + self.aggregate_output = aggregate_output @staticmethod def _compute_boundary_interval_idx(knots: torch.Tensor) -> int: @@ -90,7 +92,8 @@ def basis(self, x: torch.Tensor) -> torch.Tensor: x = torch.as_tensor(x) # ensure float dtype consistent - x = x.to(dtype=self.knots.dtype, device=self.knots.device) + # x = x.to(dtype=self.knots.dtype, device=self.knots.device) + x = x.to(dtype=self.knots.dtype, device=self.knots.device).as_subclass(torch.Tensor) # make x shape (..., 1) for broadcasting x_exp = x.unsqueeze(-1) # (..., 1) @@ -147,11 +150,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: B = self.basis(x) # (..., n_ctrl) cp = self.control_points + # print("vectorized forward, cp:", cp) if cp.ndim == 2: # (S, n_ctrl) # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) + # print('B shape:', B.shape, 'cp shape:', cp.shape) + #out = (B @ cp.transpose(0, 1)).squeeze(-1) out = B @ cp.transpose(0, 1) - return out + # out = B @ cp[0] else: # (S, O, n_ctrl) # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S @@ -160,7 +166,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # out = torch.einsum("...n, son -> ...so", B, cp) out = torch.einsum("bsc,sco->bso", B, cp) - return out + if self.aggregate_output == "mean": + out = out.mean(dim=-1) # aggregate over O dimension if present + elif self.aggregate_output == "sum": + out = out.sum(dim=-1) + + # print("vectorized forward, out:", out.shape) + + return out def forward_basis(self, basis): """ diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index d375f92ef..2191f6ee4 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -217,6 +217,9 @@ def test_vectorized(args, N): result_single = torch.stack([ splines[i](x) for i in range(N) ]) - result_single = result_single.permute(1, 2, 0) + result_single = result_single.permute(1, 2, 0) # shape (100, N) out_vectorized = vectorized_spline(x) + print("result single shape:", result_single.shape) + print("out vectorized shape:", out_vectorized.shape) + assert out_vectorized.shape == (100, 1, N) assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5) \ No newline at end of file From e2ec4d0a4571bb90320154a634dceabd4be7e42a Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Thu, 2 Apr 2026 17:19:33 +0200 Subject: [PATCH 05/12] fix index mismatch and remove unused function --- pina/_src/model/vectorized_spline.py | 31 ++++++---------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index f60ac8a2b..737b52fcc 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -93,7 +93,9 @@ def basis(self, x: torch.Tensor) -> torch.Tensor: # ensure float dtype consistent # x = x.to(dtype=self.knots.dtype, device=self.knots.device) - x = x.to(dtype=self.knots.dtype, device=self.knots.device).as_subclass(torch.Tensor) + x = x.as_subclass(torch.Tensor).to( + dtype=self.knots.dtype, device=self.knots.device + ) # make x shape (..., 1) for broadcasting x_exp = x.unsqueeze(-1) # (..., 1) @@ -155,7 +157,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # (S, n_ctrl) # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) # print('B shape:', B.shape, 'cp shape:', cp.shape) - #out = (B @ cp.transpose(0, 1)).squeeze(-1) + # out = (B @ cp.transpose(0, 1)).squeeze(-1) out = B @ cp.transpose(0, 1) # out = B @ cp[0] else: @@ -164,7 +166,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # vectorized using einsum (yes, this one is actually appropriate) # (..., n) * (S, O, n) -> (..., S, O) # out = torch.einsum("...n, son -> ...so", B, cp) - out = torch.einsum("bsc,sco->bso", B, cp) + out = torch.einsum("bsc,soc->bso", B, cp) if self.aggregate_output == "mean": out = out.mean(dim=-1) # aggregate over O dimension if present @@ -172,26 +174,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = out.sum(dim=-1) # print("vectorized forward, out:", out.shape) - - return out - def forward_basis(self, basis): - """ - Evaluate spline(s) given precomputed basis. - - """ - cp = self.control_points - if cp.ndim == 2: - # (S, n_ctrl) - # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) - out = basis @ cp.transpose(0, 1) - return out - else: - # (S, O, n_ctrl) - # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S - # vectorized using einsum (yes, this one is actually appropriate) - # (..., n) * (S, O, n) -> (..., S, O) - # out = torch.einsum("...n, son -> ...so", B, cp) - out = torch.einsum("bsc,sco->bso", basis, cp) - - return out + return out From 88e4bbd7a59225dcd42d13da778b8b7c9a392cce Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Thu, 2 Apr 2026 17:39:46 +0200 Subject: [PATCH 06/12] Fix vectorized splines and implement working KAN --- pina/_src/model/block/kan_block.py | 464 ++++------------ pina/_src/model/kolmogorov_arnold_network.py | 254 +++------ pina/_src/model/spline.py | 2 +- pina/_src/model/vectorized_spline.py | 554 +++++++++++++++---- 4 files changed, 613 insertions(+), 661 deletions(-) diff --git a/pina/_src/model/block/kan_block.py b/pina/_src/model/block/kan_block.py index cbb7509ab..08655f467 100644 --- a/pina/_src/model/block/kan_block.py +++ b/pina/_src/model/block/kan_block.py @@ -1,382 +1,146 @@ -"""Create the infrastructure for a KAN layer""" +"""Module for the Kolmogorov-Arnold Network block.""" import torch -import numpy as np - -from pina._src.model.spline import Spline from pina._src.model.vectorized_spline import VectorizedSpline +from pina._src.core.utils import check_consistency, check_positive_integer + -# TODO -# - Improve documentation and comments throughout the code for better clarity. -# - Remove any unused parameters or code related to the base function if it's -# not being utilized in the current implementation. -# - Clean unused code class KANBlock(torch.nn.Module): - """define a KAN layer using splines""" + """ + TODO: docstring. + """ def __init__( self, - k, input_dimensions, output_dimensions, - inner_nodes, - num=3, - grid_eps=0.1, - grid_range=[-1, 1], - grid_extension=True, - noise_scale=0.1, - base_function=torch.nn.SiLU(), - scale_base_mu=0.0, - scale_base_sigma=1.0, - scale_sp=1.0, - sparse_init=True, - sp_trainable=True, - sb_trainable=True, - vectorized=True, + spline_order=3, + n_knots=10, + grid_range=[0, 1], + base_function=torch.nn.SiLU, + use_base_linear=True, + use_bias=True, + init_scale_spline=1e-2, + init_scale_base=1.0, ): """ - Initialize the KAN layer. - - num è il numero di intervalli nella griglia iniziale (esclusi gli eventuali nodi di estensione) + Initialization of the :class:`KANBlock` class. + + :param int input_dimensions: The number of input features. + :param int output_dimensions: The number of output features. + :param int spline_order: The order of each spline basis function. + Default is 3 (cubic splines). + :param int n_knots: The number of knots for each spline basis function. + Default is 10. + :param grid_range: The range for the spline knots. It must be either a + list or a tuple of the form [min, max]. Default is [0, 1]. + :type grid_range: list | tuple. + :param torch.nn.Module base_function: The base activation function to be + applied to the input before the linear transformation. Default is + :class:`torch.nn.SiLU`. + :param bool use_base_linear: Whether to include a linear transformation + of the base function output. Default is True. + :param bool use_bias: Whether to include a bias term in the output. + Default is True. + :param init_scale_spline: The scale for initializing each spline + control points. Default is 1e-2. + :type init_scale_spline: float | int. + :param init_scale_base: The scale for initializing the base linear + weights. Default is 1.0. + :type init_scale_base: float | int. + :raises ValueError: If ``grid_range`` is not of length 2. """ super().__init__() - self.k = k - self.input_dimensions = input_dimensions - self.output_dimensions = output_dimensions - self.inner_nodes = inner_nodes - self.num = num - self.grid_eps = grid_eps - self.grid_range = grid_range - self.grid_extension = grid_extension - self.vectorized = vectorized - - if sparse_init: - self.mask = torch.nn.Parameter( - self.sparse_mask(input_dimensions, output_dimensions) - ).requires_grad_(False) - else: - self.mask = torch.nn.Parameter( - torch.ones(input_dimensions, output_dimensions) - ).requires_grad_(False) - - grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[ - None, : - ].expand(self.input_dimensions, self.num + 1) - knots = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1) - - if grid_extension: - h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) - for i in range(self.k): - grid = torch.cat([grid[:, [0]] - h, grid], dim=1) - grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - n_control_points = len(knots) - (self.k) - - # control_points = torch.nn.Parameter( - # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale - # ) - # print(control_points.shape) - if self.vectorized: - control_points = torch.randn( - self.input_dimensions * self.output_dimensions, n_control_points - ) - print("control points", control_points.shape) - control_points = torch.stack( - [ - torch.randn(n_control_points) - for _ in range( - self.input_dimensions * self.output_dimensions - ) - ] - ) - print("control points", control_points.shape) - self.spline_q = VectorizedSpline( - order=self.k, knots=knots, control_points=control_points + # Check consistency + check_consistency(base_function, torch.nn.Module, subclass=True) + check_positive_integer(input_dimensions, strict=True) + check_positive_integer(output_dimensions, strict=True) + check_positive_integer(spline_order, strict=True) + check_positive_integer(n_knots, strict=True) + check_consistency(use_base_linear, bool) + check_consistency(use_bias, bool) + check_consistency(init_scale_spline, (int, float)) + check_consistency(init_scale_base, (int, float)) + check_consistency(grid_range, (int, float)) + + # Raise error if grid_range is not valid + if len(grid_range) != 2: + raise ValueError("Grid must be a list or tuple with two elements.") + + # Knots for the spline basis functions + initial_knots = torch.ones(spline_order) * grid_range[0] + final_knots = torch.ones(spline_order) * grid_range[1] + + # Number of internal knots + n_internal = max(0, n_knots - 2 * spline_order) + + # Internal knots are uniformly spaced in the grid range + internal_knots = torch.linspace( + grid_range[0], grid_range[1], n_internal + 2 + )[1:-1] + + # Define the knots + knots = torch.cat((initial_knots, internal_knots, final_knots)) + knots = knots.unsqueeze(0).repeat(input_dimensions, 1) + + # Define the control points for the spline basis functions + control_points = ( + torch.randn( + input_dimensions, + output_dimensions, + knots.shape[-1] - spline_order, ) - - else: - spline_q = [] - for q in range(self.output_dimensions): - spline_p = [] - for p in range(self.input_dimensions): - spline_ = Spline( - order=self.k, - knots=knots, - control_points=torch.randn(n_control_points), - ) - spline_p.append(spline_) - spline_p = torch.nn.ModuleList(spline_p) - spline_q.append(spline_p) - self.spline_q = torch.nn.ModuleList(spline_q) - - # control_points = torch.nn.Parameter( - # torch.randn(n_control_points, self.output_dimensions) * noise_scale) - # print(control_points) - # print('uuu') - - # self.spline = Spline( - # order=self.k, knots=knots, control_points=control_points) - - # self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ - # scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) - # self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) - self.base_function = base_function - - @staticmethod - def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor: - """ - get sparse mask - """ - in_coord = torch.arange(in_dimensions) * 1 / in_dimensions + 1 / ( - 2 * in_dimensions + * init_scale_spline ) - out_coord = torch.arange(out_dimensions) * 1 / out_dimensions + 1 / ( - 2 * out_dimensions + + # Define the vectorized spline module + self.spline = VectorizedSpline( + order=spline_order, knots=knots, control_points=control_points ) - dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) - in_nearest = torch.argmin(dist_mat, dim=0) - in_connection = torch.stack( - [torch.arange(in_dimensions), in_nearest] - ).permute(1, 0) - out_nearest = torch.argmin(dist_mat, dim=1) - out_connection = torch.stack( - [out_nearest, torch.arange(out_dimensions)] - ).permute(1, 0) - all_connection = torch.cat([in_connection, out_connection], dim=0) - mask = torch.zeros(in_dimensions, out_dimensions) - mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 - return mask + # Initialize the base function + self.base_function = base_function() - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass through the KAN layer. - Each input goes through: w_base*base(x) + w_spline*spline(x) - Then sum across input dimensions for each output node. - """ - if hasattr(x, "tensor"): - x_tensor = x.tensor - else: - x_tensor = x - - if self.vectorized: - y = self.spline_q.forward( - x_tensor - ) # (batch, output_dimensions, input_dimensions) - y = y.reshape( - y.shape[0], - y.shape[1], - self.output_dimensions, - self.input_dimensions, + # Initialize the base linear weights if needed + if use_base_linear: + self.base_weight = torch.nn.Parameter( + torch.randn(output_dimensions, input_dimensions) + * (init_scale_base / (input_dimensions**0.5)) ) - base_out = self.base_function(x_tensor) # (batch, input_dimensions) - y = y + base_out[:, :, None, None] - y = y.sum(dim=3).sum(dim=1) # sum over input dimensions else: - y = [] - for q in range(self.output_dimensions): - y_q = [] - for p in range(self.input_dimensions): - spline_out = self.spline_q[q][p].forward( - x_tensor[:, p] - ) # (batch, input_dimensions, output_dimensions) - base_out = self.base_function( - x_tensor[:, p] - ) # (batch, input_dimensions) - y_q.append(spline_out + base_out) - y.append(torch.stack(y_q, dim=1).sum(dim=1)) - y = torch.stack(y, dim=1) + self.register_parameter("base_weight", None) - return y - - def update_grid_from_samples(self, x: torch.Tensor, mode: str = "sample"): - """ - Update grid from input samples to better fit data distribution. - Based on PyKAN implementation but with boundary preservation. - """ - # Convert LabelTensor to regular tensor for spline operations - if hasattr(x, "tensor"): - # This is a LabelTensor, extract the tensor part - x_tensor = x.tensor + # Initialize the bias term if needed + if use_bias: + self.bias = torch.nn.Parameter(torch.zeros(output_dimensions)) else: - x_tensor = x - - with torch.no_grad(): - batch_size = x_tensor.shape[0] - x_sorted = torch.sort(x_tensor, dim=0)[ - 0 - ] # (batch_size, input_dimensions) - - # Get current number of intervals (excluding extensions) - if self.grid_extension: - num_interval = self.spline.knots.shape[1] - 1 - 2 * self.k - else: - num_interval = self.spline.knots.shape[1] - 1 - - def get_grid(num_intervals: int): - """PyKAN-style grid creation with boundary preservation""" - ids = [ - int(batch_size * i / num_intervals) - for i in range(num_intervals) - ] + [-1] - grid_adaptive = x_sorted[ids, :].transpose( - 0, 1 - ) # (input_dimensions, num_intervals+1) - - original_min = self.grid_range[0] - original_max = self.grid_range[1] - - # Clamp adaptive grid to not shrink beyond original domain - grid_adaptive[:, 0] = torch.min( - grid_adaptive[:, 0], - torch.full_like(grid_adaptive[:, 0], original_min), - ) - grid_adaptive[:, -1] = torch.max( - grid_adaptive[:, -1], - torch.full_like(grid_adaptive[:, -1], original_max), - ) - - margin = 0.0 - h = ( - grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin - ) / num_intervals - grid_uniform = ( - grid_adaptive[:, [0]] - - margin - + h - * torch.arange( - num_intervals + 1, - device=x_tensor.device, - dtype=x_tensor.dtype, - )[None, :] - ) + self.register_parameter("bias", None) - grid_blended = ( - self.grid_eps * grid_uniform - + (1 - self.grid_eps) * grid_adaptive - ) - - return grid_blended - - # Create augmented evaluation points: samples + boundary points - # This ensures we preserve boundary behavior while adapting to sample density - boundary_points = torch.tensor( - [[self.grid_range[0]], [self.grid_range[1]]], - device=x_tensor.device, - dtype=x_tensor.dtype, - ).expand(-1, self.input_dimensions) - - # Combine samples with boundary points for evaluation - x_augmented = torch.cat([x_sorted, boundary_points], dim=0) - x_augmented = torch.sort(x_augmented, dim=0)[ - 0 - ] # Re-sort with boundaries included - - # Evaluate current spline at augmented points (samples + boundaries) - basis = self.spline.basis( - x_augmented, self.spline.k, self.spline.knots - ) - y_eval = torch.einsum( - "bil,iol->bio", basis, self.spline.control_points - ) - - # Create new grid - new_grid = get_grid(num_interval) - - if mode == "grid": - # For 'grid' mode, use denser sampling - sample_grid = get_grid(2 * num_interval) - x_augmented = sample_grid.transpose( - 0, 1 - ) # (batch_size, input_dimensions) - basis = self.spline.basis( - x_augmented, self.spline.k, self.spline.knots - ) - y_eval = torch.einsum( - "bil,iol->bio", basis, self.spline.control_points - ) - - # Add grid extensions if needed - if self.grid_extension: - h = (new_grid[:, [-1]] - new_grid[:, [0]]) / ( - new_grid.shape[1] - 1 - ) - for i in range(self.k): - new_grid = torch.cat( - [new_grid[:, [0]] - h, new_grid], dim=1 - ) - new_grid = torch.cat( - [new_grid, new_grid[:, [-1]] + h], dim=1 - ) - - # Update grid and refit coefficients - self.spline.knots = new_grid + def forward(self, x): + """ + Forward pass of the :class:`KANBlock`. It transforms the input using a + vectorized spline basis and optionally adds a linear transformation of a + base activation function. - try: - # Refit coefficients using augmented points (preserves boundaries) - self.spline.compute_control_points(x_augmented, y_eval) - except Exception as e: - print( - f"Warning: Failed to update coefficients during grid refinement: {e}" - ) + The input is expected to have shape (batch_size, input_dimensions) and + the output will have shape (batch_size, output_dimensions). - def update_grid_resolution(self, new_num: int): - """ - Update grid resolution to a new number of intervals. + :param torch.Tensor x: The input tensor for the model. + :return: The output tensor of the model. + :rtype: torch.Tensor """ - with torch.no_grad(): - # Sample the current spline function on a dense grid - x_eval = torch.linspace( - self.grid_range[0], - self.grid_range[1], - steps=2 * new_num, - device=self.spline.knots.device, - ) - x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions) - - basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots) - y_eval = torch.einsum( - "bil,iol->bio", basis, self.spline.control_points - ) + y = self.spline(x) - # Update num and create a new grid - self.num = new_num - new_grid = torch.linspace( - self.grid_range[0], - self.grid_range[1], - steps=self.num + 1, - device=self.spline.knots.device, - ) - new_grid = new_grid[None, :].expand( - self.input_dimensions, self.num + 1 - ) + if self.base_weight is not None: + base_x = self.base_function(x) + base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight) + y = y + base_out - if self.grid_extension: - h = (new_grid[:, [-1]] - new_grid[:, [0]]) / ( - new_grid.shape[1] - 1 - ) - for i in range(self.k): - new_grid = torch.cat( - [new_grid[:, [0]] - h, new_grid], dim=1 - ) - new_grid = torch.cat( - [new_grid, new_grid[:, [-1]] + h], dim=1 - ) + # aggregate contributions from all input dimensions + y = y.sum(dim=1) - # Update spline with the new grid and re-compute control points - self.spline.knots = new_grid - self.spline.compute_control_points(x_eval, y_eval) + if self.bias is not None: + y = y + self.bias - def get_grid_statistics(self): - """Get statistics about the current grid for debugging/analysis""" - return { - "grid_shape": self.spline.knots.shape, - "grid_min": self.spline.knots.min().item(), - "grid_max": self.spline.knots.max().item(), - "grid_range": (self.spline.knots.max() - self.spline.knots.min()) - .mean() - .item(), - "num_intervals": self.spline.knots.shape[1] - - 1 - - (2 * self.k if self.spline.grid_extension else 0), - } + return y diff --git a/pina/_src/model/kolmogorov_arnold_network.py b/pina/_src/model/kolmogorov_arnold_network.py index 1c8c38789..dec01569c 100644 --- a/pina/_src/model/kolmogorov_arnold_network.py +++ b/pina/_src/model/kolmogorov_arnold_network.py @@ -1,210 +1,86 @@ -"""Kolmogorov Arnold Network implementation""" - import torch -import torch.nn as nn -from typing import List - from pina._src.model.block.kan_block import KANBlock +from pina._src.core.utils import check_consistency class KolmogorovArnoldNetwork(torch.nn.Module): """ - Kolmogorov Arnold Network, a neural network using KAN layers instead of - traditional MLP layers. Each layer uses learnable univariate functions - (B-splines + base functions) on edges. - - .. references:: - - Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., - ... & Tegmark, M. (2024). Kan: Kolmogorov-arnold networks. arXiv - preprint arXiv:2404.19756. - + TODO: add docstring. """ def __init__( self, - layer_sizes: List[int], - k: int = 3, - num: int = 3, - grid_eps: float = 0.1, - grid_range: List[float] = [-1, 1], - grid_extension: bool = True, - noise_scale: float = 0.1, - base_function=torch.nn.SiLU(), - scale_base_mu: float = 0.0, - scale_base_sigma: float = 1.0, - scale_sp: float = 1.0, - inner_nodes: int = 5, - sparse_init: bool = False, - sp_trainable: bool = True, - sb_trainable: bool = True, - save_act: bool = True, + layers, + spline_order=3, + n_knots=10, + grid_range=[-1, 1], + base_function=torch.nn.SiLU, + use_base_linear=True, + use_bias=True, + init_scale_spline=1e-2, + init_scale_base=1.0, ): """ - Initialize the KAN network. - - :param iterable layer_sizes: List of layer sizes including input and - output dimensions. - :param int k: Order of the B-spline. - :param int num: Number of grid points for B-splines. - :param float grid_eps: Epsilon for grid spacing. - :param list grid_range: Range for the grid [min, max]. - :param bool grid_extension: Whether to extend the grid. - :param float noise_scale: Scale for initialization noise. - :param base_function: Base activation function (e.g., SiLU). - :param float scale_base_mu: Mean for base function scaling. - :param float scale_base_sigma: Std for base function scaling. - :param float scale_sp: Scale for spline functions. - :param int inner_nodes: Number of inner nodes for KAN layers. - :param bool sparse_init: Whether to use sparse initialization. - :param bool sp_trainable: Whether spline parameters are trainable. - :param bool sb_trainable: Whether base function parameters are - trainable. - :param bool save_act: Whether to save activations after each layer. + Initialization of the :class:`KolmogorovArnoldNetwork` class. + + :param layers: A list of integers specifying the sizes of each layer, + including input and output dimensions. + :type layers: list | tuple. + :param int spline_order: The order of each spline basis function. + Default is 3 (cubic splines). + :param int n_knots: The number of knots for each spline basis function. + Default is 3. + :param grid_range: The range for the spline knots. It must be either a + list or a tuple of the form [min, max]. Default is [0, 1]. + :type grid_range: list | tuple. + :param torch.nn.Module base_function: The base activation function to be + applied to the input before the linear transformation. Default is + :class:`torch.nn.SiLU`. + :param bool use_base_linear: Whether to include a linear transformation + of the base function output. Default is True. + :param bool use_bias: Whether to include a bias term in the output. + Default is True. + :param init_scale_spline: The scale for initializing each spline + control points. Default is 1e-2. + :type init_scale_spline: float | int. + :param init_scale_base: The scale for initializing the base linear + weights. Default is 1.0. + :type init_scale_base: float | int. + :raises ValueError: If ``grid_range`` is not of length 2. """ super().__init__() - if len(layer_sizes) < 2: - raise ValueError("Need at least input and output dimensions") - - self.layer_sizes = layer_sizes - self.num_layers = len(layer_sizes) - 1 - self.save_act = save_act - - # Create KAN layers - self.kan_layers = nn.ModuleList() - - for i in range(self.num_layers): - layer = KANBlock( - k=k, - input_dimensions=layer_sizes[i], - output_dimensions=layer_sizes[i + 1], - num=num, - grid_eps=grid_eps, - grid_range=grid_range, - grid_extension=grid_extension, - noise_scale=noise_scale, - base_function=base_function, - scale_base_mu=scale_base_mu, - scale_base_sigma=scale_base_sigma, - scale_sp=scale_sp, - inner_nodes=inner_nodes, - sparse_init=sparse_init, - sp_trainable=sp_trainable, - sb_trainable=sb_trainable, + # Check consistency -- all other checks are performed in KANBlock + check_consistency(layers, int) + if len(layers) < 2: + raise ValueError( + "`Provide at least two elements for layers (input and output)." ) - self.kan_layers.append(layer) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass through the KAN network. - - Args: - x: Input tensor of shape (batch_size, input_dimensions) - - Returns: - Output tensor of shape (batch_size, output_dimensions) - """ - current = x - self.acts = [current] - - for i, layer in enumerate(self.kan_layers): - current = layer(current) - # current = torch.nn.functional.sigmoid(current) - - if self.save_act: - self.acts.append(current.detach()) - - return current - - def get_num_parameters(self) -> int: - """Get total number of trainable parameters""" - return sum(p.numel() for p in self.parameters() if p.requires_grad) - - def update_grid_from_samples(self, x: torch.Tensor, mode: str = "sample"): - """ - Update grid for all layers based on input samples. - This adapts the grid points to better fit the data distribution. - Args: - x: Input samples, shape (batch_size, input_dimensions) - mode: 'sample' or 'grid' - determines sampling strategy + # Initialize KAN blocks + self.kan_layers = torch.nn.ModuleList( + [ + KANBlock( + input_dimensions=layers[i], + output_dimensions=layers[i + 1], + spline_order=spline_order, + n_knots=n_knots, + grid_range=grid_range, + base_function=base_function, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + for i in range(len(layers) - 1) + ] + ) + + def forward(self, x): """ - current = x - - for i, layer in enumerate(self.kan_layers): - layer.update_grid_from_samples(current, mode=mode) - - if i < len(self.kan_layers) - 1: - with torch.no_grad(): - current = layer(current) - - def update_grid_resolution(self, new_num: int): - """ - Update the grid resolution for all layers. - This can be used for adaptive training where grid resolution increases over time. - - Args: - new_num: New number of grid points + TODO: add docstring. """ for layer in self.kan_layers: - layer.update_grid_resolution(new_num) - - def enable_sparsification(self, threshold: float = 1e-4): - """ - Enable sparsification by setting small weights to zero. - - Args: - threshold: Threshold below which weights are set to zero - """ - with torch.no_grad(): - for layer in self.kan_layers: - # Sparsify scale parameters - layer.scale_base.data[ - torch.abs(layer.scale_base.data) < threshold - ] = 0 - layer.scale_spline.data[ - torch.abs(layer.scale_spline.data) < threshold - ] = 0 - - # Update mask - layer.mask.data = ( - (torch.abs(layer.scale_base) >= threshold) - | (torch.abs(layer.scale_spline) >= threshold) - ).float() + x = layer(x) - def get_activation_statistics(self, x: torch.Tensor): - """ - Get statistics about activations for analysis purposes. - - Args: - x: Input tensor - - Returns: - Dictionary with activation statistics - """ - stats = {} - current = x - - for i, layer in enumerate(self.kan_layers): - current = layer(current) - stats[f"layer_{i}"] = { - "mean": current.mean().item(), - "std": current.std().item(), - "min": current.min().item(), - "max": current.max().item(), - } - - return stats - - def get_network_grid_statistics(self): - """ - Get grid statistics for all layers in the network. - - Returns: - Dictionary with grid statistics for each layer - """ - stats = {} - for i, layer in enumerate(self.kan_layers): - stats[f"layer_{i}"] = layer.get_grid_statistics() - return stats + return x diff --git a/pina/_src/model/spline.py b/pina/_src/model/spline.py index 5df52c106..4fd3bfd24 100644 --- a/pina/_src/model/spline.py +++ b/pina/_src/model/spline.py @@ -278,7 +278,7 @@ def forward(self, x): :rtype: torch.Tensor """ basis = self.basis(x.as_subclass(torch.Tensor)) - # print("normal forward, cp:", self.control_points) + return basis @ self.control_points def derivative(self, x, degree): diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index 737b52fcc..fe48fb8c5 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -1,131 +1,304 @@ -"""Vectorized univariate B-spline model.""" +"""Vectorized univariate B-spline model with per-spline knots.""" +import warnings import torch -import torch.nn as nn +from pina._src.core.utils import check_consistency, check_positive_integer -class VectorizedSpline(nn.Module): - """ - Vectorized univariate B-spline model (shared knots, many splines). - - Notation: - - knots: shape (m,) - - order: k (degree = k-1) - - n_ctrl = m - k - - control_points: - * (S, n_ctrl) -> S splines, scalar output each - * (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels) - Input: - - x: shape (...,) or (..., B) - Output: - - if control_points is (S, n_ctrl): shape (..., S) - - if control_points is (S, O, n_ctrl): shape (..., S, O) +class VectorizedSpline(torch.nn.Module): + r""" + The vectorized B-spline model class. + + A :class:`VectorizedSpline` represents a vector spline, i.e., a collection + of independent univariate B-splines evaluated in parallel. Each univariate + spline has its own knot vector and its own control points, and acts on one + input feature. + + Given ``s`` univariate splines, the vector spline maps an input + :math:`x = (x^{(1)}, \dots, x^{(s)}) \in \mathbb{R}^s` to an output obtained + by evaluating each univariate spline on its corresponding scalar input + :math:`x^{(j)}`. + + For the :math:`j`-th univariate spline of order :math:`k`, the output is + defined as + + .. math:: + + S^{(j)}(x^{(j)}) = \sum_{i=1}^{n_j} B_{i,k}^{(j)}(x^{(j)}) C_i^{(j)}, + + where: + + - :math:`C^{(j)}` are the control points of the :math:`j`-th univariate + spline. In the scalar-output case, :math:`C^{(j)} \in \mathbb{R}^{n_j}`. + More generally, each univariate spline may have output dimension + :math:`o`, so :math:`C^{(j)} \in \mathbb{R}^{o \times n_j}`. + - :math:`B_{i,k}^{(j)}(x)` are the B-spline basis functions of order + :math:`k`, i.e., piecewise polynomials of degree :math:`k-1`, associated + with the knot vector of the :math:`j`-th univariate spline. + - :math:`X^{(j)} = \{x_1^{(j)}, x_2^{(j)}, \dots, x_{m_j}^{(j)}\}` is the + non-decreasing knot vector of the :math:`j`-th univariate spline. + + If the first and last knots of a given univariate spline are repeated + :math:`k` times, then that univariate spline interpolates its first and last + control points. + + The full vector spline evaluates all univariate splines in parallel. If each + univariate spline has output dimension :math:`o`, then before optional + aggregation the output has shape ``[batch, s, o]``. + + .. note:: + + Each univariate spline is forced to be zero outside the interval defined + by the first and last knots of its own knot vector. + + .. note:: + + This class does not represent a single multivariate spline + :math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate + basis. Instead, it represents a vector spline built from ``s`` + independent univariate splines, one for each input feature. + + :Example: + + >>> from pina.model import VectorizedSpline + >>> import torch + + >>> knt1 = torch.tensor([ + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0], + ... ]) + >>> spline1 = VectorizedSpline(order=3, knots=knt1, control_points=None) + + >>> knt2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto", "n_splines": 2} + >>> spline2 = VectorizedSpline(order=3, knots=knt2, control_points=None) + + >>> knt3 = torch.tensor([ + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0], + ... ]) + >>> ctrl3 = torch.tensor([ + ... [0.0, 1.0, 3.0, 2.0], + ... [1.0, 0.0, 2.0, 1.0], + ... ]) + >>> spline3 = VectorizedSpline(order=3, knots=knt3, control_points=ctrl3) """ def __init__( self, - order, - knots, + order=4, + knots=None, control_points=None, aggregate_output=None, ): + """ + Initialization of the :class:`VectorizedSpline` class. + + :param int order: The order of each univariate spline. The corresponding + basis functions are polynomials of degree ``order - 1``. + Default is 4. + :param knots: The knots of the spline. If a tensor is provided, it must + have shape ``[s, n]``, where ``s`` is the number of univariate + splines and ``n`` is the number of knots per univariate spline. If a + dictionary is provided, it must contain the keys ``"n"``, ``"min"``, + ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies + the number of knots for each univariate spline, ``"min"`` and + ``"max"`` define the interval, ``"mode"`` selects the sampling + strategy, and ``"n_splines"`` specifies the number of univariate + splines. The supported modes are ``"uniform"``, where the knots are + evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots + are constructed to ensure that each univariate spline interpolates + the first and last control points. In this case, the number of knots + is adjusted if :math:`n < 2 * order`. If None is given, knots are + initialized automatically over :math:`[0, 1]` ensuring interpolation + of the first and last control points. Default is None. + :type knots: torch.Tensor | dict + :param torch.Tensor control_points: The control points tensor. The + tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where + each univariate spline has ``c`` control points and output dimension + ``o``. In the latter case, the control points are expanded to shape + ``[s, 1, c]``. If None, control points are initialized to learnable + parameters with zero initial value. Default is None. + :param str aggregate_output: If None, the output of each univariate + spline is returned separately, resulting in an output of shape + ``[batch, s, o]``, where ``s`` is the number of univariate splines + and ``o`` is the output dimension of each univariate spline. If set + to ``"mean"`` or ``"sum"``, the output is aggregated accordingly + across the last dimension, resulting in an output of shape + ``[batch, s]``. Default is None. + :raises AssertionError: If ``order`` is not a positive integer. + :raises ValueError: If ``knots`` is neither a torch.Tensor nor a + dictionary, when provided. + :raises ValueError: If ``control_points`` is not a torch.Tensor, + when provided. + :raises ValueError: If both ``knots`` and ``control_points`` are None. + :raises ValueError: If ``knots`` is not two-dimensional. + :raises ValueError: If ``control_points``, after expansion when + two-dimensional, is not three-dimensional. + :raises ValueError: If, for each univariate spline, the number of + ``knots`` is not equal to the sum of ``order`` and the number of + ``control_points.`` + :raises UserWarning: If, for each univariate spline, the number of + ``control_points`` is lower than the ``order``, resulting in a + degenerate spline. + :raises ValueError: If the number of univariate splines in ``knots`` and + ``control_points`` do not match. + """ + super().__init__() - if not isinstance(order, int) or order <= 0: - raise ValueError("order must be a positive integer.") - if not isinstance(knots, torch.Tensor): - raise ValueError("knots must be a torch.Tensor.") - if knots.ndim != 1: - raise ValueError("knots must be 1D.") + # Check consistency + check_positive_integer(value=order, strict=True) + check_consistency(knots, (type(None), torch.Tensor, dict)) + check_consistency(control_points, (type(None), torch.Tensor)) + + # Raise error if neither knots nor control points are provided + if knots is None and control_points is None: + raise ValueError("knots and control_points cannot both be None.") + + # Initialize knots if not provided + if knots is None and control_points is not None: + knots = { + "n": control_points.shape[-1] + order, + "min": 0, + "max": 1, + "n_splines": control_points.shape[0], + "mode": "auto", + } + + # Initialization - knots and control points managed by their setters self.order = order + self.knots = knots + self.control_points = control_points + self.aggregate_output = aggregate_output + + # Check dimensionality of knots + if self.knots.ndim != 2: + raise ValueError("knots must be two-dimensional.") - # store sorted knots as buffer - knots_sorted = knots.sort().values - self.register_buffer("knots", knots_sorted) + # Check dimensionality of control points + if self.control_points.ndim != 3: + raise ValueError("control_points must be three-dimensional.") - n_ctrl = knots_sorted.numel() - order - if n_ctrl <= 0: + # Raise error if #knots != order + #control_points + if self.knots.shape[-1] != self.order + self.control_points.shape[-1]: raise ValueError( - f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}." + f" The number of knots per spline must be equal to order + the" + f" number of control points. Got {self.knots.shape[-1]} knots" + f" per spline, {self.control_points.shape[-1]} control points," + f" and {self.order} order." ) - # boundary interval idx for rightmost inclusion - self._boundary_interval_idx = self._compute_boundary_interval_idx( - knots_sorted - ) + # Raise warning if spline is degenerate + if self.control_points.shape[-1] < self.order: + warnings.warn( + "The number of control points per spline is smaller than the" + " spline order. This creates a degenerate spline with limited" + " flexibility.", + UserWarning, + ) - # # control points init - # if control_points is None: - # # default: one spline - # cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device) - # self.control_points = nn.Parameter(cp, requires_grad=True) - # else: - # if not isinstance(control_points, torch.Tensor): - # raise ValueError("control_points must be a torch.Tensor or None.") - # if control_points.ndim not in (2, 3): - # raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).") - # if control_points.shape[-1] != n_ctrl: - # raise ValueError( - # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}." - # ) - self.control_points = nn.Parameter(control_points, requires_grad=True) - self.aggregate_output = aggregate_output + # Raise error if knots and control points have different # of splines + if self.knots.shape[0] != self.control_points.shape[0]: + raise ValueError( + f"The number of splines must be the same for knots and" + f" control points. Got {self.knots.shape[0]} splines for knots" + f" and {self.control_points.shape[0]} splines for control" + f" points." + ) - @staticmethod - def _compute_boundary_interval_idx(knots: torch.Tensor) -> int: - if knots.numel() < 2: - return 0 - diffs = knots[1:] - knots[:-1] - valid = torch.nonzero(diffs > 0, as_tuple=False) - if valid.numel() == 0: - return 0 - return int(valid[-1]) + # Precompute boundary interval index + self._boundary_interval_idx = self._compute_boundary_interval() - def basis(self, x: torch.Tensor) -> torch.Tensor: + def _compute_boundary_interval(self): """ - Compute B-spline basis functions of order self.order at x. + Precompute the index of the rightmost non-degenerate interval to improve + performance, eliminating the need to perform a search loop in the basis + function on each call. - Returns: - basis: shape (..., n_ctrl) + :return: The index of the rightmost non-degenerate interval for each + univariate spline. + :rtype: torch.Tensor """ - if not isinstance(x, torch.Tensor): - x = torch.as_tensor(x) + # Compute the differences between consecutive knots for each spline + diffs = self._knots[:, 1:] - self._knots[:, :-1] + valid = diffs > 0 - # ensure float dtype consistent - # x = x.to(dtype=self.knots.dtype, device=self.knots.device) - x = x.as_subclass(torch.Tensor).to( - dtype=self.knots.dtype, device=self.knots.device + # Initialize idx tensor to store the last valid interval for each spline + idx = torch.zeros( + self._knots.shape[0], dtype=torch.long, device=self._knots.device ) - # make x shape (..., 1) for broadcasting - x_exp = x.unsqueeze(-1) # (..., 1) + # For each spline, find the last idx where interval is non-degenerate + for s in range(self._knots.shape[0]): + valid_s = torch.nonzero(valid[s], as_tuple=False) + idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0 - # knots as (1, ..., 1, m) via unsqueeze to broadcast - # (m,) -> (1,)*x.ndim + (m,) - knots = self.knots.view(*([1] * x.ndim), -1) + return idx - # order-1 base: indicator on intervals [t_i, t_{i+1}) - basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to( - x_exp.dtype - ) # (..., m-1) + def basis(self, x): + """ + Evaluate the B-spline basis functions for each univariate spline. + + This method applies the Cox-de Boor recursion in vectorized form across + all univariate splines of the vector spline. + + :param torch.Tensor x: The points to be evaluated. + :raises ValueError: If ``x`` is not two-dimensional. + :raises ValueError: If the number of input features does not match + the number of univariate splines. + :return: The basis functions evaluated at x. + :rtype: torch.Tensor + """ + # Ensure x is a tensor of the same dtype as knots + x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype) + + # Raise error if x does not have shape (batch, s) + if x.ndim != 2: + raise ValueError( + f"The input must have shape (batch, s). Got {x.shape}." + ) + + # Raise error if x has different number of splines than knots + if x.shape[1] != self.knots.shape[0]: + raise ValueError( + f"The number of input features must be the same as the number" + f" of univariate splines. Got {x.shape[1]} input features," + f" but {self.knots.shape[0]} univariate splines." + ) - # include rightmost boundary in the last non-degenerate interval - j = self._boundary_interval_idx - knot_left = knots[..., j] - knot_right = knots[..., j + 1] - at_right = (x >= knot_left.squeeze(-1)) & torch.isclose( - x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10 + # Add a final dimension to x for broadcasting + x = x.unsqueeze(-1) + + # Add an initial dimension to knots for broadcasting + knots = self.knots.unsqueeze(0) + + # Base case of recursion: indicator functions for the intervals + basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) + basis = basis.to(x.dtype) + + # Extract left and right knots of the boundary interval for each spline + range_tensor = torch.arange(self.knots.shape[0], device=x.device) + knot_left = self.knots[range_tensor, self._boundary_interval_idx] + knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1] + + # Identify points at the rightmost boundary + at_rightmost_boundary = (x >= knot_left.unsqueeze(0)) & torch.isclose( + x, knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10 ) - if torch.any(at_right): - basis_j = basis[..., j].bool() | at_right - basis[..., j] = basis_j.to(basis.dtype) - # Cox-de Boor recursion up to order k - # after i-th iteration, basis has length (m-1 - i) + # Ensure the correct value is set at the rightmost boundary + if torch.any(at_rightmost_boundary): + b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True) + basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0 + + # Cox-de Boor recursion -- iterative case for i in range(1, self.order): + + # Compute the denominators for both terms of the recursion denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + # Ensure no division by zero denom1 = torch.where( denom1.abs() < 1e-8, torch.ones_like(denom1), denom1 ) @@ -133,46 +306,185 @@ def basis(self, x: torch.Tensor) -> torch.Tensor: denom2.abs() < 1e-8, torch.ones_like(denom2), denom2 ) - term1 = ((x_exp - knots[..., : -(i + 1)]) / denom1) * basis[ - ..., :-1 - ] - term2 = ((knots[..., i + 1 :] - x_exp) / denom2) * basis[..., 1:] + # Compute the two terms of the recursion + term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:] + + # Combine terms to get the new basis basis = term1 + term2 - # final basis length is n_ctrl = m - order - return basis # (..., n_ctrl) + return basis - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x): """ - Evaluate spline(s) at x. + Forward pass for the :class:`VectorizedSpline` model. Each univariate + spline is evaluated independently on its corresponding input feature. + + The input is expected to have shape ``[batch, s]``, where ``s`` is the + number of univariate splines. The output has shape ``[batch, s, o]``, + where ``o`` is the output dimension of each univariate spline, unless an + aggregation method is specified. If ``aggregate_output`` is set to + ``"mean"`` or ``"sum"``, the output is aggregated across the last + dimension, resulting in an output of shape ``[batch, s]``. - If control_points is (S, n_ctrl): output (..., S) - If control_points is (S, O, n_ctrl): output (..., S, O) + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output tensor. + :rtype: torch.Tensor """ - B = self.basis(x) # (..., n_ctrl) + # Compute the basis functions at x + basis = self.basis(x) - cp = self.control_points - # print("vectorized forward, cp:", cp) - if cp.ndim == 2: - # (S, n_ctrl) - # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) - # print('B shape:', B.shape, 'cp shape:', cp.shape) - # out = (B @ cp.transpose(0, 1)).squeeze(-1) - out = B @ cp.transpose(0, 1) - # out = B @ cp[0] - else: - # (S, O, n_ctrl) - # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S - # vectorized using einsum (yes, this one is actually appropriate) - # (..., n) * (S, O, n) -> (..., S, O) - # out = torch.einsum("...n, son -> ...so", B, cp) - out = torch.einsum("bsc,soc->bso", B, cp) + # Compute the output for each spline + out = torch.einsum("bsc,soc->bso", basis, self.control_points) + # Aggregate output if needed if self.aggregate_output == "mean": - out = out.mean(dim=-1) # aggregate over O dimension if present + out = out.mean(dim=-1) elif self.aggregate_output == "sum": out = out.sum(dim=-1) - # print("vectorized forward, out:", out.shape) - return out + + @property + def control_points(self): + """ + The control points of the spline. + + :return: The control points. + :rtype: torch.Tensor + """ + return self._control_points + + @control_points.setter + def control_points(self, control_points): + """ + Set the control points of the spline. + + :param torch.Tensor control_points: The control points tensor. The + tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where + each univariate spline has ``c`` control points and output dimension + ``o``. In the latter case, the control points are expanded to shape + ``[s, 1, c]``. + :raises ValueError: If there are not enough knots to define the control + points, due to the relation: #knots = order + #control_points. + """ + # If control points are not provided, initialize them + if control_points is None: + + # Check that there are enough knots to define control points + if self.knots.shape[-1] < self.order + 1: + raise ValueError( + f"Not enough knots to define control points. Got" + f" {self.knots.shape[-1]} knots for each univariate spline," + f" but need at least {self.order + 1}." + ) + + # Initialize control points to zero + control_points = torch.zeros( + self.knots.shape[0], 1, self.knots.shape[-1] - self.order + ) + + # If a the control points are 2D, add an output dimension of size 1 + if control_points.ndim == 2: + control_points = control_points.unsqueeze(1) + + # Set control points + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) + + @property + def knots(self): + """ + The knots of the spline. + + :return: The knots. + :rtype: torch.Tensor + """ + return self._knots + + @knots.setter + def knots(self, value): + """ + Set the knots of the spline. + :param value: The knots of the spline. If a tensor is provided, it must + have shape ``[s, n]``, where ``s`` is the number of univariate + splines and ``n`` is the number of knots per univariate spline. If a + dictionary is provided, it must contain the keys ``"n"``, ``"min"``, + ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies + the number of knots for each univariate spline, ``"min"`` and + ``"max"`` define the interval, ``"mode"`` selects the sampling + strategy, and ``"n_splines"`` specifies the number of univariate + splines. The supported modes are ``"uniform"``, where the knots are + evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots + are constructed to ensure that each univariate spline interpolates + the first and last control points. In this case, the number of knots + is adjusted if :math:`n < 2 * order`. If None is given, knots are + initialized automatically over :math:`[0, 1]` ensuring interpolation + of the first and last control points. + :type value: torch.Tensor | dict + :raises ValueError: If a dictionary is provided but does not contain + the required keys. + :raises ValueError: If the mode specified in the dictionary is invalid. + """ + # If a dictionary is provided, initialize knots accordingly + if isinstance(value, dict): + + # Check that required keys are present + required_keys = {"n", "min", "max", "mode", "n_splines"} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"When providing knots as a dictionary, the following " + f"keys must be present: {required_keys}. Got " + f"{value.keys()}." + ) + + # Save number of splines for later use + n_splines = value["n_splines"] + + # Uniform sampling of knots + if value["mode"] == "uniform": + value = torch.linspace(value["min"], value["max"], value["n"]) + + # Automatic sampling of interpolating knots + elif value["mode"] == "auto": + + # Repeat the first and last knots 'order' times + initial_knots = torch.ones(self.order) * value["min"] + final_knots = torch.ones(self.order) * value["max"] + + # Number of internal knots + n_internal = value["n"] - 2 * self.order + + # If no internal knots are needed, just concatenate boundaries + if n_internal <= 0: + value = torch.cat((initial_knots, final_knots)) + + # Else, sample internal knots uniformly and exclude boundaries + # Recover the correct number of internal knots when slicing by + # adding 2 to n_internal + else: + internal_knots = torch.linspace( + value["min"], value["max"], n_internal + 2 + )[1:-1] + value = torch.cat( + (initial_knots, internal_knots, final_knots) + ) + + # Raise error if mode is invalid + else: + raise ValueError( + f"Invalid mode for knots initialization. Got " + f"{value['mode']}, but expected 'uniform' or 'auto'." + ) + + # Repeat the knot vector for each spline + value = value.unsqueeze(0).repeat(n_splines, 1) + + # Set knots + self.register_buffer("_knots", value.sort(dim=1).values) + + # Recompute boundary interval when knots change + if hasattr(self, "_boundary_interval_idx"): + self._boundary_interval_idx = self._compute_boundary_interval() From d9b59ff88a5cc587df47f6b3575bba6c6d40c61a Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 7 Apr 2026 11:47:15 +0200 Subject: [PATCH 07/12] minor fix to output dimension in vector splines --- pina/_src/model/vectorized_spline.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index fe48fb8c5..70ea3718e 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -128,6 +128,8 @@ def __init__( :raises AssertionError: If ``order`` is not a positive integer. :raises ValueError: If ``knots`` is neither a torch.Tensor nor a dictionary, when provided. + :raises ValueError: If ``aggregate_output`` is not None, "mean", or + "sum". :raises ValueError: If ``control_points`` is not a torch.Tensor, when provided. :raises ValueError: If both ``knots`` and ``control_points`` are None. @@ -155,6 +157,13 @@ def __init__( if knots is None and control_points is None: raise ValueError("knots and control_points cannot both be None.") + # Raise error if aggregate_output is not None, "mean", or "sum" + if aggregate_output not in (None, "mean", "sum"): + raise ValueError( + f"aggregate_output must be None, 'mean', or 'sum'." + f" Got {aggregate_output}." + ) + # Initialize knots if not provided if knots is None and control_points is not None: knots = { @@ -323,9 +332,11 @@ def forward(self, x): The input is expected to have shape ``[batch, s]``, where ``s`` is the number of univariate splines. The output has shape ``[batch, s, o]``, where ``o`` is the output dimension of each univariate spline, unless an - aggregation method is specified. If ``aggregate_output`` is set to - ``"mean"`` or ``"sum"``, the output is aggregated across the last - dimension, resulting in an output of shape ``[batch, s]``. + aggregation method is specified. If both ``s`` and ``o`` are 1, the + output is aggregated across the last dimension, resulting in an output + of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or + ``"sum"``, the output is aggregated across the last dimension, resulting + in an output of shape ``[batch, s]``. :param x: The input tensor. :type x: torch.Tensor | LabelTensor @@ -343,6 +354,8 @@ def forward(self, x): out = out.mean(dim=-1) elif self.aggregate_output == "sum": out = out.sum(dim=-1) + elif out.shape[1] == 1 and out.shape[2] == 1: + out = out.squeeze(-1) return out @@ -483,7 +496,7 @@ def knots(self, value): value = value.unsqueeze(0).repeat(n_splines, 1) # Set knots - self.register_buffer("_knots", value.sort(dim=1).values) + self.register_buffer("_knots", value.sort(dim=-1).values) # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): From ad8a27f0b0fc428069ba450be08af6c08fd6fdd6 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 7 Apr 2026 12:09:53 +0200 Subject: [PATCH 08/12] add tests --- tests/test_block/test_kan_block.py | 146 ++++++++++ .../test_kolmogorov_arnold_network.py | 204 +++++--------- tests/test_model/test_spline.py | 34 +-- tests/test_model/test_vectorized_spline.py | 259 ++++++++++++++++++ 4 files changed, 473 insertions(+), 170 deletions(-) create mode 100644 tests/test_block/test_kan_block.py create mode 100644 tests/test_model/test_vectorized_spline.py diff --git a/tests/test_block/test_kan_block.py b/tests/test_block/test_kan_block.py new file mode 100644 index 000000000..08f549f40 --- /dev/null +++ b/tests/test_block/test_kan_block.py @@ -0,0 +1,146 @@ +import torch +import pytest +from pina.model.block import KANBlock + +# Data +input_dim = 3 +data = torch.rand((10, input_dim)) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_constructor( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + # Should fail if input_dimensions is not a positive integer + with pytest.raises(AssertionError): + KANBlock(input_dimensions=-1, output_dimensions=output_dimensions) + + # Should fail if output_dimensions is not a positive integer + with pytest.raises(AssertionError): + KANBlock(input_dimensions=data.shape[1], output_dimensions=-1) + + # Should fail if spline_order is not a positive integer + with pytest.raises(AssertionError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=-1, + ) + + # Should fail if n_knots is not a positive integer + with pytest.raises(AssertionError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + n_knots=-1, + ) + + # Should fail if grid_range is not of length 2 + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + grid_range=[-1, 0, 1], + ) + + # Should fail if base_function is not a torch.nn.Module subclass + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + base_function="not_a_module", + ) + + # Should fail if use_base_linear is not a boolean + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + use_base_linear="not_a_bool", + ) + + # Should fail if use_bias is not a boolean + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + use_bias="not_a_bool", + ) + + # Should fail if init_scale_spline is not a float or int + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + init_scale_spline="not_a_number", + ) + + # Should fail if init_scale_base is not a float or int + with pytest.raises(ValueError): + KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + init_scale_base="not_a_number", + ) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_forward( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + model = KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + output_ = model(data) + assert output_.shape == (data.shape[0], output_dimensions) + + +@pytest.mark.parametrize("output_dimensions", [1, 5]) +@pytest.mark.parametrize("spline_order", [3, 4]) +@pytest.mark.parametrize("n_knots", [10, 20]) +@pytest.mark.parametrize("init_scale_spline", [1e-2, 1e-1]) +@pytest.mark.parametrize("init_scale_base", [1.0, 0.1]) +def test_backward( + output_dimensions, spline_order, n_knots, init_scale_spline, init_scale_base +): + + model = KANBlock( + input_dimensions=data.shape[1], + output_dimensions=output_dimensions, + spline_order=spline_order, + n_knots=n_knots, + init_scale_spline=init_scale_spline, + init_scale_base=init_scale_base, + ) + + data.requires_grad_() + output_ = model(data) + + loss = torch.mean(output_) + loss.backward() + assert data.grad.shape == data.shape diff --git a/tests/test_model/test_kolmogorov_arnold_network.py b/tests/test_model/test_kolmogorov_arnold_network.py index 42f994f71..ec26f81d2 100644 --- a/tests/test_model/test_kolmogorov_arnold_network.py +++ b/tests/test_model/test_kolmogorov_arnold_network.py @@ -1,153 +1,83 @@ import torch import pytest - from pina.model import KolmogorovArnoldNetwork -data = torch.rand((20, 3)) -input_vars = 3 -output_vars = 1 +# Data +input_dim = 3 +data = torch.rand((10, input_dim)) -def test_constructor(): - KolmogorovArnoldNetwork([input_vars, output_vars]) - KolmogorovArnoldNetwork([input_vars, 10, 20, output_vars]) - KolmogorovArnoldNetwork( - [input_vars, 10, 20, output_vars], - k=3, - num=5 - ) - KolmogorovArnoldNetwork( - [input_vars, 10, 20, output_vars], - k=3, - num=5, - grid_eps=0.05, - grid_range=[-2, 2] - ) +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_constructor(use_base_linear, use_bias, grid_range, layers): + + # Constructor KolmogorovArnoldNetwork( - [input_vars, 10, output_vars], - base_function=torch.nn.Tanh(), - scale_sp=0.5, - sparse_init=True + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, ) - -def test_constructor_wrong(): + # Should fail if grid_range is not of length 2 with pytest.raises(ValueError): - KolmogorovArnoldNetwork([input_vars]) - with pytest.raises(ValueError): - KolmogorovArnoldNetwork([]) - - -def test_forward(): - dim_in, dim_out = 3, 2 - kan = KolmogorovArnoldNetwork([dim_in, dim_out]) - output_ = kan(data) - assert output_.shape == (data.shape[0], dim_out) + KolmogorovArnoldNetwork(layers=layers, grid_range=[-1, 0, 1]) + # Should fail if layers has less than 2 elements + with pytest.raises(ValueError): + KolmogorovArnoldNetwork(layers=[input_dim]) + + +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_forward(use_base_linear, use_bias, grid_range, layers): + + model = KolmogorovArnoldNetwork( + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, + ) -def test_forward_multilayer(): - dim_in, dim_out = 3, 2 - kan = KolmogorovArnoldNetwork([dim_in, 10, 5, dim_out]) - output_ = kan(data) - assert output_.shape == (data.shape[0], dim_out) + output_ = model(data) + assert output_.shape == (data.shape[0], layers[-1]) + + +@pytest.mark.parametrize("use_base_linear", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("grid_range", [[-1, 1], [0, 2]]) +@pytest.mark.parametrize("layers", [[input_dim, 5, 1], [input_dim, 2]]) +def test_backward(use_base_linear, use_bias, grid_range, layers): + + model = KolmogorovArnoldNetwork( + layers=layers, + spline_order=3, + n_knots=10, + grid_range=grid_range, + base_function=torch.nn.SiLU, + use_base_linear=use_base_linear, + use_bias=use_bias, + init_scale_spline=1e-2, + init_scale_base=1.0, + ) + data.requires_grad_() + output_ = model(data) -def test_backward(): - dim_in, dim_out = 3, 2 - kan = KolmogorovArnoldNetwork([dim_in, dim_out]) - data.requires_grad = True - output_ = kan(data) loss = torch.mean(output_) loss.backward() - assert data._grad.shape == torch.Size([20, 3]) - - -def test_get_num_parameters(): - kan = KolmogorovArnoldNetwork([3, 5, 2]) - num_params = kan.get_num_parameters() - assert num_params > 0 - assert isinstance(num_params, int) - -from pina.problem.zoo import Poisson2DSquareProblem -from pina.solver import PINN -from pina.trainer import Trainer - -def test_train_poisson(): - problem = Poisson2DSquareProblem() - problem.discretise_domain(n=10, mode="random", domains="all") - - model = KolmogorovArnoldNetwork([2, 3, 1], k=3, num=5) - solver = PINN(model=model, problem=problem) - trainer = Trainer( - solver=solver, - max_epochs=10, - accelerator="cpu", - batch_size=100, - train_size=1.0, - val_size=0.0, - test_size=0.0, - ) - trainer.train() - - - -# def test_update_grid_from_samples(): -# kan = KolmogorovArnoldNetwork([3, 5, 2]) -# samples = torch.randn(50, 3) -# kan.update_grid_from_samples(samples, mode='sample') -# # Check that the network still works after grid update -# output = kan(data) -# assert output.shape == (data.shape[0], 2) - - -# def test_update_grid_resolution(): -# kan = KolmogorovArnoldNetwork([3, 5, 2], num=3) -# kan.update_grid_resolution(5) -# # Check that the network still works after resolution update -# output = kan(data) -# assert output.shape == (data.shape[0], 2) - - -# def test_enable_sparsification(): -# kan = KolmogorovArnoldNetwork([3, 5, 2]) -# kan.enable_sparsification(threshold=1e-4) -# # Check that the network still works after sparsification -# output = kan(data) -# assert output.shape == (data.shape[0], 2) - - -# def test_get_activation_statistics(): -# kan = KolmogorovArnoldNetwork([3, 5, 2]) -# stats = kan.get_activation_statistics(data) -# assert isinstance(stats, dict) -# assert 'layer_0' in stats -# assert 'layer_1' in stats -# assert 'mean' in stats['layer_0'] -# assert 'std' in stats['layer_0'] -# assert 'min' in stats['layer_0'] -# assert 'max' in stats['layer_0'] - - -# def test_get_network_grid_statistics(): -# kan = KolmogorovArnoldNetwork([3, 5, 2]) -# stats = kan.get_network_grid_statistics() -# assert isinstance(stats, dict) -# assert 'layer_0' in stats -# assert 'layer_1' in stats - - -# def test_save_act(): -# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=True) -# output = kan(data) -# assert hasattr(kan, 'acts') -# assert len(kan.acts) == 3 # input + 2 layers -# assert kan.acts[0].shape == data.shape -# assert kan.acts[-1].shape == output.shape - - -# def test_save_act_disabled(): -# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=False) -# _ = kan(data) -# assert hasattr(kan, 'acts') -# # Only the first activation (input) is saved -# assert len(kan.acts) == 1 + assert data.grad.shape == data.shape diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index 2191f6ee4..baff81940 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -2,7 +2,7 @@ import pytest from scipy.interpolate import BSpline from pina.operator import grad -from pina.model import Spline, VectorizedSpline +from pina.model import Spline from pina import LabelTensor # Utility quantities for testing @@ -191,35 +191,3 @@ def test_derivative(args, pts): # Check shape and value assert first_der.shape == pts.shape assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) - - -@pytest.mark.parametrize("args", valid_args) -@pytest.mark.parametrize("N", [1, 4, 7]) -def test_vectorized(args, N): - - cps = [] - splines = [] - - for i in range(N): - spline = Spline(**args) - splines.append(spline) - cps.append(spline.control_points) - - unique_cps = torch.stack(cps, dim=0) - vectorized_spline = VectorizedSpline( - order=args["order"], - knots=splines[0].knots, - control_points=unique_cps - ) - - x = torch.rand(100, 1) - - result_single = torch.stack([ - splines[i](x) for i in range(N) - ]) - result_single = result_single.permute(1, 2, 0) # shape (100, N) - out_vectorized = vectorized_spline(x) - print("result single shape:", result_single.shape) - print("out vectorized shape:", out_vectorized.shape) - assert out_vectorized.shape == (100, 1, N) - assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5) \ No newline at end of file diff --git a/tests/test_model/test_vectorized_spline.py b/tests/test_model/test_vectorized_spline.py new file mode 100644 index 000000000..373333895 --- /dev/null +++ b/tests/test_model/test_vectorized_spline.py @@ -0,0 +1,259 @@ +import torch +import pytest +from pina.model import VectorizedSpline, Spline +from pina import LabelTensor + + +# Utility quantities for testing +order = torch.randint(3, 6, (1,)).item() +n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() +n_knots = order + n_ctrl_pts +n_splines = torch.randint(2, 5, (1,)).item() +output_dim = torch.randint(1, 4, (1,)).item() + +# Input points +labels = [f"x{i}" for i in range(n_splines)] +pts = torch.rand(10, n_splines).requires_grad_(True) +pts = LabelTensor(pts, labels) + + +# Define all possible combinations of valid arguments for VectorizedSpline class +valid_args = [ + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": torch.linspace(0, 1, n_knots) + .unsqueeze(0) + .repeat(n_splines, 1), + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "auto", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "uniform", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": None, + "knots": torch.linspace(0, 1, n_knots) + .unsqueeze(0) + .repeat(n_splines, 1), + }, + { + "order": order, + "control_points": None, + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "auto", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": None, + "knots": { + "n": n_knots, + "min": 0, + "max": 1, + "mode": "uniform", + "n_splines": n_splines, + }, + }, + { + "order": order, + "control_points": torch.rand(n_splines, output_dim, n_ctrl_pts), + "knots": None, + }, +] + + +@pytest.mark.parametrize("args", valid_args) +@pytest.mark.parametrize("aggregate_output", ["mean", "sum", None]) +def test_constructor(args, aggregate_output): + VectorizedSpline(**args, aggregate_output=aggregate_output) + + # Should fail if order is not a positive integer + with pytest.raises(AssertionError): + VectorizedSpline( + order=-1, + control_points=args["control_points"], + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if control_points is not None or a torch.Tensor + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=[1, 2, 3], + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if knots is not None, a torch.Tensor, or a dict + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=5, + aggregate_output=aggregate_output, + ) + + # Should fail if aggregate_output is not None, "mean", or "sum" + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=args["knots"], + aggregate_output="invalid", + ) + + # Should fail if both knots and control_points are None + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=None, + knots=None, + aggregate_output=aggregate_output, + ) + + # Should fail if knots is not two-dimensional + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=torch.rand(n_knots), + aggregate_output=aggregate_output, + ) + + # Should fail if control_points is not three-dimensional + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=torch.rand(n_ctrl_pts), + knots=args["knots"], + aggregate_output=aggregate_output, + ) + + # Should fail if the number of knots != order + number of control points + # If control points are None, they are initialized to fulfill this condition + if args["control_points"] is not None: + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots=torch.linspace(0, 1, n_knots + 1) + .unsqueeze(0) + .repeat(n_splines, 1), + aggregate_output=aggregate_output, + ) + + # Should fail if the knot dict is missing required keys + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1}, + aggregate_output=aggregate_output, + ) + + # Should fail if the knot dict has invalid 'mode' key + with pytest.raises(ValueError): + VectorizedSpline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, + aggregate_output=aggregate_output, + ) + + # Should fail if knots and control points have different number of splines + with pytest.raises(ValueError): + VectorizedSpline( + order=3, + control_points=torch.rand(5, 4, 5), + knots=torch.linspace(0, 1, 8).unsqueeze(0).repeat(3, 1), + aggregate_output=aggregate_output, + ) + + +@pytest.mark.parametrize("args", valid_args) +def test_forward(args): + + # Define the model + model = VectorizedSpline(**args) + + # Evaluate the model + output_ = model(pts) + + # Check output shape + if model.aggregate_output is None: + assert output_.shape == ( + pts.shape[0], + pts.shape[1], + model.control_points.shape[1], + ) + else: + assert output_.shape == pts.shape + + +@pytest.mark.parametrize("args", valid_args) +def test_backward(args): + + # Define the model + model = VectorizedSpline(**args) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape + + +def test_1d_vs_vectorized(): + + control_points = torch.rand(1, 1, n_ctrl_pts) + knots = torch.linspace(0, 1, n_knots).unsqueeze(0) + + # Classical 1D spline + + spline = Spline( + order=order, + control_points=control_points.squeeze(), + knots=knots.squeeze(), + ) + + # Create a VectorizedSpline instance with the same control pts and knots + vectorized_spline = VectorizedSpline( + order=order, + knots=knots, + control_points=control_points, + aggregate_output=None, + ) + + # Input points + x = LabelTensor(torch.rand(10, 1), labels=["x"]) + + # Evaluate both models on the same input + out_spline = spline(x) + out_vectorized = vectorized_spline(x) + + assert out_vectorized.shape == out_spline.shape + assert torch.allclose(out_vectorized, out_spline, atol=1e-5, rtol=1e-5) From d4dfb65042913bdc365ef3b7669a271781e22b35 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 7 Apr 2026 12:18:36 +0200 Subject: [PATCH 09/12] add rst files --- docs/source/_rst/_code.rst | 3 +++ docs/source/_rst/model/block/kan_block.rst | 7 +++++++ docs/source/_rst/model/kolmogorov_arnold_network.rst | 7 +++++++ docs/source/_rst/model/vectorized_spline.rst | 7 +++++++ 4 files changed, 24 insertions(+) create mode 100644 docs/source/_rst/model/block/kan_block.rst create mode 100644 docs/source/_rst/model/kolmogorov_arnold_network.rst create mode 100644 docs/source/_rst/model/vectorized_spline.rst diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 7d992d1ca..e4a5f8a61 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -110,6 +110,8 @@ Models PirateNet EquivariantGraphNeuralOperator SINDy + Vectorized Spline + Kolmogorov-Arnold Network Blocks ------------- @@ -128,6 +130,7 @@ Blocks Continuous Convolution Block Orthogonal Block PirateNet Block + KAN Block Message Passing ------------------- diff --git a/docs/source/_rst/model/block/kan_block.rst b/docs/source/_rst/model/block/kan_block.rst new file mode 100644 index 000000000..95ca239eb --- /dev/null +++ b/docs/source/_rst/model/block/kan_block.rst @@ -0,0 +1,7 @@ +KANBlock +======================= +.. currentmodule:: pina.model.block.kan_block + +.. autoclass:: pina._src.model.block.kan_block.KANBlock + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/model/kolmogorov_arnold_network.rst b/docs/source/_rst/model/kolmogorov_arnold_network.rst new file mode 100644 index 000000000..0211611f4 --- /dev/null +++ b/docs/source/_rst/model/kolmogorov_arnold_network.rst @@ -0,0 +1,7 @@ +KolmogorovArnoldNetwork +=========================== +.. currentmodule:: pina.model.kolmogorov_arnold_network + +.. autoclass:: pina._src.model.kolmogorov_arnold_network.KolmogorovArnoldNetwork + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/model/vectorized_spline.rst b/docs/source/_rst/model/vectorized_spline.rst new file mode 100644 index 000000000..08522bc54 --- /dev/null +++ b/docs/source/_rst/model/vectorized_spline.rst @@ -0,0 +1,7 @@ +VectorizedSpline +======================= +.. currentmodule:: pina.model.vectorized_spline + +.. autoclass:: pina._src.model.vectorized_spline.VectorizedSpline + :members: + :show-inheritance: \ No newline at end of file From 5bd5902a990377df876e9110ea3a0e7f0850cafa Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 7 Apr 2026 15:25:21 +0200 Subject: [PATCH 10/12] add docstrings --- pina/_src/model/block/kan_block.py | 30 ++++++++++++++------ pina/_src/model/kolmogorov_arnold_network.py | 23 +++++++++++++-- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pina/_src/model/block/kan_block.py b/pina/_src/model/block/kan_block.py index 08655f467..77597d310 100644 --- a/pina/_src/model/block/kan_block.py +++ b/pina/_src/model/block/kan_block.py @@ -7,7 +7,20 @@ class KANBlock(torch.nn.Module): """ - TODO: docstring. + The inner block of the Kolmogorov-Arnold Network (KAN). + + The block applies a spline transformation to the input, optionally combined + with a linear transformation of a base activation function. The output is + aggregated across input dimensions to produce the final output. + + .. seealso:: + + **Original reference**: + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., + Hou T., Tegmark M. (2025). + *KAN: Kolmogorov-Arnold Networks*. + DOI: `arXiv preprint arXiv:2404.19756. + `_ """ def __init__( @@ -119,16 +132,15 @@ def __init__( def forward(self, x): """ - Forward pass of the :class:`KANBlock`. It transforms the input using a - vectorized spline basis and optionally adds a linear transformation of a - base activation function. - - The input is expected to have shape (batch_size, input_dimensions) and - the output will have shape (batch_size, output_dimensions). + Forward pass of the Kolmogorov-Arnold block. The input is passed through + the spline transformation, optionally combined with a linear + transformation of the base function output, and then aggregated across + input dimensions to produce the final output. - :param torch.Tensor x: The input tensor for the model. + :param x: The input tensor for the model. + :type x: torch.Tensor | LabelTensor :return: The output tensor of the model. - :rtype: torch.Tensor + :rtype: torch.Tensor | LabelTensor """ y = self.spline(x) diff --git a/pina/_src/model/kolmogorov_arnold_network.py b/pina/_src/model/kolmogorov_arnold_network.py index dec01569c..1782aab4b 100644 --- a/pina/_src/model/kolmogorov_arnold_network.py +++ b/pina/_src/model/kolmogorov_arnold_network.py @@ -5,7 +5,20 @@ class KolmogorovArnoldNetwork(torch.nn.Module): """ - TODO: add docstring. + Implementation of Kolmogorov-Arnold Network (KAN). + + The model consists of a sequence of KAN blocks, where each block applies a + spline transformation to the input, optionally combined with a linear + transformation of a base activation function. + + .. seealso:: + + **Original reference**: + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., + Hou T., Tegmark M. (2025). + *KAN: Kolmogorov-Arnold Networks*. + DOI: `arXiv preprint arXiv:2404.19756. + `_ """ def __init__( @@ -78,7 +91,13 @@ def __init__( def forward(self, x): """ - TODO: add docstring. + Forward pass of the KolmogorovArnoldNetwork model. It passes the input + through each KAN block in the network and returns the final output. + + :param x: The input tensor for the model. + :type x: torch.Tensor | LabelTensor + :return: The output tensor of the model. + :rtype: torch.Tensor | LabelTensor """ for layer in self.kan_layers: x = layer(x) From 6eb49fb3d822b7ec37663feb86552d3dd3cd14dc Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Tue, 7 Apr 2026 17:56:12 +0200 Subject: [PATCH 11/12] implement derivatives for vector splines --- pina/_src/model/spline.py | 8 +- pina/_src/model/vectorized_spline.py | 159 +++++++++++++++++++-- tests/test_model/test_vectorized_spline.py | 28 ++++ 3 files changed, 184 insertions(+), 11 deletions(-) diff --git a/pina/_src/model/spline.py b/pina/_src/model/spline.py index 4fd3bfd24..5e5b133c3 100644 --- a/pina/_src/model/spline.py +++ b/pina/_src/model/spline.py @@ -277,9 +277,11 @@ def forward(self, x): :return: The output tensor. :rtype: torch.Tensor """ - basis = self.basis(x.as_subclass(torch.Tensor)) - - return basis @ self.control_points + return torch.einsum( + "...bi, i -> ...b", + self.basis(x.as_subclass(torch.Tensor)).squeeze(-1), + self.control_points, + ) def derivative(self, x, degree): """ diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index 70ea3718e..c2fe54eba 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -55,9 +55,18 @@ class VectorizedSpline(torch.nn.Module): This class does not represent a single multivariate spline :math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate - basis. Instead, it represents a vector spline built from ``s`` + basis. Instead, it represents a vector of splines built from ``s`` independent univariate splines, one for each input feature. + .. note:: + + When using the :meth:`derivative` method of this class, derivatives are + computed directly in vectorized form and returned with the correct + shape. In contrast, when relying on ``autograd``, derivatives must be + computed separately for each output dimension of each univariate spline + and then combined, since autograd does not natively handle this + vectorized structure. + :Example: >>> from pina.model import VectorizedSpline @@ -133,7 +142,8 @@ def __init__( :raises ValueError: If ``control_points`` is not a torch.Tensor, when provided. :raises ValueError: If both ``knots`` and ``control_points`` are None. - :raises ValueError: If ``knots`` is not two-dimensional. + :raises ValueError: If ``knots`` is not two-dimensional, after + processing. :raises ValueError: If ``control_points``, after expansion when two-dimensional, is not three-dimensional. :raises ValueError: If, for each univariate spline, the number of @@ -180,10 +190,6 @@ def __init__( self.control_points = control_points self.aggregate_output = aggregate_output - # Check dimensionality of knots - if self.knots.ndim != 2: - raise ValueError("knots must be two-dimensional.") - # Check dimensionality of control points if self.control_points.ndim != 3: raise ValueError("control_points must be three-dimensional.") @@ -218,6 +224,9 @@ def __init__( # Precompute boundary interval index self._boundary_interval_idx = self._compute_boundary_interval() + # Precompute denominators used in derivative formulas + self._compute_derivative_denominators() + def _compute_boundary_interval(self): """ Precompute the index of the rightmost non-degenerate interval to improve @@ -243,8 +252,36 @@ def _compute_boundary_interval(self): idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0 return idx + + def _compute_derivative_denominators(self): + """ + Precompute the denominators used in the derivatives for all orders up to + the spline order to avoid redundant calculations. + """ + # Precompute for order 2 to k + for i in range(2, self.order + 1): + + # Denominators for the derivative recurrence relations + left_den = self.knots[:, i - 1 : -1] - self.knots[:, :-i] + right_den = self.knots[:, i:] - self.knots[:, 1 : -i + 1] + + # If consecutive knots are equal, set left and right factors to zero + left_fac = torch.where( + torch.abs(left_den) > 1e-10, + (i - 1) / left_den, + torch.zeros_like(left_den), + ) + right_fac = torch.where( + torch.abs(right_den) > 1e-10, + (i - 1) / right_den, + torch.zeros_like(right_den), + ) - def basis(self, x): + # Register buffers + self.register_buffer(f"_left_factor_order_{i}", left_fac) + self.register_buffer(f"_right_factor_order_{i}", right_fac) + + def basis(self, x, collection=False): """ Evaluate the B-spline basis functions for each univariate spline. @@ -252,12 +289,18 @@ def basis(self, x): all univariate splines of the vector spline. :param torch.Tensor x: The points to be evaluated. + :param bool collection: If True, returns a list of basis functions for + all orders up to the spline order. Default is False. + :raise ValueError: If ``collection`` is not a boolean. :raises ValueError: If ``x`` is not two-dimensional. :raises ValueError: If the number of input features does not match the number of univariate splines. :return: The basis functions evaluated at x. :rtype: torch.Tensor """ + # Check consistency + check_consistency(collection, bool) + # Ensure x is a tensor of the same dtype as knots x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype) @@ -300,6 +343,10 @@ def basis(self, x): b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True) basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0 + # If returning the whole collection, initialize list + if collection: + basis_collection = [None, basis] + # Cox-de Boor recursion -- iterative case for i in range(1, self.order): @@ -322,7 +369,10 @@ def basis(self, x): # Combine terms to get the new basis basis = term1 + term2 - return basis + if collection: + basis_collection.append(basis) + + return basis_collection if collection else basis def forward(self, x): """ @@ -358,6 +408,91 @@ def forward(self, x): out = out.squeeze(-1) return out + + def derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of each univariate spline at the + given input points. + + The output has shape ``[batch, s, o]``, where ``o`` is the output + dimension of each univariate spline, unless an aggregation method is + specified. If both ``s`` and ``o`` are 1, the output is aggregated + across the last dimension, resulting in an output of shape + ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or + ``"sum"``, the output is aggregated across the last dimension, resulting + in an output of shape ``[batch, s]``. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param int degree: The derivative degree to compute. + :return: The derivative tensor. + :rtype: torch.Tensor + """ + # Check consistency + check_positive_integer(degree, strict=False) + + # Compute basis derivative + der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree) + + # Compute the output for each spline + out = torch.einsum("bsc,soc->bso", der, self.control_points) + + # Aggregate output if needed + if self.aggregate_output == "mean": + out = out.mean(dim=-1) + elif self.aggregate_output == "sum": + out = out.sum(dim=-1) + elif out.shape[1] == 1 and out.shape[2] == 1: + out = out.squeeze(-1) + + return out + + def _basis_derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of the vectorized spline basis + functions at the given input points using an iterative approach. + + :param torch.Tensor x: The points to be evaluated. + :param int degree: The derivative degree to compute. + :return: The derivative of the basis functions of order ``self.order``. + :rtype: torch.Tensor + """ + # Compute the whole basis collection + basis = self.basis(x, collection=True) + + # Derivatives initialization (dummy at index 0 for convenience) + derivatives = [None] + [basis[o] for o in range(1, self.order + 1)] + + # Iterate over derivative degrees + for _ in range(1, degree + 1): + + # Current degree derivatives (with dummy at index 0 for convenience) + current_der = [None] * (self.order + 1) + current_der[1] = torch.zeros_like(derivatives[1]) + + # Iterate over basis orders + for o in range(2, self.order + 1): + + # Retrieve precomputed factors + left_fac = getattr(self, f"_left_factor_order_{o}") + right_fac = getattr(self, f"_right_factor_order_{o}") + + # derivatives[o - 1] has shape [b, s, m] + # Slice previous derivatives to align + left_part = derivatives[o - 1][..., :-1] + right_part = derivatives[o - 1][..., 1:] + + # Broadcast factors over batch dims + left_fac = left_fac.unsqueeze(0) + right_fac = right_fac.unsqueeze(0) + + # Compute current derivatives + current_der[o] = left_fac * left_part - right_fac * right_part + + # Update derivatives for next degree + derivatives = current_der + + return derivatives[self.order] @property def control_points(self): @@ -440,6 +575,7 @@ def knots(self, value): :raises ValueError: If a dictionary is provided but does not contain the required keys. :raises ValueError: If the mode specified in the dictionary is invalid. + :raises ValueError: If knots is not two-dimensional after processing. """ # If a dictionary is provided, initialize knots accordingly if isinstance(value, dict): @@ -498,6 +634,13 @@ def knots(self, value): # Set knots self.register_buffer("_knots", value.sort(dim=-1).values) + # Check dimensionality of knots + if self.knots.ndim != 2: + raise ValueError("knots must be two-dimensional.") + # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): self._boundary_interval_idx = self._compute_boundary_interval() + + # Recompute derivative denominators when knots change + self._compute_derivative_denominators() diff --git a/tests/test_model/test_vectorized_spline.py b/tests/test_model/test_vectorized_spline.py index 373333895..cc0107073 100644 --- a/tests/test_model/test_vectorized_spline.py +++ b/tests/test_model/test_vectorized_spline.py @@ -1,6 +1,7 @@ import torch import pytest from pina.model import VectorizedSpline, Spline +from pina.operator import grad from pina import LabelTensor @@ -227,6 +228,33 @@ def test_backward(args): assert model.control_points.grad.shape == model.control_points.shape +@pytest.mark.parametrize("args", valid_args) +def test_derivative(args): + + # Define and evaluate the model + model = VectorizedSpline(**args) + pts.requires_grad_(True) + output_ = model(pts) + + # Compute analytical derivatives + first_der = model.derivative(x=pts, degree=1) + + # Compute autograd derivatives -- we need to loop over output dimensions + # since autograd doesn't support vectorized outputs + gradients = [] + for j in range(output_.shape[2]): + out = output_[:, :, j].squeeze(-1) + out = LabelTensor(out, [f"u{j}" for j in range(out.shape[1])]) + gradients.append( + grad(out, pts)[[f"du{j}dx{j}" for j in range(pts.shape[1])]] + ) + first_der_auto = torch.stack(gradients, dim=-1) + + # Check shape and value + assert first_der.shape == first_der_auto.shape + assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) + + def test_1d_vs_vectorized(): control_points = torch.rand(1, 1, n_ctrl_pts) From 6caa8734cdb509ad7ba8f82869ae5b007dec8134 Mon Sep 17 00:00:00 2001 From: GiovanniCanali Date: Wed, 8 Apr 2026 10:51:47 +0200 Subject: [PATCH 12/12] fix minor shape bug --- pina/_src/model/vectorized_spline.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py index c2fe54eba..0fd7c2535 100644 --- a/pina/_src/model/vectorized_spline.py +++ b/pina/_src/model/vectorized_spline.py @@ -222,7 +222,9 @@ def __init__( ) # Precompute boundary interval index - self._boundary_interval_idx = self._compute_boundary_interval() + self.register_buffer( + "_boundary_interval_idx", self._compute_boundary_interval() + ) # Precompute denominators used in derivative formulas self._compute_derivative_denominators() @@ -252,7 +254,7 @@ def _compute_boundary_interval(self): idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0 return idx - + def _compute_derivative_denominators(self): """ Precompute the denominators used in the derivatives for all orders up to @@ -334,8 +336,10 @@ def basis(self, x, collection=False): knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1] # Identify points at the rightmost boundary - at_rightmost_boundary = (x >= knot_left.unsqueeze(0)) & torch.isclose( - x, knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10 + at_rightmost_boundary = ( + x.squeeze(-1) >= knot_left.unsqueeze(0) + ) & torch.isclose( + x.squeeze(-1), knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10 ) # Ensure the correct value is set at the rightmost boundary @@ -408,12 +412,12 @@ def forward(self, x): out = out.squeeze(-1) return out - + def derivative(self, x, degree): """ Compute the ``degree``-th derivative of each univariate spline at the - given input points. - + given input points. + The output has shape ``[batch, s, o]``, where ``o`` is the output dimension of each univariate spline, unless an aggregation method is specified. If both ``s`` and ``o`` are 1, the output is aggregated @@ -472,7 +476,7 @@ def _basis_derivative(self, x, degree): # Iterate over basis orders for o in range(2, self.order + 1): - + # Retrieve precomputed factors left_fac = getattr(self, f"_left_factor_order_{o}") right_fac = getattr(self, f"_right_factor_order_{o}") @@ -640,7 +644,9 @@ def knots(self, value): # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): - self._boundary_interval_idx = self._compute_boundary_interval() + self.register_buffer( + "_boundary_interval_idx", self._compute_boundary_interval() + ) # Recompute derivative denominators when knots change self._compute_derivative_denominators()