varing/packable/ruint/
v1.rs

1use core::ops::{BitAnd, Shl, Shr};
2use ruint_1::Uint;
3
4use crate::packable::Packable;
5
6/// Packs `Uint<LBITS, LLIMBS>` and `Uint<RBITS, RLIMBS>` into `Uint<PBITS, PLIMBS>`.
7fn pack_uint<
8  const LBITS: usize,
9  const LLIMBS: usize,
10  const RBITS: usize,
11  const RLIMBS: usize,
12  const PBITS: usize,
13  const PLIMBS: usize,
14>(
15  lhs: &Uint<LBITS, LLIMBS>,
16  rhs: &Uint<RBITS, RLIMBS>,
17) -> Uint<PBITS, PLIMBS> {
18  if LBITS == 0 && RBITS == 0 {
19    return Uint::<PBITS, PLIMBS>::ZERO;
20  }
21
22  assert_consts::<LBITS, LLIMBS, RBITS, RLIMBS, PBITS, PLIMBS>();
23
24  // Decide which value goes to the high bits and which to the low bits
25  let (high_bits, high_value, low_value, low_bits) = if LBITS > RBITS {
26    // lhs goes to high bits, rhs goes to low bits
27    let high = Uint::<PBITS, PLIMBS>::from_limbs_slice(lhs.as_limbs());
28    let low = Uint::<PBITS, PLIMBS>::from_limbs_slice(rhs.as_limbs());
29    (LBITS, high, low, RBITS)
30  } else {
31    // rhs goes to high bits, lhs goes to low bits
32    let high = Uint::<PBITS, PLIMBS>::from_limbs_slice(rhs.as_limbs());
33    let low = Uint::<PBITS, PLIMBS>::from_limbs_slice(lhs.as_limbs());
34    (RBITS, high, low, LBITS)
35  };
36
37  // Create mask for the low value to ensure it doesn't exceed its bit width
38  let low_mask = if low_bits == PBITS {
39    Uint::<PBITS, PLIMBS>::MAX
40  } else {
41    (Uint::<PBITS, PLIMBS>::from(1u64) << low_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
42  };
43
44  // Apply mask to low value
45  let masked_low = low_value.bitand(low_mask);
46
47  // Create mask for the high value
48  let high_mask = if high_bits == PBITS {
49    Uint::<PBITS, PLIMBS>::MAX
50  } else {
51    (Uint::<PBITS, PLIMBS>::from(1u64) << high_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
52  };
53
54  // Apply mask to high value, shift it to proper position, and combine with low value
55  let masked_high = high_value.bitand(high_mask);
56
57  masked_high.shl(low_bits as u32).bitor(masked_low)
58}
59
60/// Unpacks `Uint<PBITS, PLIMBS>` into `Uint<LBITS, LLIMBS>` and `Uint<RBITS, RLIMBS>`.
61fn unpack_uint<
62  const LBITS: usize,
63  const LLIMBS: usize,
64  const RBITS: usize,
65  const RLIMBS: usize,
66  const PBITS: usize,
67  const PLIMBS: usize,
68>(
69  packed: &Uint<PBITS, PLIMBS>,
70) -> (Uint<LBITS, LLIMBS>, Uint<RBITS, RLIMBS>) {
71  if LBITS == 0 && RBITS == 0 {
72    return (Uint::<LBITS, LLIMBS>::ZERO, Uint::<RBITS, RLIMBS>::ZERO);
73  }
74
75  assert_consts::<LBITS, LLIMBS, RBITS, RLIMBS, PBITS, PLIMBS>();
76
77  // Determine which value was placed in high bits vs low bits
78  let low_bits = if LBITS > RBITS { RBITS } else { LBITS };
79
80  // Create masks for extracting each value
81  let low_mask = if low_bits == PBITS {
82    Uint::<PBITS, PLIMBS>::MAX
83  } else {
84    (Uint::<PBITS, PLIMBS>::from(1u64) << low_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
85  };
86
87  // Extract the low bits part
88  let low_value = packed.bitand(&low_mask);
89
90  // Extract the high bits part
91  let high_value = packed.shr(low_bits as u32);
92
93  // Create properly sized results based on whether lhs or rhs was the larger value
94  if LBITS > RBITS {
95    // lhs was in high bits, rhs was in low bits
96    let lhs = Uint::<LBITS, LLIMBS>::from_limbs_slice(&high_value.as_limbs()[..LLIMBS]);
97    let rhs = Uint::<RBITS, RLIMBS>::from_limbs_slice(&low_value.as_limbs()[..RLIMBS]);
98    (lhs, rhs)
99  } else {
100    // rhs was in high bits, lhs was in low bits
101    let lhs = Uint::<LBITS, LLIMBS>::from_limbs_slice(&low_value.as_limbs()[..LLIMBS]);
102    let rhs = Uint::<RBITS, RLIMBS>::from_limbs_slice(&high_value.as_limbs()[..RLIMBS]);
103    (lhs, rhs)
104  }
105}
106
107impl<
108    const LBITS: usize,
109    const LLIMBS: usize,
110    const RBITS: usize,
111    const RLIMBS: usize,
112    const PBITS: usize,
113    const PLIMBS: usize,
114  > Packable<Uint<LBITS, LLIMBS>, Uint<RBITS, RLIMBS>> for Uint<PBITS, PLIMBS>
115{
116  fn pack(&self, rhs: &Uint<LBITS, LLIMBS>) -> Uint<RBITS, RLIMBS> {
117    pack_uint(self, rhs)
118  }
119
120  fn unpack(packed: Uint<RBITS, RLIMBS>) -> (Self, Uint<LBITS, LLIMBS>)
121  where
122    Self: Sized,
123    Uint<LBITS, LLIMBS>: Sized,
124  {
125    unpack_uint(&packed)
126  }
127}
128
129const fn assert_consts<
130  const LBITS: usize,
131  const LLIMBS: usize,
132  const RBITS: usize,
133  const RLIMBS: usize,
134  const PBITS: usize,
135  const PLIMBS: usize,
136>() {
137  // Check if there's enough space in the packed value for both integers
138  assert!(
139    LBITS + RBITS <= PBITS,
140    "The sum of LBITS and RBITS must be less than or equal to PBITS"
141  );
142  assert!(
143    LLIMBS + RLIMBS <= PLIMBS,
144    "The sum of LLIMBS and RLIMBS must be less than or equal to PLIMBS"
145  );
146}
147
148#[cfg(test)]
149mod tests {
150  use super::*;
151
152  fn roundtrip<
153    const LBITS: usize,
154    const LLIMBS: usize,
155    const RBITS: usize,
156    const RLIMBS: usize,
157    const PBITS: usize,
158    const PLIMBS: usize,
159  >(
160    lhs: &Uint<LBITS, LLIMBS>,
161    rhs: &Uint<RBITS, RLIMBS>,
162  ) -> bool {
163    let packed = lhs.pack(rhs);
164    let (lhs_unpacked, rhs_unpacked) =
165      <Uint<LBITS, LLIMBS> as Packable<Uint<RBITS, RLIMBS>, Uint<PBITS, PLIMBS>>>::unpack(packed);
166    lhs == &lhs_unpacked && rhs == &rhs_unpacked
167  }
168
169  macro_rules! fuzzy_packable {
170    ($(($bits:literal, $limbs:literal)),+$(,)?) => {
171      paste::paste! {
172        quickcheck::quickcheck! {
173          $(
174            fn [<fuzzy_u $bits:snake>](a: [<U $bits>], b: [<U $bits>]) -> bool {
175              roundtrip::<$bits, $limbs, $bits, $limbs, {$bits * 2}, {$limbs * 2}>(&a, &b)
176            }
177          )*
178        }
179      }
180    };
181  }
182
183  use ruint_1::aliases::*;
184
185  fuzzy_packable!(
186    (64, 1),
187    (128, 2),
188    (256, 4),
189    (512, 8),
190    (1024, 16),
191    (2048, 32),
192  );
193
194  #[test]
195  fn zero() {
196    let output = pack_uint(&Uint::<0, 0>::ZERO, &Uint::<0, 0>::ZERO);
197    assert_eq!(output, Uint::<0, 0>::ZERO);
198    let (lhs, rhs) = unpack_uint::<0, 0, 0, 0, 0, 0>(&output);
199    assert_eq!(lhs, Uint::<0, 0>::ZERO);
200    assert_eq!(rhs, Uint::<0, 0>::ZERO);
201  }
202
203  #[test]
204  #[should_panic]
205  fn assert_consts_panic() {
206    assert_consts::<1, 1, 1, 1, 2, 0>();
207  }
208
209  #[test]
210  #[should_panic]
211  fn assert_consts_panic_2() {
212    assert_consts::<1, 1, 1, 1, 0, 2>();
213  }
214}