smallint/
ops.rs

1use crate::smallint::{SmallIntType, SmallUintType};
2use crate::{SmallInt, SmallUint};
3use core::mem::ManuallyDrop;
4use core::ops::{Add, Mul, Neg, Sub};
5
6impl Neg for SmallInt {
7    type Output = Self;
8
9    fn neg(self) -> Self::Output {
10        match self.0 {
11            SmallIntType::Inline(i) => SmallInt(SmallIntType::Inline(-i)),
12            SmallIntType::Heap((r, s)) => {
13                let size = usize::try_from(s.abs()).unwrap();
14                let slice = unsafe { core::slice::from_raw_parts(r, size) };
15                let mut ret = vec![0; size];
16                ret.clone_from_slice(slice);
17                let mut val = ManuallyDrop::new(ret.into_boxed_slice());
18                SmallInt(SmallIntType::Heap((val.as_mut_ptr(), -s)))
19            }
20        }
21    }
22}
23
24macro_rules! basic_op {
25    ($imp:ident, $lower:ident, $typ:ty, $fun:ident) => {
26        impl<'a, 'b> $imp<&'a $typ> for &'b $typ {
27            type Output = $typ;
28
29            fn $lower(self, rhs: &$typ) -> Self::Output {
30                $fun(self, rhs)
31            }
32        }
33
34        impl<'a> $imp<$typ> for &'a $typ {
35            type Output = $typ;
36
37            fn $lower(self, rhs: $typ) -> Self::Output {
38                self.$lower(&rhs)
39            }
40        }
41
42        impl<'a> $imp<&'a $typ> for $typ {
43            type Output = $typ;
44
45            fn $lower(self, rhs: &$typ) -> Self::Output {
46                (&self).$lower(rhs)
47            }
48        }
49
50        impl $imp<$typ> for $typ {
51            type Output = $typ;
52
53            fn $lower(self, rhs: $typ) -> Self::Output {
54                (&self).$lower(&rhs)
55            }
56        }
57    };
58}
59
60fn add_two_slices(slice1: &[u32], slice2: &[u32]) -> Vec<u32> {
61    let s = slice1.len();
62    let j = slice2.len();
63
64    let larger = std::cmp::max(s, j);
65    let mut res = Vec::with_capacity(larger + 1);
66    let mut carry = false;
67
68    for t in 0..larger {
69        let value1 = if t < s { slice1[t] } else { 0 };
70
71        let value2 = if t < j { slice2[t] } else { 0 };
72
73        let (val, overflow) = value1.overflowing_add(value2);
74        let (cval, coverflow) = val.overflowing_add(carry as u32);
75        res.push(cval);
76        carry = overflow | coverflow;
77    }
78
79    if carry {
80        res.push(1);
81    }
82    while res.len() != 1 && res[res.len() - 1] == 0 {
83        res.pop();
84    }
85    res
86}
87
88fn add(a: &SmallUint, b: &SmallUint) -> SmallUint {
89    match (&a.0, &b.0) {
90        (&SmallUintType::Inline(i), &SmallUintType::Inline(j)) => match i.overflowing_add(j) {
91            (t, false) => SmallUint(SmallUintType::Inline(t)),
92            (t, true) => {
93                let mut res = [0, 0, 0, 0, 1];
94
95                let mut v = t;
96                #[allow(clippy::needless_range_loop)]
97                for r in 0..4 {
98                    res[r] = v as u32;
99
100                    v >>= 32;
101                }
102
103                let mut slice = ManuallyDrop::new(<Box<[u32]>>::from(res));
104
105                SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), 5)))
106            }
107        },
108        (&SmallUintType::Heap((r, s)), &SmallUintType::Inline(i))
109        | (&SmallUintType::Inline(i), &SmallUintType::Heap((r, s))) => {
110            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
111
112            let mut res = [0, 0, 0, 0];
113
114            let mut v = i;
115            #[allow(clippy::needless_range_loop)]
116            for r in 0..4 {
117                res[r] = v as u32;
118
119                v >>= 32;
120            }
121
122            let result = add_two_slices(slice1, &res[..]);
123            let size = result.len();
124
125            let mut slice = ManuallyDrop::new(result.into_boxed_slice());
126
127            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
128        }
129        (&SmallUintType::Heap((r, s)), &SmallUintType::Heap((i, j))) => {
130            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
131            let slice2 = unsafe { core::slice::from_raw_parts(i, j) };
132
133            let res = add_two_slices(slice1, slice2);
134            let size = res.len();
135
136            let mut slice = ManuallyDrop::new(res.into_boxed_slice());
137
138            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
139        }
140    }
141}
142
143basic_op!(Add, add, SmallUint, add);
144
145fn add_signed(a: &SmallInt, b: &SmallInt) -> SmallInt {
146    let a_sign;
147    match &a.0 {
148        SmallIntType::Inline(i) => a_sign = i.signum() as i8,
149        SmallIntType::Heap((_, s)) => a_sign = s.signum() as i8,
150    }
151
152    let b_sign;
153    match &b.0 {
154        SmallIntType::Inline(i) => b_sign = i.signum() as i8,
155        SmallIntType::Heap((_, s)) => b_sign = s.signum() as i8,
156    }
157
158    match (a_sign, b_sign) {
159        x if (x.0 >= 0 && x.1 >= 0) => SmallInt::from(
160            SmallUint::from_smallint_unsigned(a.clone())
161                + SmallUint::from_smallint_unsigned(b.clone()),
162        ),
163        x if (x.0 < 0 && x.1 < 0) => -SmallInt::from(
164            SmallUint::from_smallint_unsigned(a.clone())
165                + SmallUint::from_smallint_unsigned(b.clone()),
166        ),
167
168        x if (x.0 >= 0 && x.1 < 0) => {
169            let s = SmallUint::from_smallint_unsigned(a.clone());
170            let b = SmallUint::from_smallint_unsigned(b.clone());
171            if b <= s {
172                SmallInt::from(s - b)
173            } else {
174                -SmallInt::from(b - s)
175            }
176        }
177
178        x if (x.0 < 0 && x.1 >= 0) => {
179            let s = SmallUint::from_smallint_unsigned(a.clone());
180            let b = SmallUint::from_smallint_unsigned(b.clone());
181            if s <= b {
182                SmallInt::from(b - s)
183            } else {
184                -SmallInt::from(s - b)
185            }
186        }
187        (_, _) => {
188            panic!("This shouldn't happen. ");
189        }
190    }
191}
192
193basic_op!(Add, add, SmallInt, add_signed);
194
195fn sub_two_slices(slice1: &[u32], slice2: &[u32]) -> Vec<u32> {
196    let b = slice1.len();
197    let s = slice2.len();
198
199    if b < s {
200        panic!("First number is smaller than second.");
201    }
202
203    let mut res = Vec::with_capacity(std::cmp::max(s, b));
204    let mut borrow = false;
205
206    for i in 0..b {
207        let mut value1 = slice1[i];
208
209        let value2 = if i < s { slice2[i] } else { 0 };
210
211        if borrow {
212            let (temp, b) = value1.overflowing_sub(1);
213            value1 = temp;
214            borrow = b;
215        }
216
217        if value2 > value1 {
218            borrow = true;
219        }
220
221        let val = value1.wrapping_sub(value2);
222        res.push(val);
223    }
224
225    if borrow {
226        panic!("First number is smaller than second. ");
227    }
228
229    res
230}
231
232fn sub(a: &SmallUint, b: &SmallUint) -> SmallUint {
233    match (&a.0, &b.0) {
234        (&SmallUintType::Inline(i), &SmallUintType::Inline(j)) => {
235            if let (t, false) = i.overflowing_sub(j) {
236                SmallUint(SmallUintType::Inline(t))
237            } else {
238                panic!("First number is smaller than second. ");
239            }
240        }
241        (&SmallUintType::Heap((r, s)), &SmallUintType::Inline(i)) => {
242            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
243
244            let mut res = [0, 0, 0, 0];
245
246            let mut v = i;
247            #[allow(clippy::needless_range_loop)]
248            for r in 0..4 {
249                res[r] = v as u32;
250
251                v >>= 32;
252            }
253
254            let result = sub_two_slices(slice1, &res[..]);
255            let size = result.len();
256
257            let mut slice = ManuallyDrop::new(result.into_boxed_slice());
258
259            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
260        }
261        (&SmallUintType::Inline(_), &SmallUintType::Heap((_, _))) => {
262            panic!("First number is smaller than second. ");
263        }
264        (&SmallUintType::Heap((r, s)), &SmallUintType::Heap((i, j))) => {
265            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
266            let slice2 = unsafe { core::slice::from_raw_parts(i, j) };
267
268            let res = sub_two_slices(slice1, slice2);
269            let size = res.len();
270
271            let mut slice = ManuallyDrop::new(res.into_boxed_slice());
272
273            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
274        }
275    }
276}
277
278basic_op!(Sub, sub, SmallUint, sub);
279
280fn sub_signed(a: &SmallInt, b: &SmallInt) -> SmallInt {
281    a + (-b.clone())
282}
283
284basic_op!(Sub, sub, SmallInt, sub_signed);
285
286// Taken from https://github.com/rust-lang/rust/issues/85532#issuecomment-916309635. Credit to
287// AaronKutch.
288const fn carrying_mul_u128(lhs: u128, rhs: u128, carry: u128) -> (u128, u128) {
289    //                       [rhs_hi]  [rhs_lo]
290    //                       [lhs_hi]  [lhs_lo]
291    //                     X___________________
292    //                       [------tmp0------]
293    //             [------tmp1------]
294    //             [------tmp2------]
295    //     [------tmp3------]
296    //                       [-------add------]
297    // +_______________________________________
298    //                       [------sum0------]
299    //     [------sum1------]
300
301    let lhs_lo = lhs as u64;
302    let rhs_lo = rhs as u64;
303    let lhs_hi = (lhs.wrapping_shr(64)) as u64;
304    let rhs_hi = (rhs.wrapping_shr(64)) as u64;
305    let tmp0 = (lhs_lo as u128).wrapping_mul(rhs_lo as u128);
306    let tmp1 = (lhs_lo as u128).wrapping_mul(rhs_hi as u128);
307    let tmp2 = (lhs_hi as u128).wrapping_mul(rhs_lo as u128);
308    let tmp3 = (lhs_hi as u128).wrapping_mul(rhs_hi as u128);
309    // tmp1 and tmp2 straddle the boundary. We have to handle three carries
310    let (sum0, carry0) = tmp0.overflowing_add(tmp1.wrapping_shl(64));
311    let (sum0, carry1) = sum0.overflowing_add(tmp2.wrapping_shl(64));
312    let (sum0, carry2) = sum0.overflowing_add(carry);
313    let sum1 = tmp3
314        .wrapping_add(tmp1.wrapping_shr(64))
315        .wrapping_add(tmp2.wrapping_shr(64))
316        .wrapping_add(carry0 as u128)
317        .wrapping_add(carry1 as u128)
318        .wrapping_add(carry2 as u128);
319    (sum0, sum1)
320}
321
322fn mul_two_slices(slice1: &[u32], slice2: &[u32]) -> Vec<u32> {
323    // https://en.wikipedia.org/wiki/Karatsuba_algorithm
324
325    let l1 = slice1.len();
326    let l2 = slice2.len();
327
328    if l1 == 0 || l2 == 0 {
329        return vec![];
330    } else if l1 == 1 {
331        let mut overflow = 0;
332        let mut res: Vec<u32> = Vec::with_capacity(l2 + 1);
333
334        #[allow(clippy::needless_range_loop)]
335        for i in 0..l2 {
336            let mut r = (slice2[i] as u64) * (slice1[0] as u64);
337            r += overflow as u64;
338            let m = r as u32;
339            overflow = (r >> 32) as u32;
340            res.push(m);
341        }
342
343        if overflow != 0 {
344            res.push(overflow);
345        }
346
347        return res;
348    } else if l2 == 1 {
349        let mut overflow = 0;
350        let mut res: Vec<u32> = Vec::with_capacity(l2 + 1);
351
352        #[allow(clippy::needless_range_loop)]
353        for i in 0..l1 {
354            let mut r = (slice1[i] as u64) * (slice2[0] as u64);
355            r += overflow as u64;
356            let m = r as u32;
357            overflow = (r >> 32) as u32;
358            res.push(m);
359        }
360
361        if overflow != 0 {
362            res.push(overflow);
363        }
364
365        return res;
366    }
367
368    let m = std::cmp::min(l1, l2);
369    let m2 = (m as u32) / 2;
370
371    let (low1, high1) = slice1.split_at(m2 as usize);
372    let (low2, high2) = slice2.split_at(m2 as usize);
373
374    let z0 = mul_two_slices(low1, low2);
375    let z1 = mul_two_slices(&add_two_slices(low1, high1), &add_two_slices(low2, high2));
376    let z2 = mul_two_slices(high1, high2);
377
378    let mut op0 = z2.clone();
379
380    op0.reverse();
381
382    op0.resize(op0.len() + (m2 as usize * 2), 0);
383
384    op0.reverse();
385
386    let mut op1 = sub_two_slices(&sub_two_slices(&z1, &z2), &z0);
387
388    op1.reverse();
389
390    op1.resize(op1.len() + (m2 as usize), 0);
391
392    op1.reverse();
393
394    add_two_slices(&add_two_slices(&op0, &op1), &z0)
395}
396
397fn mul(a: &SmallUint, b: &SmallUint) -> SmallUint {
398    match (&a.0, &b.0) {
399        (&SmallUintType::Inline(i), &SmallUintType::Inline(j)) => {
400            match carrying_mul_u128(i, j, 0) {
401                (t, 0) => SmallUint(SmallUintType::Inline(t)),
402                (t, o) => {
403                    let mut res = Vec::with_capacity(8);
404
405                    let mut v = t;
406                    #[allow(clippy::needless_range_loop)]
407                    for _ in 0..4 {
408                        res.push(v as u32);
409
410                        v >>= 32;
411                    }
412
413                    let mut v = o;
414                    for _ in 4..8 {
415                        res.push(v as u32);
416
417                        v >>= 32;
418                    }
419
420                    while res[res.len() - 1] == 0 {
421                        res.pop();
422                    }
423
424                    let size = res.len();
425
426                    let mut slice = ManuallyDrop::new(res.into_boxed_slice());
427
428                    SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
429                }
430            }
431        }
432
433        (&SmallUintType::Heap((r, s)), &SmallUintType::Inline(i))
434        | (&SmallUintType::Inline(i), &SmallUintType::Heap((r, s))) => {
435            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
436
437            let mut res = [0, 0, 0, 0];
438
439            let mut v = i;
440            #[allow(clippy::needless_range_loop)]
441            for r in 0..4 {
442                res[r] = v as u32;
443
444                v >>= 32;
445            }
446
447            let result = mul_two_slices(slice1, &res[..]);
448            let size = result.len();
449
450            let mut slice = ManuallyDrop::new(result.into_boxed_slice());
451
452            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
453        }
454
455        (&SmallUintType::Heap((r, s)), &SmallUintType::Heap((i, j))) => {
456            let slice1 = unsafe { core::slice::from_raw_parts(r, s) };
457            let slice2 = unsafe { core::slice::from_raw_parts(i, j) };
458
459            let res = mul_two_slices(slice1, slice2);
460            let size = res.len();
461
462            let mut slice = ManuallyDrop::new(res.into_boxed_slice());
463
464            SmallUint(SmallUintType::Heap((slice.as_mut_ptr(), size)))
465        }
466    }
467}
468
469basic_op!(Mul, mul, SmallUint, mul);
470
471fn mul_signed(a: &SmallInt, b: &SmallInt) -> SmallInt {
472    let a_sign;
473    match &a.0 {
474        SmallIntType::Inline(i) => a_sign = i.signum() as i8,
475        SmallIntType::Heap((_, s)) => a_sign = s.signum() as i8,
476    }
477
478    let b_sign;
479    match &b.0 {
480        SmallIntType::Inline(i) => b_sign = i.signum() as i8,
481        SmallIntType::Heap((_, s)) => b_sign = s.signum() as i8,
482    }
483
484    match (a_sign, b_sign) {
485        x if (x.0 >= 0 && x.1 >= 0) => SmallInt::from(
486            SmallUint::from_smallint_unsigned(a.clone())
487                * SmallUint::from_smallint_unsigned(b.clone()),
488        ),
489        x if (x.0 < 0 && x.1 < 0) => SmallInt::from(
490            SmallUint::from_smallint_unsigned(a.clone())
491                * SmallUint::from_smallint_unsigned(b.clone()),
492        ),
493
494        x if (x.0 >= 0 && x.1 < 0) => -SmallInt::from(
495            SmallUint::from_smallint_unsigned(a.clone())
496                * SmallUint::from_smallint_unsigned(b.clone()),
497        ),
498
499        x if (x.0 < 0 && x.1 >= 0) => -SmallInt::from(
500            SmallUint::from_smallint_unsigned(a.clone())
501                * SmallUint::from_smallint_unsigned(b.clone()),
502        ),
503
504        (_, _) => {
505            panic!("This shouldn't happen. ");
506        }
507    }
508}
509
510basic_op!(Mul, mul, SmallInt, mul_signed);