1use crate as sunscreen;
2use crate::fhe::{with_fhe_ctx, FheContextOps};
3use crate::types::{
4 bfv::Signed, intern::FheProgramNode, ops::*, BfvType, Cipher, FheType, GraphCipherAdd,
5 GraphCipherDiv, GraphCipherMul, GraphCipherSub, NumCiphertexts, TryFromPlaintext,
6 TryIntoPlaintext, TypeName,
7};
8use crate::{FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName};
9use std::cmp::Eq;
10use std::ops::*;
11use sunscreen_runtime::Error;
12
13use num::Rational64;
14
15#[derive(Debug, Clone, Copy, TypeName, Eq)]
16pub struct Rational {
20 num: Signed,
21 den: Signed,
22}
23
24impl PartialEq for Rational {
25 fn eq(&self, other: &Self) -> bool {
26 let num_a: i64 = self.num.into();
27 let num_b: i64 = other.num.into();
28 let den_a: i64 = self.den.into();
29 let den_b: i64 = other.den.into();
30
31 num_a * den_b == num_b * den_a
32 }
33}
34
35impl Default for Rational {
36 fn default() -> Self {
37 Self::try_from(0.0).unwrap()
38 }
39}
40
41impl NumCiphertexts for Rational {
42 const NUM_CIPHERTEXTS: usize = Signed::NUM_CIPHERTEXTS + Signed::NUM_CIPHERTEXTS;
43}
44
45impl TryFromPlaintext for Rational {
46 fn try_from_plaintext(plaintext: &Plaintext, params: &Params) -> Result<Self, Error> {
47 let (num, den) = match &plaintext.inner {
48 InnerPlaintext::Seal(p) => {
49 let num = Plaintext {
50 data_type: Self::type_name(),
51 inner: InnerPlaintext::Seal(vec![p[0].clone()]),
52 };
53 let den = Plaintext {
54 data_type: Self::type_name(),
55 inner: InnerPlaintext::Seal(vec![p[1].clone()]),
56 };
57
58 (
59 Signed::try_from_plaintext(&num, params)?,
60 Signed::try_from_plaintext(&den, params)?,
61 )
62 }
63 };
64
65 Ok(Self { num, den })
66 }
67}
68
69impl TryIntoPlaintext for Rational {
70 fn try_into_plaintext(&self, params: &Params) -> Result<Plaintext, Error> {
71 let num = self.num.try_into_plaintext(params)?;
72 let den = self.den.try_into_plaintext(params)?;
73
74 let (num, den) = match (num.inner, den.inner) {
75 (InnerPlaintext::Seal(n), InnerPlaintext::Seal(d)) => (n[0].clone(), d[0].clone()),
76 };
77
78 Ok(Plaintext {
79 data_type: Self::type_name(),
80 inner: InnerPlaintext::Seal(vec![num, den]),
81 })
82 }
83}
84
85impl FheProgramInputTrait for Rational {}
86impl FheType for Rational {}
87impl BfvType for Rational {}
88
89impl TryFrom<f64> for Rational {
90 type Error = Error;
91
92 fn try_from(val: f64) -> Result<Self, Self::Error> {
93 let val = Rational64::approximate_float(val)
94 .ok_or_else(|| Error::fhe_type_error("Failed to parse float into rational"))?;
95
96 Ok(Self {
97 num: Signed::from(*val.numer()),
98 den: Signed::from(*val.denom()),
99 })
100 }
101}
102
103impl From<Rational> for f64 {
104 fn from(val: Rational) -> Self {
105 let num: i64 = val.num.into();
106 let den: i64 = val.den.into();
107
108 num as f64 / den as f64
109 }
110}
111
112impl Add for Rational {
113 type Output = Self;
114
115 fn add(self, rhs: Self) -> Self::Output {
116 Self::Output {
117 num: self.num * rhs.den + rhs.num * self.den,
118 den: self.den * rhs.den,
119 }
120 }
121}
122
123impl Add<f64> for Rational {
124 type Output = Self;
125
126 fn add(self, rhs: f64) -> Self::Output {
127 let rhs = Rational::try_from(rhs).unwrap();
128
129 Self::Output {
130 num: self.num * rhs.den + rhs.num * self.den,
131 den: self.den * rhs.den,
132 }
133 }
134}
135
136impl Add<Rational> for f64 {
137 type Output = Rational;
138
139 fn add(self, rhs: Rational) -> Self::Output {
140 let lhs = Rational::try_from(self).unwrap();
141
142 Self::Output {
143 num: lhs.num * rhs.den + rhs.num * lhs.den,
144 den: lhs.den * rhs.den,
145 }
146 }
147}
148
149impl Mul for Rational {
150 type Output = Self;
151
152 fn mul(self, rhs: Self) -> Self::Output {
153 Self::Output {
154 num: self.num * rhs.num,
155 den: self.den * rhs.den,
156 }
157 }
158}
159
160impl Mul<f64> for Rational {
161 type Output = Self;
162
163 fn mul(self, rhs: f64) -> Self::Output {
164 let rhs = Rational::try_from(rhs).unwrap();
165
166 Self {
167 num: self.num * rhs.num,
168 den: self.den * rhs.den,
169 }
170 }
171}
172
173impl Mul<Rational> for f64 {
174 type Output = Rational;
175
176 fn mul(self, rhs: Rational) -> Self::Output {
177 let lhs = Rational::try_from(self).unwrap();
178
179 Self::Output {
180 num: lhs.num * rhs.num,
181 den: lhs.den * rhs.den,
182 }
183 }
184}
185
186impl Sub for Rational {
187 type Output = Self;
188
189 fn sub(self, rhs: Self) -> Self::Output {
190 Self::Output {
191 num: self.num * rhs.den - rhs.num * self.den,
192 den: self.den * rhs.den,
193 }
194 }
195}
196
197impl Sub<f64> for Rational {
198 type Output = Self;
199
200 fn sub(self, rhs: f64) -> Self::Output {
201 let rhs = Rational::try_from(rhs).unwrap();
202
203 Self::Output {
204 num: self.num * rhs.den - rhs.num * self.den,
205 den: self.den * rhs.den,
206 }
207 }
208}
209
210impl Sub<Rational> for f64 {
211 type Output = Rational;
212
213 fn sub(self, rhs: Rational) -> Self::Output {
214 let lhs = Rational::try_from(self).unwrap();
215
216 Self::Output {
217 num: lhs.num * rhs.den - rhs.num * lhs.den,
218 den: lhs.den * rhs.den,
219 }
220 }
221}
222
223impl Div for Rational {
224 type Output = Self;
225
226 fn div(self, rhs: Self) -> Self::Output {
227 Self::Output {
228 num: self.num * rhs.den,
229 den: self.den * rhs.num,
230 }
231 }
232}
233
234impl Div<f64> for Rational {
235 type Output = Self;
236
237 fn div(self, rhs: f64) -> Self::Output {
238 let rhs = Rational::try_from(rhs).unwrap();
239
240 Self::Output {
241 num: self.num * rhs.den,
242 den: self.den * rhs.num,
243 }
244 }
245}
246
247impl Div<Rational> for f64 {
248 type Output = Rational;
249
250 fn div(self, rhs: Rational) -> Self::Output {
251 let lhs = Rational::try_from(self).unwrap();
252
253 Self::Output {
254 num: lhs.num * rhs.den,
255 den: lhs.den * rhs.num,
256 }
257 }
258}
259
260impl Neg for Rational {
261 type Output = Self;
262
263 fn neg(self) -> Self::Output {
264 Self::Output {
265 num: -self.num,
266 den: self.den,
267 }
268 }
269}
270
271impl GraphCipherAdd for Rational {
272 type Left = Self;
273 type Right = Self;
274
275 fn graph_cipher_add(
276 a: FheProgramNode<Cipher<Self::Left>>,
277 b: FheProgramNode<Cipher<Self::Right>>,
278 ) -> FheProgramNode<Cipher<Self::Left>> {
279 with_fhe_ctx(|ctx| {
280 let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
282 let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
283
284 let den_2 = ctx.add_multiplication(a.ids[1], b.ids[1]);
286
287 let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
288
289 FheProgramNode::new(&ids)
290 })
291 }
292}
293
294impl GraphCipherPlainAdd for Rational {
295 type Left = Self;
296 type Right = Self;
297
298 fn graph_cipher_plain_add(
299 a: FheProgramNode<Cipher<Self::Left>>,
300 b: FheProgramNode<Self::Right>,
301 ) -> FheProgramNode<Cipher<Self::Left>> {
302 with_fhe_ctx(|ctx| {
303 let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
305 let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
306
307 let den_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
309
310 let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
311
312 FheProgramNode::new(&ids)
313 })
314 }
315}
316
317impl GraphCipherInsert for Rational {
318 type Lit = f64;
319 type Val = Self;
320
321 fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
322 with_fhe_ctx(|ctx| {
323 let lit = Self::try_from(lit).unwrap();
324
325 let lit_num =
326 ctx.add_plaintext_literal(lit.num.try_into_plaintext(&ctx.data).unwrap().inner);
327
328 let lit_den =
329 ctx.add_plaintext_literal(lit.den.try_into_plaintext(&ctx.data).unwrap().inner);
330
331 FheProgramNode::new(&[lit_num, lit_den])
332 })
333 }
334}
335
336impl GraphCipherConstAdd for Rational {
337 type Left = Self;
338 type Right = f64;
339
340 fn graph_cipher_const_add(
341 a: FheProgramNode<Cipher<Self::Left>>,
342 b: Self::Right,
343 ) -> FheProgramNode<Cipher<Self::Left>> {
344 let lit = Self::graph_cipher_insert(b);
345 with_fhe_ctx(|ctx| {
346 let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
348 let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
349
350 let den_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
352
353 let ids = [ctx.add_addition(num_a_2, num_b_2), den_2];
354
355 FheProgramNode::new(&ids)
356 })
357 }
358}
359
360impl GraphCipherSub for Rational {
361 type Left = Self;
362 type Right = Self;
363
364 fn graph_cipher_sub(
365 a: FheProgramNode<Cipher<Self::Left>>,
366 b: FheProgramNode<Cipher<Self::Right>>,
367 ) -> FheProgramNode<Cipher<Self::Left>> {
368 with_fhe_ctx(|ctx| {
369 let num_a_2 = ctx.add_multiplication(a.ids[0], b.ids[1]);
371 let num_b_2 = ctx.add_multiplication(a.ids[1], b.ids[0]);
372
373 let den_2 = ctx.add_multiplication(a.ids[1], b.ids[1]);
375
376 let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
377
378 FheProgramNode::new(&ids)
379 })
380 }
381}
382
383impl GraphCipherPlainSub for Rational {
384 type Left = Self;
385 type Right = Self;
386
387 fn graph_cipher_plain_sub(
388 a: FheProgramNode<Cipher<Self::Left>>,
389 b: FheProgramNode<Self::Right>,
390 ) -> FheProgramNode<Cipher<Self::Left>> {
391 with_fhe_ctx(|ctx| {
392 let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
394 let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
395
396 let den_2 = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
398
399 let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
400
401 FheProgramNode::new(&ids)
402 })
403 }
404}
405
406impl GraphPlainCipherSub for Rational {
407 type Left = Self;
408 type Right = Self;
409
410 fn graph_plain_cipher_sub(
411 a: FheProgramNode<Self::Left>,
412 b: FheProgramNode<Cipher<Self::Right>>,
413 ) -> FheProgramNode<Cipher<Self::Left>> {
414 with_fhe_ctx(|ctx| {
415 let num_a_2 = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
417 let num_b_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
418
419 let den_2 = ctx.add_multiplication_plaintext(b.ids[1], a.ids[1]);
421
422 let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
423
424 FheProgramNode::new(&ids)
425 })
426 }
427}
428
429impl GraphCipherConstSub for Rational {
430 type Left = Self;
431 type Right = f64;
432
433 fn graph_cipher_const_sub(
434 a: FheProgramNode<Cipher<Self::Left>>,
435 b: Self::Right,
436 ) -> FheProgramNode<Cipher<Self::Left>> {
437 let lit = Self::graph_cipher_insert(b);
438 with_fhe_ctx(|ctx| {
439 let num_a_2 = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
441 let num_b_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
442
443 let den_2 = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
445
446 let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
447
448 FheProgramNode::new(&ids)
449 })
450 }
451}
452
453impl GraphConstCipherSub for Rational {
454 type Left = f64;
455 type Right = Self;
456
457 fn graph_const_cipher_sub(
458 a: Self::Left,
459 b: FheProgramNode<Cipher<Self::Right>>,
460 ) -> FheProgramNode<Cipher<Self::Right>> {
461 let lit = Self::graph_cipher_insert(a);
462 with_fhe_ctx(|ctx| {
463 let num_b_2 = ctx.add_multiplication_plaintext(b.ids[0], lit.ids[1]);
465 let num_a_2 = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[0]);
466
467 let den_2 = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[1]);
469
470 let ids = [ctx.add_subtraction(num_a_2, num_b_2), den_2];
471
472 FheProgramNode::new(&ids)
473 })
474 }
475}
476
477impl GraphCipherMul for Rational {
478 type Left = Self;
479 type Right = Self;
480
481 fn graph_cipher_mul(
482 a: FheProgramNode<Cipher<Self::Left>>,
483 b: FheProgramNode<Cipher<Self::Right>>,
484 ) -> FheProgramNode<Cipher<Self::Left>> {
485 with_fhe_ctx(|ctx| {
486 let mul_num = ctx.add_multiplication(a.ids[0], b.ids[0]);
487 let mul_den = ctx.add_multiplication(a.ids[1], b.ids[1]);
488
489 let ids = [mul_num, mul_den];
490
491 FheProgramNode::new(&ids)
492 })
493 }
494}
495
496impl GraphCipherPlainMul for Rational {
497 type Left = Self;
498 type Right = Self;
499
500 fn graph_cipher_plain_mul(
501 a: FheProgramNode<Cipher<Self::Left>>,
502 b: FheProgramNode<Self::Right>,
503 ) -> FheProgramNode<Cipher<Self::Left>> {
504 with_fhe_ctx(|ctx| {
505 let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
506 let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[1]);
507
508 let ids = [mul_num, mul_den];
509
510 FheProgramNode::new(&ids)
511 })
512 }
513}
514
515impl GraphCipherConstMul for Rational {
516 type Left = Self;
517 type Right = f64;
518
519 fn graph_cipher_const_mul(
520 a: FheProgramNode<Cipher<Self::Left>>,
521 b: Self::Right,
522 ) -> FheProgramNode<Cipher<Self::Left>> {
523 let lit = Self::graph_cipher_insert(b);
524 with_fhe_ctx(|ctx| {
525 let mul_num = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
526 let mul_den = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[1]);
527
528 let ids = [mul_num, mul_den];
529
530 FheProgramNode::new(&ids)
531 })
532 }
533}
534
535impl GraphCipherDiv for Rational {
536 type Left = Self;
537 type Right = Self;
538
539 fn graph_cipher_div(
540 a: FheProgramNode<Cipher<Self::Left>>,
541 b: FheProgramNode<Cipher<Self::Right>>,
542 ) -> FheProgramNode<Cipher<Self::Left>> {
543 with_fhe_ctx(|ctx| {
544 let mul_num = ctx.add_multiplication(a.ids[0], b.ids[1]);
545 let mul_den = ctx.add_multiplication(a.ids[1], b.ids[0]);
546
547 let ids = [mul_num, mul_den];
548
549 FheProgramNode::new(&ids)
550 })
551 }
552}
553
554impl GraphCipherPlainDiv for Rational {
555 type Left = Self;
556 type Right = Self;
557
558 fn graph_cipher_plain_div(
559 a: FheProgramNode<Cipher<Self::Left>>,
560 b: FheProgramNode<Self::Right>,
561 ) -> FheProgramNode<Cipher<Self::Left>> {
562 with_fhe_ctx(|ctx| {
563 let mul_num = ctx.add_multiplication_plaintext(a.ids[0], b.ids[1]);
564 let mul_den = ctx.add_multiplication_plaintext(a.ids[1], b.ids[0]);
565
566 let ids = [mul_num, mul_den];
567
568 FheProgramNode::new(&ids)
569 })
570 }
571}
572
573impl GraphPlainCipherDiv for Rational {
574 type Left = Self;
575 type Right = Self;
576
577 fn graph_plain_cipher_div(
578 a: FheProgramNode<Self::Left>,
579 b: FheProgramNode<Cipher<Self::Right>>,
580 ) -> FheProgramNode<Cipher<Self::Left>> {
581 with_fhe_ctx(|ctx| {
582 let mul_num = ctx.add_multiplication_plaintext(b.ids[1], a.ids[0]);
583 let mul_den = ctx.add_multiplication_plaintext(b.ids[0], a.ids[1]);
584
585 let ids = [mul_num, mul_den];
586
587 FheProgramNode::new(&ids)
588 })
589 }
590}
591
592impl GraphCipherConstDiv for Rational {
593 type Left = Self;
594 type Right = f64;
595
596 fn graph_cipher_const_div(
597 a: FheProgramNode<Cipher<Self::Left>>,
598 b: Self::Right,
599 ) -> FheProgramNode<Cipher<Self::Left>> {
600 let lit = Self::graph_cipher_insert(b);
601 with_fhe_ctx(|ctx| {
602 let mul_num = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[1]);
603 let mul_den = ctx.add_multiplication_plaintext(a.ids[1], lit.ids[0]);
604
605 let ids = [mul_num, mul_den];
606
607 FheProgramNode::new(&ids)
608 })
609 }
610}
611
612impl GraphConstCipherDiv for Rational {
613 type Left = f64;
614 type Right = Self;
615
616 fn graph_const_cipher_div(
617 a: Self::Left,
618 b: FheProgramNode<Cipher<Self::Right>>,
619 ) -> FheProgramNode<Cipher<Self::Right>> {
620 let lit = Self::graph_cipher_insert(a);
621 with_fhe_ctx(|ctx| {
622 let mul_num = ctx.add_multiplication_plaintext(b.ids[1], lit.ids[0]);
623 let mul_den = ctx.add_multiplication_plaintext(b.ids[0], lit.ids[1]);
624
625 let ids = [mul_num, mul_den];
626
627 FheProgramNode::new(&ids)
628 })
629 }
630}
631
632impl GraphCipherNeg for Rational {
633 type Val = Self;
634
635 fn graph_cipher_neg(a: FheProgramNode<Cipher<Self::Val>>) -> FheProgramNode<Cipher<Self::Val>> {
636 with_fhe_ctx(|ctx| {
637 let neg = ctx.add_negate(a.ids[0]);
638 let ids = [neg, a.ids[1]];
639
640 FheProgramNode::new(&ids)
641 })
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn can_add_non_fhe() {
651 let a = Rational::try_from(5.).unwrap();
652 let b = Rational::try_from(10.).unwrap();
653
654 assert_eq!(a + b, 15f64.try_into().unwrap());
655 assert_eq!(a + 10., 15f64.try_into().unwrap());
656 assert_eq!(10. + a, 15f64.try_into().unwrap());
657 }
658
659 #[test]
660 fn can_mul_non_fhe() {
661 let a = Rational::try_from(5.).unwrap();
662 let b = Rational::try_from(10.).unwrap();
663
664 assert_eq!(a * b, 50f64.try_into().unwrap());
665 assert_eq!(a * 10., 50f64.try_into().unwrap());
666 assert_eq!(10. * a, 50f64.try_into().unwrap());
667 }
668
669 #[test]
670 fn can_sub_non_fhe() {
671 let a = Rational::try_from(5.).unwrap();
672 let b = Rational::try_from(10.).unwrap();
673
674 assert_eq!(a - b, (-5.).try_into().unwrap());
675 assert_eq!(a - 10., (-5.).try_into().unwrap());
676 assert_eq!(10. - a, (5.).try_into().unwrap());
677 }
678
679 #[test]
680 fn can_div_non_fhe() {
681 let a = Rational::try_from(5.).unwrap();
682 let b = Rational::try_from(10.).unwrap();
683
684 assert_eq!(a / b, (0.5).try_into().unwrap());
685 assert_eq!(a / 10., (0.5).try_into().unwrap());
686 assert_eq!(10. / a, (2.).try_into().unwrap());
687 }
688
689 #[test]
690 fn can_neg_non_fhe() {
691 let a = Rational::try_from(5.).unwrap();
692
693 assert_eq!(-a, (-5.).try_into().unwrap());
694 }
695}