rust_hdl_lib_core/
bitvec.rs

1use crate::bits::LiteralType;
2
3#[derive(Debug, Clone, Copy, PartialEq, Hash)]
4pub struct BitVec<const N: usize> {
5    bits: [bool; N],
6}
7
8impl<const N: usize> From<[bool; N]> for BitVec<N> {
9    fn from(x: [bool; N]) -> Self {
10        BitVec { bits: x }
11    }
12}
13
14impl<const N: usize> BitVec<N> {
15    pub fn to_u128(&self) -> u128 {
16        assert!(N <= 128);
17        let mut ret = 0_u128;
18        for i in 0..N {
19            if self.bits[N - 1 - i] {
20                ret |= 1 << (N - 1 - i);
21            }
22        }
23        ret
24    }
25
26    pub fn all(&self) -> bool {
27        for i in 0..N {
28            if !self.bits[i] {
29                return false;
30            }
31        }
32        true
33    }
34
35    pub fn any(&self) -> bool {
36        for i in 0..N {
37            if self.bits[i] {
38                return true;
39            }
40        }
41        false
42    }
43
44    pub fn xor(&self) -> bool {
45        let mut ret = false;
46        for i in 0..N {
47            ret ^= self.bits[i];
48        }
49        ret
50    }
51
52    pub fn get_bit(&self, ndx: usize) -> bool {
53        assert!(ndx < N);
54        self.bits[ndx]
55    }
56
57    pub fn replace_bit(&self, ndx: usize, val: bool) -> BitVec<N> {
58        let mut t = self.bits;
59        t[ndx] = val;
60        BitVec { bits: t }
61    }
62
63    pub fn resize<const M: usize>(&self) -> BitVec<M> {
64        let mut t = [false; M];
65        (0..M.min(N)).for_each(|i| {
66            t[i] = self.bits[i];
67        });
68        BitVec { bits: t }
69    }
70}
71
72impl<const N: usize> std::ops::Shr<LiteralType> for BitVec<N> {
73    type Output = BitVec<N>;
74
75    fn shr(self, rhs: LiteralType) -> Self::Output {
76        let rhs = rhs as usize;
77        let mut bits = [false; N];
78        (rhs..N).for_each(|i| {
79            bits[i - rhs] = self.bits[i];
80        });
81        Self { bits }
82    }
83}
84
85impl<const N: usize, const M: usize> std::ops::Shr<BitVec<M>> for BitVec<N> {
86    type Output = BitVec<N>;
87
88    fn shr(self, rhs: BitVec<M>) -> Self::Output {
89        let rhs: LiteralType = rhs.into();
90        self >> rhs
91    }
92}
93
94impl<const N: usize> std::ops::Shl<LiteralType> for BitVec<N> {
95    type Output = BitVec<N>;
96
97    fn shl(self, rhs: LiteralType) -> Self::Output {
98        let rhs = rhs as usize;
99        let mut bits = [false; N];
100        (rhs..N).for_each(|i| {
101            bits[i] = self.bits[i - rhs];
102        });
103        Self { bits }
104    }
105}
106
107impl<const N: usize, const M: usize> std::ops::Shl<BitVec<M>> for BitVec<N> {
108    type Output = BitVec<N>;
109
110    fn shl(self, rhs: BitVec<M>) -> Self::Output {
111        let rhs: LiteralType = rhs.into();
112        self << rhs
113    }
114}
115
116impl<const N: usize> std::ops::BitOr for BitVec<N> {
117    type Output = BitVec<N>;
118
119    fn bitor(self, rhs: Self) -> Self::Output {
120        self.binop(&rhs, |a, b| a | b)
121    }
122}
123
124impl<const N: usize> std::ops::BitAnd for BitVec<N> {
125    type Output = BitVec<N>;
126
127    fn bitand(self, rhs: Self) -> Self::Output {
128        self.binop(&rhs, |a, b| a & b)
129    }
130}
131
132impl<const N: usize> std::ops::BitXor for BitVec<N> {
133    type Output = BitVec<N>;
134
135    fn bitxor(self, rhs: Self) -> Self::Output {
136        self.binop(&rhs, |a, b| a ^ b)
137    }
138}
139
140impl<const N: usize> std::ops::Not for BitVec<N> {
141    type Output = BitVec<N>;
142
143    fn not(self) -> Self::Output {
144        let mut bits = [false; N];
145        (0..N).for_each(|i| {
146            bits[i] = !self.bits[i];
147        });
148        Self { bits }
149    }
150}
151
152// Add with cary
153// A  B  C  | X  Co | A ^ (B ^ C) | A B + B C-IN + A C-IN
154// 0  0  0  | 0  0  | 0           | 0
155// 0  0  1  | 1  0  | 1           | 0
156// 0  1  0  | 1  0  | 1           | 0
157// 0  1  1  | 0  1  | 0           | 1
158// 1  0  0  | 1  0  | 1           | 0
159// 1  0  1  | 0  1  | 0           | 1
160// 1  1  0  | 0  1  | 0           | 1
161// 1  1  1  | 1  1  | 1           | 1
162impl<const N: usize> std::ops::Add<BitVec<N>> for BitVec<N> {
163    type Output = BitVec<N>;
164
165    fn add(self, rhs: BitVec<N>) -> Self::Output {
166        let mut carry = false;
167        let mut bits = [false; N];
168        (0..N).for_each(|i| {
169            let a = self.bits[i];
170            let b = rhs.bits[i];
171            let c_i = carry;
172            bits[i] = a ^ b ^ c_i;
173            carry = (a & b) | (b & c_i) | (a & c_i);
174        });
175        Self { bits }
176    }
177}
178
179impl<const N: usize> std::ops::Sub<BitVec<N>> for BitVec<N> {
180    type Output = BitVec<N>;
181
182    fn sub(self, rhs: BitVec<N>) -> Self::Output {
183        self + !rhs + 1_u32.into()
184    }
185}
186
187impl<const N: usize> std::cmp::PartialOrd for BitVec<N> {
188    fn partial_cmp(&self, other: &BitVec<N>) -> Option<std::cmp::Ordering> {
189        for i in 0..N {
190            let a = self.bits[N - 1 - i];
191            let b = other.bits[N - 1 - i];
192            if a & !b {
193                return Some(std::cmp::Ordering::Greater);
194            }
195            if !a & b {
196                return Some(std::cmp::Ordering::Less);
197            }
198        }
199        Some(std::cmp::Ordering::Equal)
200    }
201}
202
203impl<const N: usize> BitVec<N> {
204    fn binop<T>(&self, rhs: &Self, op: T) -> Self
205    where
206        T: Fn(&bool, &bool) -> bool,
207    {
208        let mut bits = [false; N];
209        (0..N).for_each(|i| {
210            bits[i] = op(&self.bits[i], &rhs.bits[i]);
211        });
212        Self { bits }
213    }
214}
215
216macro_rules! define_vec_from_uint {
217    ($name:ident) => {
218        impl<const N: usize> From<$name> for BitVec<N> {
219            fn from(mut x: $name) -> Self {
220                let mut bits = [false; N];
221                for i in 0..N {
222                    bits[i] = (x & 1) != 0;
223                    x >>= 1;
224                }
225                Self { bits }
226            }
227        }
228    };
229}
230
231define_vec_from_uint!(u8);
232define_vec_from_uint!(u16);
233define_vec_from_uint!(u32);
234define_vec_from_uint!(u64);
235define_vec_from_uint!(u128);
236define_vec_from_uint!(usize);
237define_vec_from_uint!(i8);
238define_vec_from_uint!(i16);
239define_vec_from_uint!(i32);
240define_vec_from_uint!(i64);
241define_vec_from_uint!(i128);
242
243macro_rules! define_uint_from_vec {
244    ($name:ident, $width: expr) => {
245        impl<const N: usize> From<BitVec<N>> for $name {
246            fn from(t: BitVec<N>) -> Self {
247                let mut x: $name = 0;
248                for i in 0..N {
249                    x <<= 1;
250                    x |= if t.bits[N - 1 - i] { 1 } else { 0 }
251                }
252                x
253            }
254        }
255    };
256}
257
258macro_rules! define_int_from_vec {
259    ($name: ident, $width: expr) => {
260        impl<const N: usize> From<BitVec<N>> for $name {
261            fn from(t: BitVec<N>) -> Self {
262                assert!(N <= $width);
263                let mut x: $name = 0;
264                if t.bits[N - 1] {
265                    for i in 0..N {
266                        x <<= 1;
267                        x |= if t.bits[N - 1 - i] { 0 } else { 1 }
268                    }
269                    x = -x + 1
270                } else {
271                    for i in 0..N {
272                        x <<= 1;
273                        x |= if t.bits[N - 1 - i] { 1 } else { 0 }
274                    }
275                }
276                x
277            }
278        }
279    };
280}
281
282define_uint_from_vec!(u8, 8);
283define_uint_from_vec!(u16, 16);
284define_uint_from_vec!(u32, 32);
285define_uint_from_vec!(u64, 64);
286define_uint_from_vec!(u128, 128);
287#[cfg(target_pointer_width = "64")]
288define_uint_from_vec!(usize, 64);
289#[cfg(target_pointer_width = "32")]
290define_uint_from_vec!(usize, 32);
291
292define_int_from_vec!(i8, 8);
293define_int_from_vec!(i16, 16);
294define_int_from_vec!(i32, 32);
295define_int_from_vec!(i64, 64);
296define_int_from_vec!(i128, 128);
297
298#[cfg(test)]
299mod tests {
300    use std::num::Wrapping;
301
302    use super::BitVec;
303
304    #[test]
305    fn or_test() {
306        let a: BitVec<32> = 45_u32.into();
307        let b: BitVec<32> = 10395_u32.into();
308        let c = a | b;
309        let c_u32: u32 = c.into();
310        assert_eq!(c_u32, 45_u32 | 10395_u32)
311    }
312    #[test]
313    fn and_test() {
314        let a: BitVec<32> = 45_u32.into();
315        let b: BitVec<32> = 10395_u32.into();
316        let c = a & b;
317        let c_u32: u32 = c.into();
318        assert_eq!(c_u32, 45_u32 & 10395_u32)
319    }
320    #[test]
321    fn xor_test() {
322        let a: BitVec<32> = 45_u32.into();
323        let b: BitVec<32> = 10395_u32.into();
324        let c = a ^ b;
325        let c_u32: u32 = c.into();
326        assert_eq!(c_u32, 45_u32 ^ 10395_u32)
327    }
328    #[test]
329    fn not_test() {
330        let a: BitVec<32> = 45_u32.into();
331        let c = !a;
332        let c_u32: u32 = c.into();
333        assert_eq!(c_u32, !45_u32);
334    }
335    #[test]
336    fn shr_test() {
337        let a: BitVec<32> = 10395_u32.into();
338        let c = a >> 4;
339        let c_u32: u32 = c.into();
340        assert_eq!(c_u32, 10395_u32 >> 4);
341    }
342    #[test]
343    fn shr_test_pair() {
344        let a: BitVec<32> = 10395_u32.into();
345        let b: BitVec<4> = 4_u32.into();
346        let c = a >> b;
347        let c_u32: u32 = c.into();
348        assert_eq!(c_u32, 10395_u32 >> 4);
349    }
350    #[test]
351    fn shl_test() {
352        let a: BitVec<32> = 10395_u32.into();
353        let c = a << 24;
354        let c_u32: u32 = c.into();
355        assert_eq!(c_u32, 10395_u32 << 24);
356    }
357    #[test]
358    fn shl_test_pair() {
359        let a: BitVec<32> = 10395_u32.into();
360        let b: BitVec<4> = 4_u32.into();
361        let c = a << b;
362        let c_u32: u32 = c.into();
363        assert_eq!(c_u32, 10395_u32 << 4);
364    }
365    #[test]
366    fn add_works() {
367        let a: BitVec<32> = 10234_u32.into();
368        let b: BitVec<32> = 19423_u32.into();
369        let c = a + b;
370        let c_u32: u32 = c.into();
371        assert_eq!(c_u32, 10234_u32 + 19423_u32);
372    }
373    #[test]
374    fn add_works_with_overflow() {
375        let x = 2_042_102_334_u32;
376        let y = 2_942_142_512_u32;
377        let a: BitVec<32> = x.into();
378        let b: BitVec<32> = y.into();
379        let c = a + b;
380        let c_u32: u32 = c.into();
381        assert_eq!(Wrapping(c_u32), Wrapping(x) + Wrapping(y));
382    }
383    #[test]
384    fn sub_works() {
385        let x = 2_042_102_334_u32;
386        let y = 2_942_142_512_u32;
387        let a: BitVec<32> = x.into();
388        let b: BitVec<32> = y.into();
389        let c = a - b;
390        let c_u32: u32 = c.into();
391        assert_eq!(Wrapping(c_u32), Wrapping(x) - Wrapping(y));
392    }
393    #[test]
394    fn eq_works() {
395        let x = 2_032_142_351_u32;
396        let y = 2_942_142_512_u32;
397        let a: BitVec<32> = x.into();
398        let b: BitVec<32> = x.into();
399        let c: BitVec<32> = y.into();
400        assert_eq!(a, b);
401        assert_ne!(a, c)
402    }
403    #[test]
404    fn all_works() {
405        let a: BitVec<48> = 0xFFFF_FFFF_FFFF_u64.into();
406        assert!(a.all());
407        assert!(a.any());
408    }
409}