1use super::aligned::AlignedVec;
2use super::error::TensorError;
3use super::shape::{
4 broadcast_offset, broadcast_shape, compute_strides, increment_coords, shape_element_count,
5};
6use super::simd;
7use super::tensor::Tensor;
8
9impl Tensor {
10 pub fn add(&self, rhs: &Self) -> Result<Self, TensorError> {
14 if self.shape() == rhs.shape() {
15 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Add);
16 }
17 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Add) {
18 return result;
19 }
20 self.binary_broadcast_op(rhs, |l, r| l + r)
21 }
22
23 pub fn sub(&self, rhs: &Self) -> Result<Self, TensorError> {
25 if self.shape() == rhs.shape() {
26 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Sub);
27 }
28 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Sub) {
29 return result;
30 }
31 self.binary_broadcast_op(rhs, |l, r| l - r)
32 }
33
34 pub fn mul(&self, rhs: &Self) -> Result<Self, TensorError> {
36 if self.shape() == rhs.shape() {
37 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Mul);
38 }
39 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Mul) {
40 return result;
41 }
42 self.binary_broadcast_op(rhs, |l, r| l * r)
43 }
44
45 pub fn div(&self, rhs: &Self) -> Result<Self, TensorError> {
47 if self.shape() == rhs.shape() {
48 return self.binary_same_shape_simd(rhs, simd::BinaryKind::Div);
49 }
50 if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Div) {
51 return result;
52 }
53 self.binary_broadcast_op(rhs, |l, r| l / r)
54 }
55
56 #[allow(unsafe_code)]
63 pub fn pow(&self, rhs: &Self) -> Result<Self, TensorError> {
64 let rhs_total: usize = rhs.shape().iter().product();
69 if rhs_total == 1 {
70 let exp_val = rhs.data()[0];
71 if exp_val == 2.0 {
72 return self.mul(self);
73 }
74 if exp_val == 0.5 {
75 return Ok(self.sqrt());
76 }
77 if exp_val == 1.0 {
78 return Ok(self.clone());
79 }
80 if exp_val == 0.0 {
81 return Tensor::ones(self.shape().to_vec());
82 }
83 if exp_val == -1.0 {
84 return Ok(self.reciprocal());
85 }
86 }
87 if self.shape() == rhs.shape() {
89 return self.pow_simd(rhs);
90 }
91 self.binary_broadcast_op(rhs, |l, r| l.powf(r))
92 }
93
94 #[allow(unsafe_code)]
96 fn pow_simd(&self, rhs: &Self) -> Result<Self, TensorError> {
97 let len = self.len();
98 let mut ln_buf = AlignedVec::<f32>::uninitialized(len);
100 simd::ln_dispatch(self.data(), &mut ln_buf);
101 let mut prod_buf = AlignedVec::<f32>::uninitialized(len);
103 simd::binary_dispatch(&ln_buf, rhs.data(), &mut prod_buf, simd::BinaryKind::Mul);
104 let mut out = AlignedVec::<f32>::uninitialized(len);
106 simd::exp_dispatch(&prod_buf, &mut out);
107 Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
108 }
109
110 #[allow(unsafe_code)]
115 pub fn atan2(&self, rhs: &Self) -> Result<Self, TensorError> {
116 if self.shape() == rhs.shape() {
117 return self.atan2_fast(rhs);
118 }
119 self.binary_broadcast_op(rhs, f32::atan2)
120 }
121
122 #[allow(unsafe_code)]
126 fn atan2_fast(&self, rhs: &Self) -> Result<Self, TensorError> {
127 let y_data = self.data();
128 let x_data = rhs.data();
129 let len = self.len();
130 let mut out = AlignedVec::<f32>::uninitialized(len);
131
132 let mut i = 0;
136 while i + 4 <= len {
137 out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
138 out[i + 1] = fast_atan2_scalar(y_data[i + 1], x_data[i + 1]);
139 out[i + 2] = fast_atan2_scalar(y_data[i + 2], x_data[i + 2]);
140 out[i + 3] = fast_atan2_scalar(y_data[i + 3], x_data[i + 3]);
141 i += 4;
142 }
143 while i < len {
144 out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
145 i += 1;
146 }
147
148 Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
149 }
150
151 pub fn minimum(&self, rhs: &Self) -> Result<Self, TensorError> {
153 if self.shape() == rhs.shape() {
154 return self.binary_same_shape(rhs, f32::min);
155 }
156 self.binary_broadcast_op(rhs, f32::min)
157 }
158
159 pub fn maximum(&self, rhs: &Self) -> Result<Self, TensorError> {
161 if self.shape() == rhs.shape() {
162 return self.binary_same_shape(rhs, f32::max);
163 }
164 self.binary_broadcast_op(rhs, f32::max)
165 }
166
167 pub fn neg(&self) -> Self {
171 self.unary_simd_op(simd::UnaryKind::Neg)
172 }
173
174 pub fn abs(&self) -> Self {
176 self.unary_simd_op(simd::UnaryKind::Abs)
177 }
178
179 #[allow(unsafe_code)]
181 pub fn exp(&self) -> Self {
182 let len = self.len();
183 let mut out = AlignedVec::<f32>::uninitialized(len);
186 simd::exp_dispatch(self.data(), &mut out);
187 Tensor::from_raw_parts(self.shape(), self.strides(), out)
188 }
189
190 #[allow(unsafe_code)]
192 pub fn ln(&self) -> Self {
193 let len = self.len();
194 let mut out = AlignedVec::<f32>::uninitialized(len);
197 simd::ln_dispatch(self.data(), &mut out);
198 Tensor::from_raw_parts(self.shape(), self.strides(), out)
199 }
200
201 pub fn sqrt(&self) -> Self {
203 self.unary_simd_op(simd::UnaryKind::Sqrt)
204 }
205
206 pub fn reciprocal(&self) -> Self {
208 self.unary_simd_op(simd::UnaryKind::Recip)
209 }
210
211 pub fn sign(&self) -> Self {
213 self.unary_simd_op(simd::UnaryKind::Sign)
214 }
215
216 pub fn floor(&self) -> Self {
218 self.unary_simd_op(simd::UnaryKind::Floor)
219 }
220
221 pub fn ceil(&self) -> Self {
223 self.unary_simd_op(simd::UnaryKind::Ceil)
224 }
225
226 pub fn round(&self) -> Self {
228 self.unary_simd_op(simd::UnaryKind::Round)
229 }
230
231 #[allow(unsafe_code)]
233 pub fn sin(&self) -> Self {
234 let len = self.len();
235 let mut out = AlignedVec::<f32>::uninitialized(len);
236 simd::sin_dispatch(self.data(), &mut out);
237 Tensor::from_raw_parts(self.shape(), self.strides(), out)
238 }
239
240 #[allow(unsafe_code)]
242 pub fn cos(&self) -> Self {
243 let len = self.len();
244 let mut out = AlignedVec::<f32>::uninitialized(len);
245 simd::cos_dispatch(self.data(), &mut out);
246 Tensor::from_raw_parts(self.shape(), self.strides(), out)
247 }
248
249 #[allow(unsafe_code)]
251 pub fn tan(&self) -> Self {
252 let len = self.len();
253 let mut out = AlignedVec::<f32>::uninitialized(len);
254 simd::tan_dispatch(self.data(), &mut out);
255 Tensor::from_raw_parts(self.shape(), self.strides(), out)
256 }
257
258 pub fn asin(&self) -> Self {
260 self.unary_op(f32::asin)
261 }
262
263 pub fn acos(&self) -> Self {
265 self.unary_op(f32::acos)
266 }
267
268 pub fn atan(&self) -> Self {
270 self.unary_op(f32::atan)
271 }
272
273 pub fn sinh(&self) -> Self {
275 self.unary_op(f32::sinh)
276 }
277
278 pub fn cosh(&self) -> Self {
280 self.unary_op(f32::cosh)
281 }
282
283 pub fn log2(&self) -> Self {
285 self.unary_op(f32::log2)
286 }
287
288 pub fn log10(&self) -> Self {
290 self.unary_op(f32::log10)
291 }
292
293 pub fn degrees(&self) -> Self {
295 self.unary_op(|v| v.to_degrees())
296 }
297
298 pub fn radians(&self) -> Self {
300 self.unary_op(|v| v.to_radians())
301 }
302
303 #[allow(unsafe_code)]
305 pub fn clamp(&self, min: f32, max: f32) -> Self {
306 let len = self.len();
307 let mut out = AlignedVec::<f32>::uninitialized(len);
308 simd::clamp_dispatch(self.data(), &mut out, min, max);
309 Tensor::from_raw_parts(self.shape(), self.strides(), out)
310 }
311
312 pub fn scale(&self, factor: f32) -> Self {
314 self.unary_op(|v| v * factor)
315 }
316
317 pub fn add_scalar(&self, value: f32) -> Self {
319 self.unary_op(|v| v + value)
320 }
321
322 pub fn sum(&self) -> f32 {
326 simd::sum_dispatch(self.data())
327 }
328
329 pub fn mean(&self) -> f32 {
331 if self.is_empty() {
332 return f32::NAN;
333 }
334 self.sum() / self.len() as f32
335 }
336
337 pub fn max_value(&self) -> f32 {
339 simd::max_dispatch(self.data())
340 }
341
342 pub fn min_value(&self) -> f32 {
344 simd::min_dispatch(self.data())
345 }
346
347 pub fn argmax(&self) -> Option<usize> {
349 if self.is_empty() {
350 return None;
351 }
352 let mut best = 0;
353 let mut best_val = self.data()[0];
354 for (i, &v) in self.data().iter().enumerate().skip(1) {
355 if v > best_val {
356 best_val = v;
357 best = i;
358 }
359 }
360 Some(best)
361 }
362
363 pub fn argmin(&self) -> Option<usize> {
365 if self.is_empty() {
366 return None;
367 }
368 let mut best = 0;
369 let mut best_val = self.data()[0];
370 for (i, &v) in self.data().iter().enumerate().skip(1) {
371 if v < best_val {
372 best_val = v;
373 best = i;
374 }
375 }
376 Some(best)
377 }
378
379 pub fn var(&self) -> f32 {
381 if self.is_empty() {
382 return f32::NAN;
383 }
384 let m = self.mean();
385 self.data().iter().map(|&v| (v - m) * (v - m)).sum::<f32>() / self.len() as f32
386 }
387
388 pub fn std_dev(&self) -> f32 {
390 self.var().sqrt()
391 }
392
393 pub fn sum_axis(&self, axis: usize) -> Result<Self, TensorError> {
395 let shape = self.shape();
396 let rank = shape.len();
397 if axis >= rank {
398 return Err(TensorError::InvalidAxis { axis, rank });
399 }
400
401 if rank == 2 && axis == 0 {
403 let (rows, cols) = (shape[0], shape[1]);
404 let data = self.data();
405 let mut out = vec![0.0f32; cols];
406 for row in 0..rows {
407 let row_start = row * cols;
408 simd::add_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
409 }
410 return Self::from_vec(vec![cols], out);
411 }
412
413 if rank == 2 && axis == 1 {
415 let (rows, cols) = (shape[0], shape[1]);
416 let data = self.data();
417 let mut out = Vec::with_capacity(rows);
418 for row in 0..rows {
419 out.push(simd::sum_dispatch(&data[row * cols..(row + 1) * cols]));
420 }
421 return Self::from_vec(vec![rows], out);
422 }
423
424 self.reduce_axis(axis, 0.0, |acc, v| acc + v)
425 }
426
427 pub fn mean_axis(&self, axis: usize) -> Result<Self, TensorError> {
429 if axis >= self.rank() {
430 return Err(TensorError::InvalidAxis {
431 axis,
432 rank: self.rank(),
433 });
434 }
435 let axis_len = self.shape()[axis] as f32;
436 let sum = self.sum_axis(axis)?;
437 Ok(sum.scale(1.0 / axis_len))
438 }
439
440 pub fn max_axis(&self, axis: usize) -> Result<Self, TensorError> {
442 let shape = self.shape();
443 let rank = shape.len();
444 if axis >= rank {
445 return Err(TensorError::InvalidAxis { axis, rank });
446 }
447
448 if rank == 2 && axis == 0 {
450 let (rows, cols) = (shape[0], shape[1]);
451 let data = self.data();
452 let mut out = data[..cols].to_vec();
453 for row in 1..rows {
454 let row_start = row * cols;
455 simd::max_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
456 }
457 return Self::from_vec(vec![cols], out);
458 }
459
460 if rank == 2 && axis == 1 {
462 let (rows, cols) = (shape[0], shape[1]);
463 let data = self.data();
464 let mut out = Vec::with_capacity(rows);
465 for row in 0..rows {
466 out.push(simd::max_dispatch(&data[row * cols..(row + 1) * cols]));
467 }
468 return Self::from_vec(vec![rows], out);
469 }
470
471 self.reduce_axis(axis, f32::NEG_INFINITY, f32::max)
472 }
473
474 pub fn min_axis(&self, axis: usize) -> Result<Self, TensorError> {
476 self.reduce_axis(axis, f32::INFINITY, f32::min)
477 }
478
479 pub fn var_axis(&self, axis: usize) -> Result<Self, TensorError> {
481 let m = self.mean_axis(axis)?;
482 let diff = self.sub(&m.unsqueeze(axis)?)?;
483 let sq = diff.mul(&diff)?;
484 sq.mean_axis(axis)
485 }
486
487 pub fn median(&self) -> f32 {
489 let mut sorted = self.data().to_vec();
490 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491 let n = sorted.len();
492 if n == 0 {
493 return 0.0;
494 }
495 if n % 2 == 1 {
496 sorted[n / 2]
497 } else {
498 (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
499 }
500 }
501
502 pub fn median_axis(&self, axis: usize) -> Result<Self, TensorError> {
504 let shape = self.shape();
505 let rank = shape.len();
506 if axis >= rank {
507 return Err(TensorError::InvalidAxis { axis, rank });
508 }
509 let axis_len = shape[axis];
510 let outer: usize = shape[..axis].iter().product();
511 let inner: usize = shape[axis + 1..].iter().product();
512 let mut new_shape = shape.to_vec();
513 new_shape.remove(axis);
514 if new_shape.is_empty() {
515 new_shape.push(1);
516 }
517 let data = self.data();
518 let mut result = Vec::with_capacity(outer * inner);
519 for o in 0..outer {
520 for i in 0..inner {
521 let mut vals: Vec<f32> = (0..axis_len)
522 .map(|a| data[o * axis_len * inner + a * inner + i])
523 .collect();
524 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525 let n = vals.len();
526 let med = if n % 2 == 1 {
527 vals[n / 2]
528 } else {
529 (vals[n / 2 - 1] + vals[n / 2]) / 2.0
530 };
531 result.push(med);
532 }
533 }
534 Self::from_vec(new_shape, result)
535 }
536
537 pub fn any(&self) -> bool {
539 self.data().iter().any(|&v| v != 0.0)
540 }
541
542 pub fn all(&self) -> bool {
544 self.data().iter().all(|&v| v != 0.0)
545 }
546
547 pub fn quantile(&self, q: f32) -> f32 {
549 let mut sorted = self.data().to_vec();
550 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
551 let n = sorted.len();
552 if n == 0 {
553 return 0.0;
554 }
555 let idx = q * (n - 1) as f32;
556 let lo = idx.floor() as usize;
557 let hi = idx.ceil() as usize;
558 if lo == hi || hi >= n {
559 sorted[lo.min(n - 1)]
560 } else {
561 let frac = idx - lo as f32;
562 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
563 }
564 }
565
566 #[allow(unsafe_code)]
574 pub fn transpose_2d(&self) -> Result<Self, TensorError> {
575 if self.rank() != 2 {
576 return Err(TensorError::InvalidAxis {
577 axis: 1,
578 rank: self.rank(),
579 });
580 }
581 let rows = self.shape()[0];
582 let cols = self.shape()[1];
583 let mut out_data = AlignedVec::<f32>::uninitialized(rows * cols);
585 let src = self.data();
586
587 const TILE: usize = 8;
589 let rr = rows / TILE * TILE;
590 let cc = cols / TILE * TILE;
591
592 for ii in (0..rr).step_by(TILE) {
593 for jj in (0..cc).step_by(TILE) {
594 for r in ii..ii + TILE {
595 for c in jj..jj + TILE {
596 out_data[c * rows + r] = src[r * cols + c];
597 }
598 }
599 }
600 for r in ii..ii + TILE {
602 for c in cc..cols {
603 out_data[c * rows + r] = src[r * cols + c];
604 }
605 }
606 }
607 for r in rr..rows {
609 for c in 0..cols {
610 out_data[c * rows + r] = src[r * cols + c];
611 }
612 }
613
614 Tensor::from_aligned(vec![cols, rows], out_data)
615 }
616
617 pub fn permute(&self, axes: &[usize]) -> Result<Self, TensorError> {
619 if axes.len() != self.rank() {
620 return Err(TensorError::InvalidIndexRank {
621 expected: self.rank(),
622 got: axes.len(),
623 });
624 }
625 let rank = self.rank();
626 let mut seen = vec![false; rank];
627 for &a in axes {
628 if a >= rank {
629 return Err(TensorError::InvalidAxis { axis: a, rank });
630 }
631 seen[a] = true;
632 }
633 if seen.iter().any(|&s| !s) {
634 return Err(TensorError::InvalidAxis { axis: 0, rank });
635 }
636
637 let src_shape = self.shape();
638 let mut out_shape = vec![0usize; rank];
639 for (dst, &src_axis) in axes.iter().enumerate() {
640 out_shape[dst] = src_shape[src_axis];
641 }
642 let out_count =
643 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
644 shape: out_shape.clone(),
645 })?;
646
647 if rank == 4 && axes == [0, 3, 1, 2] {
651 let (n, h, w, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
652 let hw = h * w;
653 let src = self.data();
654 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
655 const TILE: usize = 32;
656 #[allow(unsafe_code)]
657 unsafe {
658 let src_ptr = src.as_ptr();
659 let dst_ptr = dst.as_mut_ptr();
660 for batch in 0..n {
661 let s_base = src_ptr.add(batch * hw * c);
662 let d_base = dst_ptr.add(batch * c * hw);
663 for i0 in (0..hw).step_by(TILE) {
664 let ie = (i0 + TILE).min(hw);
665 for j0 in (0..c).step_by(TILE) {
666 let je = (j0 + TILE).min(c);
667 for i in i0..ie {
668 let s_row = s_base.add(i * c);
669 for j in j0..je {
670 *d_base.add(j * hw + i) = *s_row.add(j);
671 }
672 }
673 }
674 }
675 }
676 }
677 return Tensor::from_aligned(out_shape, dst);
678 }
679 if rank == 4 && axes == [0, 2, 3, 1] {
681 let (n, c, h, w) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
682 let hw = h * w;
683 let src = self.data();
684 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
685 const TILE: usize = 32;
686 #[allow(unsafe_code)]
687 unsafe {
688 let src_ptr = src.as_ptr();
689 let dst_ptr = dst.as_mut_ptr();
690 for batch in 0..n {
691 let s_base = src_ptr.add(batch * c * hw);
692 let d_base = dst_ptr.add(batch * hw * c);
693 for i0 in (0..c).step_by(TILE) {
694 let ie = (i0 + TILE).min(c);
695 for j0 in (0..hw).step_by(TILE) {
696 let je = (j0 + TILE).min(hw);
697 for i in i0..ie {
698 let s_row = s_base.add(i * hw);
699 for j in j0..je {
700 *d_base.add(j * c + i) = *s_row.add(j);
701 }
702 }
703 }
704 }
705 }
706 }
707 return Tensor::from_aligned(out_shape, dst);
708 }
709 if rank == 3 && axes == [0, 2, 1] {
711 let (a, b, c) = (src_shape[0], src_shape[1], src_shape[2]);
712 let src = self.data();
713 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
714 const TILE: usize = 32;
715 #[allow(unsafe_code)]
716 unsafe {
717 let src_ptr = src.as_ptr();
718 let dst_ptr = dst.as_mut_ptr();
719 for batch in 0..a {
720 let s_base = src_ptr.add(batch * b * c);
721 let d_base = dst_ptr.add(batch * c * b);
722 for i0 in (0..b).step_by(TILE) {
723 let ie = (i0 + TILE).min(b);
724 for j0 in (0..c).step_by(TILE) {
725 let je = (j0 + TILE).min(c);
726 for i in i0..ie {
727 let s_row = s_base.add(i * c);
728 for j in j0..je {
729 *d_base.add(j * b + i) = *s_row.add(j);
730 }
731 }
732 }
733 }
734 }
735 }
736 return Tensor::from_aligned(out_shape, dst);
737 }
738
739 if rank == 4 && axes == [0, 1, 3, 2] {
742 let (nn, a, b, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
743 let src = self.data();
744 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
745 const TILE: usize = 32;
746 #[allow(unsafe_code)]
747 unsafe {
748 let src_ptr = src.as_ptr();
749 let dst_ptr = dst.as_mut_ptr();
750 for n in 0..nn {
751 for aa in 0..a {
752 let base = (n * a + aa) * b * c;
753 let s_base = src_ptr.add(base);
754 let d_base = dst_ptr.add(base); for i0 in (0..b).step_by(TILE) {
756 let ie = (i0 + TILE).min(b);
757 for j0 in (0..c).step_by(TILE) {
758 let je = (j0 + TILE).min(c);
759 for i in i0..ie {
760 let s_row = s_base.add(i * c);
761 for j in j0..je {
762 *d_base.add(j * b + i) = *s_row.add(j);
763 }
764 }
765 }
766 }
767 }
768 }
769 }
770 return Tensor::from_aligned(out_shape, dst);
771 }
772 if rank == 4 && axes == [0, 2, 1, 3] {
775 let (nn, a, b, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
776 let src = self.data();
777 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
778 #[allow(unsafe_code)]
779 unsafe {
780 let src_ptr = src.as_ptr();
781 let dst_ptr = dst.as_mut_ptr();
782 for n in 0..nn {
783 let s_batch = src_ptr.add(n * a * b * c);
784 let d_batch = dst_ptr.add(n * b * a * c);
785 for aa in 0..a {
786 for bb in 0..b {
787 std::ptr::copy_nonoverlapping(
788 s_batch.add(aa * b * c + bb * c),
789 d_batch.add(bb * a * c + aa * c),
790 c,
791 );
792 }
793 }
794 }
795 }
796 return Tensor::from_aligned(out_shape, dst);
797 }
798 if rank == 4 && axes == [0, 3, 2, 1] {
801 let (nn, a, b, d) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
802 let src = self.data();
803 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
804 let src_a_stride = b * d;
805 let dst_d_stride = b * a;
806 const TILE: usize = 32;
807 #[allow(unsafe_code)]
808 unsafe {
809 let src_ptr = src.as_ptr();
810 let dst_ptr = dst.as_mut_ptr();
811 for n in 0..nn {
812 for bb in 0..b {
813 let s_base = src_ptr.add(n * a * b * d + bb * d);
814 let d_base = dst_ptr.add(n * d * b * a + bb * a);
815 for i0 in (0..a).step_by(TILE) {
816 let ie = (i0 + TILE).min(a);
817 for j0 in (0..d).step_by(TILE) {
818 let je = (j0 + TILE).min(d);
819 for i in i0..ie {
820 for j in j0..je {
821 *d_base.add(j * dst_d_stride + i) =
822 *s_base.add(i * src_a_stride + j);
823 }
824 }
825 }
826 }
827 }
828 }
829 }
830 return Tensor::from_aligned(out_shape, dst);
831 }
832 if rank == 2 && axes == [1, 0] {
834 let (rows, cols) = (src_shape[0], src_shape[1]);
835 let src = self.data();
836 let mut dst = AlignedVec::<f32>::uninitialized(out_count);
837 const TILE: usize = 32;
838 #[allow(unsafe_code)]
839 unsafe {
840 let src_ptr = src.as_ptr();
841 let dst_ptr = dst.as_mut_ptr();
842 for i0 in (0..rows).step_by(TILE) {
843 let ie = (i0 + TILE).min(rows);
844 for j0 in (0..cols).step_by(TILE) {
845 let je = (j0 + TILE).min(cols);
846 for i in i0..ie {
847 let s_row = src_ptr.add(i * cols);
848 for j in j0..je {
849 *dst_ptr.add(j * rows + i) = *s_row.add(j);
850 }
851 }
852 }
853 }
854 }
855 return Tensor::from_aligned(out_shape, dst);
856 }
857
858 let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
860 shape: out_shape.clone(),
861 })?;
862 let mut out_data = vec![0.0f32; out_count];
863
864 let mut in_coords = vec![0usize; rank];
865 for &val in self.data().iter() {
866 let mut out_offset = 0usize;
867 for (dst_axis, &src_axis) in axes.iter().enumerate() {
868 out_offset += in_coords[src_axis] * out_strides[dst_axis];
869 }
870 out_data[out_offset] = val;
871 increment_coords(&mut in_coords, src_shape);
872 }
873
874 Tensor::from_vec(out_shape, out_data)
875 }
876
877 pub fn unsqueeze(&self, axis: usize) -> Result<Self, TensorError> {
879 if axis > self.rank() {
880 return Err(TensorError::InvalidAxis {
881 axis,
882 rank: self.rank() + 1,
883 });
884 }
885 let mut new_shape = self.shape().to_vec();
886 new_shape.insert(axis, 1);
887 self.reshape(new_shape)
888 }
889
890 pub fn squeeze(&self, axis: usize) -> Result<Self, TensorError> {
892 if axis >= self.rank() {
893 return Err(TensorError::InvalidAxis {
894 axis,
895 rank: self.rank(),
896 });
897 }
898 if self.shape()[axis] != 1 {
899 return Err(TensorError::InvalidAxis {
900 axis,
901 rank: self.rank(),
902 });
903 }
904 let mut new_shape = self.shape().to_vec();
905 new_shape.remove(axis);
906 self.reshape(new_shape)
907 }
908
909 pub fn cat(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
912 if tensors.is_empty() {
913 return Err(TensorError::SizeMismatch {
914 shape: vec![],
915 data_len: 0,
916 });
917 }
918 let rank = tensors[0].rank();
919 if axis >= rank {
920 return Err(TensorError::InvalidAxis { axis, rank });
921 }
922 for t in &tensors[1..] {
923 if t.rank() != rank {
924 return Err(TensorError::ShapeMismatch {
925 left: tensors[0].shape().to_vec(),
926 right: t.shape().to_vec(),
927 });
928 }
929 for (a, (&d0, &di)) in tensors[0].shape().iter().zip(t.shape().iter()).enumerate() {
930 if a != axis && d0 != di {
931 return Err(TensorError::ShapeMismatch {
932 left: tensors[0].shape().to_vec(),
933 right: t.shape().to_vec(),
934 });
935 }
936 }
937 }
938
939 let mut out_shape = tensors[0].shape().to_vec();
940 out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
941 let out_count =
942 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
943 shape: out_shape.clone(),
944 })?;
945
946 let outer: usize = out_shape[..axis].iter().product();
947 let inner: usize = out_shape[axis + 1..].iter().product();
948
949 let mut out_data = AlignedVec::<f32>::uninitialized(out_count);
952
953 if inner == 1 && tensors.len() <= 8 {
954 let axis_lens: Vec<usize> = tensors.iter().map(|t| t.shape()[axis]).collect();
956 let dst = out_data.as_mut_slice();
957 let mut dst_off = 0;
958 for o in 0..outer {
959 for (ti, t) in tensors.iter().enumerate() {
960 let al = axis_lens[ti];
961 let src_off = o * al;
962 dst[dst_off..dst_off + al].copy_from_slice(&t.data()[src_off..src_off + al]);
963 dst_off += al;
964 }
965 }
966 } else {
967 let dst = out_data.as_mut_slice();
968 let mut dst_off = 0;
969 for o in 0..outer {
970 for t in tensors {
971 let t_axis_len = t.shape()[axis];
972 let chunk_len = t_axis_len * inner;
973 let chunk_start = o * chunk_len;
974 dst[dst_off..dst_off + chunk_len]
975 .copy_from_slice(&t.data()[chunk_start..chunk_start + chunk_len]);
976 dst_off += chunk_len;
977 }
978 }
979 }
980
981 Tensor::from_aligned(out_shape, out_data)
982 }
983
984 pub fn stack(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
986 if tensors.is_empty() {
987 return Err(TensorError::SizeMismatch {
988 shape: vec![],
989 data_len: 0,
990 });
991 }
992 if axis > tensors[0].rank() {
993 return Err(TensorError::InvalidAxis {
994 axis,
995 rank: tensors[0].rank() + 1,
996 });
997 }
998 let expanded: Vec<Self> = tensors
999 .iter()
1000 .map(|t| t.unsqueeze(axis))
1001 .collect::<Result<_, _>>()?;
1002 let refs: Vec<&Self> = expanded.iter().collect();
1003 Self::cat(&refs, axis)
1004 }
1005
1006 pub fn select(&self, axis: usize, index: usize) -> Result<Self, TensorError> {
1008 if axis >= self.rank() {
1009 return Err(TensorError::InvalidAxis {
1010 axis,
1011 rank: self.rank(),
1012 });
1013 }
1014 if index >= self.shape()[axis] {
1015 return Err(TensorError::IndexOutOfBounds {
1016 axis,
1017 index,
1018 dim: self.shape()[axis],
1019 });
1020 }
1021 let outer: usize = self.shape()[..axis].iter().product();
1022 let axis_len = self.shape()[axis];
1023 let inner: usize = self.shape()[axis + 1..].iter().product();
1024
1025 let mut out_data = Vec::with_capacity(outer * inner);
1026 for o in 0..outer {
1027 let base = o * axis_len * inner + index * inner;
1028 out_data.extend_from_slice(&self.data()[base..base + inner]);
1029 }
1030
1031 let mut out_shape = self.shape().to_vec();
1032 out_shape.remove(axis);
1033 Tensor::from_vec(out_shape, out_data)
1034 }
1035
1036 pub fn narrow(&self, axis: usize, start: usize, length: usize) -> Result<Self, TensorError> {
1038 if axis >= self.rank() {
1039 return Err(TensorError::InvalidAxis {
1040 axis,
1041 rank: self.rank(),
1042 });
1043 }
1044 if start + length > self.shape()[axis] {
1045 return Err(TensorError::IndexOutOfBounds {
1046 axis,
1047 index: start + length,
1048 dim: self.shape()[axis],
1049 });
1050 }
1051 let outer: usize = self.shape()[..axis].iter().product();
1052 let axis_len = self.shape()[axis];
1053 let inner: usize = self.shape()[axis + 1..].iter().product();
1054
1055 let mut out_data = Vec::with_capacity(outer * length * inner);
1056 for o in 0..outer {
1057 let base = o * axis_len * inner + start * inner;
1058 out_data.extend_from_slice(&self.data()[base..base + length * inner]);
1059 }
1060
1061 let mut out_shape = self.shape().to_vec();
1062 out_shape[axis] = length;
1063 Tensor::from_vec(out_shape, out_data)
1064 }
1065
1066 pub fn eq_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1070 if self.shape() == rhs.shape() {
1071 return self.binary_same_shape(rhs, |l, r| {
1072 if (l - r).abs() < f32::EPSILON {
1073 1.0
1074 } else {
1075 0.0
1076 }
1077 });
1078 }
1079 self.binary_broadcast_op(rhs, |l, r| {
1080 if (l - r).abs() < f32::EPSILON {
1081 1.0
1082 } else {
1083 0.0
1084 }
1085 })
1086 }
1087
1088 pub fn gt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1090 if self.shape() == rhs.shape() {
1091 return self.binary_same_shape(rhs, |l, r| if l > r { 1.0 } else { 0.0 });
1092 }
1093 self.binary_broadcast_op(rhs, |l, r| if l > r { 1.0 } else { 0.0 })
1094 }
1095
1096 pub fn lt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1098 if self.shape() == rhs.shape() {
1099 return self.binary_same_shape(rhs, |l, r| if l < r { 1.0 } else { 0.0 });
1100 }
1101 self.binary_broadcast_op(rhs, |l, r| if l < r { 1.0 } else { 0.0 })
1102 }
1103
1104 pub fn gt_tensor_into(&self, rhs: &Self, output: &mut Self) {
1107 debug_assert_eq!(self.shape(), rhs.shape());
1108 debug_assert_eq!(self.shape(), output.shape());
1109 simd::cmp_dispatch(
1110 self.data(),
1111 rhs.data(),
1112 output.data_mut(),
1113 simd::CmpKind::Gt,
1114 );
1115 }
1116
1117 pub fn eq_tensor_into(&self, rhs: &Self, output: &mut Self) {
1120 debug_assert_eq!(self.shape(), rhs.shape());
1121 debug_assert_eq!(self.shape(), output.shape());
1122 simd::cmp_dispatch(
1123 self.data(),
1124 rhs.data(),
1125 output.data_mut(),
1126 simd::CmpKind::Eq,
1127 );
1128 }
1129
1130 pub fn lt_tensor_into(&self, rhs: &Self, output: &mut Self) {
1133 debug_assert_eq!(self.shape(), rhs.shape());
1134 debug_assert_eq!(self.shape(), output.shape());
1135 simd::cmp_dispatch(
1136 self.data(),
1137 rhs.data(),
1138 output.data_mut(),
1139 simd::CmpKind::Lt,
1140 );
1141 }
1142
1143 pub fn ne_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1145 if self.shape() == rhs.shape() {
1146 return self.binary_same_shape(
1147 rhs,
1148 |l, r| {
1149 if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 }
1150 },
1151 );
1152 }
1153 self.binary_broadcast_op(rhs, |l, r| if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 })
1154 }
1155
1156 pub fn le_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1158 if self.shape() == rhs.shape() {
1159 return self.binary_same_shape(rhs, |l, r| if l <= r { 1.0 } else { 0.0 });
1160 }
1161 self.binary_broadcast_op(rhs, |l, r| if l <= r { 1.0 } else { 0.0 })
1162 }
1163
1164 pub fn ge_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1166 if self.shape() == rhs.shape() {
1167 return self.binary_same_shape(rhs, |l, r| if l >= r { 1.0 } else { 0.0 });
1168 }
1169 self.binary_broadcast_op(rhs, |l, r| if l >= r { 1.0 } else { 0.0 })
1170 }
1171
1172 pub fn all_finite(&self) -> bool {
1174 self.data().iter().all(|v| v.is_finite())
1175 }
1176
1177 pub fn where_cond(&self, condition: &Self, other: &Self) -> Result<Self, TensorError> {
1182 if self.shape() != condition.shape() || self.shape() != other.shape() {
1183 return Err(TensorError::ShapeMismatch {
1184 left: self.shape().to_vec(),
1185 right: condition.shape().to_vec(),
1186 });
1187 }
1188 let data: Vec<f32> = condition
1189 .data()
1190 .iter()
1191 .zip(self.data().iter())
1192 .zip(other.data().iter())
1193 .map(|((&c, &t), &f)| if c != 0.0 { t } else { f })
1194 .collect();
1195 Tensor::from_vec(self.shape().to_vec(), data)
1196 }
1197
1198 pub fn masked_fill(&self, mask: &Self, value: f32) -> Result<Self, TensorError> {
1200 if self.shape() != mask.shape() {
1201 return Err(TensorError::ShapeMismatch {
1202 left: self.shape().to_vec(),
1203 right: mask.shape().to_vec(),
1204 });
1205 }
1206 let data: Vec<f32> = self
1207 .data()
1208 .iter()
1209 .zip(mask.data().iter())
1210 .map(|(&v, &m)| if m != 0.0 { value } else { v })
1211 .collect();
1212 Tensor::from_vec(self.shape().to_vec(), data)
1213 }
1214
1215 pub fn scatter(&self, axis: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
1218 if index.shape() != src.shape() {
1219 return Err(TensorError::ShapeMismatch {
1220 left: index.shape().to_vec(),
1221 right: src.shape().to_vec(),
1222 });
1223 }
1224 if axis >= self.rank() {
1225 return Err(TensorError::InvalidAxis {
1226 axis,
1227 rank: self.rank(),
1228 });
1229 }
1230 let mut out = self.data().to_vec();
1231 let shape = index.shape();
1232 let outer: usize = shape[..axis].iter().product();
1233 let dim = shape[axis];
1234 let inner: usize = shape[axis + 1..].iter().product();
1235 let self_dim = self.shape()[axis];
1236 let self_inner: usize = self.shape()[axis + 1..].iter().product();
1237
1238 for o in 0..outer {
1239 for d in 0..dim {
1240 for i in 0..inner {
1241 let src_idx = (o * dim + d) * inner + i;
1242 let target_d = index.data()[src_idx] as usize;
1243 if target_d < self_dim {
1244 let out_idx = (o * self_dim + target_d) * self_inner + i;
1245 if out_idx < out.len() {
1246 out[out_idx] = src.data()[src_idx];
1247 }
1248 }
1249 }
1250 }
1251 }
1252 Tensor::from_vec(self.shape().to_vec(), out)
1253 }
1254
1255 pub fn gather(&self, axis: usize, index: &Self) -> Result<Self, TensorError> {
1257 if axis >= self.rank() {
1258 return Err(TensorError::InvalidAxis {
1259 axis,
1260 rank: self.rank(),
1261 });
1262 }
1263 let shape = index.shape();
1264 let outer: usize = shape[..axis].iter().product();
1265 let dim = shape[axis];
1266 let inner: usize = shape[axis + 1..].iter().product();
1267 let self_dim = self.shape()[axis];
1268 let self_inner: usize = self.shape()[axis + 1..].iter().product();
1269
1270 let mut out = vec![0.0f32; index.len()];
1271 for o in 0..outer {
1272 for d in 0..dim {
1273 for i in 0..inner {
1274 let idx_pos = (o * dim + d) * inner + i;
1275 let src_d = index.data()[idx_pos] as usize;
1276 if src_d < self_dim {
1277 let src_pos = (o * self_dim + src_d) * self_inner + i;
1278 if src_pos < self.len() {
1279 out[idx_pos] = self.data()[src_pos];
1280 }
1281 }
1282 }
1283 }
1284 }
1285 Tensor::from_vec(shape.to_vec(), out)
1286 }
1287
1288 pub fn topk(&self, k: usize) -> Result<(Self, Self), TensorError> {
1290 if self.rank() == 0 {
1291 return Err(TensorError::InvalidAxis { axis: 0, rank: 0 });
1292 }
1293 let last_dim = *self.shape().last().expect("non-empty shape");
1294 let k = k.min(last_dim);
1295 let outer: usize = self.len() / last_dim;
1296
1297 let mut values = Vec::with_capacity(outer * k);
1298 let mut indices = Vec::with_capacity(outer * k);
1299
1300 for o in 0..outer {
1301 let start = o * last_dim;
1302 let slice = &self.data()[start..start + last_dim];
1303 let mut idx_vec: Vec<usize> = (0..last_dim).collect();
1304 idx_vec.sort_unstable_by(|&a, &b| {
1305 slice[b]
1306 .partial_cmp(&slice[a])
1307 .unwrap_or(std::cmp::Ordering::Equal)
1308 });
1309 for &i in &idx_vec[..k] {
1310 values.push(slice[i]);
1311 indices.push(i as f32);
1312 }
1313 }
1314
1315 let mut out_shape = self.shape().to_vec();
1316 *out_shape.last_mut().expect("non-empty shape") = k;
1317 let val_t = Tensor::from_vec(out_shape.clone(), values)?;
1318 let idx_t = Tensor::from_vec(out_shape, indices)?;
1319 Ok((val_t, idx_t))
1320 }
1321
1322 pub fn triu(&self, diagonal: i64) -> Result<Self, TensorError> {
1325 if self.rank() < 2 {
1326 return Err(TensorError::InvalidAxis {
1327 axis: 0,
1328 rank: self.rank(),
1329 });
1330 }
1331 let shape = self.shape();
1332 let rows = shape[shape.len() - 2];
1333 let cols = shape[shape.len() - 1];
1334 let batch: usize = shape[..shape.len() - 2].iter().product();
1335 let mut out = self.data().to_vec();
1336 for b in 0..batch {
1337 for r in 0..rows {
1338 for c in 0..cols {
1339 if (c as i64) < (r as i64) + diagonal {
1340 out[b * rows * cols + r * cols + c] = 0.0;
1341 }
1342 }
1343 }
1344 }
1345 Tensor::from_vec(shape.to_vec(), out)
1346 }
1347
1348 pub fn tril(&self, diagonal: i64) -> Result<Self, TensorError> {
1350 if self.rank() < 2 {
1351 return Err(TensorError::InvalidAxis {
1352 axis: 0,
1353 rank: self.rank(),
1354 });
1355 }
1356 let shape = self.shape();
1357 let rows = shape[shape.len() - 2];
1358 let cols = shape[shape.len() - 1];
1359 let batch: usize = shape[..shape.len() - 2].iter().product();
1360 let mut out = self.data().to_vec();
1361 for b in 0..batch {
1362 for r in 0..rows {
1363 for c in 0..cols {
1364 if (c as i64) > (r as i64) + diagonal {
1365 out[b * rows * cols + r * cols + c] = 0.0;
1366 }
1367 }
1368 }
1369 }
1370 Tensor::from_vec(shape.to_vec(), out)
1371 }
1372
1373 pub fn eye(n: usize) -> Result<Self, TensorError> {
1375 let mut data = vec![0.0f32; n * n];
1376 for i in 0..n {
1377 data[i * n + i] = 1.0;
1378 }
1379 Tensor::from_vec(vec![n, n], data)
1380 }
1381
1382 pub fn diag(vector: &Tensor) -> Result<Self, TensorError> {
1384 let shape = vector.shape();
1385 if shape.len() != 1 {
1386 return Err(TensorError::UnsupportedOperation {
1387 msg: format!("diag requires a 1D tensor, got shape {:?}", shape),
1388 });
1389 }
1390 let n = shape[0];
1391 let mut data = vec![0.0f32; n * n];
1392 for i in 0..n {
1393 data[i * n + i] = vector.data()[i];
1394 }
1395 Self::from_vec(vec![n, n], data)
1396 }
1397
1398 pub fn diag_extract(&self) -> Result<Self, TensorError> {
1400 let shape = self.shape();
1401 if shape.len() != 2 {
1402 return Err(TensorError::UnsupportedOperation {
1403 msg: format!("diag_extract requires a 2D tensor, got shape {:?}", shape),
1404 });
1405 }
1406 let n = shape[0].min(shape[1]);
1407 let cols = shape[1];
1408 let data: Vec<f32> = (0..n).map(|i| self.data()[i * cols + i]).collect();
1409 Self::from_vec(vec![n], data)
1410 }
1411
1412 pub fn pad(&self, padding: &[(usize, usize)], value: f32) -> Result<Self, TensorError> {
1414 let shape = self.shape();
1415 if padding.len() != shape.len() {
1416 return Err(TensorError::InvalidIndexRank {
1417 expected: shape.len(),
1418 got: padding.len(),
1419 });
1420 }
1421 let new_shape: Vec<usize> = shape
1422 .iter()
1423 .zip(padding)
1424 .map(|(&s, &(b, a))| s + b + a)
1425 .collect();
1426 let new_size: usize = new_shape.iter().product();
1427 let mut result = vec![value; new_size];
1428 let ndim = shape.len();
1429
1430 let mut old_strides = vec![1usize; ndim];
1432 for i in (0..ndim.saturating_sub(1)).rev() {
1433 old_strides[i] = old_strides[i + 1] * shape[i + 1];
1434 }
1435 let mut new_strides = vec![1usize; ndim];
1436 for i in (0..ndim.saturating_sub(1)).rev() {
1437 new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1438 }
1439
1440 let old_size: usize = shape.iter().product();
1441 let data = self.data();
1442 for flat_idx in 0..old_size {
1443 let mut remaining = flat_idx;
1444 let mut new_flat = 0;
1445 for d in 0..ndim {
1446 let coord = remaining / old_strides[d];
1447 remaining %= old_strides[d];
1448 new_flat += (coord + padding[d].0) * new_strides[d];
1449 }
1450 result[new_flat] = data[flat_idx];
1451 }
1452
1453 Self::from_vec(new_shape, result)
1454 }
1455
1456 pub fn repeat(&self, counts: &[usize]) -> Result<Self, TensorError> {
1458 if counts.len() != self.rank() {
1459 return Err(TensorError::InvalidIndexRank {
1460 expected: self.rank(),
1461 got: counts.len(),
1462 });
1463 }
1464 let mut out = self.clone();
1465 for (axis, &count) in counts.iter().enumerate() {
1466 if count > 1 {
1467 let refs: Vec<&Tensor> = std::iter::repeat_n(&out, count).collect();
1468 out = Tensor::cat(&refs, axis)?;
1469 }
1470 }
1471 Ok(out)
1472 }
1473
1474 pub fn cumsum(&self, axis: usize) -> Result<Self, TensorError> {
1478 if axis >= self.rank() {
1479 return Err(TensorError::InvalidAxis {
1480 axis,
1481 rank: self.rank(),
1482 });
1483 }
1484 let shape = self.shape();
1485 let outer: usize = shape[..axis].iter().product();
1486 let axis_len = shape[axis];
1487 let inner: usize = shape[axis + 1..].iter().product();
1488 let mut out = self.data().to_vec();
1489
1490 for o in 0..outer {
1491 for i in 0..inner {
1492 let mut acc = 0.0f32;
1493 for d in 0..axis_len {
1494 let idx = (o * axis_len + d) * inner + i;
1495 acc += out[idx];
1496 out[idx] = acc;
1497 }
1498 }
1499 }
1500 Tensor::from_vec(shape.to_vec(), out)
1501 }
1502
1503 pub fn cumprod(&self, axis: usize) -> Result<Self, TensorError> {
1505 if axis >= self.rank() {
1506 return Err(TensorError::InvalidAxis {
1507 axis,
1508 rank: self.rank(),
1509 });
1510 }
1511 let shape = self.shape();
1512 let outer: usize = shape[..axis].iter().product();
1513 let axis_len = shape[axis];
1514 let inner: usize = shape[axis + 1..].iter().product();
1515 let mut out = self.data().to_vec();
1516
1517 for o in 0..outer {
1518 for i in 0..inner {
1519 let mut acc = 1.0f32;
1520 for d in 0..axis_len {
1521 let idx = (o * axis_len + d) * inner + i;
1522 acc *= out[idx];
1523 out[idx] = acc;
1524 }
1525 }
1526 }
1527 Tensor::from_vec(shape.to_vec(), out)
1528 }
1529
1530 pub fn to_fp16(&self) -> Vec<u16> {
1535 self.data().iter().map(|&v| f32_to_fp16(v)).collect()
1536 }
1537
1538 pub fn from_fp16(shape: Vec<usize>, fp16_data: &[u16]) -> Result<Self, TensorError> {
1540 let data: Vec<f32> = fp16_data.iter().map(|&v| fp16_to_f32(v)).collect();
1541 Tensor::from_vec(shape, data)
1542 }
1543
1544 #[allow(unsafe_code)]
1547 fn unary_op<F>(&self, op: F) -> Self
1548 where
1549 F: Fn(f32) -> f32,
1550 {
1551 let src = self.data();
1552 let len = src.len();
1553 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1556 let inp = src.as_ptr();
1557 let outp = out_data.as_mut_ptr();
1558 unsafe {
1559 for i in 0..len {
1560 *outp.add(i) = op(*inp.add(i));
1561 }
1562 }
1563 Tensor::from_raw_parts(self.shape(), self.strides(), out_data)
1564 }
1565
1566 #[allow(unsafe_code)]
1567 fn unary_simd_op(&self, kind: simd::UnaryKind) -> Self {
1568 let len = self.len();
1569 let mut out = AlignedVec::<f32>::uninitialized(len);
1572 simd::unary_dispatch(self.data(), &mut out, kind);
1573 Tensor::from_raw_parts(self.shape(), self.strides(), out)
1574 }
1575
1576 fn reduce_axis<F>(&self, axis: usize, init: f32, combine: F) -> Result<Self, TensorError>
1577 where
1578 F: Fn(f32, f32) -> f32,
1579 {
1580 if axis >= self.rank() {
1581 return Err(TensorError::InvalidAxis {
1582 axis,
1583 rank: self.rank(),
1584 });
1585 }
1586
1587 let mut out_shape = self.shape().to_vec();
1588 out_shape.remove(axis);
1589 let out_count =
1590 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1591 shape: out_shape.clone(),
1592 })?;
1593 let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1594 shape: out_shape.clone(),
1595 })?;
1596 let mut out_data = vec![init; out_count];
1597
1598 let mut in_coords = vec![0usize; self.rank()];
1599 for input in self.data().iter().copied() {
1600 let mut out_offset = 0usize;
1601 for (src_axis, coord) in in_coords.iter().copied().enumerate() {
1602 if src_axis == axis {
1603 continue;
1604 }
1605 let dst_axis = if src_axis < axis {
1606 src_axis
1607 } else {
1608 src_axis - 1
1609 };
1610 if !out_shape.is_empty() {
1611 out_offset += coord * out_strides[dst_axis];
1612 }
1613 }
1614 out_data[out_offset] = combine(out_data[out_offset], input);
1615 increment_coords(&mut in_coords, self.shape());
1616 }
1617
1618 Tensor::from_vec(out_shape, out_data)
1619 }
1620
1621 #[allow(unsafe_code)]
1627 fn binary_broadcast_lastdim_simd(
1628 &self,
1629 rhs: &Self,
1630 kind: simd::BinaryKind,
1631 ) -> Option<Result<Self, TensorError>> {
1632 let lhs_shape = self.shape();
1633 let rhs_shape = rhs.shape();
1634
1635 let lhs_last = *lhs_shape.last()?;
1638 if lhs_last == 0 {
1639 return None;
1640 }
1641
1642 let rhs_last = *rhs_shape.last()?;
1643
1644 let rhs_is_lastdim_vec =
1647 rhs_last == lhs_last && rhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1648 let lhs_is_lastdim_vec =
1650 lhs_last == rhs_last && lhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1651
1652 if rhs_is_lastdim_vec && !lhs_is_lastdim_vec {
1653 let lhs_data = self.data();
1655 let rhs_data = rhs.data();
1656 let row_len = lhs_last;
1657 let num_rows = lhs_data.len() / row_len;
1658 let mut out_data = AlignedVec::<f32>::uninitialized(lhs_data.len());
1659
1660 for i in 0..num_rows {
1661 let start = i * row_len;
1662 let end = start + row_len;
1663 simd::binary_dispatch(
1664 &lhs_data[start..end],
1665 &rhs_data[..row_len],
1666 &mut out_data[start..end],
1667 kind,
1668 );
1669 }
1670
1671 let out_strides = compute_strides(lhs_shape).expect("valid shape for strides");
1672 Some(Ok(Tensor::from_raw_parts(
1673 lhs_shape,
1674 &out_strides,
1675 out_data,
1676 )))
1677 } else if lhs_is_lastdim_vec && !rhs_is_lastdim_vec {
1678 let lhs_data = self.data();
1680 let rhs_data = rhs.data();
1681 let row_len = rhs_last;
1682 let num_rows = rhs_data.len() / row_len;
1683 let mut out_data = AlignedVec::<f32>::uninitialized(rhs_data.len());
1684
1685 for i in 0..num_rows {
1686 let start = i * row_len;
1687 let end = start + row_len;
1688 simd::binary_dispatch(
1689 &lhs_data[..row_len],
1690 &rhs_data[start..end],
1691 &mut out_data[start..end],
1692 kind,
1693 );
1694 }
1695
1696 let out_strides = compute_strides(rhs_shape).expect("valid shape for strides");
1697 Some(Ok(Tensor::from_raw_parts(
1698 rhs_shape,
1699 &out_strides,
1700 out_data,
1701 )))
1702 } else {
1703 None
1704 }
1705 }
1706
1707 fn binary_broadcast_op<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1708 where
1709 F: Fn(f32, f32) -> f32,
1710 {
1711 let out_shape = broadcast_shape(self.shape(), rhs.shape()).ok_or_else(|| {
1712 TensorError::BroadcastIncompatible {
1713 left: self.shape().to_vec(),
1714 right: rhs.shape().to_vec(),
1715 }
1716 })?;
1717
1718 let out_count =
1719 shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1720 shape: out_shape.clone(),
1721 })?;
1722 let mut out_data = vec![0.0; out_count];
1723 let mut coords = vec![0usize; out_shape.len()];
1724
1725 for value in &mut out_data {
1726 let left_offset = broadcast_offset(self.shape(), self.strides(), &coords);
1727 let right_offset = broadcast_offset(rhs.shape(), rhs.strides(), &coords);
1728 *value = op(self.data()[left_offset], rhs.data()[right_offset]);
1729 increment_coords(&mut coords, &out_shape);
1730 }
1731
1732 Tensor::from_vec(out_shape, out_data)
1733 }
1734
1735 #[allow(unsafe_code)]
1737 #[allow(unsafe_code)]
1738 fn binary_same_shape_simd(
1739 &self,
1740 rhs: &Self,
1741 kind: simd::BinaryKind,
1742 ) -> Result<Self, TensorError> {
1743 let len = self.len();
1744 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1745
1746 if len >= 100_000 {
1749 let n = std::thread::available_parallelism()
1750 .map(|p| p.get())
1751 .unwrap_or(4);
1752 let chunk = len.div_ceil(n);
1753 let lp = self.data().as_ptr() as usize;
1754 let rp = rhs.data().as_ptr() as usize;
1755 let op = out_data.as_mut_ptr() as usize;
1756
1757 #[cfg(target_os = "macos")]
1758 {
1759 use std::ffi::c_void;
1760 #[allow(unsafe_code)]
1761 unsafe extern "C" {
1762 fn dispatch_get_global_queue(id: isize, flags: usize) -> *const c_void;
1763 fn dispatch_apply_f(
1764 n: usize,
1765 q: *const c_void,
1766 ctx: *mut c_void,
1767 work: unsafe extern "C" fn(*mut c_void, usize),
1768 );
1769 }
1770 struct Ctx {
1771 lp: usize,
1772 rp: usize,
1773 op: usize,
1774 chunk: usize,
1775 len: usize,
1776 kind: simd::BinaryKind,
1777 }
1778 let ctx = Ctx {
1779 lp,
1780 rp,
1781 op,
1782 chunk,
1783 len,
1784 kind,
1785 };
1786 unsafe extern "C" fn work(raw: *mut c_void, t: usize) {
1787 let c = unsafe { &*(raw as *const Ctx) };
1788 let start = t * c.chunk;
1789 let end = (start + c.chunk).min(c.len);
1790 if start >= end {
1791 return;
1792 }
1793 let l = unsafe {
1794 std::slice::from_raw_parts((c.lp as *const f32).add(start), end - start)
1795 };
1796 let r = unsafe {
1797 std::slice::from_raw_parts((c.rp as *const f32).add(start), end - start)
1798 };
1799 let o = unsafe {
1800 std::slice::from_raw_parts_mut((c.op as *mut f32).add(start), end - start)
1801 };
1802 simd::binary_dispatch(l, r, o, c.kind);
1803 }
1804 let q = unsafe { dispatch_get_global_queue(0, 0) };
1805 unsafe {
1806 dispatch_apply_f(n, q, &ctx as *const Ctx as *mut c_void, work);
1807 }
1808 }
1809
1810 #[cfg(not(target_os = "macos"))]
1811 {
1812 use rayon::prelude::*;
1814 (0..n).into_par_iter().for_each(|t| {
1815 let start = t * chunk;
1816 let end = (start + chunk).min(len);
1817 if start >= end {
1818 return;
1819 }
1820 let l = unsafe {
1821 std::slice::from_raw_parts((lp as *const f32).add(start), end - start)
1822 };
1823 let r = unsafe {
1824 std::slice::from_raw_parts((rp as *const f32).add(start), end - start)
1825 };
1826 let o = unsafe {
1827 std::slice::from_raw_parts_mut((op as *mut f32).add(start), end - start)
1828 };
1829 simd::binary_dispatch(l, r, o, kind);
1830 });
1831 }
1832
1833 return Ok(Tensor::from_raw_parts(
1834 self.shape(),
1835 self.strides(),
1836 out_data,
1837 ));
1838 }
1839
1840 simd::binary_dispatch(self.data(), rhs.data(), &mut out_data, kind);
1841 Ok(Tensor::from_raw_parts(
1842 self.shape(),
1843 self.strides(),
1844 out_data,
1845 ))
1846 }
1847
1848 #[allow(unsafe_code)]
1849 fn binary_same_shape<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1850 where
1851 F: Fn(f32, f32) -> f32,
1852 {
1853 let len = self.len();
1854 let mut out_data = AlignedVec::<f32>::uninitialized(len);
1857
1858 let lhs_ptr = self.data().as_ptr();
1859 let rhs_ptr = rhs.data().as_ptr();
1860 let out_ptr = out_data.as_mut_ptr();
1861
1862 unsafe {
1867 let mut index = 0usize;
1868 while index + 4 <= len {
1869 *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1870 *out_ptr.add(index + 1) = op(*lhs_ptr.add(index + 1), *rhs_ptr.add(index + 1));
1871 *out_ptr.add(index + 2) = op(*lhs_ptr.add(index + 2), *rhs_ptr.add(index + 2));
1872 *out_ptr.add(index + 3) = op(*lhs_ptr.add(index + 3), *rhs_ptr.add(index + 3));
1873 index += 4;
1874 }
1875 while index < len {
1876 *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1877 index += 1;
1878 }
1879 }
1880
1881 Ok(Tensor::from_raw_parts(
1882 self.shape(),
1883 self.strides(),
1884 out_data,
1885 ))
1886 }
1887
1888 pub fn relu_inplace(&mut self) {
1892 simd::relu_inplace_dispatch(self.data_mut());
1893 }
1894
1895 pub fn add_inplace(&mut self, rhs: &Self) {
1897 debug_assert_eq!(self.len(), rhs.len());
1898 simd::add_inplace_dispatch(self.data_mut(), rhs.data());
1899 }
1900
1901 pub fn add_scalar_inplace(&mut self, s: f32) {
1903 simd::add_scalar_inplace_dispatch(self.data_mut(), s);
1904 }
1905
1906 pub fn mul_scalar_inplace(&mut self, s: f32) {
1908 simd::mul_scalar_inplace_dispatch(self.data_mut(), s);
1909 }
1910
1911 pub fn add_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1916 debug_assert_eq!(lhs.shape(), rhs.shape());
1917 debug_assert_eq!(lhs.shape(), output.shape());
1918 simd::binary_dispatch(
1919 lhs.data(),
1920 rhs.data(),
1921 output.data_mut(),
1922 simd::BinaryKind::Add,
1923 );
1924 }
1925
1926 pub fn sub_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1929 debug_assert_eq!(lhs.shape(), rhs.shape());
1930 debug_assert_eq!(lhs.shape(), output.shape());
1931 simd::binary_dispatch(
1932 lhs.data(),
1933 rhs.data(),
1934 output.data_mut(),
1935 simd::BinaryKind::Sub,
1936 );
1937 }
1938
1939 pub fn mul_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1942 debug_assert_eq!(lhs.shape(), rhs.shape());
1943 debug_assert_eq!(lhs.shape(), output.shape());
1944 simd::binary_dispatch(
1945 lhs.data(),
1946 rhs.data(),
1947 output.data_mut(),
1948 simd::BinaryKind::Mul,
1949 );
1950 }
1951}
1952
1953#[allow(clippy::excessive_precision)]
1964#[inline(always)]
1965fn fast_atan2_scalar(y: f32, x: f32) -> f32 {
1966 const PI: f32 = std::f32::consts::PI;
1967 const FRAC_PI_2: f32 = std::f32::consts::FRAC_PI_2;
1968 const FRAC_PI_4: f32 = std::f32::consts::FRAC_PI_4;
1969 const TAN_3PI_8: f32 = 2.414_213_6; const TAN_PI_8: f32 = 0.414_213_57; let ax = x.abs();
1973 let ay = y.abs();
1974
1975 let (num, den, swap) = if ax >= ay {
1977 (ay, ax, false)
1978 } else {
1979 (ax, ay, true)
1980 };
1981 let z = if den > 0.0 { num / den } else { 0.0 };
1982
1983 let (z_red, bias) = if z > TAN_3PI_8 {
1985 (-1.0 / z, FRAC_PI_2)
1987 } else if z > TAN_PI_8 {
1988 ((z - 1.0) / (z + 1.0), FRAC_PI_4)
1989 } else {
1990 (z, 0.0)
1991 };
1992
1993 let z2 = z_red * z_red;
1996 let p: f32 = 8.054_666e-02;
1997 let p = p.mul_add(z2, -1.384_895_1e-01);
1998 let p = p.mul_add(z2, 1.997_075_8e-01);
1999 let p = p.mul_add(z2, -3.333_129_8e-01);
2000 let atan_z = z_red.mul_add(z2 * p, z_red) + bias;
2001
2002 let mut result = if swap { FRAC_PI_2 - atan_z } else { atan_z };
2004
2005 if x < 0.0 {
2007 result = PI - result;
2008 }
2009 if y < 0.0 {
2010 result = -result;
2011 }
2012
2013 result
2014}
2015
2016fn f32_to_fp16(val: f32) -> u16 {
2020 let bits = val.to_bits();
2021 let sign = ((bits >> 16) & 0x8000) as u16;
2022 let exponent = ((bits >> 23) & 0xFF) as i32;
2023 let mantissa = bits & 0x007F_FFFF;
2024
2025 if exponent == 255 {
2026 return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
2028 }
2029
2030 let unbiased = exponent - 127;
2031 if unbiased > 15 {
2032 return sign | 0x7C00; }
2034 if unbiased < -24 {
2035 return sign; }
2037 if unbiased < -14 {
2038 let shift = -1 - unbiased;
2040 let m = (mantissa | 0x0080_0000) >> (shift + 13);
2041 return sign | m as u16;
2042 }
2043
2044 let fp16_exp = ((unbiased + 15) as u16) << 10;
2045 let fp16_man = (mantissa >> 13) as u16;
2046 sign | fp16_exp | fp16_man
2047}
2048
2049impl Tensor {
2050 pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self), TensorError> {
2056 if dim >= self.rank() {
2057 return Err(TensorError::InvalidAxis {
2058 axis: dim,
2059 rank: self.rank(),
2060 });
2061 }
2062 let shape = self.shape();
2063 let outer: usize = shape[..dim].iter().product();
2064 let dim_len = shape[dim];
2065 let inner: usize = shape[dim + 1..].iter().product();
2066 let data = self.data();
2067
2068 let mut out_vals = vec![0.0f32; data.len()];
2069 let mut out_idxs = vec![0.0f32; data.len()];
2070
2071 for o in 0..outer {
2072 for i in 0..inner {
2073 let mut idx_vec: Vec<usize> = (0..dim_len).collect();
2074 idx_vec.sort_unstable_by(|&a, &b| {
2075 let va = data[(o * dim_len + a) * inner + i];
2076 let vb = data[(o * dim_len + b) * inner + i];
2077 if descending {
2078 vb.partial_cmp(&va).unwrap_or(std::cmp::Ordering::Equal)
2079 } else {
2080 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
2081 }
2082 });
2083 for (rank, &src) in idx_vec.iter().enumerate() {
2084 let dst = (o * dim_len + rank) * inner + i;
2085 let src_pos = (o * dim_len + src) * inner + i;
2086 out_vals[dst] = data[src_pos];
2087 out_idxs[dst] = src as f32;
2088 }
2089 }
2090 }
2091
2092 let v = Tensor::from_vec(shape.to_vec(), out_vals)?;
2093 let idx = Tensor::from_vec(shape.to_vec(), out_idxs)?;
2094 Ok((v, idx))
2095 }
2096
2097 pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self, TensorError> {
2099 let (_, indices) = self.sort(dim, descending)?;
2100 Ok(indices)
2101 }
2102
2103 pub fn unique(&self) -> (Self, Self, Self) {
2105 let data = self.data();
2106 let mut sorted: Vec<f32> = data.to_vec();
2107 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
2108 sorted.dedup();
2109
2110 let mut inverse = vec![0.0f32; data.len()];
2111 let mut counts = vec![0.0f32; sorted.len()];
2112 for (i, &v) in data.iter().enumerate() {
2113 let pos = sorted
2114 .binary_search_by(|probe| {
2115 probe.partial_cmp(&v).unwrap_or(std::cmp::Ordering::Equal)
2116 })
2117 .expect("value exists in sorted list");
2118 inverse[i] = pos as f32;
2119 counts[pos] += 1.0;
2120 }
2121
2122 let vals = Tensor::from_vec(vec![sorted.len()], sorted).expect("unique vals");
2123 let inv = Tensor::from_vec(self.shape().to_vec(), inverse).expect("unique inv");
2124 let cnt = Tensor::from_vec(vec![counts.len()], counts).expect("unique counts");
2125 (vals, inv, cnt)
2126 }
2127
2128 pub fn nonzero(&self) -> Self {
2130 let shape = self.shape();
2131 let rank = shape.len().max(1);
2132 let data = self.data();
2133 let mut coords: Vec<Vec<usize>> = Vec::new();
2134
2135 if shape.is_empty() {
2136 if data[0] != 0.0 {
2138 coords.push(vec![0]);
2139 }
2140 } else {
2141 let mut idx = vec![0usize; shape.len()];
2142 for pos in 0..data.len() {
2143 if data[pos] != 0.0 {
2144 coords.push(idx.clone());
2145 }
2146 for d in (0..shape.len()).rev() {
2148 idx[d] += 1;
2149 if idx[d] < shape[d] {
2150 break;
2151 }
2152 idx[d] = 0;
2153 }
2154 }
2155 }
2156
2157 let n = coords.len();
2158 let mut flat = Vec::with_capacity(n * rank);
2159 for c in &coords {
2160 for &v in c {
2161 flat.push(v as f32);
2162 }
2163 }
2164 if n == 0 {
2165 Tensor::from_vec(vec![0, rank], flat).expect("nonzero empty")
2166 } else {
2167 Tensor::from_vec(vec![n, rank], flat).expect("nonzero")
2168 }
2169 }
2170
2171 pub fn flip(&self, dims: &[usize]) -> Result<Self, TensorError> {
2175 for &d in dims {
2176 if d >= self.rank() {
2177 return Err(TensorError::InvalidAxis {
2178 axis: d,
2179 rank: self.rank(),
2180 });
2181 }
2182 }
2183 let shape = self.shape();
2184 let data = self.data();
2185 let total = data.len();
2186 let mut out = vec![0.0f32; total];
2187 let rank = shape.len();
2188
2189 let mut src_idx = vec![0usize; rank];
2190 for pos in 0..total {
2191 let mut dst_idx = src_idx.clone();
2193 for &d in dims {
2194 dst_idx[d] = shape[d] - 1 - src_idx[d];
2195 }
2196 let mut dst_pos = 0;
2198 let mut stride = 1;
2199 for d in (0..rank).rev() {
2200 dst_pos += dst_idx[d] * stride;
2201 stride *= shape[d];
2202 }
2203 out[dst_pos] = data[pos];
2204
2205 for d in (0..rank).rev() {
2207 src_idx[d] += 1;
2208 if src_idx[d] < shape[d] {
2209 break;
2210 }
2211 src_idx[d] = 0;
2212 }
2213 }
2214 Tensor::from_vec(shape.to_vec(), out)
2215 }
2216
2217 pub fn roll(&self, shift: i64, dim: usize) -> Result<Self, TensorError> {
2219 if dim >= self.rank() {
2220 return Err(TensorError::InvalidAxis {
2221 axis: dim,
2222 rank: self.rank(),
2223 });
2224 }
2225 let shape = self.shape();
2226 let outer: usize = shape[..dim].iter().product();
2227 let dim_len = shape[dim];
2228 let inner: usize = shape[dim + 1..].iter().product();
2229 let data = self.data();
2230
2231 let mut out = vec![0.0f32; data.len()];
2232 for o in 0..outer {
2233 for d in 0..dim_len {
2234 let dst_d = ((d as i64 + shift).rem_euclid(dim_len as i64)) as usize;
2235 for i in 0..inner {
2236 out[(o * dim_len + dst_d) * inner + i] = data[(o * dim_len + d) * inner + i];
2237 }
2238 }
2239 }
2240 Tensor::from_vec(shape.to_vec(), out)
2241 }
2242
2243 pub fn linspace(start: f32, end: f32, steps: usize) -> Result<Self, TensorError> {
2247 if steps == 0 {
2248 return Tensor::from_vec(vec![0], vec![]);
2249 }
2250 if steps == 1 {
2251 return Tensor::from_vec(vec![1], vec![start]);
2252 }
2253 let step = (end - start) / (steps - 1) as f32;
2254 let data: Vec<f32> = (0..steps).map(|i| start + step * i as f32).collect();
2255 Tensor::from_vec(vec![steps], data)
2256 }
2257
2258 pub fn arange(start: f32, end: f32, step: f32) -> Result<Self, TensorError> {
2260 if step == 0.0 {
2261 return Err(TensorError::ShapeMismatch {
2262 left: vec![],
2263 right: vec![],
2264 });
2265 }
2266 let mut data = Vec::new();
2267 let mut v = start;
2268 if step > 0.0 {
2269 while v < end {
2270 data.push(v);
2271 v += step;
2272 }
2273 } else {
2274 while v > end {
2275 data.push(v);
2276 v += step;
2277 }
2278 }
2279 let n = data.len();
2280 Tensor::from_vec(vec![n], data)
2281 }
2282
2283 pub fn meshgrid(tensors: &[Self]) -> Result<Vec<Self>, TensorError> {
2285 let shape: Vec<usize> = tensors.iter().map(|t| t.len()).collect();
2286 let total: usize = shape.iter().product();
2287 let n = tensors.len();
2288 let mut result = Vec::with_capacity(n);
2289
2290 for (idx, t) in tensors.iter().enumerate() {
2291 let t_data = t.data();
2292 let mut out = vec![0.0f32; total];
2293 let inner: usize = shape[idx + 1..].iter().product();
2295 let outer: usize = shape[..idx].iter().product();
2296 let dim_len = shape[idx];
2297 for o in 0..outer {
2298 for d in 0..dim_len {
2299 for i in 0..inner {
2300 out[(o * dim_len + d) * inner + i] = t_data[d];
2301 }
2302 }
2303 }
2304 result.push(Tensor::from_vec(shape.clone(), out)?);
2305 }
2306 Ok(result)
2307 }
2308
2309 pub fn boolean_mask(&self, mask: &Self) -> Result<Self, TensorError> {
2313 if self.shape() != mask.shape() {
2314 return Err(TensorError::ShapeMismatch {
2315 left: self.shape().to_vec(),
2316 right: mask.shape().to_vec(),
2317 });
2318 }
2319 let data = self.data();
2320 let m = mask.data();
2321 let out: Vec<f32> = data
2322 .iter()
2323 .zip(m.iter())
2324 .filter(|(_, mv)| **mv != 0.0)
2325 .map(|(v, _)| *v)
2326 .collect();
2327 let n = out.len();
2328 Tensor::from_vec(vec![n], out)
2329 }
2330
2331 pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self, TensorError> {
2333 if dim >= self.rank() {
2334 return Err(TensorError::InvalidAxis {
2335 axis: dim,
2336 rank: self.rank(),
2337 });
2338 }
2339 let shape = self.shape();
2340 let idx_data = indices.data();
2341 let n_idx = idx_data.len();
2342 let outer: usize = shape[..dim].iter().product();
2343 let dim_len = shape[dim];
2344 let inner: usize = shape[dim + 1..].iter().product();
2345 let data = self.data();
2346
2347 let mut out = Vec::with_capacity(outer * n_idx * inner);
2348 for o in 0..outer {
2349 for &idx_f in idx_data {
2350 let idx = idx_f as usize;
2351 if idx >= dim_len {
2352 return Err(TensorError::IndexOutOfBounds {
2353 axis: dim,
2354 index: idx,
2355 dim: dim_len,
2356 });
2357 }
2358 let src_start = (o * dim_len + idx) * inner;
2359 out.extend_from_slice(&data[src_start..src_start + inner]);
2360 }
2361 }
2362
2363 let mut out_shape = shape.to_vec();
2364 out_shape[dim] = n_idx;
2365 Tensor::from_vec(out_shape, out)
2366 }
2367
2368 pub fn rand(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2372 let n: usize = shape.iter().product();
2373 let mut rng = seed;
2374 let data: Vec<f32> = (0..n)
2375 .map(|_| {
2376 rng ^= rng << 13;
2377 rng ^= rng >> 7;
2378 rng ^= rng << 17;
2379 (rng as f32) / (u64::MAX as f32)
2380 })
2381 .collect();
2382 Self::from_vec(shape, data)
2383 }
2384
2385 pub fn randn(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2388 let n: usize = shape.iter().product();
2389 let mut rng = seed;
2390 let mut next_rng = || -> f32 {
2391 rng ^= rng << 13;
2392 rng ^= rng >> 7;
2393 rng ^= rng << 17;
2394 ((rng as f64) / (u64::MAX as f64)).clamp(1e-15, 1.0 - 1e-15) as f32
2396 };
2397 let mut data = Vec::with_capacity(n);
2398 let mut i = 0;
2399 while i < n {
2400 let u1 = next_rng();
2401 let u2 = next_rng();
2402 let r = (-2.0 * (u1 as f64).ln()).sqrt();
2403 let theta = 2.0 * std::f64::consts::PI * u2 as f64;
2404 data.push((r * theta.cos()) as f32);
2405 i += 1;
2406 if i < n {
2407 data.push((r * theta.sin()) as f32);
2408 i += 1;
2409 }
2410 }
2411 Self::from_vec(shape, data)
2412 }
2413
2414 pub fn randint(shape: Vec<usize>, low: i64, high: i64, seed: u64) -> Result<Self, TensorError> {
2416 if high <= low {
2417 return Err(TensorError::UnsupportedOperation {
2418 msg: format!("randint requires high > low, got low={low}, high={high}"),
2419 });
2420 }
2421 let range = (high - low) as u64;
2422 let n: usize = shape.iter().product();
2423 let mut rng = seed;
2424 let data: Vec<f32> = (0..n)
2425 .map(|_| {
2426 rng ^= rng << 13;
2427 rng ^= rng >> 7;
2428 rng ^= rng << 17;
2429 (low + (rng % range) as i64) as f32
2430 })
2431 .collect();
2432 Self::from_vec(shape, data)
2433 }
2434
2435 pub fn randperm(n: usize, seed: u64) -> Result<Self, TensorError> {
2437 let mut perm: Vec<f32> = (0..n).map(|i| i as f32).collect();
2438 let mut rng = seed;
2439 for i in (1..n).rev() {
2440 rng ^= rng << 13;
2441 rng ^= rng >> 7;
2442 rng ^= rng << 17;
2443 let j = (rng as usize) % (i + 1);
2444 perm.swap(i, j);
2445 }
2446 Self::from_vec(vec![n], perm)
2447 }
2448}
2449
2450impl Tensor {
2453 pub fn step_slice(
2455 &self,
2456 dim: usize,
2457 start: usize,
2458 end: usize,
2459 step: usize,
2460 ) -> Result<Self, TensorError> {
2461 let rank = self.rank();
2462 if dim >= rank {
2463 return Err(TensorError::InvalidAxis { axis: dim, rank });
2464 }
2465 if step == 0 {
2466 return Err(TensorError::UnsupportedOperation {
2467 msg: "step must be > 0".to_string(),
2468 });
2469 }
2470 let shape = self.shape();
2471 let dim_len = shape[dim];
2472 let end = end.min(dim_len);
2473 if start >= end {
2474 let mut out_shape = shape.to_vec();
2476 out_shape[dim] = 0;
2477 return Tensor::from_vec(out_shape, vec![]);
2478 }
2479
2480 let selected_indices: Vec<usize> = (start..end).step_by(step).collect();
2481 let new_dim = selected_indices.len();
2482
2483 let outer: usize = shape[..dim].iter().product();
2484 let inner: usize = shape[dim + 1..].iter().product();
2485 let data = self.data();
2486
2487 let mut out = Vec::with_capacity(outer * new_dim * inner);
2488 for o in 0..outer {
2489 for &idx in &selected_indices {
2490 let src_start = (o * dim_len + idx) * inner;
2491 out.extend_from_slice(&data[src_start..src_start + inner]);
2492 }
2493 }
2494
2495 let mut out_shape = shape.to_vec();
2496 out_shape[dim] = new_dim;
2497 Tensor::from_vec(out_shape, out)
2498 }
2499
2500 pub fn einsum(equation: &str, tensors: &[&Tensor]) -> Result<Tensor, TensorError> {
2512 let equation = equation.replace(' ', "");
2513 let parts: Vec<&str> = equation.split("->").collect();
2514 if parts.len() != 2 {
2515 return Err(TensorError::UnsupportedOperation {
2516 msg: format!("invalid einsum equation: {equation}"),
2517 });
2518 }
2519 let inputs_str = parts[0];
2520 let output_str = parts[1];
2521 let input_parts: Vec<&str> = inputs_str.split(',').collect();
2522
2523 if input_parts.len() != tensors.len() {
2524 return Err(TensorError::UnsupportedOperation {
2525 msg: format!(
2526 "einsum equation has {} inputs but {} tensors provided",
2527 input_parts.len(),
2528 tensors.len()
2529 ),
2530 });
2531 }
2532
2533 let pattern = format!(
2535 "{}->{}",
2536 input_parts
2537 .iter()
2538 .map(|s| s.to_string())
2539 .collect::<Vec<_>>()
2540 .join(","),
2541 output_str
2542 );
2543
2544 match pattern.as_str() {
2545 "ij,jk->ik" => {
2547 let a = tensors[0];
2548 let b = tensors[1];
2549 if a.rank() != 2 || b.rank() != 2 {
2550 return Err(TensorError::UnsupportedOperation {
2551 msg: "ij,jk->ik requires 2D tensors".to_string(),
2552 });
2553 }
2554 let (m, k1) = (a.shape()[0], a.shape()[1]);
2555 let (k2, n) = (b.shape()[0], b.shape()[1]);
2556 if k1 != k2 {
2557 return Err(TensorError::ShapeMismatch {
2558 left: a.shape().to_vec(),
2559 right: b.shape().to_vec(),
2560 });
2561 }
2562 let ad = a.data();
2563 let bd = b.data();
2564 let mut out = vec![0.0f32; m * n];
2565 for i in 0..m {
2566 for j in 0..n {
2567 let mut sum = 0.0f32;
2568 for p in 0..k1 {
2569 sum += ad[i * k1 + p] * bd[p * n + j];
2570 }
2571 out[i * n + j] = sum;
2572 }
2573 }
2574 Tensor::from_vec(vec![m, n], out)
2575 }
2576 "ij->ji" => {
2578 let a = tensors[0];
2579 if a.rank() != 2 {
2580 return Err(TensorError::UnsupportedOperation {
2581 msg: "ij->ji requires a 2D tensor".to_string(),
2582 });
2583 }
2584 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2585 let ad = a.data();
2586 let mut out = vec![0.0f32; rows * cols];
2587 for i in 0..rows {
2588 for j in 0..cols {
2589 out[j * rows + i] = ad[i * cols + j];
2590 }
2591 }
2592 Tensor::from_vec(vec![cols, rows], out)
2593 }
2594 "ii->i" => {
2596 let a = tensors[0];
2597 if a.rank() != 2 || a.shape()[0] != a.shape()[1] {
2598 return Err(TensorError::UnsupportedOperation {
2599 msg: "ii->i requires a square 2D tensor".to_string(),
2600 });
2601 }
2602 let n = a.shape()[0];
2603 let ad = a.data();
2604 let out: Vec<f32> = (0..n).map(|i| ad[i * n + i]).collect();
2605 Tensor::from_vec(vec![n], out)
2606 }
2607 "ij->i" => {
2609 let a = tensors[0];
2610 if a.rank() != 2 {
2611 return Err(TensorError::UnsupportedOperation {
2612 msg: "ij->i requires a 2D tensor".to_string(),
2613 });
2614 }
2615 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2616 let ad = a.data();
2617 let out: Vec<f32> = (0..rows)
2618 .map(|i| ad[i * cols..(i + 1) * cols].iter().sum())
2619 .collect();
2620 Tensor::from_vec(vec![rows], out)
2621 }
2622 "ij->j" => {
2624 let a = tensors[0];
2625 if a.rank() != 2 {
2626 return Err(TensorError::UnsupportedOperation {
2627 msg: "ij->j requires a 2D tensor".to_string(),
2628 });
2629 }
2630 let (rows, cols) = (a.shape()[0], a.shape()[1]);
2631 let ad = a.data();
2632 let mut out = vec![0.0f32; cols];
2633 for i in 0..rows {
2634 for j in 0..cols {
2635 out[j] += ad[i * cols + j];
2636 }
2637 }
2638 Tensor::from_vec(vec![cols], out)
2639 }
2640 "ij->" => {
2642 let a = tensors[0];
2643 if a.rank() != 2 {
2644 return Err(TensorError::UnsupportedOperation {
2645 msg: "ij-> requires a 2D tensor".to_string(),
2646 });
2647 }
2648 let sum: f32 = a.data().iter().sum();
2649 Ok(Tensor::scalar(sum))
2650 }
2651 "i,i->" => {
2653 let a = tensors[0];
2654 let b = tensors[1];
2655 if a.rank() != 1 || b.rank() != 1 {
2656 return Err(TensorError::UnsupportedOperation {
2657 msg: "i,i-> requires 1D tensors".to_string(),
2658 });
2659 }
2660 if a.shape()[0] != b.shape()[0] {
2661 return Err(TensorError::ShapeMismatch {
2662 left: a.shape().to_vec(),
2663 right: b.shape().to_vec(),
2664 });
2665 }
2666 let sum: f32 = a
2667 .data()
2668 .iter()
2669 .zip(b.data().iter())
2670 .map(|(x, y)| x * y)
2671 .sum();
2672 Ok(Tensor::scalar(sum))
2673 }
2674 "ij,ij->" => {
2676 let a = tensors[0];
2677 let b = tensors[1];
2678 if a.rank() != 2 || b.rank() != 2 {
2679 return Err(TensorError::UnsupportedOperation {
2680 msg: "ij,ij-> requires 2D tensors".to_string(),
2681 });
2682 }
2683 if a.shape() != b.shape() {
2684 return Err(TensorError::ShapeMismatch {
2685 left: a.shape().to_vec(),
2686 right: b.shape().to_vec(),
2687 });
2688 }
2689 let sum: f32 = a
2690 .data()
2691 .iter()
2692 .zip(b.data().iter())
2693 .map(|(x, y)| x * y)
2694 .sum();
2695 Ok(Tensor::scalar(sum))
2696 }
2697 _ => Err(TensorError::UnsupportedOperation {
2698 msg: format!("unsupported einsum pattern: {pattern}"),
2699 }),
2700 }
2701 }
2702
2703 pub fn chunk(&self, n_chunks: usize, axis: usize) -> Result<Vec<Self>, TensorError> {
2707 if axis >= self.rank() {
2708 return Err(TensorError::InvalidAxis {
2709 axis,
2710 rank: self.rank(),
2711 });
2712 }
2713 if n_chunks == 0 {
2714 return Err(TensorError::UnsupportedOperation {
2715 msg: "n_chunks must be > 0".to_string(),
2716 });
2717 }
2718 let dim = self.shape()[axis];
2719 let chunk_size = dim.div_ceil(n_chunks); let mut chunks = Vec::new();
2721 let mut start = 0;
2722 while start < dim {
2723 let length = chunk_size.min(dim - start);
2724 chunks.push(self.narrow(axis, start, length)?);
2725 start += length;
2726 }
2727 Ok(chunks)
2728 }
2729
2730 pub fn histogram(&self, bins: usize, min: f32, max: f32) -> Result<Self, TensorError> {
2735 let mut counts = vec![0.0f32; bins];
2736 let range = max - min;
2737 for &v in self.data() {
2738 if v >= min && v <= max {
2739 let idx = if v == max {
2740 bins - 1
2741 } else {
2742 ((v - min) / range * bins as f32) as usize
2743 };
2744 counts[idx] += 1.0;
2745 }
2746 }
2747 Tensor::from_vec(vec![bins], counts)
2748 }
2749
2750 pub fn bincount(&self, num_bins: usize) -> Result<Self, TensorError> {
2755 let mut counts = vec![0.0f32; num_bins];
2756 for &v in self.data() {
2757 let idx = v as usize;
2758 if idx < num_bins {
2759 counts[idx] += 1.0;
2760 }
2761 }
2762 Tensor::from_vec(vec![num_bins], counts)
2763 }
2764
2765 pub fn item(&self) -> Result<f32, TensorError> {
2770 if self.len() != 1 {
2771 return Err(TensorError::ShapeMismatch {
2772 left: vec![1],
2773 right: self.shape().to_vec(),
2774 });
2775 }
2776 Ok(self.data()[0])
2777 }
2778
2779 pub fn is_scalar(&self) -> bool {
2781 self.len() == 1
2782 }
2783
2784 pub fn scatter_add(&self, dim: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
2790 if dim >= self.rank() {
2791 return Err(TensorError::InvalidAxis {
2792 axis: dim,
2793 rank: self.rank(),
2794 });
2795 }
2796 if index.rank() != self.rank() {
2797 return Err(TensorError::InvalidIndexRank {
2798 expected: self.rank(),
2799 got: index.rank(),
2800 });
2801 }
2802 if src.shape() != index.shape() {
2803 return Err(TensorError::ShapeMismatch {
2804 left: src.shape().to_vec(),
2805 right: index.shape().to_vec(),
2806 });
2807 }
2808
2809 let self_shape = self.shape();
2810 let idx_shape = index.shape();
2811 let idx_data = index.data();
2812 let src_data = src.data();
2813 let ndim = self.rank();
2814
2815 let mut out = self.data().to_vec();
2816 let mut coords = vec![0usize; ndim];
2817
2818 for pos in 0..index.len() {
2819 let idx_val = idx_data[pos] as usize;
2820 if idx_val >= self_shape[dim] {
2821 return Err(TensorError::IndexOutOfBounds {
2822 axis: dim,
2823 index: idx_val,
2824 dim: self_shape[dim],
2825 });
2826 }
2827
2828 let mut dst_offset = 0;
2829 for d in 0..ndim {
2830 let c = if d == dim { idx_val } else { coords[d] };
2831 dst_offset += c * self.strides()[d];
2832 }
2833 out[dst_offset] += src_data[pos];
2834
2835 increment_coords(&mut coords, idx_shape);
2836 }
2837
2838 Tensor::from_vec(self_shape.to_vec(), out)
2839 }
2840}
2841
2842fn fp16_to_f32(half: u16) -> f32 {
2844 let sign = ((half & 0x8000) as u32) << 16;
2845 let exponent = (half >> 10) & 0x1F;
2846 let mantissa = (half & 0x03FF) as u32;
2847
2848 if exponent == 0 {
2849 if mantissa == 0 {
2850 return f32::from_bits(sign); }
2852 let mut m = mantissa;
2854 let mut e = 0i32;
2855 while m & 0x0400 == 0 {
2856 m <<= 1;
2857 e += 1;
2858 }
2859 m &= 0x03FF;
2860 let f32_exp = ((127 - 15 - e) as u32) << 23;
2861 let f32_man = m << 13;
2862 return f32::from_bits(sign | f32_exp | f32_man);
2863 }
2864 if exponent == 31 {
2865 let f32_exp = 0xFF << 23;
2866 let f32_man = mantissa << 13;
2867 return f32::from_bits(sign | f32_exp | f32_man);
2868 }
2869
2870 let f32_exp = ((exponent as i32 - 15 + 127) as u32 & 0xFF) << 23;
2871 let f32_man = mantissa << 13;
2872 f32::from_bits(sign | f32_exp | f32_man)
2873}