1extern crate alloc;
2use crate::axes::IntoAxes;
3use crate::dtype::DType;
4use crate::error::ZyxError;
5use crate::scalar::Scalar;
6use crate::shape::Shape;
7use crate::utils::SizedIterator;
8use crate::{backend::Backend, node::Node};
9use alloc::{boxed::Box, collections::BTreeSet, vec::Vec};
10use core::{
11 cmp::Ordering,
12 iter::repeat,
13 ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, SubAssign},
14};
15
16#[derive(Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Debug)]
18pub struct Id(usize);
19
20pub const fn id(id: usize) -> Id {
22 Id(id)
23}
24
25impl Id {
26 pub const fn i(self) -> usize {
28 self.0
29 }
30}
31
32impl core::fmt::Display for Id {
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 f.write_fmt(format_args!("{:?}", self))
35 }
36}
37
38impl SubAssign<usize> for Id {
39 fn sub_assign(&mut self, rhs: usize) {
40 self.0 -= rhs;
41 }
42}
43
44pub trait IntoRange: Clone {
46 fn into_range(self) -> Range<i64>;
48}
49
50impl IntoRange for RangeFull {
51 fn into_range(self) -> Range<i64> {
52 0..i64::MAX
53 }
54}
55
56impl IntoRange for RangeFrom<i64> {
57 fn into_range(self) -> Range<i64> {
58 self.start..i64::MAX
59 }
60}
61
62impl IntoRange for RangeTo<i64> {
63 fn into_range(self) -> Range<i64> {
64 0..self.end
65 }
66}
67
68impl IntoRange for RangeInclusive<i64> {
69 fn into_range(self) -> Range<i64> {
70 *self.start()..*self.end() + 1
71 }
72}
73
74impl IntoRange for RangeToInclusive<i64> {
75 fn into_range(self) -> Range<i64> {
76 0..self.end + 1
77 }
78}
79
80impl IntoRange for Range<i64> {
81 fn into_range(self) -> Range<i64> {
82 self
83 }
84}
85
86impl IntoRange for i64 {
87 fn into_range(self) -> Range<i64> {
88 self..self + 1
89 }
90}
91
92pub trait IntoIndex {
94 fn into_index(self) -> impl IntoIterator<Item = Range<i64>>;
96}
97
98impl<I: IntoRange> IntoIndex for &[I] {
99 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
100 self.iter().cloned().map(IntoRange::into_range)
101 }
102}
103
104impl<I0: IntoRange> IntoIndex for I0 {
105 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
106 [self.into_range()].into_iter()
107 }
108}
109
110impl<I0: IntoRange, I1: IntoRange> IntoIndex for (I0, I1) {
111 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
112 [self.0.into_range(), self.1.into_range()].into_iter()
113 }
114}
115
116impl<I0: IntoRange, I1: IntoRange, I2: IntoRange> IntoIndex for (I0, I1, I2) {
117 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
118 [
119 self.0.into_range(),
120 self.1.into_range(),
121 self.2.into_range(),
122 ]
123 .into_iter()
124 }
125}
126
127impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange> IntoIndex for (I0, I1, I2, I3) {
128 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
129 [
130 self.0.into_range(),
131 self.1.into_range(),
132 self.2.into_range(),
133 self.3.into_range(),
134 ]
135 .into_iter()
136 }
137}
138
139impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange> IntoIndex
140 for (I0, I1, I2, I3, I4)
141{
142 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
143 [
144 self.0.into_range(),
145 self.1.into_range(),
146 self.2.into_range(),
147 self.3.into_range(),
148 self.4.into_range(),
149 ]
150 .into_iter()
151 }
152}
153
154impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange, I5: IntoRange>
155 IntoIndex for (I0, I1, I2, I3, I4, I5)
156{
157 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
158 [
159 self.0.into_range(),
160 self.1.into_range(),
161 self.2.into_range(),
162 self.3.into_range(),
163 self.4.into_range(),
164 self.5.into_range(),
165 ]
166 .into_iter()
167 }
168}
169
170impl<
171 I0: IntoRange,
172 I1: IntoRange,
173 I2: IntoRange,
174 I3: IntoRange,
175 I4: IntoRange,
176 I5: IntoRange,
177 I6: IntoRange,
178 > IntoIndex for (I0, I1, I2, I3, I4, I5, I6)
179{
180 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
181 [
182 self.0.into_range(),
183 self.1.into_range(),
184 self.2.into_range(),
185 self.3.into_range(),
186 self.4.into_range(),
187 self.5.into_range(),
188 self.6.into_range(),
189 ]
190 .into_iter()
191 }
192}
193
194impl<
195 I0: IntoRange,
196 I1: IntoRange,
197 I2: IntoRange,
198 I3: IntoRange,
199 I4: IntoRange,
200 I5: IntoRange,
201 I6: IntoRange,
202 I7: IntoRange,
203 > IntoIndex for (I0, I1, I2, I3, I4, I5, I6, I7)
204{
205 fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
206 [
207 self.0.into_range(),
208 self.1.into_range(),
209 self.2.into_range(),
210 self.3.into_range(),
211 self.4.into_range(),
212 self.5.into_range(),
213 self.6.into_range(),
214 self.7.into_range(),
215 ]
216 .into_iter()
217 }
218}
219
220pub trait FlattenAxes {
222 fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64>;
224}
225
226impl FlattenAxes for RangeFrom<i64> {
227 fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
228 debug_assert!(
229 if self.start > 0 {
230 (self.start as usize) < rank
231 } else {
232 ((-self.start) as usize) <= rank
233 },
234 "Cannot use {self:?} as flatten axes."
235 );
236 self.start..i64::MAX
237 }
238}
239
240impl FlattenAxes for RangeTo<i64> {
241 fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
242 debug_assert!(
243 if self.end > 0 {
244 (self.end as usize) < rank
245 } else {
246 ((-self.end) as usize) <= rank
247 },
248 "Cannot use {self:?} as flatten axes."
249 );
250 0..self.end
251 }
252}
253
254impl FlattenAxes for RangeToInclusive<i64> {
255 fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
256 debug_assert!(
257 if self.end > 0 {
258 (self.end as usize) < rank
259 } else {
260 ((-self.end) as usize) <= rank
261 },
262 "Cannot use {self:?} as flatten axes."
263 );
264 0..self.end + 1
265 }
266}
267
268impl FlattenAxes for RangeFull {
269 fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
270 0..rank as i64
271 }
272}
273
274pub struct Tensor<B: Backend> {
277 id: Id,
278 backend: B,
279}
280
281impl<B: Backend> Clone for Tensor<B> {
282 fn clone(&self) -> Self {
283 self.backend.retain(self.id);
284 tensor(self.id, self.backend)
285 }
286}
287
288impl<B: Backend> Drop for Tensor<B> {
289 fn drop(&mut self) {
290 self.backend.release(self.id).unwrap();
292 }
293}
294
295impl<B: Backend> core::fmt::Debug for Tensor<B> {
296 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
297 f.write_fmt(format_args!("{self}"))
298 }
300}
301
302impl<B: Backend> core::fmt::Display for Tensor<B> {
303 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
304 let precision = if let Some(precision) = f.precision() {
306 precision
307 } else {
308 3
309 };
310 let res = match self.dtype() {
311 DType::F32 => {
312 if let Ok(data) = &self.to_vec::<f32>() {
313 tensor_to_string(data, &self.shape(), precision, f.width())
314 } else {
315 "f32 tensor failed to realize".into()
316 }
317 }
318 DType::F64 => {
319 if let Ok(data) = &self.to_vec::<f64>() {
320 tensor_to_string(data, &self.shape(), precision, f.width())
321 } else {
322 "f64 tensor failed to realize".into()
323 }
324 }
325 DType::I32 => {
326 if let Ok(data) = &self.to_vec::<i32>() {
327 tensor_to_string(data, &self.shape(), precision, f.width())
328 } else {
329 "i32 tensor failed to realize".into()
330 }
331 }
332 };
333 f.write_fmt(format_args!(
334 "Tensor {} {}\n{res}",
335 self.shape(),
336 self.dtype()
337 ))
338 }
339}
340
341fn tensor_to_string<T: core::fmt::Display>(
342 data: &[T],
343 shape: &Shape,
344 precision: usize,
345 width: Option<usize>,
346) -> alloc::string::String {
347 use core::fmt::Write;
348 let n = shape.numel();
349 let ndim = shape.rank();
350 let mut res = alloc::string::String::new();
351 if data.is_empty() {
352 return "[]".into();
353 }
354 let mut w = 0;
356 if let Some(width) = width {
357 w = width;
358 } else {
359 for x in data {
360 let l = alloc::format!("{x:>.precision$}").len();
361 if l > w {
362 w = l;
363 }
364 }
365 }
366 let d0 = shape[-1];
367 for (i, x) in data.iter().enumerate() {
368 {
369 let mut var = 1;
370 let mut r = ndim;
371 while r > 0 {
372 if i % (n / var) == 0 {
373 res += &(" ".repeat(ndim - r) + "[".repeat(r - 1).as_str());
374 break;
375 }
376 var *= shape[ndim - r];
377 r -= 1;
378 }
379 }
380 let _ = write!(res, "{x:>w$.precision$}");
381 if (i + 1) % d0 != 0usize {
382 res += " ";
383 }
384 {
385 let mut var = 1;
386 let mut r = ndim;
387 while r > 0 {
388 if (i + 1) % (n / var) == 0 {
389 res += &"]".repeat(r - 1);
390 break;
391 }
392 var *= shape[ndim - r];
393 r -= 1;
394 }
395 }
396 if (i + 1) % d0 == 0usize && i != n - 1 {
397 res += "\n";
398 }
399 }
400 res
401}
402
403pub const fn tensor<B: Backend>(id: Id, backend: B) -> Tensor<B> {
406 Tensor { id, backend }
407}
408
409impl<B: Backend> Tensor<B> {
410 pub fn id(&self) -> Id {
414 self.id
415 }
416
417 #[must_use]
425 pub fn shape(&self) -> Shape {
426 self.backend.shape(self.id)
427 }
428
429 #[must_use]
437 pub fn numel(&self) -> usize {
438 self.shape().numel()
439 }
440
441 #[must_use]
449 pub fn dtype(&self) -> DType {
450 self.backend.dtype(self.id)
451 }
452
453 #[must_use]
461 pub fn rank(&self) -> usize {
462 self.shape().rank()
463 }
464
465 #[must_use]
473 pub fn backend(&self) -> B {
474 self.backend
475 }
476
477 #[must_use]
495 pub fn detach(&self) -> Tensor<B> {
496 tensor(
498 self.backend.push(Node::Detach(self.id)).unwrap(),
499 self.backend,
500 )
501 }
502
503 pub fn to_vec<T: Scalar>(&self) -> Result<Vec<T>, ZyxError> {
522 if T::dtype() != self.dtype() {
523 return Err(ZyxError::InvalidDType {
524 expected: T::dtype(),
525 found: self.dtype(),
526 });
527 }
528 self.backend.load(self.id)
529 }
530
531 pub fn item<T: Scalar>(&self) -> Result<T, ZyxError> {
543 self.backend
544 .load::<T>(self.id)?
545 .first()
546 .ok_or(ZyxError::IndexOutOfBounds { index: 0, len: 0 })
547 .cloned()
548 }
549
550 #[must_use]
561 pub fn backward<'a>(
562 &'a self,
563 sources: impl IntoIterator<Item = &'a Tensor<B>>,
564 ) -> Vec<Option<Tensor<B>>>
565 where
566 B: 'a,
567 {
568 let sources: Vec<&Tensor<B>> = sources.into_iter().collect();
569 let grads = self
570 .backend
571 .backward(self.id, &sources.iter().map(|t| t.id).collect())
572 .unwrap();
573 sources
574 .into_iter()
575 .map(move |x: &Tensor<B>| grads.get(&x.id).cloned())
576 .map(move |x| x.map(|x| tensor(x, self.backend)))
577 .collect()
578 }
579
580 #[must_use]
592 pub fn cast(&self, dtype: DType) -> Tensor<B> {
593 tensor(
594 self.backend.push(Node::Cast(self.id, dtype)).unwrap(),
595 self.backend,
596 )
597 }
598
599 #[must_use]
601 pub fn relu(&self) -> Tensor<B> {
602 tensor(
603 self.backend.push(Node::ReLU(self.id)).unwrap(),
604 self.backend,
605 )
606 }
607
608 #[must_use]
610 pub fn sin(&self) -> Tensor<B> {
611 tensor(self.backend.push(Node::Sin(self.id)).unwrap(), self.backend)
612 }
613
614 #[must_use]
616 pub fn cos(&self) -> Tensor<B> {
617 tensor(self.backend.push(Node::Cos(self.id)).unwrap(), self.backend)
618 }
619
620 #[must_use]
625 pub fn ln(&self) -> Tensor<B> {
626 tensor(self.backend.push(Node::Ln(self.id)).unwrap(), self.backend)
627 }
628
629 #[must_use]
631 pub fn exp(&self) -> Tensor<B> {
632 tensor(self.backend.push(Node::Exp(self.id)).unwrap(), self.backend)
633 }
634
635 #[must_use]
637 pub fn tanh(&self) -> Tensor<B> {
638 tensor(
639 self.backend.push(Node::Tanh(self.id)).unwrap(),
640 self.backend,
641 )
642 }
643
644 #[must_use]
649 pub fn sqrt(&self) -> Tensor<B> {
650 tensor(
651 self.backend.push(Node::Sqrt(self.id)).unwrap(),
652 self.backend,
653 )
654 }
655
656 #[must_use]
658 pub fn reciprocal(&self) -> Tensor<B> {
659 self.backend().ones(self.shape(), self.dtype()).unwrap() / self
660 }
661
662 #[must_use]
664 pub fn rsqrt(&self) -> Tensor<B> {
665 self.reciprocal().sqrt()
666 }
667
668 #[must_use]
670 pub fn dropout(&self, probability: impl Scalar) -> Tensor<B> {
671 self.backend()
672 .tensor(probability)
673 .unwrap()
674 .cmplt(self.backend().uniform(self.shape(), 0.0..1.0).unwrap()).cast(self.dtype())
675 * self
676 }
677
678 #[must_use]
680 pub fn abs(&self) -> Tensor<B> {
681 self.relu() + (-self).relu()
682 }
683
684 #[must_use]
686 pub fn sigmoid(&self) -> Tensor<B> {
687 let one = self.backend().ones(1, self.dtype()).unwrap();
688 &one / (&one + (-self).exp())
689 }
690
691 #[must_use]
693 pub fn swish(&self) -> Tensor<B> {
694 self * self.sigmoid()
695 }
696
697 #[must_use]
699 pub fn mish(&self) -> Tensor<B> {
700 self * self.softplus(1, 20).tanh()
701 }
702
703 #[must_use]
705 pub fn softplus(&self, beta: impl Scalar, threshold: impl Scalar) -> Tensor<B> {
706 let x = self * beta.clone();
707 x.cmplt(threshold)
708 .where_(((x).exp() + 1).ln() * beta.reciprocal(), x)
709 }
710
711 #[must_use]
713 pub fn tan(&self) -> Tensor<B> {
714 self.sin() / self.cos()
715 }
716
717 #[must_use]
719 pub fn leaky_relu(&self, neg_slope: impl Scalar) -> Tensor<B> {
720 self.relu() - (self * (-self.backend.tensor(neg_slope).unwrap())).relu()
721 }
722
723 #[must_use]
725 pub fn elu(&self, alpha: impl Scalar) -> Tensor<B> {
726 self.relu() - (1f32.into_tensor(self.backend) - self.exp()).relu() * alpha
727 }
728
729 #[must_use]
731 pub fn selu(&self) -> Tensor<B> {
732 1.0507009873554804934193349852946f32
733 * (self.relu()
734 - (1.6732632423543772848170429916717f32
735 * (self.backend.ones(1, self.dtype()).unwrap() - self.exp()))
736 .relu())
737 }
738
739 #[must_use]
741 pub fn celu(&self, alpha: impl Scalar) -> Tensor<B> {
742 self.relu()
743 - ((self.backend.ones(1, self.dtype()).unwrap() - (self / alpha.clone()).exp()) * alpha)
744 .relu()
745 }
746
747 #[must_use]
749 pub fn gelu(&self) -> Tensor<B> {
750 self * 0.5f32
751 * (((self + self.pow(3f32) * 0.044_715f32) * (2f32 / core::f32::consts::PI).sqrt())
752 .tanh()
753 + 1f32)
754 }
755
756 #[must_use]
758 pub fn quick_gelu(&self) -> Tensor<B> {
759 self * (1.702f32 * self).sigmoid()
760 }
761
762 #[must_use]
764 pub fn softmax(&self, axes: impl IntoAxes) -> Tensor<B> {
765 let axes = axes.into_axes(self.rank());
766 let e = (self - self.max(axes.clone())).exp();
767 &e / e.sum(axes)
768 }
769
770 #[must_use]
772 pub fn ln_softmax(&self, axes: impl IntoAxes) -> Tensor<B> {
773 let axes = axes.into_axes(self.rank());
774 let m = self - self.max(axes.clone());
775 &m - m.exp().sum(axes).ln()
776 }
777
778 #[must_use]
781 pub fn l1_loss(&self, target: impl IntoTensor<B>) -> Tensor<B> {
782 (self - target).abs()
783 }
784
785 #[must_use]
787 pub fn mse_loss(&self, target: impl IntoTensor<B>) -> Tensor<B> {
788 (self - target).pow(2)
789 }
790
791 #[must_use]
794 pub fn cross_entropy_loss(&self, target: impl IntoTensor<B>, axes: impl IntoAxes) -> Tensor<B> {
795 self.ln_softmax(axes) * target
796 }
797
798 #[must_use]
801 pub fn pow(&self, exponent: impl IntoTensor<B>) -> Tensor<B> {
802 let exponent = self.backend.tensor(exponent).unwrap();
803 if exponent.numel() == 1 {
804 let dtype = exponent.dtype();
805 if !dtype.is_floating() {
806 if exponent.item::<i32>().unwrap() == 2i32 {
808 return self * self;
809 } else if exponent.item::<i32>().unwrap() == 3i32 {
810 return self * self * self;
811 }
812 }
813 }
814 if self.dtype().is_floating() {
815 return (exponent * self.ln()).exp();
816 }
817 self.clone().binary_op(exponent, BOp::Pow)
818 }
819
820 #[must_use]
822 pub fn cmplt(&self, rhs: impl IntoTensor<B>) -> Tensor<B> {
823 self.clone().binary_op(rhs, BOp::Cmplt)
824 }
825
826 #[must_use]
828 pub fn where_(&self, if_true: impl IntoTensor<B>, if_false: impl IntoTensor<B>) -> Tensor<B> {
829 let x = self.clone();
830 let y = self.backend.tensor(if_true).unwrap();
831 let z = self.backend.tensor(if_false).unwrap();
832 let (x, y) = Tensor::broadcast(x, y);
833 let (x, z) = Tensor::broadcast(x, z);
834 let (y, z) = Tensor::broadcast(y, z);
835 tensor(
836 self.backend.push(Node::Where(x.id, y.id, z.id)).unwrap(),
837 self.backend,
838 )
839 }
840
841 #[must_use]
843 pub fn cosine_similarity(&self, rhs: impl IntoTensor<B>, eps: impl IntoTensor<B>) -> Tensor<B> {
844 let rhs = self.backend.tensor(rhs).unwrap();
845 let eps = self.backend.tensor(eps).unwrap();
846 let x = self.pow(2).sqrt() * rhs.pow(2).sqrt();
847 self * rhs / x.cmplt(&eps).where_(eps, x)
848 }
849
850 #[must_use]
860 pub fn dot(&self, rhs: impl IntoTensor<B>) -> Tensor<B> {
861 let y = self.backend.tensor(rhs).unwrap().transpose();
862 let xshape = self.shape();
863 let yshape = y.shape();
864 let yrank = yshape.rank();
865 debug_assert_eq!(
866 xshape[-1], yshape[-1],
867 "Cannot dot tensors with shapes {xshape} and {yshape}"
869 );
870 let x_shape = xshape[0..-1]
871 .iter()
872 .copied()
873 .chain([1])
874 .chain([xshape[-1]])
875 .collect::<Box<[usize]>>();
876 let y_shape = yshape[0..-2]
877 .iter()
878 .copied()
879 .chain([1])
880 .chain(yshape[-(yrank.min(2) as i64)..yrank as i64].iter().copied())
881 .collect::<Box<[usize]>>();
882 (self.reshape(x_shape) * y.reshape(y_shape))
885 .sum(-1)
886 .reshape(
887 xshape[0..-1]
888 .iter()
889 .copied()
890 .chain([yshape[-2]])
891 .collect::<Box<[usize]>>(),
892 )
893 }
894
895 #[must_use]
901 pub fn reshape(&self, shape: impl Into<Shape>) -> Tensor<B> {
902 let shape = shape.into();
903 debug_assert_eq!(
904 self.shape().numel(),
905 shape.numel(),
906 "Cannot reshape tensor with shape {} to {shape}",
907 self.shape()
908 );
909 tensor(
910 self.backend.push(Node::Reshape(self.id, shape)).unwrap(),
911 self.backend,
912 )
913 }
914
915 #[must_use]
917 pub fn expand(&self, shape: impl Into<Shape>) -> Tensor<B> {
918 let shape = shape.into();
919 let sh = self.shape();
920 debug_assert!(
921 shape
922 .iter()
923 .rev()
924 .enumerate()
925 .all(|(i, d)| if sh.rank() > i {
926 *d == sh[sh.rank() - i - 1] || sh[sh.rank() - i - 1] == 1
927 } else {
928 true
929 }),
930 "Can't expand tensor with shape {sh} to {shape}"
931 );
932 tensor(
933 self.backend.push(Node::Expand(self.id, shape)).unwrap(),
934 self.backend,
935 )
936 }
937
938 #[must_use]
973 pub fn pad(
974 &self,
975 padding: impl IntoIterator<Item = (i64, i64)>,
976 value: impl IntoTensor<B>,
977 ) -> Tensor<B> {
978 let dtype = self.dtype();
979 let value = self.backend.tensor(value).unwrap();
980 debug_assert_eq!(
981 value.dtype(),
982 dtype,
983 "Cannot pad tensor with dtype {} with value of dtype {}",
984 dtype,
985 value.dtype()
986 );
987 let padding: Box<[(i64, i64)]> = padding.into_iter().collect();
988 let sh = self.shape();
989 debug_assert!(
990 padding.len() <= sh.rank()
991 && padding
992 .iter()
993 .zip(sh.iter().rev())
994 .all(|((lp, rp), d)| if *lp < 0 {
995 ((-*lp) as usize) <= *d
996 } else {
997 true
998 } && if *rp < 0 {
999 ((-*rp) as usize) <= *d
1000 } else {
1001 true
1002 }),
1003 "Cannot pad tensor with shape {sh} with padding {padding:?}"
1004 );
1005 let psh = sh.clone().pad(&padding);
1006 let t0 = tensor(
1007 self.backend
1008 .push(Node::Pad(self.id, padding.clone(), psh.clone()))
1009 .unwrap(),
1010 self.backend,
1011 );
1012 if value.numel() == 1
1013 && match dtype {
1014 DType::F32 => value.item::<f32>().unwrap().is_equal(0f32),
1015 DType::F64 => value.item::<f64>().unwrap().is_equal(0f64),
1016 DType::I32 => value.item::<i32>().unwrap().is_equal(0i32),
1017 }
1018 {
1019 t0
1020 } else {
1021 t0 + tensor(
1022 self.backend
1023 .push(Node::Pad(
1024 self.backend.ones(sh, dtype).unwrap().id,
1025 padding,
1026 psh.clone(),
1027 ))
1028 .unwrap(),
1029 self.backend,
1030 )
1031 .where_(
1032 self.backend.zeros(self.shape(), self.dtype()).unwrap(),
1033 value,
1034 )
1035 }
1036 }
1037
1038 #[must_use]
1040 pub fn permute(&self, axes: impl IntoAxes) -> Tensor<B> {
1041 let axes = axes.into_axes(self.rank());
1042 let shape = self.shape().permute(&axes);
1043 debug_assert!(
1044 axes.len() == shape.rank(),
1045 "Cannot permute tensor with shape {shape} with axes {axes}"
1046 );
1047 tensor(
1048 self.backend
1049 .push(Node::Permute(self.id, axes, shape))
1050 .unwrap(),
1051 self.backend,
1052 )
1053 }
1054
1055 #[must_use]
1058 pub fn transpose(&self) -> Tensor<B> {
1059 let mut rank = self.rank();
1060 let x = if rank == 1 {
1061 let n = self.numel();
1062 rank = 2;
1063 self.reshape([1, n])
1064 } else {
1065 self.clone()
1066 };
1067 let mut axes: Vec<usize> = (0..rank).collect();
1068 axes.swap(rank - 1, rank - 2);
1069 x.permute(axes)
1070 }
1071
1072 #[must_use]
1074 pub fn flatten(&self, axes: impl FlattenAxes) -> Tensor<B> {
1075 let sh = self.shape();
1076 let n = sh.numel();
1077 let rank = sh.rank();
1078 let mut ld = 1;
1079 let mut first_dims = false;
1080 for a in axes.into_flatten_axes(rank) {
1081 let a = if a > 0 {
1082 a as usize
1083 } else {
1084 (a + rank as i64) as usize
1085 };
1086 if a == 0 {
1087 first_dims = true;
1088 }
1089 ld *= sh[a];
1090 }
1091 if first_dims {
1092 self.reshape([ld, n / ld])
1093 } else {
1094 self.reshape([n / ld, ld])
1095 }
1096 }
1097
1098 #[must_use]
1113 pub fn sum(&self, axes: impl IntoAxes) -> Tensor<B> {
1114 let axes = axes.into_axes(self.rank());
1115 let shape = self.shape().reduce(&axes);
1116 let mut uniq = BTreeSet::new();
1117 debug_assert!(
1118 axes.into_iter().all(move |x| uniq.insert(x)),
1119 "Cannot sum tensor with shape {:?} by axes {:?}, because axes contain duplicates.",
1120 self.shape(),
1121 axes
1122 );
1123 tensor(
1124 self.backend.push(Node::Sum(self.id, axes, shape)).unwrap(),
1125 self.backend,
1126 )
1127 }
1128
1129 #[must_use]
1143 pub fn max(&self, axes: impl IntoAxes) -> Tensor<B> {
1144 let axes = axes.into_axes(self.rank());
1145 let shape = self.shape().reduce(&axes);
1146 let mut uniq = BTreeSet::new();
1147 debug_assert!(
1148 axes.into_iter().all(move |x| uniq.insert(x)),
1149 "Cannot sum tensor with shape {:?} by axes {:?}, because axes contain duplicates.",
1150 self.shape(),
1151 axes
1152 );
1153 for a in &axes {
1154 debug_assert!(
1155 *a < shape.rank(),
1156 "Cannot sum tensor with shape {:?} by axes {:?}, because some axes are greater than rank.",
1157 self.shape(),
1158 axes
1159 );
1160 }
1161 tensor(
1162 self.backend.push(Node::Max(self.id, axes, shape)).unwrap(),
1163 self.backend,
1164 )
1165 }
1166
1167 #[must_use]
1169 pub fn mean(&self, axes: impl IntoAxes) -> Tensor<B> {
1170 let shape = self.shape();
1171 let axes = axes.into_axes(shape.rank());
1172 self.sum(axes.clone()) / axes.iter().copied().map(|a| shape[a]).product::<usize>() as i32
1173 }
1174
1175 #[must_use]
1177 pub fn var(&self, axes: impl IntoAxes) -> Tensor<B> {
1178 let axes = axes.into_axes(self.rank());
1179 (self - self.mean(axes.clone())).pow(2).sum(axes)
1180 }
1181
1182 #[must_use]
1184 pub fn std(&self, axes: impl IntoAxes) -> Tensor<B> {
1185 self.var(axes).sqrt()
1186 }
1187
1188 #[must_use]
1190 pub fn norm(&self, axes: impl IntoAxes, p: impl Scalar) -> Tensor<B> {
1191 self.pow(p.clone()).sum(axes).pow(p.reciprocal())
1192 }
1193
1194 #[must_use]
1196 pub fn product(&self, axes: impl IntoAxes) -> Tensor<B> {
1197 self.ln().sum(axes).exp()
1198 }
1199
1200 #[must_use]
1202 pub fn diagonal(&self) -> Tensor<B> {
1203 let n: usize = self.shape()[-1];
1204 self.flatten(..)
1205 .pad([(0, n as i64)], 0)
1206 .reshape([n, n + 1])
1207 .get((.., 0))
1208 }
1209
1210 #[must_use]
1249 pub fn get(&self, index: impl IntoIndex) -> Tensor<B> {
1250 let shape = self.shape();
1252 let padding: Vec<(i64, i64)> = index
1253 .into_index()
1254 .into_iter()
1255 .zip(shape.iter())
1256 .map(|(r, d)| {
1257 (
1258 if r.start >= 0 {
1259 -r.start
1260 } else {
1261 -r.start - *d as i64
1262 },
1263 if r.end == i64::MAX {
1264 0
1265 } else if r.end > 0 {
1266 -(*d as i64 - r.end)
1267 } else {
1268 r.end
1269 },
1270 )
1271 })
1272 .collect();
1273 let n = shape.rank() - padding.len();
1275 self.pad(
1276 padding
1277 .into_iter()
1278 .chain(repeat((0, 0)).take(n))
1279 .collect::<Vec<(i64, i64)>>()
1280 .into_iter()
1281 .rev(),
1282 0,
1283 )
1284 }
1285
1286 #[must_use]
1298 pub fn cat<'a>(tensors: impl IntoIterator<Item = &'a Tensor<B>>, dim: i64) -> Tensor<B>
1299 where
1300 B: 'a,
1301 {
1302 let tensors: Vec<&Tensor<B>> = tensors.into_iter().collect();
1303 let shape = tensors[0].shape();
1304 let rank = shape.rank();
1305 let dim = if dim < 0 { dim + rank as i64 } else { dim } as usize;
1306 for tensor in &tensors {
1308 for (i, (d1, d2)) in shape.iter().zip(tensor.shape().iter()).enumerate() {
1309 if i != dim {
1310 debug_assert_eq!(*d1, *d2, "Cannot concatenate these tensors.");
1311 }
1312 }
1313 }
1314 let mut offset = 0i64;
1315 let mut res = tensors[0]
1316 .backend
1317 .zeros(tensors[0].shape(), tensors[0].dtype())
1318 .unwrap();
1319 for tensor in tensors {
1320 res = res
1321 + tensor.pad(
1322 repeat((0i64, 0i64))
1323 .take(rank - dim - 1)
1324 .chain([(offset, 0i64)]),
1325 0,
1326 );
1327 offset += tensor.shape()[dim] as i64;
1328 }
1329 res
1330 }
1331
1332 }
1359
1360enum BOp {
1361 Add,
1362 Sub,
1363 Mul,
1364 Div,
1365 Pow,
1366 Cmplt,
1367}
1368
1369impl<B: Backend> Tensor<B> {
1371 #[must_use]
1372 fn binary_op(self, rhs: impl IntoTensor<B>, op: BOp) -> Tensor<B> {
1373 let rhs = rhs.into_tensor(self.backend);
1374 let (x, y) = Tensor::broadcast(self, rhs);
1375 tensor(
1376 x.backend
1377 .push(match op {
1378 BOp::Add => Node::Add(x.id, y.id),
1379 BOp::Sub => Node::Sub(x.id, y.id),
1380 BOp::Mul => Node::Mul(x.id, y.id),
1381 BOp::Div => Node::Div(x.id, y.id),
1382 BOp::Pow => Node::Pow(x.id, y.id),
1383 BOp::Cmplt => Node::Cmplt(x.id, y.id),
1384 })
1385 .unwrap(),
1386 x.backend,
1387 )
1388 }
1389
1390 #[must_use]
1394 fn broadcast(mut x: Tensor<B>, mut y: Tensor<B>) -> (Tensor<B>, Tensor<B>) {
1395 match (x.dtype(), y.dtype()) {
1405 (DType::F32, DType::I32) => y = y.cast(DType::F32),
1406 (DType::F32, DType::F64) => x = x.cast(DType::F64),
1407 (DType::I32, DType::F32) => x = x.cast(DType::F32),
1408 (DType::I32, DType::F64) => x = x.cast(DType::F64),
1409 (DType::F64, DType::F32) => y = y.cast(DType::F64),
1410 (DType::F64, DType::I32) => y = y.cast(DType::F64),
1411 _ => {}
1412 }
1413 let mut x_shape = x.shape();
1414 let mut y_shape = y.shape();
1415
1416 for (x, y) in x_shape.iter().rev().zip(y_shape.iter().rev()) {
1417 if x != y {
1418 debug_assert!(
1419 *x == 1 || *y == 1,
1420 "Left and right tensor shapes can not be broadcasted: {x_shape} and {y_shape}"
1421 );
1422 }
1423 }
1424
1425 let rx = x_shape.rank();
1426 let ry = y_shape.rank();
1427 match rx.cmp(&ry) {
1428 Ordering::Less => {
1429 x_shape = repeat(1)
1430 .take(ry - rx)
1431 .chain(x_shape.into_iter().copied())
1432 .collect::<Vec<usize>>()
1433 .into();
1434 }
1435 Ordering::Greater => {
1436 y_shape = repeat(1)
1437 .take(rx - ry)
1438 .chain(y_shape.into_iter().copied())
1439 .collect::<Vec<usize>>()
1440 .into();
1441 }
1442 Ordering::Equal => {}
1443 }
1444 let mut eshape = Vec::new();
1445 for (x, y) in x_shape.into_iter().zip(y_shape.into_iter()) {
1446 eshape.push(*x.max(y));
1447 }
1448 let eshape: Shape = eshape.into();
1449 if x_shape != eshape {
1450 x = x.expand(eshape.clone());
1451 }
1452 if y_shape != eshape {
1453 y = y.expand(eshape);
1454 }
1455 (x, y)
1456 }
1457}
1458
1459impl<B: Backend> core::ops::Neg for Tensor<B> {
1460 type Output = Tensor<B>;
1461 fn neg(self) -> Self::Output {
1462 tensor(self.backend.push(Node::Neg(self.id)).unwrap(), self.backend)
1463 }
1464}
1465
1466impl<B: Backend> core::ops::Neg for &Tensor<B> {
1467 type Output = Tensor<B>;
1468 fn neg(self) -> Self::Output {
1469 tensor(self.backend.push(Node::Neg(self.id)).unwrap(), self.backend)
1470 }
1471}
1472
1473impl<B: Backend, IT: IntoTensor<B>> core::ops::Add<IT> for &Tensor<B> {
1474 type Output = Tensor<B>;
1475 fn add(self, rhs: IT) -> Self::Output {
1476 self.clone().binary_op(rhs, BOp::Add)
1477 }
1478}
1479
1480impl<B: Backend, IT: IntoTensor<B>> core::ops::Add<IT> for Tensor<B> {
1481 type Output = Tensor<B>;
1482 fn add(self, rhs: IT) -> Self::Output {
1483 self.binary_op(rhs, BOp::Add)
1484 }
1485}
1486
1487impl<B: Backend, IT: IntoTensor<B>> core::ops::Sub<IT> for &Tensor<B> {
1488 type Output = Tensor<B>;
1489 fn sub(self, rhs: IT) -> Self::Output {
1490 self.clone().binary_op(rhs, BOp::Sub)
1491 }
1492}
1493
1494impl<B: Backend, IT: IntoTensor<B>> core::ops::Sub<IT> for Tensor<B> {
1495 type Output = Tensor<B>;
1496 fn sub(self, rhs: IT) -> Self::Output {
1497 self.binary_op(rhs, BOp::Sub)
1498 }
1499}
1500
1501impl<B: Backend, IT: IntoTensor<B>> core::ops::Mul<IT> for &Tensor<B> {
1502 type Output = Tensor<B>;
1503 fn mul(self, rhs: IT) -> Self::Output {
1504 self.clone().binary_op(rhs, BOp::Mul)
1505 }
1506}
1507
1508impl<B: Backend> core::ops::Mul<Tensor<B>> for f32 {
1509 type Output = Tensor<B>;
1510 fn mul(self, rhs: Tensor<B>) -> Self::Output {
1511 rhs * self
1512 }
1513}
1514
1515impl<B: Backend> core::ops::Mul<&Tensor<B>> for f32 {
1516 type Output = Tensor<B>;
1517 fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1518 rhs * self
1519 }
1520}
1521
1522impl<B: Backend> core::ops::Mul<Tensor<B>> for f64 {
1523 type Output = Tensor<B>;
1524 fn mul(self, rhs: Tensor<B>) -> Self::Output {
1525 rhs * self
1526 }
1527}
1528
1529impl<B: Backend> core::ops::Mul<&Tensor<B>> for f64 {
1530 type Output = Tensor<B>;
1531 fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1532 rhs * self
1533 }
1534}
1535
1536impl<B: Backend> core::ops::Mul<Tensor<B>> for i32 {
1537 type Output = Tensor<B>;
1538 fn mul(self, rhs: Tensor<B>) -> Self::Output {
1539 rhs * self
1540 }
1541}
1542
1543impl<B: Backend> core::ops::Mul<&Tensor<B>> for i32 {
1544 type Output = Tensor<B>;
1545 fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1546 rhs * self
1547 }
1548}
1549
1550impl<B: Backend, IT: IntoTensor<B>> core::ops::Mul<IT> for Tensor<B> {
1551 type Output = Tensor<B>;
1552 fn mul(self, rhs: IT) -> Self::Output {
1553 self.binary_op(rhs, BOp::Mul)
1554 }
1555}
1556
1557impl<B: Backend, IT: IntoTensor<B>> core::ops::Div<IT> for &Tensor<B> {
1558 type Output = Tensor<B>;
1559 fn div(self, rhs: IT) -> Self::Output {
1560 self.clone().binary_op(rhs, BOp::Div)
1561 }
1562}
1563
1564impl<B: Backend, IT: IntoTensor<B>> core::ops::Div<IT> for Tensor<B> {
1565 type Output = Tensor<B>;
1566 fn div(self, rhs: IT) -> Self::Output {
1567 self.binary_op(rhs, BOp::Div)
1568 }
1569}
1570
1571impl<B: Backend> core::ops::Div<Tensor<B>> for f32 {
1572 type Output = Tensor<B>;
1573 fn div(self, rhs: Tensor<B>) -> Self::Output {
1574 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1575 }
1576}
1577
1578impl<B: Backend> core::ops::Div<&Tensor<B>> for f32 {
1579 type Output = Tensor<B>;
1580 fn div(self, rhs: &Tensor<B>) -> Self::Output {
1581 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1582 }
1583}
1584
1585impl<B: Backend> core::ops::Div<Tensor<B>> for f64 {
1586 type Output = Tensor<B>;
1587 fn div(self, rhs: Tensor<B>) -> Self::Output {
1588 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1589 }
1590}
1591
1592impl<B: Backend> core::ops::Div<&Tensor<B>> for f64 {
1593 type Output = Tensor<B>;
1594 fn div(self, rhs: &Tensor<B>) -> Self::Output {
1595 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1596 }
1597}
1598
1599impl<B: Backend> core::ops::Div<Tensor<B>> for i32 {
1600 type Output = Tensor<B>;
1601 fn div(self, rhs: Tensor<B>) -> Self::Output {
1602 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1603 }
1604}
1605
1606impl<B: Backend> core::ops::Div<&Tensor<B>> for i32 {
1607 type Output = Tensor<B>;
1608 fn div(self, rhs: &Tensor<B>) -> Self::Output {
1609 rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1610 }
1611}
1612
1613pub trait IntoTensor<B: Backend> {
1615 fn into_tensor(self, backend: B) -> Tensor<B>;
1617}
1618
1619impl<B: Backend> IntoTensor<B> for Tensor<B> {
1620 fn into_tensor(self, _backend: B) -> Tensor<B> {
1621 self
1623 }
1624}
1625
1626impl<B: Backend> IntoTensor<B> for &Tensor<B> {
1627 fn into_tensor(self, _backend: B) -> Tensor<B> {
1628 self.clone()
1630 }
1631}
1632
1633impl<B: Backend, T: Scalar> IntoTensor<B> for Range<T>
1634where
1635 Range<T>: Iterator<Item = T> + ExactSizeIterator,
1636{
1637 fn into_tensor(self, backend: B) -> Tensor<B> {
1638 tensor(backend.store(self).unwrap(), backend)
1639 }
1640}
1641
1642impl<B: Backend, T: Scalar> IntoTensor<B> for Vec<T> {
1643 fn into_tensor(self, backend: B) -> Tensor<B> {
1644 tensor(backend.store(self).unwrap(), backend)
1645 }
1646}
1647
1648impl<B: Backend, T: Scalar> IntoTensor<B> for &'static [T] {
1649 fn into_tensor(self, backend: B) -> Tensor<B> {
1650 tensor(backend.store(self.iter().cloned()).unwrap(), backend)
1651 }
1652}
1653
1654impl<B: Backend, T: Scalar> IntoTensor<B> for T {
1655 fn into_tensor(self, backend: B) -> Tensor<B> {
1656 tensor(backend.store([self]).unwrap(), backend)
1657 }
1658}
1659
1660impl<B: Backend, T: Scalar, const D0: usize> IntoTensor<B> for [T; D0] {
1661 fn into_tensor(self, backend: B) -> Tensor<B> {
1662 tensor(backend.store(self).unwrap(), backend)
1663 }
1664}
1665
1666impl<B: Backend, T: Scalar, const D0: usize, const D1: usize> IntoTensor<B> for [[T; D1]; D0] {
1667 fn into_tensor(self, backend: B) -> Tensor<B> {
1668 tensor(
1669 backend
1670 .store(self.into_iter().flatten().make_sized(D0 * D1))
1671 .unwrap(),
1672 backend,
1673 )
1674 .reshape([D0, D1])
1675 }
1676}
1677
1678impl<B: Backend, T: Scalar, const D0: usize, const D1: usize, const D2: usize> IntoTensor<B>
1679 for [[[T; D2]; D1]; D0]
1680{
1681 fn into_tensor(self, backend: B) -> Tensor<B> {
1682 tensor(
1683 backend
1684 .store(
1685 self.into_iter()
1686 .flatten()
1687 .flatten()
1688 .make_sized(D0 * D1 * D2),
1689 )
1690 .unwrap(),
1691 backend,
1692 )
1693 .reshape([D0, D1, D2])
1694 }
1695}
1696
1697impl<B: Backend, IT: IntoTensor<B> + Clone> PartialEq<IT> for Tensor<B> {
1698 fn eq(&self, other: &IT) -> bool {
1699 let other = self.backend.tensor(other.clone()).unwrap();
1700 let dtype = self.dtype();
1701 self.shape() == other.shape()
1702 && dtype == other.dtype()
1703 && match dtype {
1704 DType::F32 => self
1705 .to_vec::<f32>()
1706 .unwrap()
1707 .into_iter()
1708 .zip(other.to_vec::<f32>().unwrap())
1709 .all(|(x, y)| x.is_equal(y)),
1710 DType::F64 => self
1711 .to_vec::<f64>()
1712 .unwrap()
1713 .into_iter()
1714 .zip(other.to_vec::<f64>().unwrap())
1715 .all(|(x, y)| x.is_equal(y)),
1716 DType::I32 => self
1717 .to_vec::<i32>()
1718 .unwrap()
1719 .into_iter()
1720 .zip(other.to_vec::<i32>().unwrap())
1721 .all(|(x, y)| x.is_equal(y)),
1722 }
1723 }
1724}