1use super::bits::Bits;
2use crate::bits::{bit_cast, LiteralType, LITERAL_BITS};
3use num_bigint::{BigInt, Sign};
4use num_traits::cast::ToPrimitive;
5use std::fmt::{Debug, Formatter, LowerHex, UpperHex};
6use std::num::Wrapping;
7
8pub type SignedLiteralType = i64;
9pub const SIGNED_LITERAL_BITS: usize = 64;
10
11#[derive(Clone, Debug, Copy, PartialEq, Default)]
12pub struct Signed<const N: usize>(Bits<N>);
13
14pub trait ToSignedBits {
15 fn to_signed_bits<const N: usize>(self) -> Signed<N>;
16}
17
18impl ToSignedBits for i8 {
19 fn to_signed_bits<const N: usize>(self) -> Signed<N> {
20 assert!(N <= 8);
21 (self as SignedLiteralType).into()
22 }
23}
24
25impl ToSignedBits for i16 {
26 fn to_signed_bits<const N: usize>(self) -> Signed<N> {
27 assert!(N <= 16);
28 (self as SignedLiteralType).into()
29 }
30}
31
32impl ToSignedBits for i32 {
33 fn to_signed_bits<const N: usize>(self) -> Signed<N> {
34 assert!(N <= 32);
35 (self as SignedLiteralType).into()
36 }
37}
38
39impl ToSignedBits for i64 {
40 fn to_signed_bits<const N: usize>(self) -> Signed<N> {
41 assert!(N <= 64);
42 (self as SignedLiteralType).into()
43 }
44}
45
46impl<const N: usize> LowerHex for Signed<N> {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 LowerHex::fmt(&self.bigint(), f)
49 }
50}
51
52impl<const N: usize> UpperHex for Signed<N> {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 UpperHex::fmt(&self.bigint(), f)
55 }
56}
57
58impl<const N: usize> Signed<N> {
59 pub fn min() -> BigInt {
60 let ret = -Self::max() - 1;
62 ret
63 }
64 pub fn max() -> BigInt {
65 BigInt::from(2).pow((N - 1) as u32) - 1
68 }
69 pub fn sign_bit(&self) -> bool {
70 self.0.get_bit(N - 1)
71 }
72 pub fn bigint(&self) -> BigInt {
73 let mut ret = BigInt::default();
74 if !self.sign_bit() {
75 for i in 0..N {
76 ret.set_bit(i as u64, self.get_bit(i))
77 }
78 ret
79 } else {
80 for i in 0..N {
81 ret.set_bit(i as u64, !self.get_bit(i))
82 }
83 -ret - 1
84 }
85 }
86 pub fn get_bit(&self, ndx: usize) -> bool {
87 self.0.get_bit(ndx)
88 }
89 pub fn get_bits<const M: usize>(&self, index: usize) -> Signed<M> {
90 Signed(self.0.get_bits::<M>(index))
91 }
92 pub fn inner(&self) -> Bits<N> {
93 self.0
94 }
95}
96
97impl<const N: usize> From<BigInt> for Signed<N> {
98 fn from(x: BigInt) -> Self {
99 assert!(x.bits() <= N as u64);
100 if N <= LITERAL_BITS {
101 if x.sign() == Sign::Minus {
102 -Signed(Bits::from((-x).to_u64().unwrap()))
103 } else {
104 Signed(Bits::from(x.to_u64().unwrap()))
105 }
106 } else {
107 if x.sign() == Sign::Minus {
108 -Signed(Bits::from((-x).to_biguint().unwrap()))
109 } else {
110 Signed(Bits::from(x.to_biguint().unwrap()))
111 }
112 }
113 }
114}
115
116impl<const N: usize> std::ops::Neg for Signed<N> {
117 type Output = Signed<N>;
118
119 fn neg(self) -> Self::Output {
120 Signed(match self.0 {
121 Bits::Short(x) => Bits::Short((Wrapping(0) - Wrapping(x.short())).0.into()),
122 Bits::Long(x) => {
123 let mut val = [false; N];
124 for ndx in 0..N {
125 val[ndx] = !x.get_bit(ndx);
126 }
127 Bits::Long(val.into()) + 1
128 }
129 })
130 }
131}
132
133impl<const N: usize> std::ops::Add<Signed<N>> for Signed<N> {
134 type Output = Signed<N>;
135
136 fn add(self, rhs: Signed<N>) -> Self::Output {
137 Self(self.0 + rhs.0)
138 }
139}
140
141impl<const N: usize> std::ops::Sub<Signed<N>> for Signed<N> {
142 type Output = Signed<N>;
143
144 fn sub(self, rhs: Signed<N>) -> Self::Output {
145 Self(self.0 - rhs.0)
146 }
147}
148
149impl std::ops::Mul<Signed<16>> for Signed<16> {
150 type Output = Signed<32>;
151
152 fn mul(self, rhs: Signed<16>) -> Self::Output {
153 Self::Output::from(self.bigint() * rhs.bigint())
154 }
155}
156
157impl<const N: usize> std::cmp::PartialOrd for Signed<N> {
158 fn partial_cmp(&self, other: &Signed<N>) -> Option<std::cmp::Ordering> {
159 self.bigint().partial_cmp(&other.bigint())
160 }
161}
162
163impl<const N: usize> From<SignedLiteralType> for Signed<N> {
164 fn from(x: SignedLiteralType) -> Self {
165 if x > 0 {
166 Self(Bits::from(x as LiteralType))
167 } else {
168 -Self(Bits::from((-x) as LiteralType))
169 }
170 }
171}
172
173pub fn signed<const N: usize>(x: SignedLiteralType) -> Signed<N> {
174 let t: Signed<N> = x.into();
175 t
176}
177
178pub fn signed_bit_cast<const M: usize, const N: usize>(x: Signed<N>) -> Signed<M> {
179 if x.sign_bit() {
180 -signed_bit_cast(-x)
181 } else {
182 Signed(bit_cast(x.0))
183 }
184}
185
186pub fn signed_cast<const N: usize>(x: Bits<N>) -> Signed<N> {
187 Signed(x)
188}
189
190pub fn unsigned_cast<const N: usize>(x: Signed<N>) -> Bits<N> {
191 x.0
192}
193
194pub fn unsigned_bit_cast<const M: usize, const N: usize>(x: Signed<N>) -> Bits<M> {
195 bit_cast(x.0)
196}
197
198#[cfg(test)]
199mod tests {
200 use crate::bits::Bits;
201 use crate::signed::{signed_bit_cast, unsigned_bit_cast, Signed};
202 use num_bigint::BigInt;
203
204 #[test]
205 fn test_min_range_correct() {
206 assert_eq!(Signed::<8>::min(), i8::MIN.into());
207 assert_eq!(Signed::<16>::min(), i16::MIN.into());
208 assert_eq!(Signed::<32>::min(), i32::MIN.into());
209 assert_eq!(Signed::<64>::min(), i64::MIN.into());
210 assert_eq!(Signed::<128>::min(), i128::MIN.into());
211 }
212
213 #[test]
214 fn test_max_range_correct() {
215 assert_eq!(Signed::<8>::max(), i8::MAX.into());
216 assert_eq!(Signed::<16>::max(), i16::MAX.into());
217 assert_eq!(Signed::<32>::max(), i32::MAX.into());
218 assert_eq!(Signed::<64>::max(), i64::MAX.into());
219 assert_eq!(Signed::<128>::max(), i128::MAX.into());
220 }
221
222 fn run_import_tests<const N: usize>(skip: u32) {
223 let mut q = Signed::<N>::min();
225 while q <= Signed::<N>::max() {
226 let x: Signed<N> = q.clone().into();
227 for i in 0..N {
228 assert_eq!(x.get_bit(i), q.bit(i as u64))
229 }
230 assert_eq!(x.bigint(), q);
231 q += skip;
232 }
233 }
234
235 #[test]
236 fn test_signed_import_small() {
237 run_import_tests::<5>(1);
238 }
239
240 #[test]
241 fn test_signed_import_large() {
242 run_import_tests::<34>(1 << 16);
243 }
244
245 #[test]
246 fn time_adds_bigint() {
247 let now = std::time::Instant::now();
248 for _iter in 0..10 {
249 let mut q = BigInt::from(0_u32);
250 for _i in 0..1_000_000 {
251 q = q + 1;
252 }
253 }
254 let elapsed = std::time::Instant::now() - now;
255 println!("Duration: {}", elapsed.as_micros());
256 }
257
258 #[test]
259 fn time_adds_bitvec() {
260 let now = std::time::Instant::now();
261 for _iter in 0..10 {
262 let mut q = Bits::<40>::from(0);
263 for _i in 0..1_000_000 {
264 q = q + 1;
265 }
266 }
267 let elapsed = std::time::Instant::now() - now;
268 println!("Duration: {}", elapsed.as_micros());
269 }
270
271 #[test]
272 fn time_adds_bitvec_small() {
273 let now = std::time::Instant::now();
274 for _iter in 0..10 {
275 let mut q = Bits::<16>::from(0);
276 for _i in 0..1_000_000 {
277 q = q + 1;
278 }
279 }
280 let elapsed = std::time::Instant::now() - now;
281 println!("Duration: {}", elapsed.as_micros());
282 }
283
284 #[test]
285 fn signed_displays_correctly() {
286 println!("{:x}", Signed::<16>::from(-23));
287 }
288
289 #[test]
290 fn test_signed_cast() {
291 let x = Signed::<16>::from(-23);
292 let y: Signed<40> = signed_bit_cast(x);
293 assert_eq!(y, Signed::<40>::from(-23));
294 }
295
296 #[test]
297 fn test_unsigned_cast() {
298 let x = Signed::<16>::from(-23);
299 let y: Bits<16> = unsigned_bit_cast(x);
300 assert_eq!(y, Bits::<16>::from(0xFFe9))
301 }
302
303 #[test]
304 fn test_neg_operator() {
305 let x = Signed::<16>::from(23);
306 assert_eq!(x.bigint(), BigInt::from(23));
307 let x = -x;
308 assert_eq!(x.bigint(), BigInt::from(-23));
309 let x = -x;
310 assert_eq!(x.bigint(), BigInt::from(23));
311 }
312
313 #[test]
314 fn test_neg_operator_larger() {
315 let t: BigInt = BigInt::from(23) << 32;
316 let x = Signed::<48>::from(t.clone());
317 assert_eq!(x.bigint(), t.clone());
318 let x = -x;
319 assert_eq!(x.bigint(), -t.clone());
320 let x = -x;
321 assert_eq!(x.bigint(), t.clone());
322 }
323}