sp1_stark/shape/
cluster.rs

1use std::{fmt::Debug, hash::Hash, str::FromStr};
2
3use hashbrown::HashMap;
4use serde::{Deserialize, Serialize};
5
6use super::Shape;
7
8/// A cluster of shapes.
9///
10/// We represent a cluster of shapes as a cartesian product of heights per chip.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct ShapeCluster<K: Eq + Hash + FromStr> {
13    inner: HashMap<K, Vec<Option<usize>>>,
14}
15
16impl<K: Debug + Clone + Eq + Hash + FromStr> ShapeCluster<K> {
17    /// Create a new shape cluster.
18    #[must_use]
19    pub fn new(inner: HashMap<K, Vec<Option<usize>>>) -> Self {
20        Self { inner }
21    }
22
23    /// Find the shape that is larger or equal to the given heights.
24    pub fn find_shape(&self, heights: &[(K, usize)]) -> Option<Shape<K>> {
25        let shape: Option<HashMap<K, Option<usize>>> = heights
26            .iter()
27            .map(|(air, height)| {
28                for maybe_log2_height in self.inner.get(air).into_iter().flatten() {
29                    let allowed_height =
30                        maybe_log2_height.map(|log_height| 1 << log_height).unwrap_or_default();
31                    if *height <= allowed_height {
32                        return Some((air.clone(), *maybe_log2_height));
33                    }
34                }
35                None
36            })
37            .collect();
38
39        let mut inner = shape?;
40        inner.retain(|_, &mut value| value.is_some());
41
42        let shape = inner
43            .into_iter()
44            .map(|(air, maybe_log_height)| (air, maybe_log_height.unwrap()))
45            .collect::<Shape<K>>();
46
47        Some(shape)
48    }
49
50    /// Iterate over the inner map.
51    pub fn iter(&self) -> impl Iterator<Item = (&K, &Vec<Option<usize>>)> {
52        self.inner.iter()
53    }
54}