1use std::sync::{Arc, RwLock};
2
3use crate::backend::{Backend, BinaryOp, CmpOp, ReduceOp, UnaryOp};
4use crate::dtype::DType;
5use crate::error::{Error, Result};
6use crate::layout::Layout;
7use crate::op::{Op, TensorId};
8use crate::shape::Shape;
9
10struct TensorInner<B: Backend> {
46 id: TensorId,
48 storage: Arc<RwLock<B::Storage>>,
50 layout: Layout,
52 dtype: DType,
54 device: B::Device,
56 op: Op<B>,
59 is_variable: bool,
62}
63
64pub struct Tensor<B: Backend> {
82 inner: Arc<TensorInner<B>>,
83}
84
85impl<B: Backend> Clone for Tensor<B> {
87 fn clone(&self) -> Self {
88 Tensor {
89 inner: Arc::clone(&self.inner),
90 }
91 }
92}
93
94impl<B: Backend> std::fmt::Debug for Tensor<B> {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(
97 f,
98 "Tensor(id={:?}, shape={}, dtype={}, device={:?})",
99 self.inner.id,
100 self.inner.layout.shape(),
101 self.inner.dtype,
102 self.inner.device,
103 )
104 }
105}
106
107impl<B: Backend> Tensor<B> {
108 pub(crate) fn from_storage(
112 storage: B::Storage,
113 layout: Layout,
114 dtype: DType,
115 device: B::Device,
116 op: Op<B>,
117 ) -> Self {
118 Tensor {
119 inner: Arc::new(TensorInner {
120 id: TensorId::new(),
121 storage: Arc::new(RwLock::new(storage)),
122 layout,
123 dtype,
124 device,
125 op,
126 is_variable: false,
127 }),
128 }
129 }
130
131 fn view_with_layout(&self, layout: Layout, op: Op<B>) -> Self {
133 Tensor {
134 inner: Arc::new(TensorInner {
135 id: TensorId::new(),
136 storage: Arc::clone(&self.inner.storage),
137 layout,
138 dtype: self.inner.dtype,
139 device: self.inner.device.clone(),
140 op,
141 is_variable: false,
142 }),
143 }
144 }
145
146 pub fn id(&self) -> TensorId {
150 self.inner.id
151 }
152
153 pub fn shape(&self) -> &Shape {
155 self.inner.layout.shape()
156 }
157
158 pub fn dims(&self) -> &[usize] {
160 self.inner.layout.dims()
161 }
162
163 pub fn rank(&self) -> usize {
165 self.inner.layout.rank()
166 }
167
168 pub fn elem_count(&self) -> usize {
170 self.inner.layout.elem_count()
171 }
172
173 pub fn dtype(&self) -> DType {
175 self.inner.dtype
176 }
177
178 pub fn device(&self) -> &B::Device {
180 &self.inner.device
181 }
182
183 pub fn layout(&self) -> &Layout {
185 &self.inner.layout
186 }
187
188 pub fn is_contiguous(&self) -> bool {
190 self.inner.layout.is_contiguous()
191 }
192
193 pub fn is_variable(&self) -> bool {
195 self.inner.is_variable
196 }
197
198 pub fn storage(&self) -> std::sync::RwLockReadGuard<'_, B::Storage> {
200 self.inner.storage.read().expect("storage lock poisoned")
201 }
202
203 fn read_storage(&self) -> Result<std::sync::RwLockReadGuard<'_, B::Storage>> {
205 self.inner
206 .storage
207 .read()
208 .map_err(|_| Error::msg("storage lock poisoned"))
209 }
210
211 fn write_storage(&self) -> Result<std::sync::RwLockWriteGuard<'_, B::Storage>> {
213 self.inner
214 .storage
215 .write()
216 .map_err(|_| Error::msg("storage lock poisoned"))
217 }
218
219 pub fn op(&self) -> &Op<B> {
221 &self.inner.op
222 }
223
224 pub fn update_data_inplace(&self, new_data: &[f64]) -> Result<()> {
239 let expected = self.elem_count();
240 if new_data.len() != expected {
241 return Err(Error::msg(format!(
242 "update_data_inplace: expected {} elements, got {}",
243 expected,
244 new_data.len()
245 )));
246 }
247 let new_storage = B::from_f64_slice(new_data, self.dtype(), self.device())?;
248 let mut guard = self.write_storage()?;
249 *guard = new_storage;
250 Ok(())
251 }
252
253 pub fn zeros(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
257 let shape = shape.into();
258 let layout = Layout::contiguous(shape.clone());
259 let storage = B::zeros(&shape, dtype, device)?;
260 Ok(Self::from_storage(
261 storage,
262 layout,
263 dtype,
264 device.clone(),
265 Op::None,
266 ))
267 }
268
269 pub fn ones(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
271 let shape = shape.into();
272 let layout = Layout::contiguous(shape.clone());
273 let storage = B::ones(&shape, dtype, device)?;
274 Ok(Self::from_storage(
275 storage,
276 layout,
277 dtype,
278 device.clone(),
279 Op::None,
280 ))
281 }
282
283 pub fn full(
285 shape: impl Into<Shape>,
286 val: f64,
287 dtype: DType,
288 device: &B::Device,
289 ) -> Result<Self> {
290 let shape = shape.into();
291 let layout = Layout::contiguous(shape.clone());
292 let storage = B::full(&shape, val, dtype, device)?;
293 Ok(Self::from_storage(
294 storage,
295 layout,
296 dtype,
297 device.clone(),
298 Op::None,
299 ))
300 }
301
302 pub fn from_f64_slice(
305 data: &[f64],
306 shape: impl Into<Shape>,
307 dtype: DType,
308 device: &B::Device,
309 ) -> Result<Self> {
310 let shape = shape.into();
311 if data.len() != shape.elem_count() {
312 return Err(Error::ElementCountMismatch {
313 shape: shape.clone(),
314 expected: shape.elem_count(),
315 got: data.len(),
316 });
317 }
318 let layout = Layout::contiguous(shape);
319 let storage = B::from_f64_slice(data, dtype, device)?;
320 Ok(Self::from_storage(
321 storage,
322 layout,
323 dtype,
324 device.clone(),
325 Op::None,
326 ))
327 }
328
329 pub fn rand(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
331 let shape = shape.into();
332 let layout = Layout::contiguous(shape.clone());
333 let storage = B::rand_uniform(&shape, dtype, device)?;
334 Ok(Self::from_storage(
335 storage,
336 layout,
337 dtype,
338 device.clone(),
339 Op::None,
340 ))
341 }
342
343 pub fn randn(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
345 let shape = shape.into();
346 let layout = Layout::contiguous(shape.clone());
347 let storage = B::rand_normal(&shape, dtype, device)?;
348 Ok(Self::from_storage(
349 storage,
350 layout,
351 dtype,
352 device.clone(),
353 Op::None,
354 ))
355 }
356
357 pub fn linspace(
364 start: f64,
365 end: f64,
366 steps: usize,
367 dtype: DType,
368 device: &B::Device,
369 ) -> Result<Self> {
370 if steps == 0 {
371 return Err(Error::msg("linspace requires steps >= 1"));
372 }
373 if steps == 1 {
374 return Self::from_f64_slice(&[start], 1, dtype, device);
375 }
376 let step = (end - start) / (steps as f64 - 1.0);
377 let data: Vec<f64> = (0..steps).map(|i| start + step * i as f64).collect();
378 Self::from_f64_slice(&data, steps, dtype, device)
379 }
380
381 pub fn eye(n: usize, dtype: DType, device: &B::Device) -> Result<Self> {
390 let mut data = vec![0.0f64; n * n];
391 for i in 0..n {
392 data[i * n + i] = 1.0;
393 }
394 Self::from_f64_slice(&data, (n, n), dtype, device)
395 }
396
397 pub fn zeros_like(other: &Self) -> Result<Self> {
399 Self::zeros(other.shape().clone(), other.dtype(), other.device())
400 }
401
402 pub fn ones_like(other: &Self) -> Result<Self> {
404 Self::ones(other.shape().clone(), other.dtype(), other.device())
405 }
406
407 pub fn full_like(other: &Self, val: f64) -> Result<Self> {
409 Self::full(other.shape().clone(), val, other.dtype(), other.device())
410 }
411
412 pub fn set_variable(self) -> Self {
415 Tensor {
416 inner: Arc::new(TensorInner {
417 id: self.inner.id,
418 storage: Arc::clone(&self.inner.storage),
419 layout: self.inner.layout.clone(),
420 dtype: self.inner.dtype,
421 device: self.inner.device.clone(),
422 op: self.inner.op.clone(),
423 is_variable: true,
424 }),
425 }
426 }
427
428 pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Self> {
432 let new_layout = self.inner.layout.transpose(dim0, dim1)?;
433 let op = Op::Transpose {
434 input: self.clone(),
435 dim0,
436 dim1,
437 };
438 Ok(self.view_with_layout(new_layout, op))
439 }
440
441 pub fn t(&self) -> Result<Self> {
443 if self.rank() != 2 {
444 return Err(Error::RankMismatch {
445 expected: 2,
446 got: self.rank(),
447 });
448 }
449 self.transpose(0, 1)
450 }
451
452 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
454 let new_layout = self.inner.layout.narrow(dim, start, len)?;
455 let op = Op::Narrow {
456 input: self.clone(),
457 dim,
458 start,
459 len,
460 };
461 Ok(self.view_with_layout(new_layout, op))
462 }
463
464 pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Self> {
467 let new_shape = new_shape.into();
468 let current_count = self.elem_count();
469 let new_count = new_shape.elem_count();
470 if current_count != new_count {
471 return Err(Error::ReshapeElementMismatch {
472 src: current_count,
473 dst: new_count,
474 dst_shape: new_shape,
475 });
476 }
477 let tensor = if self.is_contiguous() {
479 self.clone()
480 } else {
481 self.contiguous()?
482 };
483 let src_shape = tensor.shape().clone();
484 let new_layout = Layout::contiguous(new_shape);
485 let op = Op::Reshape {
486 input: tensor.clone(),
487 src_shape,
488 };
489 Ok(tensor.view_with_layout(new_layout, op))
490 }
491
492 pub fn contiguous(&self) -> Result<Self> {
496 if self.is_contiguous() {
497 return Ok(self.clone());
498 }
499 let storage = self.read_storage()?;
500 let new_storage = B::to_contiguous(&storage, &self.inner.layout)?;
501 let new_layout = Layout::contiguous(self.shape().clone());
502 Ok(Self::from_storage(
503 new_storage,
504 new_layout,
505 self.inner.dtype,
506 self.inner.device.clone(),
507 Op::Contiguous {
508 input: self.clone(),
509 },
510 ))
511 }
512
513 pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
517 let rank = self.rank();
518 if dim > rank {
519 return Err(Error::DimOutOfRange {
520 dim,
521 rank: rank + 1,
522 });
523 }
524 let mut new_dims = self.dims().to_vec();
525 let mut new_strides = self.layout().strides().to_vec();
526 let stride_val = if dim < rank { new_strides[dim] } else { 1 };
529 new_dims.insert(dim, 1);
530 new_strides.insert(dim, stride_val);
531 let new_layout = Layout::new(Shape::new(new_dims), new_strides, self.layout().offset());
532 let op = Op::Reshape {
533 input: self.clone(),
534 src_shape: self.shape().clone(),
535 };
536 Ok(self.view_with_layout(new_layout, op))
537 }
538
539 pub fn squeeze_all(&self) -> Self {
542 let new_dims: Vec<usize> = self.dims().iter().copied().filter(|&d| d != 1).collect();
543 let new_strides: Vec<usize> = self
544 .dims()
545 .iter()
546 .zip(self.layout().strides().iter())
547 .filter(|(&d, _)| d != 1)
548 .map(|(_, &s)| s)
549 .collect();
550 let new_layout = Layout::new(
551 Shape::new(if new_dims.is_empty() {
552 vec![]
553 } else {
554 new_dims
555 }),
556 new_strides,
557 self.layout().offset(),
558 );
559 let op = Op::Reshape {
560 input: self.clone(),
561 src_shape: self.shape().clone(),
562 };
563 self.view_with_layout(new_layout, op)
564 }
565
566 pub fn squeeze(&self, dim: usize) -> Result<Self> {
572 let rank = self.rank();
573 if dim >= rank {
574 return Err(Error::DimOutOfRange { dim, rank });
575 }
576 if self.dims()[dim] != 1 {
577 return Err(Error::msg(format!(
578 "squeeze: dimension {} has size {}, expected 1",
579 dim,
580 self.dims()[dim]
581 )));
582 }
583 let mut new_dims = self.dims().to_vec();
584 let mut new_strides = self.layout().strides().to_vec();
585 new_dims.remove(dim);
586 new_strides.remove(dim);
587 let new_layout = Layout::new(
588 Shape::new(if new_dims.is_empty() {
589 vec![]
590 } else {
591 new_dims
592 }),
593 new_strides,
594 self.layout().offset(),
595 );
596 let op = Op::Reshape {
597 input: self.clone(),
598 src_shape: self.shape().clone(),
599 };
600 Ok(self.view_with_layout(new_layout, op))
601 }
602
603 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
610 let rank = self.rank();
611 if dims.len() != rank {
612 return Err(Error::msg(format!(
613 "permute: expected {} dimensions, got {}",
614 rank,
615 dims.len()
616 )));
617 }
618 let mut seen = vec![false; rank];
620 for &d in dims {
621 if d >= rank {
622 return Err(Error::DimOutOfRange { dim: d, rank });
623 }
624 if seen[d] {
625 return Err(Error::msg(format!("permute: duplicate dimension {}", d)));
626 }
627 seen[d] = true;
628 }
629
630 let old_dims = self.dims();
631 let old_strides = self.layout().strides();
632 let new_dims: Vec<usize> = dims.iter().map(|&d| old_dims[d]).collect();
633 let new_strides: Vec<usize> = dims.iter().map(|&d| old_strides[d]).collect();
634 let new_layout = Layout::new(Shape::new(new_dims), new_strides, self.layout().offset());
635 let op = Op::Reshape {
638 input: self.clone(),
639 src_shape: self.shape().clone(),
640 };
641 Ok(self.view_with_layout(new_layout, op))
642 }
643
644 pub fn cumsum(&self, dim: usize) -> Result<Self> {
651 let rank = self.rank();
652 if dim >= rank {
653 return Err(Error::DimOutOfRange { dim, rank });
654 }
655 let t = self.contiguous()?;
656 let data = t.to_f64_vec()?;
657 let shape = t.shape().clone();
658 let dims = shape.dims();
659 let mut out = data.clone();
660
661 let inner: usize = dims[dim + 1..].iter().product();
663 let outer: usize = dims[..dim].iter().product();
664 let dim_size = dims[dim];
665
666 for o in 0..outer {
667 for i in 0..inner {
668 for d in 1..dim_size {
669 let idx = (o * dim_size + d) * inner + i;
670 let prev = (o * dim_size + d - 1) * inner + i;
671 out[idx] += out[prev];
672 }
673 }
674 }
675
676 Self::from_f64_slice(&out, shape, t.dtype(), t.device())
677 }
678
679 pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self)> {
685 let rank = self.rank();
686 if dim >= rank {
687 return Err(Error::DimOutOfRange { dim, rank });
688 }
689 let t = self.contiguous()?;
690 let data = t.to_f64_vec()?;
691 let shape = t.shape().clone();
692 let dims = shape.dims();
693 let dim_size = dims[dim];
694 let inner: usize = dims[dim + 1..].iter().product();
695 let outer: usize = dims[..dim].iter().product();
696
697 let mut sorted_data = data.clone();
698 let mut indices = vec![0.0f64; data.len()];
699
700 for o in 0..outer {
701 for i in 0..inner {
702 let mut slice: Vec<(f64, usize)> = (0..dim_size)
704 .map(|d| {
705 let idx = (o * dim_size + d) * inner + i;
706 (data[idx], d)
707 })
708 .collect();
709
710 if descending {
711 slice
712 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
713 } else {
714 slice
715 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
716 }
717
718 for (d, (val, orig_idx)) in slice.into_iter().enumerate() {
719 let idx = (o * dim_size + d) * inner + i;
720 sorted_data[idx] = val;
721 indices[idx] = orig_idx as f64;
722 }
723 }
724 }
725
726 let vals = Self::from_f64_slice(&sorted_data, shape.clone(), t.dtype(), t.device())?;
727 let idxs = Self::from_f64_slice(&indices, shape, t.dtype(), t.device())?;
728 Ok((vals, idxs))
729 }
730
731 pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self> {
737 let (_, indices) = self.sort(dim, descending)?;
738 Ok(indices)
739 }
740
741 pub fn add(&self, rhs: &Self) -> Result<Self> {
745 self.binary_op(rhs, BinaryOp::Add)
746 }
747
748 pub fn sub(&self, rhs: &Self) -> Result<Self> {
750 self.binary_op(rhs, BinaryOp::Sub)
751 }
752
753 pub fn mul(&self, rhs: &Self) -> Result<Self> {
755 self.binary_op(rhs, BinaryOp::Mul)
756 }
757
758 pub fn div(&self, rhs: &Self) -> Result<Self> {
760 self.binary_op(rhs, BinaryOp::Div)
761 }
762
763 fn binary_op(&self, rhs: &Self, op: BinaryOp) -> Result<Self> {
765 if self.dtype() != rhs.dtype() {
766 return Err(Error::DTypeMismatch {
767 expected: self.dtype(),
768 got: rhs.dtype(),
769 });
770 }
771 let storage_lhs = self.read_storage()?;
772 let storage_rhs = rhs.read_storage()?;
773 let result = B::binary_op(
774 op,
775 &storage_lhs,
776 &self.inner.layout,
777 &storage_rhs,
778 &rhs.inner.layout,
779 )?;
780 let result_shape = Shape::broadcast_shape(self.shape(), rhs.shape())?;
782 let result_layout = Layout::contiguous(result_shape);
783 let result_op = Op::Binary {
784 lhs: self.clone(),
785 rhs: rhs.clone(),
786 op,
787 };
788 Ok(Self::from_storage(
789 result,
790 result_layout,
791 self.inner.dtype,
792 self.inner.device.clone(),
793 result_op,
794 ))
795 }
796
797 pub fn eq(&self, rhs: &Self) -> Result<Self> {
801 self.cmp_op(rhs, CmpOp::Eq)
802 }
803
804 pub fn ne(&self, rhs: &Self) -> Result<Self> {
806 self.cmp_op(rhs, CmpOp::Ne)
807 }
808
809 pub fn gt(&self, rhs: &Self) -> Result<Self> {
811 self.cmp_op(rhs, CmpOp::Gt)
812 }
813
814 pub fn ge(&self, rhs: &Self) -> Result<Self> {
816 self.cmp_op(rhs, CmpOp::Ge)
817 }
818
819 pub fn lt(&self, rhs: &Self) -> Result<Self> {
821 self.cmp_op(rhs, CmpOp::Lt)
822 }
823
824 pub fn le(&self, rhs: &Self) -> Result<Self> {
826 self.cmp_op(rhs, CmpOp::Le)
827 }
828
829 fn cmp_op(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
832 let storage_lhs = self.read_storage()?;
833 let storage_rhs = rhs.read_storage()?;
834 let result = B::cmp_op(
835 op,
836 &storage_lhs,
837 &self.inner.layout,
838 &storage_rhs,
839 &rhs.inner.layout,
840 )?;
841 let result_shape = Shape::broadcast_shape(self.shape(), rhs.shape())?;
842 let result_layout = Layout::contiguous(result_shape);
843 Ok(Self::from_storage(
845 result,
846 result_layout,
847 DType::U8,
848 self.inner.device.clone(),
849 Op::None,
850 ))
851 }
852
853 pub fn neg(&self) -> Result<Self> {
857 self.unary_op(UnaryOp::Neg)
858 }
859
860 pub fn abs(&self) -> Result<Self> {
862 self.unary_op(UnaryOp::Abs)
863 }
864
865 pub fn exp(&self) -> Result<Self> {
867 self.unary_op(UnaryOp::Exp)
868 }
869
870 pub fn log(&self) -> Result<Self> {
872 self.unary_op(UnaryOp::Log)
873 }
874
875 pub fn sqrt(&self) -> Result<Self> {
877 self.unary_op(UnaryOp::Sqrt)
878 }
879
880 pub fn square(&self) -> Result<Self> {
882 self.unary_op(UnaryOp::Square)
883 }
884
885 pub fn relu(&self) -> Result<Self> {
887 self.unary_op(UnaryOp::Relu)
888 }
889
890 pub fn sigmoid(&self) -> Result<Self> {
892 self.unary_op(UnaryOp::Sigmoid)
893 }
894
895 pub fn tanh(&self) -> Result<Self> {
897 self.unary_op(UnaryOp::Tanh)
898 }
899
900 pub fn gelu(&self) -> Result<Self> {
902 self.unary_op(UnaryOp::Gelu)
903 }
904
905 pub fn silu(&self) -> Result<Self> {
907 self.unary_op(UnaryOp::Silu)
908 }
909
910 pub fn sin(&self) -> Result<Self> {
912 self.unary_op(UnaryOp::Sin)
913 }
914
915 pub fn cos(&self) -> Result<Self> {
917 self.unary_op(UnaryOp::Cos)
918 }
919
920 pub fn floor(&self) -> Result<Self> {
922 self.unary_op(UnaryOp::Floor)
923 }
924
925 pub fn ceil(&self) -> Result<Self> {
927 self.unary_op(UnaryOp::Ceil)
928 }
929
930 pub fn round(&self) -> Result<Self> {
932 self.unary_op(UnaryOp::Round)
933 }
934
935 pub fn powf(&self, exponent: f64) -> Result<Self> {
937 let storage = self.read_storage()?;
938 let result = B::powf(&storage, &self.inner.layout, exponent)?;
939 let result_layout = Layout::contiguous(self.shape().clone());
940 let result_op = Op::Powf {
941 input: self.clone(),
942 exponent,
943 };
944 Ok(Self::from_storage(
945 result,
946 result_layout,
947 self.inner.dtype,
948 self.inner.device.clone(),
949 result_op,
950 ))
951 }
952
953 pub fn clamp(&self, min: f64, max: f64) -> Result<Self> {
955 let storage = self.read_storage()?;
956 let result = B::clamp(&storage, &self.inner.layout, min, max)?;
957 let result_layout = Layout::contiguous(self.shape().clone());
958 let result_op = Op::Clamp {
959 input: self.clone(),
960 min,
961 max,
962 };
963 Ok(Self::from_storage(
964 result,
965 result_layout,
966 self.inner.dtype,
967 self.inner.device.clone(),
968 result_op,
969 ))
970 }
971
972 pub fn where_cond(mask: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
977 let mask_s = mask.read_storage()?;
978 let true_s = on_true.read_storage()?;
979 let false_s = on_false.read_storage()?;
980 let result = B::where_cond(
981 &mask_s,
982 &mask.inner.layout,
983 &true_s,
984 &on_true.inner.layout,
985 &false_s,
986 &on_false.inner.layout,
987 )?;
988 let result_layout = Layout::contiguous(on_true.shape().clone());
989 let result_op = Op::WhereCond {
990 mask: mask.clone(),
991 on_true: on_true.clone(),
992 on_false: on_false.clone(),
993 };
994 Ok(Self::from_storage(
995 result,
996 result_layout,
997 on_true.inner.dtype,
998 on_true.inner.device.clone(),
999 result_op,
1000 ))
1001 }
1002
1003 pub fn gather(&self, dim: usize, index: &Self) -> Result<Self> {
1010 let input_s = self.read_storage()?;
1011 let index_s = index.read_storage()?;
1012 let result = B::gather(
1013 &input_s,
1014 &self.inner.layout,
1015 &index_s,
1016 &index.inner.layout,
1017 dim,
1018 )?;
1019 let result_layout = Layout::contiguous(index.shape().clone());
1020 let result_op = Op::Gather {
1021 input: self.clone(),
1022 index: index.clone(),
1023 dim,
1024 };
1025 Ok(Self::from_storage(
1026 result,
1027 result_layout,
1028 self.inner.dtype,
1029 self.inner.device.clone(),
1030 result_op,
1031 ))
1032 }
1033
1034 pub fn masked_fill(&self, mask: &Self, value: f64) -> Result<Self> {
1040 let fill = Self::full(self.shape().clone(), value, self.dtype(), self.device())?;
1041 Self::where_cond(mask, &fill, self)
1042 }
1043
1044 pub fn pad(&self, padding: &[[usize; 2]], value: f64) -> Result<Self> {
1051 let rank = self.rank();
1052 if padding.len() > rank {
1053 return Err(Error::msg(format!(
1054 "pad: {} padding pairs but tensor rank is {}",
1055 padding.len(),
1056 rank
1057 )));
1058 }
1059
1060 let mut full_pad = vec![[0usize; 2]; rank];
1062 let offset = rank - padding.len();
1063 for (i, p) in padding.iter().enumerate() {
1064 full_pad[offset + i] = *p;
1065 }
1066
1067 let in_dims = self.dims();
1069 let out_dims: Vec<usize> = in_dims
1070 .iter()
1071 .zip(full_pad.iter())
1072 .map(|(&d, &[b, a])| d + b + a)
1073 .collect();
1074
1075 if full_pad.iter().all(|&[b, a]| b == 0 && a == 0) {
1077 return Ok(self.clone());
1078 }
1079
1080 let mut current = self.clone();
1082 for d in (0..rank).rev() {
1083 let [before, after] = full_pad[d];
1084 if before == 0 && after == 0 {
1085 continue;
1086 }
1087 let mut cur_dims = current.dims().to_vec();
1088 let mut parts: Vec<Self> = Vec::new();
1089
1090 if before > 0 {
1091 cur_dims[d] = before;
1092 let pad_before = Self::full(
1093 Shape::new(cur_dims.clone()),
1094 value,
1095 self.dtype(),
1096 self.device(),
1097 )?;
1098 cur_dims[d] = current.dims()[d]; parts.push(pad_before);
1100 }
1101 parts.push(current);
1102 if after > 0 {
1103 cur_dims[d] = after;
1104 let pad_after = Self::full(
1105 Shape::new(cur_dims.clone()),
1106 value,
1107 self.dtype(),
1108 self.device(),
1109 )?;
1110 parts.push(pad_after);
1111 }
1112 current = Self::cat(&parts, d)?;
1113 }
1114
1115 let result_layout = Layout::contiguous(Shape::new(out_dims));
1117 let pad_op = Op::Pad {
1118 input: self.clone(),
1119 padding: full_pad,
1120 };
1121
1122 let storage = current.read_storage()?;
1124 Ok(Self::from_storage(
1125 storage.clone(),
1126 result_layout,
1127 self.inner.dtype,
1128 self.inner.device.clone(),
1129 pad_op,
1130 ))
1131 }
1132
1133 #[allow(clippy::needless_range_loop)]
1140 pub fn topk(&self, k: usize, dim: usize) -> Result<(Self, Vec<usize>)> {
1141 if dim >= self.rank() {
1142 return Err(Error::DimOutOfRange {
1143 dim,
1144 rank: self.rank(),
1145 });
1146 }
1147 let dims = self.dims();
1148 let dim_size = dims[dim];
1149 if k > dim_size {
1150 return Err(Error::msg(format!(
1151 "topk: k={} exceeds dim {} size {}",
1152 k, dim, dim_size
1153 )));
1154 }
1155
1156 let data = self.contiguous()?.to_f64_vec()?;
1157
1158 let mut out_dims = dims.to_vec();
1160 out_dims[dim] = k;
1161 let out_size: usize = out_dims.iter().product();
1162 let mut out_values = vec![0.0f64; out_size];
1163 let mut out_indices = vec![0usize; out_size];
1164
1165 let outer: usize = dims[..dim].iter().product();
1167 let inner: usize = dims[dim + 1..].iter().product();
1168
1169 for o in 0..outer {
1170 for i in 0..inner {
1171 let mut slice: Vec<(f64, usize)> = (0..dim_size)
1173 .map(|d| {
1174 let flat = o * (dim_size * inner) + d * inner + i;
1175 (data[flat], d)
1176 })
1177 .collect();
1178 slice.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1180
1181 for j in 0..k {
1183 let out_flat = o * (k * inner) + j * inner + i;
1184 out_values[out_flat] = slice[j].0;
1185 out_indices[out_flat] = slice[j].1;
1186 }
1187 }
1188 }
1189
1190 let values = Self::from_f64_slice(
1191 &out_values,
1192 Shape::new(out_dims),
1193 self.dtype(),
1194 self.device(),
1195 )?;
1196 Ok((values, out_indices))
1197 }
1198
1199 fn unary_op(&self, op: UnaryOp) -> Result<Self> {
1201 let storage = self.read_storage()?;
1202 let result = B::unary_op(op, &storage, &self.inner.layout)?;
1203 let result_layout = Layout::contiguous(self.shape().clone());
1204 let result_op = Op::Unary {
1205 input: self.clone(),
1206 op,
1207 };
1208 Ok(Self::from_storage(
1209 result,
1210 result_layout,
1211 self.inner.dtype,
1212 self.inner.device.clone(),
1213 result_op,
1214 ))
1215 }
1216
1217 pub fn sum_all(&self) -> Result<Self> {
1221 self.reduce_op(ReduceOp::Sum, &[], false)
1222 }
1223
1224 pub fn sum(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1226 self.reduce_op(ReduceOp::Sum, &[dim], keep_dim)
1227 }
1228
1229 pub fn mean_all(&self) -> Result<Self> {
1231 self.reduce_op(ReduceOp::Mean, &[], false)
1232 }
1233
1234 pub fn mean(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1236 self.reduce_op(ReduceOp::Mean, &[dim], keep_dim)
1237 }
1238
1239 pub fn max(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1241 self.reduce_op(ReduceOp::Max, &[dim], keep_dim)
1242 }
1243
1244 pub fn min(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1246 self.reduce_op(ReduceOp::Min, &[dim], keep_dim)
1247 }
1248
1249 pub fn argmax(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1251 self.reduce_op(ReduceOp::ArgMax, &[dim], keep_dim)
1252 }
1253
1254 pub fn argmin(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1256 self.reduce_op(ReduceOp::ArgMin, &[dim], keep_dim)
1257 }
1258
1259 fn reduce_op(&self, op: ReduceOp, dims: &[usize], keep_dim: bool) -> Result<Self> {
1261 for &d in dims {
1263 if d >= self.rank() {
1264 return Err(Error::DimOutOfRange {
1265 dim: d,
1266 rank: self.rank(),
1267 });
1268 }
1269 }
1270 let storage = self.read_storage()?;
1271 let result = B::reduce_op(op, &storage, &self.inner.layout, dims, keep_dim)?;
1272
1273 let result_shape = if dims.is_empty() {
1275 Shape::from(())
1277 } else if keep_dim {
1278 let mut new_dims = self.dims().to_vec();
1279 for &d in dims {
1280 new_dims[d] = 1;
1281 }
1282 Shape::new(new_dims)
1283 } else {
1284 let new_dims: Vec<usize> = self
1285 .dims()
1286 .iter()
1287 .enumerate()
1288 .filter(|(i, _)| !dims.contains(i))
1289 .map(|(_, &d)| d)
1290 .collect();
1291 if new_dims.is_empty() {
1292 Shape::from(())
1293 } else {
1294 Shape::new(new_dims)
1295 }
1296 };
1297
1298 let result_layout = Layout::contiguous(result_shape);
1299 let result_dtype = match op {
1300 ReduceOp::ArgMax | ReduceOp::ArgMin => DType::I64,
1301 _ => self.inner.dtype,
1302 };
1303 let result_op = Op::Reduce {
1304 input: self.clone(),
1305 op,
1306 dims: dims.to_vec(),
1307 keep_dim,
1308 };
1309 Ok(Self::from_storage(
1310 result,
1311 result_layout,
1312 result_dtype,
1313 self.inner.device.clone(),
1314 result_op,
1315 ))
1316 }
1317
1318 pub fn softmax(&self, dim: usize) -> Result<Self> {
1326 let max_val = self.max(dim, true)?;
1328 let max_detached = max_val.detach();
1330 let shifted = self.sub(&max_detached)?; let exp_x = shifted.exp()?;
1332 let sum_exp = exp_x.sum(dim, true)?;
1333 exp_x.div(&sum_exp)
1334 }
1335
1336 pub fn log_softmax(&self, dim: usize) -> Result<Self> {
1340 let max_val = self.max(dim, true)?.detach();
1341 let shifted = self.sub(&max_val)?;
1342 let exp_x = shifted.exp()?;
1343 let sum_exp = exp_x.sum(dim, true)?;
1344 let log_sum_exp = sum_exp.log()?;
1345 shifted.sub(&log_sum_exp)
1346 }
1347
1348 pub fn var(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1350 let mu = self.mean(dim, true)?;
1351 let centered = self.sub(&mu)?;
1352 let sq = centered.square()?;
1353 sq.mean(dim, keep_dim)
1354 }
1355
1356 pub fn cat(tensors: &[Self], dim: usize) -> Result<Self> {
1361 if tensors.is_empty() {
1362 return Err(Error::msg("cat: empty tensor list"));
1363 }
1364 if tensors.len() == 1 {
1365 return Ok(tensors[0].clone());
1366 }
1367
1368 let first = &tensors[0];
1369 let rank = first.rank();
1370 if dim >= rank {
1371 return Err(Error::DimOutOfRange { dim, rank });
1372 }
1373
1374 for (i, t) in tensors.iter().enumerate().skip(1) {
1376 if t.rank() != rank {
1377 return Err(Error::msg(format!(
1378 "cat: tensor {} has rank {} but expected {}",
1379 i,
1380 t.rank(),
1381 rank
1382 )));
1383 }
1384 if t.dtype() != first.dtype() {
1385 return Err(Error::DTypeMismatch {
1386 expected: first.dtype(),
1387 got: t.dtype(),
1388 });
1389 }
1390 for d in 0..rank {
1391 if d != dim && t.dims()[d] != first.dims()[d] {
1392 return Err(Error::msg(format!(
1393 "cat: tensor {} has size {} at dim {} but expected {}",
1394 i,
1395 t.dims()[d],
1396 d,
1397 first.dims()[d]
1398 )));
1399 }
1400 }
1401 }
1402
1403 let cat_size: usize = tensors.iter().map(|t| t.dims()[dim]).sum();
1405 let mut out_dims = first.dims().to_vec();
1406 out_dims[dim] = cat_size;
1407 let out_shape = Shape::new(out_dims.clone());
1408
1409 let sizes: Vec<usize> = tensors.iter().map(|t| t.dims()[dim]).collect();
1411
1412 let inner_guards: Vec<_> = tensors
1414 .iter()
1415 .map(|t| t.inner.storage.read().unwrap())
1416 .collect();
1417 let pairs: Vec<(&B::Storage, &Layout)> = tensors
1418 .iter()
1419 .enumerate()
1420 .map(|(i, t)| (&*inner_guards[i], &t.inner.layout))
1421 .collect();
1422
1423 let storage = B::cat(&pairs, &out_shape, dim)?;
1424 let layout = Layout::contiguous(out_shape);
1425 let op = Op::Cat {
1426 inputs: tensors.to_vec(),
1427 dim,
1428 sizes,
1429 };
1430 Ok(Self::from_storage(
1431 storage,
1432 layout,
1433 first.dtype(),
1434 first.device().clone(),
1435 op,
1436 ))
1437 }
1438
1439 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Self>> {
1442 if dim >= self.rank() {
1443 return Err(Error::DimOutOfRange {
1444 dim,
1445 rank: self.rank(),
1446 });
1447 }
1448 let dim_size = self.dims()[dim];
1449 let chunk_size = dim_size.div_ceil(n);
1450 let mut chunks = Vec::new();
1451 let mut start = 0;
1452 while start < dim_size {
1453 let len = chunk_size.min(dim_size - start);
1454 chunks.push(self.narrow(dim, start, len)?);
1455 start += len;
1456 }
1457 Ok(chunks)
1458 }
1459
1460 pub fn expand(&self, target_shape: impl Into<Shape>) -> Result<Self> {
1464 let target = target_shape.into();
1465 let self_dims = self.dims();
1466 let target_dims = target.dims();
1467
1468 if self_dims.len() != target_dims.len() {
1469 return Err(Error::msg(format!(
1470 "expand: rank mismatch — self {:?} vs target {:?}",
1471 self_dims, target_dims
1472 )));
1473 }
1474
1475 for (i, (&sd, &td)) in self_dims.iter().zip(target_dims.iter()).enumerate() {
1476 if sd != td && sd != 1 {
1477 return Err(Error::msg(format!(
1478 "expand: can only expand size-1 dims, but dim {} has size {}",
1479 i, sd
1480 )));
1481 }
1482 }
1483
1484 let self_strides = self.inner.layout.strides();
1488 let mut new_strides = self_strides.to_vec();
1489 for d in 0..target_dims.len() {
1490 if self_dims[d] == 1 && target_dims[d] > 1 {
1491 new_strides[d] = 0;
1492 }
1493 }
1494
1495 let new_layout = Layout::new(target, new_strides, self.inner.layout.offset());
1496
1497 Ok(Tensor {
1499 inner: Arc::new(TensorInner {
1500 id: TensorId::new(),
1501 storage: Arc::clone(&self.inner.storage),
1502 layout: new_layout,
1503 dtype: self.inner.dtype,
1504 device: self.inner.device.clone(),
1505 op: Op::None,
1506 is_variable: false,
1507 }),
1508 })
1509 }
1510
1511 pub fn stack(tensors: &[Self], dim: usize) -> Result<Self> {
1518 if tensors.is_empty() {
1519 return Err(Error::msg("stack: empty tensor list"));
1520 }
1521 let first_shape = tensors[0].shape().clone();
1522 let rank = first_shape.dims().len();
1523 if dim > rank {
1524 return Err(Error::DimOutOfRange {
1525 dim,
1526 rank: rank + 1,
1527 });
1528 }
1529 for (i, t) in tensors.iter().enumerate().skip(1) {
1531 if t.shape() != &first_shape {
1532 return Err(Error::msg(format!(
1533 "stack: tensor {} has shape {:?} but expected {:?}",
1534 i,
1535 t.dims(),
1536 first_shape.dims()
1537 )));
1538 }
1539 }
1540 let unsqueezed: Vec<Self> = tensors
1542 .iter()
1543 .map(|t| t.unsqueeze(dim))
1544 .collect::<Result<Vec<_>>>()?;
1545 Self::cat(&unsqueezed, dim)
1546 }
1547
1548 pub fn arange(n: usize, dtype: DType, device: &B::Device) -> Result<Self> {
1552 let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
1553 Self::from_f64_slice(&data, n, dtype, device)
1554 }
1555
1556 pub fn arange_step(
1558 start: f64,
1559 end: f64,
1560 step: f64,
1561 dtype: DType,
1562 device: &B::Device,
1563 ) -> Result<Self> {
1564 if step == 0.0 {
1565 return Err(Error::msg("arange_step: step cannot be zero"));
1566 }
1567 let mut data = Vec::new();
1568 let mut v = start;
1569 if step > 0.0 {
1570 while v < end {
1571 data.push(v);
1572 v += step;
1573 }
1574 } else {
1575 while v > end {
1576 data.push(v);
1577 v += step;
1578 }
1579 }
1580 let len = data.len();
1581 Self::from_f64_slice(&data, len, dtype, device)
1582 }
1583
1584 pub fn triu(
1591 n: usize,
1592 m: usize,
1593 diagonal: i64,
1594 dtype: DType,
1595 device: &B::Device,
1596 ) -> Result<Self> {
1597 let mut data = vec![0.0f64; n * m];
1598 for i in 0..n {
1599 for j in 0..m {
1600 if (j as i64) >= (i as i64) + diagonal {
1601 data[i * m + j] = 1.0;
1602 }
1603 }
1604 }
1605 Self::from_f64_slice(&data, (n, m), dtype, device)
1606 }
1607
1608 pub fn tril(
1611 n: usize,
1612 m: usize,
1613 diagonal: i64,
1614 dtype: DType,
1615 device: &B::Device,
1616 ) -> Result<Self> {
1617 let mut data = vec![0.0f64; n * m];
1618 for i in 0..n {
1619 for j in 0..m {
1620 if (j as i64) <= (i as i64) + diagonal {
1621 data[i * m + j] = 1.0;
1622 }
1623 }
1624 }
1625 Self::from_f64_slice(&data, (n, m), dtype, device)
1626 }
1627
1628 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1635 if self.dtype() != rhs.dtype() {
1636 return Err(Error::DTypeMismatch {
1637 expected: self.dtype(),
1638 got: rhs.dtype(),
1639 });
1640 }
1641 if self.rank() < 2 || rhs.rank() < 2 {
1643 return Err(Error::RankMismatch {
1644 expected: 2,
1645 got: self.rank().min(rhs.rank()),
1646 });
1647 }
1648 let lhs_dims = self.dims();
1649 let rhs_dims = rhs.dims();
1650 let k1 = lhs_dims[lhs_dims.len() - 1];
1651 let k2 = rhs_dims[rhs_dims.len() - 2];
1652 if k1 != k2 {
1653 let m = lhs_dims[lhs_dims.len() - 2];
1654 let n = rhs_dims[rhs_dims.len() - 1];
1655 return Err(Error::MatmulShapeMismatch { m, k1, k2, n });
1656 }
1657
1658 let storage_lhs = self.read_storage()?;
1659 let storage_rhs = rhs.read_storage()?;
1660 let result = B::matmul(
1661 &storage_lhs,
1662 &self.inner.layout,
1663 &storage_rhs,
1664 &rhs.inner.layout,
1665 )?;
1666
1667 let m = lhs_dims[lhs_dims.len() - 2];
1669 let n = rhs_dims[rhs_dims.len() - 1];
1670 let mut result_dims: Vec<usize> = lhs_dims[..lhs_dims.len() - 2].to_vec();
1671 result_dims.push(m);
1672 result_dims.push(n);
1673 let result_layout = Layout::contiguous(Shape::new(result_dims));
1674 let result_op = Op::Matmul {
1675 lhs: self.clone(),
1676 rhs: rhs.clone(),
1677 };
1678 Ok(Self::from_storage(
1679 result,
1680 result_layout,
1681 self.inner.dtype,
1682 self.inner.device.clone(),
1683 result_op,
1684 ))
1685 }
1686
1687 #[allow(clippy::needless_range_loop)]
1700 pub fn conv2d(
1701 &self,
1702 weight: &Self,
1703 bias: Option<&Self>,
1704 stride: [usize; 2],
1705 padding: [usize; 2],
1706 ) -> Result<Self> {
1707 if self.rank() != 4 {
1709 return Err(Error::msg(format!(
1710 "conv2d input must be 4D [N,C,H,W], got rank {}",
1711 self.rank()
1712 )));
1713 }
1714 if weight.rank() != 4 {
1715 return Err(Error::msg(format!(
1716 "conv2d weight must be 4D [C_out,C_in,kH,kW], got rank {}",
1717 weight.rank()
1718 )));
1719 }
1720
1721 let in_dims = self.dims();
1722 let w_dims = weight.dims();
1723 let (n, c_in, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1724 let (c_out, wc_in, kh, kw) = (w_dims[0], w_dims[1], w_dims[2], w_dims[3]);
1725
1726 if c_in != wc_in {
1727 return Err(Error::msg(format!(
1728 "conv2d: input channels {} != weight channels {}",
1729 c_in, wc_in
1730 )));
1731 }
1732
1733 let [sh, sw] = stride;
1734 let [ph, pw] = padding;
1735
1736 if h + 2 * ph < kh || w + 2 * pw < kw {
1737 return Err(Error::msg("conv2d: kernel larger than padded input"));
1738 }
1739
1740 let h_out = (h + 2 * ph - kh) / sh + 1;
1741 let w_out = (w + 2 * pw - kw) / sw + 1;
1742
1743 let input_data = self.contiguous()?.to_f64_vec()?;
1745 let weight_data = weight.contiguous()?.to_f64_vec()?;
1746 let bias_data = match bias {
1747 Some(b) => Some(b.contiguous()?.to_f64_vec()?),
1748 None => None,
1749 };
1750
1751 let out_size = n * c_out * h_out * w_out;
1752 let mut output = vec![0.0f64; out_size];
1753
1754 let col_rows = c_in * kh * kw;
1759 let col_cols = h_out * w_out;
1760 let mut columns = vec![0.0f64; col_rows * col_cols];
1761 let sample_size = c_in * h * w;
1762
1763 for ni in 0..n {
1764 let in_offset = ni * sample_size;
1766 im2col(
1767 &input_data[in_offset..in_offset + sample_size],
1768 c_in,
1769 h,
1770 w,
1771 kh,
1772 kw,
1773 sh,
1774 sw,
1775 ph,
1776 pw,
1777 h_out,
1778 w_out,
1779 &mut columns,
1780 );
1781
1782 let out_offset = ni * c_out * h_out * w_out;
1784 gemm(
1785 &weight_data,
1786 &columns,
1787 &mut output[out_offset..out_offset + c_out * col_cols],
1788 c_out,
1789 col_cols,
1790 col_rows,
1791 );
1792
1793 if let Some(ref bd) = bias_data {
1795 for co in 0..c_out {
1796 let row_start = out_offset + co * col_cols;
1797 for j in 0..col_cols {
1798 output[row_start + j] += bd[co];
1799 }
1800 }
1801 }
1802 }
1803
1804 let result_shape = Shape::new(vec![n, c_out, h_out, w_out]);
1805 let result_op = Op::Conv2d {
1806 input: self.clone(),
1807 weight: weight.clone(),
1808 bias: bias.cloned(),
1809 stride,
1810 padding,
1811 };
1812 Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1813 Self::from_storage(
1814 {
1815 let s = t.inner.storage.read().expect("storage lock poisoned");
1816 s.clone()
1817 },
1818 Layout::contiguous(result_shape),
1819 self.inner.dtype,
1820 self.inner.device.clone(),
1821 result_op,
1822 )
1823 })
1824 }
1825
1826 pub fn max_pool2d(
1833 &self,
1834 kernel_size: [usize; 2],
1835 stride: [usize; 2],
1836 padding: [usize; 2],
1837 ) -> Result<Self> {
1838 if self.rank() != 4 {
1839 return Err(Error::msg(format!(
1840 "max_pool2d input must be 4D [N,C,H,W], got rank {}",
1841 self.rank()
1842 )));
1843 }
1844
1845 let dims = self.dims();
1846 let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
1847 let [kh, kw] = kernel_size;
1848 let [sh, sw] = stride;
1849 let [ph, pw] = padding;
1850
1851 if h + 2 * ph < kh || w + 2 * pw < kw {
1852 return Err(Error::msg("max_pool2d: kernel larger than padded input"));
1853 }
1854
1855 let h_out = (h + 2 * ph - kh) / sh + 1;
1856 let w_out = (w + 2 * pw - kw) / sw + 1;
1857
1858 let input_data = self.contiguous()?.to_f64_vec()?;
1859 let out_size = n * c * h_out * w_out;
1860 let mut output = vec![f64::NEG_INFINITY; out_size];
1861 let mut indices = vec![0usize; out_size];
1862
1863 for ni in 0..n {
1864 for ci in 0..c {
1865 for oh in 0..h_out {
1866 for ow in 0..w_out {
1867 let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1868 let mut max_val = f64::NEG_INFINITY;
1869 let mut max_idx = 0usize;
1870 for ki in 0..kh {
1871 for kj in 0..kw {
1872 let ih = (oh * sh + ki) as isize - ph as isize;
1873 let iw = (ow * sw + kj) as isize - pw as isize;
1874 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1875 let ih = ih as usize;
1876 let iw = iw as usize;
1877 let in_idx = ((ni * c + ci) * h + ih) * w + iw;
1878 if input_data[in_idx] > max_val {
1879 max_val = input_data[in_idx];
1880 max_idx = in_idx;
1881 }
1882 }
1883 }
1884 }
1885 output[out_idx] = max_val;
1886 indices[out_idx] = max_idx;
1887 }
1888 }
1889 }
1890 }
1891
1892 let result_shape = Shape::new(vec![n, c, h_out, w_out]);
1893 let result_op = Op::MaxPool2d {
1894 input: self.clone(),
1895 kernel_size,
1896 stride,
1897 padding,
1898 indices: indices.clone(),
1899 };
1900 Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1901 Self::from_storage(
1902 {
1903 let s = t.inner.storage.read().expect("storage lock poisoned");
1904 s.clone()
1905 },
1906 Layout::contiguous(result_shape),
1907 self.inner.dtype,
1908 self.inner.device.clone(),
1909 result_op,
1910 )
1911 })
1912 }
1913
1914 pub fn avg_pool2d(
1918 &self,
1919 kernel_size: [usize; 2],
1920 stride: [usize; 2],
1921 padding: [usize; 2],
1922 ) -> Result<Self> {
1923 if self.rank() != 4 {
1924 return Err(Error::msg(format!(
1925 "avg_pool2d input must be 4D [N,C,H,W], got rank {}",
1926 self.rank()
1927 )));
1928 }
1929
1930 let dims = self.dims();
1931 let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
1932 let [kh, kw] = kernel_size;
1933 let [sh, sw] = stride;
1934 let [ph, pw] = padding;
1935
1936 if h + 2 * ph < kh || w + 2 * pw < kw {
1937 return Err(Error::msg("avg_pool2d: kernel larger than padded input"));
1938 }
1939
1940 let h_out = (h + 2 * ph - kh) / sh + 1;
1941 let w_out = (w + 2 * pw - kw) / sw + 1;
1942
1943 let input_data = self.contiguous()?.to_f64_vec()?;
1944 let out_size = n * c * h_out * w_out;
1945 let mut output = vec![0.0f64; out_size];
1946
1947 for ni in 0..n {
1948 for ci in 0..c {
1949 for oh in 0..h_out {
1950 for ow in 0..w_out {
1951 let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1952 let mut sum = 0.0f64;
1953 let mut count = 0usize;
1954 for ki in 0..kh {
1955 for kj in 0..kw {
1956 let ih = (oh * sh + ki) as isize - ph as isize;
1957 let iw = (ow * sw + kj) as isize - pw as isize;
1958 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1959 let in_idx =
1960 ((ni * c + ci) * h + ih as usize) * w + iw as usize;
1961 sum += input_data[in_idx];
1962 count += 1;
1963 }
1964 }
1965 }
1966 output[out_idx] = if count > 0 { sum / count as f64 } else { 0.0 };
1967 }
1968 }
1969 }
1970 }
1971
1972 let result_shape = Shape::new(vec![n, c, h_out, w_out]);
1973 let result_op = Op::AvgPool2d {
1974 input: self.clone(),
1975 kernel_size,
1976 stride,
1977 padding,
1978 };
1979 Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1980 Self::from_storage(
1981 {
1982 let s = t.inner.storage.read().expect("storage lock poisoned");
1983 s.clone()
1984 },
1985 Layout::contiguous(result_shape),
1986 self.inner.dtype,
1987 self.inner.device.clone(),
1988 result_op,
1989 )
1990 })
1991 }
1992
1993 #[allow(clippy::needless_range_loop)]
1998 pub fn conv1d(
1999 &self,
2000 weight: &Self,
2001 bias: Option<&Self>,
2002 stride: usize,
2003 padding: usize,
2004 ) -> Result<Self> {
2005 if self.rank() != 3 {
2006 return Err(Error::msg(format!(
2007 "conv1d input must be 3D [N,C_in,L], got rank {}",
2008 self.rank()
2009 )));
2010 }
2011 if weight.rank() != 3 {
2012 return Err(Error::msg(format!(
2013 "conv1d weight must be 3D [C_out,C_in,K], got rank {}",
2014 weight.rank()
2015 )));
2016 }
2017
2018 let in_dims = self.dims();
2019 let w_dims = weight.dims();
2020 let (n, c_in, l) = (in_dims[0], in_dims[1], in_dims[2]);
2021 let (c_out, wc_in, k) = (w_dims[0], w_dims[1], w_dims[2]);
2022
2023 if c_in != wc_in {
2024 return Err(Error::msg(format!(
2025 "conv1d: input channels {} != weight channels {}",
2026 c_in, wc_in
2027 )));
2028 }
2029 if let Some(b) = bias {
2030 if b.elem_count() != c_out {
2031 return Err(Error::msg(format!(
2032 "conv1d: bias size {} != output channels {}",
2033 b.elem_count(),
2034 c_out
2035 )));
2036 }
2037 }
2038
2039 if l + 2 * padding < k {
2040 return Err(Error::msg("conv1d: kernel larger than padded input"));
2041 }
2042
2043 let l_out = (l + 2 * padding - k) / stride + 1;
2044
2045 let input_data = self.contiguous()?.to_f64_vec()?;
2046 let weight_data = weight.contiguous()?.to_f64_vec()?;
2047 let bias_data: Option<Vec<f64>> = match bias {
2048 Some(b) => Some(b.to_f64_vec()?),
2049 None => None,
2050 };
2051
2052 let out_size = n * c_out * l_out;
2053 let mut output = vec![0.0f64; out_size];
2054
2055 let col_rows = c_in * k;
2057 let col_cols = l_out;
2058 let mut columns = vec![0.0f64; col_rows * col_cols];
2059 let sample_size = c_in * l;
2060
2061 for ni in 0..n {
2062 let in_offset = ni * sample_size;
2064 im2col(
2065 &input_data[in_offset..in_offset + sample_size],
2066 c_in,
2067 1,
2068 l,
2069 1,
2070 k,
2071 1,
2072 stride,
2073 0,
2074 padding,
2075 1,
2076 l_out,
2077 &mut columns,
2078 );
2079
2080 let out_offset = ni * c_out * l_out;
2082 gemm(
2083 &weight_data,
2084 &columns,
2085 &mut output[out_offset..out_offset + c_out * col_cols],
2086 c_out,
2087 col_cols,
2088 col_rows,
2089 );
2090
2091 if let Some(ref bd) = bias_data {
2093 for co in 0..c_out {
2094 let row_start = out_offset + co * col_cols;
2095 for j in 0..col_cols {
2096 output[row_start + j] += bd[co];
2097 }
2098 }
2099 }
2100 }
2101
2102 let result_shape = Shape::new(vec![n, c_out, l_out]);
2103 let result_op = Op::Conv1d {
2104 input: self.clone(),
2105 weight: weight.clone(),
2106 bias: bias.cloned(),
2107 stride,
2108 padding,
2109 };
2110 Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
2111 Self::from_storage(
2112 {
2113 let s = t.inner.storage.read().expect("storage lock poisoned");
2114 s.clone()
2115 },
2116 Layout::contiguous(result_shape),
2117 self.inner.dtype,
2118 self.inner.device.clone(),
2119 result_op,
2120 )
2121 })
2122 }
2123
2124 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
2129 let storage = self.read_storage()?;
2130 let result = B::affine(&storage, &self.inner.layout, mul, add)?;
2131 let result_layout = Layout::contiguous(self.shape().clone());
2132 let result_op = Op::Affine {
2133 input: self.clone(),
2134 mul,
2135 add,
2136 };
2137 Ok(Self::from_storage(
2138 result,
2139 result_layout,
2140 self.inner.dtype,
2141 self.inner.device.clone(),
2142 result_op,
2143 ))
2144 }
2145
2146 pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
2150 let storage = self.read_storage()?;
2151 B::to_f64_vec(&storage, &self.inner.layout)
2152 }
2153
2154 pub fn to_scalar_f64(&self) -> Result<f64> {
2156 if self.elem_count() != 1 {
2157 return Err(Error::NotAScalar {
2158 shape: self.shape().clone(),
2159 });
2160 }
2161 let vec = self.to_f64_vec()?;
2162 Ok(vec[0])
2163 }
2164
2165 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2171 if self.dtype() == dtype {
2172 return Ok(self.clone());
2173 }
2174 let src_dtype = self.dtype();
2175 let guard = self.inner.storage.read().unwrap();
2176 let storage = B::cast(&*guard, &self.inner.layout, dtype, self.device())?;
2177 let layout = Layout::contiguous(self.shape().clone());
2178 let op = if self.is_variable() {
2179 Op::ToDtype {
2180 input: self.clone(),
2181 src_dtype,
2182 }
2183 } else {
2184 Op::None
2185 };
2186 Ok(Self::from_storage(
2187 storage,
2188 layout,
2189 dtype,
2190 self.device().clone(),
2191 op,
2192 ))
2193 }
2194
2195 pub fn to_string_with_data(&self) -> Result<String> {
2197 let data = self.to_f64_vec()?;
2198 Ok(format!(
2199 "Tensor(shape={}, dtype={}, data={:?})",
2200 self.shape(),
2201 self.dtype(),
2202 data
2203 ))
2204 }
2205
2206 pub fn backward(&self) -> Result<crate::backprop::GradStore<B>> {
2222 crate::backprop::backward(self)
2223 }
2224
2225 pub fn detach(&self) -> Self {
2228 self.view_with_layout(self.layout().clone(), Op::None)
2229 }
2230
2231 pub fn freeze(&self) -> Self {
2236 Tensor {
2237 inner: Arc::new(TensorInner {
2238 id: self.inner.id,
2239 storage: Arc::clone(&self.inner.storage),
2240 layout: self.inner.layout.clone(),
2241 dtype: self.inner.dtype,
2242 device: self.inner.device.clone(),
2243 op: self.inner.op.clone(),
2244 is_variable: false,
2245 }),
2246 }
2247 }
2248
2249 pub fn unfreeze(&self) -> Self {
2253 self.set_variable_ref()
2254 }
2255
2256 pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self> {
2263 if dim >= self.rank() {
2264 return Err(Error::DimOutOfRange {
2265 dim,
2266 rank: self.rank(),
2267 });
2268 }
2269 let guard = self.inner.storage.read().unwrap();
2270 let idx_guard = indices.inner.storage.read().unwrap();
2271 let storage = B::index_select(
2272 &*guard,
2273 &self.inner.layout,
2274 &*idx_guard,
2275 &indices.inner.layout,
2276 dim,
2277 )?;
2278 let mut out_dims = self.dims().to_vec();
2279 out_dims[dim] = indices.elem_count();
2280 let layout = Layout::contiguous(Shape::new(out_dims));
2281 let op = Op::IndexSelect {
2283 input: self.clone(),
2284 indices: indices.clone(),
2285 dim,
2286 };
2287 Ok(Self::from_storage(
2288 storage,
2289 layout,
2290 self.dtype(),
2291 self.device().clone(),
2292 op,
2293 ))
2294 }
2295
2296 pub fn split(&self, split_size: usize, dim: usize) -> Result<Vec<Self>> {
2300 if dim >= self.rank() {
2301 return Err(Error::DimOutOfRange {
2302 dim,
2303 rank: self.rank(),
2304 });
2305 }
2306 if split_size == 0 {
2307 return Err(Error::msg("split: split_size must be > 0"));
2308 }
2309 let dim_size = self.dims()[dim];
2310 let mut parts = Vec::new();
2311 let mut start = 0;
2312 while start < dim_size {
2313 let len = split_size.min(dim_size - start);
2314 parts.push(self.narrow(dim, start, len)?);
2315 start += len;
2316 }
2317 Ok(parts)
2318 }
2319
2320 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Self> {
2325 let rank = self.rank();
2326 if start_dim >= rank || end_dim >= rank || start_dim > end_dim {
2327 return Err(Error::msg(format!(
2328 "flatten: invalid range [{}, {}] for rank {}",
2329 start_dim, end_dim, rank
2330 )));
2331 }
2332 let dims = self.dims();
2333 let mut new_dims: Vec<usize> = Vec::new();
2334 new_dims.extend_from_slice(&dims[..start_dim]);
2335 let flat: usize = dims[start_dim..=end_dim].iter().product();
2336 new_dims.push(flat);
2337 if end_dim + 1 < rank {
2338 new_dims.extend_from_slice(&dims[end_dim + 1..]);
2339 }
2340 self.reshape(new_dims)
2341 }
2342
2343 pub fn std(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2347 self.var(dim, keep_dim)?.sqrt()
2348 }
2349
2350 pub fn reciprocal(&self) -> Result<Self> {
2352 let one = Self::ones(self.dims(), self.dtype(), self.device())?;
2353 one.div(self)
2354 }
2355
2356 pub fn rsqrt(&self) -> Result<Self> {
2358 self.sqrt()?.reciprocal()
2359 }
2360
2361 pub fn sign(&self) -> Result<Self> {
2366 let eps = 1e-12;
2367 let abs_x = self.abs()?;
2368 let denom = abs_x.affine(1.0, eps)?;
2369 let raw = self.div(&denom)?;
2370 raw.clamp(-1.0, 1.0)
2371 }
2372
2373 pub fn logsumexp(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2377 let m = self.max(dim, true)?.detach();
2378 let shifted = self.sub(&m)?;
2379 let sum_exp = shifted.exp()?.sum(dim, true)?.log()?;
2380 let result = m.add(&sum_exp)?;
2381 if keep_dim {
2382 Ok(result)
2383 } else {
2384 result.squeeze(dim)
2385 }
2386 }
2387
2388 pub fn prod(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2393 let log_abs = self.abs()?.log()?;
2394 let sum_log = log_abs.sum(dim, keep_dim)?;
2395 let magnitude = sum_log.exp()?;
2396 Ok(magnitude)
2400 }
2401
2402 fn set_variable_ref(&self) -> Self {
2404 Tensor {
2405 inner: Arc::new(TensorInner {
2406 id: self.inner.id,
2407 storage: Arc::clone(&self.inner.storage),
2408 layout: self.inner.layout.clone(),
2409 dtype: self.inner.dtype,
2410 device: self.inner.device.clone(),
2411 op: self.inner.op.clone(),
2412 is_variable: true,
2413 }),
2414 }
2415 }
2416}
2417
2418#[inline]
2434#[allow(clippy::too_many_arguments)]
2435pub(crate) fn im2col(
2436 input: &[f64],
2437 c_in: usize,
2438 h: usize,
2439 w: usize,
2440 kh: usize,
2441 kw: usize,
2442 sh: usize,
2443 sw: usize,
2444 ph: usize,
2445 pw: usize,
2446 h_out: usize,
2447 w_out: usize,
2448 columns: &mut [f64],
2449) {
2450 let col_cols = h_out * w_out;
2451 for ci in 0..c_in {
2454 for ki in 0..kh {
2455 for kj in 0..kw {
2456 let row = (ci * kh + ki) * kw + kj;
2457 let row_offset = row * col_cols;
2458 for oh in 0..h_out {
2459 for ow in 0..w_out {
2460 let ih = (oh * sh + ki) as isize - ph as isize;
2461 let iw = (ow * sw + kj) as isize - pw as isize;
2462 let val = if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
2463 input[(ci * h + ih as usize) * w + iw as usize]
2464 } else {
2465 0.0
2466 };
2467 columns[row_offset + oh * w_out + ow] = val;
2468 }
2469 }
2470 }
2471 }
2472 }
2473}
2474
2475#[inline]
2481#[allow(clippy::too_many_arguments)]
2482pub(crate) fn col2im(
2483 columns: &[f64],
2484 c_in: usize,
2485 h: usize,
2486 w: usize,
2487 kh: usize,
2488 kw: usize,
2489 sh: usize,
2490 sw: usize,
2491 ph: usize,
2492 pw: usize,
2493 h_out: usize,
2494 w_out: usize,
2495 output: &mut [f64],
2496) {
2497 let col_cols = h_out * w_out;
2498 for ci in 0..c_in {
2499 for ki in 0..kh {
2500 for kj in 0..kw {
2501 let row = (ci * kh + ki) * kw + kj;
2502 let row_offset = row * col_cols;
2503 for oh in 0..h_out {
2504 for ow in 0..w_out {
2505 let ih = (oh * sh + ki) as isize - ph as isize;
2506 let iw = (ow * sw + kj) as isize - pw as isize;
2507 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
2508 output[(ci * h + ih as usize) * w + iw as usize] +=
2509 columns[row_offset + oh * w_out + ow];
2510 }
2511 }
2512 }
2513 }
2514 }
2515 }
2516}
2517
2518#[inline]
2523pub(crate) fn gemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2524 for i in 0..m {
2525 let a_row = i * k;
2526 let c_row = i * n;
2527 for p in 0..k {
2528 let a_val = a[a_row + p];
2529 let b_row = p * n;
2530 for j in 0..n {
2531 c[c_row + j] += a_val * b[b_row + j];
2532 }
2533 }
2534 }
2535}
2536
2537#[inline]
2541pub(crate) fn gemm_at_b(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2542 for i in 0..m {
2543 let c_row = i * n;
2544 for p in 0..k {
2545 let a_val = a[p * m + i]; let b_row = p * n;
2547 for j in 0..n {
2548 c[c_row + j] += a_val * b[b_row + j];
2549 }
2550 }
2551 }
2552}
2553
2554#[inline]
2558pub(crate) fn gemm_a_bt(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2559 for i in 0..m {
2560 let a_row = i * k;
2561 let c_row = i * n;
2562 for j in 0..n {
2563 let b_row = j * k;
2564 let mut val = 0.0f64;
2565 for p in 0..k {
2566 val += a[a_row + p] * b[b_row + p];
2567 }
2568 c[c_row + j] += val;
2569 }
2570 }
2571}