1use std::{
2 fmt::Display,
3 ops::{Index, IndexMut},
4};
5
6use crate::air::SP1AirBuilder;
7use arrayref::array_ref;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use slop_algebra::{AbstractField, Field};
11use sp1_derive::AlignedBorrow;
12use sp1_primitives::consts::WORD_SIZE;
13use std::array::IntoIter;
14use struct_reflection::{StructReflection, StructReflectionHelper};
15
16#[derive(
21 AlignedBorrow,
22 Clone,
23 Copy,
24 Debug,
25 Default,
26 PartialEq,
27 Eq,
28 Hash,
29 Serialize,
30 Deserialize,
31 StructReflection,
32)]
33#[repr(C)]
34pub struct Word<T>(pub [T; WORD_SIZE]);
35
36impl<T> Word<T> {
37 pub fn map<F, S>(self, f: F) -> Word<S>
39 where
40 F: FnMut(T) -> S,
41 {
42 Word(self.0.map(f))
43 }
44
45 pub fn extend_var<AB: SP1AirBuilder<Var = T>>(var: T) -> Word<AB::Expr> {
47 Word([AB::Expr::zero() + var, AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
48 }
49
50 pub fn extend_half<AB: SP1AirBuilder<Var = T>>(var: &[T; 2]) -> Word<AB::Expr>
52 where
53 T: Clone,
54 {
55 Word([
56 AB::Expr::zero() + var[0].clone(),
57 AB::Expr::zero() + var[1].clone(),
58 AB::Expr::zero(),
59 AB::Expr::zero(),
60 ])
61 }
62}
63
64impl<T: AbstractField + Clone> Word<T> {
65 pub fn extend_expr<AB: SP1AirBuilder<Expr = T>>(expr: T) -> Word<AB::Expr> {
67 Word([AB::Expr::zero() + expr, AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
68 }
69
70 #[must_use]
72 pub fn zero<AB: SP1AirBuilder<Expr = T>>() -> Word<T> {
73 Word([AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
74 }
75
76 pub fn reduce<AB: SP1AirBuilder<Expr = T>>(&self) -> AB::Expr {
78 let base = [1, 1 << 16, 1 << 32, 1 << 48].map(AB::Expr::from_wrapped_u64);
79 self.0.iter().enumerate().map(|(i, x)| base[i].clone() * x.clone()).sum()
80 }
81
82 pub fn from_le_bits<AB: SP1AirBuilder<Expr = T>>(
85 le_bits: &[impl Into<T> + Clone],
86 sign_extend: bool,
87 ) -> Word<AB::Expr> {
88 assert!(le_bits.len() <= WORD_SIZE * 16);
89
90 let mut limbs = le_bits
91 .chunks(16)
92 .map(|chunk| {
93 chunk.iter().enumerate().fold(AB::Expr::zero(), |a, (i, b)| {
94 a + AB::Expr::from_canonical_u16(1 << i) * (*b).clone().into()
95 })
96 })
97 .collect_vec();
98
99 let sign_bit = (*le_bits.last().unwrap()).clone().into();
100
101 if sign_extend {
102 let most_sig_limb = limbs.last_mut().unwrap();
104 let most_sig_num_bits = le_bits.len() % 16;
105 if most_sig_num_bits > 0 {
106 *most_sig_limb = (*most_sig_limb).clone()
107 + (AB::Expr::from_canonical_u32((1 << 16) - (1 << most_sig_num_bits)))
108 * sign_bit.clone();
109 }
110 }
111
112 let extend_limb = if sign_extend {
113 AB::Expr::from_canonical_u16(u16::MAX) * sign_bit.clone()
114 } else {
115 AB::Expr::zero()
116 };
117
118 limbs.resize(WORD_SIZE, extend_limb);
119
120 Word::from_iter(limbs)
121 }
122}
123
124impl<F: Field> Word<F> {
125 pub fn to_u32(&self) -> u32 {
127 let low = self.0[0].to_string().parse::<u16>().unwrap();
128 let high = self.0[1].to_string().parse::<u16>().unwrap();
129 ((high as u32) << 16) | (low as u32)
130 }
131
132 pub fn to_u64(&self) -> u64 {
134 let low = self.0[0].to_string().parse::<u16>().unwrap();
135 let mid_low = self.0[1].to_string().parse::<u16>().unwrap();
136 let mid_high = self.0[2].to_string().parse::<u16>().unwrap();
137 let high = self.0[3].to_string().parse::<u16>().unwrap();
138 ((high as u64) << 48) | ((mid_high as u64) << 32) | ((mid_low as u64) << 16) | (low as u64)
139 }
140}
141
142impl<T> Index<usize> for Word<T> {
143 type Output = T;
144
145 fn index(&self, index: usize) -> &Self::Output {
146 &self.0[index]
147 }
148}
149
150impl<T> IndexMut<usize> for Word<T> {
151 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
152 &mut self.0[index]
153 }
154}
155
156impl<F: AbstractField> From<u32> for Word<F> {
157 fn from(value: u32) -> Self {
158 Word([
159 F::from_canonical_u16((value & 0xFFFF) as u16),
160 F::from_canonical_u16((value >> 16) as u16),
161 F::zero(),
162 F::zero(),
163 ])
164 }
165}
166
167impl<F: AbstractField> From<u64> for Word<F> {
168 fn from(value: u64) -> Self {
169 Word([
170 F::from_canonical_u16((value & 0xFFFF) as u16),
171 F::from_canonical_u16((value >> 16) as u16),
172 F::from_canonical_u16((value >> 32) as u16),
173 F::from_canonical_u16((value >> 48) as u16),
174 ])
175 }
176}
177
178impl<T> IntoIterator for Word<T> {
179 type Item = T;
180 type IntoIter = IntoIter<T, WORD_SIZE>;
181
182 fn into_iter(self) -> Self::IntoIter {
183 self.0.into_iter()
184 }
185}
186
187impl<T: Clone> FromIterator<T> for Word<T> {
188 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
189 let elements = iter.into_iter().take(WORD_SIZE).collect_vec();
190
191 Word(array_ref![elements, 0, WORD_SIZE].clone())
192 }
193}
194
195impl<T: Display> Display for Word<T> {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 write!(f, "Word(")?;
198 for (i, value) in self.0.iter().enumerate() {
199 write!(f, "{value}")?;
200 if i < self.0.len() - 1 {
201 write!(f, ", ")?;
202 }
203 }
204 write!(f, ")")?;
205 Ok(())
206 }
207}