rust_hdl_lib_core/
bitvec.rs1use crate::bits::LiteralType;
2
3#[derive(Debug, Clone, Copy, PartialEq, Hash)]
4pub struct BitVec<const N: usize> {
5 bits: [bool; N],
6}
7
8impl<const N: usize> From<[bool; N]> for BitVec<N> {
9 fn from(x: [bool; N]) -> Self {
10 BitVec { bits: x }
11 }
12}
13
14impl<const N: usize> BitVec<N> {
15 pub fn to_u128(&self) -> u128 {
16 assert!(N <= 128);
17 let mut ret = 0_u128;
18 for i in 0..N {
19 if self.bits[N - 1 - i] {
20 ret |= 1 << (N - 1 - i);
21 }
22 }
23 ret
24 }
25
26 pub fn all(&self) -> bool {
27 for i in 0..N {
28 if !self.bits[i] {
29 return false;
30 }
31 }
32 true
33 }
34
35 pub fn any(&self) -> bool {
36 for i in 0..N {
37 if self.bits[i] {
38 return true;
39 }
40 }
41 false
42 }
43
44 pub fn xor(&self) -> bool {
45 let mut ret = false;
46 for i in 0..N {
47 ret ^= self.bits[i];
48 }
49 ret
50 }
51
52 pub fn get_bit(&self, ndx: usize) -> bool {
53 assert!(ndx < N);
54 self.bits[ndx]
55 }
56
57 pub fn replace_bit(&self, ndx: usize, val: bool) -> BitVec<N> {
58 let mut t = self.bits;
59 t[ndx] = val;
60 BitVec { bits: t }
61 }
62
63 pub fn resize<const M: usize>(&self) -> BitVec<M> {
64 let mut t = [false; M];
65 (0..M.min(N)).for_each(|i| {
66 t[i] = self.bits[i];
67 });
68 BitVec { bits: t }
69 }
70}
71
72impl<const N: usize> std::ops::Shr<LiteralType> for BitVec<N> {
73 type Output = BitVec<N>;
74
75 fn shr(self, rhs: LiteralType) -> Self::Output {
76 let rhs = rhs as usize;
77 let mut bits = [false; N];
78 (rhs..N).for_each(|i| {
79 bits[i - rhs] = self.bits[i];
80 });
81 Self { bits }
82 }
83}
84
85impl<const N: usize, const M: usize> std::ops::Shr<BitVec<M>> for BitVec<N> {
86 type Output = BitVec<N>;
87
88 fn shr(self, rhs: BitVec<M>) -> Self::Output {
89 let rhs: LiteralType = rhs.into();
90 self >> rhs
91 }
92}
93
94impl<const N: usize> std::ops::Shl<LiteralType> for BitVec<N> {
95 type Output = BitVec<N>;
96
97 fn shl(self, rhs: LiteralType) -> Self::Output {
98 let rhs = rhs as usize;
99 let mut bits = [false; N];
100 (rhs..N).for_each(|i| {
101 bits[i] = self.bits[i - rhs];
102 });
103 Self { bits }
104 }
105}
106
107impl<const N: usize, const M: usize> std::ops::Shl<BitVec<M>> for BitVec<N> {
108 type Output = BitVec<N>;
109
110 fn shl(self, rhs: BitVec<M>) -> Self::Output {
111 let rhs: LiteralType = rhs.into();
112 self << rhs
113 }
114}
115
116impl<const N: usize> std::ops::BitOr for BitVec<N> {
117 type Output = BitVec<N>;
118
119 fn bitor(self, rhs: Self) -> Self::Output {
120 self.binop(&rhs, |a, b| a | b)
121 }
122}
123
124impl<const N: usize> std::ops::BitAnd for BitVec<N> {
125 type Output = BitVec<N>;
126
127 fn bitand(self, rhs: Self) -> Self::Output {
128 self.binop(&rhs, |a, b| a & b)
129 }
130}
131
132impl<const N: usize> std::ops::BitXor for BitVec<N> {
133 type Output = BitVec<N>;
134
135 fn bitxor(self, rhs: Self) -> Self::Output {
136 self.binop(&rhs, |a, b| a ^ b)
137 }
138}
139
140impl<const N: usize> std::ops::Not for BitVec<N> {
141 type Output = BitVec<N>;
142
143 fn not(self) -> Self::Output {
144 let mut bits = [false; N];
145 (0..N).for_each(|i| {
146 bits[i] = !self.bits[i];
147 });
148 Self { bits }
149 }
150}
151
152impl<const N: usize> std::ops::Add<BitVec<N>> for BitVec<N> {
163 type Output = BitVec<N>;
164
165 fn add(self, rhs: BitVec<N>) -> Self::Output {
166 let mut carry = false;
167 let mut bits = [false; N];
168 (0..N).for_each(|i| {
169 let a = self.bits[i];
170 let b = rhs.bits[i];
171 let c_i = carry;
172 bits[i] = a ^ b ^ c_i;
173 carry = (a & b) | (b & c_i) | (a & c_i);
174 });
175 Self { bits }
176 }
177}
178
179impl<const N: usize> std::ops::Sub<BitVec<N>> for BitVec<N> {
180 type Output = BitVec<N>;
181
182 fn sub(self, rhs: BitVec<N>) -> Self::Output {
183 self + !rhs + 1_u32.into()
184 }
185}
186
187impl<const N: usize> std::cmp::PartialOrd for BitVec<N> {
188 fn partial_cmp(&self, other: &BitVec<N>) -> Option<std::cmp::Ordering> {
189 for i in 0..N {
190 let a = self.bits[N - 1 - i];
191 let b = other.bits[N - 1 - i];
192 if a & !b {
193 return Some(std::cmp::Ordering::Greater);
194 }
195 if !a & b {
196 return Some(std::cmp::Ordering::Less);
197 }
198 }
199 Some(std::cmp::Ordering::Equal)
200 }
201}
202
203impl<const N: usize> BitVec<N> {
204 fn binop<T>(&self, rhs: &Self, op: T) -> Self
205 where
206 T: Fn(&bool, &bool) -> bool,
207 {
208 let mut bits = [false; N];
209 (0..N).for_each(|i| {
210 bits[i] = op(&self.bits[i], &rhs.bits[i]);
211 });
212 Self { bits }
213 }
214}
215
216macro_rules! define_vec_from_uint {
217 ($name:ident) => {
218 impl<const N: usize> From<$name> for BitVec<N> {
219 fn from(mut x: $name) -> Self {
220 let mut bits = [false; N];
221 for i in 0..N {
222 bits[i] = (x & 1) != 0;
223 x >>= 1;
224 }
225 Self { bits }
226 }
227 }
228 };
229}
230
231define_vec_from_uint!(u8);
232define_vec_from_uint!(u16);
233define_vec_from_uint!(u32);
234define_vec_from_uint!(u64);
235define_vec_from_uint!(u128);
236define_vec_from_uint!(usize);
237define_vec_from_uint!(i8);
238define_vec_from_uint!(i16);
239define_vec_from_uint!(i32);
240define_vec_from_uint!(i64);
241define_vec_from_uint!(i128);
242
243macro_rules! define_uint_from_vec {
244 ($name:ident, $width: expr) => {
245 impl<const N: usize> From<BitVec<N>> for $name {
246 fn from(t: BitVec<N>) -> Self {
247 let mut x: $name = 0;
248 for i in 0..N {
249 x <<= 1;
250 x |= if t.bits[N - 1 - i] { 1 } else { 0 }
251 }
252 x
253 }
254 }
255 };
256}
257
258macro_rules! define_int_from_vec {
259 ($name: ident, $width: expr) => {
260 impl<const N: usize> From<BitVec<N>> for $name {
261 fn from(t: BitVec<N>) -> Self {
262 assert!(N <= $width);
263 let mut x: $name = 0;
264 if t.bits[N - 1] {
265 for i in 0..N {
266 x <<= 1;
267 x |= if t.bits[N - 1 - i] { 0 } else { 1 }
268 }
269 x = -x + 1
270 } else {
271 for i in 0..N {
272 x <<= 1;
273 x |= if t.bits[N - 1 - i] { 1 } else { 0 }
274 }
275 }
276 x
277 }
278 }
279 };
280}
281
282define_uint_from_vec!(u8, 8);
283define_uint_from_vec!(u16, 16);
284define_uint_from_vec!(u32, 32);
285define_uint_from_vec!(u64, 64);
286define_uint_from_vec!(u128, 128);
287#[cfg(target_pointer_width = "64")]
288define_uint_from_vec!(usize, 64);
289#[cfg(target_pointer_width = "32")]
290define_uint_from_vec!(usize, 32);
291
292define_int_from_vec!(i8, 8);
293define_int_from_vec!(i16, 16);
294define_int_from_vec!(i32, 32);
295define_int_from_vec!(i64, 64);
296define_int_from_vec!(i128, 128);
297
298#[cfg(test)]
299mod tests {
300 use std::num::Wrapping;
301
302 use super::BitVec;
303
304 #[test]
305 fn or_test() {
306 let a: BitVec<32> = 45_u32.into();
307 let b: BitVec<32> = 10395_u32.into();
308 let c = a | b;
309 let c_u32: u32 = c.into();
310 assert_eq!(c_u32, 45_u32 | 10395_u32)
311 }
312 #[test]
313 fn and_test() {
314 let a: BitVec<32> = 45_u32.into();
315 let b: BitVec<32> = 10395_u32.into();
316 let c = a & b;
317 let c_u32: u32 = c.into();
318 assert_eq!(c_u32, 45_u32 & 10395_u32)
319 }
320 #[test]
321 fn xor_test() {
322 let a: BitVec<32> = 45_u32.into();
323 let b: BitVec<32> = 10395_u32.into();
324 let c = a ^ b;
325 let c_u32: u32 = c.into();
326 assert_eq!(c_u32, 45_u32 ^ 10395_u32)
327 }
328 #[test]
329 fn not_test() {
330 let a: BitVec<32> = 45_u32.into();
331 let c = !a;
332 let c_u32: u32 = c.into();
333 assert_eq!(c_u32, !45_u32);
334 }
335 #[test]
336 fn shr_test() {
337 let a: BitVec<32> = 10395_u32.into();
338 let c = a >> 4;
339 let c_u32: u32 = c.into();
340 assert_eq!(c_u32, 10395_u32 >> 4);
341 }
342 #[test]
343 fn shr_test_pair() {
344 let a: BitVec<32> = 10395_u32.into();
345 let b: BitVec<4> = 4_u32.into();
346 let c = a >> b;
347 let c_u32: u32 = c.into();
348 assert_eq!(c_u32, 10395_u32 >> 4);
349 }
350 #[test]
351 fn shl_test() {
352 let a: BitVec<32> = 10395_u32.into();
353 let c = a << 24;
354 let c_u32: u32 = c.into();
355 assert_eq!(c_u32, 10395_u32 << 24);
356 }
357 #[test]
358 fn shl_test_pair() {
359 let a: BitVec<32> = 10395_u32.into();
360 let b: BitVec<4> = 4_u32.into();
361 let c = a << b;
362 let c_u32: u32 = c.into();
363 assert_eq!(c_u32, 10395_u32 << 4);
364 }
365 #[test]
366 fn add_works() {
367 let a: BitVec<32> = 10234_u32.into();
368 let b: BitVec<32> = 19423_u32.into();
369 let c = a + b;
370 let c_u32: u32 = c.into();
371 assert_eq!(c_u32, 10234_u32 + 19423_u32);
372 }
373 #[test]
374 fn add_works_with_overflow() {
375 let x = 2_042_102_334_u32;
376 let y = 2_942_142_512_u32;
377 let a: BitVec<32> = x.into();
378 let b: BitVec<32> = y.into();
379 let c = a + b;
380 let c_u32: u32 = c.into();
381 assert_eq!(Wrapping(c_u32), Wrapping(x) + Wrapping(y));
382 }
383 #[test]
384 fn sub_works() {
385 let x = 2_042_102_334_u32;
386 let y = 2_942_142_512_u32;
387 let a: BitVec<32> = x.into();
388 let b: BitVec<32> = y.into();
389 let c = a - b;
390 let c_u32: u32 = c.into();
391 assert_eq!(Wrapping(c_u32), Wrapping(x) - Wrapping(y));
392 }
393 #[test]
394 fn eq_works() {
395 let x = 2_032_142_351_u32;
396 let y = 2_942_142_512_u32;
397 let a: BitVec<32> = x.into();
398 let b: BitVec<32> = x.into();
399 let c: BitVec<32> = y.into();
400 assert_eq!(a, b);
401 assert_ne!(a, c)
402 }
403 #[test]
404 fn all_works() {
405 let a: BitVec<48> = 0xFFFF_FFFF_FFFF_u64.into();
406 assert!(a.all());
407 assert!(a.any());
408 }
409}