spenso/iterators/
fiber_iterators.rs

1//! High-level fiber iterators for tensor traversal
2//!
3//! This module contains iterators that build on the core iterators to provide
4//! high-level fiber iteration capabilities for various tensor types.
5
6use gat_lending_iterator::LendingIterator;
7use linnet::permutation::Permutation;
8use std::fmt::Debug;
9
10use crate::{
11    structure::{representation::LibraryRep, slot::IsAbstractSlot, TensorStructure},
12    tensors::data::{DenseTensor, GetTensorData, SparseTensor},
13};
14
15use super::{
16    core_iterators::CoreFlatFiberIterator,
17    fiber::{Fiber, FiberClass, FiberMut},
18    traits::ResetableIterator,
19    FiberIteratorItem, IteratesAlongFibers, IteratesAlongPermutedFibers,
20};
21
22/// Iterator for traversing tensor fibers
23///
24/// This high-level iterator uses a core iterator to traverse fibers in tensors,
25/// returning references to tensor elements at each position.
26#[derive(Debug)]
27pub struct FiberIterator<
28    'a,
29    S: TensorStructure,
30    I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R>,
31> {
32    /// The fiber being iterated
33    pub fiber: Fiber<'a, S>,
34    /// The underlying core iterator
35    pub iter: I,
36    /// Number of indices skipped
37    pub skipped: usize,
38}
39
40impl<S: TensorStructure, I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R> + Clone> Clone
41    for FiberIterator<'_, S, I>
42{
43    fn clone(&self) -> Self {
44        FiberIterator {
45            fiber: self.fiber.clone(),
46            iter: self.iter.clone(),
47            skipped: self.skipped,
48        }
49    }
50}
51
52impl<'a, S: TensorStructure, I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R>>
53    FiberIterator<'a, S, I>
54{
55    /// Creates a new fiber iterator
56    ///
57    /// # Arguments
58    ///
59    /// * `fiber` - The fiber to iterate over
60    /// * `conj` - Whether to use conjugate iteration
61    pub fn new(fiber: Fiber<'a, S>, conj: bool) -> Self {
62        FiberIterator {
63            iter: I::new(&fiber, conj),
64            fiber,
65            skipped: 0,
66        }
67    }
68
69    /// Resets the iterator to its initial state
70    pub fn reset(&mut self) {
71        self.iter.reset();
72        self.skipped = 0;
73    }
74
75    /// Shifts the iterator by the given amount
76    ///
77    /// # Arguments
78    ///
79    /// * `shift` - The amount to shift by
80    pub fn shift(&mut self, shift: usize) {
81        self.reset();
82        self.iter.shift(shift);
83    }
84}
85
86impl<'a, S: TensorStructure, I: IteratesAlongPermutedFibers<<S::Slot as IsAbstractSlot>::R>>
87    FiberIterator<'a, S, I>
88{
89    /// Creates a new fiber iterator with a permutation
90    ///
91    /// # Arguments
92    ///
93    /// * `fiber` - The fiber to iterate over
94    /// * `permutation` - The permutation to apply
95    /// * `conj` - Whether to use conjugate iteration
96    pub fn new_permuted(fiber: Fiber<'a, S>, permutation: Permutation, conj: bool) -> Self {
97        FiberIterator {
98            iter: I::new_permuted(&fiber, conj, permutation),
99            fiber,
100            skipped: 0,
101        }
102    }
103}
104
105impl<I: IteratesAlongFibers<LibraryRep>> Iterator
106    for FiberIterator<'_, crate::structure::OrderedStructure, I>
107{
108    type Item = I::Item;
109    fn next(&mut self) -> Option<Self::Item> {
110        self.iter.next()
111    }
112}
113
114impl<
115        'a,
116        I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R, Item = It>,
117        S: TensorStructure,
118        T,
119        It,
120    > Iterator for FiberIterator<'a, DenseTensor<T, S>, I>
121where
122    It: FiberIteratorItem,
123{
124    type Item = (&'a T, It::OtherData);
125    fn next(&mut self) -> Option<Self::Item> {
126        self.iter.next().map(|x| {
127            // println!(
128            //     "DenseTensor: flat_idx: {}, size: {:?}",
129            //     x.flat_idx(),
130            //     self.fiber.structure.size()
131            // );
132            if let Some(t) = self.fiber.structure.get_ref_linear(x.flat_idx()) {
133                (t, x.other_data())
134            } else {
135                panic!(
136                    "DenseTensor: Out of bounds {} {}",
137                    x.flat_idx(),
138                    self.fiber.structure.size().unwrap()
139                )
140            }
141        })
142    }
143}
144
145impl<
146        'a,
147        I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R, Item = It>,
148        S: TensorStructure,
149        T,
150        It,
151    > Iterator for FiberIterator<'a, SparseTensor<T, S>, I>
152where
153    It: FiberIteratorItem,
154{
155    type Item = (&'a T, usize, It::OtherData);
156    fn next(&mut self) -> Option<Self::Item> {
157        if let Some(i) = self.iter.next() {
158            if let Some(t) = self.fiber.structure.get_ref_linear(i.flat_idx()) {
159                let skipped = self.skipped;
160                self.skipped = 0;
161                return Some((t, skipped, i.other_data()));
162            } else {
163                self.skipped += 1;
164                return self.next();
165            }
166        }
167        None
168    }
169}
170
171/// Mutable iterator for traversing tensor fibers
172///
173/// Similar to `FiberIterator` but returns mutable references to tensor elements.
174pub struct MutFiberIterator<
175    'a,
176    S: TensorStructure,
177    I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R>,
178> {
179    /// The underlying core iterator
180    iter: I,
181    /// The fiber being iterated
182    fiber: FiberMut<'a, S>,
183    /// Number of indices skipped
184    skipped: usize,
185}
186
187impl<
188        I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R, Item = It>,
189        S: TensorStructure,
190        T,
191        It,
192    > LendingIterator for MutFiberIterator<'_, SparseTensor<T, S>, I>
193where
194    It: FiberIteratorItem,
195{
196    type Item<'r>
197        = (&'r mut T, usize, It::OtherData)
198    where
199        Self: 'r;
200    fn next(&mut self) -> Option<Self::Item<'_>> {
201        let flat = self.iter.next()?;
202        if self.fiber.structure.is_empty_at_flat(flat.flat_idx()) {
203            let skipped = self.skipped;
204            self.skipped = 0;
205            Some((
206                self.fiber
207                    .structure
208                    .get_mut_linear(flat.flat_idx())
209                    .unwrap(),
210                skipped,
211                flat.other_data(),
212            ))
213        } else {
214            self.skipped += 1;
215            self.next()
216        }
217    }
218}
219
220impl<
221        I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R, Item = It>,
222        S: TensorStructure,
223        T,
224        It,
225    > LendingIterator for MutFiberIterator<'_, DenseTensor<T, S>, I>
226where
227    It: FiberIteratorItem,
228{
229    type Item<'r>
230        = (&'r mut T, It::OtherData)
231    where
232        Self: 'r;
233    fn next(&mut self) -> Option<Self::Item<'_>> {
234        self.iter.next().map(|x| {
235            (
236                self.fiber.structure.get_mut_linear(x.flat_idx()).unwrap(),
237                x.other_data(),
238            )
239        })
240    }
241}
242
243impl<'a, S: TensorStructure, I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R>>
244    MutFiberIterator<'a, S, I>
245{
246    /// Creates a new mutable fiber iterator
247    ///
248    /// # Arguments
249    ///
250    /// * `fiber` - The fiber to iterate over
251    /// * `conj` - Whether to use conjugate iteration
252    pub fn new(fiber: FiberMut<'a, S>, conj: bool) -> Self {
253        MutFiberIterator {
254            iter: I::new(&fiber, conj),
255            fiber,
256            skipped: 0,
257        }
258    }
259
260    /// Resets the iterator to its initial state
261    pub fn reset(&mut self) {
262        self.iter.reset();
263        self.skipped = 0;
264    }
265
266    /// Shifts the iterator by the given amount
267    ///
268    /// # Arguments
269    ///
270    /// * `shift` - The amount to shift by
271    pub fn shift(&mut self, shift: usize) {
272        self.iter.shift(shift);
273    }
274}
275
276impl<'a, S: TensorStructure, I: IteratesAlongPermutedFibers<<S::Slot as IsAbstractSlot>::R>>
277    MutFiberIterator<'a, S, I>
278{
279    /// Creates a new mutable fiber iterator with a permutation
280    ///
281    /// # Arguments
282    ///
283    /// * `fiber` - The fiber to iterate over
284    /// * `permutation` - The permutation to apply
285    /// * `conj` - Whether to use conjugate iteration
286    pub fn new_permuted(fiber: FiberMut<'a, S>, permutation: Permutation, conj: bool) -> Self {
287        MutFiberIterator {
288            iter: I::new_permuted(&fiber, conj, permutation),
289            fiber,
290            skipped: 0,
291        }
292    }
293}
294
295/// Iterator for traversing fiber classes
296///
297/// Iterates over all fibers in a fiber class, returning an iterator for each fiber.
298pub struct FiberClassIterator<
299    'b,
300    S: TensorStructure,
301    I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R> = CoreFlatFiberIterator,
302> {
303    /// Iterator over fibers within the class
304    pub fiber_iter: FiberIterator<'b, S, I>,
305    /// Iterator over indices of fibers in the class
306    pub class_iter: CoreFlatFiberIterator,
307}
308
309impl<'b, N: TensorStructure> FiberClassIterator<'b, N, CoreFlatFiberIterator> {
310    /// Creates a new fiber class iterator
311    ///
312    /// # Arguments
313    ///
314    /// * `class` - The fiber class to iterate over
315    pub fn new(class: FiberClass<'b, N>) -> Self {
316        let (iter, iter_conj) = CoreFlatFiberIterator::new_paired_conjugates(&class);
317
318        let fiber = FiberIterator {
319            fiber: class.into(),
320            iter,
321            skipped: 0,
322        };
323
324        FiberClassIterator {
325            fiber_iter: fiber,
326            class_iter: iter_conj,
327        }
328    }
329}
330
331impl<N: TensorStructure, I: IteratesAlongFibers<<N::Slot as IsAbstractSlot>::R>>
332    FiberClassIterator<'_, N, I>
333{
334    /// Resets the iterator to its initial state
335    pub fn reset(&mut self) {
336        self.class_iter.reset();
337        self.fiber_iter.reset();
338        self.fiber_iter.shift(0);
339    }
340}
341
342impl<'b, N: TensorStructure, I: IteratesAlongPermutedFibers<<N::Slot as IsAbstractSlot>::R>>
343    FiberClassIterator<'b, N, I>
344{
345    /// Creates a new fiber class iterator with a permutation
346    ///
347    /// # Arguments
348    ///
349    /// * `class` - The fiber class to iterate over
350    /// * `permutation` - The permutation to apply
351    pub fn new_permuted(class: FiberClass<'b, N>, permutation: Permutation) -> Self {
352        let iter = CoreFlatFiberIterator::new(&class, false);
353
354        let fiber = FiberIterator::new_permuted(class.into(), permutation, false);
355
356        FiberClassIterator {
357            fiber_iter: fiber,
358            class_iter: iter,
359        }
360    }
361}
362
363impl<
364        'a,
365        S: TensorStructure + 'a,
366        I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R> + Clone + Debug,
367    > Iterator for FiberClassIterator<'a, S, I>
368{
369    type Item = FiberIterator<'a, S, I>;
370
371    fn next(&mut self) -> Option<Self::Item> {
372        let shift = self.class_iter.next()?;
373        self.fiber_iter.reset();
374        self.fiber_iter.shift(shift.into());
375        Some(self.fiber_iter.clone())
376    }
377}
378
379impl<'a, S: TensorStructure + 'a, I: IteratesAlongFibers<<S::Slot as IsAbstractSlot>::R>>
380    LendingIterator for FiberClassIterator<'a, S, I>
381{
382    type Item<'r>
383        = &'r mut FiberIterator<'a, S, I>
384    where
385        Self: 'r;
386
387    fn next(&mut self) -> Option<Self::Item<'_>> {
388        let shift = self.class_iter.next()?;
389        self.fiber_iter.reset();
390        self.fiber_iter.shift(shift.into());
391        Some(&mut self.fiber_iter)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397
398    use crate::structure::{
399        representation::{Euclidean, RepName},
400        OrderedStructure, PermutedStructure,
401    };
402
403    use super::*;
404
405    #[test]
406    fn weaved_iterator() {
407        let strct: DenseTensor<u32, OrderedStructure<Euclidean>> = DenseTensor::zero(
408            PermutedStructure::from_iter([
409                Euclidean {}.new_slot(4, 1),
410                Euclidean {}.new_slot(4, 2),
411                Euclidean {}.new_slot(4, 3),
412                Euclidean {}.new_slot(4, 4),
413            ])
414            .structure,
415        );
416
417        let fiber_spec = [true, false, true, false];
418        let self_fiber_class = Fiber::from(fiber_spec.as_slice(), &strct.structure); //We use the partition as a filter here, for indices that belong to self, vs those that belong to other
419        let (self_fiber_class_iter, mut _other_fiber_class_iter) =
420            CoreFlatFiberIterator::new_paired_conjugates(&self_fiber_class); // these are iterators over the open indices of self and other, except expressed in the flat indices of the resulting structure
421
422        for i in self_fiber_class_iter {
423            println!("{}-> {:?}", i, strct.expanded_index(i))
424        }
425    }
426}