sunscreen/types/bfv/
fractional.rs

1use seal_fhe::Plaintext as SealPlaintext;
2
3use crate::{
4    fhe::{with_fhe_ctx, FheContextOps},
5    types::{
6        ops::{
7            GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
8            GraphCipherConstSub, GraphCipherInsert, GraphCipherMul, GraphCipherNeg,
9            GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub,
10            GraphConstCipherSub, GraphPlainCipherSub,
11        },
12        Cipher,
13    },
14};
15use crate::{
16    types::{intern::FheProgramNode, BfvType, FheType, Type, Version},
17    FheProgramInputTrait, Params, WithContext,
18};
19
20use sunscreen_runtime::{
21    InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName,
22    TypeNameInstance,
23};
24
25use std::ops::*;
26
27#[derive(Debug, Clone, Copy, PartialEq)]
28/**
29 * A quasi fixed-point representation capable of storing values with
30 * both integer and fractional components.
31 *
32 * # Remarks
33 * This type is capable of addition, subtraction, and multiplication with no
34 * more overhead than the [`Signed`](crate::types::bfv::Signed) type.
35 * That is, addition and multiplication each take exactly one operation.
36 *
37 * ## Representation
38 * Recall that in BFV, the plaintext consists of a polynomial with
39 * `poly_degree` terms. `poly_degree` is a BFV scheme parameter that (by
40 * default) suncreen assigns for you depending on your FHE program's noise
41 * requirements.
42 *
43 * This type represents values with both an integer and fractional component.
44 * Semantically, you can think of this as a fixed-point value, but the
45 * implementation is somewhat different. The generic argument `INT_BITS`
46 * defines how many bits are reserved for the integer portion and the
47 * remaining `poly_degree - INT_BITS` bits store the fraction.
48 *
49 * Internally, this has a fairly funky representation that differs from
50 * traditional fixed-point. These variations allow the type to function
51 * properly under addition and multiplication in the absence of carries
52 * without needing to shift the decimal location after multiply operations.
53 *
54 * Each binary digit of the number maps to a single coefficient in the
55 * polynomial. The integer digits map to the low order plaintext polynomial
56 * coefficients with the following relation:
57 *
58 * ```text
59 * int(x) = sum_{i=0..INT_BITS}(c_i * 2^i)
60 * ```
61 *
62 * where `c_i` is the coefficient for the `x^i` term of the polynomial.
63 *
64 * Then, the fractional parts follow:
65 *
66 * ```text
67 * frac(x) = sum_{i=INT_BITS..N}(-c_i * 2^(N-i))
68 * ```
69 *
70 * where `N` is the `poly_degree`.
71 *
72 * Note that the sign of the polynomial coefficient for fractional terms are
73 * inverted. The entire value is simply `int(x) + frac(x)`.
74 *
75 * For example:
76 * `5.8125 =`
77 *
78 * | Coefficient index | 0   | 1   | 2   | ... | N-4  | N-3  | N-2  | N-1  |
79 * |-------------------|-----|-----|-----|-----|------|------|------|------|
80 * | 2^N               | 2^0 | 2^1 | 2^2 | ... | 2^-4 | 2^-3 | 2^-2 | 2^-1 |
81 * | Value             | 1   | 0   | 1   | ... | -1   | 0    | -1   | -1   |
82 *
83 * Negative values encode every digit as negative, where a negative
84 * coefficient is any value above `(plain_modulus + 1) / 2` up to
85 * `plain_modulus - 1`. The former is the most negative value, while the
86 * latter is the value `-1`. This is analogous to how 2's complement
87 * defines values above `0x80..00` to be negative with `0x80..00`
88 * being `INT_MIN` and `0xFF..FF` being `-1`.
89 *
90 * For example, if plain modulus is `14`, the value `-1` encodes as the
91 * unsigned value `13`, `-6` encodes as `8`, and the values `0..7` are simply
92 * `0..7` respectively.
93 *
94 * A full example of encoding a negative value:
95 * `-5.8125 =`
96 *
97 * | Coefficient index | 0   | 1   | 2   | ... | N-4  | N-3  | N-2  | N-1  |
98 * |-------------------|-----|-----|-----|-----|------|------|------|------|
99 * | 2^N               | 2^0 | 2^1 | 2^2 | ... | 2^-4 | 2^-3 | 2^-2 | 2^-1 |
100 * | Value             | -1  | 0   | -1  | ... | 1    | 0    |  1   | 1    |
101 *
102 * See [SEAL v2.1 documentation](https://eprint.iacr.org/2017/224.pdf) for
103 * full details.
104 *
105 * ## Limitations
106 * When encrypting a Fractional type, encoding will fail if:
107 * * The underlying [`f64`] is infinite.
108 * * The underlying [`f64`] is NaN
109 * * The integer portion of the underlying [`f64`] exceeds the precision for
110 * `INT_BITS`
111 *
112 * Subnormals flush to 0, while normals are represented without precision loss.
113 *
114 * While the numbers are binary, addition and multiplication are carryless.
115 * That is, carries don't propagate but instead increase the digit (i.e.
116 * polynomial coefficients) beyond radix 2. However, they're still subject to
117 * the scheme's `plain_modulus` specified during FHE program compilation.
118 * Repeated operations on an encrypted Fractional value will result in garbled
119 * values if *any* digit overflows the `plain_modulus`.
120 *
121 * Additionally numbers can experience more traditional overflow if the integer
122 * portion exceeds `2^INT_BITS`. Finally, repeated multiplications of
123 * numbers with decimal components introduce new decmal digits. If more than
124 * `2^(n-INT_BITS)` decimals appear, they will overflow into the integer
125 * portion and garble the number.
126 *
127 * To mitigate these issues, you should do some mix of the following:
128 * * Ensure inputs never result in either of these scenarios. Inputs to a
129 * FHE program need to have small enough digits to avoid digit overflow, values
130 * are small enough to avoid integer underflow, and have few enough decimal
131 * places to avoid decimal underflow.
132 * * Alice can periodically decrypt values, call turn the [`Fractional`] into
133 * an [`f64`], turn that back into a [`Fractional`], and re-encrypt. This will
134 * propagate carries and truncate the decimal portion to at most 53
135 * places (radix 2).
136 *
137 * ```rust
138 * # use sunscreen::types::bfv::Fractional;
139 * # use sunscreen::{Ciphertext, PublicKey, PrivateKey, FheRuntime, Result};
140 *
141 * fn normalize(
142 *   runtime: &FheRuntime,
143 *   ciphertext: &Ciphertext,
144 *   private_key: &PrivateKey,
145 *   public_key: &PublicKey
146 * ) -> Result<Ciphertext> {
147 *   let val: Fractional::<64> = runtime.decrypt(&ciphertext, &private_key)?;
148 *   let val: f64 = val.into();
149 *   let val = Fractional::<64>::from(val);
150 *
151 *   Ok(runtime.encrypt(val, &public_key)?)
152 * }
153 * ```
154 *
155 * Overflow aside, decryption can result in more acceptable and exprected precision loss:
156 * * If `INT_BITS > 1024`, the [`Fractional`]'s int can exceed [`f64::MAX`],
157 * resulting in [`f64::INFINITY`].
158 * * Decrypion will truncate precision beyond the 53 floating point mantissa bits (52 for subnormals). As previously mentioned, encrypting a subnormal
159 *  flushes to 0.
160 */
161pub struct Fractional<const INT_BITS: usize> {
162    val: f64,
163}
164
165impl<const INT_BITS: usize> std::ops::Deref for Fractional<INT_BITS> {
166    type Target = f64;
167
168    fn deref(&self) -> &Self::Target {
169        &self.val
170    }
171}
172
173impl<const INT_BITS: usize> NumCiphertexts for Fractional<INT_BITS> {
174    const NUM_CIPHERTEXTS: usize = 1;
175}
176
177impl<const INT_BITS: usize> FheProgramInputTrait for Fractional<INT_BITS> {}
178
179impl<const INT_BITS: usize> Default for Fractional<INT_BITS> {
180    fn default() -> Self {
181        Self::from(0.0)
182    }
183}
184
185impl<const INT_BITS: usize> TypeName for Fractional<INT_BITS> {
186    fn type_name() -> Type {
187        let version = env!("CARGO_PKG_VERSION");
188
189        Type {
190            name: format!("sunscreen::types::Fractional<{}>", INT_BITS),
191            version: Version::parse(version).expect("Crate version is not a valid semver"),
192            is_encrypted: false,
193        }
194    }
195}
196impl<const INT_BITS: usize> TypeNameInstance for Fractional<INT_BITS> {
197    fn type_name_instance(&self) -> Type {
198        Self::type_name()
199    }
200}
201
202impl<const INT_BITS: usize> FheType for Fractional<INT_BITS> {}
203impl<const INT_BITS: usize> BfvType for Fractional<INT_BITS> {}
204
205impl<const INT_BITS: usize> Fractional<INT_BITS> {}
206
207impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
208    type Left = Fractional<INT_BITS>;
209    type Right = Fractional<INT_BITS>;
210
211    fn graph_cipher_add(
212        a: FheProgramNode<Cipher<Self::Left>>,
213        b: FheProgramNode<Cipher<Self::Right>>,
214    ) -> FheProgramNode<Cipher<Self::Left>> {
215        with_fhe_ctx(|ctx| {
216            let n = ctx.add_addition(a.ids[0], b.ids[0]);
217
218            FheProgramNode::new(&[n])
219        })
220    }
221}
222
223impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
224    type Left = Fractional<INT_BITS>;
225    type Right = Fractional<INT_BITS>;
226
227    fn graph_cipher_plain_add(
228        a: FheProgramNode<Cipher<Self::Left>>,
229        b: FheProgramNode<Self::Right>,
230    ) -> FheProgramNode<Cipher<Self::Left>> {
231        with_fhe_ctx(|ctx| {
232            let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
233
234            FheProgramNode::new(&[n])
235        })
236    }
237}
238
239impl<const INT_BITS: usize> GraphCipherInsert for Fractional<INT_BITS> {
240    type Lit = f64;
241    type Val = Self;
242
243    fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
244        with_fhe_ctx(|ctx| {
245            let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap();
246            let lit = ctx.add_plaintext_literal(lit.inner);
247
248            FheProgramNode::new(&[lit])
249        })
250    }
251}
252
253impl<const INT_BITS: usize> GraphCipherConstAdd for Fractional<INT_BITS> {
254    type Left = Fractional<INT_BITS>;
255    type Right = f64;
256
257    fn graph_cipher_const_add(
258        a: FheProgramNode<Cipher<Self::Left>>,
259        b: Self::Right,
260    ) -> FheProgramNode<Cipher<Self::Left>> {
261        let lit = Self::graph_cipher_insert(b);
262        with_fhe_ctx(|ctx| {
263            let n = ctx.add_addition_plaintext(a.ids[0], lit.ids[0]);
264            FheProgramNode::new(&[n])
265        })
266    }
267}
268
269impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
270    type Left = Fractional<INT_BITS>;
271    type Right = Fractional<INT_BITS>;
272
273    fn graph_cipher_sub(
274        a: FheProgramNode<Cipher<Self::Left>>,
275        b: FheProgramNode<Cipher<Self::Right>>,
276    ) -> FheProgramNode<Cipher<Self::Left>> {
277        with_fhe_ctx(|ctx| {
278            let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
279
280            FheProgramNode::new(&[n])
281        })
282    }
283}
284
285impl<const INT_BITS: usize> GraphCipherPlainSub for Fractional<INT_BITS> {
286    type Left = Fractional<INT_BITS>;
287    type Right = Fractional<INT_BITS>;
288
289    fn graph_cipher_plain_sub(
290        a: FheProgramNode<Cipher<Self::Left>>,
291        b: FheProgramNode<Self::Right>,
292    ) -> FheProgramNode<Cipher<Self::Left>> {
293        with_fhe_ctx(|ctx| {
294            let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
295
296            FheProgramNode::new(&[n])
297        })
298    }
299}
300
301impl<const INT_BITS: usize> GraphPlainCipherSub for Fractional<INT_BITS> {
302    type Left = Fractional<INT_BITS>;
303    type Right = Fractional<INT_BITS>;
304
305    fn graph_plain_cipher_sub(
306        a: FheProgramNode<Self::Left>,
307        b: FheProgramNode<Cipher<Self::Right>>,
308    ) -> FheProgramNode<Cipher<Self::Left>> {
309        with_fhe_ctx(|ctx| {
310            let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
311            let n = ctx.add_negate(n);
312
313            FheProgramNode::new(&[n])
314        })
315    }
316}
317
318impl<const INT_BITS: usize> GraphCipherConstSub for Fractional<INT_BITS> {
319    type Left = Fractional<INT_BITS>;
320    type Right = f64;
321
322    fn graph_cipher_const_sub(
323        a: FheProgramNode<Cipher<Self::Left>>,
324        b: Self::Right,
325    ) -> FheProgramNode<Cipher<Self::Left>> {
326        let lit = Self::graph_cipher_insert(b);
327        with_fhe_ctx(|ctx| {
328            let n = ctx.add_subtraction_plaintext(a.ids[0], lit.ids[0]);
329            FheProgramNode::new(&[n])
330        })
331    }
332}
333
334impl<const INT_BITS: usize> GraphConstCipherSub for Fractional<INT_BITS> {
335    type Left = f64;
336    type Right = Fractional<INT_BITS>;
337
338    fn graph_const_cipher_sub(
339        a: Self::Left,
340        b: FheProgramNode<Cipher<Self::Right>>,
341    ) -> FheProgramNode<Cipher<Self::Right>> {
342        let lit = Self::graph_cipher_insert(a);
343        with_fhe_ctx(|ctx| {
344            let n = ctx.add_subtraction_plaintext(b.ids[0], lit.ids[0]);
345            let n = ctx.add_negate(n);
346
347            FheProgramNode::new(&[n])
348        })
349    }
350}
351
352impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
353    type Left = Fractional<INT_BITS>;
354    type Right = Fractional<INT_BITS>;
355
356    fn graph_cipher_mul(
357        a: FheProgramNode<Cipher<Self::Left>>,
358        b: FheProgramNode<Cipher<Self::Right>>,
359    ) -> FheProgramNode<Cipher<Self::Left>> {
360        with_fhe_ctx(|ctx| {
361            let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
362
363            FheProgramNode::new(&[n])
364        })
365    }
366}
367
368impl<const INT_BITS: usize> GraphCipherPlainMul for Fractional<INT_BITS> {
369    type Left = Fractional<INT_BITS>;
370    type Right = Fractional<INT_BITS>;
371
372    fn graph_cipher_plain_mul(
373        a: FheProgramNode<Cipher<Self::Left>>,
374        b: FheProgramNode<Self::Right>,
375    ) -> FheProgramNode<Cipher<Self::Left>> {
376        with_fhe_ctx(|ctx| {
377            let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
378
379            FheProgramNode::new(&[n])
380        })
381    }
382}
383
384impl<const INT_BITS: usize> GraphCipherConstMul for Fractional<INT_BITS> {
385    type Left = Fractional<INT_BITS>;
386    type Right = f64;
387
388    fn graph_cipher_const_mul(
389        a: FheProgramNode<Cipher<Self::Left>>,
390        b: Self::Right,
391    ) -> FheProgramNode<Cipher<Self::Left>> {
392        let lit = Self::graph_cipher_insert(b);
393        with_fhe_ctx(|ctx| {
394            let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
395            FheProgramNode::new(&[n])
396        })
397    }
398}
399
400impl<const INT_BITS: usize> GraphCipherConstDiv for Fractional<INT_BITS> {
401    type Left = Fractional<INT_BITS>;
402    type Right = f64;
403
404    fn graph_cipher_const_div(
405        a: FheProgramNode<Cipher<Self::Left>>,
406        b: f64,
407    ) -> FheProgramNode<Cipher<Self::Left>> {
408        let lit = Self::graph_cipher_insert(1. / b);
409        with_fhe_ctx(|ctx| {
410            let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
411            FheProgramNode::new(&[n])
412        })
413    }
414}
415
416impl<const INT_BITS: usize> GraphCipherNeg for Fractional<INT_BITS> {
417    type Val = Fractional<INT_BITS>;
418
419    fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
420        with_fhe_ctx(|ctx| {
421            let n = ctx.add_negate(a.ids[0]);
422
423            FheProgramNode::new(&[n])
424        })
425    }
426}
427
428impl<const INT_BITS: usize> TryIntoPlaintext for Fractional<INT_BITS> {
429    fn try_into_plaintext(
430        &self,
431        params: &Params,
432    ) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
433        if self.val.is_nan() {
434            return Err(sunscreen_runtime::Error::fhe_type_error("Value is NaN."));
435        }
436
437        if self.val.is_infinite() {
438            return Err(sunscreen_runtime::Error::fhe_type_error(
439                "Value is infinite.",
440            ));
441        }
442
443        let mut seal_plaintext = SealPlaintext::new()?;
444        let n = params.lattice_dimension as usize;
445        seal_plaintext.resize(n);
446
447        // Just flush subnormals, as they're tiny and annoying.
448        if self.val.is_subnormal() || self.val == 0.0 {
449            return Ok(Plaintext {
450                data_type: self.type_name_instance(),
451                inner: InnerPlaintext::Seal(vec![WithContext {
452                    params: params.clone(),
453                    data: seal_plaintext,
454                }]),
455            });
456        }
457
458        // If we made it this far, the float value is of normal form.
459        // Recall 64-bit IEEE 754-2008 floats have 52 mantissa, 11 exp, and 1
460        // sign bit from LSB to MSB order. They are represented by the form
461        // -1^sign * 2^(exp - 1023) * 1.mantissa
462
463        // Coerce the f64 into a u64 so we can extract out the
464        // sign, mantissa, and exponent.
465        let as_u64: u64 = self.val.to_bits();
466
467        let sign_mask = 0x1 << 63;
468        let mantissa_mask = 0xFFFFFFFFFFFFF;
469        let exp_mask = !mantissa_mask & !sign_mask;
470
471        // Mask of the mantissa and add the implicit 1
472        let mantissa = as_u64 & mantissa_mask | (mantissa_mask + 1);
473        let exp = as_u64 & exp_mask;
474        let power = (exp >> (f64::MANTISSA_DIGITS - 1)) as i64 - 1023;
475        let sign = (as_u64 & sign_mask) >> 63;
476
477        if power + 1 > INT_BITS as i64 {
478            return Err(sunscreen_runtime::Error::fhe_type_error("Out of range"));
479        }
480
481        for i in 0..f64::MANTISSA_DIGITS {
482            let bit_value = (mantissa & 0x1 << i) >> i;
483            let bit_power = power - (f64::MANTISSA_DIGITS - i - 1) as i64;
484
485            let coeff_index = if bit_power >= 0 {
486                bit_power as usize
487            } else {
488                (n as i64 + bit_power) as usize
489            };
490
491            // For powers less than 0, we invert the sign.
492            let sign = if bit_power >= 0 { sign } else { !sign & 0x1 };
493
494            let coeff = if sign == 0 {
495                bit_value
496            } else if bit_value > 0 {
497                params.plain_modulus - bit_value
498            } else {
499                0
500            };
501
502            seal_plaintext.set_coefficient(coeff_index, coeff);
503        }
504
505        Ok(Plaintext {
506            data_type: self.type_name_instance(),
507            inner: InnerPlaintext::Seal(vec![WithContext {
508                params: params.clone(),
509                data: seal_plaintext,
510            }]),
511        })
512    }
513}
514
515impl<const INT_BITS: usize> TryFromPlaintext for Fractional<INT_BITS> {
516    fn try_from_plaintext(
517        plaintext: &Plaintext,
518        params: &Params,
519    ) -> std::result::Result<Self, sunscreen_runtime::Error> {
520        let val = match &plaintext.inner {
521            InnerPlaintext::Seal(p) => {
522                if p.len() != 1 {
523                    return Err(sunscreen_runtime::Error::IncorrectCiphertextCount);
524                }
525
526                let mut val = 0.0f64;
527                let n = params.lattice_dimension as usize;
528
529                let len = p[0].len();
530
531                let negative_cutoff = (params.plain_modulus + 1) / 2;
532
533                for i in 0..usize::min(n, len) {
534                    let power = if i < INT_BITS {
535                        i as i64
536                    } else {
537                        i as i64 - n as i64
538                    };
539
540                    let coeff = p[0].get_coefficient(i);
541
542                    // Reverse the sign of negative powers.
543                    let sign = if power >= 0 { 1f64 } else { -1f64 };
544
545                    if coeff < negative_cutoff {
546                        val += sign * coeff as f64 * (power as f64).exp2();
547                    } else {
548                        val -= sign * (params.plain_modulus - coeff) as f64 * (power as f64).exp2();
549                    };
550                }
551
552                Self { val }
553            }
554        };
555
556        Ok(val)
557    }
558}
559
560impl<const INT_BITS: usize> From<f64> for Fractional<INT_BITS> {
561    fn from(val: f64) -> Self {
562        Self { val }
563    }
564}
565
566impl<const INT_BITS: usize> From<Fractional<INT_BITS>> for f64 {
567    fn from(frac: Fractional<INT_BITS>) -> Self {
568        frac.val
569    }
570}
571
572impl<const INT_BITS: usize> Add for Fractional<INT_BITS> {
573    type Output = Self;
574
575    fn add(self, rhs: Self) -> Self {
576        Self {
577            val: self.val + rhs.val,
578        }
579    }
580}
581
582impl<const INT_BITS: usize> Add<f64> for Fractional<INT_BITS> {
583    type Output = Self;
584
585    fn add(self, rhs: f64) -> Self {
586        Self {
587            val: self.val + rhs,
588        }
589    }
590}
591
592impl<const INT_BITS: usize> Add<Fractional<INT_BITS>> for f64 {
593    type Output = Fractional<INT_BITS>;
594
595    fn add(self, rhs: Fractional<INT_BITS>) -> Self::Output {
596        Fractional {
597            val: self + rhs.val,
598        }
599    }
600}
601
602impl<const INT_BITS: usize> Mul for Fractional<INT_BITS> {
603    type Output = Self;
604
605    fn mul(self, rhs: Self) -> Self {
606        Self {
607            val: self.val * rhs.val,
608        }
609    }
610}
611
612impl<const INT_BITS: usize> Mul<f64> for Fractional<INT_BITS> {
613    type Output = Self;
614
615    fn mul(self, rhs: f64) -> Self {
616        Self {
617            val: self.val * rhs,
618        }
619    }
620}
621
622impl<const INT_BITS: usize> Mul<Fractional<INT_BITS>> for f64 {
623    type Output = Fractional<INT_BITS>;
624
625    fn mul(self, rhs: Fractional<INT_BITS>) -> Self::Output {
626        Fractional {
627            val: self * rhs.val,
628        }
629    }
630}
631
632impl<const INT_BITS: usize> Sub for Fractional<INT_BITS> {
633    type Output = Self;
634
635    fn sub(self, rhs: Self) -> Self {
636        Self {
637            val: self.val - rhs.val,
638        }
639    }
640}
641
642impl<const INT_BITS: usize> Sub<f64> for Fractional<INT_BITS> {
643    type Output = Self;
644
645    fn sub(self, rhs: f64) -> Self {
646        Self {
647            val: self.val - rhs,
648        }
649    }
650}
651
652impl<const INT_BITS: usize> Sub<Fractional<INT_BITS>> for f64 {
653    type Output = Fractional<INT_BITS>;
654
655    fn sub(self, rhs: Fractional<INT_BITS>) -> Self::Output {
656        Fractional {
657            val: self - rhs.val,
658        }
659    }
660}
661
662impl<const INT_BITS: usize> Div<f64> for Fractional<INT_BITS> {
663    type Output = Self;
664
665    fn div(self, rhs: f64) -> Self {
666        Self {
667            val: self.val / rhs,
668        }
669    }
670}
671
672impl<const INT_BITS: usize> Neg for Fractional<INT_BITS> {
673    type Output = Self;
674
675    fn neg(self) -> Self {
676        Self { val: -self.val }
677    }
678}
679
680#[cfg(test)]
681mod tests {
682
683    #![allow(clippy::approx_constant)]
684
685    use super::*;
686    use crate::{SchemeType, SecurityLevel};
687    use float_cmp::ApproxEq;
688
689    #[test]
690    fn can_encode_decode_fractional() {
691        let round_trip = |x: f64| {
692            let params = Params {
693                lattice_dimension: 4096,
694                plain_modulus: 1_000_000,
695                coeff_modulus: vec![],
696                scheme_type: SchemeType::Bfv,
697                security_level: SecurityLevel::TC128,
698            };
699
700            let f_1 = Fractional::<64>::from(x);
701            let pt = f_1.try_into_plaintext(&params).unwrap();
702            let f_2 = Fractional::<64>::try_from_plaintext(&pt, &params).unwrap();
703
704            assert_eq!(f_1, f_2);
705        };
706
707        round_trip(3.14);
708        round_trip(0.0);
709        round_trip(1.0);
710        round_trip(5.8125);
711        round_trip(6.0);
712        round_trip(6.6);
713        round_trip(1.2);
714        round_trip(1e13);
715        round_trip(0.0000000005);
716        round_trip(-1.0);
717        round_trip(-5.875);
718        round_trip(-6.0);
719        round_trip(-6.6);
720        round_trip(-1.2);
721        round_trip(-1e13);
722        round_trip(-0.0000000005);
723    }
724
725    #[test]
726    fn can_add_non_fhe() {
727        let a = Fractional::<64>::from(3.14);
728        let b = Fractional::<64>::from(1.5);
729
730        // Allow 1 ULP of error
731        assert!((a + b).approx_eq(4.64, (0.0, 1)));
732        assert!((3.14 + b).approx_eq(4.64, (0.0, 1)));
733        assert!((a + 1.5).approx_eq(4.64, (0.0, 1)));
734    }
735
736    #[test]
737    fn can_mul_non_fhe() {
738        let a = Fractional::<64>::from(3.14);
739        let b = Fractional::<64>::from(1.5);
740
741        // Allow 1 ULP of error
742        assert!((a * b).approx_eq(4.71, (0.0, 1)));
743        assert!((3.14 * b).approx_eq(4.71, (0.0, 1)));
744        assert!((a * 1.5).approx_eq(4.71, (0.0, 1)));
745    }
746
747    #[test]
748    fn can_sub_non_fhe() {
749        let a = Fractional::<64>::from(3.14);
750        let b = Fractional::<64>::from(1.5);
751
752        // Allow 1 ULP of error
753        assert!((a - b).approx_eq(1.64, (0.0, 1)));
754        assert!((3.14 - b).approx_eq(1.64, (0.0, 1)));
755        assert!((a - 1.5).approx_eq(1.64, (0.0, 1)));
756    }
757
758    #[test]
759    fn can_div_non_fhe() {
760        let a = Fractional::<64>::from(3.14);
761
762        // Allow 1 ULP of error
763        assert!((a / 1.5).approx_eq(3.14 / 1.5, (0.0, 1)));
764    }
765
766    #[test]
767    fn can_neg_non_fhe() {
768        let a = Fractional::<64>::from(3.14);
769
770        // Allow 1 ULP of error
771        assert_eq!(-a, (-3.14).into());
772    }
773}