1use crate::dag::{Dag, DagLike, NoSharing};
15use crate::Tmr;
16
17use std::sync::Arc;
18use std::{cmp, fmt, hash};
19
20#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
22pub enum CompleteBound {
23 Unit,
25 Sum(Arc<Final>, Arc<Final>),
27 Product(Arc<Final>, Arc<Final>),
29}
30
31#[derive(Clone)]
34pub struct Final {
35 bound: CompleteBound,
37 bit_width: usize,
39 has_padding: bool,
43 tmr: Tmr,
45}
46
47impl PartialEq for Final {
48 fn eq(&self, other: &Self) -> bool {
49 self.tmr == other.tmr
50 }
51}
52impl Eq for Final {}
53
54impl PartialOrd for Final {
55 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
56 Some(self.cmp(other))
57 }
58}
59impl Ord for Final {
60 fn cmp(&self, other: &Self) -> cmp::Ordering {
61 self.tmr.cmp(&other.tmr)
62 }
63}
64impl hash::Hash for Final {
65 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
66 self.tmr.hash(hasher)
67 }
68}
69
70impl fmt::Debug for Final {
71 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72 write!(
73 f,
74 "{{ tmr: {}, bit_width: {}, bound: {} }}",
75 self.tmr, self.bit_width, self
76 )
77 }
78}
79
80impl fmt::Display for Final {
81 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82 let mut skipping: Option<Tmr> = None;
83 for data in self.verbose_pre_order_iter::<NoSharing>(None) {
84 if let Some(skip) = skipping {
85 if data.is_complete && data.node.tmr == skip {
86 skipping = None;
87 }
88 continue;
89 } else {
90 if data.node.tmr == Tmr::TWO_TWO_N[0] {
91 f.write_str("2")?;
92 skipping = Some(data.node.tmr);
93 }
94 for (n, tmr) in Tmr::TWO_TWO_N.iter().enumerate().skip(1) {
95 if data.node.tmr == *tmr {
96 write!(f, "2^{}", 1 << n)?;
97 skipping = Some(data.node.tmr);
98 }
99 }
100 }
101 if skipping.is_some() {
102 continue;
103 }
104
105 match (&data.node.bound, data.n_children_yielded) {
106 (CompleteBound::Unit, _) => {
107 f.write_str("1")?;
108 }
109 (CompleteBound::Sum(ref left, _), 0)
111 if matches!(left.bound, CompleteBound::Unit) =>
112 {
113 skipping = Some(Tmr::unit());
114 }
115 (CompleteBound::Sum(ref left, _), 1)
116 if matches!(left.bound, CompleteBound::Unit) => {}
117 (CompleteBound::Sum(ref left, _), 2)
118 if matches!(left.bound, CompleteBound::Unit) =>
119 {
120 f.write_str("?")?;
121 }
122 (CompleteBound::Sum(..), 0) | (CompleteBound::Product(..), 0) => {
124 if data.index > 0 {
125 f.write_str("(")?;
126 }
127 }
128 (CompleteBound::Sum(..), 2) | (CompleteBound::Product(..), 2) => {
129 if data.index > 0 {
130 f.write_str(")")?;
131 }
132 }
133 (CompleteBound::Sum(..), _) => f.write_str(" + ")?,
134 (CompleteBound::Product(..), _) => f.write_str(" × ")?,
135 }
136 }
137 Ok(())
138 }
139}
140
141impl DagLike for &'_ Final {
142 type Node = Final;
143 fn data(&self) -> &Final {
144 self
145 }
146 fn as_dag_node(&self) -> Dag<Self> {
147 match self.bound {
148 CompleteBound::Unit => Dag::Nullary,
149 CompleteBound::Sum(ref left, ref right)
150 | CompleteBound::Product(ref left, ref right) => Dag::Binary(left, right),
151 }
152 }
153}
154
155macro_rules! construct_final_two_two_n {
156 ($name: ident, $n: expr, $text: expr) => {
157 #[doc = "Create the type of"]
158 #[doc = $text]
159 #[doc = "words.\n\nThe type is precomputed and fast to access."]
160 pub fn $name() -> Arc<Self> {
161 super::precomputed::nth_power_of_2($n)
162 }
163 };
164}
165
166impl Final {
167 pub fn unit() -> Arc<Self> {
169 Arc::new(Final {
170 bound: CompleteBound::Unit,
171 bit_width: 0,
172 has_padding: false,
173 tmr: Tmr::unit(),
174 })
175 }
176
177 pub fn two_two_n(n: usize) -> Arc<Self> {
181 super::precomputed::nth_power_of_2(n)
182 }
183
184 construct_final_two_two_n!(u1, 0, "1-bit");
185 construct_final_two_two_n!(u2, 1, "2-bit");
186 construct_final_two_two_n!(u4, 2, "4-bit");
187 construct_final_two_two_n!(u8, 3, "8-bit");
188 construct_final_two_two_n!(u16, 4, "16-bit");
189 construct_final_two_two_n!(u32, 5, "32-bit");
190 construct_final_two_two_n!(u64, 6, "64-bit");
191 construct_final_two_two_n!(u128, 7, "128-bit");
192 construct_final_two_two_n!(u256, 8, "256-bit");
193 construct_final_two_two_n!(u512, 9, "512-bit");
194
195 pub fn sum(left: Arc<Self>, right: Arc<Self>) -> Arc<Self> {
197 Arc::new(Final {
202 tmr: Tmr::sum(left.tmr, right.tmr),
203 bit_width: cmp::max(left.bit_width, right.bit_width).saturating_add(1),
204 has_padding: left.has_padding || right.has_padding || left.bit_width != right.bit_width,
205 bound: CompleteBound::Sum(left, right),
206 })
207 }
208
209 pub fn product(left: Arc<Self>, right: Arc<Self>) -> Arc<Self> {
211 Arc::new(Final {
213 tmr: Tmr::product(left.tmr, right.tmr),
214 bit_width: left.bit_width.saturating_add(right.bit_width),
215 has_padding: left.has_padding || right.has_padding,
216 bound: CompleteBound::Product(left, right),
217 })
218 }
219
220 pub fn tmr(&self) -> Tmr {
222 self.tmr
223 }
224
225 pub fn bit_width(&self) -> usize {
227 self.bit_width
228 }
229
230 pub fn has_padding(&self) -> bool {
235 self.has_padding
236 }
237
238 pub fn is_empty(&self) -> bool {
241 self.bit_width() == 0
242 }
243
244 pub fn bound(&self) -> &CompleteBound {
246 &self.bound
247 }
248
249 pub fn is_unit(&self) -> bool {
251 self.bound == CompleteBound::Unit
252 }
253
254 pub fn as_sum(&self) -> Option<(&Arc<Self>, &Arc<Self>)> {
256 match &self.bound {
257 CompleteBound::Sum(left, right) => Some((left, right)),
258 _ => None,
259 }
260 }
261
262 pub fn as_product(&self) -> Option<(&Arc<Self>, &Arc<Self>)> {
264 match &self.bound {
265 CompleteBound::Product(left, right) => Some((left, right)),
266 _ => None,
267 }
268 }
269
270 pub fn as_word(&self) -> Option<u32> {
276 (0..32u32).find(|&n| {
277 self.tmr == Tmr::TWO_TWO_N[n as usize] })
279 }
280
281 pub fn pad_left(&self, other: &Self) -> usize {
283 cmp::max(self.bit_width, other.bit_width) - self.bit_width
284 }
285
286 pub fn pad_right(&self, other: &Self) -> usize {
288 cmp::max(self.bit_width, other.bit_width) - other.bit_width
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn final_stringify() {
298 let ty1 = Final::two_two_n(10);
299 assert_eq!(ty1.to_string(), "2^1024");
300
301 let sum = Final::sum(Final::two_two_n(5), Final::two_two_n(10));
302 assert_eq!(sum.to_string(), "2^32 + 2^1024");
303
304 let prod = Final::product(Final::two_two_n(5), Final::two_two_n(10));
305 assert_eq!(prod.to_string(), "2^32 × 2^1024");
306
307 let ty1 = Final::two_two_n(0);
308 assert_eq!(ty1.to_string(), "2");
309
310 let ty1 = Final::sum(Final::unit(), Final::two_two_n(2));
311 assert_eq!(ty1.to_string(), "2^4?");
312 }
313}