1use core::ops::{BitAnd, Shl, Shr};
2use ruint_1::Uint;
3
4use crate::packable::Packable;
5
6fn pack_uint<
8 const LBITS: usize,
9 const LLIMBS: usize,
10 const RBITS: usize,
11 const RLIMBS: usize,
12 const PBITS: usize,
13 const PLIMBS: usize,
14>(
15 lhs: &Uint<LBITS, LLIMBS>,
16 rhs: &Uint<RBITS, RLIMBS>,
17) -> Uint<PBITS, PLIMBS> {
18 if LBITS == 0 && RBITS == 0 {
19 return Uint::<PBITS, PLIMBS>::ZERO;
20 }
21
22 assert_consts::<LBITS, LLIMBS, RBITS, RLIMBS, PBITS, PLIMBS>();
23
24 let (high_bits, high_value, low_value, low_bits) = if LBITS > RBITS {
26 let high = Uint::<PBITS, PLIMBS>::from_limbs_slice(lhs.as_limbs());
28 let low = Uint::<PBITS, PLIMBS>::from_limbs_slice(rhs.as_limbs());
29 (LBITS, high, low, RBITS)
30 } else {
31 let high = Uint::<PBITS, PLIMBS>::from_limbs_slice(rhs.as_limbs());
33 let low = Uint::<PBITS, PLIMBS>::from_limbs_slice(lhs.as_limbs());
34 (RBITS, high, low, LBITS)
35 };
36
37 let low_mask = if low_bits == PBITS {
39 Uint::<PBITS, PLIMBS>::MAX
40 } else {
41 (Uint::<PBITS, PLIMBS>::from(1u64) << low_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
42 };
43
44 let masked_low = low_value.bitand(low_mask);
46
47 let high_mask = if high_bits == PBITS {
49 Uint::<PBITS, PLIMBS>::MAX
50 } else {
51 (Uint::<PBITS, PLIMBS>::from(1u64) << high_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
52 };
53
54 let masked_high = high_value.bitand(high_mask);
56
57 masked_high.shl(low_bits as u32).bitor(masked_low)
58}
59
60fn unpack_uint<
62 const LBITS: usize,
63 const LLIMBS: usize,
64 const RBITS: usize,
65 const RLIMBS: usize,
66 const PBITS: usize,
67 const PLIMBS: usize,
68>(
69 packed: &Uint<PBITS, PLIMBS>,
70) -> (Uint<LBITS, LLIMBS>, Uint<RBITS, RLIMBS>) {
71 if LBITS == 0 && RBITS == 0 {
72 return (Uint::<LBITS, LLIMBS>::ZERO, Uint::<RBITS, RLIMBS>::ZERO);
73 }
74
75 assert_consts::<LBITS, LLIMBS, RBITS, RLIMBS, PBITS, PLIMBS>();
76
77 let low_bits = if LBITS > RBITS { RBITS } else { LBITS };
79
80 let low_mask = if low_bits == PBITS {
82 Uint::<PBITS, PLIMBS>::MAX
83 } else {
84 (Uint::<PBITS, PLIMBS>::from(1u64) << low_bits) - Uint::<PBITS, PLIMBS>::from(1u64)
85 };
86
87 let low_value = packed.bitand(&low_mask);
89
90 let high_value = packed.shr(low_bits as u32);
92
93 if LBITS > RBITS {
95 let lhs = Uint::<LBITS, LLIMBS>::from_limbs_slice(&high_value.as_limbs()[..LLIMBS]);
97 let rhs = Uint::<RBITS, RLIMBS>::from_limbs_slice(&low_value.as_limbs()[..RLIMBS]);
98 (lhs, rhs)
99 } else {
100 let lhs = Uint::<LBITS, LLIMBS>::from_limbs_slice(&low_value.as_limbs()[..LLIMBS]);
102 let rhs = Uint::<RBITS, RLIMBS>::from_limbs_slice(&high_value.as_limbs()[..RLIMBS]);
103 (lhs, rhs)
104 }
105}
106
107impl<
108 const LBITS: usize,
109 const LLIMBS: usize,
110 const RBITS: usize,
111 const RLIMBS: usize,
112 const PBITS: usize,
113 const PLIMBS: usize,
114 > Packable<Uint<LBITS, LLIMBS>, Uint<RBITS, RLIMBS>> for Uint<PBITS, PLIMBS>
115{
116 fn pack(&self, rhs: &Uint<LBITS, LLIMBS>) -> Uint<RBITS, RLIMBS> {
117 pack_uint(self, rhs)
118 }
119
120 fn unpack(packed: Uint<RBITS, RLIMBS>) -> (Self, Uint<LBITS, LLIMBS>)
121 where
122 Self: Sized,
123 Uint<LBITS, LLIMBS>: Sized,
124 {
125 unpack_uint(&packed)
126 }
127}
128
129const fn assert_consts<
130 const LBITS: usize,
131 const LLIMBS: usize,
132 const RBITS: usize,
133 const RLIMBS: usize,
134 const PBITS: usize,
135 const PLIMBS: usize,
136>() {
137 assert!(
139 LBITS + RBITS <= PBITS,
140 "The sum of LBITS and RBITS must be less than or equal to PBITS"
141 );
142 assert!(
143 LLIMBS + RLIMBS <= PLIMBS,
144 "The sum of LLIMBS and RLIMBS must be less than or equal to PLIMBS"
145 );
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 fn roundtrip<
153 const LBITS: usize,
154 const LLIMBS: usize,
155 const RBITS: usize,
156 const RLIMBS: usize,
157 const PBITS: usize,
158 const PLIMBS: usize,
159 >(
160 lhs: &Uint<LBITS, LLIMBS>,
161 rhs: &Uint<RBITS, RLIMBS>,
162 ) -> bool {
163 let packed = lhs.pack(rhs);
164 let (lhs_unpacked, rhs_unpacked) =
165 <Uint<LBITS, LLIMBS> as Packable<Uint<RBITS, RLIMBS>, Uint<PBITS, PLIMBS>>>::unpack(packed);
166 lhs == &lhs_unpacked && rhs == &rhs_unpacked
167 }
168
169 macro_rules! fuzzy_packable {
170 ($(($bits:literal, $limbs:literal)),+$(,)?) => {
171 paste::paste! {
172 quickcheck::quickcheck! {
173 $(
174 fn [<fuzzy_u $bits:snake>](a: [<U $bits>], b: [<U $bits>]) -> bool {
175 roundtrip::<$bits, $limbs, $bits, $limbs, {$bits * 2}, {$limbs * 2}>(&a, &b)
176 }
177 )*
178 }
179 }
180 };
181 }
182
183 use ruint_1::aliases::*;
184
185 fuzzy_packable!(
186 (64, 1),
187 (128, 2),
188 (256, 4),
189 (512, 8),
190 (1024, 16),
191 (2048, 32),
192 );
193
194 #[test]
195 fn zero() {
196 let output = pack_uint(&Uint::<0, 0>::ZERO, &Uint::<0, 0>::ZERO);
197 assert_eq!(output, Uint::<0, 0>::ZERO);
198 let (lhs, rhs) = unpack_uint::<0, 0, 0, 0, 0, 0>(&output);
199 assert_eq!(lhs, Uint::<0, 0>::ZERO);
200 assert_eq!(rhs, Uint::<0, 0>::ZERO);
201 }
202
203 #[test]
204 #[should_panic]
205 fn assert_consts_panic() {
206 assert_consts::<1, 1, 1, 1, 2, 0>();
207 }
208
209 #[test]
210 #[should_panic]
211 fn assert_consts_panic_2() {
212 assert_consts::<1, 1, 1, 1, 0, 2>();
213 }
214}