sunscreen/types/bfv/
rational.rs

1use crate as sunscreen;
2use crate::fhe::{with_fhe_ctx, FheContextOps};
3use crate::types::{
4    bfv::Signed, intern::FheProgramNode, ops::*, BfvType, Cipher, FheType, GraphCipherAdd,
5    GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext,
6    TryIntoPlaintext, TypeName,
7};
8use crate::{FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
9use std::cmp::Eq;
10use std::ops::*;
11use sunscreen_runtime::Error;
12
13use num::Rational64;
14
15#[derive(Debug, Clone, Copy, TypeName, Eq)]
16/**
17 * Represents the ratio of two integers. Allows for fractional values and division.
18 */
19pub struct Rational {
20    num: Signed,
21    den: Signed,
22}
23
24impl PartialEq for Rational {
25    fn eq(&self, other: &Self) -> bool {
26        let num_a: i64 = self.num.into();
27        let num_b: i64 = other.num.into();
28        let den_a: i64 = self.den.into();
29        let den_b: i64 = other.den.into();
30
31        num_a * den_b == num_b * den_a
32    }
33}
34
35impl Default for Rational {
36    fn default() -> Self {
37        Self::try_from(0.0).unwrap()
38    }
39}
40
41impl NumCiphertexts for Rational {
42    const NUM_CIPHERTEXTS: usize = Signed::NUM_CIPHERTEXTS + Signed::NUM_CIPHERTEXTS;
43}
44
45impl TryFromPlaintext for Rational {
46    fn try_from_plaintext(plaintext: &Plaintext, params: &Params) -> Result<Self, Error> {
47        let (num, den) = match &plaintext.inner {
48            InnerPlaintext::Seal(p) => {
49                let num = Plaintext {
50                    data_type: Self::type_name(),
51                    inner: InnerPlaintext::Seal(vec![p[0].clone()]),
52                };
53                let den = Plaintext {
54                    data_type: Self::type_name(),
55                    inner: InnerPlaintext::Seal(vec![p[1].clone()]),
56                };
57
58                (
59                    Signed::try_from_plaintext(&num, params)?,
60                    Signed::try_from_plaintext(&den, params)?,
61                )
62            }
63        };
64
65        Ok(Self { num, den })
66    }
67}
68
69impl TryIntoPlaintext for Rational {
70    fn try_into_plaintext(&self, params: &Params) -> Result<Plaintext, Error> {
71        let num = self.num.try_into_plaintext(params)?;
72        let den = self.den.try_into_plaintext(params)?;
73
74        let (num, den) = match (num.inner, den.inner) {
75            (InnerPlaintext::Seal(n), InnerPlaintext::Seal(d)) => (n[0].clone(), d[0].clone()),
76        };
77
78        Ok(Plaintext {
79            data_type: Self::type_name(),
80            inner: InnerPlaintext::Seal(vec![num, den]),
81        })
82    }
83}
84
85impl FheProgramInputTrait for Rational {}
86impl FheType for Rational {}
87impl BfvType for Rational {}
88
89impl TryFrom<f64> for Rational {
90    type Error = Error;
91
92    fn try_from(val: f64) -> Result<Self, Self::Error> {
93        let val = Rational64::approximate_float(val)
94            .ok_or_else(|| Error::fhe_type_error("Failed to parse float into rational"))?;
95
96        Ok(Self {
97            num: Signed::from(*val.numer()),
98            den: Signed::from(*val.denom()),
99        })
100    }
101}
102
103impl From<Rational> for f64 {
104    fn from(val: Rational) -> Self {
105        let num: i64 = val.num.into();
106        let den: i64 = val.den.into();
107
108        num as f64 / den as f64
109    }
110}
111
112impl Add for Rational {
113    type Output = Self;
114
115    fn add(self, rhs: Self) -> Self::Output {
116        Self::Output {
117            num: self.num * rhs.den + rhs.num * self.den,
118            den: self.den * rhs.den,
119        }
120    }
121}
122
123impl Add<f64> for Rational {
124    type Output = Self;
125
126    fn add(self, rhs: f64) -> Self::Output {
127        let rhs = Rational::try_from(rhs).unwrap();
128
129        Self::Output {
130            num: self.num * rhs.den + rhs.num * self.den,
131            den: self.den * rhs.den,
132        }
133    }
134}
135
136impl Add<Rational> for f64 {
137    type Output = Rational;
138
139    fn add(self, rhs: Rational) -> Self::Output {
140        let lhs = Rational::try_from(self).unwrap();
141
142        Self::Output {
143            num: lhs.num * rhs.den + rhs.num * lhs.den,
144            den: lhs.den * rhs.den,
145        }
146    }
147}
148
149impl Mul for Rational {
150    type Output = Self;
151
152    fn mul(self, rhs: Self) -> Self::Output {
153        Self::Output {
154            num: self.num * rhs.num,
155            den: self.den * rhs.den,
156        }
157    }
158}
159
160impl Mul<f64> for Rational {
161    type Output = Self;
162
163    fn mul(self, rhs: f64) -> Self::Output {
164        let rhs = Rational::try_from(rhs).unwrap();
165
166        Self {
167            num: self.num * rhs.num,
168            den: self.den * rhs.den,
169        }
170    }
171}
172
173impl Mul<Rational> for f64 {
174    type Output = Rational;
175
176    fn mul(self, rhs: Rational) -> Self::Output {
177        let lhs = Rational::try_from(self).unwrap();
178
179        Self::Output {
180            num: lhs.num * rhs.num,
181            den: lhs.den * rhs.den,
182        }
183    }
184}
185
186impl Sub for Rational {
187    type Output = Self;
188
189    fn sub(self, rhs: Self) -> Self::Output {
190        Self::Output {
191            num: self.num * rhs.den - rhs.num * self.den,
192            den: self.den * rhs.den,
193        }
194    }
195}
196
197impl Sub<f64> for Rational {
198    type Output = Self;
199
200    fn sub(self, rhs: f64) -> Self::Output {
201        let rhs = Rational::try_from(rhs).unwrap();
202
203        Self::Output {
204            num: self.num * rhs.den - rhs.num * self.den,
205            den: self.den * rhs.den,
206        }
207    }
208}
209
210impl Sub<Rational> for f64 {
211    type Output = Rational;
212
213    fn sub(self, rhs: Rational) -> Self::Output {
214        let lhs = Rational::try_from(self).unwrap();
215
216        Self::Output {
217            num: lhs.num * rhs.den - rhs.num * lhs.den,
218            den: lhs.den * rhs.den,
219        }
220    }
221}
222
223impl Div for Rational {
224    type Output = Self;
225
226    fn div(self, rhs: Self) -> Self::Output {
227        Self::Output {
228            num: self.num * rhs.den,
229            den: self.den * rhs.num,
230        }
231    }
232}
233
234impl Div<f64> for Rational {
235    type Output = Self;
236
237    fn div(self, rhs: f64) -> Self::Output {
238        let rhs = Rational::try_from(rhs).unwrap();
239
240        Self::Output {
241            num: self.num * rhs.den,
242            den: self.den * rhs.num,
243        }
244    }
245}
246
247impl Div<Rational> for f64 {
248    type Output = Rational;
249
250    fn div(self, rhs: Rational) -> Self::Output {
251        let lhs = Rational::try_from(self).unwrap();
252
253        Self::Output {
254            num: lhs.num * rhs.den,
255            den: lhs.den * rhs.num,
256        }
257    }
258}
259
260impl Neg for Rational {
261    type Output = Self;
262
263    fn neg(self) -> Self::Output {
264        Self::Output {
265            num: -self.num,
266            den: self.den,
267        }
268    }
269}
270
271impl GraphCipherAdd for Rational {
272    type Left = Self;
273    type Right = Self;
274
275    fn graph_cipher_add(
276        a: FheProgramNode<Cipher<Self::Left>>,
277        b: FheProgramNode<Cipher<Self::Right>>,
278    ) -> FheProgramNode<Cipher<Self::Left>> {
279        with_fhe_ctx(|ctx| {
280            // Scale each numinator by the other's denominator.
281            let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
282            let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
283
284            // Get denominators to have the same scale
285            let den_2 = ctx.add_multiplication(a.ids[1], b.ids[1]);
286
287            let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
288
289            FheProgramNode::new(&ids)
290        })
291    }
292}
293
294impl GraphCipherPlainAdd for Rational {
295    type Left = Self;
296    type Right = Self;
297
298    fn graph_cipher_plain_add(
299        a: FheProgramNode<Cipher<Self::Left>>,
300        b: FheProgramNode<Self::Right>,
301    ) -> FheProgramNode<Cipher<Self::Left>> {
302        with_fhe_ctx(|ctx| {
303            // Scale each numinator by the other's denominator.
304            let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
305            let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
306
307            // Get denominators to have the same scale
308            let den_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
309
310            let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
311
312            FheProgramNode::new(&ids)
313        })
314    }
315}
316
317impl GraphCipherInsert for Rational {
318    type Lit = f64;
319    type Val = Self;
320
321    fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
322        with_fhe_ctx(|ctx| {
323            let lit = Self::try_from(lit).unwrap();
324
325            let lit_num =
326                ctx.add_plaintext_literal(lit.num.try_into_plaintext(&ctx.data).unwrap().inner);
327
328            let lit_den =
329                ctx.add_plaintext_literal(lit.den.try_into_plaintext(&ctx.data).unwrap().inner);
330
331            FheProgramNode::new(&[lit_num, lit_den])
332        })
333    }
334}
335
336impl GraphCipherConstAdd for Rational {
337    type Left = Self;
338    type Right = f64;
339
340    fn graph_cipher_const_add(
341        a: FheProgramNode<Cipher<Self::Left>>,
342        b: Self::Right,
343    ) -> FheProgramNode<Cipher<Self::Left>> {
344        let lit = Self::graph_cipher_insert(b);
345        with_fhe_ctx(|ctx| {
346            // Scale each numinator by the other's denominator.
347            let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
348            let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
349
350            // Get denominators to have the same scale
351            let den_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
352
353            let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
354
355            FheProgramNode::new(&ids)
356        })
357    }
358}
359
360impl GraphCipherSub for Rational {
361    type Left = Self;
362    type Right = Self;
363
364    fn graph_cipher_sub(
365        a: FheProgramNode<Cipher<Self::Left>>,
366        b: FheProgramNode<Cipher<Self::Right>>,
367    ) -> FheProgramNode<Cipher<Self::Left>> {
368        with_fhe_ctx(|ctx| {
369            // Scale each numinator by the other's denominator.
370            let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
371            let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
372
373            // Get denominators to have the same scale
374            let den_2 = ctx.add_multiplication(a.ids[1], b.ids[1]);
375
376            let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
377
378            FheProgramNode::new(&ids)
379        })
380    }
381}
382
383impl GraphCipherPlainSub for Rational {
384    type Left = Self;
385    type Right = Self;
386
387    fn graph_cipher_plain_sub(
388        a: FheProgramNode<Cipher<Self::Left>>,
389        b: FheProgramNode<Self::Right>,
390    ) -> FheProgramNode<Cipher<Self::Left>> {
391        with_fhe_ctx(|ctx| {
392            // Scale each numinator by the other's denominator.
393            let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
394            let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
395
396            // Get denominators to have the same scale
397            let den_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
398
399            let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
400
401            FheProgramNode::new(&ids)
402        })
403    }
404}
405
406impl GraphPlainCipherSub for Rational {
407    type Left = Self;
408    type Right = Self;
409
410    fn graph_plain_cipher_sub(
411        a: FheProgramNode<Self::Left>,
412        b: FheProgramNode<Cipher<Self::Right>>,
413    ) -> FheProgramNode<Cipher<Self::Left>> {
414        with_fhe_ctx(|ctx| {
415            // Scale each numinator by the other's denominator.
416            let num_a_2 = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
417            let num_b_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
418
419            // Get denominators to have the same scale
420            let den_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[1]);
421
422            let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
423
424            FheProgramNode::new(&ids)
425        })
426    }
427}
428
429impl GraphCipherConstSub for Rational {
430    type Left = Self;
431    type Right = f64;
432
433    fn graph_cipher_const_sub(
434        a: FheProgramNode<Cipher<Self::Left>>,
435        b: Self::Right,
436    ) -> FheProgramNode<Cipher<Self::Left>> {
437        let lit = Self::graph_cipher_insert(b);
438        with_fhe_ctx(|ctx| {
439            // Scale each numinator by the other's denominator.
440            let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
441            let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
442
443            // Get denominators to have the same scale
444            let den_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
445
446            let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
447
448            FheProgramNode::new(&ids)
449        })
450    }
451}
452
453impl GraphConstCipherSub for Rational {
454    type Left = f64;
455    type Right = Self;
456
457    fn graph_const_cipher_sub(
458        a: Self::Left,
459        b: FheProgramNode<Cipher<Self::Right>>,
460    ) -> FheProgramNode<Cipher<Self::Right>> {
461        let lit = Self::graph_cipher_insert(a);
462        with_fhe_ctx(|ctx| {
463            // Scale each numinator by the other's denominator.
464            let num_b_2 = ctx.add_multiplication_plaintext(b.ids[0], lit.ids[1]);
465            let num_a_2 = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[0]);
466
467            // Get denominators to have the same scale
468            let den_2 = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[1]);
469
470            let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
471
472            FheProgramNode::new(&ids)
473        })
474    }
475}
476
477impl GraphCipherMul for Rational {
478    type Left = Self;
479    type Right = Self;
480
481    fn graph_cipher_mul(
482        a: FheProgramNode<Cipher<Self::Left>>,
483        b: FheProgramNode<Cipher<Self::Right>>,
484    ) -> FheProgramNode<Cipher<Self::Left>> {
485        with_fhe_ctx(|ctx| {
486            let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
487            let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]);
488
489            let ids = [mul_num, mul_den];
490
491            FheProgramNode::new(&ids)
492        })
493    }
494}
495
496impl GraphCipherPlainMul for Rational {
497    type Left = Self;
498    type Right = Self;
499
500    fn graph_cipher_plain_mul(
501        a: FheProgramNode<Cipher<Self::Left>>,
502        b: FheProgramNode<Self::Right>,
503    ) -> FheProgramNode<Cipher<Self::Left>> {
504        with_fhe_ctx(|ctx| {
505            let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
506            let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
507
508            let ids = [mul_num, mul_den];
509
510            FheProgramNode::new(&ids)
511        })
512    }
513}
514
515impl GraphCipherConstMul for Rational {
516    type Left = Self;
517    type Right = f64;
518
519    fn graph_cipher_const_mul(
520        a: FheProgramNode<Cipher<Self::Left>>,
521        b: Self::Right,
522    ) -> FheProgramNode<Cipher<Self::Left>> {
523        let lit = Self::graph_cipher_insert(b);
524        with_fhe_ctx(|ctx| {
525            let mul_num = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
526            let mul_den = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
527
528            let ids = [mul_num, mul_den];
529
530            FheProgramNode::new(&ids)
531        })
532    }
533}
534
535impl GraphCipherDiv for Rational {
536    type Left = Self;
537    type Right = Self;
538
539    fn graph_cipher_div(
540        a: FheProgramNode<Cipher<Self::Left>>,
541        b: FheProgramNode<Cipher<Self::Right>>,
542    ) -> FheProgramNode<Cipher<Self::Left>> {
543        with_fhe_ctx(|ctx| {
544            let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);
545            let mul_den = ctx.add_multiplication(a.ids[1], b.ids[0]);
546
547            let ids = [mul_num, mul_den];
548
549            FheProgramNode::new(&ids)
550        })
551    }
552}
553
554impl GraphCipherPlainDiv for Rational {
555    type Left = Self;
556    type Right = Self;
557
558    fn graph_cipher_plain_div(
559        a: FheProgramNode<Cipher<Self::Left>>,
560        b: FheProgramNode<Self::Right>,
561    ) -> FheProgramNode<Cipher<Self::Left>> {
562        with_fhe_ctx(|ctx| {
563            let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
564            let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
565
566            let ids = [mul_num, mul_den];
567
568            FheProgramNode::new(&ids)
569        })
570    }
571}
572
573impl GraphPlainCipherDiv for Rational {
574    type Left = Self;
575    type Right = Self;
576
577    fn graph_plain_cipher_div(
578        a: FheProgramNode<Self::Left>,
579        b: FheProgramNode<Cipher<Self::Right>>,
580    ) -> FheProgramNode<Cipher<Self::Left>> {
581        with_fhe_ctx(|ctx| {
582            let mul_num = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
583            let mul_den = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
584
585            let ids = [mul_num, mul_den];
586
587            FheProgramNode::new(&ids)
588        })
589    }
590}
591
592impl GraphCipherConstDiv for Rational {
593    type Left = Self;
594    type Right = f64;
595
596    fn graph_cipher_const_div(
597        a: FheProgramNode<Cipher<Self::Left>>,
598        b: Self::Right,
599    ) -> FheProgramNode<Cipher<Self::Left>> {
600        let lit = Self::graph_cipher_insert(b);
601        with_fhe_ctx(|ctx| {
602            let mul_num = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
603            let mul_den = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
604
605            let ids = [mul_num, mul_den];
606
607            FheProgramNode::new(&ids)
608        })
609    }
610}
611
612impl GraphConstCipherDiv for Rational {
613    type Left = f64;
614    type Right = Self;
615
616    fn graph_const_cipher_div(
617        a: Self::Left,
618        b: FheProgramNode<Cipher<Self::Right>>,
619    ) -> FheProgramNode<Cipher<Self::Right>> {
620        let lit = Self::graph_cipher_insert(a);
621        with_fhe_ctx(|ctx| {
622            let mul_num = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[0]);
623            let mul_den = ctx.add_multiplication_plaintext(b.ids[0], lit.ids[1]);
624
625            let ids = [mul_num, mul_den];
626
627            FheProgramNode::new(&ids)
628        })
629    }
630}
631
632impl GraphCipherNeg for Rational {
633    type Val = Self;
634
635    fn graph_cipher_neg(a: FheProgramNode<Cipher<Self::Val>>) -> FheProgramNode<Cipher<Self::Val>> {
636        with_fhe_ctx(|ctx| {
637            let neg = ctx.add_negate(a.ids[0]);
638            let ids = [neg, a.ids[1]];
639
640            FheProgramNode::new(&ids)
641        })
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn can_add_non_fhe() {
651        let a = Rational::try_from(5.).unwrap();
652        let b = Rational::try_from(10.).unwrap();
653
654        assert_eq!(a + b, 15f64.try_into().unwrap());
655        assert_eq!(a + 10., 15f64.try_into().unwrap());
656        assert_eq!(10. + a, 15f64.try_into().unwrap());
657    }
658
659    #[test]
660    fn can_mul_non_fhe() {
661        let a = Rational::try_from(5.).unwrap();
662        let b = Rational::try_from(10.).unwrap();
663
664        assert_eq!(a * b, 50f64.try_into().unwrap());
665        assert_eq!(a * 10., 50f64.try_into().unwrap());
666        assert_eq!(10. * a, 50f64.try_into().unwrap());
667    }
668
669    #[test]
670    fn can_sub_non_fhe() {
671        let a = Rational::try_from(5.).unwrap();
672        let b = Rational::try_from(10.).unwrap();
673
674        assert_eq!(a - b, (-5.).try_into().unwrap());
675        assert_eq!(a - 10., (-5.).try_into().unwrap());
676        assert_eq!(10. - a, (5.).try_into().unwrap());
677    }
678
679    #[test]
680    fn can_div_non_fhe() {
681        let a = Rational::try_from(5.).unwrap();
682        let b = Rational::try_from(10.).unwrap();
683
684        assert_eq!(a / b, (0.5).try_into().unwrap());
685        assert_eq!(a / 10., (0.5).try_into().unwrap());
686        assert_eq!(10. / a, (2.).try_into().unwrap());
687    }
688
689    #[test]
690    fn can_neg_non_fhe() {
691        let a = Rational::try_from(5.).unwrap();
692
693        assert_eq!(-a, (-5.).try_into().unwrap());
694    }
695}