simplicity/types/
final_data.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Finalized (Complete) Type Data
4//!
5//! Once a type is complete (has no free variables), it can be represented as
6//! a much simpler data structure than [`super::Type`], which we call [`Final`].
7//! This contains a recursively-defined [`CompleteBound`] which specifies what
8//! the type is, as well as a cached Merkle root (the TMR) and bit-width.
9//!
10//! We refer to types as "finalized" when they are represented by this data
11//! structure, since this structure is immutable.
12//!
13
14use crate::dag::{Dag, DagLike, NoSharing};
15use crate::Tmr;
16
17use std::sync::Arc;
18use std::{cmp, fmt, hash};
19
20/// A finalized type bound, whose tree is accessible without any mutex locking
21#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
22pub enum CompleteBound {
23    /// The unit type
24    Unit,
25    /// A sum of two other types
26    Sum(Arc<Final>, Arc<Final>),
27    /// A product of two other types
28    Product(Arc<Final>, Arc<Final>),
29}
30
31/// Data related to a finalized type, which can be extracted from a [`super::Type`]
32/// if (and only if) it is finalized.
33#[derive(Clone)]
34pub struct Final {
35    /// Underlying type
36    bound: CompleteBound,
37    /// Width of the type, in bits, in the bit machine
38    bit_width: usize,
39    /// Whether the type's bit representation has any padding. If this is true,
40    /// then its "compact" witness-encoded bit-width may be lower than its "padded"
41    /// bit-machine bit-width.
42    has_padding: bool,
43    /// TMR of the type
44    tmr: Tmr,
45}
46
47impl PartialEq for Final {
48    fn eq(&self, other: &Self) -> bool {
49        self.tmr == other.tmr
50    }
51}
52impl Eq for Final {}
53
54impl PartialOrd for Final {
55    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
56        Some(self.cmp(other))
57    }
58}
59impl Ord for Final {
60    fn cmp(&self, other: &Self) -> cmp::Ordering {
61        self.tmr.cmp(&other.tmr)
62    }
63}
64impl hash::Hash for Final {
65    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
66        self.tmr.hash(hasher)
67    }
68}
69
70impl fmt::Debug for Final {
71    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72        write!(
73            f,
74            "{{ tmr: {}, bit_width: {}, bound: {} }}",
75            self.tmr, self.bit_width, self
76        )
77    }
78}
79
80impl fmt::Display for Final {
81    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82        let mut skipping: Option<Tmr> = None;
83        for data in self.verbose_pre_order_iter::<NoSharing>(None) {
84            if let Some(skip) = skipping {
85                if data.is_complete && data.node.tmr == skip {
86                    skipping = None;
87                }
88                continue;
89            } else {
90                if data.node.tmr == Tmr::TWO_TWO_N[0] {
91                    f.write_str("2")?;
92                    skipping = Some(data.node.tmr);
93                }
94                for (n, tmr) in Tmr::TWO_TWO_N.iter().enumerate().skip(1) {
95                    if data.node.tmr == *tmr {
96                        write!(f, "2^{}", 1 << n)?;
97                        skipping = Some(data.node.tmr);
98                    }
99                }
100            }
101            if skipping.is_some() {
102                continue;
103            }
104
105            match (&data.node.bound, data.n_children_yielded) {
106                (CompleteBound::Unit, _) => {
107                    f.write_str("1")?;
108                }
109                // special-case 1 + A as A?
110                (CompleteBound::Sum(ref left, _), 0)
111                    if matches!(left.bound, CompleteBound::Unit) =>
112                {
113                    skipping = Some(Tmr::unit());
114                }
115                (CompleteBound::Sum(ref left, _), 1)
116                    if matches!(left.bound, CompleteBound::Unit) => {}
117                (CompleteBound::Sum(ref left, _), 2)
118                    if matches!(left.bound, CompleteBound::Unit) =>
119                {
120                    f.write_str("?")?;
121                }
122                // other sums and products
123                (CompleteBound::Sum(..), 0) | (CompleteBound::Product(..), 0) => {
124                    if data.index > 0 {
125                        f.write_str("(")?;
126                    }
127                }
128                (CompleteBound::Sum(..), 2) | (CompleteBound::Product(..), 2) => {
129                    if data.index > 0 {
130                        f.write_str(")")?;
131                    }
132                }
133                (CompleteBound::Sum(..), _) => f.write_str(" + ")?,
134                (CompleteBound::Product(..), _) => f.write_str(" × ")?,
135            }
136        }
137        Ok(())
138    }
139}
140
141impl DagLike for &'_ Final {
142    type Node = Final;
143    fn data(&self) -> &Final {
144        self
145    }
146    fn as_dag_node(&self) -> Dag<Self> {
147        match self.bound {
148            CompleteBound::Unit => Dag::Nullary,
149            CompleteBound::Sum(ref left, ref right)
150            | CompleteBound::Product(ref left, ref right) => Dag::Binary(left, right),
151        }
152    }
153}
154
155macro_rules! construct_final_two_two_n {
156    ($name: ident, $n: expr, $text: expr) => {
157        #[doc = "Create the type of"]
158        #[doc = $text]
159        #[doc = "words.\n\nThe type is precomputed and fast to access."]
160        pub fn $name() -> Arc<Self> {
161            super::precomputed::nth_power_of_2($n)
162        }
163    };
164}
165
166impl Final {
167    /// Create the unit type.
168    pub fn unit() -> Arc<Self> {
169        Arc::new(Final {
170            bound: CompleteBound::Unit,
171            bit_width: 0,
172            has_padding: false,
173            tmr: Tmr::unit(),
174        })
175    }
176
177    /// Create the type `2^(2^n)` for the given `n`.
178    ///
179    /// The type is precomputed and fast to access.
180    pub fn two_two_n(n: usize) -> Arc<Self> {
181        super::precomputed::nth_power_of_2(n)
182    }
183
184    construct_final_two_two_n!(u1, 0, "1-bit");
185    construct_final_two_two_n!(u2, 1, "2-bit");
186    construct_final_two_two_n!(u4, 2, "4-bit");
187    construct_final_two_two_n!(u8, 3, "8-bit");
188    construct_final_two_two_n!(u16, 4, "16-bit");
189    construct_final_two_two_n!(u32, 5, "32-bit");
190    construct_final_two_two_n!(u64, 6, "64-bit");
191    construct_final_two_two_n!(u128, 7, "128-bit");
192    construct_final_two_two_n!(u256, 8, "256-bit");
193    construct_final_two_two_n!(u512, 9, "512-bit");
194
195    /// Create the sum of the given `left` and `right` types.
196    pub fn sum(left: Arc<Self>, right: Arc<Self>) -> Arc<Self> {
197        // Use saturating_add for bitwidths. If the user has overflowed usize, even on a 32-bit
198        // system this means that they have a 4-gigabit type and their program should be rejected
199        // by a sanity check somewhere. However, if we panic here, the user cannot finalize their
200        // program and cannot even tell that this resource limit has been hit before panicking.
201        Arc::new(Final {
202            tmr: Tmr::sum(left.tmr, right.tmr),
203            bit_width: cmp::max(left.bit_width, right.bit_width).saturating_add(1),
204            has_padding: left.has_padding || right.has_padding || left.bit_width != right.bit_width,
205            bound: CompleteBound::Sum(left, right),
206        })
207    }
208
209    /// Create the product of the given `left` and `right` types.
210    pub fn product(left: Arc<Self>, right: Arc<Self>) -> Arc<Self> {
211        // See comment in `sum` about use of saturating add.
212        Arc::new(Final {
213            tmr: Tmr::product(left.tmr, right.tmr),
214            bit_width: left.bit_width.saturating_add(right.bit_width),
215            has_padding: left.has_padding || right.has_padding,
216            bound: CompleteBound::Product(left, right),
217        })
218    }
219
220    /// Accessor for the TMR
221    pub fn tmr(&self) -> Tmr {
222        self.tmr
223    }
224
225    /// Accessor for the Bit Machine bit-width of the type
226    pub fn bit_width(&self) -> usize {
227        self.bit_width
228    }
229
230    /// Whether the type's bit representation has any padding.
231    ///
232    /// If this is true, then its "compact" witness-encoded bit-width may be lower
233    /// than its "padded" bit-machine bit-width.
234    pub fn has_padding(&self) -> bool {
235        self.has_padding
236    }
237
238    /// Check if the type is a nested product of units.
239    /// In this case, values contain no information.
240    pub fn is_empty(&self) -> bool {
241        self.bit_width() == 0
242    }
243
244    /// Accessor for the type bound
245    pub fn bound(&self) -> &CompleteBound {
246        &self.bound
247    }
248
249    /// Check if the type is a unit.
250    pub fn is_unit(&self) -> bool {
251        self.bound == CompleteBound::Unit
252    }
253
254    /// Access the inner types of a sum type.
255    pub fn as_sum(&self) -> Option<(&Arc<Self>, &Arc<Self>)> {
256        match &self.bound {
257            CompleteBound::Sum(left, right) => Some((left, right)),
258            _ => None,
259        }
260    }
261
262    /// Access the inner types of a product type.
263    pub fn as_product(&self) -> Option<(&Arc<Self>, &Arc<Self>)> {
264        match &self.bound {
265            CompleteBound::Product(left, right) => Some((left, right)),
266            _ => None,
267        }
268    }
269
270    /// If the type is of the form `TWO^(2^n)`, then return `n`.
271    ///
272    /// ## Post condition
273    ///
274    /// 0 ≤ n < 32.
275    pub fn as_word(&self) -> Option<u32> {
276        (0..32u32).find(|&n| {
277            self.tmr == Tmr::TWO_TWO_N[n as usize] // cast safety: 32-bit machine or higher
278        })
279    }
280
281    /// Compute the padding of left values of the sum type `Self + Other`.
282    pub fn pad_left(&self, other: &Self) -> usize {
283        cmp::max(self.bit_width, other.bit_width) - self.bit_width
284    }
285
286    /// Compute the padding of right values of the sum type `Self + Other`.
287    pub fn pad_right(&self, other: &Self) -> usize {
288        cmp::max(self.bit_width, other.bit_width) - other.bit_width
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn final_stringify() {
298        let ty1 = Final::two_two_n(10);
299        assert_eq!(ty1.to_string(), "2^1024");
300
301        let sum = Final::sum(Final::two_two_n(5), Final::two_two_n(10));
302        assert_eq!(sum.to_string(), "2^32 + 2^1024");
303
304        let prod = Final::product(Final::two_two_n(5), Final::two_two_n(10));
305        assert_eq!(prod.to_string(), "2^32 × 2^1024");
306
307        let ty1 = Final::two_two_n(0);
308        assert_eq!(ty1.to_string(), "2");
309
310        let ty1 = Final::sum(Final::unit(), Final::two_two_n(2));
311        assert_eq!(ty1.to_string(), "2^4?");
312    }
313}