tenrso_exec/executor/functions.rs
1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::{BinaryOp, ElemOp, ReduceOp, ScatterMode};
6use crate::hints::ExecHints;
7use anyhow::Result;
8use scirs2_core::numeric::Num;
9use tenrso_core::{Axis, TensorHandle};
10/// Main executor trait for tensor operations
11pub trait TenrsoExecutor<T>
12where
13 T: Clone + Num + 'static,
14{
15 /// Execute an einsum contraction
16 fn einsum(
17 &mut self,
18 spec: &str,
19 inputs: &[TensorHandle<T>],
20 hints: &ExecHints,
21 ) -> Result<TensorHandle<T>>;
22 /// Apply unary element-wise operation
23 fn elem_op(&mut self, op: ElemOp, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
24 /// Apply binary element-wise operation
25 fn binary_op(
26 &mut self,
27 op: BinaryOp,
28 x: &TensorHandle<T>,
29 y: &TensorHandle<T>,
30 ) -> Result<TensorHandle<T>>;
31 /// Apply reduction operation
32 fn reduce(
33 &mut self,
34 op: ReduceOp,
35 x: &TensorHandle<T>,
36 axes: &[Axis],
37 ) -> Result<TensorHandle<T>>;
38 /// Clip tensor values to be within [min_val, max_val]
39 fn clip(&mut self, x: &TensorHandle<T>, min_val: T, max_val: T) -> Result<TensorHandle<T>>;
40 /// Softmax operation along specified axis
41 /// Computes: exp(x) / sum(exp(x), axis)
42 /// Uses numerically stable implementation: exp(x - max(x))
43 fn softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
44 /// Log-softmax operation along specified axis (numerically stable)
45 /// Computes: x - log(sum(exp(x), axis))
46 fn log_softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
47 /// Transpose/permute tensor axes
48 /// Reorders the dimensions according to the provided axes permutation
49 fn transpose(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>>;
50 /// Reshape tensor to new shape
51 /// Total number of elements must remain constant
52 fn reshape(&mut self, x: &TensorHandle<T>, new_shape: &[usize]) -> Result<TensorHandle<T>>;
53 /// Concatenate tensors along specified axis
54 fn concatenate(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>>;
55 /// Split tensor along specified axis into chunks
56 fn split(
57 &mut self,
58 x: &TensorHandle<T>,
59 num_splits: usize,
60 axis: Axis,
61 ) -> Result<Vec<TensorHandle<T>>>;
62 /// Layer normalization (fused operation)
63 /// Normalizes over the last dimension: (x - mean) / sqrt(var + eps)
64 fn layer_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>>;
65 /// Batch normalization (fused operation)
66 /// Normalizes over the batch dimension (first axis)
67 fn batch_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>>;
68 /// Conditional selection: where(condition, x, y)
69 /// Returns x where condition is true (>0), y otherwise
70 /// All tensors must have the same shape
71 fn where_op(
72 &mut self,
73 condition: &TensorHandle<T>,
74 x: &TensorHandle<T>,
75 y: &TensorHandle<T>,
76 ) -> Result<TensorHandle<T>>;
77 /// Masked selection: select values from x where mask is true (>0)
78 /// Returns a 1D tensor containing selected values
79 fn masked_select(
80 &mut self,
81 x: &TensorHandle<T>,
82 mask: &TensorHandle<T>,
83 ) -> Result<TensorHandle<T>>;
84 /// Element-wise modulo operation: x % divisor
85 fn modulo(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>>;
86 /// Element-wise remainder operation (same as modulo for positive numbers)
87 fn remainder(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>>;
88 /// Max pooling operation (1D)
89 /// Applies max pooling with specified kernel size and stride
90 fn max_pool_1d(
91 &mut self,
92 x: &TensorHandle<T>,
93 kernel_size: usize,
94 stride: usize,
95 ) -> Result<TensorHandle<T>>;
96 /// Average pooling operation (1D)
97 /// Applies average pooling with specified kernel size and stride
98 fn avg_pool_1d(
99 &mut self,
100 x: &TensorHandle<T>,
101 kernel_size: usize,
102 stride: usize,
103 ) -> Result<TensorHandle<T>>;
104 /// Max pooling operation (2D)
105 /// Applies max pooling with specified kernel size and stride
106 fn max_pool_2d(
107 &mut self,
108 x: &TensorHandle<T>,
109 kernel_size: (usize, usize),
110 stride: (usize, usize),
111 ) -> Result<TensorHandle<T>>;
112 /// Average pooling operation (2D)
113 /// Applies average pooling with specified kernel size and stride
114 fn avg_pool_2d(
115 &mut self,
116 x: &TensorHandle<T>,
117 kernel_size: (usize, usize),
118 stride: (usize, usize),
119 ) -> Result<TensorHandle<T>>;
120 /// 1D Convolution operation
121 /// Applies 1D convolution: output\[i\] = sum_j(input\[i+j\] * kernel\[j\])
122 ///
123 /// # Arguments
124 /// * `x` - Input tensor of shape \[batch, in_channels, length\]
125 /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_size\]
126 /// * `bias` - Optional bias of shape \[out_channels\]
127 /// * `stride` - Stride for the convolution
128 /// * `padding` - Padding to apply (left, right)
129 fn conv1d(
130 &mut self,
131 x: &TensorHandle<T>,
132 kernel: &TensorHandle<T>,
133 bias: Option<&TensorHandle<T>>,
134 stride: usize,
135 padding: (usize, usize),
136 ) -> Result<TensorHandle<T>>;
137 /// 2D Convolution operation
138 /// Applies 2D convolution for image processing
139 ///
140 /// # Arguments
141 /// * `x` - Input tensor of shape \[batch, in_channels, height, width\]
142 /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_h, kernel_w\]
143 /// * `bias` - Optional bias of shape \[out_channels\]
144 /// * `stride` - Stride for the convolution (stride_h, stride_w)
145 /// * `padding` - Padding to apply (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
146 fn conv2d(
147 &mut self,
148 x: &TensorHandle<T>,
149 kernel: &TensorHandle<T>,
150 bias: Option<&TensorHandle<T>>,
151 stride: (usize, usize),
152 padding: (usize, usize, usize, usize),
153 ) -> Result<TensorHandle<T>>;
154 /// 3D Convolution operation
155 /// Applies 3D convolution for volumetric data (video, medical imaging, etc.)
156 ///
157 /// # Arguments
158 /// * `x` - Input tensor of shape \[batch, in_channels, depth, height, width\]
159 /// * `kernel` - Convolution kernel of shape \[out_channels, in_channels, kernel_d, kernel_h, kernel_w\]
160 /// * `bias` - Optional bias of shape \[out_channels\]
161 /// * `stride` - Stride for the convolution (stride_d, stride_h, stride_w)
162 /// * `padding` - Padding to apply (pad_d_front, pad_d_back, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
163 fn conv3d(
164 &mut self,
165 x: &TensorHandle<T>,
166 kernel: &TensorHandle<T>,
167 bias: Option<&TensorHandle<T>>,
168 stride: (usize, usize, usize),
169 padding: (usize, usize, usize, usize, usize, usize),
170 ) -> Result<TensorHandle<T>>;
171 /// Gather operation - selects values along an axis using indices
172 ///
173 /// # Arguments
174 /// * `x` - Input tensor
175 /// * `axis` - Axis along which to gather
176 /// * `indices` - Integer indices to gather (as Float tensor, will be cast to usize)
177 fn gather(
178 &mut self,
179 x: &TensorHandle<T>,
180 axis: Axis,
181 indices: &TensorHandle<T>,
182 ) -> Result<TensorHandle<T>>;
183 /// Scatter operation - writes values to an output tensor using indices
184 ///
185 /// # Arguments
186 /// * `shape` - Shape of the output tensor
187 /// * `axis` - Axis along which to scatter
188 /// * `indices` - Integer indices where to write (as Float tensor, will be cast to usize)
189 /// * `values` - Values to write at the indices
190 fn scatter(
191 &mut self,
192 shape: &[usize],
193 axis: Axis,
194 indices: &TensorHandle<T>,
195 values: &TensorHandle<T>,
196 ) -> Result<TensorHandle<T>>;
197 /// Matrix determinant operation
198 /// Computes the determinant of a square matrix or batch of matrices
199 ///
200 /// # Arguments
201 /// * `x` - Input tensor of shape [..., N, N] (last two dimensions must be square)
202 ///
203 /// # Returns
204 /// Tensor of shape [...] containing determinants
205 fn determinant(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
206 /// Matrix inverse operation
207 /// Computes the inverse of a square matrix or batch of matrices
208 ///
209 /// # Arguments
210 /// * `x` - Input tensor of shape [..., N, N] (last two dimensions must be square)
211 ///
212 /// # Returns
213 /// Tensor of shape [..., N, N] containing matrix inverses
214 fn matrix_inverse(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>>;
215 /// Solve linear system Ax = b
216 /// Solves the linear system of equations for x
217 ///
218 /// # Arguments
219 /// * `a` - Coefficient matrix of shape [..., N, N]
220 /// * `b` - Right-hand side of shape [..., N] or [..., N, K]
221 ///
222 /// # Returns
223 /// Solution tensor x of the same shape as b
224 fn solve(&mut self, a: &TensorHandle<T>, b: &TensorHandle<T>) -> Result<TensorHandle<T>>;
225
226 /// Advanced gather operation with negative indices support
227 /// Gathers values from `x` along the specified `axis` using `indices`.
228 ///
229 /// # Arguments
230 /// * `x` - Input tensor
231 /// * `axis` - Axis along which to gather
232 /// * `indices` - Integer indices (as Float tensor)
233 /// * `allow_negative` - Whether to allow Python-style negative indices
234 ///
235 /// # Returns
236 /// Tensor with gathered values
237 fn advanced_gather(
238 &mut self,
239 x: &TensorHandle<T>,
240 axis: Axis,
241 indices: &TensorHandle<T>,
242 allow_negative: bool,
243 ) -> Result<TensorHandle<T>>;
244
245 /// Advanced scatter operation with accumulation modes
246 /// Scatters `values` into an output tensor along the specified `axis` using `indices`.
247 ///
248 /// # Arguments
249 /// * `shape` - Shape of the output tensor
250 /// * `axis` - Axis along which to scatter
251 /// * `indices` - Integer indices where to write
252 /// * `values` - Values to scatter
253 /// * `mode` - Scatter mode (Replace, Add, Max, Min)
254 ///
255 /// # Returns
256 /// Output tensor with scattered values
257 fn advanced_scatter(
258 &mut self,
259 shape: &[usize],
260 axis: Axis,
261 indices: &TensorHandle<T>,
262 values: &TensorHandle<T>,
263 mode: ScatterMode,
264 ) -> Result<TensorHandle<T>>;
265
266 /// Fancy indexing with boolean masks
267 /// Selects elements from `x` where `mask` is true (> 0).
268 ///
269 /// # Arguments
270 /// * `x` - Input tensor
271 /// * `mask` - Boolean mask tensor (same shape as x)
272 ///
273 /// # Returns
274 /// 1D tensor containing selected elements
275 fn fancy_index_mask(
276 &mut self,
277 x: &TensorHandle<T>,
278 mask: &TensorHandle<T>,
279 ) -> Result<TensorHandle<T>>;
280
281 /// Tile operation - repeats a tensor along each dimension
282 /// Constructs an array by repeating `x` the number of times given by `reps`.
283 ///
284 /// # Arguments
285 /// * `x` - Input tensor
286 /// * `reps` - Number of repetitions along each dimension
287 ///
288 /// # Returns
289 /// Tiled tensor
290 fn tile(&mut self, x: &TensorHandle<T>, reps: &[usize]) -> Result<TensorHandle<T>>;
291
292 /// Pad operation - pads a tensor with a constant value
293 /// Pads the input tensor with the specified value.
294 ///
295 /// # Arguments
296 /// * `x` - Input tensor
297 /// * `pad_width` - Number of values padded to edges of each axis (before, after) for each dimension
298 /// * `constant_value` - The value to set the padded values
299 ///
300 /// # Returns
301 /// Padded tensor
302 fn pad(
303 &mut self,
304 x: &TensorHandle<T>,
305 pad_width: &[(usize, usize)],
306 constant_value: T,
307 ) -> Result<TensorHandle<T>>;
308
309 /// Flip operation - reverses the order of elements along specified axes
310 /// Reverses the order of elements in an array along the given axes.
311 ///
312 /// # Arguments
313 /// * `x` - Input tensor
314 /// * `axes` - Axes along which to flip
315 ///
316 /// # Returns
317 /// Flipped tensor
318 fn flip(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>>;
319
320 /// Squeeze operation - removes dimensions of size 1
321 /// Removes single-dimensional entries from the shape of an array.
322 ///
323 /// # Arguments
324 /// * `x` - Input tensor
325 /// * `axes` - Optional axes to squeeze. If None, all axes of size 1 are removed
326 ///
327 /// # Returns
328 /// Squeezed tensor
329 fn squeeze(&mut self, x: &TensorHandle<T>, axes: Option<&[Axis]>) -> Result<TensorHandle<T>>;
330
331 /// Unsqueeze/expand_dims operation - adds a dimension of size 1
332 /// Expands the shape of an array by inserting a new axis.
333 ///
334 /// # Arguments
335 /// * `x` - Input tensor
336 /// * `axis` - Position where new axis is to be inserted
337 ///
338 /// # Returns
339 /// Tensor with expanded dimensions
340 fn unsqueeze(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
341
342 /// Stack operation - joins tensors along a new axis
343 /// Joins a sequence of tensors along a new axis.
344 ///
345 /// # Arguments
346 /// * `tensors` - Sequence of tensors to stack
347 /// * `axis` - Axis along which to stack
348 ///
349 /// # Returns
350 /// Stacked tensor
351 fn stack(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>>;
352
353 /// Repeat operation - repeats elements of an array
354 /// Repeat elements of an array along each axis.
355 ///
356 /// # Arguments
357 /// * `x` - Input tensor
358 /// * `repeats` - Number of repetitions for each element along each axis
359 /// * `axis` - Axis along which to repeat values
360 ///
361 /// # Returns
362 /// Repeated tensor
363 fn repeat(
364 &mut self,
365 x: &TensorHandle<T>,
366 repeats: usize,
367 axis: Axis,
368 ) -> Result<TensorHandle<T>>;
369
370 /// Roll operation - rolls array elements along an axis
371 /// Shifts elements along an axis, wrapping around at the boundaries.
372 ///
373 /// # Arguments
374 /// * `x` - Input tensor
375 /// * `shift` - Number of places to shift (positive or negative)
376 /// * `axis` - Axis along which to roll
377 ///
378 /// # Returns
379 /// Rolled tensor
380 fn roll(&mut self, x: &TensorHandle<T>, shift: isize, axis: Axis) -> Result<TensorHandle<T>>;
381
382 /// ArgMax operation - indices of maximum values along an axis
383 /// Returns the indices of the maximum values along an axis.
384 ///
385 /// # Arguments
386 /// * `x` - Input tensor
387 /// * `axis` - Axis along which to find argmax
388 ///
389 /// # Returns
390 /// Tensor containing indices of maximum values
391 fn argmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
392
393 /// ArgMin operation - indices of minimum values along an axis
394 /// Returns the indices of the minimum values along an axis.
395 ///
396 /// # Arguments
397 /// * `x` - Input tensor
398 /// * `axis` - Axis along which to find argmin
399 ///
400 /// # Returns
401 /// Tensor containing indices of minimum values
402 fn argmin(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>>;
403}