Skip to main content

vortex_utils/
iter.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Iterator extension traits.
5
6/// An extension trait for iterators that provides balanced binary tree reduction.
7///
8/// Unlike [`Iterator::reduce`], which builds a left-leaning linear chain of depth N,
9/// `reduce_balanced` builds a balanced binary tree of depth log(N). This avoids deep
10/// nesting that can cause stack overflows on drop or suboptimal evaluation.
11///
12/// ```text
13/// reduce:          reduce_balanced:
14///     f                  f
15///    / \                / \
16///   f   d              f   f
17///  / \                / \ / \
18/// f   c              a  b c  d
19/// |\
20/// a b
21/// ```
22pub trait ReduceBalancedIterExt: Iterator {
23    /// Like [`Iterator::reduce`], but builds a balanced binary tree instead of a linear chain.
24    ///
25    /// `[a, b, c, d]` becomes `combine(combine(a, b), combine(c, d))`.
26    ///
27    /// Returns `None` if the iterator is empty.
28    fn reduce_balanced<F>(self, combine: F) -> Option<Self::Item>
29    where
30        Self::Item: Clone,
31        F: Fn(Self::Item, Self::Item) -> Self::Item;
32
33    /// Fallible version of [`reduce_balanced`](ReduceBalancedIterExt::reduce_balanced).
34    ///
35    /// Short-circuits on the first error.
36    fn try_reduce_balanced<F, E>(self, combine: F) -> Result<Option<Self::Item>, E>
37    where
38        Self::Item: Clone,
39        F: Fn(Self::Item, Self::Item) -> Result<Self::Item, E>;
40}
41
42impl<I: Iterator + Sized> ReduceBalancedIterExt for I {
43    fn reduce_balanced<F>(self, combine: F) -> Option<Self::Item>
44    where
45        Self::Item: Clone,
46        F: Fn(Self::Item, Self::Item) -> Self::Item,
47    {
48        let mut items: Vec<_> = self.collect();
49        if items.is_empty() {
50            return None;
51        }
52        if items.len() == 1 {
53            return items.pop();
54        }
55
56        while items.len() > 1 {
57            let len = items.len();
58
59            for target_idx in 0..(len / 2) {
60                let item_idx = target_idx * 2;
61                let new = combine(items[item_idx].clone(), items[item_idx + 1].clone());
62                items[target_idx] = new;
63            }
64
65            if !len.is_multiple_of(2) {
66                // Merge the odd element into the last paired element so it stays inside the tree.
67                let lhs = items[(len / 2) - 1].clone();
68                let rhs = items[len - 1].clone();
69                items[len / 2 - 1] = combine(lhs, rhs);
70            }
71
72            items.truncate(len / 2);
73        }
74
75        assert_eq!(items.len(), 1);
76        items.pop()
77    }
78
79    fn try_reduce_balanced<F, E>(self, combine: F) -> Result<Option<Self::Item>, E>
80    where
81        Self::Item: Clone,
82        F: Fn(Self::Item, Self::Item) -> Result<Self::Item, E>,
83    {
84        let mut items: Vec<_> = self.collect();
85        if items.is_empty() {
86            return Ok(None);
87        }
88        if items.len() == 1 {
89            return Ok(items.pop());
90        }
91
92        while items.len() > 1 {
93            let len = items.len();
94
95            for target_idx in 0..(len / 2) {
96                let item_idx = target_idx * 2;
97                let new = combine(items[item_idx].clone(), items[item_idx + 1].clone())?;
98                items[target_idx] = new;
99            }
100
101            if !len.is_multiple_of(2) {
102                let lhs = items[(len / 2) - 1].clone();
103                let rhs = items[len - 1].clone();
104                items[len / 2 - 1] = combine(lhs, rhs)?;
105            }
106
107            items.truncate(len / 2);
108        }
109
110        assert_eq!(items.len(), 1);
111        Ok(items.pop())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_empty() {
121        let result = std::iter::empty::<i32>().reduce_balanced(|a, b| a + b);
122        assert_eq!(result, None);
123    }
124
125    #[test]
126    fn test_single() {
127        let result = [42].into_iter().reduce_balanced(|a, b| a + b);
128        assert_eq!(result, Some(42));
129    }
130
131    #[test]
132    fn test_two() {
133        let result = [1, 2].into_iter().reduce_balanced(|a, b| a + b);
134        assert_eq!(result, Some(3));
135    }
136
137    #[test]
138    fn test_power_of_two() {
139        let result = [1, 2, 3, 4].into_iter().reduce_balanced(|a, b| a + b);
140        assert_eq!(result, Some(10));
141    }
142
143    #[test]
144    fn test_odd_count() {
145        let result = [1, 2, 3, 4, 5].into_iter().reduce_balanced(|a, b| a + b);
146        assert_eq!(result, Some(15));
147    }
148
149    #[test]
150    fn test_balanced_structure() {
151        // Use string concatenation to verify the tree shape.
152        // [a, b, c, d] should produce ((a+b)+(c+d)), not (((a+b)+c)+d).
153        let result = ["a", "b", "c", "d"]
154            .into_iter()
155            .map(String::from)
156            .reduce_balanced(|a, b| format!("({a}+{b})"));
157        assert_eq!(result, Some("((a+b)+(c+d))".to_string()));
158    }
159
160    #[test]
161    fn test_balanced_structure_odd() {
162        // [a, b, c] should produce ((a+b)+c) — odd element merges into last pair.
163        let result = ["a", "b", "c"]
164            .into_iter()
165            .map(String::from)
166            .reduce_balanced(|a, b| format!("({a}+{b})"));
167        assert_eq!(result, Some("((a+b)+c)".to_string()));
168    }
169
170    #[test]
171    fn test_balanced_structure_five() {
172        // [a, b, c, d, e] => ((a+b)+((c+d)+e))
173        let result = ["a", "b", "c", "d", "e"]
174            .into_iter()
175            .map(String::from)
176            .reduce_balanced(|a, b| format!("({a}+{b})"));
177        assert_eq!(result, Some("((a+b)+((c+d)+e))".to_string()));
178    }
179
180    #[test]
181    fn test_try_reduce_balanced_ok() {
182        let result: Result<_, &str> = [1, 2, 3, 4]
183            .into_iter()
184            .try_reduce_balanced(|a, b| Ok(a + b));
185        assert_eq!(result, Ok(Some(10)));
186    }
187
188    #[test]
189    fn test_try_reduce_balanced_err() {
190        let result: Result<Option<i32>, &str> = [1, 2, 3, 4]
191            .into_iter()
192            .try_reduce_balanced(|a, b| if a + b > 4 { Err("too big") } else { Ok(a + b) });
193        assert_eq!(result, Err("too big"));
194    }
195
196    #[test]
197    fn test_try_reduce_balanced_empty() {
198        let result: Result<_, &str> =
199            std::iter::empty::<i32>().try_reduce_balanced(|a, b| Ok(a + b));
200        assert_eq!(result, Ok(None));
201    }
202
203    #[test]
204    fn test_try_reduce_balanced_single() {
205        let result: Result<_, &str> = [42].into_iter().try_reduce_balanced(|a, b| Ok(a + b));
206        assert_eq!(result, Ok(Some(42)));
207    }
208}