1use super::functions::*;
23use super::types::*;
24use crate::hints::ExecHints;
25use anyhow::{anyhow, Result};
26use scirs2_core::numeric::{Float, FromPrimitive, Num};
27use tenrso_core::{Axis, DenseND, TensorHandle};
28use tenrso_planner::EinsumSpec;
29
30impl Default for CpuExecutor {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<T> TenrsoExecutor<T> for CpuExecutor
37where
38 T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
39{
40 fn einsum(
41 &mut self,
42 spec: &str,
43 inputs: &[TensorHandle<T>],
44 hints: &ExecHints,
45 ) -> Result<TensorHandle<T>> {
46 let parsed_spec = EinsumSpec::parse(spec)?;
47 if parsed_spec.num_inputs() != inputs.len() {
48 return Err(anyhow!(
49 "Spec expects {} inputs, got {}",
50 parsed_spec.num_inputs(),
51 inputs.len()
52 ));
53 }
54 let dense_inputs: Vec<&DenseND<T>> = inputs
55 .iter()
56 .map(|h| {
57 h.as_dense()
58 .ok_or_else(|| anyhow!("Only dense tensors supported for now"))
59 })
60 .collect::<Result<Vec<_>>>()?;
61 let dense_inputs_owned: Vec<DenseND<T>> = dense_inputs.iter().map(|&t| t.clone()).collect();
62 let result = self.execute_einsum_with_planner(&parsed_spec, &dense_inputs_owned, hints)?;
63 Ok(TensorHandle::from_dense_auto(result))
64 }
65 fn elem_op(&mut self, op: ElemOp, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
66 let dense = x
67 .as_dense()
68 .ok_or_else(|| anyhow!("Only dense tensors supported for elem_op"))?;
69 let result_data = match op {
70 ElemOp::Neg => dense.view().mapv(|v| -v),
71 ElemOp::Abs => dense.view().mapv(|v| v.abs()),
72 ElemOp::Exp => dense.view().mapv(|v| v.exp()),
73 ElemOp::Log => dense.view().mapv(|v| v.ln()),
74 ElemOp::Sin => dense.view().mapv(|v| v.sin()),
75 ElemOp::Cos => dense.view().mapv(|v| v.cos()),
76 ElemOp::Sqrt => dense.view().mapv(|v| v.sqrt()),
77 ElemOp::Sqr => dense.view().mapv(|v| v * v),
78 ElemOp::Recip => dense.view().mapv(|v| v.recip()),
79 ElemOp::Tanh => dense.view().mapv(|v| v.tanh()),
80 ElemOp::Sigmoid => dense.view().mapv(|v| {
81 let one = T::one();
82 one / (one + (-v).exp())
83 }),
84 ElemOp::ReLU => dense.view().mapv(|v| {
85 let zero = T::zero();
86 if v > zero {
87 v
88 } else {
89 zero
90 }
91 }),
92 ElemOp::Gelu => dense.view().mapv(|v| {
93 let half = T::from_f64(0.5).unwrap_or_else(T::one);
94 let one = T::one();
95 let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
96 let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
97 let x_cubed = v * v * v;
98 let inner = coeff * (v + cubic_coeff * x_cubed);
99 half * v * (one + inner.tanh())
100 }),
101 ElemOp::Elu => dense.view().mapv(|v| {
102 let zero = T::zero();
103 let one = T::one();
104 if v > zero {
105 v
106 } else {
107 v.exp() - one
108 }
109 }),
110 ElemOp::Selu => dense.view().mapv(|v| {
111 let zero = T::zero();
112 let one = T::one();
113 let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
114 let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
115 if v > zero {
116 scale * v
117 } else {
118 scale * alpha * (v.exp() - one)
119 }
120 }),
121 ElemOp::Softplus => dense.view().mapv(|v| {
122 let zero = T::zero();
123 let one = T::one();
124 let abs_v = v.abs();
125 let max_part = if v > zero { v } else { zero };
126 max_part + (one + (-abs_v).exp()).ln()
127 }),
128 ElemOp::Sign => dense.view().mapv(|v| {
129 let zero = T::zero();
130 let one = T::one();
131 let neg_one = -one;
132 if v > zero {
133 one
134 } else if v < zero {
135 neg_one
136 } else {
137 zero
138 }
139 }),
140 };
141 let result = DenseND::from_array(result_data);
142 Ok(TensorHandle::from_dense_auto(result))
143 }
144 fn binary_op(
145 &mut self,
146 op: BinaryOp,
147 x: &TensorHandle<T>,
148 y: &TensorHandle<T>,
149 ) -> Result<TensorHandle<T>> {
150 let dense_x = x
151 .as_dense()
152 .ok_or_else(|| anyhow!("Only dense tensors supported for binary_op"))?;
153 let dense_y = y
154 .as_dense()
155 .ok_or_else(|| anyhow!("Only dense tensors supported for binary_op"))?;
156 if dense_x.shape() != dense_y.shape() {
157 return self.binary_op_with_broadcast(op, dense_x, dense_y);
158 }
159 use scirs2_core::ndarray_ext::Zip;
160 let result_data = match op {
161 BinaryOp::Add => &dense_x.view() + &dense_y.view(),
162 BinaryOp::Sub => &dense_x.view() - &dense_y.view(),
163 BinaryOp::Mul => &dense_x.view() * &dense_y.view(),
164 BinaryOp::Div => &dense_x.view() / &dense_y.view(),
165 BinaryOp::Pow => {
166 let mut result = dense_x.view().to_owned();
167 Zip::from(&mut result)
168 .and(&dense_x.view())
169 .and(&dense_y.view())
170 .for_each(|r, &x_val, &y_val| {
171 *r = x_val.powf(y_val);
172 });
173 result
174 }
175 BinaryOp::Maximum => {
176 let mut result = dense_x.view().to_owned();
177 Zip::from(&mut result)
178 .and(&dense_x.view())
179 .and(&dense_y.view())
180 .for_each(|r, &x_val, &y_val| {
181 *r = if x_val > y_val { x_val } else { y_val };
182 });
183 result
184 }
185 BinaryOp::Minimum => {
186 let mut result = dense_x.view().to_owned();
187 Zip::from(&mut result)
188 .and(&dense_x.view())
189 .and(&dense_y.view())
190 .for_each(|r, &x_val, &y_val| {
191 *r = if x_val < y_val { x_val } else { y_val };
192 });
193 result
194 }
195 };
196 let result = DenseND::from_array(result_data);
197 Ok(TensorHandle::from_dense_auto(result))
198 }
199 fn reduce(
200 &mut self,
201 op: ReduceOp,
202 x: &TensorHandle<T>,
203 axes: &[Axis],
204 ) -> Result<TensorHandle<T>> {
205 let dense = x
206 .as_dense()
207 .ok_or_else(|| anyhow!("Only dense tensors supported for reduce"))?;
208 if axes.is_empty() {
209 return Err(anyhow!("No axes specified for reduction"));
210 }
211 let axis_indices: Vec<usize> = axes.to_vec();
212 let ndim = dense.shape().len();
213 for &axis_idx in &axis_indices {
214 if axis_idx >= ndim {
215 return Err(anyhow!(
216 "Axis index {} out of range for tensor with {} dimensions",
217 axis_idx,
218 ndim
219 ));
220 }
221 }
222 let mut result = dense.view().to_owned();
223 let mut sorted_axes = axis_indices.clone();
224 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
225 for &axis_idx in &sorted_axes {
226 let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
227 result = match op {
228 ReduceOp::Sum => result.sum_axis(axis),
229 ReduceOp::Max => {
230 result
231 .map_axis(
232 axis,
233 |view| {
234 view.iter()
235 .cloned()
236 .max_by(|a, b| {
237 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
238 })
239 .unwrap_or_else(T::default)
240 },
241 )
242 }
243 ReduceOp::Min => {
244 result
245 .map_axis(
246 axis,
247 |view| {
248 view.iter()
249 .cloned()
250 .min_by(|a, b| {
251 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
252 })
253 .unwrap_or_else(T::default)
254 },
255 )
256 }
257 ReduceOp::Mean => {
258 result
259 .mean_axis(axis)
260 .ok_or_else(|| {
261 anyhow!(
262 "Mean reduction failed - axis might be empty or type doesn't support division"
263 )
264 })?
265 }
266 ReduceOp::Prod => {
267 result
268 .map_axis(
269 axis,
270 |view| {
271 view.iter()
272 .cloned()
273 .fold(T::one(), |acc, x| acc * x)
274 },
275 )
276 }
277 ReduceOp::All => {
278 result
279 .map_axis(
280 axis,
281 |view| {
282 let all_nonzero = view.iter().all(|&x| x != T::zero());
283 if all_nonzero { T::one() } else { T::zero() }
284 },
285 )
286 }
287 ReduceOp::Any => {
288 result
289 .map_axis(
290 axis,
291 |view| {
292 let any_nonzero = view.iter().any(|&x| x != T::zero());
293 if any_nonzero { T::one() } else { T::zero() }
294 },
295 )
296 }
297 ReduceOp::ArgMax | ReduceOp::ArgMin => {
298 return Err(anyhow!(
299 "ArgMax and ArgMin should use dedicated argmax/argmin methods, not reduce"
300 ));
301 }
302 };
303 }
304 let result_tensor = DenseND::from_array(result);
305 Ok(TensorHandle::from_dense_auto(result_tensor))
306 }
307 fn clip(&mut self, x: &TensorHandle<T>, min_val: T, max_val: T) -> Result<TensorHandle<T>> {
308 let dense = x
309 .as_dense()
310 .ok_or_else(|| anyhow!("Only dense tensors supported for clip"))?;
311 if min_val > max_val {
312 return Err(anyhow!("Invalid clip bounds: min_val > max_val"));
313 }
314 let result_data = dense.view().mapv(|v| {
315 if v < min_val {
316 min_val
317 } else if v > max_val {
318 max_val
319 } else {
320 v
321 }
322 });
323 let result = DenseND::from_array(result_data);
324 Ok(TensorHandle::from_dense_auto(result))
325 }
326 fn softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
327 let dense = x
328 .as_dense()
329 .ok_or_else(|| anyhow!("Only dense tensors supported for softmax"))?;
330 let ndim = dense.shape().len();
331 if axis >= ndim {
332 return Err(anyhow!(
333 "Axis {} out of range for tensor with {} dimensions",
334 axis,
335 ndim
336 ));
337 }
338 use scirs2_core::ndarray_ext::Zip;
339 let axis_obj = scirs2_core::ndarray_ext::Axis(axis);
340 let max_vals = dense.view().map_axis(axis_obj, |view| {
341 view.iter()
342 .cloned()
343 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
344 .unwrap_or_else(T::zero)
345 });
346 let mut exp_vals = dense.view().to_owned();
347 Zip::from(exp_vals.lanes_mut(axis_obj))
348 .and(max_vals.view())
349 .for_each(|mut lane, &max_val| {
350 lane.mapv_inplace(|v| (v - max_val).exp());
351 });
352 let sum_exp = exp_vals.sum_axis(axis_obj);
353 let mut result = exp_vals;
354 Zip::from(result.lanes_mut(axis_obj))
355 .and(sum_exp.view())
356 .for_each(|mut lane, &sum_val| {
357 lane.mapv_inplace(|v| v / sum_val);
358 });
359 Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
360 }
361 fn log_softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
362 let dense = x
363 .as_dense()
364 .ok_or_else(|| anyhow!("Only dense tensors supported for log_softmax"))?;
365 let ndim = dense.shape().len();
366 if axis >= ndim {
367 return Err(anyhow!(
368 "Axis {} out of range for tensor with {} dimensions",
369 axis,
370 ndim
371 ));
372 }
373 use scirs2_core::ndarray_ext::Zip;
374 let axis_obj = scirs2_core::ndarray_ext::Axis(axis);
375 let max_vals = dense.view().map_axis(axis_obj, |view| {
376 view.iter()
377 .cloned()
378 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
379 .unwrap_or_else(T::zero)
380 });
381 let mut exp_vals = dense.view().to_owned();
382 Zip::from(exp_vals.lanes_mut(axis_obj))
383 .and(max_vals.view())
384 .for_each(|mut lane, &max_val| {
385 lane.mapv_inplace(|v| (v - max_val).exp());
386 });
387 let sum_exp = exp_vals.sum_axis(axis_obj);
388 let log_sum_exp = sum_exp.mapv(|v| v.ln());
389 let mut result = dense.view().to_owned();
390 Zip::from(result.lanes_mut(axis_obj))
391 .and(max_vals.view())
392 .and(log_sum_exp.view())
393 .for_each(|mut lane, &max_val, &lse| {
394 lane.mapv_inplace(|v| v - max_val - lse);
395 });
396 Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
397 }
398 fn transpose(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>> {
399 let dense = x
400 .as_dense()
401 .ok_or_else(|| anyhow!("Only dense tensors supported for transpose"))?;
402 let ndim = dense.shape().len();
403 if axes.len() != ndim {
404 return Err(anyhow!(
405 "Axes length ({}) must match tensor dimensionality ({})",
406 axes.len(),
407 ndim
408 ));
409 }
410 let mut seen = vec![false; ndim];
411 for &axis in axes {
412 if axis >= ndim {
413 return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
414 }
415 if seen[axis] {
416 return Err(anyhow!("Duplicate axis {} in permutation", axis));
417 }
418 seen[axis] = true;
419 }
420 let permuted = dense.view().permuted_axes(axes);
421 let result = DenseND::from_array(permuted.to_owned());
422 Ok(TensorHandle::from_dense_auto(result))
423 }
424 fn reshape(&mut self, x: &TensorHandle<T>, new_shape: &[usize]) -> Result<TensorHandle<T>> {
425 let dense = x
426 .as_dense()
427 .ok_or_else(|| anyhow!("Only dense tensors supported for reshape"))?;
428 let old_size: usize = dense.shape().iter().product();
429 let new_size: usize = new_shape.iter().product();
430 if old_size != new_size {
431 return Err(anyhow!(
432 "Cannot reshape tensor of size {} to size {}",
433 old_size,
434 new_size
435 ));
436 }
437 use scirs2_core::ndarray_ext::{Array, IxDyn};
438 let data: Vec<T> = dense.view().iter().cloned().collect();
439 let reshaped = Array::from_shape_vec(IxDyn(new_shape), data)
440 .map_err(|e| anyhow!("Reshape failed: {}", e))?;
441 let result = DenseND::from_array(reshaped);
442 Ok(TensorHandle::from_dense_auto(result))
443 }
444 fn concatenate(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>> {
445 if tensors.is_empty() {
446 return Err(anyhow!("Cannot concatenate empty tensor list"));
447 }
448 let dense_tensors: Vec<&DenseND<T>> = tensors
449 .iter()
450 .map(|t| {
451 t.as_dense()
452 .ok_or_else(|| anyhow!("Only dense tensors supported for concatenate"))
453 })
454 .collect::<Result<Vec<_>>>()?;
455 let ndim = dense_tensors[0].shape().len();
456 for t in dense_tensors.iter().skip(1) {
457 if t.shape().len() != ndim {
458 return Err(anyhow!(
459 "All tensors must have same number of dimensions, got {} and {}",
460 ndim,
461 t.shape().len()
462 ));
463 }
464 }
465 if axis >= ndim {
466 return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
467 }
468 for dim in 0..ndim {
469 if dim != axis {
470 let expected_size = dense_tensors[0].shape()[dim];
471 for (i, t) in dense_tensors.iter().enumerate().skip(1) {
472 if t.shape()[dim] != expected_size {
473 return Err(anyhow!(
474 "Dimension {} mismatch: tensor 0 has size {}, tensor {} has size {}",
475 dim,
476 expected_size,
477 i,
478 t.shape()[dim]
479 ));
480 }
481 }
482 }
483 }
484 use scirs2_core::ndarray_ext::{Array, IxDyn};
485 let mut output_shape = dense_tensors[0].shape().to_vec();
486 for t in dense_tensors.iter().skip(1) {
487 output_shape[axis] += t.shape()[axis];
488 }
489 let output_size: usize = output_shape.iter().product();
490
491 let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
493 output_data.clear(); output_data.reserve(output_size);
495
496 for flat_idx in 0..output_size {
497 let out_idx = self.flat_to_multidim(flat_idx, &output_shape);
498 let mut cumulative_axis_size = 0;
499 let mut tensor_idx = 0;
500 let mut local_axis_pos = out_idx[axis];
501 for (i, t) in dense_tensors.iter().enumerate() {
502 let t_axis_size = t.shape()[axis];
503 if local_axis_pos < cumulative_axis_size + t_axis_size {
504 tensor_idx = i;
505 local_axis_pos -= cumulative_axis_size;
506 break;
507 }
508 cumulative_axis_size += t_axis_size;
509 }
510 let mut src_idx = out_idx.clone();
511 src_idx[axis] = local_axis_pos;
512 let val = dense_tensors[tensor_idx].view()[src_idx.as_slice()];
513 output_data.push(val);
514 }
515
516 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
517 .map_err(|e| anyhow!("Concatenation failed: {}", e))?;
518 self.release_pooled_generic::<T>(&output_shape, output_data);
519 let result = DenseND::from_array(result_array);
520 Ok(TensorHandle::from_dense_auto(result))
521 }
522 fn split(
523 &mut self,
524 x: &TensorHandle<T>,
525 num_splits: usize,
526 axis: Axis,
527 ) -> Result<Vec<TensorHandle<T>>> {
528 let dense = x
529 .as_dense()
530 .ok_or_else(|| anyhow!("Only dense tensors supported for split"))?;
531 let ndim = dense.shape().len();
532 if axis >= ndim {
533 return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
534 }
535 let axis_size = dense.shape()[axis];
536 if axis_size % num_splits != 0 {
537 return Err(anyhow!(
538 "Cannot split axis of size {} into {} equal parts",
539 axis_size,
540 num_splits
541 ));
542 }
543 let split_size = axis_size / num_splits;
544 let mut results = Vec::with_capacity(num_splits);
545 use scirs2_core::ndarray_ext::Axis as NdAxis;
546 for i in 0..num_splits {
547 let start = i * split_size;
548 let end = start + split_size;
549 let sliced = dense
550 .view()
551 .slice_axis(NdAxis(axis), (start..end).into())
552 .to_owned();
553 results.push(TensorHandle::from_dense_auto(DenseND::from_array(sliced)));
554 }
555 Ok(results)
556 }
557 fn layer_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>> {
558 let dense = x
559 .as_dense()
560 .ok_or_else(|| anyhow!("Only dense tensors supported for layer_norm"))?;
561 let ndim = dense.shape().len();
562 if ndim == 0 {
563 return Err(anyhow!("Cannot normalize scalar tensor"));
564 }
565 let last_axis = ndim - 1;
566 use scirs2_core::ndarray_ext::{Axis as NdAxis, Zip};
567 let mean = dense.view().mean_axis(NdAxis(last_axis)).ok_or_else(|| {
568 anyhow!(
569 "Mean computation failed - axis might be empty or type doesn't support division"
570 )
571 })?;
572 let mut variance = dense.view().to_owned();
573 Zip::from(variance.lanes_mut(NdAxis(last_axis)))
574 .and(mean.view())
575 .for_each(|mut lane, &m| {
576 lane.mapv_inplace(|v| {
577 let diff = v - m;
578 diff * diff
579 });
580 });
581 let variance = variance
582 .mean_axis(NdAxis(last_axis))
583 .ok_or_else(|| anyhow!("Variance computation failed"))?;
584 let mut result = dense.view().to_owned();
585 Zip::from(result.lanes_mut(NdAxis(last_axis)))
586 .and(mean.view())
587 .and(variance.view())
588 .for_each(|mut lane, &m, &v| {
589 let std = (v + eps).sqrt();
590 lane.mapv_inplace(|x_val| (x_val - m) / std);
591 });
592 Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
593 }
594 fn batch_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>> {
595 let dense = x
596 .as_dense()
597 .ok_or_else(|| anyhow!("Only dense tensors supported for batch_norm"))?;
598 let ndim = dense.shape().len();
599 if ndim == 0 {
600 return Err(anyhow!("Cannot normalize scalar tensor"));
601 }
602 let batch_axis = 0;
603 use scirs2_core::ndarray_ext::{Axis as NdAxis, Zip};
604 let mean = dense.view().mean_axis(NdAxis(batch_axis)).ok_or_else(|| {
605 anyhow!(
606 "Mean computation failed - axis might be empty or type doesn't support division"
607 )
608 })?;
609 let mut variance = dense.view().to_owned();
610 Zip::from(variance.lanes_mut(NdAxis(batch_axis)))
611 .and(mean.view())
612 .for_each(|mut lane, &m| {
613 lane.mapv_inplace(|v| {
614 let diff = v - m;
615 diff * diff
616 });
617 });
618 let variance = variance
619 .mean_axis(NdAxis(batch_axis))
620 .ok_or_else(|| anyhow!("Variance computation failed"))?;
621 let mut result = dense.view().to_owned();
622 Zip::from(result.lanes_mut(NdAxis(batch_axis)))
623 .and(mean.view())
624 .and(variance.view())
625 .for_each(|mut lane, &m, &v| {
626 let std = (v + eps).sqrt();
627 lane.mapv_inplace(|x_val| (x_val - m) / std);
628 });
629 Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
630 }
631 fn where_op(
632 &mut self,
633 condition: &TensorHandle<T>,
634 x: &TensorHandle<T>,
635 y: &TensorHandle<T>,
636 ) -> Result<TensorHandle<T>> {
637 let cond_dense = condition
638 .as_dense()
639 .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
640 let x_dense = x
641 .as_dense()
642 .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
643 let y_dense = y
644 .as_dense()
645 .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
646 if cond_dense.shape() != x_dense.shape() || x_dense.shape() != y_dense.shape() {
647 return Err(anyhow!(
648 "Shape mismatch: condition={:?}, x={:?}, y={:?}",
649 cond_dense.shape(),
650 x_dense.shape(),
651 y_dense.shape()
652 ));
653 }
654 use scirs2_core::ndarray_ext::Zip;
655 let mut result = x_dense.view().to_owned();
656 Zip::from(&mut result)
657 .and(&cond_dense.view())
658 .and(&x_dense.view())
659 .and(&y_dense.view())
660 .for_each(|r, &c, &x_val, &y_val| {
661 *r = if c > T::zero() { x_val } else { y_val };
662 });
663 Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
664 }
665 fn masked_select(
666 &mut self,
667 x: &TensorHandle<T>,
668 mask: &TensorHandle<T>,
669 ) -> Result<TensorHandle<T>> {
670 let x_dense = x
671 .as_dense()
672 .ok_or_else(|| anyhow!("Only dense tensors supported for masked_select"))?;
673 let mask_dense = mask
674 .as_dense()
675 .ok_or_else(|| anyhow!("Only dense tensors supported for masked_select"))?;
676 if x_dense.shape() != mask_dense.shape() {
677 return Err(anyhow!(
678 "Shape mismatch: x={:?}, mask={:?}",
679 x_dense.shape(),
680 mask_dense.shape()
681 ));
682 }
683 let mut selected = Vec::new();
684 for (x_val, mask_val) in x_dense.view().iter().zip(mask_dense.view().iter()) {
685 if *mask_val > T::zero() {
686 selected.push(*x_val);
687 }
688 }
689 use scirs2_core::ndarray_ext::{Array, IxDyn};
690 let result_array = Array::from_shape_vec(IxDyn(&[selected.len()]), selected)
691 .map_err(|e| anyhow!("Failed to create result array: {}", e))?;
692 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
693 result_array,
694 )))
695 }
696 fn modulo(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>> {
697 let dense = x
698 .as_dense()
699 .ok_or_else(|| anyhow!("Only dense tensors supported for modulo"))?;
700 if divisor == T::zero() {
701 return Err(anyhow!("Division by zero in modulo operation"));
702 }
703 let result_data = dense.view().mapv(|v| {
704 let quot = (v / divisor).floor();
705 v - quot * divisor
706 });
707 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
708 result_data,
709 )))
710 }
711 fn remainder(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>> {
712 self.modulo(x, divisor)
713 }
714 fn max_pool_1d(
715 &mut self,
716 x: &TensorHandle<T>,
717 kernel_size: usize,
718 stride: usize,
719 ) -> Result<TensorHandle<T>> {
720 let dense = x
721 .as_dense()
722 .ok_or_else(|| anyhow!("Only dense tensors supported for max_pool_1d"))?;
723 if kernel_size == 0 || stride == 0 {
724 return Err(anyhow!("Kernel size and stride must be positive"));
725 }
726 let shape = dense.shape();
727 if shape.len() != 1 {
728 return Err(anyhow!(
729 "Expected 1D tensor for max_pool_1d, got {:?}D",
730 shape.len()
731 ));
732 }
733 let input_len = shape[0];
734 if kernel_size > input_len {
735 return Err(anyhow!(
736 "Kernel size {} larger than input length {}",
737 kernel_size,
738 input_len
739 ));
740 }
741 let output_len = (input_len - kernel_size) / stride + 1;
742 let mut output = Vec::with_capacity(output_len);
743 let view = dense.view();
744 for i in 0..output_len {
745 let start = i * stride;
746 let end = start + kernel_size;
747 let max_val = (start..end)
748 .map(|j| view[[j]])
749 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
750 .unwrap_or_else(T::default);
751 output.push(max_val);
752 }
753 use scirs2_core::ndarray_ext::{Array, IxDyn};
754 let result_array = Array::from_shape_vec(IxDyn(&[output_len]), output)
755 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
756 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
757 result_array,
758 )))
759 }
760 fn avg_pool_1d(
761 &mut self,
762 x: &TensorHandle<T>,
763 kernel_size: usize,
764 stride: usize,
765 ) -> Result<TensorHandle<T>> {
766 let dense = x
767 .as_dense()
768 .ok_or_else(|| anyhow!("Only dense tensors supported for avg_pool_1d"))?;
769 if kernel_size == 0 || stride == 0 {
770 return Err(anyhow!("Kernel size and stride must be positive"));
771 }
772 let shape = dense.shape();
773 if shape.len() != 1 {
774 return Err(anyhow!(
775 "Expected 1D tensor for avg_pool_1d, got {:?}D",
776 shape.len()
777 ));
778 }
779 let input_len = shape[0];
780 if kernel_size > input_len {
781 return Err(anyhow!(
782 "Kernel size {} larger than input length {}",
783 kernel_size,
784 input_len
785 ));
786 }
787 let output_len = (input_len - kernel_size) / stride + 1;
788 let mut output = Vec::with_capacity(output_len);
789 let view = dense.view();
790 let kernel_size_t = T::from_usize(kernel_size).unwrap_or_else(T::one);
791 for i in 0..output_len {
792 let start = i * stride;
793 let end = start + kernel_size;
794 let mut sum = T::zero();
795 for j in start..end {
796 sum += view[[j]];
797 }
798 let avg = sum / kernel_size_t;
799 output.push(avg);
800 }
801 use scirs2_core::ndarray_ext::{Array, IxDyn};
802 let result_array = Array::from_shape_vec(IxDyn(&[output_len]), output)
803 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
804 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
805 result_array,
806 )))
807 }
808 fn max_pool_2d(
809 &mut self,
810 x: &TensorHandle<T>,
811 kernel_size: (usize, usize),
812 stride: (usize, usize),
813 ) -> Result<TensorHandle<T>> {
814 let dense = x
815 .as_dense()
816 .ok_or_else(|| anyhow!("Only dense tensors supported for max_pool_2d"))?;
817 let (kh, kw) = kernel_size;
818 let (sh, sw) = stride;
819 if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
820 return Err(anyhow!("Kernel size and stride must be positive"));
821 }
822 let shape = dense.shape();
823 if shape.len() != 2 {
824 return Err(anyhow!(
825 "Expected 2D tensor for max_pool_2d, got {:?}D",
826 shape.len()
827 ));
828 }
829 let (h, w) = (shape[0], shape[1]);
830 if kh > h || kw > w {
831 return Err(anyhow!(
832 "Kernel size ({}, {}) larger than input ({}, {})",
833 kh,
834 kw,
835 h,
836 w
837 ));
838 }
839 let out_h = (h - kh) / sh + 1;
840 let out_w = (w - kw) / sw + 1;
841 let output_shape = [out_h, out_w];
842
843 let mut output = self.acquire_pooled_generic::<T>(&output_shape);
845 output.clear(); output.reserve(out_h * out_w);
847
848 let view = dense.view();
849 for i in 0..out_h {
850 for j in 0..out_w {
851 let start_h = i * sh;
852 let start_w = j * sw;
853 let mut max_val = T::default();
854 let mut first = true;
855 for di in 0..kh {
856 for dj in 0..kw {
857 let val = view[[start_h + di, start_w + dj]];
858 if first || val > max_val {
859 max_val = val;
860 first = false;
861 }
862 }
863 }
864 output.push(max_val);
865 }
866 }
867
868 use scirs2_core::ndarray_ext::{Array, IxDyn};
869 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
870 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
871 self.release_pooled_generic::<T>(&output_shape, output);
872 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
873 result_array,
874 )))
875 }
876 fn avg_pool_2d(
877 &mut self,
878 x: &TensorHandle<T>,
879 kernel_size: (usize, usize),
880 stride: (usize, usize),
881 ) -> Result<TensorHandle<T>> {
882 let dense = x
883 .as_dense()
884 .ok_or_else(|| anyhow!("Only dense tensors supported for avg_pool_2d"))?;
885 let (kh, kw) = kernel_size;
886 let (sh, sw) = stride;
887 if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
888 return Err(anyhow!("Kernel size and stride must be positive"));
889 }
890 let shape = dense.shape();
891 if shape.len() != 2 {
892 return Err(anyhow!(
893 "Expected 2D tensor for avg_pool_2d, got {:?}D",
894 shape.len()
895 ));
896 }
897 let (h, w) = (shape[0], shape[1]);
898 if kh > h || kw > w {
899 return Err(anyhow!(
900 "Kernel size ({}, {}) larger than input ({}, {})",
901 kh,
902 kw,
903 h,
904 w
905 ));
906 }
907 let out_h = (h - kh) / sh + 1;
908 let out_w = (w - kw) / sw + 1;
909 let output_shape = [out_h, out_w];
910
911 let mut output = self.acquire_pooled_generic::<T>(&output_shape);
913 output.clear(); output.reserve(out_h * out_w);
915
916 let view = dense.view();
917 let kernel_count = T::from_usize(kh * kw).unwrap_or_else(T::one);
918 for i in 0..out_h {
919 for j in 0..out_w {
920 let start_h = i * sh;
921 let start_w = j * sw;
922 let mut sum = T::zero();
923 for di in 0..kh {
924 for dj in 0..kw {
925 sum += view[[start_h + di, start_w + dj]];
926 }
927 }
928 let avg = sum / kernel_count;
929 output.push(avg);
930 }
931 }
932
933 use scirs2_core::ndarray_ext::{Array, IxDyn};
934 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
935 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
936 self.release_pooled_generic::<T>(&output_shape, output);
937 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
938 result_array,
939 )))
940 }
941 fn conv1d(
942 &mut self,
943 x: &TensorHandle<T>,
944 kernel: &TensorHandle<T>,
945 bias: Option<&TensorHandle<T>>,
946 stride: usize,
947 padding: (usize, usize),
948 ) -> Result<TensorHandle<T>> {
949 let dense_x = x
950 .as_dense()
951 .ok_or_else(|| anyhow!("Only dense tensors supported for conv1d"))?;
952 let dense_kernel = kernel
953 .as_dense()
954 .ok_or_else(|| anyhow!("Only dense tensors supported for conv1d kernel"))?;
955 if stride == 0 {
956 return Err(anyhow!("Stride must be positive"));
957 }
958 let x_shape = dense_x.shape();
959 let k_shape = dense_kernel.shape();
960 if x_shape.len() != 3 {
961 return Err(anyhow!(
962 "Expected 3D input tensor [batch, in_channels, length], got {:?}D",
963 x_shape.len()
964 ));
965 }
966 if k_shape.len() != 3 {
967 return Err(anyhow!(
968 "Expected 3D kernel tensor [out_channels, in_channels, kernel_size], got {:?}D",
969 k_shape.len()
970 ));
971 }
972 let (batch, in_channels, in_length) = (x_shape[0], x_shape[1], x_shape[2]);
973 let (out_channels, k_in_channels, kernel_size) = (k_shape[0], k_shape[1], k_shape[2]);
974 if in_channels != k_in_channels {
975 return Err(anyhow!(
976 "Input channels mismatch: input has {}, kernel expects {}",
977 in_channels,
978 k_in_channels
979 ));
980 }
981 if let Some(bias_tensor) = bias {
982 let bias_dense = bias_tensor
983 .as_dense()
984 .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
985 let bias_shape = bias_dense.shape();
986 if bias_shape.len() != 1 || bias_shape[0] != out_channels {
987 return Err(anyhow!(
988 "Expected bias shape [{}], got {:?}",
989 out_channels,
990 bias_shape
991 ));
992 }
993 }
994 let (pad_left, pad_right) = padding;
995 let padded_length = in_length + pad_left + pad_right;
996 if kernel_size > padded_length {
997 return Err(anyhow!(
998 "Kernel size {} larger than padded input length {}",
999 kernel_size,
1000 padded_length
1001 ));
1002 }
1003 let out_length = (padded_length - kernel_size) / stride + 1;
1004 let output_shape = [batch, out_channels, out_length];
1005
1006 let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1008 output.clear(); output.resize(batch * out_channels * out_length, T::zero());
1010
1011 let x_view = dense_x.view();
1012 let k_view = dense_kernel.view();
1013 for b in 0..batch {
1014 for oc in 0..out_channels {
1015 for o in 0..out_length {
1016 let mut sum = T::zero();
1017 let in_start = (o * stride) as isize - pad_left as isize;
1018 for ic in 0..in_channels {
1019 for k in 0..kernel_size {
1020 let in_pos = in_start + k as isize;
1021 if in_pos >= 0 && (in_pos as usize) < in_length {
1022 let x_val = x_view[[b, ic, in_pos as usize]];
1023 let k_val = k_view[[oc, ic, k]];
1024 sum += x_val * k_val;
1025 }
1026 }
1027 }
1028 if let Some(bias_tensor) = bias {
1029 let bias_dense = bias_tensor.as_dense().unwrap();
1030 let bias_view = bias_dense.view();
1031 sum += bias_view[[oc]];
1032 }
1033 output[b * out_channels * out_length + oc * out_length + o] = sum;
1034 }
1035 }
1036 }
1037
1038 use scirs2_core::ndarray_ext::{Array, IxDyn};
1039 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1040 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1041 self.release_pooled_generic::<T>(&output_shape, output);
1042 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1043 result_array,
1044 )))
1045 }
1046 fn conv2d(
1047 &mut self,
1048 x: &TensorHandle<T>,
1049 kernel: &TensorHandle<T>,
1050 bias: Option<&TensorHandle<T>>,
1051 stride: (usize, usize),
1052 padding: (usize, usize, usize, usize),
1053 ) -> Result<TensorHandle<T>> {
1054 let dense_x = x
1055 .as_dense()
1056 .ok_or_else(|| anyhow!("Only dense tensors supported for conv2d"))?;
1057 let dense_kernel = kernel
1058 .as_dense()
1059 .ok_or_else(|| anyhow!("Only dense tensors supported for conv2d kernel"))?;
1060 let (stride_h, stride_w) = stride;
1061 if stride_h == 0 || stride_w == 0 {
1062 return Err(anyhow!("Stride must be positive"));
1063 }
1064 let x_shape = dense_x.shape();
1065 let k_shape = dense_kernel.shape();
1066 if x_shape.len() != 4 {
1067 return Err(anyhow!(
1068 "Expected 4D input tensor [batch, in_channels, height, width], got {:?}D",
1069 x_shape.len()
1070 ));
1071 }
1072 if k_shape.len() != 4 {
1073 return Err(
1074 anyhow!(
1075 "Expected 4D kernel tensor [out_channels, in_channels, kernel_h, kernel_w], got {:?}D",
1076 k_shape.len()
1077 ),
1078 );
1079 }
1080 let (batch, in_channels, in_h, in_w) = (x_shape[0], x_shape[1], x_shape[2], x_shape[3]);
1081 let (out_channels, k_in_channels, kernel_h, kernel_w) =
1082 (k_shape[0], k_shape[1], k_shape[2], k_shape[3]);
1083 if in_channels != k_in_channels {
1084 return Err(anyhow!(
1085 "Input channels mismatch: input has {}, kernel expects {}",
1086 in_channels,
1087 k_in_channels
1088 ));
1089 }
1090 if let Some(bias_tensor) = bias {
1091 let bias_dense = bias_tensor
1092 .as_dense()
1093 .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
1094 let bias_shape = bias_dense.shape();
1095 if bias_shape.len() != 1 || bias_shape[0] != out_channels {
1096 return Err(anyhow!(
1097 "Expected bias shape [{}], got {:?}",
1098 out_channels,
1099 bias_shape
1100 ));
1101 }
1102 }
1103 let (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right) = padding;
1104 let padded_h = in_h + pad_h_top + pad_h_bottom;
1105 let padded_w = in_w + pad_w_left + pad_w_right;
1106 if kernel_h > padded_h || kernel_w > padded_w {
1107 return Err(anyhow!(
1108 "Kernel size ({}, {}) larger than padded input ({}, {})",
1109 kernel_h,
1110 kernel_w,
1111 padded_h,
1112 padded_w
1113 ));
1114 }
1115 let out_h = (padded_h - kernel_h) / stride_h + 1;
1116 let out_w = (padded_w - kernel_w) / stride_w + 1;
1117 let output_shape = [batch, out_channels, out_h, out_w];
1118
1119 let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1121 output.clear(); output.resize(batch * out_channels * out_h * out_w, T::zero());
1123
1124 let x_view = dense_x.view();
1125 let k_view = dense_kernel.view();
1126 for b in 0..batch {
1127 for oc in 0..out_channels {
1128 for oh in 0..out_h {
1129 for ow in 0..out_w {
1130 let mut sum = T::zero();
1131 let in_start_h = (oh * stride_h) as isize - pad_h_top as isize;
1132 let in_start_w = (ow * stride_w) as isize - pad_w_left as isize;
1133 for ic in 0..in_channels {
1134 for kh in 0..kernel_h {
1135 for kw in 0..kernel_w {
1136 let in_h_pos = in_start_h + kh as isize;
1137 let in_w_pos = in_start_w + kw as isize;
1138 if in_h_pos >= 0
1139 && (in_h_pos as usize) < in_h
1140 && in_w_pos >= 0
1141 && (in_w_pos as usize) < in_w
1142 {
1143 let x_val =
1144 x_view[[b, ic, in_h_pos as usize, in_w_pos as usize]];
1145 let k_val = k_view[[oc, ic, kh, kw]];
1146 sum += x_val * k_val;
1147 }
1148 }
1149 }
1150 }
1151 if let Some(bias_tensor) = bias {
1152 let bias_dense = bias_tensor.as_dense().unwrap();
1153 let bias_view = bias_dense.view();
1154 sum += bias_view[[oc]];
1155 }
1156 let out_idx = ((b * out_channels + oc) * out_h + oh) * out_w + ow;
1157 output[out_idx] = sum;
1158 }
1159 }
1160 }
1161 }
1162
1163 use scirs2_core::ndarray_ext::{Array, IxDyn};
1164 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1165 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1166 self.release_pooled_generic::<T>(&output_shape, output);
1167 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1168 result_array,
1169 )))
1170 }
1171 fn gather(
1172 &mut self,
1173 x: &TensorHandle<T>,
1174 axis: Axis,
1175 indices: &TensorHandle<T>,
1176 ) -> Result<TensorHandle<T>> {
1177 let dense_x = x
1178 .as_dense()
1179 .ok_or_else(|| anyhow!("Only dense tensors supported for gather"))?;
1180 let dense_indices = indices
1181 .as_dense()
1182 .ok_or_else(|| anyhow!("Only dense tensors supported for indices"))?;
1183 let x_shape = dense_x.shape();
1184 let axis_idx = axis;
1185 if axis_idx >= x_shape.len() {
1186 return Err(anyhow!(
1187 "Axis {} out of bounds for tensor with {} dimensions",
1188 axis_idx,
1189 x_shape.len()
1190 ));
1191 }
1192 let indices_shape = dense_indices.shape();
1193 let num_indices: usize = indices_shape.iter().product();
1194 let mut output_shape = Vec::new();
1195 for (i, &dim) in x_shape.iter().enumerate() {
1196 if i == axis_idx {
1197 output_shape.extend_from_slice(indices_shape);
1198 } else if i != axis_idx {
1199 output_shape.push(dim);
1200 }
1201 }
1202 let output_size: usize = output_shape.iter().product();
1203 let mut output = Vec::with_capacity(output_size);
1204 let x_view = dense_x.view();
1205 let indices_view = dense_indices.view();
1206 if axis_idx == 0 && indices_shape.len() == 1 {
1207 let axis_size = x_shape[0];
1208 let elements_per_item: usize = x_shape[1..].iter().product();
1209 for idx_flat in 0..num_indices {
1210 let idx_multi = self.flat_to_multidim(idx_flat, indices_shape);
1211 let idx_val = indices_view[idx_multi.as_slice()];
1212 let idx = idx_val
1213 .to_usize()
1214 .ok_or_else(|| anyhow!("Invalid index value: cannot convert to usize"))?;
1215 if idx >= axis_size {
1216 return Err(anyhow!(
1217 "Index {} out of bounds for axis {} with size {}",
1218 idx,
1219 axis_idx,
1220 axis_size
1221 ));
1222 }
1223 for elem_idx in 0..elements_per_item {
1224 let mut x_index = vec![idx];
1225 let elem_multi = self.flat_to_multidim(elem_idx, &x_shape[1..]);
1226 x_index.extend(elem_multi);
1227 output.push(x_view[x_index.as_slice()]);
1228 }
1229 }
1230 } else {
1231 return Err(anyhow!(
1232 "Gather only supports axis=0 with 1D indices in this implementation"
1233 ));
1234 }
1235 use scirs2_core::ndarray_ext::{Array, IxDyn};
1236 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output)
1237 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1238 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1239 result_array,
1240 )))
1241 }
1242 fn scatter(
1243 &mut self,
1244 shape: &[usize],
1245 axis: Axis,
1246 indices: &TensorHandle<T>,
1247 values: &TensorHandle<T>,
1248 ) -> Result<TensorHandle<T>> {
1249 let dense_indices = indices
1250 .as_dense()
1251 .ok_or_else(|| anyhow!("Only dense tensors supported for indices"))?;
1252 let dense_values = values
1253 .as_dense()
1254 .ok_or_else(|| anyhow!("Only dense tensors supported for values"))?;
1255 let axis_idx = axis;
1256 if axis_idx >= shape.len() {
1257 return Err(anyhow!(
1258 "Axis {} out of bounds for output shape with {} dimensions",
1259 axis_idx,
1260 shape.len()
1261 ));
1262 }
1263 let indices_shape = dense_indices.shape();
1264 let values_shape = dense_values.shape();
1265 let num_indices: usize = indices_shape.iter().product();
1266 let mut expected_values_shape = Vec::new();
1267 expected_values_shape.extend_from_slice(&shape[..axis_idx]);
1268 expected_values_shape.extend_from_slice(indices_shape);
1269 expected_values_shape.extend_from_slice(&shape[axis_idx + 1..]);
1270 if values_shape != expected_values_shape.as_slice() {
1271 return Err(anyhow!(
1272 "Values shape {:?} doesn't match expected shape {:?}",
1273 values_shape,
1274 expected_values_shape
1275 ));
1276 }
1277 let output_size: usize = shape.iter().product();
1278 let mut output = vec![T::zero(); output_size];
1279 let indices_view = dense_indices.view();
1280 let values_view = dense_values.view();
1281 if axis_idx == 0 && indices_shape.len() == 1 {
1282 let axis_size = shape[0];
1283 let elements_per_item: usize = shape[1..].iter().product();
1284 for idx_flat in 0..num_indices {
1285 let idx_multi = self.flat_to_multidim(idx_flat, indices_shape);
1286 let idx_val = indices_view[idx_multi.as_slice()];
1287 let idx = idx_val
1288 .to_usize()
1289 .ok_or_else(|| anyhow!("Invalid index value: cannot convert to usize"))?;
1290 if idx >= axis_size {
1291 return Err(anyhow!(
1292 "Index {} out of bounds for axis {} with size {}",
1293 idx,
1294 axis_idx,
1295 axis_size
1296 ));
1297 }
1298 for elem_idx in 0..elements_per_item {
1299 let mut values_index = vec![idx_flat];
1300 let elem_multi = self.flat_to_multidim(elem_idx, &shape[1..]);
1301 values_index.extend(elem_multi.clone());
1302 let mut out_index = vec![idx];
1303 out_index.extend(elem_multi);
1304 let out_flat = self.multidim_to_flat(&out_index, shape);
1305 output[out_flat] = values_view[values_index.as_slice()];
1306 }
1307 }
1308 } else {
1309 return Err(anyhow!(
1310 "Scatter only supports axis=0 with 1D indices in this implementation"
1311 ));
1312 }
1313 use scirs2_core::ndarray_ext::{Array, IxDyn};
1314 let result_array = Array::from_shape_vec(IxDyn(shape), output)
1315 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1316 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1317 result_array,
1318 )))
1319 }
1320 fn conv3d(
1321 &mut self,
1322 x: &TensorHandle<T>,
1323 kernel: &TensorHandle<T>,
1324 bias: Option<&TensorHandle<T>>,
1325 stride: (usize, usize, usize),
1326 padding: (usize, usize, usize, usize, usize, usize),
1327 ) -> Result<TensorHandle<T>> {
1328 let dense_x = x
1329 .as_dense()
1330 .ok_or_else(|| anyhow!("Only dense tensors supported for conv3d"))?;
1331 let dense_kernel = kernel
1332 .as_dense()
1333 .ok_or_else(|| anyhow!("Only dense tensors supported for conv3d kernel"))?;
1334 let (stride_d, stride_h, stride_w) = stride;
1335 if stride_d == 0 || stride_h == 0 || stride_w == 0 {
1336 return Err(anyhow!("Stride must be positive"));
1337 }
1338 let x_shape = dense_x.shape();
1339 let k_shape = dense_kernel.shape();
1340 if x_shape.len() != 5 {
1341 return Err(anyhow!(
1342 "Expected 5D input tensor [batch, in_channels, depth, height, width], got {:?}D",
1343 x_shape.len()
1344 ));
1345 }
1346 if k_shape.len() != 5 {
1347 return Err(
1348 anyhow!(
1349 "Expected 5D kernel tensor [out_channels, in_channels, kernel_d, kernel_h, kernel_w], got {:?}D",
1350 k_shape.len()
1351 ),
1352 );
1353 }
1354 let (batch, in_channels, in_d, in_h, in_w) =
1355 (x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]);
1356 let (out_channels, k_in_channels, kernel_d, kernel_h, kernel_w) =
1357 (k_shape[0], k_shape[1], k_shape[2], k_shape[3], k_shape[4]);
1358 if in_channels != k_in_channels {
1359 return Err(anyhow!(
1360 "Input channels mismatch: input has {}, kernel expects {}",
1361 in_channels,
1362 k_in_channels
1363 ));
1364 }
1365 if let Some(bias_tensor) = bias {
1366 let bias_dense = bias_tensor
1367 .as_dense()
1368 .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
1369 let bias_shape = bias_dense.shape();
1370 if bias_shape.len() != 1 || bias_shape[0] != out_channels {
1371 return Err(anyhow!(
1372 "Expected bias shape [{}], got {:?}",
1373 out_channels,
1374 bias_shape
1375 ));
1376 }
1377 }
1378 let (pad_d_front, pad_d_back, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right) = padding;
1379 let padded_d = in_d + pad_d_front + pad_d_back;
1380 let padded_h = in_h + pad_h_top + pad_h_bottom;
1381 let padded_w = in_w + pad_w_left + pad_w_right;
1382 if kernel_d > padded_d || kernel_h > padded_h || kernel_w > padded_w {
1383 return Err(anyhow!(
1384 "Kernel size ({}, {}, {}) larger than padded input ({}, {}, {})",
1385 kernel_d,
1386 kernel_h,
1387 kernel_w,
1388 padded_d,
1389 padded_h,
1390 padded_w
1391 ));
1392 }
1393 let out_d = (padded_d - kernel_d) / stride_d + 1;
1394 let out_h = (padded_h - kernel_h) / stride_h + 1;
1395 let out_w = (padded_w - kernel_w) / stride_w + 1;
1396 let output_shape = [batch, out_channels, out_d, out_h, out_w];
1397
1398 let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1400 output.clear(); output.resize(batch * out_channels * out_d * out_h * out_w, T::zero());
1402
1403 let x_view = dense_x.view();
1404 let k_view = dense_kernel.view();
1405 for b in 0..batch {
1406 for oc in 0..out_channels {
1407 for od in 0..out_d {
1408 for oh in 0..out_h {
1409 for ow in 0..out_w {
1410 let mut sum = T::zero();
1411 let in_start_d = (od * stride_d) as isize - pad_d_front as isize;
1412 let in_start_h = (oh * stride_h) as isize - pad_h_top as isize;
1413 let in_start_w = (ow * stride_w) as isize - pad_w_left as isize;
1414 for ic in 0..in_channels {
1415 for kd in 0..kernel_d {
1416 for kh in 0..kernel_h {
1417 for kw in 0..kernel_w {
1418 let in_d_pos = in_start_d + kd as isize;
1419 let in_h_pos = in_start_h + kh as isize;
1420 let in_w_pos = in_start_w + kw as isize;
1421 if in_d_pos >= 0
1422 && (in_d_pos as usize) < in_d
1423 && in_h_pos >= 0
1424 && (in_h_pos as usize) < in_h
1425 && in_w_pos >= 0
1426 && (in_w_pos as usize) < in_w
1427 {
1428 let x_val = x_view[[
1429 b,
1430 ic,
1431 in_d_pos as usize,
1432 in_h_pos as usize,
1433 in_w_pos as usize,
1434 ]];
1435 let k_val = k_view[[oc, ic, kd, kh, kw]];
1436 sum += x_val * k_val;
1437 }
1438 }
1439 }
1440 }
1441 }
1442 if let Some(bias_tensor) = bias {
1443 let bias_dense = bias_tensor.as_dense().unwrap();
1444 let bias_view = bias_dense.view();
1445 sum += bias_view[[oc]];
1446 }
1447 let out_idx = ((((b * out_channels + oc) * out_d + od) * out_h + oh)
1448 * out_w)
1449 + ow;
1450 output[out_idx] = sum;
1451 }
1452 }
1453 }
1454 }
1455 }
1456
1457 use scirs2_core::ndarray_ext::{Array, IxDyn};
1458 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1459 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1460 self.release_pooled_generic::<T>(&output_shape, output);
1461 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1462 result_array,
1463 )))
1464 }
1465 fn determinant(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1466 let dense = x
1467 .as_dense()
1468 .ok_or_else(|| anyhow!("Only dense tensors supported for determinant"))?;
1469 let shape = dense.shape();
1470 if shape.len() < 2 {
1471 return Err(anyhow!("Input must be at least 2D for determinant"));
1472 }
1473 let n = shape[shape.len() - 1];
1474 let m = shape[shape.len() - 2];
1475 if n != m {
1476 return Err(anyhow!(
1477 "Last two dimensions must be square for determinant, got {}x{}",
1478 m,
1479 n
1480 ));
1481 }
1482 if shape.len() == 2 {
1483 use scirs2_core::ndarray_ext::Array2;
1484 let view = dense.view();
1485 let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| view[[i, j]]);
1486 let det = self.compute_determinant_2d(&matrix)?;
1487 return Ok(TensorHandle::from_dense_auto(DenseND::from_vec(
1488 vec![det],
1489 &[],
1490 )?));
1491 }
1492 let batch_size: usize = shape[..shape.len() - 2].iter().product();
1493 let mut determinants = Vec::with_capacity(batch_size);
1494 let view = dense.view();
1495 for batch_idx in 0..batch_size {
1496 let batch_multi = self.flat_to_multidim(batch_idx, &shape[..shape.len() - 2]);
1497 use scirs2_core::ndarray_ext::Array2;
1498 let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| {
1499 let mut idx = batch_multi.clone();
1500 idx.push(i);
1501 idx.push(j);
1502 view[idx.as_slice()]
1503 });
1504 let det = self.compute_determinant_2d(&matrix)?;
1505 determinants.push(det);
1506 }
1507 let output_shape = &shape[..shape.len() - 2];
1508 use scirs2_core::ndarray_ext::{Array, IxDyn};
1509 let result_array = Array::from_shape_vec(IxDyn(output_shape), determinants)
1510 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1511 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1512 result_array,
1513 )))
1514 }
1515 fn matrix_inverse(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1516 let dense = x
1517 .as_dense()
1518 .ok_or_else(|| anyhow!("Only dense tensors supported for matrix_inverse"))?;
1519 let shape = dense.shape();
1520 if shape.len() < 2 {
1521 return Err(anyhow!("Input must be at least 2D for matrix inverse"));
1522 }
1523 let n = shape[shape.len() - 1];
1524 let m = shape[shape.len() - 2];
1525 if n != m {
1526 return Err(anyhow!(
1527 "Last two dimensions must be square for matrix inverse, got {}x{}",
1528 m,
1529 n
1530 ));
1531 }
1532 if shape.len() == 2 {
1533 use scirs2_core::ndarray_ext::Array2;
1534 let view = dense.view();
1535 let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| view[[i, j]]);
1536 let inv = self.compute_inverse_2d(&matrix)?;
1537 let inv_dyn = inv.into_dyn();
1538 return Ok(TensorHandle::from_dense_auto(DenseND::from_array(inv_dyn)));
1539 }
1540 let batch_size: usize = shape[..shape.len() - 2].iter().product();
1541 let output_size = batch_size * n * n;
1542 let mut output = Vec::with_capacity(output_size);
1543 let view = dense.view();
1544 for batch_idx in 0..batch_size {
1545 let batch_multi = self.flat_to_multidim(batch_idx, &shape[..shape.len() - 2]);
1546 use scirs2_core::ndarray_ext::Array2;
1547 let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| {
1548 let mut idx = batch_multi.clone();
1549 idx.push(i);
1550 idx.push(j);
1551 view[idx.as_slice()]
1552 });
1553 let inv = self.compute_inverse_2d(&matrix)?;
1554 output.extend(inv.iter().copied());
1555 }
1556 use scirs2_core::ndarray_ext::{Array, IxDyn};
1557 let result_array = Array::from_shape_vec(IxDyn(shape), output)
1558 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1559 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1560 result_array,
1561 )))
1562 }
1563 fn solve(&mut self, a: &TensorHandle<T>, b: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1564 let dense_a = a
1565 .as_dense()
1566 .ok_or_else(|| anyhow!("Only dense tensors supported for solve (A)"))?;
1567 let dense_b = b
1568 .as_dense()
1569 .ok_or_else(|| anyhow!("Only dense tensors supported for solve (b)"))?;
1570 let a_shape = dense_a.shape();
1571 let b_shape = dense_b.shape();
1572 if a_shape.len() < 2 {
1573 return Err(anyhow!("Matrix A must be at least 2D"));
1574 }
1575 if b_shape.is_empty() {
1576 return Err(anyhow!("Vector/matrix b must be at least 1D"));
1577 }
1578 let n = a_shape[a_shape.len() - 1];
1579 let m = a_shape[a_shape.len() - 2];
1580 if n != m {
1581 return Err(anyhow!("Matrix A must be square, got {}x{}", m, n));
1582 }
1583 let b_rows = b_shape[b_shape.len()
1584 - (if b_shape.len() == a_shape.len() - 1 {
1585 1
1586 } else {
1587 2
1588 })];
1589 if b_rows != n {
1590 return Err(anyhow!(
1591 "Dimension mismatch: A is {}x{}, b has {} rows",
1592 m,
1593 n,
1594 b_rows
1595 ));
1596 }
1597 if a_shape.len() == 2 && b_shape.len() == 1 {
1598 use scirs2_core::ndarray_ext::{Array1, Array2};
1599 let a_view = dense_a.view();
1600 let b_view = dense_b.view();
1601 let a_matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| a_view[[i, j]]);
1602 let b_vector: Array1<T> = Array1::from_shape_fn(n, |i| b_view[[i]]);
1603 let x = self.solve_2d_1d(&a_matrix, &b_vector)?;
1604 let x_dyn = x.into_dyn();
1605 return Ok(TensorHandle::from_dense_auto(DenseND::from_array(x_dyn)));
1606 }
1607 Err(anyhow!(
1608 "Solve only supports 2D matrix A with 1D vector b in this implementation"
1609 ))
1610 }
1611
1612 fn advanced_gather(
1613 &mut self,
1614 x: &TensorHandle<T>,
1615 axis: Axis,
1616 indices: &TensorHandle<T>,
1617 allow_negative: bool,
1618 ) -> Result<TensorHandle<T>> {
1619 let dense = x
1620 .as_dense()
1621 .ok_or_else(|| anyhow!("Only dense tensors supported for advanced_gather"))?;
1622 let indices_dense = indices
1623 .as_dense()
1624 .ok_or_else(|| anyhow!("Indices must be dense tensor"))?;
1625
1626 let result =
1627 super::advanced_indexing::advanced_gather(dense, axis, indices_dense, allow_negative)?;
1628 Ok(TensorHandle::from_dense_auto(result))
1629 }
1630
1631 fn advanced_scatter(
1632 &mut self,
1633 shape: &[usize],
1634 axis: Axis,
1635 indices: &TensorHandle<T>,
1636 values: &TensorHandle<T>,
1637 mode: ScatterMode,
1638 ) -> Result<TensorHandle<T>> {
1639 let indices_dense = indices
1640 .as_dense()
1641 .ok_or_else(|| anyhow!("Indices must be dense tensor"))?;
1642 let values_dense = values
1643 .as_dense()
1644 .ok_or_else(|| anyhow!("Values must be dense tensor"))?;
1645
1646 let result = super::advanced_indexing::advanced_scatter(
1647 shape,
1648 axis,
1649 indices_dense,
1650 values_dense,
1651 mode,
1652 )?;
1653 Ok(TensorHandle::from_dense_auto(result))
1654 }
1655
1656 fn fancy_index_mask(
1657 &mut self,
1658 x: &TensorHandle<T>,
1659 mask: &TensorHandle<T>,
1660 ) -> Result<TensorHandle<T>> {
1661 let dense = x
1662 .as_dense()
1663 .ok_or_else(|| anyhow!("Only dense tensors supported for fancy_index_mask"))?;
1664 let mask_dense = mask
1665 .as_dense()
1666 .ok_or_else(|| anyhow!("Mask must be dense tensor"))?;
1667
1668 let result = super::advanced_indexing::fancy_index_mask(dense, mask_dense)?;
1669 Ok(TensorHandle::from_dense_auto(result))
1670 }
1671
1672 fn tile(&mut self, x: &TensorHandle<T>, reps: &[usize]) -> Result<TensorHandle<T>> {
1673 let dense = x
1674 .as_dense()
1675 .ok_or_else(|| anyhow!("Only dense tensors supported for tile"))?;
1676
1677 let input_shape = dense.shape();
1678
1679 if reps.len() != input_shape.len() {
1681 return Err(anyhow!(
1682 "Reps length {} must match input dimensions {}",
1683 reps.len(),
1684 input_shape.len()
1685 ));
1686 }
1687
1688 let output_shape: Vec<usize> = input_shape
1690 .iter()
1691 .zip(reps.iter())
1692 .map(|(&dim, &rep)| dim * rep)
1693 .collect();
1694
1695 let input_view = dense.view();
1696 let output_size: usize = output_shape.iter().product();
1697
1698 let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
1700 output_data.clear(); output_data.reserve(output_size);
1702
1703 for i in 0..output_size {
1705 let out_idx = self.flat_to_multidim(i, &output_shape);
1706 let in_idx: Vec<usize> = out_idx
1708 .iter()
1709 .zip(input_shape.iter())
1710 .map(|(&o, &s)| o % s)
1711 .collect();
1712 output_data.push(input_view[in_idx.as_slice()]);
1713 }
1714
1715 use scirs2_core::ndarray_ext::{Array, IxDyn};
1716 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
1717 .map_err(|e| anyhow!("Failed to create tiled array: {}", e))?;
1718 self.release_pooled_generic::<T>(&output_shape, output_data);
1719
1720 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1721 result_array,
1722 )))
1723 }
1724
1725 fn pad(
1726 &mut self,
1727 x: &TensorHandle<T>,
1728 pad_width: &[(usize, usize)],
1729 constant_value: T,
1730 ) -> Result<TensorHandle<T>> {
1731 let dense = x
1732 .as_dense()
1733 .ok_or_else(|| anyhow!("Only dense tensors supported for pad"))?;
1734
1735 let input_shape = dense.shape();
1736
1737 if pad_width.len() != input_shape.len() {
1738 return Err(anyhow!(
1739 "Pad width length {} must match input dimensions {}",
1740 pad_width.len(),
1741 input_shape.len()
1742 ));
1743 }
1744
1745 let output_shape: Vec<usize> = input_shape
1747 .iter()
1748 .zip(pad_width.iter())
1749 .map(|(&dim, &(before, after))| dim + before + after)
1750 .collect();
1751
1752 let input_view = dense.view();
1753 let output_size: usize = output_shape.iter().product();
1754
1755 let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
1757 output_data.clear(); output_data.resize(output_size, constant_value); let input_size: usize = input_shape.iter().product();
1762 for i in 0..input_size {
1763 let in_idx = self.flat_to_multidim(i, input_shape);
1764 let out_idx: Vec<usize> = in_idx
1766 .iter()
1767 .zip(pad_width.iter())
1768 .map(|(&idx, &(before, _))| idx + before)
1769 .collect();
1770 let out_flat = self.multidim_to_flat(&out_idx, &output_shape);
1771 output_data[out_flat] = input_view[in_idx.as_slice()];
1772 }
1773
1774 use scirs2_core::ndarray_ext::{Array, IxDyn};
1775 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
1776 .map_err(|e| anyhow!("Failed to create padded array: {}", e))?;
1777 self.release_pooled_generic::<T>(&output_shape, output_data);
1778
1779 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1780 result_array,
1781 )))
1782 }
1783
1784 fn flip(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>> {
1785 let dense = x
1786 .as_dense()
1787 .ok_or_else(|| anyhow!("Only dense tensors supported for flip"))?;
1788
1789 let shape = dense.shape();
1790
1791 for &axis in axes {
1793 if axis >= shape.len() {
1794 return Err(anyhow!(
1795 "Axis {} out of bounds for tensor with {} dimensions",
1796 axis,
1797 shape.len()
1798 ));
1799 }
1800 }
1801
1802 let input_view = dense.view();
1803 let total_elements: usize = shape.iter().product();
1804
1805 let mut output_data = self.acquire_pooled_generic::<T>(shape);
1807 output_data.clear(); output_data.reserve(total_elements);
1809
1810 for i in 0..total_elements {
1812 let out_idx = self.flat_to_multidim(i, shape);
1813 let in_idx: Vec<usize> = out_idx
1815 .iter()
1816 .enumerate()
1817 .map(|(axis, &idx)| {
1818 if axes.contains(&axis) {
1819 shape[axis] - 1 - idx
1821 } else {
1822 idx
1823 }
1824 })
1825 .collect();
1826 output_data.push(input_view[in_idx.as_slice()]);
1827 }
1828
1829 use scirs2_core::ndarray_ext::{Array, IxDyn};
1830 let result_array = Array::from_shape_vec(IxDyn(shape), output_data.clone())
1831 .map_err(|e| anyhow!("Failed to create flipped array: {}", e))?;
1832 self.release_pooled_generic::<T>(shape, output_data);
1833
1834 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1835 result_array,
1836 )))
1837 }
1838
1839 fn squeeze(&mut self, x: &TensorHandle<T>, axes: Option<&[Axis]>) -> Result<TensorHandle<T>> {
1840 let dense = x
1841 .as_dense()
1842 .ok_or_else(|| anyhow!("Only dense tensors supported for squeeze"))?;
1843
1844 let shape = dense.shape();
1845
1846 let axes_to_squeeze: Vec<usize> = if let Some(ax) = axes {
1848 for &axis in ax {
1850 if axis >= shape.len() {
1851 return Err(anyhow!(
1852 "Axis {} out of bounds for tensor with {} dimensions",
1853 axis,
1854 shape.len()
1855 ));
1856 }
1857 if shape[axis] != 1 {
1858 return Err(anyhow!(
1859 "Cannot squeeze axis {} with size {}",
1860 axis,
1861 shape[axis]
1862 ));
1863 }
1864 }
1865 ax.to_vec()
1866 } else {
1867 shape
1869 .iter()
1870 .enumerate()
1871 .filter_map(|(i, &s)| if s == 1 { Some(i) } else { None })
1872 .collect()
1873 };
1874
1875 let new_shape: Vec<usize> = shape
1877 .iter()
1878 .enumerate()
1879 .filter_map(|(i, &s)| {
1880 if axes_to_squeeze.contains(&i) {
1881 None
1882 } else {
1883 Some(s)
1884 }
1885 })
1886 .collect();
1887
1888 if new_shape.len() == shape.len() {
1890 return Ok(x.clone());
1891 }
1892
1893 self.reshape(x, &new_shape)
1895 }
1896
1897 fn unsqueeze(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
1898 let dense = x
1899 .as_dense()
1900 .ok_or_else(|| anyhow!("Only dense tensors supported for unsqueeze"))?;
1901
1902 let shape = dense.shape();
1903
1904 if axis > shape.len() {
1906 return Err(anyhow!(
1907 "Axis {} out of bounds for unsqueeze (max {})",
1908 axis,
1909 shape.len()
1910 ));
1911 }
1912
1913 let mut new_shape = shape.to_vec();
1915 new_shape.insert(axis, 1);
1916
1917 self.reshape(x, &new_shape)
1918 }
1919
1920 fn stack(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>> {
1921 if tensors.is_empty() {
1922 return Err(anyhow!("Cannot stack empty sequence of tensors"));
1923 }
1924
1925 let first_shape = tensors[0]
1927 .as_dense()
1928 .ok_or_else(|| anyhow!("Only dense tensors supported for stack"))?
1929 .shape();
1930
1931 for (i, tensor) in tensors.iter().enumerate().skip(1) {
1933 let shape = tensor
1934 .as_dense()
1935 .ok_or_else(|| anyhow!("Only dense tensors supported for stack"))?
1936 .shape();
1937 if shape != first_shape {
1938 return Err(anyhow!(
1939 "All tensors must have the same shape for stacking. Tensor 0: {:?}, Tensor {}: {:?}",
1940 first_shape,
1941 i,
1942 shape
1943 ));
1944 }
1945 }
1946
1947 if axis > first_shape.len() {
1948 return Err(anyhow!(
1949 "Axis {} out of bounds for stack (max {})",
1950 axis,
1951 first_shape.len()
1952 ));
1953 }
1954
1955 let mut unsqueezed = Vec::new();
1957 for tensor in tensors {
1958 unsqueezed.push(self.unsqueeze(tensor, axis)?);
1959 }
1960
1961 self.concatenate(&unsqueezed, axis)
1963 }
1964
1965 fn repeat(
1966 &mut self,
1967 x: &TensorHandle<T>,
1968 repeats: usize,
1969 axis: Axis,
1970 ) -> Result<TensorHandle<T>> {
1971 let dense = x
1972 .as_dense()
1973 .ok_or_else(|| anyhow!("Only dense tensors supported for repeat"))?;
1974
1975 let shape = dense.shape();
1976
1977 if axis >= shape.len() {
1978 return Err(anyhow!(
1979 "Axis {} out of bounds for tensor with {} dimensions",
1980 axis,
1981 shape.len()
1982 ));
1983 }
1984
1985 if repeats == 0 {
1986 return Err(anyhow!("Repeat count must be greater than 0"));
1987 }
1988
1989 if repeats == 1 {
1990 return Ok(x.clone());
1991 }
1992
1993 let mut output_shape = shape.to_vec();
1995 output_shape[axis] *= repeats;
1996
1997 let input_view = dense.view();
1998 let total_elements: usize = output_shape.iter().product();
1999 let mut output_data = Vec::with_capacity(total_elements);
2000
2001 for i in 0..total_elements {
2003 let out_idx = self.flat_to_multidim(i, &output_shape);
2004 let mut in_idx = out_idx.clone();
2006 in_idx[axis] /= repeats;
2007 output_data.push(input_view[in_idx.as_slice()]);
2008 }
2009
2010 use scirs2_core::ndarray_ext::{Array, IxDyn};
2011 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data)
2012 .map_err(|e| anyhow!("Failed to create repeated array: {}", e))?;
2013
2014 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2015 result_array,
2016 )))
2017 }
2018
2019 fn roll(&mut self, x: &TensorHandle<T>, shift: isize, axis: Axis) -> Result<TensorHandle<T>> {
2020 let dense = x
2021 .as_dense()
2022 .ok_or_else(|| anyhow!("Only dense tensors supported for roll"))?;
2023
2024 let shape = dense.shape();
2025
2026 if axis >= shape.len() {
2027 return Err(anyhow!(
2028 "Axis {} out of bounds for tensor with {} dimensions",
2029 axis,
2030 shape.len()
2031 ));
2032 }
2033
2034 if shift == 0 {
2035 return Ok(x.clone());
2036 }
2037
2038 let axis_size = shape[axis] as isize;
2039 let normalized_shift = ((shift % axis_size) + axis_size) % axis_size;
2041
2042 if normalized_shift == 0 {
2043 return Ok(x.clone());
2044 }
2045
2046 let input_view = dense.view();
2047 let total_elements: usize = shape.iter().product();
2048 let mut output_data = Vec::with_capacity(total_elements);
2049
2050 for i in 0..total_elements {
2052 let out_idx = self.flat_to_multidim(i, shape);
2053 let mut in_idx = out_idx.clone();
2055 let old_idx = out_idx[axis] as isize;
2056 let new_idx = ((old_idx - normalized_shift + axis_size) % axis_size) as usize;
2057 in_idx[axis] = new_idx;
2058 output_data.push(input_view[in_idx.as_slice()]);
2059 }
2060
2061 use scirs2_core::ndarray_ext::{Array, IxDyn};
2062 let result_array = Array::from_shape_vec(IxDyn(shape), output_data)
2063 .map_err(|e| anyhow!("Failed to create rolled array: {}", e))?;
2064
2065 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2066 result_array,
2067 )))
2068 }
2069
2070 fn argmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
2071 let dense = x
2072 .as_dense()
2073 .ok_or_else(|| anyhow!("Only dense tensors supported for argmax"))?;
2074
2075 let shape = dense.shape();
2076
2077 if axis >= shape.len() {
2078 return Err(anyhow!(
2079 "Axis {} out of bounds for tensor with {} dimensions",
2080 axis,
2081 shape.len()
2082 ));
2083 }
2084
2085 let output_shape: Vec<usize> = shape
2087 .iter()
2088 .enumerate()
2089 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
2090 .collect();
2091
2092 let input_view = dense.view();
2093 let output_size: usize = if output_shape.is_empty() {
2094 1
2095 } else {
2096 output_shape.iter().product()
2097 };
2098 let mut output_data = Vec::with_capacity(output_size);
2099
2100 for i in 0..output_size {
2102 let base_idx = if output_shape.is_empty() {
2103 vec![]
2104 } else {
2105 self.flat_to_multidim(i, &output_shape)
2106 };
2107
2108 let mut max_val = T::from_f64(f64::NEG_INFINITY).unwrap();
2109 let mut max_idx = 0usize;
2110
2111 for j in 0..shape[axis] {
2112 let mut idx = Vec::new();
2113 let mut out_pos = 0;
2114 for (dim, &_size) in shape.iter().enumerate() {
2115 if dim == axis {
2116 idx.push(j);
2117 } else {
2118 idx.push(base_idx[out_pos]);
2119 out_pos += 1;
2120 }
2121 }
2122 let val = input_view[idx.as_slice()];
2123 if val > max_val {
2124 max_val = val;
2125 max_idx = j;
2126 }
2127 }
2128
2129 output_data.push(T::from_usize(max_idx).unwrap());
2130 }
2131
2132 use scirs2_core::ndarray_ext::{Array, IxDyn};
2133 let result_shape = if output_shape.is_empty() {
2134 vec![]
2135 } else {
2136 output_shape
2137 };
2138 let result_array = Array::from_shape_vec(IxDyn(&result_shape), output_data)
2139 .map_err(|e| anyhow!("Failed to create argmax array: {}", e))?;
2140
2141 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2142 result_array,
2143 )))
2144 }
2145
2146 fn argmin(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
2147 let dense = x
2148 .as_dense()
2149 .ok_or_else(|| anyhow!("Only dense tensors supported for argmin"))?;
2150
2151 let shape = dense.shape();
2152
2153 if axis >= shape.len() {
2154 return Err(anyhow!(
2155 "Axis {} out of bounds for tensor with {} dimensions",
2156 axis,
2157 shape.len()
2158 ));
2159 }
2160
2161 let output_shape: Vec<usize> = shape
2163 .iter()
2164 .enumerate()
2165 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
2166 .collect();
2167
2168 let input_view = dense.view();
2169 let output_size: usize = if output_shape.is_empty() {
2170 1
2171 } else {
2172 output_shape.iter().product()
2173 };
2174 let mut output_data = Vec::with_capacity(output_size);
2175
2176 for i in 0..output_size {
2178 let base_idx = if output_shape.is_empty() {
2179 vec![]
2180 } else {
2181 self.flat_to_multidim(i, &output_shape)
2182 };
2183
2184 let mut min_val = T::from_f64(f64::INFINITY).unwrap();
2185 let mut min_idx = 0usize;
2186
2187 for j in 0..shape[axis] {
2188 let mut idx = Vec::new();
2189 let mut out_pos = 0;
2190 for (dim, &_size) in shape.iter().enumerate() {
2191 if dim == axis {
2192 idx.push(j);
2193 } else {
2194 idx.push(base_idx[out_pos]);
2195 out_pos += 1;
2196 }
2197 }
2198 let val = input_view[idx.as_slice()];
2199 if val < min_val {
2200 min_val = val;
2201 min_idx = j;
2202 }
2203 }
2204
2205 output_data.push(T::from_usize(min_idx).unwrap());
2206 }
2207
2208 use scirs2_core::ndarray_ext::{Array, IxDyn};
2209 let result_shape = if output_shape.is_empty() {
2210 vec![]
2211 } else {
2212 output_shape
2213 };
2214 let result_array = Array::from_shape_vec(IxDyn(&result_shape), output_data)
2215 .map_err(|e| anyhow!("Failed to create argmin array: {}", e))?;
2216
2217 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2218 result_array,
2219 )))
2220 }
2221}