1use seal_fhe::Plaintext as SealPlaintext;
2
3use crate::{
4 fhe::{with_fhe_ctx, FheContextOps},
5 types::{
6 ops::{
7 GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
8 GraphCipherConstSub, GraphCipherInsert, GraphCipherMul, GraphCipherNeg,
9 GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub,
10 GraphConstCipherSub, GraphPlainCipherSub,
11 },
12 Cipher,
13 },
14};
15use crate::{
16 types::{intern::FheProgramNode, BfvType, FheType, Type, Version},
17 FheProgramInputTrait, Params, WithContext,
18};
19
20use sunscreen_runtime::{
21 InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName,
22 TypeNameInstance,
23};
24
25use std::ops::*;
26
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct Fractional<const INT_BITS: usize> {
162 val: f64,
163}
164
165impl<const INT_BITS: usize> std::ops::Deref for Fractional<INT_BITS> {
166 type Target = f64;
167
168 fn deref(&self) -> &Self::Target {
169 &self.val
170 }
171}
172
173impl<const INT_BITS: usize> NumCiphertexts for Fractional<INT_BITS> {
174 const NUM_CIPHERTEXTS: usize = 1;
175}
176
177impl<const INT_BITS: usize> FheProgramInputTrait for Fractional<INT_BITS> {}
178
179impl<const INT_BITS: usize> Default for Fractional<INT_BITS> {
180 fn default() -> Self {
181 Self::from(0.0)
182 }
183}
184
185impl<const INT_BITS: usize> TypeName for Fractional<INT_BITS> {
186 fn type_name() -> Type {
187 let version = env!("CARGO_PKG_VERSION");
188
189 Type {
190 name: format!("sunscreen::types::Fractional<{}>", INT_BITS),
191 version: Version::parse(version).expect("Crate version is not a valid semver"),
192 is_encrypted: false,
193 }
194 }
195}
196impl<const INT_BITS: usize> TypeNameInstance for Fractional<INT_BITS> {
197 fn type_name_instance(&self) -> Type {
198 Self::type_name()
199 }
200}
201
202impl<const INT_BITS: usize> FheType for Fractional<INT_BITS> {}
203impl<const INT_BITS: usize> BfvType for Fractional<INT_BITS> {}
204
205impl<const INT_BITS: usize> Fractional<INT_BITS> {}
206
207impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
208 type Left = Fractional<INT_BITS>;
209 type Right = Fractional<INT_BITS>;
210
211 fn graph_cipher_add(
212 a: FheProgramNode<Cipher<Self::Left>>,
213 b: FheProgramNode<Cipher<Self::Right>>,
214 ) -> FheProgramNode<Cipher<Self::Left>> {
215 with_fhe_ctx(|ctx| {
216 let n = ctx.add_addition(a.ids[0], b.ids[0]);
217
218 FheProgramNode::new(&[n])
219 })
220 }
221}
222
223impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
224 type Left = Fractional<INT_BITS>;
225 type Right = Fractional<INT_BITS>;
226
227 fn graph_cipher_plain_add(
228 a: FheProgramNode<Cipher<Self::Left>>,
229 b: FheProgramNode<Self::Right>,
230 ) -> FheProgramNode<Cipher<Self::Left>> {
231 with_fhe_ctx(|ctx| {
232 let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
233
234 FheProgramNode::new(&[n])
235 })
236 }
237}
238
239impl<const INT_BITS: usize> GraphCipherInsert for Fractional<INT_BITS> {
240 type Lit = f64;
241 type Val = Self;
242
243 fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
244 with_fhe_ctx(|ctx| {
245 let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap();
246 let lit = ctx.add_plaintext_literal(lit.inner);
247
248 FheProgramNode::new(&[lit])
249 })
250 }
251}
252
253impl<const INT_BITS: usize> GraphCipherConstAdd for Fractional<INT_BITS> {
254 type Left = Fractional<INT_BITS>;
255 type Right = f64;
256
257 fn graph_cipher_const_add(
258 a: FheProgramNode<Cipher<Self::Left>>,
259 b: Self::Right,
260 ) -> FheProgramNode<Cipher<Self::Left>> {
261 let lit = Self::graph_cipher_insert(b);
262 with_fhe_ctx(|ctx| {
263 let n = ctx.add_addition_plaintext(a.ids[0], lit.ids[0]);
264 FheProgramNode::new(&[n])
265 })
266 }
267}
268
269impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
270 type Left = Fractional<INT_BITS>;
271 type Right = Fractional<INT_BITS>;
272
273 fn graph_cipher_sub(
274 a: FheProgramNode<Cipher<Self::Left>>,
275 b: FheProgramNode<Cipher<Self::Right>>,
276 ) -> FheProgramNode<Cipher<Self::Left>> {
277 with_fhe_ctx(|ctx| {
278 let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
279
280 FheProgramNode::new(&[n])
281 })
282 }
283}
284
285impl<const INT_BITS: usize> GraphCipherPlainSub for Fractional<INT_BITS> {
286 type Left = Fractional<INT_BITS>;
287 type Right = Fractional<INT_BITS>;
288
289 fn graph_cipher_plain_sub(
290 a: FheProgramNode<Cipher<Self::Left>>,
291 b: FheProgramNode<Self::Right>,
292 ) -> FheProgramNode<Cipher<Self::Left>> {
293 with_fhe_ctx(|ctx| {
294 let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
295
296 FheProgramNode::new(&[n])
297 })
298 }
299}
300
301impl<const INT_BITS: usize> GraphPlainCipherSub for Fractional<INT_BITS> {
302 type Left = Fractional<INT_BITS>;
303 type Right = Fractional<INT_BITS>;
304
305 fn graph_plain_cipher_sub(
306 a: FheProgramNode<Self::Left>,
307 b: FheProgramNode<Cipher<Self::Right>>,
308 ) -> FheProgramNode<Cipher<Self::Left>> {
309 with_fhe_ctx(|ctx| {
310 let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
311 let n = ctx.add_negate(n);
312
313 FheProgramNode::new(&[n])
314 })
315 }
316}
317
318impl<const INT_BITS: usize> GraphCipherConstSub for Fractional<INT_BITS> {
319 type Left = Fractional<INT_BITS>;
320 type Right = f64;
321
322 fn graph_cipher_const_sub(
323 a: FheProgramNode<Cipher<Self::Left>>,
324 b: Self::Right,
325 ) -> FheProgramNode<Cipher<Self::Left>> {
326 let lit = Self::graph_cipher_insert(b);
327 with_fhe_ctx(|ctx| {
328 let n = ctx.add_subtraction_plaintext(a.ids[0], lit.ids[0]);
329 FheProgramNode::new(&[n])
330 })
331 }
332}
333
334impl<const INT_BITS: usize> GraphConstCipherSub for Fractional<INT_BITS> {
335 type Left = f64;
336 type Right = Fractional<INT_BITS>;
337
338 fn graph_const_cipher_sub(
339 a: Self::Left,
340 b: FheProgramNode<Cipher<Self::Right>>,
341 ) -> FheProgramNode<Cipher<Self::Right>> {
342 let lit = Self::graph_cipher_insert(a);
343 with_fhe_ctx(|ctx| {
344 let n = ctx.add_subtraction_plaintext(b.ids[0], lit.ids[0]);
345 let n = ctx.add_negate(n);
346
347 FheProgramNode::new(&[n])
348 })
349 }
350}
351
352impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
353 type Left = Fractional<INT_BITS>;
354 type Right = Fractional<INT_BITS>;
355
356 fn graph_cipher_mul(
357 a: FheProgramNode<Cipher<Self::Left>>,
358 b: FheProgramNode<Cipher<Self::Right>>,
359 ) -> FheProgramNode<Cipher<Self::Left>> {
360 with_fhe_ctx(|ctx| {
361 let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
362
363 FheProgramNode::new(&[n])
364 })
365 }
366}
367
368impl<const INT_BITS: usize> GraphCipherPlainMul for Fractional<INT_BITS> {
369 type Left = Fractional<INT_BITS>;
370 type Right = Fractional<INT_BITS>;
371
372 fn graph_cipher_plain_mul(
373 a: FheProgramNode<Cipher<Self::Left>>,
374 b: FheProgramNode<Self::Right>,
375 ) -> FheProgramNode<Cipher<Self::Left>> {
376 with_fhe_ctx(|ctx| {
377 let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
378
379 FheProgramNode::new(&[n])
380 })
381 }
382}
383
384impl<const INT_BITS: usize> GraphCipherConstMul for Fractional<INT_BITS> {
385 type Left = Fractional<INT_BITS>;
386 type Right = f64;
387
388 fn graph_cipher_const_mul(
389 a: FheProgramNode<Cipher<Self::Left>>,
390 b: Self::Right,
391 ) -> FheProgramNode<Cipher<Self::Left>> {
392 let lit = Self::graph_cipher_insert(b);
393 with_fhe_ctx(|ctx| {
394 let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
395 FheProgramNode::new(&[n])
396 })
397 }
398}
399
400impl<const INT_BITS: usize> GraphCipherConstDiv for Fractional<INT_BITS> {
401 type Left = Fractional<INT_BITS>;
402 type Right = f64;
403
404 fn graph_cipher_const_div(
405 a: FheProgramNode<Cipher<Self::Left>>,
406 b: f64,
407 ) -> FheProgramNode<Cipher<Self::Left>> {
408 let lit = Self::graph_cipher_insert(1. / b);
409 with_fhe_ctx(|ctx| {
410 let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
411 FheProgramNode::new(&[n])
412 })
413 }
414}
415
416impl<const INT_BITS: usize> GraphCipherNeg for Fractional<INT_BITS> {
417 type Val = Fractional<INT_BITS>;
418
419 fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
420 with_fhe_ctx(|ctx| {
421 let n = ctx.add_negate(a.ids[0]);
422
423 FheProgramNode::new(&[n])
424 })
425 }
426}
427
428impl<const INT_BITS: usize> TryIntoPlaintext for Fractional<INT_BITS> {
429 fn try_into_plaintext(
430 &self,
431 params: &Params,
432 ) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
433 if self.val.is_nan() {
434 return Err(sunscreen_runtime::Error::fhe_type_error("Value is NaN."));
435 }
436
437 if self.val.is_infinite() {
438 return Err(sunscreen_runtime::Error::fhe_type_error(
439 "Value is infinite.",
440 ));
441 }
442
443 let mut seal_plaintext = SealPlaintext::new()?;
444 let n = params.lattice_dimension as usize;
445 seal_plaintext.resize(n);
446
447 if self.val.is_subnormal() || self.val == 0.0 {
449 return Ok(Plaintext {
450 data_type: self.type_name_instance(),
451 inner: InnerPlaintext::Seal(vec![WithContext {
452 params: params.clone(),
453 data: seal_plaintext,
454 }]),
455 });
456 }
457
458 let as_u64: u64 = self.val.to_bits();
466
467 let sign_mask = 0x1 << 63;
468 let mantissa_mask = 0xFFFFFFFFFFFFF;
469 let exp_mask = !mantissa_mask & !sign_mask;
470
471 let mantissa = as_u64 & mantissa_mask | (mantissa_mask + 1);
473 let exp = as_u64 & exp_mask;
474 let power = (exp >> (f64::MANTISSA_DIGITS - 1)) as i64 - 1023;
475 let sign = (as_u64 & sign_mask) >> 63;
476
477 if power + 1 > INT_BITS as i64 {
478 return Err(sunscreen_runtime::Error::fhe_type_error("Out of range"));
479 }
480
481 for i in 0..f64::MANTISSA_DIGITS {
482 let bit_value = (mantissa & 0x1 << i) >> i;
483 let bit_power = power - (f64::MANTISSA_DIGITS - i - 1) as i64;
484
485 let coeff_index = if bit_power >= 0 {
486 bit_power as usize
487 } else {
488 (n as i64 + bit_power) as usize
489 };
490
491 let sign = if bit_power >= 0 { sign } else { !sign & 0x1 };
493
494 let coeff = if sign == 0 {
495 bit_value
496 } else if bit_value > 0 {
497 params.plain_modulus - bit_value
498 } else {
499 0
500 };
501
502 seal_plaintext.set_coefficient(coeff_index, coeff);
503 }
504
505 Ok(Plaintext {
506 data_type: self.type_name_instance(),
507 inner: InnerPlaintext::Seal(vec![WithContext {
508 params: params.clone(),
509 data: seal_plaintext,
510 }]),
511 })
512 }
513}
514
515impl<const INT_BITS: usize> TryFromPlaintext for Fractional<INT_BITS> {
516 fn try_from_plaintext(
517 plaintext: &Plaintext,
518 params: &Params,
519 ) -> std::result::Result<Self, sunscreen_runtime::Error> {
520 let val = match &plaintext.inner {
521 InnerPlaintext::Seal(p) => {
522 if p.len() != 1 {
523 return Err(sunscreen_runtime::Error::IncorrectCiphertextCount);
524 }
525
526 let mut val = 0.0f64;
527 let n = params.lattice_dimension as usize;
528
529 let len = p[0].len();
530
531 let negative_cutoff = (params.plain_modulus + 1) / 2;
532
533 for i in 0..usize::min(n, len) {
534 let power = if i < INT_BITS {
535 i as i64
536 } else {
537 i as i64 - n as i64
538 };
539
540 let coeff = p[0].get_coefficient(i);
541
542 let sign = if power >= 0 { 1f64 } else { -1f64 };
544
545 if coeff < negative_cutoff {
546 val += sign * coeff as f64 * (power as f64).exp2();
547 } else {
548 val -= sign * (params.plain_modulus - coeff) as f64 * (power as f64).exp2();
549 };
550 }
551
552 Self { val }
553 }
554 };
555
556 Ok(val)
557 }
558}
559
560impl<const INT_BITS: usize> From<f64> for Fractional<INT_BITS> {
561 fn from(val: f64) -> Self {
562 Self { val }
563 }
564}
565
566impl<const INT_BITS: usize> From<Fractional<INT_BITS>> for f64 {
567 fn from(frac: Fractional<INT_BITS>) -> Self {
568 frac.val
569 }
570}
571
572impl<const INT_BITS: usize> Add for Fractional<INT_BITS> {
573 type Output = Self;
574
575 fn add(self, rhs: Self) -> Self {
576 Self {
577 val: self.val + rhs.val,
578 }
579 }
580}
581
582impl<const INT_BITS: usize> Add<f64> for Fractional<INT_BITS> {
583 type Output = Self;
584
585 fn add(self, rhs: f64) -> Self {
586 Self {
587 val: self.val + rhs,
588 }
589 }
590}
591
592impl<const INT_BITS: usize> Add<Fractional<INT_BITS>> for f64 {
593 type Output = Fractional<INT_BITS>;
594
595 fn add(self, rhs: Fractional<INT_BITS>) -> Self::Output {
596 Fractional {
597 val: self + rhs.val,
598 }
599 }
600}
601
602impl<const INT_BITS: usize> Mul for Fractional<INT_BITS> {
603 type Output = Self;
604
605 fn mul(self, rhs: Self) -> Self {
606 Self {
607 val: self.val * rhs.val,
608 }
609 }
610}
611
612impl<const INT_BITS: usize> Mul<f64> for Fractional<INT_BITS> {
613 type Output = Self;
614
615 fn mul(self, rhs: f64) -> Self {
616 Self {
617 val: self.val * rhs,
618 }
619 }
620}
621
622impl<const INT_BITS: usize> Mul<Fractional<INT_BITS>> for f64 {
623 type Output = Fractional<INT_BITS>;
624
625 fn mul(self, rhs: Fractional<INT_BITS>) -> Self::Output {
626 Fractional {
627 val: self * rhs.val,
628 }
629 }
630}
631
632impl<const INT_BITS: usize> Sub for Fractional<INT_BITS> {
633 type Output = Self;
634
635 fn sub(self, rhs: Self) -> Self {
636 Self {
637 val: self.val - rhs.val,
638 }
639 }
640}
641
642impl<const INT_BITS: usize> Sub<f64> for Fractional<INT_BITS> {
643 type Output = Self;
644
645 fn sub(self, rhs: f64) -> Self {
646 Self {
647 val: self.val - rhs,
648 }
649 }
650}
651
652impl<const INT_BITS: usize> Sub<Fractional<INT_BITS>> for f64 {
653 type Output = Fractional<INT_BITS>;
654
655 fn sub(self, rhs: Fractional<INT_BITS>) -> Self::Output {
656 Fractional {
657 val: self - rhs.val,
658 }
659 }
660}
661
662impl<const INT_BITS: usize> Div<f64> for Fractional<INT_BITS> {
663 type Output = Self;
664
665 fn div(self, rhs: f64) -> Self {
666 Self {
667 val: self.val / rhs,
668 }
669 }
670}
671
672impl<const INT_BITS: usize> Neg for Fractional<INT_BITS> {
673 type Output = Self;
674
675 fn neg(self) -> Self {
676 Self { val: -self.val }
677 }
678}
679
680#[cfg(test)]
681mod tests {
682
683 #![allow(clippy::approx_constant)]
684
685 use super::*;
686 use crate::{SchemeType, SecurityLevel};
687 use float_cmp::ApproxEq;
688
689 #[test]
690 fn can_encode_decode_fractional() {
691 let round_trip = |x: f64| {
692 let params = Params {
693 lattice_dimension: 4096,
694 plain_modulus: 1_000_000,
695 coeff_modulus: vec![],
696 scheme_type: SchemeType::Bfv,
697 security_level: SecurityLevel::TC128,
698 };
699
700 let f_1 = Fractional::<64>::from(x);
701 let pt = f_1.try_into_plaintext(¶ms).unwrap();
702 let f_2 = Fractional::<64>::try_from_plaintext(&pt, ¶ms).unwrap();
703
704 assert_eq!(f_1, f_2);
705 };
706
707 round_trip(3.14);
708 round_trip(0.0);
709 round_trip(1.0);
710 round_trip(5.8125);
711 round_trip(6.0);
712 round_trip(6.6);
713 round_trip(1.2);
714 round_trip(1e13);
715 round_trip(0.0000000005);
716 round_trip(-1.0);
717 round_trip(-5.875);
718 round_trip(-6.0);
719 round_trip(-6.6);
720 round_trip(-1.2);
721 round_trip(-1e13);
722 round_trip(-0.0000000005);
723 }
724
725 #[test]
726 fn can_add_non_fhe() {
727 let a = Fractional::<64>::from(3.14);
728 let b = Fractional::<64>::from(1.5);
729
730 assert!((a + b).approx_eq(4.64, (0.0, 1)));
732 assert!((3.14 + b).approx_eq(4.64, (0.0, 1)));
733 assert!((a + 1.5).approx_eq(4.64, (0.0, 1)));
734 }
735
736 #[test]
737 fn can_mul_non_fhe() {
738 let a = Fractional::<64>::from(3.14);
739 let b = Fractional::<64>::from(1.5);
740
741 assert!((a * b).approx_eq(4.71, (0.0, 1)));
743 assert!((3.14 * b).approx_eq(4.71, (0.0, 1)));
744 assert!((a * 1.5).approx_eq(4.71, (0.0, 1)));
745 }
746
747 #[test]
748 fn can_sub_non_fhe() {
749 let a = Fractional::<64>::from(3.14);
750 let b = Fractional::<64>::from(1.5);
751
752 assert!((a - b).approx_eq(1.64, (0.0, 1)));
754 assert!((3.14 - b).approx_eq(1.64, (0.0, 1)));
755 assert!((a - 1.5).approx_eq(1.64, (0.0, 1)));
756 }
757
758 #[test]
759 fn can_div_non_fhe() {
760 let a = Fractional::<64>::from(3.14);
761
762 assert!((a / 1.5).approx_eq(3.14 / 1.5, (0.0, 1)));
764 }
765
766 #[test]
767 fn can_neg_non_fhe() {
768 let a = Fractional::<64>::from(3.14);
769
770 assert_eq!(-a, (-3.14).into());
772 }
773}