1use super::aligned::AlignedVec;
2use super::error::TensorError;
3use super::shape::{
4 broadcast_offset, broadcast_shape, compute_strides, increment_coords, shape_element_count,
5};
6use super::simd;
7use super::tensor::Tensor;
8
9impl Tensor {
10 pub fn add(&self, rhs: &Self) -> Result<Self, TensorError> {
14 if self.shape() == rhs.shape() {
15 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Add);
16 }
17 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Add) {
18 return result;
19 }
20 self.binary_broadcast_op(rhs, |l, r| l + r)
21 }
22
23 pub fn sub(&self, rhs: &Self) -> Result<Self, TensorError> {
25 if self.shape() == rhs.shape() {
26 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Sub);
27 }
28 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Sub) {
29 return result;
30 }
31 self.binary_broadcast_op(rhs, |l, r| l - r)
32 }
33
34 pub fn mul(&self, rhs: &Self) -> Result<Self, TensorError> {
36 if self.shape() == rhs.shape() {
37 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Mul);
38 }
39 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Mul) {
40 return result;
41 }
42 self.binary_broadcast_op(rhs, |l, r| l * r)
43 }
44
45 pub fn div(&self, rhs: &Self) -> Result<Self, TensorError> {
47 if self.shape() == rhs.shape() {
48 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Div);
49 }
50 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Div) {
51 return result;
52 }
53 self.binary_broadcast_op(rhs, |l, r| l / r)
54 }
55
56 #[allow(unsafe_code)]
63 pub fn pow(&self, rhs: &Self) -> Result<Self, TensorError> {
64 let rhs_total: usize = rhs.shape().iter().product();
69 if rhs_total == 1 {
70 let exp_val = rhs.data()[0];
71 if exp_val == 2.0 {
72 return self.mul(self);
73 }
74 if exp_val == 0.5 {
75 return Ok(self.sqrt());
76 }
77 if exp_val == 1.0 {
78 return Ok(self.clone());
79 }
80 if exp_val == 0.0 {
81 return Tensor::ones(self.shape().to_vec());
82 }
83 if exp_val == -1.0 {
84 return Ok(self.reciprocal());
85 }
86 }
87 if self.shape() == rhs.shape() {
89 return self.pow_simd(rhs);
90 }
91 self.binary_broadcast_op(rhs, |l, r| l.powf(r))
92 }
93
94 #[allow(unsafe_code)]
96 fn pow_simd(&self, rhs: &Self) -> Result<Self, TensorError> {
97 let len = self.len();
98 let mut ln_buf = AlignedVec::<f32>::uninitialized(len);
100 simd::ln_dispatch(self.data(), &mut ln_buf);
101 let mut prod_buf = AlignedVec::<f32>::uninitialized(len);
103 simd::binary_dispatch(&ln_buf, rhs.data(), &mut prod_buf, simd::BinaryKind::Mul);
104 let mut out = AlignedVec::<f32>::uninitialized(len);
106 simd::exp_dispatch(&prod_buf, &mut out);
107 Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
108 }
109
110 #[allow(unsafe_code)]
115 pub fn atan2(&self, rhs: &Self) -> Result<Self, TensorError> {
116 if self.shape() == rhs.shape() {
117 return self.atan2_fast(rhs);
118 }
119 self.binary_broadcast_op(rhs, f32::atan2)
120 }
121
122 #[allow(unsafe_code)]
126 fn atan2_fast(&self, rhs: &Self) -> Result<Self, TensorError> {
127 let y_data = self.data();
128 let x_data = rhs.data();
129 let len = self.len();
130 let mut out = AlignedVec::<f32>::uninitialized(len);
131
132 let mut i = 0;
136 while i + 4 <= len {
137 out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
138 out[i + 1] = fast_atan2_scalar(y_data[i + 1], x_data[i + 1]);
139 out[i + 2] = fast_atan2_scalar(y_data[i + 2], x_data[i + 2]);
140 out[i + 3] = fast_atan2_scalar(y_data[i + 3], x_data[i + 3]);
141 i += 4;
142 }
143 while i < len {
144 out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
145 i += 1;
146 }
147
148 Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
149 }
150
151 pub fn minimum(&self, rhs: &Self) -> Result<Self, TensorError> {
153 if self.shape() == rhs.shape() {
154 return self.binary_same_shape(rhs, f32::min);
155 }
156 self.binary_broadcast_op(rhs, f32::min)
157 }
158
159 pub fn maximum(&self, rhs: &Self) -> Result<Self, TensorError> {
161 if self.shape() == rhs.shape() {
162 return self.binary_same_shape(rhs, f32::max);
163 }
164 self.binary_broadcast_op(rhs, f32::max)
165 }
166
167 pub fn neg(&self) -> Self {
171 self.unary_simd_op(simd::UnaryKind::Neg)
172 }
173
174 pub fn abs(&self) -> Self {
176 self.unary_simd_op(simd::UnaryKind::Abs)
177 }
178
179 #[allow(unsafe_code)]
181 pub fn exp(&self) -> Self {
182 let len = self.len();
183 let mut out = AlignedVec::<f32>::uninitialized(len);
186 simd::exp_dispatch(self.data(), &mut out);
187 Tensor::from_raw_parts(self.shape(), self.strides(), out)
188 }
189
190 #[allow(unsafe_code)]
192 pub fn ln(&self) -> Self {
193 let len = self.len();
194 let mut out = AlignedVec::<f32>::uninitialized(len);
197 simd::ln_dispatch(self.data(), &mut out);
198 Tensor::from_raw_parts(self.shape(), self.strides(), out)
199 }
200
201 pub fn sqrt(&self) -> Self {
203 self.unary_simd_op(simd::UnaryKind::Sqrt)
204 }
205
206 pub fn reciprocal(&self) -> Self {
208 self.unary_simd_op(simd::UnaryKind::Recip)
209 }
210
211 pub fn sign(&self) -> Self {
213 self.unary_simd_op(simd::UnaryKind::Sign)
214 }
215
216 pub fn floor(&self) -> Self {
218 self.unary_simd_op(simd::UnaryKind::Floor)
219 }
220
221 pub fn ceil(&self) -> Self {
223 self.unary_simd_op(simd::UnaryKind::Ceil)
224 }
225
226 pub fn round(&self) -> Self {
228 self.unary_simd_op(simd::UnaryKind::Round)
229 }
230
231 #[allow(unsafe_code)]
233 pub fn sin(&self) -> Self {
234 let len = self.len();
235 let mut out = AlignedVec::<f32>::uninitialized(len);
236 simd::sin_dispatch(self.data(), &mut out);
237 Tensor::from_raw_parts(self.shape(), self.strides(), out)
238 }
239
240 #[allow(unsafe_code)]
242 pub fn cos(&self) -> Self {
243 let len = self.len();
244 let mut out = AlignedVec::<f32>::uninitialized(len);
245 simd::cos_dispatch(self.data(), &mut out);
246 Tensor::from_raw_parts(self.shape(), self.strides(), out)
247 }
248
249 #[allow(unsafe_code)]
251 pub fn tan(&self) -> Self {
252 let len = self.len();
253 let mut out = AlignedVec::<f32>::uninitialized(len);
254 simd::tan_dispatch(self.data(), &mut out);
255 Tensor::from_raw_parts(self.shape(), self.strides(), out)
256 }
257
258 pub fn asin(&self) -> Self {
260 self.unary_op(f32::asin)
261 }
262
263 pub fn acos(&self) -> Self {
265 self.unary_op(f32::acos)
266 }
267
268 pub fn atan(&self) -> Self {
270 self.unary_op(f32::atan)
271 }
272
273 pub fn sinh(&self) -> Self {
275 self.unary_op(f32::sinh)
276 }
277
278 pub fn cosh(&self) -> Self {
280 self.unary_op(f32::cosh)
281 }
282
283 pub fn log2(&self) -> Self {
285 self.unary_op(f32::log2)
286 }
287
288 pub fn log10(&self) -> Self {
290 self.unary_op(f32::log10)
291 }
292
293 pub fn degrees(&self) -> Self {
295 self.unary_op(|v| v.to_degrees())
296 }
297
298 pub fn radians(&self) -> Self {
300 self.unary_op(|v| v.to_radians())
301 }
302
303 #[allow(unsafe_code)]
305 pub fn clamp(&self, min: f32, max: f32) -> Self {
306 let len = self.len();
307 let mut out = AlignedVec::<f32>::uninitialized(len);
308 simd::clamp_dispatch(self.data(), &mut out, min, max);
309 Tensor::from_raw_parts(self.shape(), self.strides(), out)
310 }
311
312 pub fn scale(&self, factor: f32) -> Self {
314 self.unary_op(|v| v * factor)
315 }
316
317 pub fn add_scalar(&self, value: f32) -> Self {
319 self.unary_op(|v| v + value)
320 }
321
322 pub fn sum(&self) -> f32 {
326 simd::sum_dispatch(self.data())
327 }
328
329 pub fn mean(&self) -> f32 {
331 if self.is_empty() {
332 return f32::NAN;
333 }
334 self.sum() / self.len() as f32
335 }
336
337 pub fn max_value(&self) -> f32 {
339 simd::max_dispatch(self.data())
340 }
341
342 pub fn min_value(&self) -> f32 {
344 simd::min_dispatch(self.data())
345 }
346
347 pub fn argmax(&self) -> Option<usize> {
349 if self.is_empty() {
350 return None;
351 }
352 let mut best = 0;
353 let mut best_val = self.data()[0];
354 for (i, &v) in self.data().iter().enumerate().skip(1) {
355 if v > best_val {
356 best_val = v;
357 best = i;
358 }
359 }
360 Some(best)
361 }
362
363 pub fn argmin(&self) -> Option<usize> {
365 if self.is_empty() {
366 return None;
367 }
368 let mut best = 0;
369 let mut best_val = self.data()[0];
370 for (i, &v) in self.data().iter().enumerate().skip(1) {
371 if v < best_val {
372 best_val = v;
373 best = i;
374 }
375 }
376 Some(best)
377 }
378
379 pub fn var(&self) -> f32 {
381 if self.is_empty() {
382 return f32::NAN;
383 }
384 let m = self.mean();
385 self.data().iter().map(|&v| (v - m) * (v - m)).sum::<f32>() / self.len() as f32
386 }
387
388 pub fn std_dev(&self) -> f32 {
390 self.var().sqrt()
391 }
392
393 pub fn sum_axis(&self, axis: usize) -> Result<Self, TensorError> {
395 let shape = self.shape();
396 let rank = shape.len();
397 if axis >= rank {
398 return Err(TensorError::InvalidAxis { axis, rank });
399 }
400
401 if rank == 2 && axis == 0 {
403 let (rows, cols) = (shape[0], shape[1]);
404 let data = self.data();
405 let mut out = vec![0.0f32; cols];
406 for row in 0..rows {
407 let row_start = row * cols;
408 simd::add_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
409 }
410 return Self::from_vec(vec![cols], out);
411 }
412
413 if rank == 2 && axis == 1 {
415 let (rows, cols) = (shape[0], shape[1]);
416 let data = self.data();
417 let mut out = Vec::with_capacity(rows);
418 for row in 0..rows {
419 out.push(simd::sum_dispatch(&data[row * cols..(row + 1) * cols]));
420 }
421 return Self::from_vec(vec![rows], out);
422 }
423
424 self.reduce_axis(axis, 0.0, |acc, v| acc + v)
425 }
426
427 pub fn mean_axis(&self, axis: usize) -> Result<Self, TensorError> {
429 if axis >= self.rank() {
430 return Err(TensorError::InvalidAxis {
431 axis,
432 rank: self.rank(),
433 });
434 }
435 let axis_len = self.shape()[axis] as f32;
436 let sum = self.sum_axis(axis)?;
437 Ok(sum.scale(1.0 / axis_len))
438 }
439
440 pub fn max_axis(&self, axis: usize) -> Result<Self, TensorError> {
442 let shape = self.shape();
443 let rank = shape.len();
444 if axis >= rank {
445 return Err(TensorError::InvalidAxis { axis, rank });
446 }
447
448 if rank == 2 && axis == 0 {
450 let (rows, cols) = (shape[0], shape[1]);
451 let data = self.data();
452 let mut out = data[..cols].to_vec();
453 for row in 1..rows {
454 let row_start = row * cols;
455 simd::max_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
456 }
457 return Self::from_vec(vec![cols], out);
458 }
459
460 if rank == 2 && axis == 1 {
462 let (rows, cols) = (shape[0], shape[1]);
463 let data = self.data();
464 let mut out = Vec::with_capacity(rows);
465 for row in 0..rows {
466 out.push(simd::max_dispatch(&data[row * cols..(row + 1) * cols]));
467 }
468 return Self::from_vec(vec![rows], out);
469 }
470
471 self.reduce_axis(axis, f32::NEG_INFINITY, f32::max)
472 }
473
474 pub fn min_axis(&self, axis: usize) -> Result<Self, TensorError> {
476 self.reduce_axis(axis, f32::INFINITY, f32::min)
477 }
478
479 pub fn var_axis(&self, axis: usize) -> Result<Self, TensorError> {
481 let m = self.mean_axis(axis)?;
482 let diff = self.sub(&m.unsqueeze(axis)?)?;
483 let sq = diff.mul(&diff)?;
484 sq.mean_axis(axis)
485 }
486
487 pub fn median(&self) -> f32 {
489 let mut sorted = self.data().to_vec();
490 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491 let n = sorted.len();
492 if n == 0 {
493 return 0.0;
494 }
495 if n % 2 == 1 {
496 sorted[n / 2]
497 } else {
498 (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
499 }
500 }
501
502 pub fn median_axis(&self, axis: usize) -> Result<Self, TensorError> {
504 let shape = self.shape();
505 let rank = shape.len();
506 if axis >= rank {
507 return Err(TensorError::InvalidAxis { axis, rank });
508 }
509 let axis_len = shape[axis];
510 let outer: usize = shape[..axis].iter().product();
511 let inner: usize = shape[axis + 1..].iter().product();
512 let mut new_shape = shape.to_vec();
513 new_shape.remove(axis);
514 if new_shape.is_empty() {
515 new_shape.push(1);
516 }
517 let data = self.data();
518 let mut result = Vec::with_capacity(outer * inner);
519 for o in 0..outer {
520 for i in 0..inner {
521 let mut vals: Vec<f32> = (0..axis_len)
522 .map(|a| data[o * axis_len * inner + a * inner + i])
523 .collect();
524 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525 let n = vals.len();
526 let med = if n % 2 == 1 {
527 vals[n / 2]
528 } else {
529 (vals[n / 2 - 1] + vals[n / 2]) / 2.0
530 };
531 result.push(med);
532 }
533 }
534 Self::from_vec(new_shape, result)
535 }
536
537 pub fn any(&self) -> bool {
539 self.data().iter().any(|&v| v != 0.0)
540 }
541
542 pub fn all(&self) -> bool {
544 self.data().iter().all(|&v| v != 0.0)
545 }
546
547 pub fn quantile(&self, q: f32) -> f32 {
549 let mut sorted = self.data().to_vec();
550 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
551 let n = sorted.len();
552 if n == 0 {
553 return 0.0;
554 }
555 let idx = q * (n - 1) as f32;
556 let lo = idx.floor() as usize;
557 let hi = idx.ceil() as usize;
558 if lo == hi || hi >= n {
559 sorted[lo.min(n - 1)]
560 } else {
561 let frac = idx - lo as f32;
562 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
563 }
564 }
565
566 #[allow(unsafe_code)]
574 pub fn transpose_2d(&self) -> Result<Self, TensorError> {
575 if self.rank() != 2 {
576 return Err(TensorError::InvalidAxis {
577 axis: 1,
578 rank: self.rank(),
579 });
580 }
581 let rows = self.shape()[0];
582 let cols = self.shape()[1];
583 let mut out_data = AlignedVec::<f32>::uninitialized(rows * cols);
585 let src = self.data();
586
587 const TILE: usize = 8;
589 let rr = rows / TILE * TILE;
590 let cc = cols / TILE * TILE;
591
592 for ii in (0..rr).step_by(TILE) {
593 for jj in (0..cc).step_by(TILE) {
594 for r in ii..ii + TILE {
595 for c in jj..jj + TILE {
596 out_data[c * rows + r] = src[r * cols + c];
597 }
598 }
599 }
600 for r in ii..ii + TILE {
602 for c in cc..cols {
603 out_data[c * rows + r] = src[r * cols + c];
604 }
605 }
606 }
607 for r in rr..rows {
609 for c in 0..cols {
610 out_data[c * rows + r] = src[r * cols + c];
611 }
612 }
613
614 Tensor::from_aligned(vec![cols, rows], out_data)
615 }
616
617 pub fn permute(&self, axes: &[usize]) -> Result<Self, TensorError> {
619 if axes.len() != self.rank() {
620 return Err(TensorError::InvalidIndexRank {
621 expected: self.rank(),
622 got: axes.len(),
623 });
624 }
625 let rank = self.rank();
626 let mut seen = vec![false; rank];
627 for &a in axes {
628 if a >= rank {
629 return Err(TensorError::InvalidAxis { axis: a, rank });
630 }
631 seen[a] = true;
632 }
633 if seen.iter().any(|&s| !s) {
634 return Err(TensorError::InvalidAxis { axis: 0, rank });
635 }
636
637 let src_shape = self.shape();
638 let mut out_shape = vec![0usize; rank];
639 for (dst, &src_axis) in axes.iter().enumerate() {
640 out_shape[dst] = src_shape[src_axis];
641 }
642 let out_count =
643 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
644 shape: out_shape.clone(),
645 })?;
646 let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
647 shape: out_shape.clone(),
648 })?;
649 let mut out_data = vec![0.0f32; out_count];
650
651 let mut in_coords = vec![0usize; rank];
652 for &val in self.data().iter() {
653 let mut out_offset = 0usize;
654 for (dst_axis, &src_axis) in axes.iter().enumerate() {
655 out_offset += in_coords[src_axis] * out_strides[dst_axis];
656 }
657 out_data[out_offset] = val;
658 increment_coords(&mut in_coords, src_shape);
659 }
660
661 Tensor::from_vec(out_shape, out_data)
662 }
663
664 pub fn unsqueeze(&self, axis: usize) -> Result<Self, TensorError> {
666 if axis > self.rank() {
667 return Err(TensorError::InvalidAxis {
668 axis,
669 rank: self.rank() + 1,
670 });
671 }
672 let mut new_shape = self.shape().to_vec();
673 new_shape.insert(axis, 1);
674 self.reshape(new_shape)
675 }
676
677 pub fn squeeze(&self, axis: usize) -> Result<Self, TensorError> {
679 if axis >= self.rank() {
680 return Err(TensorError::InvalidAxis {
681 axis,
682 rank: self.rank(),
683 });
684 }
685 if self.shape()[axis] != 1 {
686 return Err(TensorError::InvalidAxis {
687 axis,
688 rank: self.rank(),
689 });
690 }
691 let mut new_shape = self.shape().to_vec();
692 new_shape.remove(axis);
693 self.reshape(new_shape)
694 }
695
696 pub fn cat(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
699 if tensors.is_empty() {
700 return Err(TensorError::SizeMismatch {
701 shape: vec![],
702 data_len: 0,
703 });
704 }
705 let rank = tensors[0].rank();
706 if axis >= rank {
707 return Err(TensorError::InvalidAxis { axis, rank });
708 }
709 for t in &tensors[1..] {
710 if t.rank() != rank {
711 return Err(TensorError::ShapeMismatch {
712 left: tensors[0].shape().to_vec(),
713 right: t.shape().to_vec(),
714 });
715 }
716 for (a, (&d0, &di)) in tensors[0].shape().iter().zip(t.shape().iter()).enumerate() {
717 if a != axis && d0 != di {
718 return Err(TensorError::ShapeMismatch {
719 left: tensors[0].shape().to_vec(),
720 right: t.shape().to_vec(),
721 });
722 }
723 }
724 }
725
726 let mut out_shape = tensors[0].shape().to_vec();
727 out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
728 let out_count =
729 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
730 shape: out_shape.clone(),
731 })?;
732
733 let outer: usize = out_shape[..axis].iter().product();
734 let inner: usize = out_shape[axis + 1..].iter().product();
735 let mut out_data = Vec::with_capacity(out_count);
736
737 for o in 0..outer {
738 for t in tensors {
739 let t_axis_len = t.shape()[axis];
740 let chunk_start = o * t_axis_len * inner;
741 let chunk_end = chunk_start + t_axis_len * inner;
742 out_data.extend_from_slice(&t.data()[chunk_start..chunk_end]);
743 }
744 }
745
746 Tensor::from_vec(out_shape, out_data)
747 }
748
749 pub fn stack(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
751 if tensors.is_empty() {
752 return Err(TensorError::SizeMismatch {
753 shape: vec![],
754 data_len: 0,
755 });
756 }
757 if axis > tensors[0].rank() {
758 return Err(TensorError::InvalidAxis {
759 axis,
760 rank: tensors[0].rank() + 1,
761 });
762 }
763 let expanded: Vec<Self> = tensors
764 .iter()
765 .map(|t| t.unsqueeze(axis))
766 .collect::<Result<_, _>>()?;
767 let refs: Vec<&Self> = expanded.iter().collect();
768 Self::cat(&refs, axis)
769 }
770
771 pub fn select(&self, axis: usize, index: usize) -> Result<Self, TensorError> {
773 if axis >= self.rank() {
774 return Err(TensorError::InvalidAxis {
775 axis,
776 rank: self.rank(),
777 });
778 }
779 if index >= self.shape()[axis] {
780 return Err(TensorError::IndexOutOfBounds {
781 axis,
782 index,
783 dim: self.shape()[axis],
784 });
785 }
786 let outer: usize = self.shape()[..axis].iter().product();
787 let axis_len = self.shape()[axis];
788 let inner: usize = self.shape()[axis + 1..].iter().product();
789
790 let mut out_data = Vec::with_capacity(outer * inner);
791 for o in 0..outer {
792 let base = o * axis_len * inner + index * inner;
793 out_data.extend_from_slice(&self.data()[base..base + inner]);
794 }
795
796 let mut out_shape = self.shape().to_vec();
797 out_shape.remove(axis);
798 Tensor::from_vec(out_shape, out_data)
799 }
800
801 pub fn narrow(&self, axis: usize, start: usize, length: usize) -> Result<Self, TensorError> {
803 if axis >= self.rank() {
804 return Err(TensorError::InvalidAxis {
805 axis,
806 rank: self.rank(),
807 });
808 }
809 if start + length > self.shape()[axis] {
810 return Err(TensorError::IndexOutOfBounds {
811 axis,
812 index: start + length,
813 dim: self.shape()[axis],
814 });
815 }
816 let outer: usize = self.shape()[..axis].iter().product();
817 let axis_len = self.shape()[axis];
818 let inner: usize = self.shape()[axis + 1..].iter().product();
819
820 let mut out_data = Vec::with_capacity(outer * length * inner);
821 for o in 0..outer {
822 let base = o * axis_len * inner + start * inner;
823 out_data.extend_from_slice(&self.data()[base..base + length * inner]);
824 }
825
826 let mut out_shape = self.shape().to_vec();
827 out_shape[axis] = length;
828 Tensor::from_vec(out_shape, out_data)
829 }
830
831 pub fn eq_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
835 if self.shape() == rhs.shape() {
836 return self.binary_same_shape(rhs, |l, r| {
837 if (l - r).abs() < f32::EPSILON {
838 1.0
839 } else {
840 0.0
841 }
842 });
843 }
844 self.binary_broadcast_op(rhs, |l, r| {
845 if (l - r).abs() < f32::EPSILON {
846 1.0
847 } else {
848 0.0
849 }
850 })
851 }
852
853 pub fn gt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
855 if self.shape() == rhs.shape() {
856 return self.binary_same_shape(rhs, |l, r| if l > r { 1.0 } else { 0.0 });
857 }
858 self.binary_broadcast_op(rhs, |l, r| if l > r { 1.0 } else { 0.0 })
859 }
860
861 pub fn lt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
863 if self.shape() == rhs.shape() {
864 return self.binary_same_shape(rhs, |l, r| if l < r { 1.0 } else { 0.0 });
865 }
866 self.binary_broadcast_op(rhs, |l, r| if l < r { 1.0 } else { 0.0 })
867 }
868
869 pub fn gt_tensor_into(&self, rhs: &Self, output: &mut Self) {
872 debug_assert_eq!(self.shape(), rhs.shape());
873 debug_assert_eq!(self.shape(), output.shape());
874 simd::cmp_dispatch(
875 self.data(),
876 rhs.data(),
877 output.data_mut(),
878 simd::CmpKind::Gt,
879 );
880 }
881
882 pub fn eq_tensor_into(&self, rhs: &Self, output: &mut Self) {
885 debug_assert_eq!(self.shape(), rhs.shape());
886 debug_assert_eq!(self.shape(), output.shape());
887 simd::cmp_dispatch(
888 self.data(),
889 rhs.data(),
890 output.data_mut(),
891 simd::CmpKind::Eq,
892 );
893 }
894
895 pub fn lt_tensor_into(&self, rhs: &Self, output: &mut Self) {
898 debug_assert_eq!(self.shape(), rhs.shape());
899 debug_assert_eq!(self.shape(), output.shape());
900 simd::cmp_dispatch(
901 self.data(),
902 rhs.data(),
903 output.data_mut(),
904 simd::CmpKind::Lt,
905 );
906 }
907
908 pub fn ne_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
910 if self.shape() == rhs.shape() {
911 return self.binary_same_shape(
912 rhs,
913 |l, r| {
914 if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 }
915 },
916 );
917 }
918 self.binary_broadcast_op(rhs, |l, r| if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 })
919 }
920
921 pub fn le_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
923 if self.shape() == rhs.shape() {
924 return self.binary_same_shape(rhs, |l, r| if l <= r { 1.0 } else { 0.0 });
925 }
926 self.binary_broadcast_op(rhs, |l, r| if l <= r { 1.0 } else { 0.0 })
927 }
928
929 pub fn ge_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
931 if self.shape() == rhs.shape() {
932 return self.binary_same_shape(rhs, |l, r| if l >= r { 1.0 } else { 0.0 });
933 }
934 self.binary_broadcast_op(rhs, |l, r| if l >= r { 1.0 } else { 0.0 })
935 }
936
937 pub fn all_finite(&self) -> bool {
939 self.data().iter().all(|v| v.is_finite())
940 }
941
942 pub fn where_cond(&self, condition: &Self, other: &Self) -> Result<Self, TensorError> {
947 if self.shape() != condition.shape() || self.shape() != other.shape() {
948 return Err(TensorError::ShapeMismatch {
949 left: self.shape().to_vec(),
950 right: condition.shape().to_vec(),
951 });
952 }
953 let data: Vec<f32> = condition
954 .data()
955 .iter()
956 .zip(self.data().iter())
957 .zip(other.data().iter())
958 .map(|((&c, &t), &f)| if c != 0.0 { t } else { f })
959 .collect();
960 Tensor::from_vec(self.shape().to_vec(), data)
961 }
962
963 pub fn masked_fill(&self, mask: &Self, value: f32) -> Result<Self, TensorError> {
965 if self.shape() != mask.shape() {
966 return Err(TensorError::ShapeMismatch {
967 left: self.shape().to_vec(),
968 right: mask.shape().to_vec(),
969 });
970 }
971 let data: Vec<f32> = self
972 .data()
973 .iter()
974 .zip(mask.data().iter())
975 .map(|(&v, &m)| if m != 0.0 { value } else { v })
976 .collect();
977 Tensor::from_vec(self.shape().to_vec(), data)
978 }
979
980 pub fn scatter(&self, axis: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
983 if index.shape() != src.shape() {
984 return Err(TensorError::ShapeMismatch {
985 left: index.shape().to_vec(),
986 right: src.shape().to_vec(),
987 });
988 }
989 if axis >= self.rank() {
990 return Err(TensorError::InvalidAxis {
991 axis,
992 rank: self.rank(),
993 });
994 }
995 let mut out = self.data().to_vec();
996 let shape = index.shape();
997 let outer: usize = shape[..axis].iter().product();
998 let dim = shape[axis];
999 let inner: usize = shape[axis + 1..].iter().product();
1000 let self_dim = self.shape()[axis];
1001 let self_inner: usize = self.shape()[axis + 1..].iter().product();
1002
1003 for o in 0..outer {
1004 for d in 0..dim {
1005 for i in 0..inner {
1006 let src_idx = (o * dim + d) * inner + i;
1007 let target_d = index.data()[src_idx] as usize;
1008 if target_d < self_dim {
1009 let out_idx = (o * self_dim + target_d) * self_inner + i;
1010 if out_idx < out.len() {
1011 out[out_idx] = src.data()[src_idx];
1012 }
1013 }
1014 }
1015 }
1016 }
1017 Tensor::from_vec(self.shape().to_vec(), out)
1018 }
1019
1020 pub fn gather(&self, axis: usize, index: &Self) -> Result<Self, TensorError> {
1022 if axis >= self.rank() {
1023 return Err(TensorError::InvalidAxis {
1024 axis,
1025 rank: self.rank(),
1026 });
1027 }
1028 let shape = index.shape();
1029 let outer: usize = shape[..axis].iter().product();
1030 let dim = shape[axis];
1031 let inner: usize = shape[axis + 1..].iter().product();
1032 let self_dim = self.shape()[axis];
1033 let self_inner: usize = self.shape()[axis + 1..].iter().product();
1034
1035 let mut out = vec![0.0f32; index.len()];
1036 for o in 0..outer {
1037 for d in 0..dim {
1038 for i in 0..inner {
1039 let idx_pos = (o * dim + d) * inner + i;
1040 let src_d = index.data()[idx_pos] as usize;
1041 if src_d < self_dim {
1042 let src_pos = (o * self_dim + src_d) * self_inner + i;
1043 if src_pos < self.len() {
1044 out[idx_pos] = self.data()[src_pos];
1045 }
1046 }
1047 }
1048 }
1049 }
1050 Tensor::from_vec(shape.to_vec(), out)
1051 }
1052
1053 pub fn topk(&self, k: usize) -> Result<(Self, Self), TensorError> {
1055 if self.rank() == 0 {
1056 return Err(TensorError::InvalidAxis { axis: 0, rank: 0 });
1057 }
1058 let last_dim = *self.shape().last().expect("non-empty shape");
1059 let k = k.min(last_dim);
1060 let outer: usize = self.len() / last_dim;
1061
1062 let mut values = Vec::with_capacity(outer * k);
1063 let mut indices = Vec::with_capacity(outer * k);
1064
1065 for o in 0..outer {
1066 let start = o * last_dim;
1067 let slice = &self.data()[start..start + last_dim];
1068 let mut idx_vec: Vec<usize> = (0..last_dim).collect();
1069 idx_vec.sort_unstable_by(|&a, &b| {
1070 slice[b]
1071 .partial_cmp(&slice[a])
1072 .unwrap_or(std::cmp::Ordering::Equal)
1073 });
1074 for &i in &idx_vec[..k] {
1075 values.push(slice[i]);
1076 indices.push(i as f32);
1077 }
1078 }
1079
1080 let mut out_shape = self.shape().to_vec();
1081 *out_shape.last_mut().expect("non-empty shape") = k;
1082 let val_t = Tensor::from_vec(out_shape.clone(), values)?;
1083 let idx_t = Tensor::from_vec(out_shape, indices)?;
1084 Ok((val_t, idx_t))
1085 }
1086
1087 pub fn triu(&self, diagonal: i64) -> Result<Self, TensorError> {
1090 if self.rank() < 2 {
1091 return Err(TensorError::InvalidAxis {
1092 axis: 0,
1093 rank: self.rank(),
1094 });
1095 }
1096 let shape = self.shape();
1097 let rows = shape[shape.len() - 2];
1098 let cols = shape[shape.len() - 1];
1099 let batch: usize = shape[..shape.len() - 2].iter().product();
1100 let mut out = self.data().to_vec();
1101 for b in 0..batch {
1102 for r in 0..rows {
1103 for c in 0..cols {
1104 if (c as i64) < (r as i64) + diagonal {
1105 out[b * rows * cols + r * cols + c] = 0.0;
1106 }
1107 }
1108 }
1109 }
1110 Tensor::from_vec(shape.to_vec(), out)
1111 }
1112
1113 pub fn tril(&self, diagonal: i64) -> Result<Self, TensorError> {
1115 if self.rank() < 2 {
1116 return Err(TensorError::InvalidAxis {
1117 axis: 0,
1118 rank: self.rank(),
1119 });
1120 }
1121 let shape = self.shape();
1122 let rows = shape[shape.len() - 2];
1123 let cols = shape[shape.len() - 1];
1124 let batch: usize = shape[..shape.len() - 2].iter().product();
1125 let mut out = self.data().to_vec();
1126 for b in 0..batch {
1127 for r in 0..rows {
1128 for c in 0..cols {
1129 if (c as i64) > (r as i64) + diagonal {
1130 out[b * rows * cols + r * cols + c] = 0.0;
1131 }
1132 }
1133 }
1134 }
1135 Tensor::from_vec(shape.to_vec(), out)
1136 }
1137
1138 pub fn eye(n: usize) -> Result<Self, TensorError> {
1140 let mut data = vec![0.0f32; n * n];
1141 for i in 0..n {
1142 data[i * n + i] = 1.0;
1143 }
1144 Tensor::from_vec(vec![n, n], data)
1145 }
1146
1147 pub fn diag(vector: &Tensor) -> Result<Self, TensorError> {
1149 let shape = vector.shape();
1150 if shape.len() != 1 {
1151 return Err(TensorError::UnsupportedOperation {
1152 msg: format!("diag requires a 1D tensor, got shape {:?}", shape),
1153 });
1154 }
1155 let n = shape[0];
1156 let mut data = vec![0.0f32; n * n];
1157 for i in 0..n {
1158 data[i * n + i] = vector.data()[i];
1159 }
1160 Self::from_vec(vec![n, n], data)
1161 }
1162
1163 pub fn diag_extract(&self) -> Result<Self, TensorError> {
1165 let shape = self.shape();
1166 if shape.len() != 2 {
1167 return Err(TensorError::UnsupportedOperation {
1168 msg: format!("diag_extract requires a 2D tensor, got shape {:?}", shape),
1169 });
1170 }
1171 let n = shape[0].min(shape[1]);
1172 let cols = shape[1];
1173 let data: Vec<f32> = (0..n).map(|i| self.data()[i * cols + i]).collect();
1174 Self::from_vec(vec![n], data)
1175 }
1176
1177 pub fn pad(&self, padding: &[(usize, usize)], value: f32) -> Result<Self, TensorError> {
1179 let shape = self.shape();
1180 if padding.len() != shape.len() {
1181 return Err(TensorError::InvalidIndexRank {
1182 expected: shape.len(),
1183 got: padding.len(),
1184 });
1185 }
1186 let new_shape: Vec<usize> = shape
1187 .iter()
1188 .zip(padding)
1189 .map(|(&s, &(b, a))| s + b + a)
1190 .collect();
1191 let new_size: usize = new_shape.iter().product();
1192 let mut result = vec![value; new_size];
1193 let ndim = shape.len();
1194
1195 let mut old_strides = vec![1usize; ndim];
1197 for i in (0..ndim.saturating_sub(1)).rev() {
1198 old_strides[i] = old_strides[i + 1] * shape[i + 1];
1199 }
1200 let mut new_strides = vec![1usize; ndim];
1201 for i in (0..ndim.saturating_sub(1)).rev() {
1202 new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1203 }
1204
1205 let old_size: usize = shape.iter().product();
1206 let data = self.data();
1207 for flat_idx in 0..old_size {
1208 let mut remaining = flat_idx;
1209 let mut new_flat = 0;
1210 for d in 0..ndim {
1211 let coord = remaining / old_strides[d];
1212 remaining %= old_strides[d];
1213 new_flat += (coord + padding[d].0) * new_strides[d];
1214 }
1215 result[new_flat] = data[flat_idx];
1216 }
1217
1218 Self::from_vec(new_shape, result)
1219 }
1220
1221 pub fn repeat(&self, counts: &[usize]) -> Result<Self, TensorError> {
1223 if counts.len() != self.rank() {
1224 return Err(TensorError::InvalidIndexRank {
1225 expected: self.rank(),
1226 got: counts.len(),
1227 });
1228 }
1229 let mut out = self.clone();
1230 for (axis, &count) in counts.iter().enumerate() {
1231 if count > 1 {
1232 let refs: Vec<&Tensor> = std::iter::repeat_n(&out, count).collect();
1233 out = Tensor::cat(&refs, axis)?;
1234 }
1235 }
1236 Ok(out)
1237 }
1238
1239 pub fn cumsum(&self, axis: usize) -> Result<Self, TensorError> {
1243 if axis >= self.rank() {
1244 return Err(TensorError::InvalidAxis {
1245 axis,
1246 rank: self.rank(),
1247 });
1248 }
1249 let shape = self.shape();
1250 let outer: usize = shape[..axis].iter().product();
1251 let axis_len = shape[axis];
1252 let inner: usize = shape[axis + 1..].iter().product();
1253 let mut out = self.data().to_vec();
1254
1255 for o in 0..outer {
1256 for i in 0..inner {
1257 let mut acc = 0.0f32;
1258 for d in 0..axis_len {
1259 let idx = (o * axis_len + d) * inner + i;
1260 acc += out[idx];
1261 out[idx] = acc;
1262 }
1263 }
1264 }
1265 Tensor::from_vec(shape.to_vec(), out)
1266 }
1267
1268 pub fn cumprod(&self, axis: usize) -> Result<Self, TensorError> {
1270 if axis >= self.rank() {
1271 return Err(TensorError::InvalidAxis {
1272 axis,
1273 rank: self.rank(),
1274 });
1275 }
1276 let shape = self.shape();
1277 let outer: usize = shape[..axis].iter().product();
1278 let axis_len = shape[axis];
1279 let inner: usize = shape[axis + 1..].iter().product();
1280 let mut out = self.data().to_vec();
1281
1282 for o in 0..outer {
1283 for i in 0..inner {
1284 let mut acc = 1.0f32;
1285 for d in 0..axis_len {
1286 let idx = (o * axis_len + d) * inner + i;
1287 acc *= out[idx];
1288 out[idx] = acc;
1289 }
1290 }
1291 }
1292 Tensor::from_vec(shape.to_vec(), out)
1293 }
1294
1295 pub fn to_fp16(&self) -> Vec<u16> {
1300 self.data().iter().map(|&v| f32_to_fp16(v)).collect()
1301 }
1302
1303 pub fn from_fp16(shape: Vec<usize>, fp16_data: &[u16]) -> Result<Self, TensorError> {
1305 let data: Vec<f32> = fp16_data.iter().map(|&v| fp16_to_f32(v)).collect();
1306 Tensor::from_vec(shape, data)
1307 }
1308
1309 #[allow(unsafe_code)]
1312 fn unary_op<F>(&self, op: F) -> Self
1313 where
1314 F: Fn(f32) -> f32,
1315 {
1316 let src = self.data();
1317 let len = src.len();
1318 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1321 let inp = src.as_ptr();
1322 let outp = out_data.as_mut_ptr();
1323 unsafe {
1324 for i in 0..len {
1325 *outp.add(i) = op(*inp.add(i));
1326 }
1327 }
1328 Tensor::from_raw_parts(self.shape(), self.strides(), out_data)
1329 }
1330
1331 #[allow(unsafe_code)]
1332 fn unary_simd_op(&self, kind: simd::UnaryKind) -> Self {
1333 let len = self.len();
1334 let mut out = AlignedVec::<f32>::uninitialized(len);
1337 simd::unary_dispatch(self.data(), &mut out, kind);
1338 Tensor::from_raw_parts(self.shape(), self.strides(), out)
1339 }
1340
1341 fn reduce_axis<F>(&self, axis: usize, init: f32, combine: F) -> Result<Self, TensorError>
1342 where
1343 F: Fn(f32, f32) -> f32,
1344 {
1345 if axis >= self.rank() {
1346 return Err(TensorError::InvalidAxis {
1347 axis,
1348 rank: self.rank(),
1349 });
1350 }
1351
1352 let mut out_shape = self.shape().to_vec();
1353 out_shape.remove(axis);
1354 let out_count =
1355 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1356 shape: out_shape.clone(),
1357 })?;
1358 let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1359 shape: out_shape.clone(),
1360 })?;
1361 let mut out_data = vec![init; out_count];
1362
1363 let mut in_coords = vec![0usize; self.rank()];
1364 for input in self.data().iter().copied() {
1365 let mut out_offset = 0usize;
1366 for (src_axis, coord) in in_coords.iter().copied().enumerate() {
1367 if src_axis == axis {
1368 continue;
1369 }
1370 let dst_axis = if src_axis < axis {
1371 src_axis
1372 } else {
1373 src_axis - 1
1374 };
1375 if !out_shape.is_empty() {
1376 out_offset += coord * out_strides[dst_axis];
1377 }
1378 }
1379 out_data[out_offset] = combine(out_data[out_offset], input);
1380 increment_coords(&mut in_coords, self.shape());
1381 }
1382
1383 Tensor::from_vec(out_shape, out_data)
1384 }
1385
1386 #[allow(unsafe_code)]
1392 fn binary_broadcast_lastdim_simd(
1393 &self,
1394 rhs: &Self,
1395 kind: simd::BinaryKind,
1396 ) -> Option<Result<Self, TensorError>> {
1397 let lhs_shape = self.shape();
1398 let rhs_shape = rhs.shape();
1399
1400 let lhs_last = *lhs_shape.last()?;
1403 if lhs_last == 0 {
1404 return None;
1405 }
1406
1407 let rhs_last = *rhs_shape.last()?;
1408
1409 let rhs_is_lastdim_vec =
1412 rhs_last == lhs_last && rhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1413 let lhs_is_lastdim_vec =
1415 lhs_last == rhs_last && lhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1416
1417 if rhs_is_lastdim_vec && !lhs_is_lastdim_vec {
1418 let lhs_data = self.data();
1420 let rhs_data = rhs.data();
1421 let row_len = lhs_last;
1422 let num_rows = lhs_data.len() / row_len;
1423 let mut out_data = AlignedVec::<f32>::uninitialized(lhs_data.len());
1424
1425 for i in 0..num_rows {
1426 let start = i * row_len;
1427 let end = start + row_len;
1428 simd::binary_dispatch(
1429 &lhs_data[start..end],
1430 &rhs_data[..row_len],
1431 &mut out_data[start..end],
1432 kind,
1433 );
1434 }
1435
1436 let out_strides = compute_strides(lhs_shape).expect("valid shape for strides");
1437 Some(Ok(Tensor::from_raw_parts(
1438 lhs_shape,
1439 &out_strides,
1440 out_data,
1441 )))
1442 } else if lhs_is_lastdim_vec && !rhs_is_lastdim_vec {
1443 let lhs_data = self.data();
1445 let rhs_data = rhs.data();
1446 let row_len = rhs_last;
1447 let num_rows = rhs_data.len() / row_len;
1448 let mut out_data = AlignedVec::<f32>::uninitialized(rhs_data.len());
1449
1450 for i in 0..num_rows {
1451 let start = i * row_len;
1452 let end = start + row_len;
1453 simd::binary_dispatch(
1454 &lhs_data[..row_len],
1455 &rhs_data[start..end],
1456 &mut out_data[start..end],
1457 kind,
1458 );
1459 }
1460
1461 let out_strides = compute_strides(rhs_shape).expect("valid shape for strides");
1462 Some(Ok(Tensor::from_raw_parts(
1463 rhs_shape,
1464 &out_strides,
1465 out_data,
1466 )))
1467 } else {
1468 None
1469 }
1470 }
1471
1472 fn binary_broadcast_op<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1473 where
1474 F: Fn(f32, f32) -> f32,
1475 {
1476 let out_shape = broadcast_shape(self.shape(), rhs.shape()).ok_or_else(|| {
1477 TensorError::BroadcastIncompatible {
1478 left: self.shape().to_vec(),
1479 right: rhs.shape().to_vec(),
1480 }
1481 })?;
1482
1483 let out_count =
1484 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1485 shape: out_shape.clone(),
1486 })?;
1487 let mut out_data = vec![0.0; out_count];
1488 let mut coords = vec![0usize; out_shape.len()];
1489
1490 for value in &mut out_data {
1491 let left_offset = broadcast_offset(self.shape(), self.strides(), &coords);
1492 let right_offset = broadcast_offset(rhs.shape(), rhs.strides(), &coords);
1493 *value = op(self.data()[left_offset], rhs.data()[right_offset]);
1494 increment_coords(&mut coords, &out_shape);
1495 }
1496
1497 Tensor::from_vec(out_shape, out_data)
1498 }
1499
1500 #[allow(unsafe_code)]
1502 #[allow(unsafe_code)]
1503 fn binary_same_shape_simd(
1504 &self,
1505 rhs: &Self,
1506 kind: simd::BinaryKind,
1507 ) -> Result<Self, TensorError> {
1508 let len = self.len();
1509 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1510
1511 if len >= 100_000 {
1514 let n = std::thread::available_parallelism()
1515 .map(|p| p.get())
1516 .unwrap_or(4);
1517 let chunk = len.div_ceil(n);
1518 let lp = self.data().as_ptr() as usize;
1519 let rp = rhs.data().as_ptr() as usize;
1520 let op = out_data.as_mut_ptr() as usize;
1521
1522 #[cfg(target_os = "macos")]
1523 {
1524 use std::ffi::c_void;
1525 #[allow(unsafe_code)]
1526 unsafe extern "C" {
1527 fn dispatch_get_global_queue(id: isize, flags: usize) -> *const c_void;
1528 fn dispatch_apply_f(
1529 n: usize,
1530 q: *const c_void,
1531 ctx: *mut c_void,
1532 work: unsafe extern "C" fn(*mut c_void, usize),
1533 );
1534 }
1535 struct Ctx {
1536 lp: usize,
1537 rp: usize,
1538 op: usize,
1539 chunk: usize,
1540 len: usize,
1541 kind: simd::BinaryKind,
1542 }
1543 let ctx = Ctx {
1544 lp,
1545 rp,
1546 op,
1547 chunk,
1548 len,
1549 kind,
1550 };
1551 unsafe extern "C" fn work(raw: *mut c_void, t: usize) {
1552 let c = unsafe { &*(raw as *const Ctx) };
1553 let start = t * c.chunk;
1554 let end = (start + c.chunk).min(c.len);
1555 if start >= end {
1556 return;
1557 }
1558 let l = unsafe {
1559 std::slice::from_raw_parts((c.lp as *const f32).add(start), end - start)
1560 };
1561 let r = unsafe {
1562 std::slice::from_raw_parts((c.rp as *const f32).add(start), end - start)
1563 };
1564 let o = unsafe {
1565 std::slice::from_raw_parts_mut((c.op as *mut f32).add(start), end - start)
1566 };
1567 simd::binary_dispatch(l, r, o, c.kind);
1568 }
1569 let q = unsafe { dispatch_get_global_queue(0, 0) };
1570 unsafe {
1571 dispatch_apply_f(n, q, &ctx as *const Ctx as *mut c_void, work);
1572 }
1573 }
1574
1575 #[cfg(not(target_os = "macos"))]
1576 {
1577 use rayon::prelude::*;
1579 (0..n).into_par_iter().for_each(|t| {
1580 let start = t * chunk;
1581 let end = (start + chunk).min(len);
1582 if start >= end {
1583 return;
1584 }
1585 let l = unsafe {
1586 std::slice::from_raw_parts((lp as *const f32).add(start), end - start)
1587 };
1588 let r = unsafe {
1589 std::slice::from_raw_parts((rp as *const f32).add(start), end - start)
1590 };
1591 let o = unsafe {
1592 std::slice::from_raw_parts_mut((op as *mut f32).add(start), end - start)
1593 };
1594 simd::binary_dispatch(l, r, o, kind);
1595 });
1596 }
1597
1598 return Ok(Tensor::from_raw_parts(
1599 self.shape(),
1600 self.strides(),
1601 out_data,
1602 ));
1603 }
1604
1605 simd::binary_dispatch(self.data(), rhs.data(), &mut out_data, kind);
1606 Ok(Tensor::from_raw_parts(
1607 self.shape(),
1608 self.strides(),
1609 out_data,
1610 ))
1611 }
1612
1613 #[allow(unsafe_code)]
1614 fn binary_same_shape<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1615 where
1616 F: Fn(f32, f32) -> f32,
1617 {
1618 let len = self.len();
1619 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1622
1623 let lhs_ptr = self.data().as_ptr();
1624 let rhs_ptr = rhs.data().as_ptr();
1625 let out_ptr = out_data.as_mut_ptr();
1626
1627 unsafe {
1632 let mut index = 0usize;
1633 while index + 4 <= len {
1634 *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1635 *out_ptr.add(index + 1) = op(*lhs_ptr.add(index + 1), *rhs_ptr.add(index + 1));
1636 *out_ptr.add(index + 2) = op(*lhs_ptr.add(index + 2), *rhs_ptr.add(index + 2));
1637 *out_ptr.add(index + 3) = op(*lhs_ptr.add(index + 3), *rhs_ptr.add(index + 3));
1638 index += 4;
1639 }
1640 while index < len {
1641 *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1642 index += 1;
1643 }
1644 }
1645
1646 Ok(Tensor::from_raw_parts(
1647 self.shape(),
1648 self.strides(),
1649 out_data,
1650 ))
1651 }
1652
1653 pub fn relu_inplace(&mut self) {
1657 simd::relu_inplace_dispatch(self.data_mut());
1658 }
1659
1660 pub fn add_inplace(&mut self, rhs: &Self) {
1662 debug_assert_eq!(self.len(), rhs.len());
1663 simd::add_inplace_dispatch(self.data_mut(), rhs.data());
1664 }
1665
1666 pub fn add_scalar_inplace(&mut self, s: f32) {
1668 simd::add_scalar_inplace_dispatch(self.data_mut(), s);
1669 }
1670
1671 pub fn mul_scalar_inplace(&mut self, s: f32) {
1673 simd::mul_scalar_inplace_dispatch(self.data_mut(), s);
1674 }
1675
1676 pub fn add_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1681 debug_assert_eq!(lhs.shape(), rhs.shape());
1682 debug_assert_eq!(lhs.shape(), output.shape());
1683 simd::binary_dispatch(
1684 lhs.data(),
1685 rhs.data(),
1686 output.data_mut(),
1687 simd::BinaryKind::Add,
1688 );
1689 }
1690
1691 pub fn sub_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1694 debug_assert_eq!(lhs.shape(), rhs.shape());
1695 debug_assert_eq!(lhs.shape(), output.shape());
1696 simd::binary_dispatch(
1697 lhs.data(),
1698 rhs.data(),
1699 output.data_mut(),
1700 simd::BinaryKind::Sub,
1701 );
1702 }
1703
1704 pub fn mul_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1707 debug_assert_eq!(lhs.shape(), rhs.shape());
1708 debug_assert_eq!(lhs.shape(), output.shape());
1709 simd::binary_dispatch(
1710 lhs.data(),
1711 rhs.data(),
1712 output.data_mut(),
1713 simd::BinaryKind::Mul,
1714 );
1715 }
1716}
1717
1718#[allow(clippy::excessive_precision)]
1729#[inline(always)]
1730fn fast_atan2_scalar(y: f32, x: f32) -> f32 {
1731 const PI: f32 = std::f32::consts::PI;
1732 const FRAC_PI_2: f32 = std::f32::consts::FRAC_PI_2;
1733 const FRAC_PI_4: f32 = std::f32::consts::FRAC_PI_4;
1734 const TAN_3PI_8: f32 = 2.414_213_6; const TAN_PI_8: f32 = 0.414_213_57; let ax = x.abs();
1738 let ay = y.abs();
1739
1740 let (num, den, swap) = if ax >= ay {
1742 (ay, ax, false)
1743 } else {
1744 (ax, ay, true)
1745 };
1746 let z = if den > 0.0 { num / den } else { 0.0 };
1747
1748 let (z_red, bias) = if z > TAN_3PI_8 {
1750 (-1.0 / z, FRAC_PI_2)
1752 } else if z > TAN_PI_8 {
1753 ((z - 1.0) / (z + 1.0), FRAC_PI_4)
1754 } else {
1755 (z, 0.0)
1756 };
1757
1758 let z2 = z_red * z_red;
1761 let p: f32 = 8.054_666e-02;
1762 let p = p.mul_add(z2, -1.384_895_1e-01);
1763 let p = p.mul_add(z2, 1.997_075_8e-01);
1764 let p = p.mul_add(z2, -3.333_129_8e-01);
1765 let atan_z = z_red.mul_add(z2 * p, z_red) + bias;
1766
1767 let mut result = if swap { FRAC_PI_2 - atan_z } else { atan_z };
1769
1770 if x < 0.0 {
1772 result = PI - result;
1773 }
1774 if y < 0.0 {
1775 result = -result;
1776 }
1777
1778 result
1779}
1780
1781fn f32_to_fp16(val: f32) -> u16 {
1785 let bits = val.to_bits();
1786 let sign = ((bits >> 16) & 0x8000) as u16;
1787 let exponent = ((bits >> 23) & 0xFF) as i32;
1788 let mantissa = bits & 0x007F_FFFF;
1789
1790 if exponent == 255 {
1791 return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
1793 }
1794
1795 let unbiased = exponent - 127;
1796 if unbiased > 15 {
1797 return sign | 0x7C00; }
1799 if unbiased < -24 {
1800 return sign; }
1802 if unbiased < -14 {
1803 let shift = -1 - unbiased;
1805 let m = (mantissa | 0x0080_0000) >> (shift + 13);
1806 return sign | m as u16;
1807 }
1808
1809 let fp16_exp = ((unbiased + 15) as u16) << 10;
1810 let fp16_man = (mantissa >> 13) as u16;
1811 sign | fp16_exp | fp16_man
1812}
1813
1814impl Tensor {
1815 pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self), TensorError> {
1821 if dim >= self.rank() {
1822 return Err(TensorError::InvalidAxis {
1823 axis: dim,
1824 rank: self.rank(),
1825 });
1826 }
1827 let shape = self.shape();
1828 let outer: usize = shape[..dim].iter().product();
1829 let dim_len = shape[dim];
1830 let inner: usize = shape[dim + 1..].iter().product();
1831 let data = self.data();
1832
1833 let mut out_vals = vec![0.0f32; data.len()];
1834 let mut out_idxs = vec![0.0f32; data.len()];
1835
1836 for o in 0..outer {
1837 for i in 0..inner {
1838 let mut idx_vec: Vec<usize> = (0..dim_len).collect();
1839 idx_vec.sort_unstable_by(|&a, &b| {
1840 let va = data[(o * dim_len + a) * inner + i];
1841 let vb = data[(o * dim_len + b) * inner + i];
1842 if descending {
1843 vb.partial_cmp(&va).unwrap_or(std::cmp::Ordering::Equal)
1844 } else {
1845 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
1846 }
1847 });
1848 for (rank, &src) in idx_vec.iter().enumerate() {
1849 let dst = (o * dim_len + rank) * inner + i;
1850 let src_pos = (o * dim_len + src) * inner + i;
1851 out_vals[dst] = data[src_pos];
1852 out_idxs[dst] = src as f32;
1853 }
1854 }
1855 }
1856
1857 let v = Tensor::from_vec(shape.to_vec(), out_vals)?;
1858 let idx = Tensor::from_vec(shape.to_vec(), out_idxs)?;
1859 Ok((v, idx))
1860 }
1861
1862 pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self, TensorError> {
1864 let (_, indices) = self.sort(dim, descending)?;
1865 Ok(indices)
1866 }
1867
1868 pub fn unique(&self) -> (Self, Self, Self) {
1870 let data = self.data();
1871 let mut sorted: Vec<f32> = data.to_vec();
1872 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1873 sorted.dedup();
1874
1875 let mut inverse = vec![0.0f32; data.len()];
1876 let mut counts = vec![0.0f32; sorted.len()];
1877 for (i, &v) in data.iter().enumerate() {
1878 let pos = sorted
1879 .binary_search_by(|probe| {
1880 probe.partial_cmp(&v).unwrap_or(std::cmp::Ordering::Equal)
1881 })
1882 .expect("value exists in sorted list");
1883 inverse[i] = pos as f32;
1884 counts[pos] += 1.0;
1885 }
1886
1887 let vals = Tensor::from_vec(vec![sorted.len()], sorted).expect("unique vals");
1888 let inv = Tensor::from_vec(self.shape().to_vec(), inverse).expect("unique inv");
1889 let cnt = Tensor::from_vec(vec![counts.len()], counts).expect("unique counts");
1890 (vals, inv, cnt)
1891 }
1892
1893 pub fn nonzero(&self) -> Self {
1895 let shape = self.shape();
1896 let rank = shape.len().max(1);
1897 let data = self.data();
1898 let mut coords: Vec<Vec<usize>> = Vec::new();
1899
1900 if shape.is_empty() {
1901 if data[0] != 0.0 {
1903 coords.push(vec![0]);
1904 }
1905 } else {
1906 let mut idx = vec![0usize; shape.len()];
1907 for pos in 0..data.len() {
1908 if data[pos] != 0.0 {
1909 coords.push(idx.clone());
1910 }
1911 for d in (0..shape.len()).rev() {
1913 idx[d] += 1;
1914 if idx[d] < shape[d] {
1915 break;
1916 }
1917 idx[d] = 0;
1918 }
1919 }
1920 }
1921
1922 let n = coords.len();
1923 let mut flat = Vec::with_capacity(n * rank);
1924 for c in &coords {
1925 for &v in c {
1926 flat.push(v as f32);
1927 }
1928 }
1929 if n == 0 {
1930 Tensor::from_vec(vec![0, rank], flat).expect("nonzero empty")
1931 } else {
1932 Tensor::from_vec(vec![n, rank], flat).expect("nonzero")
1933 }
1934 }
1935
1936 pub fn flip(&self, dims: &[usize]) -> Result<Self, TensorError> {
1940 for &d in dims {
1941 if d >= self.rank() {
1942 return Err(TensorError::InvalidAxis {
1943 axis: d,
1944 rank: self.rank(),
1945 });
1946 }
1947 }
1948 let shape = self.shape();
1949 let data = self.data();
1950 let total = data.len();
1951 let mut out = vec![0.0f32; total];
1952 let rank = shape.len();
1953
1954 let mut src_idx = vec![0usize; rank];
1955 for pos in 0..total {
1956 let mut dst_idx = src_idx.clone();
1958 for &d in dims {
1959 dst_idx[d] = shape[d] - 1 - src_idx[d];
1960 }
1961 let mut dst_pos = 0;
1963 let mut stride = 1;
1964 for d in (0..rank).rev() {
1965 dst_pos += dst_idx[d] * stride;
1966 stride *= shape[d];
1967 }
1968 out[dst_pos] = data[pos];
1969
1970 for d in (0..rank).rev() {
1972 src_idx[d] += 1;
1973 if src_idx[d] < shape[d] {
1974 break;
1975 }
1976 src_idx[d] = 0;
1977 }
1978 }
1979 Tensor::from_vec(shape.to_vec(), out)
1980 }
1981
1982 pub fn roll(&self, shift: i64, dim: usize) -> Result<Self, TensorError> {
1984 if dim >= self.rank() {
1985 return Err(TensorError::InvalidAxis {
1986 axis: dim,
1987 rank: self.rank(),
1988 });
1989 }
1990 let shape = self.shape();
1991 let outer: usize = shape[..dim].iter().product();
1992 let dim_len = shape[dim];
1993 let inner: usize = shape[dim + 1..].iter().product();
1994 let data = self.data();
1995
1996 let mut out = vec![0.0f32; data.len()];
1997 for o in 0..outer {
1998 for d in 0..dim_len {
1999 let dst_d = ((d as i64 + shift).rem_euclid(dim_len as i64)) as usize;
2000 for i in 0..inner {
2001 out[(o * dim_len + dst_d) * inner + i] = data[(o * dim_len + d) * inner + i];
2002 }
2003 }
2004 }
2005 Tensor::from_vec(shape.to_vec(), out)
2006 }
2007
2008 pub fn linspace(start: f32, end: f32, steps: usize) -> Result<Self, TensorError> {
2012 if steps == 0 {
2013 return Tensor::from_vec(vec![0], vec![]);
2014 }
2015 if steps == 1 {
2016 return Tensor::from_vec(vec![1], vec![start]);
2017 }
2018 let step = (end - start) / (steps - 1) as f32;
2019 let data: Vec<f32> = (0..steps).map(|i| start + step * i as f32).collect();
2020 Tensor::from_vec(vec![steps], data)
2021 }
2022
2023 pub fn arange(start: f32, end: f32, step: f32) -> Result<Self, TensorError> {
2025 if step == 0.0 {
2026 return Err(TensorError::ShapeMismatch {
2027 left: vec![],
2028 right: vec![],
2029 });
2030 }
2031 let mut data = Vec::new();
2032 let mut v = start;
2033 if step > 0.0 {
2034 while v < end {
2035 data.push(v);
2036 v += step;
2037 }
2038 } else {
2039 while v > end {
2040 data.push(v);
2041 v += step;
2042 }
2043 }
2044 let n = data.len();
2045 Tensor::from_vec(vec![n], data)
2046 }
2047
2048 pub fn meshgrid(tensors: &[Self]) -> Result<Vec<Self>, TensorError> {
2050 let shape: Vec<usize> = tensors.iter().map(|t| t.len()).collect();
2051 let total: usize = shape.iter().product();
2052 let n = tensors.len();
2053 let mut result = Vec::with_capacity(n);
2054
2055 for (idx, t) in tensors.iter().enumerate() {
2056 let t_data = t.data();
2057 let mut out = vec![0.0f32; total];
2058 let inner: usize = shape[idx + 1..].iter().product();
2060 let outer: usize = shape[..idx].iter().product();
2061 let dim_len = shape[idx];
2062 for o in 0..outer {
2063 for d in 0..dim_len {
2064 for i in 0..inner {
2065 out[(o * dim_len + d) * inner + i] = t_data[d];
2066 }
2067 }
2068 }
2069 result.push(Tensor::from_vec(shape.clone(), out)?);
2070 }
2071 Ok(result)
2072 }
2073
2074 pub fn boolean_mask(&self, mask: &Self) -> Result<Self, TensorError> {
2078 if self.shape() != mask.shape() {
2079 return Err(TensorError::ShapeMismatch {
2080 left: self.shape().to_vec(),
2081 right: mask.shape().to_vec(),
2082 });
2083 }
2084 let data = self.data();
2085 let m = mask.data();
2086 let out: Vec<f32> = data
2087 .iter()
2088 .zip(m.iter())
2089 .filter(|(_, mv)| **mv != 0.0)
2090 .map(|(v, _)| *v)
2091 .collect();
2092 let n = out.len();
2093 Tensor::from_vec(vec![n], out)
2094 }
2095
2096 pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self, TensorError> {
2098 if dim >= self.rank() {
2099 return Err(TensorError::InvalidAxis {
2100 axis: dim,
2101 rank: self.rank(),
2102 });
2103 }
2104 let shape = self.shape();
2105 let idx_data = indices.data();
2106 let n_idx = idx_data.len();
2107 let outer: usize = shape[..dim].iter().product();
2108 let dim_len = shape[dim];
2109 let inner: usize = shape[dim + 1..].iter().product();
2110 let data = self.data();
2111
2112 let mut out = Vec::with_capacity(outer * n_idx * inner);
2113 for o in 0..outer {
2114 for &idx_f in idx_data {
2115 let idx = idx_f as usize;
2116 if idx >= dim_len {
2117 return Err(TensorError::IndexOutOfBounds {
2118 axis: dim,
2119 index: idx,
2120 dim: dim_len,
2121 });
2122 }
2123 let src_start = (o * dim_len + idx) * inner;
2124 out.extend_from_slice(&data[src_start..src_start + inner]);
2125 }
2126 }
2127
2128 let mut out_shape = shape.to_vec();
2129 out_shape[dim] = n_idx;
2130 Tensor::from_vec(out_shape, out)
2131 }
2132
2133 pub fn rand(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2137 let n: usize = shape.iter().product();
2138 let mut rng = seed;
2139 let data: Vec<f32> = (0..n)
2140 .map(|_| {
2141 rng ^= rng << 13;
2142 rng ^= rng >> 7;
2143 rng ^= rng << 17;
2144 (rng as f32) / (u64::MAX as f32)
2145 })
2146 .collect();
2147 Self::from_vec(shape, data)
2148 }
2149
2150 pub fn randn(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2153 let n: usize = shape.iter().product();
2154 let mut rng = seed;
2155 let mut next_rng = || -> f32 {
2156 rng ^= rng << 13;
2157 rng ^= rng >> 7;
2158 rng ^= rng << 17;
2159 ((rng as f64) / (u64::MAX as f64)).clamp(1e-15, 1.0 - 1e-15) as f32
2161 };
2162 let mut data = Vec::with_capacity(n);
2163 let mut i = 0;
2164 while i < n {
2165 let u1 = next_rng();
2166 let u2 = next_rng();
2167 let r = (-2.0 * (u1 as f64).ln()).sqrt();
2168 let theta = 2.0 * std::f64::consts::PI * u2 as f64;
2169 data.push((r * theta.cos()) as f32);
2170 i += 1;
2171 if i < n {
2172 data.push((r * theta.sin()) as f32);
2173 i += 1;
2174 }
2175 }
2176 Self::from_vec(shape, data)
2177 }
2178
2179 pub fn randint(shape: Vec<usize>, low: i64, high: i64, seed: u64) -> Result<Self, TensorError> {
2181 if high <= low {
2182 return Err(TensorError::UnsupportedOperation {
2183 msg: format!("randint requires high > low, got low={low}, high={high}"),
2184 });
2185 }
2186 let range = (high - low) as u64;
2187 let n: usize = shape.iter().product();
2188 let mut rng = seed;
2189 let data: Vec<f32> = (0..n)
2190 .map(|_| {
2191 rng ^= rng << 13;
2192 rng ^= rng >> 7;
2193 rng ^= rng << 17;
2194 (low + (rng % range) as i64) as f32
2195 })
2196 .collect();
2197 Self::from_vec(shape, data)
2198 }
2199
2200 pub fn randperm(n: usize, seed: u64) -> Result<Self, TensorError> {
2202 let mut perm: Vec<f32> = (0..n).map(|i| i as f32).collect();
2203 let mut rng = seed;
2204 for i in (1..n).rev() {
2205 rng ^= rng << 13;
2206 rng ^= rng >> 7;
2207 rng ^= rng << 17;
2208 let j = (rng as usize) % (i + 1);
2209 perm.swap(i, j);
2210 }
2211 Self::from_vec(vec![n], perm)
2212 }
2213}
2214
2215impl Tensor {
2218 pub fn step_slice(
2220 &self,
2221 dim: usize,
2222 start: usize,
2223 end: usize,
2224 step: usize,
2225 ) -> Result<Self, TensorError> {
2226 let rank = self.rank();
2227 if dim >= rank {
2228 return Err(TensorError::InvalidAxis { axis: dim, rank });
2229 }
2230 if step == 0 {
2231 return Err(TensorError::UnsupportedOperation {
2232 msg: "step must be > 0".to_string(),
2233 });
2234 }
2235 let shape = self.shape();
2236 let dim_len = shape[dim];
2237 let end = end.min(dim_len);
2238 if start >= end {
2239 let mut out_shape = shape.to_vec();
2241 out_shape[dim] = 0;
2242 return Tensor::from_vec(out_shape, vec![]);
2243 }
2244
2245 let selected_indices: Vec<usize> = (start..end).step_by(step).collect();
2246 let new_dim = selected_indices.len();
2247
2248 let outer: usize = shape[..dim].iter().product();
2249 let inner: usize = shape[dim + 1..].iter().product();
2250 let data = self.data();
2251
2252 let mut out = Vec::with_capacity(outer * new_dim * inner);
2253 for o in 0..outer {
2254 for &idx in &selected_indices {
2255 let src_start = (o * dim_len + idx) * inner;
2256 out.extend_from_slice(&data[src_start..src_start + inner]);
2257 }
2258 }
2259
2260 let mut out_shape = shape.to_vec();
2261 out_shape[dim] = new_dim;
2262 Tensor::from_vec(out_shape, out)
2263 }
2264
2265 pub fn einsum(equation: &str, tensors: &[&Tensor]) -> Result<Tensor, TensorError> {
2277 let equation = equation.replace(' ', "");
2278 let parts: Vec<&str> = equation.split("->").collect();
2279 if parts.len() != 2 {
2280 return Err(TensorError::UnsupportedOperation {
2281 msg: format!("invalid einsum equation: {equation}"),
2282 });
2283 }
2284 let inputs_str = parts[0];
2285 let output_str = parts[1];
2286 let input_parts: Vec<&str> = inputs_str.split(',').collect();
2287
2288 if input_parts.len() != tensors.len() {
2289 return Err(TensorError::UnsupportedOperation {
2290 msg: format!(
2291 "einsum equation has {} inputs but {} tensors provided",
2292 input_parts.len(),
2293 tensors.len()
2294 ),
2295 });
2296 }
2297
2298 let pattern = format!(
2300 "{}->{}",
2301 input_parts
2302 .iter()
2303 .map(|s| s.to_string())
2304 .collect::<Vec<_>>()
2305 .join(","),
2306 output_str
2307 );
2308
2309 match pattern.as_str() {
2310 "ij,jk->ik" => {
2312 let a = tensors[0];
2313 let b = tensors[1];
2314 if a.rank() != 2 || b.rank() != 2 {
2315 return Err(TensorError::UnsupportedOperation {
2316 msg: "ij,jk->ik requires 2D tensors".to_string(),
2317 });
2318 }
2319 let (m, k1) = (a.shape()[0], a.shape()[1]);
2320 let (k2, n) = (b.shape()[0], b.shape()[1]);
2321 if k1 != k2 {
2322 return Err(TensorError::ShapeMismatch {
2323 left: a.shape().to_vec(),
2324 right: b.shape().to_vec(),
2325 });
2326 }
2327 let ad = a.data();
2328 let bd = b.data();
2329 let mut out = vec![0.0f32; m * n];
2330 for i in 0..m {
2331 for j in 0..n {
2332 let mut sum = 0.0f32;
2333 for p in 0..k1 {
2334 sum += ad[i * k1 + p] * bd[p * n + j];
2335 }
2336 out[i * n + j] = sum;
2337 }
2338 }
2339 Tensor::from_vec(vec![m, n], out)
2340 }
2341 "ij->ji" => {
2343 let a = tensors[0];
2344 if a.rank() != 2 {
2345 return Err(TensorError::UnsupportedOperation {
2346 msg: "ij->ji requires a 2D tensor".to_string(),
2347 });
2348 }
2349 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2350 let ad = a.data();
2351 let mut out = vec![0.0f32; rows * cols];
2352 for i in 0..rows {
2353 for j in 0..cols {
2354 out[j * rows + i] = ad[i * cols + j];
2355 }
2356 }
2357 Tensor::from_vec(vec![cols, rows], out)
2358 }
2359 "ii->i" => {
2361 let a = tensors[0];
2362 if a.rank() != 2 || a.shape()[0] != a.shape()[1] {
2363 return Err(TensorError::UnsupportedOperation {
2364 msg: "ii->i requires a square 2D tensor".to_string(),
2365 });
2366 }
2367 let n = a.shape()[0];
2368 let ad = a.data();
2369 let out: Vec<f32> = (0..n).map(|i| ad[i * n + i]).collect();
2370 Tensor::from_vec(vec![n], out)
2371 }
2372 "ij->i" => {
2374 let a = tensors[0];
2375 if a.rank() != 2 {
2376 return Err(TensorError::UnsupportedOperation {
2377 msg: "ij->i requires a 2D tensor".to_string(),
2378 });
2379 }
2380 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2381 let ad = a.data();
2382 let out: Vec<f32> = (0..rows)
2383 .map(|i| ad[i * cols..(i + 1) * cols].iter().sum())
2384 .collect();
2385 Tensor::from_vec(vec![rows], out)
2386 }
2387 "ij->j" => {
2389 let a = tensors[0];
2390 if a.rank() != 2 {
2391 return Err(TensorError::UnsupportedOperation {
2392 msg: "ij->j requires a 2D tensor".to_string(),
2393 });
2394 }
2395 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2396 let ad = a.data();
2397 let mut out = vec![0.0f32; cols];
2398 for i in 0..rows {
2399 for j in 0..cols {
2400 out[j] += ad[i * cols + j];
2401 }
2402 }
2403 Tensor::from_vec(vec![cols], out)
2404 }
2405 "ij->" => {
2407 let a = tensors[0];
2408 if a.rank() != 2 {
2409 return Err(TensorError::UnsupportedOperation {
2410 msg: "ij-> requires a 2D tensor".to_string(),
2411 });
2412 }
2413 let sum: f32 = a.data().iter().sum();
2414 Ok(Tensor::scalar(sum))
2415 }
2416 "i,i->" => {
2418 let a = tensors[0];
2419 let b = tensors[1];
2420 if a.rank() != 1 || b.rank() != 1 {
2421 return Err(TensorError::UnsupportedOperation {
2422 msg: "i,i-> requires 1D tensors".to_string(),
2423 });
2424 }
2425 if a.shape()[0] != b.shape()[0] {
2426 return Err(TensorError::ShapeMismatch {
2427 left: a.shape().to_vec(),
2428 right: b.shape().to_vec(),
2429 });
2430 }
2431 let sum: f32 = a
2432 .data()
2433 .iter()
2434 .zip(b.data().iter())
2435 .map(|(x, y)| x * y)
2436 .sum();
2437 Ok(Tensor::scalar(sum))
2438 }
2439 "ij,ij->" => {
2441 let a = tensors[0];
2442 let b = tensors[1];
2443 if a.rank() != 2 || b.rank() != 2 {
2444 return Err(TensorError::UnsupportedOperation {
2445 msg: "ij,ij-> requires 2D tensors".to_string(),
2446 });
2447 }
2448 if a.shape() != b.shape() {
2449 return Err(TensorError::ShapeMismatch {
2450 left: a.shape().to_vec(),
2451 right: b.shape().to_vec(),
2452 });
2453 }
2454 let sum: f32 = a
2455 .data()
2456 .iter()
2457 .zip(b.data().iter())
2458 .map(|(x, y)| x * y)
2459 .sum();
2460 Ok(Tensor::scalar(sum))
2461 }
2462 _ => Err(TensorError::UnsupportedOperation {
2463 msg: format!("unsupported einsum pattern: {pattern}"),
2464 }),
2465 }
2466 }
2467
2468 pub fn chunk(&self, n_chunks: usize, axis: usize) -> Result<Vec<Self>, TensorError> {
2472 if axis >= self.rank() {
2473 return Err(TensorError::InvalidAxis {
2474 axis,
2475 rank: self.rank(),
2476 });
2477 }
2478 if n_chunks == 0 {
2479 return Err(TensorError::UnsupportedOperation {
2480 msg: "n_chunks must be > 0".to_string(),
2481 });
2482 }
2483 let dim = self.shape()[axis];
2484 let chunk_size = dim.div_ceil(n_chunks); let mut chunks = Vec::new();
2486 let mut start = 0;
2487 while start < dim {
2488 let length = chunk_size.min(dim - start);
2489 chunks.push(self.narrow(axis, start, length)?);
2490 start += length;
2491 }
2492 Ok(chunks)
2493 }
2494
2495 pub fn histogram(&self, bins: usize, min: f32, max: f32) -> Result<Self, TensorError> {
2500 let mut counts = vec![0.0f32; bins];
2501 let range = max - min;
2502 for &v in self.data() {
2503 if v >= min && v <= max {
2504 let idx = if v == max {
2505 bins - 1
2506 } else {
2507 ((v - min) / range * bins as f32) as usize
2508 };
2509 counts[idx] += 1.0;
2510 }
2511 }
2512 Tensor::from_vec(vec![bins], counts)
2513 }
2514
2515 pub fn bincount(&self, num_bins: usize) -> Result<Self, TensorError> {
2520 let mut counts = vec![0.0f32; num_bins];
2521 for &v in self.data() {
2522 let idx = v as usize;
2523 if idx < num_bins {
2524 counts[idx] += 1.0;
2525 }
2526 }
2527 Tensor::from_vec(vec![num_bins], counts)
2528 }
2529
2530 pub fn item(&self) -> Result<f32, TensorError> {
2535 if self.len() != 1 {
2536 return Err(TensorError::ShapeMismatch {
2537 left: vec![1],
2538 right: self.shape().to_vec(),
2539 });
2540 }
2541 Ok(self.data()[0])
2542 }
2543
2544 pub fn is_scalar(&self) -> bool {
2546 self.len() == 1
2547 }
2548
2549 pub fn scatter_add(&self, dim: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
2555 if dim >= self.rank() {
2556 return Err(TensorError::InvalidAxis {
2557 axis: dim,
2558 rank: self.rank(),
2559 });
2560 }
2561 if index.rank() != self.rank() {
2562 return Err(TensorError::InvalidIndexRank {
2563 expected: self.rank(),
2564 got: index.rank(),
2565 });
2566 }
2567 if src.shape() != index.shape() {
2568 return Err(TensorError::ShapeMismatch {
2569 left: src.shape().to_vec(),
2570 right: index.shape().to_vec(),
2571 });
2572 }
2573
2574 let self_shape = self.shape();
2575 let idx_shape = index.shape();
2576 let idx_data = index.data();
2577 let src_data = src.data();
2578 let ndim = self.rank();
2579
2580 let mut out = self.data().to_vec();
2581 let mut coords = vec![0usize; ndim];
2582
2583 for pos in 0..index.len() {
2584 let idx_val = idx_data[pos] as usize;
2585 if idx_val >= self_shape[dim] {
2586 return Err(TensorError::IndexOutOfBounds {
2587 axis: dim,
2588 index: idx_val,
2589 dim: self_shape[dim],
2590 });
2591 }
2592
2593 let mut dst_offset = 0;
2594 for d in 0..ndim {
2595 let c = if d == dim { idx_val } else { coords[d] };
2596 dst_offset += c * self.strides()[d];
2597 }
2598 out[dst_offset] += src_data[pos];
2599
2600 increment_coords(&mut coords, idx_shape);
2601 }
2602
2603 Tensor::from_vec(self_shape.to_vec(), out)
2604 }
2605}
2606
2607fn fp16_to_f32(half: u16) -> f32 {
2609 let sign = ((half & 0x8000) as u32) << 16;
2610 let exponent = (half >> 10) & 0x1F;
2611 let mantissa = (half & 0x03FF) as u32;
2612
2613 if exponent == 0 {
2614 if mantissa == 0 {
2615 return f32::from_bits(sign); }
2617 let mut m = mantissa;
2619 let mut e = 0i32;
2620 while m & 0x0400 == 0 {
2621 m <<= 1;
2622 e += 1;
2623 }
2624 m &= 0x03FF;
2625 let f32_exp = ((127 - 15 - e) as u32) << 23;
2626 let f32_man = m << 13;
2627 return f32::from_bits(sign | f32_exp | f32_man);
2628 }
2629 if exponent == 31 {
2630 let f32_exp = 0xFF << 23;
2631 let f32_man = mantissa << 13;
2632 return f32::from_bits(sign | f32_exp | f32_man);
2633 }
2634
2635 let f32_exp = ((exponent as i32 - 15 + 127) as u32 & 0xFF) << 23;
2636 let f32_man = mantissa << 13;
2637 f32::from_bits(sign | f32_exp | f32_man)
2638}