Skip to main content

sp1_hypercube/
word.rs

1use std::{
2    fmt::Display,
3    ops::{Index, IndexMut},
4};
5
6use crate::air::SP1AirBuilder;
7use arrayref::array_ref;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use slop_algebra::{AbstractField, Field};
11use sp1_derive::AlignedBorrow;
12use sp1_primitives::consts::WORD_SIZE;
13use std::array::IntoIter;
14
15/// An array of four u16 limbs to represent a 64-bit value.
16///
17/// We use the generic type `T` to represent the different representations of a u16 limb, ranging
18/// from a `u16` to a `AB::Var` or `AB::Expr`.
19#[derive(
20    AlignedBorrow, Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize,
21)]
22#[repr(C)]
23pub struct Word<T>(pub [T; WORD_SIZE]);
24
25impl<T> Word<T> {
26    /// Applies `f` to each element of the word.
27    pub fn map<F, S>(self, f: F) -> Word<S>
28    where
29        F: FnMut(T) -> S,
30    {
31        Word(self.0.map(f))
32    }
33
34    /// Extends a variable to a word.
35    pub fn extend_var<AB: SP1AirBuilder<Var = T>>(var: T) -> Word<AB::Expr> {
36        Word([AB::Expr::zero() + var, AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
37    }
38
39    /// Extends a half word to a word.
40    pub fn extend_half<AB: SP1AirBuilder<Var = T>>(var: &[T; 2]) -> Word<AB::Expr>
41    where
42        T: Clone,
43    {
44        Word([
45            AB::Expr::zero() + var[0].clone(),
46            AB::Expr::zero() + var[1].clone(),
47            AB::Expr::zero(),
48            AB::Expr::zero(),
49        ])
50    }
51}
52
53impl<T: AbstractField + Clone> Word<T> {
54    /// Extends a variable to a word.
55    pub fn extend_expr<AB: SP1AirBuilder<Expr = T>>(expr: T) -> Word<AB::Expr> {
56        Word([AB::Expr::zero() + expr, AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
57    }
58
59    /// Returns a word with all zero expressions.
60    #[must_use]
61    pub fn zero<AB: SP1AirBuilder<Expr = T>>() -> Word<T> {
62        Word([AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero()])
63    }
64
65    /// Reduces a word to a single expression.
66    pub fn reduce<AB: SP1AirBuilder<Expr = T>>(&self) -> AB::Expr {
67        let base = [1, 1 << 16, 1 << 32, 1 << 48].map(AB::Expr::from_wrapped_u64);
68        self.0.iter().enumerate().map(|(i, x)| base[i].clone() * x.clone()).sum()
69    }
70
71    /// Creates a word from `le_bits`.
72    /// Safety: This assumes that the `le_bits` are already checked to be boolean.
73    pub fn from_le_bits<AB: SP1AirBuilder<Expr = T>>(
74        le_bits: &[impl Into<T> + Clone],
75        sign_extend: bool,
76    ) -> Word<AB::Expr> {
77        assert!(le_bits.len() <= WORD_SIZE * 16);
78
79        let mut limbs = le_bits
80            .chunks(16)
81            .map(|chunk| {
82                chunk.iter().enumerate().fold(AB::Expr::zero(), |a, (i, b)| {
83                    a + AB::Expr::from_canonical_u16(1 << i) * (*b).clone().into()
84                })
85            })
86            .collect_vec();
87
88        let sign_bit = (*le_bits.last().unwrap()).clone().into();
89
90        if sign_extend {
91            // Sign extend the most significant limb.
92            let most_sig_limb = limbs.last_mut().unwrap();
93            let most_sig_num_bits = le_bits.len() % 16;
94            if most_sig_num_bits > 0 {
95                *most_sig_limb = (*most_sig_limb).clone()
96                    + (AB::Expr::from_canonical_u32((1 << 16) - (1 << most_sig_num_bits)))
97                        * sign_bit.clone();
98            }
99        }
100
101        let extend_limb = if sign_extend {
102            AB::Expr::from_canonical_u16(u16::MAX) * sign_bit.clone()
103        } else {
104            AB::Expr::zero()
105        };
106
107        limbs.resize(WORD_SIZE, extend_limb);
108
109        Word::from_iter(limbs)
110    }
111}
112
113impl<F: Field> Word<F> {
114    /// Converts a word to a u32.
115    pub fn to_u32(&self) -> u32 {
116        let low = self.0[0].to_string().parse::<u16>().unwrap();
117        let high = self.0[1].to_string().parse::<u16>().unwrap();
118        ((high as u32) << 16) | (low as u32)
119    }
120
121    /// Converts a word to a u64.
122    pub fn to_u64(&self) -> u64 {
123        let low = self.0[0].to_string().parse::<u16>().unwrap();
124        let mid_low = self.0[1].to_string().parse::<u16>().unwrap();
125        let mid_high = self.0[2].to_string().parse::<u16>().unwrap();
126        let high = self.0[3].to_string().parse::<u16>().unwrap();
127        ((high as u64) << 48) | ((mid_high as u64) << 32) | ((mid_low as u64) << 16) | (low as u64)
128    }
129}
130
131impl<T> Index<usize> for Word<T> {
132    type Output = T;
133
134    fn index(&self, index: usize) -> &Self::Output {
135        &self.0[index]
136    }
137}
138
139impl<T> IndexMut<usize> for Word<T> {
140    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
141        &mut self.0[index]
142    }
143}
144
145impl<F: AbstractField> From<u32> for Word<F> {
146    fn from(value: u32) -> Self {
147        Word([
148            F::from_canonical_u16((value & 0xFFFF) as u16),
149            F::from_canonical_u16((value >> 16) as u16),
150            F::zero(),
151            F::zero(),
152        ])
153    }
154}
155
156impl<F: AbstractField> From<u64> for Word<F> {
157    fn from(value: u64) -> Self {
158        Word([
159            F::from_canonical_u16((value & 0xFFFF) as u16),
160            F::from_canonical_u16((value >> 16) as u16),
161            F::from_canonical_u16((value >> 32) as u16),
162            F::from_canonical_u16((value >> 48) as u16),
163        ])
164    }
165}
166
167impl<T> IntoIterator for Word<T> {
168    type Item = T;
169    type IntoIter = IntoIter<T, WORD_SIZE>;
170
171    fn into_iter(self) -> Self::IntoIter {
172        self.0.into_iter()
173    }
174}
175
176impl<T: Clone> FromIterator<T> for Word<T> {
177    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
178        let elements = iter.into_iter().take(WORD_SIZE).collect_vec();
179
180        Word(array_ref![elements, 0, WORD_SIZE].clone())
181    }
182}
183
184impl<T: Display> Display for Word<T> {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        write!(f, "Word(")?;
187        for (i, value) in self.0.iter().enumerate() {
188            write!(f, "{value}")?;
189            if i < self.0.len() - 1 {
190                write!(f, ", ")?;
191            }
192        }
193        write!(f, ")")?;
194        Ok(())
195    }
196}