1use super::Tensor;
2use crate::TchError;
3
4pub struct Iter<T> {
5 index: i64,
6 len: i64,
7 content: Tensor,
8 phantom: std::marker::PhantomData<T>,
9}
10
11impl Tensor {
12 pub fn iter<T>(&self) -> Result<Iter<T>, TchError> {
13 Ok(Iter {
14 index: 0,
15 len: self.size1()?,
16 content: self.shallow_clone(),
17 phantom: std::marker::PhantomData,
18 })
19 }
20}
21
22impl std::iter::Iterator for Iter<i64> {
23 type Item = i64;
24 fn next(&mut self) -> Option<Self::Item> {
25 if self.index >= self.len {
26 return None;
27 }
28 let v = self.content.int64_value(&[self.index]);
29 self.index += 1;
30 Some(v)
31 }
32}
33
34impl std::iter::Iterator for Iter<f64> {
35 type Item = f64;
36 fn next(&mut self) -> Option<Self::Item> {
37 if self.index >= self.len {
38 return None;
39 }
40 let v = self.content.double_value(&[self.index]);
41 self.index += 1;
42 Some(v)
43 }
44}
45
46impl std::iter::Sum for Tensor {
47 fn sum<I: Iterator<Item = Tensor>>(mut iter: I) -> Tensor {
48 match iter.next() {
49 None => Tensor::from(0.),
50 Some(t) => iter.fold(t, |acc, x| x + acc),
51 }
52 }
53}
54
55impl<'a> std::iter::Sum<&'a Tensor> for Tensor {
56 fn sum<I: Iterator<Item = &'a Tensor>>(mut iter: I) -> Tensor {
57 match iter.next() {
58 None => Tensor::from(0.),
59 Some(t) => iter.fold(t.shallow_clone(), |acc, x| x + acc),
60 }
61 }
62}