tch_plus/tensor/
iter.rs

1use super::Tensor;
2use crate::TchError;
3
4pub struct Iter<T> {
5    index: i64,
6    len: i64,
7    content: Tensor,
8    phantom: std::marker::PhantomData<T>,
9}
10
11impl Tensor {
12    pub fn iter<T>(&self) -> Result<Iter<T>, TchError> {
13        Ok(Iter {
14            index: 0,
15            len: self.size1()?,
16            content: self.shallow_clone(),
17            phantom: std::marker::PhantomData,
18        })
19    }
20}
21
22impl std::iter::Iterator for Iter<i64> {
23    type Item = i64;
24    fn next(&mut self) -> Option<Self::Item> {
25        if self.index >= self.len {
26            return None;
27        }
28        let v = self.content.int64_value(&[self.index]);
29        self.index += 1;
30        Some(v)
31    }
32}
33
34impl std::iter::Iterator for Iter<f64> {
35    type Item = f64;
36    fn next(&mut self) -> Option<Self::Item> {
37        if self.index >= self.len {
38            return None;
39        }
40        let v = self.content.double_value(&[self.index]);
41        self.index += 1;
42        Some(v)
43    }
44}
45
46impl std::iter::Sum for Tensor {
47    fn sum<I: Iterator<Item = Tensor>>(mut iter: I) -> Tensor {
48        match iter.next() {
49            None => Tensor::from(0.),
50            Some(t) => iter.fold(t, |acc, x| x + acc),
51        }
52    }
53}
54
55impl<'a> std::iter::Sum<&'a Tensor> for Tensor {
56    fn sum<I: Iterator<Item = &'a Tensor>>(mut iter: I) -> Tensor {
57        match iter.next() {
58            None => Tensor::from(0.),
59            Some(t) => iter.fold(t.shallow_clone(), |acc, x| x + acc),
60        }
61    }
62}