sp1_recursion_compiler/ir/
iter.rs

1use std::mem;
2
3use super::{Builder, Config, DslIr, DslIrBlock};
4
5pub trait IrIter<C: Config, Item>: Sized {
6    fn ir_par_map_collect<B, F, S>(self, builder: &mut Builder<C>, map_op: F) -> B
7    where
8        F: FnMut(&mut Builder<C>, Item) -> S,
9        B: Default + Extend<S>;
10}
11
12impl<C, I, Item> IrIter<C, Item> for I
13where
14    C: Config,
15    I: Iterator<Item = Item>,
16{
17    fn ir_par_map_collect<B, F, S>(self, builder: &mut Builder<C>, mut map_op: F) -> B
18    where
19        F: FnMut(&mut Builder<C>, I::Item) -> S,
20        B: Default + Extend<S>,
21    {
22        let prev_ops = mem::take(builder.get_mut_operations());
23        let (blocks, coll): (Vec<_>, B) = self
24            .map(|r| {
25                let next_addr = builder.variable_count();
26                let s = map_op(builder, r);
27                let block = DslIrBlock {
28                    ops: mem::take(builder.get_mut_operations()),
29                    addrs_written: next_addr..builder.variable_count(),
30                };
31                (block, s)
32            })
33            .unzip();
34        *builder.get_mut_operations() = prev_ops;
35        builder.push_op(DslIr::Parallel(blocks));
36        coll
37    }
38}