tenrso_exec/executor/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::{BinaryOp, ElemOp, ReduceOp, ScatterMode};
6use crate::hints::ExecHints;
7use anyhow::Result;
8use scirs2_core::numeric::Num;
9use tenrso_core::{Axis, TensorHandle};
10/// Main executor trait for tensor operations
11pub trait TenrsoExecutor<T>
12where
13    T: Clone + Num + 'static,
14{
15    /// Execute an einsum contraction
16    fn einsum(
17        &mut self,
18        spec: &str,
19        inputs: &[TensorHandle<T>],
20        hints: &ExecHints,
21    ) -> Result<TensorHandle<T>>;
22    /// Apply unary element-wise operation
23    fn elem_op(&mut self, op: ElemOp, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
24    /// Apply binary element-wise operation
25    fn binary_op(
26        &mut self,
27        op: BinaryOp,
28        x: &TensorHandle<T>,
29        y: &TensorHandle<T>,
30    ) -> Result<TensorHandle<T>>;
31    /// Apply reduction operation
32    fn reduce(
33        &mut self,
34        op: ReduceOp,
35        x: &TensorHandle<T>,
36        axes: &[Axis],
37    ) -> Result<TensorHandle<T>>;
38    /// Clip tensor values to be within [min_val, max_val]
39    fn clip(&mut self, x: &TensorHandle<T>, min_val: T, max_val: T) -> Result<TensorHandle<T>>;
40    /// Softmax operation along specified axis
41    /// Computes: exp(x) / sum(exp(x), axis)
42    /// Uses numerically stable implementation: exp(x - max(x))
43    fn softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
44    /// Log-softmax operation along specified axis (numerically stable)
45    /// Computes: x - log(sum(exp(x), axis))
46    fn log_softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
47    /// Transpose/permute tensor axes
48    /// Reorders the dimensions according to the provided axes permutation
49    fn transpose(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>>;
50    /// Reshape tensor to new shape
51    /// Total number of elements must remain constant
52    fn reshape(&mut self, x: &TensorHandle<T>, new_shape: &[usize]) -> Result<TensorHandle<T>>;
53    /// Concatenate tensors along specified axis
54    fn concatenate(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>>;
55    /// Split tensor along specified axis into chunks
56    fn split(
57        &mut self,
58        x: &TensorHandle<T>,
59        num_splits: usize,
60        axis: Axis,
61    ) -> Result<Vec<TensorHandle<T>>>;
62    /// Layer normalization (fused operation)
63    /// Normalizes over the last dimension: (x - mean) / sqrt(var + eps)
64    fn layer_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>>;
65    /// Batch normalization (fused operation)
66    /// Normalizes over the batch dimension (first axis)
67    fn batch_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>>;
68    /// Conditional selection: where(condition, x, y)
69    /// Returns x where condition is true (>0), y otherwise
70    /// All tensors must have the same shape
71    fn where_op(
72        &mut self,
73        condition: &TensorHandle<T>,
74        x: &TensorHandle<T>,
75        y: &TensorHandle<T>,
76    ) -> Result<TensorHandle<T>>;
77    /// Masked selection: select values from x where mask is true (>0)
78    /// Returns a 1D tensor containing selected values
79    fn masked_select(
80        &mut self,
81        x: &TensorHandle<T>,
82        mask: &TensorHandle<T>,
83    ) -> Result<TensorHandle<T>>;
84    /// Element-wise modulo operation: x % divisor
85    fn modulo(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>>;
86    /// Element-wise remainder operation (same as modulo for positive numbers)
87    fn remainder(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>>;
88    /// Max pooling operation (1D)
89    /// Applies max pooling with specified kernel size and stride
90    fn max_pool_1d(
91        &mut self,
92        x: &TensorHandle<T>,
93        kernel_size: usize,
94        stride: usize,
95    ) -> Result<TensorHandle<T>>;
96    /// Average pooling operation (1D)
97    /// Applies average pooling with specified kernel size and stride
98    fn avg_pool_1d(
99        &mut self,
100        x: &TensorHandle<T>,
101        kernel_size: usize,
102        stride: usize,
103    ) -> Result<TensorHandle<T>>;
104    /// Max pooling operation (2D)
105    /// Applies max pooling with specified kernel size and stride
106    fn max_pool_2d(
107        &mut self,
108        x: &TensorHandle<T>,
109        kernel_size: (usize, usize),
110        stride: (usize, usize),
111    ) -> Result<TensorHandle<T>>;
112    /// Average pooling operation (2D)
113    /// Applies average pooling with specified kernel size and stride
114    fn avg_pool_2d(
115        &mut self,
116        x: &TensorHandle<T>,
117        kernel_size: (usize, usize),
118        stride: (usize, usize),
119    ) -> Result<TensorHandle<T>>;
120    /// 1D Convolution operation
121    /// Applies 1D convolution: output\[i\] = sum_j(input\[i+j\] * kernel\[j\])
122    ///
123    /// # Arguments
124    /// * `x` - Input tensor of shape \[batch, in_channels, length\]
125    /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_size\]
126    /// * `bias` - Optional bias of shape \[out_channels\]
127    /// * `stride` - Stride for the convolution
128    /// * `padding` - Padding to apply (left, right)
129    fn conv1d(
130        &mut self,
131        x: &TensorHandle<T>,
132        kernel: &TensorHandle<T>,
133        bias: Option<&TensorHandle<T>>,
134        stride: usize,
135        padding: (usize, usize),
136    ) -> Result<TensorHandle<T>>;
137    /// 2D Convolution operation
138    /// Applies 2D convolution for image processing
139    ///
140    /// # Arguments
141    /// * `x` - Input tensor of shape \[batch, in_channels, height, width\]
142    /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_h, kernel_w\]
143    /// * `bias` - Optional bias of shape \[out_channels\]
144    /// * `stride` - Stride for the convolution (stride_h, stride_w)
145    /// * `padding` - Padding to apply (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
146    fn conv2d(
147        &mut self,
148        x: &TensorHandle<T>,
149        kernel: &TensorHandle<T>,
150        bias: Option<&TensorHandle<T>>,
151        stride: (usize, usize),
152        padding: (usize, usize, usize, usize),
153    ) -> Result<TensorHandle<T>>;
154    /// 3D Convolution operation
155    /// Applies 3D convolution for volumetric data (video, medical imaging, etc.)
156    ///
157    /// # Arguments
158    /// * `x` - Input tensor of shape \[batch, in_channels, depth, height, width\]
159    /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_d, kernel_h, kernel_w\]
160    /// * `bias` - Optional bias of shape \[out_channels\]
161    /// * `stride` - Stride for the convolution (stride_d, stride_h, stride_w)
162    /// * `padding` - Padding to apply (pad_d_front, pad_d_back, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
163    fn conv3d(
164        &mut self,
165        x: &TensorHandle<T>,
166        kernel: &TensorHandle<T>,
167        bias: Option<&TensorHandle<T>>,
168        stride: (usize, usize, usize),
169        padding: (usize, usize, usize, usize, usize, usize),
170    ) -> Result<TensorHandle<T>>;
171    /// Gather operation - selects values along an axis using indices
172    ///
173    /// # Arguments
174    /// * `x` - Input tensor
175    /// * `axis` - Axis along which to gather
176    /// * `indices` - Integer indices to gather (as Float tensor, will be cast to usize)
177    fn gather(
178        &mut self,
179        x: &TensorHandle<T>,
180        axis: Axis,
181        indices: &TensorHandle<T>,
182    ) -> Result<TensorHandle<T>>;
183    /// Scatter operation - writes values to an output tensor using indices
184    ///
185    /// # Arguments
186    /// * `shape` - Shape of the output tensor
187    /// * `axis` - Axis along which to scatter
188    /// * `indices` - Integer indices where to write (as Float tensor, will be cast to usize)
189    /// * `values` - Values to write at the indices
190    fn scatter(
191        &mut self,
192        shape: &[usize],
193        axis: Axis,
194        indices: &TensorHandle<T>,
195        values: &TensorHandle<T>,
196    ) -> Result<TensorHandle<T>>;
197    /// Matrix determinant operation
198    /// Computes the determinant of a square matrix or batch of matrices
199    ///
200    /// # Arguments
201    /// * `x` - Input tensor of shape [..., N, N] (last two dimensions must be square)
202    ///
203    /// # Returns
204    /// Tensor of shape [...] containing determinants
205    fn determinant(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
206    /// Matrix inverse operation
207    /// Computes the inverse of a square matrix or batch of matrices
208    ///
209    /// # Arguments
210    /// * `x` - Input tensor of shape [..., N, N] (last two dimensions must be square)
211    ///
212    /// # Returns
213    /// Tensor of shape [..., N, N] containing matrix inverses
214    fn matrix_inverse(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
215    /// Solve linear system Ax = b
216    /// Solves the linear system of equations for x
217    ///
218    /// # Arguments
219    /// * `a` - Coefficient matrix of shape [..., N, N]
220    /// * `b` - Right-hand side of shape [..., N] or [..., N, K]
221    ///
222    /// # Returns
223    /// Solution tensor x of the same shape as b
224    fn solve(&mut self, a: &TensorHandle<T>, b: &TensorHandle<T>) -> Result<TensorHandle<T>>;
225
226    /// Advanced gather operation with negative indices support
227    /// Gathers values from `x` along the specified `axis` using `indices`.
228    ///
229    /// # Arguments
230    /// * `x` - Input tensor
231    /// * `axis` - Axis along which to gather
232    /// * `indices` - Integer indices (as Float tensor)
233    /// * `allow_negative` - Whether to allow Python-style negative indices
234    ///
235    /// # Returns
236    /// Tensor with gathered values
237    fn advanced_gather(
238        &mut self,
239        x: &TensorHandle<T>,
240        axis: Axis,
241        indices: &TensorHandle<T>,
242        allow_negative: bool,
243    ) -> Result<TensorHandle<T>>;
244
245    /// Advanced scatter operation with accumulation modes
246    /// Scatters `values` into an output tensor along the specified `axis` using `indices`.
247    ///
248    /// # Arguments
249    /// * `shape` - Shape of the output tensor
250    /// * `axis` - Axis along which to scatter
251    /// * `indices` - Integer indices where to write
252    /// * `values` - Values to scatter
253    /// * `mode` - Scatter mode (Replace, Add, Max, Min)
254    ///
255    /// # Returns
256    /// Output tensor with scattered values
257    fn advanced_scatter(
258        &mut self,
259        shape: &[usize],
260        axis: Axis,
261        indices: &TensorHandle<T>,
262        values: &TensorHandle<T>,
263        mode: ScatterMode,
264    ) -> Result<TensorHandle<T>>;
265
266    /// Fancy indexing with boolean masks
267    /// Selects elements from `x` where `mask` is true (> 0).
268    ///
269    /// # Arguments
270    /// * `x` - Input tensor
271    /// * `mask` - Boolean mask tensor (same shape as x)
272    ///
273    /// # Returns
274    /// 1D tensor containing selected elements
275    fn fancy_index_mask(
276        &mut self,
277        x: &TensorHandle<T>,
278        mask: &TensorHandle<T>,
279    ) -> Result<TensorHandle<T>>;
280
281    /// Tile operation - repeats a tensor along each dimension
282    /// Constructs an array by repeating `x` the number of times given by `reps`.
283    ///
284    /// # Arguments
285    /// * `x` - Input tensor
286    /// * `reps` - Number of repetitions along each dimension
287    ///
288    /// # Returns
289    /// Tiled tensor
290    fn tile(&mut self, x: &TensorHandle<T>, reps: &[usize]) -> Result<TensorHandle<T>>;
291
292    /// Pad operation - pads a tensor with a constant value
293    /// Pads the input tensor with the specified value.
294    ///
295    /// # Arguments
296    /// * `x` - Input tensor
297    /// * `pad_width` - Number of values padded to edges of each axis (before, after) for each dimension
298    /// * `constant_value` - The value to set the padded values
299    ///
300    /// # Returns
301    /// Padded tensor
302    fn pad(
303        &mut self,
304        x: &TensorHandle<T>,
305        pad_width: &[(usize, usize)],
306        constant_value: T,
307    ) -> Result<TensorHandle<T>>;
308
309    /// Flip operation - reverses the order of elements along specified axes
310    /// Reverses the order of elements in an array along the given axes.
311    ///
312    /// # Arguments
313    /// * `x` - Input tensor
314    /// * `axes` - Axes along which to flip
315    ///
316    /// # Returns
317    /// Flipped tensor
318    fn flip(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>>;
319
320    /// Squeeze operation - removes dimensions of size 1
321    /// Removes single-dimensional entries from the shape of an array.
322    ///
323    /// # Arguments
324    /// * `x` - Input tensor
325    /// * `axes` - Optional axes to squeeze. If None, all axes of size 1 are removed
326    ///
327    /// # Returns
328    /// Squeezed tensor
329    fn squeeze(&mut self, x: &TensorHandle<T>, axes: Option<&[Axis]>) -> Result<TensorHandle<T>>;
330
331    /// Unsqueeze/expand_dims operation - adds a dimension of size 1
332    /// Expands the shape of an array by inserting a new axis.
333    ///
334    /// # Arguments
335    /// * `x` - Input tensor
336    /// * `axis` - Position where new axis is to be inserted
337    ///
338    /// # Returns
339    /// Tensor with expanded dimensions
340    fn unsqueeze(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
341
342    /// Stack operation - joins tensors along a new axis
343    /// Joins a sequence of tensors along a new axis.
344    ///
345    /// # Arguments
346    /// * `tensors` - Sequence of tensors to stack
347    /// * `axis` - Axis along which to stack
348    ///
349    /// # Returns
350    /// Stacked tensor
351    fn stack(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>>;
352
353    /// Repeat operation - repeats elements of an array
354    /// Repeat elements of an array along each axis.
355    ///
356    /// # Arguments
357    /// * `x` - Input tensor
358    /// * `repeats` - Number of repetitions for each element along each axis
359    /// * `axis` - Axis along which to repeat values
360    ///
361    /// # Returns
362    /// Repeated tensor
363    fn repeat(
364        &mut self,
365        x: &TensorHandle<T>,
366        repeats: usize,
367        axis: Axis,
368    ) -> Result<TensorHandle<T>>;
369
370    /// Roll operation - rolls array elements along an axis
371    /// Shifts elements along an axis, wrapping around at the boundaries.
372    ///
373    /// # Arguments
374    /// * `x` - Input tensor
375    /// * `shift` - Number of places to shift (positive or negative)
376    /// * `axis` - Axis along which to roll
377    ///
378    /// # Returns
379    /// Rolled tensor
380    fn roll(&mut self, x: &TensorHandle<T>, shift: isize, axis: Axis) -> Result<TensorHandle<T>>;
381
382    /// ArgMax operation - indices of maximum values along an axis
383    /// Returns the indices of the maximum values along an axis.
384    ///
385    /// # Arguments
386    /// * `x` - Input tensor
387    /// * `axis` - Axis along which to find argmax
388    ///
389    /// # Returns
390    /// Tensor containing indices of maximum values
391    fn argmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
392
393    /// ArgMin operation - indices of minimum values along an axis
394    /// Returns the indices of the minimum values along an axis.
395    ///
396    /// # Arguments
397    /// * `x` - Input tensor
398    /// * `axis` - Axis along which to find argmin
399    ///
400    /// # Returns
401    /// Tensor containing indices of minimum values
402    fn argmin(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
403}