zyx_core/
tensor.rs

1extern crate alloc;
2use crate::axes::IntoAxes;
3use crate::dtype::DType;
4use crate::error::ZyxError;
5use crate::scalar::Scalar;
6use crate::shape::Shape;
7use crate::utils::SizedIterator;
8use crate::{backend::Backend, node::Node};
9use alloc::{boxed::Box, collections::BTreeSet, vec::Vec};
10use core::{
11    cmp::Ordering,
12    iter::repeat,
13    ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, SubAssign},
14};
15
16/// Id of tensor.
17#[derive(Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Debug)]
18pub struct Id(usize);
19
20/// Create new id.
21pub const fn id(id: usize) -> Id {
22    Id(id)
23}
24
25impl Id {
26    /// Convert id to usize
27    pub const fn i(self) -> usize {
28        self.0
29    }
30}
31
32impl core::fmt::Display for Id {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        f.write_fmt(format_args!("{:?}", self))
35    }
36}
37
38impl SubAssign<usize> for Id {
39    fn sub_assign(&mut self, rhs: usize) {
40        self.0 -= rhs;
41    }
42}
43
44/// Into i64 range, used for indexing
45pub trait IntoRange: Clone {
46    /// Convert self to range i64, if it is scalar, it gets converted to x..x+1
47    fn into_range(self) -> Range<i64>;
48}
49
50impl IntoRange for RangeFull {
51    fn into_range(self) -> Range<i64> {
52        0..i64::MAX
53    }
54}
55
56impl IntoRange for RangeFrom<i64> {
57    fn into_range(self) -> Range<i64> {
58        self.start..i64::MAX
59    }
60}
61
62impl IntoRange for RangeTo<i64> {
63    fn into_range(self) -> Range<i64> {
64        0..self.end
65    }
66}
67
68impl IntoRange for RangeInclusive<i64> {
69    fn into_range(self) -> Range<i64> {
70        *self.start()..*self.end() + 1
71    }
72}
73
74impl IntoRange for RangeToInclusive<i64> {
75    fn into_range(self) -> Range<i64> {
76        0..self.end + 1
77    }
78}
79
80impl IntoRange for Range<i64> {
81    fn into_range(self) -> Range<i64> {
82        self
83    }
84}
85
86impl IntoRange for i64 {
87    fn into_range(self) -> Range<i64> {
88        self..self + 1
89    }
90}
91
92/// Implemented for objects that can be used to index tensors.
93pub trait IntoIndex {
94    /// Convert self to tensor index.
95    fn into_index(self) -> impl IntoIterator<Item = Range<i64>>;
96}
97
98impl<I: IntoRange> IntoIndex for &[I] {
99    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
100        self.iter().cloned().map(IntoRange::into_range)
101    }
102}
103
104impl<I0: IntoRange> IntoIndex for I0 {
105    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
106        [self.into_range()].into_iter()
107    }
108}
109
110impl<I0: IntoRange, I1: IntoRange> IntoIndex for (I0, I1) {
111    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
112        [self.0.into_range(), self.1.into_range()].into_iter()
113    }
114}
115
116impl<I0: IntoRange, I1: IntoRange, I2: IntoRange> IntoIndex for (I0, I1, I2) {
117    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
118        [
119            self.0.into_range(),
120            self.1.into_range(),
121            self.2.into_range(),
122        ]
123        .into_iter()
124    }
125}
126
127impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange> IntoIndex for (I0, I1, I2, I3) {
128    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
129        [
130            self.0.into_range(),
131            self.1.into_range(),
132            self.2.into_range(),
133            self.3.into_range(),
134        ]
135        .into_iter()
136    }
137}
138
139impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange> IntoIndex
140    for (I0, I1, I2, I3, I4)
141{
142    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
143        [
144            self.0.into_range(),
145            self.1.into_range(),
146            self.2.into_range(),
147            self.3.into_range(),
148            self.4.into_range(),
149        ]
150        .into_iter()
151    }
152}
153
154impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange, I5: IntoRange>
155    IntoIndex for (I0, I1, I2, I3, I4, I5)
156{
157    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
158        [
159            self.0.into_range(),
160            self.1.into_range(),
161            self.2.into_range(),
162            self.3.into_range(),
163            self.4.into_range(),
164            self.5.into_range(),
165        ]
166        .into_iter()
167    }
168}
169
170impl<
171        I0: IntoRange,
172        I1: IntoRange,
173        I2: IntoRange,
174        I3: IntoRange,
175        I4: IntoRange,
176        I5: IntoRange,
177        I6: IntoRange,
178    > IntoIndex for (I0, I1, I2, I3, I4, I5, I6)
179{
180    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
181        [
182            self.0.into_range(),
183            self.1.into_range(),
184            self.2.into_range(),
185            self.3.into_range(),
186            self.4.into_range(),
187            self.5.into_range(),
188            self.6.into_range(),
189        ]
190        .into_iter()
191    }
192}
193
194impl<
195        I0: IntoRange,
196        I1: IntoRange,
197        I2: IntoRange,
198        I3: IntoRange,
199        I4: IntoRange,
200        I5: IntoRange,
201        I6: IntoRange,
202        I7: IntoRange,
203    > IntoIndex for (I0, I1, I2, I3, I4, I5, I6, I7)
204{
205    fn into_index(self) -> impl IntoIterator<Item = Range<i64>> {
206        [
207            self.0.into_range(),
208            self.1.into_range(),
209            self.2.into_range(),
210            self.3.into_range(),
211            self.4.into_range(),
212            self.5.into_range(),
213            self.6.into_range(),
214            self.7.into_range(),
215        ]
216        .into_iter()
217    }
218}
219
220/// A range of axes that can be used for flattening tensors.
221pub trait FlattenAxes {
222    /// Get flatten axes
223    fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64>;
224}
225
226impl FlattenAxes for RangeFrom<i64> {
227    fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
228        debug_assert!(
229            if self.start > 0 {
230                (self.start as usize) < rank
231            } else {
232                ((-self.start) as usize) <= rank
233            },
234            "Cannot use {self:?} as flatten axes."
235        );
236        self.start..i64::MAX
237    }
238}
239
240impl FlattenAxes for RangeTo<i64> {
241    fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
242        debug_assert!(
243            if self.end > 0 {
244                (self.end as usize) < rank
245            } else {
246                ((-self.end) as usize) <= rank
247            },
248            "Cannot use {self:?} as flatten axes."
249        );
250        0..self.end
251    }
252}
253
254impl FlattenAxes for RangeToInclusive<i64> {
255    fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
256        debug_assert!(
257            if self.end > 0 {
258                (self.end as usize) < rank
259            } else {
260                ((-self.end) as usize) <= rank
261            },
262            "Cannot use {self:?} as flatten axes."
263        );
264        0..self.end + 1
265    }
266}
267
268impl FlattenAxes for RangeFull {
269    fn into_flatten_axes(self, rank: usize) -> impl IntoIterator<Item = i64> {
270        0..rank as i64
271    }
272}
273
274/// Tensor is the core object of zyx.
275/// It is multidimensional array.
276pub struct Tensor<B: Backend> {
277    id: Id,
278    backend: B,
279}
280
281impl<B: Backend> Clone for Tensor<B> {
282    fn clone(&self) -> Self {
283        self.backend.retain(self.id);
284        tensor(self.id, self.backend)
285    }
286}
287
288impl<B: Backend> Drop for Tensor<B> {
289    fn drop(&mut self) {
290        //std::println!("Dropping tensor {}", self.id);
291        self.backend.release(self.id).unwrap();
292    }
293}
294
295impl<B: Backend> core::fmt::Debug for Tensor<B> {
296    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
297        f.write_fmt(format_args!("{self}"))
298        //f.write_fmt(format_args!("Tensor {{ id = {:?} }}", self.id))
299    }
300}
301
302impl<B: Backend> core::fmt::Display for Tensor<B> {
303    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
304        // TODO don't print the whole tensor if it is too big
305        let precision = if let Some(precision) = f.precision() {
306            precision
307        } else {
308            3
309        };
310        let res = match self.dtype() {
311            DType::F32 => {
312                if let Ok(data) = &self.to_vec::<f32>() {
313                    tensor_to_string(data, &self.shape(), precision, f.width())
314                } else {
315                    "f32 tensor failed to realize".into()
316                }
317            }
318            DType::F64 => {
319                if let Ok(data) = &self.to_vec::<f64>() {
320                    tensor_to_string(data, &self.shape(), precision, f.width())
321                } else {
322                    "f64 tensor failed to realize".into()
323                }
324            }
325            DType::I32 => {
326                if let Ok(data) = &self.to_vec::<i32>() {
327                    tensor_to_string(data, &self.shape(), precision, f.width())
328                } else {
329                    "i32 tensor failed to realize".into()
330                }
331            }
332        };
333        f.write_fmt(format_args!(
334            "Tensor {} {}\n{res}",
335            self.shape(),
336            self.dtype()
337        ))
338    }
339}
340
341fn tensor_to_string<T: core::fmt::Display>(
342    data: &[T],
343    shape: &Shape,
344    precision: usize,
345    width: Option<usize>,
346) -> alloc::string::String {
347    use core::fmt::Write;
348    let n = shape.numel();
349    let ndim = shape.rank();
350    let mut res = alloc::string::String::new();
351    if data.is_empty() {
352        return "[]".into();
353    }
354    // get maximal width of single value
355    let mut w = 0;
356    if let Some(width) = width {
357        w = width;
358    } else {
359        for x in data {
360            let l = alloc::format!("{x:>.precision$}").len();
361            if l > w {
362                w = l;
363            }
364        }
365    }
366    let d0 = shape[-1];
367    for (i, x) in data.iter().enumerate() {
368        {
369            let mut var = 1;
370            let mut r = ndim;
371            while r > 0 {
372                if i % (n / var) == 0 {
373                    res += &(" ".repeat(ndim - r) + "[".repeat(r - 1).as_str());
374                    break;
375                }
376                var *= shape[ndim - r];
377                r -= 1;
378            }
379        }
380        let _ = write!(res, "{x:>w$.precision$}");
381        if (i + 1) % d0 != 0usize {
382            res += "  ";
383        }
384        {
385            let mut var = 1;
386            let mut r = ndim;
387            while r > 0 {
388                if (i + 1) % (n / var) == 0 {
389                    res += &"]".repeat(r - 1);
390                    break;
391                }
392                var *= shape[ndim - r];
393                r -= 1;
394            }
395        }
396        if (i + 1) % d0 == 0usize && i != n - 1 {
397            res += "\n";
398        }
399    }
400    res
401}
402
403/// Create new tensor from id and backend.
404/// Used mostly internally in tensor and in backends.
405pub const fn tensor<B: Backend>(id: Id, backend: B) -> Tensor<B> {
406    Tensor { id, backend }
407}
408
409impl<B: Backend> Tensor<B> {
410    // Metadata
411    /// Tensor's unique identification.
412    /// All tensors on one backend will always have different ids.
413    pub fn id(&self) -> Id {
414        self.id
415    }
416
417    /// Returns the [shape](Shape) of the self tensor.
418    /// ```
419    /// let dev = zyx_opencl::device()?;
420    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
421    /// assert_eq!(x.shape(), [2, 3]);
422    /// # Ok::<(), zyx_opencl::ZyxError>(())
423    /// ```
424    #[must_use]
425    pub fn shape(&self) -> Shape {
426        self.backend.shape(self.id)
427    }
428
429    /// Returns number of elements in the self tensor.
430    /// ```
431    /// let dev = zyx_opencl::device()?;
432    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
433    /// assert_eq!(x.numel(), 6);
434    /// # Ok::<(), zyx_opencl::ZyxError>(())
435    /// ```
436    #[must_use]
437    pub fn numel(&self) -> usize {
438        self.shape().numel()
439    }
440
441    /// Returns the [dtype](DType) of the self tensor.
442    /// ```
443    /// let dev = zyx_opencl::device()?;
444    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
445    /// assert_eq!(x.dtype(), zyx_opencl::DType::I32);
446    /// # Ok::<(), zyx_opencl::ZyxError>(())
447    /// ```
448    #[must_use]
449    pub fn dtype(&self) -> DType {
450        self.backend.dtype(self.id)
451    }
452
453    /// Returns the rank of the self tensor. This is the number of tensor's dimensions.
454    /// ```
455    /// let dev = zyx_opencl::device()?;
456    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
457    /// assert_eq!(x.rank(), 2);
458    /// # Ok::<(), zyx_opencl::ZyxError>(())
459    /// ```
460    #[must_use]
461    pub fn rank(&self) -> usize {
462        self.shape().rank()
463    }
464
465    /// Returns the [backend](Backend) of the self tensor.
466    /// ```
467    /// let dev = zyx_opencl::device()?;
468    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
469    /// let y = x.backend().randn([2, 4, 3], zyx_opencl::DType::F32);
470    /// # Ok::<(), zyx_opencl::ZyxError>(())
471    /// ```
472    #[must_use]
473    pub fn backend(&self) -> B {
474        self.backend
475    }
476
477    /// Detach gradient tape from tensor.
478    /// This means that resulting tensor is a shallow copy of self,
479    /// but it's gradient will be ones. Result of this operation
480    /// can only be differentiated by itself.
481    /// ```rust
482    /// let dev = zyx_opencl::device()?;
483    /// let x = dev.tensor([2, 3]);
484    /// let y = &x + &x;
485    /// let g = y.backward([&x]).pop().unwrap().unwrap();
486    /// assert_eq!(g, [2, 2]);
487    /// let z = y.detach();
488    /// let g = z.backward([&z]).pop().unwrap().unwrap();
489    /// assert_eq!(g, [1, 1]);
490    /// let g = z.backward([&x]).pop().unwrap();
491    /// assert_eq!(g, None);
492    /// # Ok::<(), zyx_opencl::ZyxError>(())
493    /// ```
494    #[must_use]
495    pub fn detach(&self) -> Tensor<B> {
496        // It should be possible to just be optimize this away.
497        tensor(
498            self.backend.push(Node::Detach(self.id)).unwrap(),
499            self.backend,
500        )
501    }
502
503    /*
504    /// Probably just add no_grad, that is all tensors coming from no_grad tensor
505    /// are not differentiable, unless some other parameter in those ops is differentiable.
506    #[must_use]
507    pub fn no_grad(&self) {
508        // TODO
509        //self.backend.no_grad(self.id);
510    }*/
511
512    // Access methods
513    /// Load tensor from backend into vector
514    /// ```
515    /// let dev = zyx_opencl::device()?;
516    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
517    /// let xvec: Vec<i32> = x.to_vec()?;
518    /// assert_eq!(xvec, vec![2, 3, 1, 4, 1, 3]);
519    /// # Ok::<(), zyx_opencl::ZyxError>(())
520    /// ```
521    pub fn to_vec<T: Scalar>(&self) -> Result<Vec<T>, ZyxError> {
522        if T::dtype() != self.dtype() {
523            return Err(ZyxError::InvalidDType {
524                expected: T::dtype(),
525                found: self.dtype(),
526            });
527        }
528        self.backend.load(self.id)
529    }
530
531    /// Returns first element stored in this tensor.
532    /// Usually used for tensors with exactly one element.
533    /// Error is returned if self tensor contains zero elements
534    /// or if backend returns error.
535    /// ```
536    /// let dev = zyx_opencl::device()?;
537    /// let x = dev.tensor([[2, 3, 1], [4, 1, 3]]);
538    /// let xitem: i32 = x.item()?;
539    /// assert_eq!(xitem, 2);
540    /// # Ok::<(), zyx_opencl::ZyxError>(())
541    /// ```
542    pub fn item<T: Scalar>(&self) -> Result<T, ZyxError> {
543        self.backend
544            .load::<T>(self.id)?
545            .first()
546            .ok_or(ZyxError::IndexOutOfBounds { index: 0, len: 0 })
547            .cloned()
548    }
549
550    // Backpropagation
551    /// Returns gradients of self w.r.t. sources.
552    /// ```rust
553    /// let dev = zyx_opencl::device()?;
554    /// let x = dev.tensor([3., 2., 1.]);
555    /// let y = x.exp() + &x;
556    /// let x_grad = y.backward([&x]).into_iter().next().unwrap().unwrap();
557    /// assert_eq!(x_grad, [21.0855369, 8.3890561, 3.7182818]);
558    /// # Ok::<(), zyx_opencl::ZyxError>(())
559    /// ```
560    #[must_use]
561    pub fn backward<'a>(
562        &'a self,
563        sources: impl IntoIterator<Item = &'a Tensor<B>>,
564    ) -> Vec<Option<Tensor<B>>>
565    where
566        B: 'a,
567    {
568        let sources: Vec<&Tensor<B>> = sources.into_iter().collect();
569        let grads = self
570            .backend
571            .backward(self.id, &sources.iter().map(|t| t.id).collect())
572            .unwrap();
573        sources
574            .into_iter()
575            .map(move |x: &Tensor<B>| grads.get(&x.id).cloned())
576            .map(move |x| x.map(|x| tensor(x, self.backend)))
577            .collect()
578    }
579
580    // Unary ops
581    /// Cast self into dtype.
582    /// ```rust
583    /// # use zyx_opencl::DType;
584    /// let dev = zyx_opencl::device()?;
585    /// let x = dev.tensor([[3, 4, 2], [4, 5, 2]]);
586    /// let y = x.cast(DType::F32);
587    /// assert_eq!(y.dtype(), DType::F32);
588    /// assert_eq!(y, [[3f32, 4., 2.], [4., 5., 2.]]);
589    /// # Ok::<(), zyx_opencl::ZyxError>(())
590    /// ```
591    #[must_use]
592    pub fn cast(&self, dtype: DType) -> Tensor<B> {
593        tensor(
594            self.backend.push(Node::Cast(self.id, dtype)).unwrap(),
595            self.backend,
596        )
597    }
598
599    /// Returns a new tensor with the rectified linear unit function applied to the elements of self.
600    #[must_use]
601    pub fn relu(&self) -> Tensor<B> {
602        tensor(
603            self.backend.push(Node::ReLU(self.id)).unwrap(),
604            self.backend,
605        )
606    }
607
608    /// Returns a new tensor with the sine of the elements of self.
609    #[must_use]
610    pub fn sin(&self) -> Tensor<B> {
611        tensor(self.backend.push(Node::Sin(self.id)).unwrap(), self.backend)
612    }
613
614    /// Returns a new tensor with the cosine of the elements of self.
615    #[must_use]
616    pub fn cos(&self) -> Tensor<B> {
617        tensor(self.backend.push(Node::Cos(self.id)).unwrap(), self.backend)
618    }
619
620    /// Returns a new tensor with the natural logarithm of the elements of self.
621    /// Due to performance reasons, this function does not check if self fits
622    /// into domain of ln(x). Result on out of domain numbers is implementation
623    /// defined (when x <= 0).
624    #[must_use]
625    pub fn ln(&self) -> Tensor<B> {
626        tensor(self.backend.push(Node::Ln(self.id)).unwrap(), self.backend)
627    }
628
629    /// Returns a new tensor with the exponential of the elements of self.
630    #[must_use]
631    pub fn exp(&self) -> Tensor<B> {
632        tensor(self.backend.push(Node::Exp(self.id)).unwrap(), self.backend)
633    }
634
635    /// Returns a new tensor with the hyperbolic tangent of the elements of self.
636    #[must_use]
637    pub fn tanh(&self) -> Tensor<B> {
638        tensor(
639            self.backend.push(Node::Tanh(self.id)).unwrap(),
640            self.backend,
641        )
642    }
643
644    /// Returns a new tensor with the square root of the elements of self.
645    /// Due to performance reasons, this function does not check if self fits
646    /// into domain of ln(x). Result on out of domain numbers is implementation
647    /// defined (when x < 0).
648    #[must_use]
649    pub fn sqrt(&self) -> Tensor<B> {
650        tensor(
651            self.backend.push(Node::Sqrt(self.id)).unwrap(),
652            self.backend,
653        )
654    }
655
656    /// Returns 1/self
657    #[must_use]
658    pub fn reciprocal(&self) -> Tensor<B> {
659        self.backend().ones(self.shape(), self.dtype()).unwrap() / self
660    }
661
662    /// Returns 1/self.sqrt()
663    #[must_use]
664    pub fn rsqrt(&self) -> Tensor<B> {
665        self.reciprocal().sqrt()
666    }
667
668    /// Returns a new tensor with each element of self randomly zeroed with given probability.
669    #[must_use]
670    pub fn dropout(&self, probability: impl Scalar) -> Tensor<B> {
671        self.backend()
672            .tensor(probability)
673            .unwrap()
674            .cmplt(self.backend().uniform(self.shape(), 0.0..1.0).unwrap()).cast(self.dtype())
675            * self
676    }
677
678    /// Returns a new tensor with the absolute value of the elements of self.
679    #[must_use]
680    pub fn abs(&self) -> Tensor<B> {
681        self.relu() + (-self).relu()
682    }
683
684    /// Returns a new tensor with the sigmoid (logistic function) of the elements of self.
685    #[must_use]
686    pub fn sigmoid(&self) -> Tensor<B> {
687        let one = self.backend().ones(1, self.dtype()).unwrap();
688        &one / (&one + (-self).exp())
689    }
690
691    /// Returns a new tensor with the swish/silu of the elements of self.
692    #[must_use]
693    pub fn swish(&self) -> Tensor<B> {
694        self * self.sigmoid()
695    }
696
697    /// Returns a new tensor with the mish of the elements of self.
698    #[must_use]
699    pub fn mish(&self) -> Tensor<B> {
700        self * self.softplus(1, 20).tanh()
701    }
702
703    /// Returns a new tensor with the softplus of the elements of self.
704    #[must_use]
705    pub fn softplus(&self, beta: impl Scalar, threshold: impl Scalar) -> Tensor<B> {
706        let x = self * beta.clone();
707        x.cmplt(threshold)
708            .where_(((x).exp() + 1).ln() * beta.reciprocal(), x)
709    }
710
711    /// Returns a new tensor with the tangent of the elements of self.
712    #[must_use]
713    pub fn tan(&self) -> Tensor<B> {
714        self.sin() / self.cos()
715    }
716
717    /// Returns a new tensor with the leaky relu of the elements of self.
718    #[must_use]
719    pub fn leaky_relu(&self, neg_slope: impl Scalar) -> Tensor<B> {
720        self.relu() - (self * (-self.backend.tensor(neg_slope).unwrap())).relu()
721    }
722
723    /// Returns a new tensor with the elu of the elements of self.
724    #[must_use]
725    pub fn elu(&self, alpha: impl Scalar) -> Tensor<B> {
726        self.relu() - (1f32.into_tensor(self.backend) - self.exp()).relu() * alpha
727    }
728
729    /// Returns a new tensor with the selu of the elements of self.
730    #[must_use]
731    pub fn selu(&self) -> Tensor<B> {
732        1.0507009873554804934193349852946f32
733            * (self.relu()
734                - (1.6732632423543772848170429916717f32
735                    * (self.backend.ones(1, self.dtype()).unwrap() - self.exp()))
736                .relu())
737    }
738
739    /// Returns a new tensor with the celu of the elements of self.
740    #[must_use]
741    pub fn celu(&self, alpha: impl Scalar) -> Tensor<B> {
742        self.relu()
743            - ((self.backend.ones(1, self.dtype()).unwrap() - (self / alpha.clone()).exp()) * alpha)
744                .relu()
745    }
746
747    /// Returns a new tensor with the gelu of the elements of self.
748    #[must_use]
749    pub fn gelu(&self) -> Tensor<B> {
750        self * 0.5f32
751            * (((self + self.pow(3f32) * 0.044_715f32) * (2f32 / core::f32::consts::PI).sqrt())
752                .tanh()
753                + 1f32)
754    }
755
756    /// Returns a new tensor with the quick gelu of the elements of self.
757    #[must_use]
758    pub fn quick_gelu(&self) -> Tensor<B> {
759        self * (1.702f32 * self).sigmoid()
760    }
761
762    /// Returns a new tensor with the softmax of the elements of self.
763    #[must_use]
764    pub fn softmax(&self, axes: impl IntoAxes) -> Tensor<B> {
765        let axes = axes.into_axes(self.rank());
766        let e = (self - self.max(axes.clone())).exp();
767        &e / e.sum(axes)
768    }
769
770    /// Returns a new tensor with the log softmax of the elements of self.
771    #[must_use]
772    pub fn ln_softmax(&self, axes: impl IntoAxes) -> Tensor<B> {
773        let axes = axes.into_axes(self.rank());
774        let m = self - self.max(axes.clone());
775        &m - m.exp().sum(axes).ln()
776    }
777
778    // Loss functions, all losses are without reduce
779    /// Measures the mean absolute error (MAE) between each element in the input self and target.
780    #[must_use]
781    pub fn l1_loss(&self, target: impl IntoTensor<B>) -> Tensor<B> {
782        (self - target).abs()
783    }
784
785    /// Measures the mean squared error (MSE) between each element in the input self and target.
786    #[must_use]
787    pub fn mse_loss(&self, target: impl IntoTensor<B>) -> Tensor<B> {
788        (self - target).pow(2)
789    }
790
791    /// Computes the cross entropy loss between self logits and target.
792    /// This function expects self to contain probabilities for each class.
793    #[must_use]
794    pub fn cross_entropy_loss(&self, target: impl IntoTensor<B>, axes: impl IntoAxes) -> Tensor<B> {
795        self.ln_softmax(axes) * target
796    }
797
798    // Binary ops
799    /// Exponentiation on self
800    #[must_use]
801    pub fn pow(&self, exponent: impl IntoTensor<B>) -> Tensor<B> {
802        let exponent = self.backend.tensor(exponent).unwrap();
803        if exponent.numel() == 1 {
804            let dtype = exponent.dtype();
805            if !dtype.is_floating() {
806                // TODO other int dtypes
807                if exponent.item::<i32>().unwrap() == 2i32 {
808                    return self * self;
809                } else if exponent.item::<i32>().unwrap() == 3i32 {
810                    return self * self * self;
811                }
812            }
813        }
814        if self.dtype().is_floating() {
815            return (exponent * self.ln()).exp();
816        }
817        self.clone().binary_op(exponent, BOp::Pow)
818    }
819
820    /// Elementwise compare less than between self and rhs
821    #[must_use]
822    pub fn cmplt(&self, rhs: impl IntoTensor<B>) -> Tensor<B> {
823        self.clone().binary_op(rhs, BOp::Cmplt)
824    }
825
826    /// Returns a new tensor with the true values replaced with if_true and the false values replaced with if_false.
827    #[must_use]
828    pub fn where_(&self, if_true: impl IntoTensor<B>, if_false: impl IntoTensor<B>) -> Tensor<B> {
829        let x = self.clone();
830        let y = self.backend.tensor(if_true).unwrap();
831        let z = self.backend.tensor(if_false).unwrap();
832        let (x, y) = Tensor::broadcast(x, y);
833        let (x, z) = Tensor::broadcast(x, z);
834        let (y, z) = Tensor::broadcast(y, z);
835        tensor(
836            self.backend.push(Node::Where(x.id, y.id, z.id)).unwrap(),
837            self.backend,
838        )
839    }
840
841    /// Returns cosine_similarity between self and rhs, computed along axes.
842    #[must_use]
843    pub fn cosine_similarity(&self, rhs: impl IntoTensor<B>, eps: impl IntoTensor<B>) -> Tensor<B> {
844        let rhs = self.backend.tensor(rhs).unwrap();
845        let eps = self.backend.tensor(eps).unwrap();
846        let x = self.pow(2).sqrt() * rhs.pow(2).sqrt();
847        self * rhs / x.cmplt(&eps).where_(eps, x)
848    }
849
850    /// Dot product (mathematical multiplication) of self and rhs.
851    /// ```rust
852    /// # use zyx_opencl::DType;
853    /// let dev = zyx_opencl::device()?;
854    /// let x = dev.tensor([[3, 4, 2], [4, 5, 2]]);
855    /// let y = dev.tensor([[3], [1], [4]]);
856    /// assert_eq!(x.dot(y), [[21], [25]]);
857    /// # Ok::<(), zyx_opencl::ZyxError>(())
858    /// ```
859    #[must_use]
860    pub fn dot(&self, rhs: impl IntoTensor<B>) -> Tensor<B> {
861        let y = self.backend.tensor(rhs).unwrap().transpose();
862        let xshape = self.shape();
863        let yshape = y.shape();
864        let yrank = yshape.rank();
865        debug_assert_eq!(
866            xshape[-1], yshape[-1],
867            //yshape[-(yrank.min(2) as i64)],
868            "Cannot dot tensors with shapes {xshape} and {yshape}"
869        );
870        let x_shape = xshape[0..-1]
871            .iter()
872            .copied()
873            .chain([1])
874            .chain([xshape[-1]])
875            .collect::<Box<[usize]>>();
876        let y_shape = yshape[0..-2]
877            .iter()
878            .copied()
879            .chain([1])
880            .chain(yshape[-(yrank.min(2) as i64)..yrank as i64].iter().copied())
881            .collect::<Box<[usize]>>();
882        //std::println!("{x_shape:?}");
883        //std::println!("{y_shape:?}");
884        (self.reshape(x_shape) * y.reshape(y_shape))
885            .sum(-1)
886            .reshape(
887                xshape[0..-1]
888                    .iter()
889                    .copied()
890                    .chain([yshape[-2]])
891                    .collect::<Box<[usize]>>(),
892            )
893    }
894
895    // Movement ops
896    /// Reshape self to shape.
897    /// # Panics
898    /// Following must hold:
899    /// self.numel() == shape.numel()
900    #[must_use]
901    pub fn reshape(&self, shape: impl Into<Shape>) -> Tensor<B> {
902        let shape = shape.into();
903        debug_assert_eq!(
904            self.shape().numel(),
905            shape.numel(),
906            "Cannot reshape tensor with shape {} to {shape}",
907            self.shape()
908        );
909        tensor(
910            self.backend.push(Node::Reshape(self.id, shape)).unwrap(),
911            self.backend,
912        )
913    }
914
915    /// Expand self into bigger shape
916    #[must_use]
917    pub fn expand(&self, shape: impl Into<Shape>) -> Tensor<B> {
918        let shape = shape.into();
919        let sh = self.shape();
920        debug_assert!(
921            shape
922                .iter()
923                .rev()
924                .enumerate()
925                .all(|(i, d)| if sh.rank() > i {
926                    *d == sh[sh.rank() - i - 1] || sh[sh.rank() - i - 1] == 1
927                } else {
928                    true
929                }),
930            "Can't expand tensor with shape {sh} to {shape}"
931        );
932        tensor(
933            self.backend.push(Node::Expand(self.id, shape)).unwrap(),
934            self.backend,
935        )
936    }
937
938    /// Constant padding
939    ///
940    /// This can both add and remove values from tensor. Negative padding removes values, positive padding
941    /// adds values.
942    ///
943    /// Pad last dimension by (1, 2)
944    /// ```rust
945    /// use zyx_opencl;
946    /// let dev = zyx_opencl::device()?;
947    /// let x = dev.tensor([[2, 3],
948    ///                     [4, 1]]);
949    /// let z = x.pad([(1, 2)], 0);
950    /// std::println!("{}", z);
951    /// assert_eq!(z, [[0, 2, 3, 0, 0],
952    ///                [0, 4, 1, 0, 0]]);
953    /// # Ok::<(), zyx_opencl::ZyxError>(())
954    /// ```
955    /// Pad last dimension by (2, -1) and second last dimension by (1, 1)
956    /// ```rust
957    /// # use zyx_opencl;
958    /// # let dev = zyx_opencl::device()?;
959    /// # let x = dev.tensor([[2, 3],
960    /// #                     [4, 1]]);
961    /// let z = x.pad([(2, -1), (1, 1)], 0);
962    /// println!("z: {z}");
963    /// assert_eq!(z, [[0, 0, 0],
964    ///                [0, 0, 2],
965    ///                [0, 0, 4],
966    ///                [0, 0, 0]]);
967    /// # Ok::<(), zyx_opencl::ZyxError>(())
968    /// ```
969    ///
970    /// # Panics
971    /// T must be of the same dtype as Tensor's dtype, otherwise this function panics.
972    #[must_use]
973    pub fn pad(
974        &self,
975        padding: impl IntoIterator<Item = (i64, i64)>,
976        value: impl IntoTensor<B>,
977    ) -> Tensor<B> {
978        let dtype = self.dtype();
979        let value = self.backend.tensor(value).unwrap();
980        debug_assert_eq!(
981            value.dtype(),
982            dtype,
983            "Cannot pad tensor with dtype {} with value of dtype {}",
984            dtype,
985            value.dtype()
986        );
987        let padding: Box<[(i64, i64)]> = padding.into_iter().collect();
988        let sh = self.shape();
989        debug_assert!(
990            padding.len() <= sh.rank()
991                && padding
992                    .iter()
993                    .zip(sh.iter().rev())
994                    .all(|((lp, rp), d)| if *lp < 0 {
995                        ((-*lp) as usize) <= *d
996                    } else {
997                        true
998                    } && if *rp < 0 {
999                        ((-*rp) as usize) <= *d
1000                    } else {
1001                        true
1002                    }),
1003            "Cannot pad tensor with shape {sh} with padding {padding:?}"
1004        );
1005        let psh = sh.clone().pad(&padding);
1006        let t0 = tensor(
1007            self.backend
1008                .push(Node::Pad(self.id, padding.clone(), psh.clone()))
1009                .unwrap(),
1010            self.backend,
1011        );
1012        if value.numel() == 1
1013            && match dtype {
1014                DType::F32 => value.item::<f32>().unwrap().is_equal(0f32),
1015                DType::F64 => value.item::<f64>().unwrap().is_equal(0f64),
1016                DType::I32 => value.item::<i32>().unwrap().is_equal(0i32),
1017            }
1018        {
1019            t0
1020        } else {
1021            t0 + tensor(
1022                self.backend
1023                    .push(Node::Pad(
1024                        self.backend.ones(sh, dtype).unwrap().id,
1025                        padding,
1026                        psh.clone(),
1027                    ))
1028                    .unwrap(),
1029                self.backend,
1030            )
1031            .where_(
1032                self.backend.zeros(self.shape(), self.dtype()).unwrap(),
1033                value,
1034            )
1035        }
1036    }
1037
1038    /// Reorder axes of self
1039    #[must_use]
1040    pub fn permute(&self, axes: impl IntoAxes) -> Tensor<B> {
1041        let axes = axes.into_axes(self.rank());
1042        let shape = self.shape().permute(&axes);
1043        debug_assert!(
1044            axes.len() == shape.rank(),
1045            "Cannot permute tensor with shape {shape} with axes {axes}"
1046        );
1047        tensor(
1048            self.backend
1049                .push(Node::Permute(self.id, axes, shape))
1050                .unwrap(),
1051            self.backend,
1052        )
1053    }
1054
1055    /// Swap last two axes of self.
1056    /// If self has rank == 1 and numel == n, then result will have shape /[n, 1/]
1057    #[must_use]
1058    pub fn transpose(&self) -> Tensor<B> {
1059        let mut rank = self.rank();
1060        let x = if rank == 1 {
1061            let n = self.numel();
1062            rank = 2;
1063            self.reshape([1, n])
1064        } else {
1065            self.clone()
1066        };
1067        let mut axes: Vec<usize> = (0..rank).collect();
1068        axes.swap(rank - 1, rank - 2);
1069        x.permute(axes)
1070    }
1071
1072    /// Flatten. Joins axes into one dimension,
1073    #[must_use]
1074    pub fn flatten(&self, axes: impl FlattenAxes) -> Tensor<B> {
1075        let sh = self.shape();
1076        let n = sh.numel();
1077        let rank = sh.rank();
1078        let mut ld = 1;
1079        let mut first_dims = false;
1080        for a in axes.into_flatten_axes(rank) {
1081            let a = if a > 0 {
1082                a as usize
1083            } else {
1084                (a + rank as i64) as usize
1085            };
1086            if a == 0 {
1087                first_dims = true;
1088            }
1089            ld *= sh[a];
1090        }
1091        if first_dims {
1092            self.reshape([ld, n / ld])
1093        } else {
1094            self.reshape([n / ld, ld])
1095        }
1096    }
1097
1098    // Reduce ops
1099    /// Reduce self by summing along axes. Shape is not squeezed.
1100    /// ```rust
1101    /// use zyx_opencl;
1102    /// let dev = zyx_opencl::device()?;
1103    /// let x = dev.tensor([[2, 3], [4, 1]]);
1104    /// let z = x.sum(-1);
1105    /// assert_eq!(z.shape(), [2, 1]);
1106    /// let z = x.sum(0);
1107    /// assert_eq!(z.shape(), [1, 2]);
1108    /// let z = x.sum(..);
1109    /// assert_eq!(z.shape(), [1, 1]);
1110    /// # Ok::<(), zyx_opencl::ZyxError>(())
1111    /// ```
1112    #[must_use]
1113    pub fn sum(&self, axes: impl IntoAxes) -> Tensor<B> {
1114        let axes = axes.into_axes(self.rank());
1115        let shape = self.shape().reduce(&axes);
1116        let mut uniq = BTreeSet::new();
1117        debug_assert!(
1118            axes.into_iter().all(move |x| uniq.insert(x)),
1119            "Cannot sum tensor with shape {:?} by axes {:?}, because axes contain duplicates.",
1120            self.shape(),
1121            axes
1122        );
1123        tensor(
1124            self.backend.push(Node::Sum(self.id, axes, shape)).unwrap(),
1125            self.backend,
1126        )
1127    }
1128
1129    /// Reduce self by maximizing along axes. Shape is not squeezed.
1130    /// ```rust
1131    /// use zyx_opencl;
1132    /// let dev = zyx_opencl::device()?;
1133    /// let x = dev.tensor([[2, 3], [4, 1]]);
1134    /// let z = x.max(-1);
1135    /// assert_eq!(z.shape(), [2, 1]);
1136    /// let z = x.max(0);
1137    /// assert_eq!(z.shape(), [1, 2]);
1138    /// let z = x.max(..);
1139    /// assert_eq!(z.shape(), [1, 1]);
1140    /// # Ok::<(), zyx_opencl::ZyxError>(())
1141    /// ```
1142    #[must_use]
1143    pub fn max(&self, axes: impl IntoAxes) -> Tensor<B> {
1144        let axes = axes.into_axes(self.rank());
1145        let shape = self.shape().reduce(&axes);
1146        let mut uniq = BTreeSet::new();
1147        debug_assert!(
1148            axes.into_iter().all(move |x| uniq.insert(x)),
1149            "Cannot sum tensor with shape {:?} by axes {:?}, because axes contain duplicates.",
1150            self.shape(),
1151            axes
1152        );
1153        for a in &axes {
1154            debug_assert!(
1155                *a < shape.rank(),
1156                "Cannot sum tensor with shape {:?} by axes {:?}, because some axes are greater than rank.",
1157                self.shape(),
1158                axes
1159            );
1160        }
1161        tensor(
1162            self.backend.push(Node::Max(self.id, axes, shape)).unwrap(),
1163            self.backend,
1164        )
1165    }
1166
1167    /// Reduce self by calculating mean along axes
1168    #[must_use]
1169    pub fn mean(&self, axes: impl IntoAxes) -> Tensor<B> {
1170        let shape = self.shape();
1171        let axes = axes.into_axes(shape.rank());
1172        self.sum(axes.clone()) / axes.iter().copied().map(|a| shape[a]).product::<usize>() as i32
1173    }
1174
1175    /// Reduce self by calculating variance along axes
1176    #[must_use]
1177    pub fn var(&self, axes: impl IntoAxes) -> Tensor<B> {
1178        let axes = axes.into_axes(self.rank());
1179        (self - self.mean(axes.clone())).pow(2).sum(axes)
1180    }
1181
1182    /// Reduce self by calculating standard deviation along axes
1183    #[must_use]
1184    pub fn std(&self, axes: impl IntoAxes) -> Tensor<B> {
1185        self.var(axes).sqrt()
1186    }
1187
1188    /// Reduce self by calculating norm along axes
1189    #[must_use]
1190    pub fn norm(&self, axes: impl IntoAxes, p: impl Scalar) -> Tensor<B> {
1191        self.pow(p.clone()).sum(axes).pow(p.reciprocal())
1192    }
1193
1194    /// Reduce self by calculating product of elements along axes
1195    #[must_use]
1196    pub fn product(&self, axes: impl IntoAxes) -> Tensor<B> {
1197        self.ln().sum(axes).exp()
1198    }
1199
1200    /// Get elements on diagonal of square matrix
1201    #[must_use]
1202    pub fn diagonal(&self) -> Tensor<B> {
1203        let n: usize = self.shape()[-1];
1204        self.flatten(..)
1205            .pad([(0, n as i64)], 0)
1206            .reshape([n, n + 1])
1207            .get((.., 0))
1208    }
1209
1210    /*
1211    /// QR decompose function
1212    #[must_use]
1213    fn qr_decompose(&self) -> (Tensor<B>, Tensor<B>) {
1214        assert_eq!(self.rank(), 2, "QR decomposition only works for 2d matrices.");
1215        let dtype = self.dtype();
1216        assert!(dtype.is_floating(), "QR decomposition only works with floating point tensors.");
1217        let [n, m] = self.shape().try_into().unwrap();
1218        let u_temp = self.get((.., 0));
1219        let mut q = Vec::new();
1220        q.push(u_temp / u_temp.norm(()));
1221        for i in 1..n {
1222            let mut u_temp = self.get((.., i));
1223            // TODO all those dot operations should be fused into one by using expand and reshape and such.
1224            for j in 0..i {
1225                let q_temp = q.get((.., j));
1226                u_temp = u_temp - self.get((.., i)).dot(&q_temp) * &q_temp;
1227            }
1228            q.push(u_temp / u_temp.norm(.., 2));
1229        }
1230        let q = Tensor::cat(q, 0);
1231        let r = q.dot(self);
1232        return (q, r)
1233    }*/
1234
1235    /// Tensor indexing.
1236    ///
1237    /// Tensors can be indexed by tuples of any combination of values or ranges of i64.
1238    /// If indexing along more than 8 dimensions, use \&\[Range\<i64\>\] or \&\[i64\]
1239    /// ```rust
1240    /// use zyx_opencl;
1241    /// let dev = zyx_opencl::device()?;
1242    /// let x = dev.tensor([[2, 3, 4],
1243    ///                     [4, 1, 8]]);
1244    /// let y: i32 = x.get((-1, -3)).item()?;
1245    /// assert_eq!(y, 4);
1246    /// # Ok::<(), zyx_opencl::ZyxError>(())
1247    /// ```
1248    #[must_use]
1249    pub fn get(&self, index: impl IntoIndex) -> Tensor<B> {
1250        // TODO asserts
1251        let shape = self.shape();
1252        let padding: Vec<(i64, i64)> = index
1253            .into_index()
1254            .into_iter()
1255            .zip(shape.iter())
1256            .map(|(r, d)| {
1257                (
1258                    if r.start >= 0 {
1259                        -r.start
1260                    } else {
1261                        -r.start - *d as i64
1262                    },
1263                    if r.end == i64::MAX {
1264                        0
1265                    } else if r.end > 0 {
1266                        -(*d as i64 - r.end)
1267                    } else {
1268                        r.end
1269                    },
1270                )
1271            })
1272            .collect();
1273        //std::println!("Get padding: {padding:?}");
1274        let n = shape.rank() - padding.len();
1275        self.pad(
1276            padding
1277                .into_iter()
1278                .chain(repeat((0, 0)).take(n))
1279                .collect::<Vec<(i64, i64)>>()
1280                .into_iter()
1281                .rev(),
1282            0,
1283        )
1284    }
1285
1286    /// Concatenate multiple tensors together along dim.
1287    // ```rust
1288    // # use zyx_opencl;
1289    // # use zyx_opencl::Tensor;
1290    // let dev = zyx_opencl::device()?;
1291    // let x = dev.tensor([[2, 3, 4], [4, 1, 8]]);
1292    // let y = dev.tensor([[2, 3], [4, 1]]);
1293    // let z = Tensor::cat([&x, &y], -1);
1294    // // assert_eq!(z, []);
1295    // # Ok::<(), zyx_opencl::ZyxError>(())
1296    // ```
1297    #[must_use]
1298    pub fn cat<'a>(tensors: impl IntoIterator<Item = &'a Tensor<B>>, dim: i64) -> Tensor<B>
1299    where
1300        B: 'a,
1301    {
1302        let tensors: Vec<&Tensor<B>> = tensors.into_iter().collect();
1303        let shape = tensors[0].shape();
1304        let rank = shape.rank();
1305        let dim = if dim < 0 { dim + rank as i64 } else { dim } as usize;
1306        // Dimension check
1307        for tensor in &tensors {
1308            for (i, (d1, d2)) in shape.iter().zip(tensor.shape().iter()).enumerate() {
1309                if i != dim {
1310                    debug_assert_eq!(*d1, *d2, "Cannot concatenate these tensors.");
1311                }
1312            }
1313        }
1314        let mut offset = 0i64;
1315        let mut res = tensors[0]
1316            .backend
1317            .zeros(tensors[0].shape(), tensors[0].dtype())
1318            .unwrap();
1319        for tensor in tensors {
1320            res = res
1321                + tensor.pad(
1322                    repeat((0i64, 0i64))
1323                        .take(rank - dim - 1)
1324                        .chain([(offset, 0i64)]),
1325                    0,
1326                );
1327            offset += tensor.shape()[dim] as i64;
1328        }
1329        res
1330    }
1331
1332    // TODO Cholesky and QR solve functions that are backend accelerated
1333
1334    /*
1335    /// Stack multiple tensors into one
1336    #[must_use]
1337    pub fn stack<'a>(tensors: impl IntoIterator<Item = &'a Tensor<B>>, dim: i64) -> Tensor<B>
1338    where
1339        B: 'a
1340    {
1341        todo!()
1342    }*/
1343
1344    /*
1345    /// Split self into multiple tensors along dim with given sizes.
1346    // TODO example
1347    #[must_use]
1348    pub fn split(&self, sizes: &[usize], dim: i64) -> Vec<Tensor<B>> {
1349        // just use negative padding
1350        todo!()
1351    }*/
1352
1353    //#[must_use]
1354    //pub fn pool(&self)
1355
1356    //#[must_use]
1357    //pub fn conv(&self)
1358}
1359
1360enum BOp {
1361    Add,
1362    Sub,
1363    Mul,
1364    Div,
1365    Pow,
1366    Cmplt,
1367}
1368
1369// Private helper functions
1370impl<B: Backend> Tensor<B> {
1371    #[must_use]
1372    fn binary_op(self, rhs: impl IntoTensor<B>, op: BOp) -> Tensor<B> {
1373        let rhs = rhs.into_tensor(self.backend);
1374        let (x, y) = Tensor::broadcast(self, rhs);
1375        tensor(
1376            x.backend
1377                .push(match op {
1378                    BOp::Add => Node::Add(x.id, y.id),
1379                    BOp::Sub => Node::Sub(x.id, y.id),
1380                    BOp::Mul => Node::Mul(x.id, y.id),
1381                    BOp::Div => Node::Div(x.id, y.id),
1382                    BOp::Pow => Node::Pow(x.id, y.id),
1383                    BOp::Cmplt => Node::Cmplt(x.id, y.id),
1384                })
1385                .unwrap(),
1386            x.backend,
1387        )
1388    }
1389
1390    /// Braodcasts to synchronize shapes and casts to synchronize dtypss
1391    /// This does both automatic expand AND automatic casting between dtypes.
1392    // TODO Both of these can be disable by changing a setting in the backend.
1393    #[must_use]
1394    fn broadcast(mut x: Tensor<B>, mut y: Tensor<B>) -> (Tensor<B>, Tensor<B>) {
1395        /*assert_eq!(
1396            graph.dtype(xid),
1397            graph.dtype(yid),
1398            "{op} parameters {xid} and {yid} have different dtypes: {} and {}",
1399            graph.dtype(xid),
1400            graph.dtype(yid)
1401        );*/
1402        // Now we just do implicit conversions. Not exactly rust style, but it's convenient.
1403        // We can later add option for backend to disable these implicit conversions.
1404        match (x.dtype(), y.dtype()) {
1405            (DType::F32, DType::I32) => y = y.cast(DType::F32),
1406            (DType::F32, DType::F64) => x = x.cast(DType::F64),
1407            (DType::I32, DType::F32) => x = x.cast(DType::F32),
1408            (DType::I32, DType::F64) => x = x.cast(DType::F64),
1409            (DType::F64, DType::F32) => y = y.cast(DType::F64),
1410            (DType::F64, DType::I32) => y = y.cast(DType::F64),
1411            _ => {}
1412        }
1413        let mut x_shape = x.shape();
1414        let mut y_shape = y.shape();
1415
1416        for (x, y) in x_shape.iter().rev().zip(y_shape.iter().rev()) {
1417            if x != y {
1418                debug_assert!(
1419                    *x == 1 || *y == 1,
1420                    "Left and right tensor shapes can not be broadcasted: {x_shape} and {y_shape}"
1421                );
1422            }
1423        }
1424
1425        let rx = x_shape.rank();
1426        let ry = y_shape.rank();
1427        match rx.cmp(&ry) {
1428            Ordering::Less => {
1429                x_shape = repeat(1)
1430                    .take(ry - rx)
1431                    .chain(x_shape.into_iter().copied())
1432                    .collect::<Vec<usize>>()
1433                    .into();
1434            }
1435            Ordering::Greater => {
1436                y_shape = repeat(1)
1437                    .take(rx - ry)
1438                    .chain(y_shape.into_iter().copied())
1439                    .collect::<Vec<usize>>()
1440                    .into();
1441            }
1442            Ordering::Equal => {}
1443        }
1444        let mut eshape = Vec::new();
1445        for (x, y) in x_shape.into_iter().zip(y_shape.into_iter()) {
1446            eshape.push(*x.max(y));
1447        }
1448        let eshape: Shape = eshape.into();
1449        if x_shape != eshape {
1450            x = x.expand(eshape.clone());
1451        }
1452        if y_shape != eshape {
1453            y = y.expand(eshape);
1454        }
1455        (x, y)
1456    }
1457}
1458
1459impl<B: Backend> core::ops::Neg for Tensor<B> {
1460    type Output = Tensor<B>;
1461    fn neg(self) -> Self::Output {
1462        tensor(self.backend.push(Node::Neg(self.id)).unwrap(), self.backend)
1463    }
1464}
1465
1466impl<B: Backend> core::ops::Neg for &Tensor<B> {
1467    type Output = Tensor<B>;
1468    fn neg(self) -> Self::Output {
1469        tensor(self.backend.push(Node::Neg(self.id)).unwrap(), self.backend)
1470    }
1471}
1472
1473impl<B: Backend, IT: IntoTensor<B>> core::ops::Add<IT> for &Tensor<B> {
1474    type Output = Tensor<B>;
1475    fn add(self, rhs: IT) -> Self::Output {
1476        self.clone().binary_op(rhs, BOp::Add)
1477    }
1478}
1479
1480impl<B: Backend, IT: IntoTensor<B>> core::ops::Add<IT> for Tensor<B> {
1481    type Output = Tensor<B>;
1482    fn add(self, rhs: IT) -> Self::Output {
1483        self.binary_op(rhs, BOp::Add)
1484    }
1485}
1486
1487impl<B: Backend, IT: IntoTensor<B>> core::ops::Sub<IT> for &Tensor<B> {
1488    type Output = Tensor<B>;
1489    fn sub(self, rhs: IT) -> Self::Output {
1490        self.clone().binary_op(rhs, BOp::Sub)
1491    }
1492}
1493
1494impl<B: Backend, IT: IntoTensor<B>> core::ops::Sub<IT> for Tensor<B> {
1495    type Output = Tensor<B>;
1496    fn sub(self, rhs: IT) -> Self::Output {
1497        self.binary_op(rhs, BOp::Sub)
1498    }
1499}
1500
1501impl<B: Backend, IT: IntoTensor<B>> core::ops::Mul<IT> for &Tensor<B> {
1502    type Output = Tensor<B>;
1503    fn mul(self, rhs: IT) -> Self::Output {
1504        self.clone().binary_op(rhs, BOp::Mul)
1505    }
1506}
1507
1508impl<B: Backend> core::ops::Mul<Tensor<B>> for f32 {
1509    type Output = Tensor<B>;
1510    fn mul(self, rhs: Tensor<B>) -> Self::Output {
1511        rhs * self
1512    }
1513}
1514
1515impl<B: Backend> core::ops::Mul<&Tensor<B>> for f32 {
1516    type Output = Tensor<B>;
1517    fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1518        rhs * self
1519    }
1520}
1521
1522impl<B: Backend> core::ops::Mul<Tensor<B>> for f64 {
1523    type Output = Tensor<B>;
1524    fn mul(self, rhs: Tensor<B>) -> Self::Output {
1525        rhs * self
1526    }
1527}
1528
1529impl<B: Backend> core::ops::Mul<&Tensor<B>> for f64 {
1530    type Output = Tensor<B>;
1531    fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1532        rhs * self
1533    }
1534}
1535
1536impl<B: Backend> core::ops::Mul<Tensor<B>> for i32 {
1537    type Output = Tensor<B>;
1538    fn mul(self, rhs: Tensor<B>) -> Self::Output {
1539        rhs * self
1540    }
1541}
1542
1543impl<B: Backend> core::ops::Mul<&Tensor<B>> for i32 {
1544    type Output = Tensor<B>;
1545    fn mul(self, rhs: &Tensor<B>) -> Self::Output {
1546        rhs * self
1547    }
1548}
1549
1550impl<B: Backend, IT: IntoTensor<B>> core::ops::Mul<IT> for Tensor<B> {
1551    type Output = Tensor<B>;
1552    fn mul(self, rhs: IT) -> Self::Output {
1553        self.binary_op(rhs, BOp::Mul)
1554    }
1555}
1556
1557impl<B: Backend, IT: IntoTensor<B>> core::ops::Div<IT> for &Tensor<B> {
1558    type Output = Tensor<B>;
1559    fn div(self, rhs: IT) -> Self::Output {
1560        self.clone().binary_op(rhs, BOp::Div)
1561    }
1562}
1563
1564impl<B: Backend, IT: IntoTensor<B>> core::ops::Div<IT> for Tensor<B> {
1565    type Output = Tensor<B>;
1566    fn div(self, rhs: IT) -> Self::Output {
1567        self.binary_op(rhs, BOp::Div)
1568    }
1569}
1570
1571impl<B: Backend> core::ops::Div<Tensor<B>> for f32 {
1572    type Output = Tensor<B>;
1573    fn div(self, rhs: Tensor<B>) -> Self::Output {
1574        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1575    }
1576}
1577
1578impl<B: Backend> core::ops::Div<&Tensor<B>> for f32 {
1579    type Output = Tensor<B>;
1580    fn div(self, rhs: &Tensor<B>) -> Self::Output {
1581        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1582    }
1583}
1584
1585impl<B: Backend> core::ops::Div<Tensor<B>> for f64 {
1586    type Output = Tensor<B>;
1587    fn div(self, rhs: Tensor<B>) -> Self::Output {
1588        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1589    }
1590}
1591
1592impl<B: Backend> core::ops::Div<&Tensor<B>> for f64 {
1593    type Output = Tensor<B>;
1594    fn div(self, rhs: &Tensor<B>) -> Self::Output {
1595        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1596    }
1597}
1598
1599impl<B: Backend> core::ops::Div<Tensor<B>> for i32 {
1600    type Output = Tensor<B>;
1601    fn div(self, rhs: Tensor<B>) -> Self::Output {
1602        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1603    }
1604}
1605
1606impl<B: Backend> core::ops::Div<&Tensor<B>> for i32 {
1607    type Output = Tensor<B>;
1608    fn div(self, rhs: &Tensor<B>) -> Self::Output {
1609        rhs.backend.tensor(self).unwrap().binary_op(rhs, BOp::Div)
1610    }
1611}
1612
1613/// Objects must implement this to be convertible into tensor
1614pub trait IntoTensor<B: Backend> {
1615    /// Convert self into tensor
1616    fn into_tensor(self, backend: B) -> Tensor<B>;
1617}
1618
1619impl<B: Backend> IntoTensor<B> for Tensor<B> {
1620    fn into_tensor(self, _backend: B) -> Tensor<B> {
1621        // TODO assert self.backend == backend
1622        self
1623    }
1624}
1625
1626impl<B: Backend> IntoTensor<B> for &Tensor<B> {
1627    fn into_tensor(self, _backend: B) -> Tensor<B> {
1628        // TODO assert self.backend == backend
1629        self.clone()
1630    }
1631}
1632
1633impl<B: Backend, T: Scalar> IntoTensor<B> for Range<T>
1634where
1635    Range<T>: Iterator<Item = T> + ExactSizeIterator,
1636{
1637    fn into_tensor(self, backend: B) -> Tensor<B> {
1638        tensor(backend.store(self).unwrap(), backend)
1639    }
1640}
1641
1642impl<B: Backend, T: Scalar> IntoTensor<B> for Vec<T> {
1643    fn into_tensor(self, backend: B) -> Tensor<B> {
1644        tensor(backend.store(self).unwrap(), backend)
1645    }
1646}
1647
1648impl<B: Backend, T: Scalar> IntoTensor<B> for &'static [T] {
1649    fn into_tensor(self, backend: B) -> Tensor<B> {
1650        tensor(backend.store(self.iter().cloned()).unwrap(), backend)
1651    }
1652}
1653
1654impl<B: Backend, T: Scalar> IntoTensor<B> for T {
1655    fn into_tensor(self, backend: B) -> Tensor<B> {
1656        tensor(backend.store([self]).unwrap(), backend)
1657    }
1658}
1659
1660impl<B: Backend, T: Scalar, const D0: usize> IntoTensor<B> for [T; D0] {
1661    fn into_tensor(self, backend: B) -> Tensor<B> {
1662        tensor(backend.store(self).unwrap(), backend)
1663    }
1664}
1665
1666impl<B: Backend, T: Scalar, const D0: usize, const D1: usize> IntoTensor<B> for [[T; D1]; D0] {
1667    fn into_tensor(self, backend: B) -> Tensor<B> {
1668        tensor(
1669            backend
1670                .store(self.into_iter().flatten().make_sized(D0 * D1))
1671                .unwrap(),
1672            backend,
1673        )
1674        .reshape([D0, D1])
1675    }
1676}
1677
1678impl<B: Backend, T: Scalar, const D0: usize, const D1: usize, const D2: usize> IntoTensor<B>
1679    for [[[T; D2]; D1]; D0]
1680{
1681    fn into_tensor(self, backend: B) -> Tensor<B> {
1682        tensor(
1683            backend
1684                .store(
1685                    self.into_iter()
1686                        .flatten()
1687                        .flatten()
1688                        .make_sized(D0 * D1 * D2),
1689                )
1690                .unwrap(),
1691            backend,
1692        )
1693        .reshape([D0, D1, D2])
1694    }
1695}
1696
1697impl<B: Backend, IT: IntoTensor<B> + Clone> PartialEq<IT> for Tensor<B> {
1698    fn eq(&self, other: &IT) -> bool {
1699        let other = self.backend.tensor(other.clone()).unwrap();
1700        let dtype = self.dtype();
1701        self.shape() == other.shape()
1702            && dtype == other.dtype()
1703            && match dtype {
1704                DType::F32 => self
1705                    .to_vec::<f32>()
1706                    .unwrap()
1707                    .into_iter()
1708                    .zip(other.to_vec::<f32>().unwrap())
1709                    .all(|(x, y)| x.is_equal(y)),
1710                DType::F64 => self
1711                    .to_vec::<f64>()
1712                    .unwrap()
1713                    .into_iter()
1714                    .zip(other.to_vec::<f64>().unwrap())
1715                    .all(|(x, y)| x.is_equal(y)),
1716                DType::I32 => self
1717                    .to_vec::<i32>()
1718                    .unwrap()
1719                    .into_iter()
1720                    .zip(other.to_vec::<i32>().unwrap())
1721                    .all(|(x, y)| x.is_equal(y)),
1722            }
1723    }
1724}