1#![allow(dead_code)]
3
4use rand::prelude::*;
39use rand::rngs::StdRng;
40use serde::{Deserialize, Serialize};
41use sha2::{Digest, Sha256};
42use zeroize::{Zeroize, ZeroizeOnDrop};
43use spine_neural::{
44 Activation, DenseLayer, MirasNeuralEncoder, MirasVariant, MultiHeadAttention,
45 NeuralEncoderConfig, TitansMemory,
46};
47use std::collections::VecDeque;
48use subtle::ConstantTimeEq;
49
50use ml_kem::kem::{Decapsulate, Encapsulate, EncapsulationKey, DecapsulationKey};
52use ml_kem::{MlKem512, MlKem768, MlKem1024, KemCore, EncodedSizeUser, Encoded,
53 MlKem512Params, MlKem768Params, MlKem1024Params};
54
55use aes_gcm::aead::Aead;
57use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
58use hkdf::Hkdf;
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PositionalEncoding {
68 max_len: usize,
69 embed_dim: usize,
70 encodings: Vec<Vec<f32>>,
71}
72
73impl PositionalEncoding {
74 pub fn new(max_len: usize, embed_dim: usize) -> Self {
75 let encodings: Vec<Vec<f32>> = (0..max_len)
77 .map(|pos| {
78 (0..embed_dim)
79 .map(|i| {
80 let angle = pos as f32
81 / (10000.0_f32).powf(2.0 * (i / 2) as f32 / embed_dim as f32);
82 if i % 2 == 0 {
83 angle.sin()
84 } else {
85 angle.cos()
86 }
87 })
88 .collect()
89 })
90 .collect();
91
92 Self {
93 max_len,
94 embed_dim,
95 encodings,
96 }
97 }
98
99 pub fn get(&self, position: usize) -> &[f32] {
100 if self.encodings.is_empty() {
102 return &[];
103 }
104 &self.encodings[position.min(self.max_len.saturating_sub(1))]
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct LayerNorm {
111 dim: usize,
112 gamma: Vec<f32>,
113 beta: Vec<f32>,
114 eps: f32,
115}
116
117impl LayerNorm {
118 pub fn new(dim: usize) -> Self {
119 Self {
120 dim,
121 gamma: vec![1.0; dim],
122 beta: vec![0.0; dim],
123 eps: 1e-5,
124 }
125 }
126
127 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
128 if x.is_empty() {
130 return Vec::new();
131 }
132 let n = x.len() as f32;
133 let mean: f32 = x.iter().sum::<f32>() / n;
134 let var: f32 = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
135 let std = (var + self.eps).sqrt();
136
137 x.iter()
138 .enumerate()
139 .map(|(i, &v)| self.gamma[i % self.dim] * (v - mean) / std + self.beta[i % self.dim])
140 .collect()
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct FeedForward {
147 linear1: DenseLayer,
148 linear2: DenseLayer,
149}
150
151impl FeedForward {
152 pub fn new(embed_dim: usize, ff_dim: usize, rng: &mut StdRng) -> Self {
153 Self {
154 linear1: DenseLayer::new(embed_dim, ff_dim, Activation::GELU, rng),
155 linear2: DenseLayer::new(ff_dim, embed_dim, Activation::None, rng),
156 }
157 }
158
159 pub fn forward(&mut self, x: &[f32]) -> Vec<f32> {
160 let hidden = self.linear1.forward(x);
161 self.linear2.forward(&hidden)
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct TitansBlock {
168 memory: TitansMemory,
170 attention: MultiHeadAttention,
172 ff: FeedForward,
173 norm1: LayerNorm,
174 norm2: LayerNorm,
175 norm3: LayerNorm,
176 embed_dim: usize,
177}
178
179impl TitansBlock {
180 pub fn new(
181 embed_dim: usize,
182 num_heads: usize,
183 ff_dim: usize,
184 memory_size: usize,
185 rng: &mut StdRng,
186 ) -> Self {
187 Self {
188 memory: TitansMemory::new(embed_dim, embed_dim, memory_size, rng),
189 attention: MultiHeadAttention::new(embed_dim, num_heads, rng),
190 ff: FeedForward::new(embed_dim, ff_dim, rng),
191 norm1: LayerNorm::new(embed_dim),
192 norm2: LayerNorm::new(embed_dim),
193 norm3: LayerNorm::new(embed_dim),
194 embed_dim,
195 }
196 }
197
198 pub fn forward(&mut self, sequence: &[Vec<f32>]) -> Vec<f32> {
199 if sequence.is_empty() {
200 return vec![0.0; self.embed_dim];
201 }
202
203 let last = &sequence[sequence.len() - 1];
204
205 let memory_out = self.memory.forward(last);
207 let residual1: Vec<f32> = memory_out
208 .iter()
209 .zip(last.iter())
210 .map(|(m, l)| m + l)
211 .collect();
212 let normed1 = self.norm1.forward(&residual1);
213
214 let attended = self.attention.forward(sequence);
216 let residual2: Vec<f32> = attended
217 .iter()
218 .zip(normed1.iter())
219 .map(|(a, n)| a + n)
220 .collect();
221 let normed2 = self.norm2.forward(&residual2);
222
223 let ff_out = self.ff.forward(&normed2);
225 let residual3: Vec<f32> = ff_out
226 .iter()
227 .zip(normed2.iter())
228 .map(|(f, n)| f + n)
229 .collect();
230 self.norm3.forward(&residual3)
231 }
232
233 pub fn get_surprise(&self) -> f32 {
235 self.memory.get_surprise()
236 }
237
238 pub fn reset_memory(&mut self) {
240 self.memory.reset_state();
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ByteTokenizer {
247 embed_dim: usize,
248 embeddings: Vec<Vec<f32>>, }
250
251impl ByteTokenizer {
252 pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
253 let scale = (1.0 / embed_dim as f32).sqrt();
254 let embeddings: Vec<Vec<f32>> = (0..256)
255 .map(|_| {
256 (0..embed_dim)
257 .map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
258 .collect()
259 })
260 .collect();
261
262 Self {
263 embed_dim,
264 embeddings,
265 }
266 }
267
268 pub fn encode(&self, byte: u8) -> &[f32] {
269 &self.embeddings[byte as usize]
270 }
271
272 pub fn encode_sequence(&self, bytes: &[u8]) -> Vec<Vec<f32>> {
273 bytes
274 .iter()
275 .map(|&b| self.embeddings[b as usize].clone())
276 .collect()
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct OutputProjection {
283 weights: Vec<Vec<f32>>, temperature: f32,
285}
286
287impl OutputProjection {
288 pub fn new(embed_dim: usize, rng: &mut StdRng) -> Self {
289 let scale = (1.0 / embed_dim as f32).sqrt();
290 let weights: Vec<Vec<f32>> = (0..256)
291 .map(|_| {
292 (0..embed_dim)
293 .map(|_| rng.gen::<f32>() * 2.0 * scale - scale)
294 .collect()
295 })
296 .collect();
297
298 Self {
299 weights,
300 temperature: 1.0,
301 }
302 }
303
304 pub fn set_temperature(&mut self, temp: f32) {
305 self.temperature = temp.max(0.01);
306 }
307
308 pub fn forward(&self, hidden: &[f32]) -> Vec<f32> {
309 let mut logits = vec![0.0; 256];
310 for (i, w) in self.weights.iter().enumerate() {
311 for (j, &h) in hidden.iter().enumerate() {
312 logits[i] += w[j] * h;
313 }
314 logits[i] /= self.temperature;
315 }
316
317 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
319 let mut sum = 0.0;
320 for l in &mut logits {
321 *l = (*l - max).exp();
322 sum += *l;
323 }
324 for l in &mut logits {
325 *l /= sum;
326 }
327
328 logits
329 }
330
331 pub fn sample(&self, probs: &[f32], rng: &mut StdRng) -> u8 {
332 let mut cumsum = 0.0;
333 let r: f32 = rng.gen();
334 for (i, &p) in probs.iter().enumerate() {
335 cumsum += p;
336 if r < cumsum {
337 return i as u8;
338 }
339 }
340 255
341 }
342
343 pub fn argmax(&self, probs: &[f32]) -> u8 {
344 probs
345 .iter()
346 .enumerate()
347 .max_by(|(_, a), (_, b)| {
348 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less)
350 })
351 .map(|(i, _)| i as u8)
352 .unwrap_or(0)
353 }
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct TitansPredictor {
366 tokenizer: ByteTokenizer,
367 positional: PositionalEncoding,
368 blocks: Vec<TitansBlock>,
369 output: OutputProjection,
370 embed_dim: usize,
371 max_seq_len: usize,
372 memory_size: usize,
373 context_window: VecDeque<Vec<f32>>,
374 total_surprise: f32,
376 #[serde(skip, default = "default_rng")]
377 rng: StdRng,
378}
379
380#[derive(Debug, Clone)]
388pub struct MirasTitansPredictor {
389 base: TitansPredictor,
391 miras_encoder: Option<MirasNeuralEncoder>,
393 active_variant: MirasVariant,
395 surprise_history: VecDeque<f32>,
397 anomaly_threshold: f32,
399 message_count: u64,
401 miras_enhanced_predictions: u64,
403 latent_dim: usize,
405}
406
407impl MirasTitansPredictor {
408 pub fn new(config: TitansConfig) -> Self {
410 let base = TitansPredictor::new(config.clone());
411
412 let encoder_config = NeuralEncoderConfig {
414 input_dim: config.embed_dim,
415 latent_dim: config.embed_dim,
416 hidden_dims: vec![config.ff_dim, config.embed_dim],
417 attention_heads: config.num_heads,
418 seed: config.seed + 1,
419 miras_variant: MirasVariant::Titans,
420 memory_tokens: config.memory_size,
421 };
422
423 let miras_encoder = Some(MirasNeuralEncoder::new(&encoder_config));
424
425 Self {
426 base,
427 miras_encoder,
428 active_variant: MirasVariant::Titans,
429 surprise_history: VecDeque::with_capacity(100),
430 anomaly_threshold: 0.5,
431 message_count: 0,
432 miras_enhanced_predictions: 0,
433 latent_dim: config.embed_dim,
434 }
435 }
436
437 pub fn new_with_variant(config: TitansConfig, variant: MirasVariant) -> Self {
439 let base = TitansPredictor::new(config.clone());
440
441 let encoder_config = NeuralEncoderConfig {
443 input_dim: config.embed_dim,
444 latent_dim: config.embed_dim,
445 hidden_dims: vec![config.ff_dim, config.embed_dim],
446 attention_heads: config.num_heads,
447 seed: config.seed + 1,
448 miras_variant: variant,
449 memory_tokens: config.memory_size,
450 };
451
452 Self {
453 base,
454 miras_encoder: Some(MirasNeuralEncoder::new(&encoder_config)),
455 active_variant: variant,
456 surprise_history: VecDeque::with_capacity(100),
457 anomaly_threshold: 0.5,
458 message_count: 0,
459 miras_enhanced_predictions: 0,
460 latent_dim: config.embed_dim,
461 }
462 }
463
464 pub fn set_anomaly_threshold(&mut self, threshold: f32) {
466 self.anomaly_threshold = threshold;
467 }
468
469 pub fn variant(&self) -> &str {
471 match self.active_variant {
472 MirasVariant::Titans => "titans",
473 MirasVariant::Yaad => "yaad",
474 MirasVariant::Moneta { .. } => "moneta",
475 MirasVariant::Memora => "memora",
476 }
477 }
478
479 pub fn anomaly_level(&self) -> f32 {
481 if self.surprise_history.is_empty() {
482 0.0
483 } else {
484 self.surprise_history.iter().sum::<f32>() / self.surprise_history.len() as f32
485 }
486 }
487
488 fn maybe_switch_variant(&mut self) {
490 let anomaly = self.anomaly_level();
491
492 let new_variant = if anomaly > self.anomaly_threshold * 2.0 {
493 MirasVariant::Yaad
495 } else if anomaly > self.anomaly_threshold {
496 MirasVariant::Memora
498 } else if self.message_count > 10000 {
499 MirasVariant::Moneta { p: 2.0 }
501 } else {
502 MirasVariant::Titans
504 };
505
506 let variant_changed = !matches!(
508 (&new_variant, &self.active_variant),
509 (MirasVariant::Titans, MirasVariant::Titans)
510 | (MirasVariant::Yaad, MirasVariant::Yaad)
511 | (MirasVariant::Moneta { .. }, MirasVariant::Moneta { .. })
512 | (MirasVariant::Memora, MirasVariant::Memora)
513 );
514
515 if variant_changed {
516 self.active_variant = new_variant;
517 }
520 }
521
522 pub fn observe(&mut self, message: &[u8]) {
524 self.base.observe(message);
526
527 let surprise = self.base.get_surprise();
529 self.surprise_history.push_back(surprise);
530 if self.surprise_history.len() > 100 {
531 self.surprise_history.pop_front();
532 }
533
534 if let Some(ref mut encoder) = self.miras_encoder {
536 let _latent = encoder.encode(message);
538 self.miras_enhanced_predictions += 1;
539 }
540
541 self.message_count += 1;
542
543 self.maybe_switch_variant();
545 }
546
547 pub fn predict_next(&mut self) -> (u8, f32) {
549 self.base.predict_next()
550 }
551
552 pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
554 self.base.predict_sequence(length, greedy)
555 }
556
557 pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
559 self.base.verify_prediction(message)
560 }
561
562 pub fn get_surprise(&self) -> f32 {
564 self.base.get_surprise()
565 }
566
567 pub fn is_anomalous(&self, threshold: f32) -> bool {
569 self.base.is_anomalous(threshold)
570 }
571
572 pub fn get_miras_surprise(&self) -> Option<f32> {
574 self.miras_encoder.as_ref().map(|e| e.get_surprise())
575 }
576
577 pub fn get_combined_surprise(&self) -> f32 {
579 let titans = self.base.get_surprise();
580 let miras = self.get_miras_surprise().unwrap_or(0.0);
581 (titans + miras) / 2.0
582 }
583
584 pub fn reset(&mut self) {
586 self.base.reset();
587 }
588
589 pub fn reset_all(&mut self) {
591 self.base.reset_all();
592 self.surprise_history.clear();
593 self.message_count = 0;
594 if let Some(ref mut encoder) = self.miras_encoder {
595 encoder.reset();
596 }
597 }
598
599 pub fn stats(&self) -> MirasPredictorStats {
601 MirasPredictorStats {
602 message_count: self.message_count,
603 miras_enhanced_predictions: self.miras_enhanced_predictions,
604 current_variant: self.variant().to_string(),
605 anomaly_level: self.anomaly_level(),
606 titans_surprise: self.base.get_surprise(),
607 miras_surprise: self.get_miras_surprise(),
608 }
609 }
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
614pub struct MirasPredictorStats {
615 pub message_count: u64,
616 pub miras_enhanced_predictions: u64,
617 pub current_variant: String,
618 pub anomaly_level: f32,
619 pub titans_surprise: f32,
620 pub miras_surprise: Option<f32>,
621}
622
623fn default_rng() -> StdRng {
624 StdRng::seed_from_u64(42)
625}
626
627impl TitansPredictor {
628 pub fn new(config: TitansConfig) -> Self {
629 let mut rng = StdRng::seed_from_u64(config.seed);
630
631 let tokenizer = ByteTokenizer::new(config.embed_dim, &mut rng);
632 let positional = PositionalEncoding::new(config.max_seq_len, config.embed_dim);
633
634 let blocks: Vec<TitansBlock> = (0..config.num_layers)
635 .map(|_| {
636 TitansBlock::new(
637 config.embed_dim,
638 config.num_heads,
639 config.ff_dim,
640 config.memory_size,
641 &mut rng,
642 )
643 })
644 .collect();
645
646 let output = OutputProjection::new(config.embed_dim, &mut rng);
647
648 Self {
649 tokenizer,
650 positional,
651 blocks,
652 output,
653 embed_dim: config.embed_dim,
654 max_seq_len: config.max_seq_len,
655 memory_size: config.memory_size,
656 context_window: VecDeque::with_capacity(config.max_seq_len),
657 total_surprise: 0.0,
658 rng,
659 }
660 }
661
662 pub fn observe(&mut self, message: &[u8]) {
664 for &byte in message {
665 let mut embedding = self.tokenizer.encode(byte).to_vec();
666 let pos = self.context_window.len();
667 let pos_enc = self.positional.get(pos);
668 for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
669 *e += *p;
670 }
671
672 self.context_window.push_back(embedding);
673 if self.context_window.len() > self.max_seq_len {
674 self.context_window.pop_front();
675 }
676 }
677
678 self.total_surprise = self.blocks.iter().map(|b| b.get_surprise()).sum::<f32>()
680 / self.blocks.len().max(1) as f32;
681 }
682
683 pub fn predict_next(&mut self) -> (u8, f32) {
685 let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
686
687 if sequence.is_empty() {
688 return (0, 1.0 / 256.0);
689 }
690
691 let mut hidden = self.blocks[0].forward(&sequence);
693 for block in &mut self.blocks[1..] {
694 let seq_with_hidden = vec![hidden.clone()];
695 hidden = block.forward(&seq_with_hidden);
696 }
697
698 let probs = self.output.forward(&hidden);
700 let predicted = self.output.argmax(&probs);
701 let confidence = probs[predicted as usize];
702
703 (predicted, confidence)
704 }
705
706 pub fn predict_sequence(&mut self, length: usize, greedy: bool) -> Vec<u8> {
708 let mut result = Vec::with_capacity(length);
709
710 for _ in 0..length {
711 let sequence: Vec<Vec<f32>> = self.context_window.iter().cloned().collect();
712
713 if sequence.is_empty() {
714 let byte = if greedy { 0 } else { self.rng.gen() };
715 result.push(byte);
716 continue;
717 }
718
719 let mut hidden = self.blocks[0].forward(&sequence);
721 for block in &mut self.blocks[1..] {
722 let seq_with_hidden = vec![hidden.clone()];
723 hidden = block.forward(&seq_with_hidden);
724 }
725
726 let probs = self.output.forward(&hidden);
727 let byte = if greedy {
728 self.output.argmax(&probs)
729 } else {
730 self.output.sample(&probs, &mut self.rng)
731 };
732
733 result.push(byte);
734
735 let mut embedding = self.tokenizer.encode(byte).to_vec();
737 let pos = self.context_window.len();
738 let pos_enc = self.positional.get(pos);
739 for (e, p) in embedding.iter_mut().zip(pos_enc.iter()) {
740 *e += *p;
741 }
742 self.context_window.push_back(embedding);
743 if self.context_window.len() > self.max_seq_len {
744 self.context_window.pop_front();
745 }
746 }
747
748 result
749 }
750
751 pub fn verify_prediction(&mut self, message: &[u8]) -> (bool, f32) {
753 let predicted = self.predict_sequence(message.len(), true);
754 let matches = predicted == message;
755
756 let similarity = predicted
757 .iter()
758 .zip(message.iter())
759 .filter(|(p, m)| p == m)
760 .count() as f32
761 / message.len().max(1) as f32;
762
763 (matches, similarity)
764 }
765
766 pub fn get_surprise(&self) -> f32 {
769 self.total_surprise
770 }
771
772 pub fn is_anomalous(&self, threshold: f32) -> bool {
774 self.total_surprise > threshold
775 }
776
777 pub fn reset(&mut self) {
779 self.context_window.clear();
780 self.total_surprise = 0.0;
781 }
782
783 pub fn reset_all(&mut self) {
785 self.context_window.clear();
786 self.total_surprise = 0.0;
787 for block in &mut self.blocks {
788 block.reset_memory();
789 }
790 }
791
792 pub fn set_temperature(&mut self, temp: f32) {
794 self.output.set_temperature(temp);
795 }
796}
797
798pub type TransformerPredictor = TitansPredictor;
800pub type TransformerConfig = TitansConfig;
801
802#[derive(Debug, Clone, Serialize, Deserialize)]
804pub struct TitansConfig {
805 pub embed_dim: usize,
806 pub num_heads: usize,
807 pub num_layers: usize,
808 pub ff_dim: usize,
809 pub max_seq_len: usize,
810 pub memory_size: usize,
812 pub seed: u64,
813}
814
815impl Default for TitansConfig {
816 fn default() -> Self {
817 Self {
818 embed_dim: 64,
819 num_heads: 4,
820 num_layers: 2,
821 ff_dim: 128,
822 max_seq_len: 256,
823 memory_size: 64, seed: 42,
825 }
826 }
827}
828
829#[derive(Debug, Clone, Serialize, Deserialize)]
835pub struct LatticeParams {
836 pub n: usize, pub q: u64, pub p: u64, pub sigma: f64, }
841
842impl Default for LatticeParams {
843 fn default() -> Self {
844 Self {
845 n: 1024, q: 12289, p: 3, sigma: 3.2, }
850 }
851}
852
853#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
859pub struct RingElement {
860 coeffs: Vec<i64>,
861 n: usize,
862 q: u64,
863}
864
865impl RingElement {
866 pub fn new(n: usize, q: u64) -> Self {
867 Self {
868 coeffs: vec![0; n],
869 n,
870 q,
871 }
872 }
873
874 pub fn random(n: usize, q: u64, rng: &mut StdRng) -> Self {
875 let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(0..q as i64)).collect();
876 Self { coeffs, n, q }
877 }
878
879 pub fn random_ternary(n: usize, q: u64, rng: &mut StdRng) -> Self {
880 let coeffs: Vec<i64> = (0..n).map(|_| rng.gen_range(-1..=1)).collect();
881 Self { coeffs, n, q }
882 }
883
884 pub fn random_gaussian(n: usize, q: u64, sigma: f64, rng: &mut StdRng) -> Self {
885 let coeffs: Vec<i64> = (0..n)
887 .map(|_| {
888 let u1: f64 = rng.gen::<f64>().max(1e-10);
889 let u2: f64 = rng.gen();
890 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
891 (z * sigma).round() as i64
892 })
893 .collect();
894 Self { coeffs, n, q }
895 }
896
897 pub fn from_bytes(bytes: &[u8], n: usize, q: u64) -> Self {
898 let mut coeffs = vec![0i64; n];
899 for (i, chunk) in bytes.chunks(2).enumerate() {
900 if i >= n {
901 break;
902 }
903 let val = if chunk.len() == 2 {
904 ((chunk[0] as u16) | ((chunk[1] as u16) << 8)) as i64
905 } else {
906 chunk[0] as i64
907 };
908 coeffs[i] = val % q as i64;
909 }
910 Self { coeffs, n, q }
911 }
912
913 pub fn to_bytes(&self) -> Vec<u8> {
914 let mut bytes = Vec::with_capacity(self.n * 2);
915 for &c in &self.coeffs {
916 let val = ((c % self.q as i64 + self.q as i64) % self.q as i64) as u16;
917 bytes.push(val as u8);
918 bytes.push((val >> 8) as u8);
919 }
920 bytes
921 }
922
923 fn reduce(&mut self) {
924 for c in &mut self.coeffs {
925 *c = ((*c % self.q as i64) + self.q as i64) % self.q as i64;
926 }
927 }
928
929 pub fn mul(&self, other: &RingElement) -> RingElement {
931 assert_eq!(self.n, other.n);
932 let mut result = vec![0i64; self.n];
933
934 for i in 0..self.n {
935 for j in 0..self.n {
936 let idx = i + j;
937 let coeff = self.coeffs[i] * other.coeffs[j];
938 if idx < self.n {
939 result[idx] += coeff;
940 } else {
941 result[idx - self.n] -= coeff;
943 }
944 }
945 }
946
947 let mut elem = RingElement {
948 coeffs: result,
949 n: self.n,
950 q: self.q,
951 };
952 elem.reduce();
953 elem
954 }
955
956 pub fn add(&self, other: &RingElement) -> RingElement {
958 assert_eq!(self.n, other.n);
959 let coeffs: Vec<i64> = self
960 .coeffs
961 .iter()
962 .zip(other.coeffs.iter())
963 .map(|(a, b)| (a + b) % self.q as i64)
964 .collect();
965 let mut elem = RingElement {
966 coeffs,
967 n: self.n,
968 q: self.q,
969 };
970 elem.reduce();
971 elem
972 }
973
974 pub fn sub(&self, other: &RingElement) -> RingElement {
976 assert_eq!(self.n, other.n);
977 let coeffs: Vec<i64> = self
978 .coeffs
979 .iter()
980 .zip(other.coeffs.iter())
981 .map(|(a, b)| (a - b) % self.q as i64)
982 .collect();
983 let mut elem = RingElement {
984 coeffs,
985 n: self.n,
986 q: self.q,
987 };
988 elem.reduce();
989 elem
990 }
991
992 pub fn scale(&self, scalar: i64) -> RingElement {
994 let coeffs: Vec<i64> = self
995 .coeffs
996 .iter()
997 .map(|&c| (c * scalar) % self.q as i64)
998 .collect();
999 let mut elem = RingElement {
1000 coeffs,
1001 n: self.n,
1002 q: self.q,
1003 };
1004 elem.reduce();
1005 elem
1006 }
1007}
1008
1009#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
1011pub enum KemAlgorithm {
1012 Rlwe,
1014 MlKem512,
1016 #[default]
1018 MlKem768,
1019 MlKem1024,
1021 Hybrid,
1023}
1024
1025#[derive(Debug, Clone, Zeroize, ZeroizeOnDrop)]
1031struct MlKemKeyPair {
1032 dk_bytes: Vec<u8>, ek_bytes: Vec<u8>, #[zeroize(skip)]
1035 algorithm: KemAlgorithm,
1036}
1037
1038mod mlkem_ops {
1040 use super::*;
1041
1042 pub fn generate_512(rng: &mut StdRng) -> MlKemKeyPair {
1043 let (dk, ek) = MlKem512::generate(rng);
1044 MlKemKeyPair {
1045 dk_bytes: dk.as_bytes().to_vec(),
1046 ek_bytes: ek.as_bytes().to_vec(),
1047 algorithm: KemAlgorithm::MlKem512,
1048 }
1049 }
1050
1051 pub fn generate_768(rng: &mut StdRng) -> MlKemKeyPair {
1052 let (dk, ek) = MlKem768::generate(rng);
1053 MlKemKeyPair {
1054 dk_bytes: dk.as_bytes().to_vec(),
1055 ek_bytes: ek.as_bytes().to_vec(),
1056 algorithm: KemAlgorithm::MlKem768,
1057 }
1058 }
1059
1060 pub fn generate_1024(rng: &mut StdRng) -> MlKemKeyPair {
1061 let (dk, ek) = MlKem1024::generate(rng);
1062 MlKemKeyPair {
1063 dk_bytes: dk.as_bytes().to_vec(),
1064 ek_bytes: ek.as_bytes().to_vec(),
1065 algorithm: KemAlgorithm::MlKem1024,
1066 }
1067 }
1068
1069 pub fn encapsulate_512(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1070 let ek_encoded = <Encoded<EncapsulationKey<MlKem512Params>>>::try_from(ek_bytes).ok()?;
1071 let ek = EncapsulationKey::<MlKem512Params>::from_bytes(&ek_encoded);
1072 let (ct, ss) = ek.encapsulate(rng).ok()?;
1073 let mut shared = [0u8; 32];
1074 shared.copy_from_slice(ss.as_slice());
1075 Some((ct.to_vec(), shared))
1076 }
1077
1078 pub fn encapsulate_768(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1079 let ek_encoded = <Encoded<EncapsulationKey<MlKem768Params>>>::try_from(ek_bytes).ok()?;
1080 let ek = EncapsulationKey::<MlKem768Params>::from_bytes(&ek_encoded);
1081 let (ct, ss) = ek.encapsulate(rng).ok()?;
1082 let mut shared = [0u8; 32];
1083 shared.copy_from_slice(ss.as_slice());
1084 Some((ct.to_vec(), shared))
1085 }
1086
1087 pub fn encapsulate_1024(ek_bytes: &[u8], rng: &mut StdRng) -> Option<(Vec<u8>, [u8; 32])> {
1088 let ek_encoded = <Encoded<EncapsulationKey<MlKem1024Params>>>::try_from(ek_bytes).ok()?;
1089 let ek = EncapsulationKey::<MlKem1024Params>::from_bytes(&ek_encoded);
1090 let (ct, ss) = ek.encapsulate(rng).ok()?;
1091 let mut shared = [0u8; 32];
1092 shared.copy_from_slice(ss.as_slice());
1093 Some((ct.to_vec(), shared))
1094 }
1095
1096 pub fn decapsulate_512(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1097 let dk_encoded = <Encoded<DecapsulationKey<MlKem512Params>>>::try_from(dk_bytes).ok()?;
1098 let dk = DecapsulationKey::<MlKem512Params>::from_bytes(&dk_encoded);
1099 let ct = <ml_kem::Ciphertext<MlKem512>>::try_from(ct_bytes).ok()?;
1100 let ss = dk.decapsulate(&ct).ok()?;
1101 let mut shared = [0u8; 32];
1102 shared.copy_from_slice(ss.as_slice());
1103 Some(shared)
1104 }
1105
1106 pub fn decapsulate_768(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1107 let dk_encoded = <Encoded<DecapsulationKey<MlKem768Params>>>::try_from(dk_bytes).ok()?;
1108 let dk = DecapsulationKey::<MlKem768Params>::from_bytes(&dk_encoded);
1109 let ct = <ml_kem::Ciphertext<MlKem768>>::try_from(ct_bytes).ok()?;
1110 let ss = dk.decapsulate(&ct).ok()?;
1111 let mut shared = [0u8; 32];
1112 shared.copy_from_slice(ss.as_slice());
1113 Some(shared)
1114 }
1115
1116 pub fn decapsulate_1024(dk_bytes: &[u8], ct_bytes: &[u8]) -> Option<[u8; 32]> {
1117 let dk_encoded = <Encoded<DecapsulationKey<MlKem1024Params>>>::try_from(dk_bytes).ok()?;
1118 let dk = DecapsulationKey::<MlKem1024Params>::from_bytes(&dk_encoded);
1119 let ct = <ml_kem::Ciphertext<MlKem1024>>::try_from(ct_bytes).ok()?;
1120 let ss = dk.decapsulate(&ct).ok()?;
1121 let mut shared = [0u8; 32];
1122 shared.copy_from_slice(ss.as_slice());
1123 Some(shared)
1124 }
1125}
1126
1127#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
1134pub struct QuantumKeyPair {
1135 pub a: RingElement,
1137 pub public_key: RingElement,
1138 secret_key: RingElement,
1139 #[zeroize(skip)]
1140 params: LatticeParams,
1141}
1142
1143#[derive(Debug, Clone, Serialize, Deserialize)]
1150pub struct QuantumKeyEvolution {
1151 params: LatticeParams,
1152 current_key: QuantumKeyPair,
1153 evolution_counter: u64,
1154 key_history: VecDeque<[u8; 32]>, max_history: usize,
1156 #[serde(skip, default = "default_rng")]
1157 rng: StdRng,
1158 algorithm: KemAlgorithm,
1160 #[serde(skip)]
1162 mlkem_keypair: Option<MlKemKeyPair>,
1163}
1164
1165impl Drop for QuantumKeyEvolution {
1166 fn drop(&mut self) {
1167 for h in self.key_history.iter_mut() {
1172 h.zeroize();
1173 }
1174 self.key_history.clear();
1175 }
1176}
1177
1178impl QuantumKeyEvolution {
1179 pub fn new(params: LatticeParams, seed: u64) -> Self {
1180 let mut rng = StdRng::seed_from_u64(seed);
1181 let current_key = Self::generate_keypair(¶ms, &mut rng);
1182
1183 Self {
1184 params,
1185 current_key,
1186 evolution_counter: 0,
1187 key_history: VecDeque::new(),
1188 max_history: 100,
1189 rng,
1190 algorithm: KemAlgorithm::Rlwe,
1191 mlkem_keypair: None,
1192 }
1193 }
1194
1195 pub fn new_with_algorithm(params: LatticeParams, seed: u64, algorithm: KemAlgorithm) -> Self {
1197 let mut rng = StdRng::seed_from_u64(seed);
1198 let current_key = Self::generate_keypair(¶ms, &mut rng);
1199 let mlkem_keypair = match algorithm {
1200 KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut rng)),
1201 KemAlgorithm::MlKem768 => Some(mlkem_ops::generate_768(&mut rng)),
1202 KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut rng)),
1203 KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut rng)),
1204 KemAlgorithm::Rlwe => None,
1205 };
1206
1207 Self {
1208 params,
1209 current_key,
1210 evolution_counter: 0,
1211 key_history: VecDeque::new(),
1212 max_history: 100,
1213 rng,
1214 algorithm,
1215 mlkem_keypair,
1216 }
1217 }
1218
1219 fn generate_keypair(params: &LatticeParams, rng: &mut StdRng) -> QuantumKeyPair {
1220 let a = RingElement::random(params.n, params.q, rng);
1222 let s = RingElement::random_ternary(params.n, params.q, rng);
1223 let e = RingElement::random_gaussian(params.n, params.q, params.sigma, rng);
1224
1225 let b = a.mul(&s).add(&e);
1227
1228 QuantumKeyPair {
1229 a, public_key: b,
1231 secret_key: s,
1232 params: params.clone(),
1233 }
1234 }
1235
1236 pub fn evolve(&mut self) -> [u8; 32] {
1241 let mut hasher = Sha256::new();
1243 hasher.update(self.current_key.public_key.to_bytes());
1244 hasher.update(self.current_key.secret_key.to_bytes());
1245 hasher.update(self.evolution_counter.to_le_bytes());
1246 let hash: [u8; 32] = hasher.finalize().into();
1247
1248 self.key_history.push_back(hash);
1250 if self.key_history.len() > self.max_history {
1251 self.key_history.pop_front();
1252 }
1253
1254 let hk = Hkdf::<Sha256>::new(Some(&self.evolution_counter.to_le_bytes()), &hash);
1256 let mut okm = [0u8; 32];
1257 hk.expand(b"spine-key-evolution", &mut okm)
1258 .expect("HKDF expand failed");
1259 let new_seed = u64::from_le_bytes(okm[0..8].try_into().unwrap());
1260 let mut new_rng = StdRng::seed_from_u64(new_seed);
1261
1262 self.current_key = Self::generate_keypair(&self.params, &mut new_rng);
1264
1265 if self.algorithm != KemAlgorithm::Rlwe {
1267 self.mlkem_keypair = match self.algorithm {
1268 KemAlgorithm::MlKem512 => Some(mlkem_ops::generate_512(&mut new_rng)),
1269 KemAlgorithm::MlKem768 | KemAlgorithm::Hybrid => Some(mlkem_ops::generate_768(&mut new_rng)),
1270 KemAlgorithm::MlKem1024 => Some(mlkem_ops::generate_1024(&mut new_rng)),
1271 KemAlgorithm::Rlwe => None,
1272 };
1273 }
1274
1275 self.evolution_counter += 1;
1276
1277 hash
1278 }
1279
1280 pub fn encapsulate(&mut self) -> (Vec<u8>, [u8; 32]) {
1282 match self.algorithm {
1283 KemAlgorithm::Rlwe => self.encapsulate_rlwe(),
1284 KemAlgorithm::MlKem512 => self.encapsulate_mlkem(KemAlgorithm::MlKem512),
1285 KemAlgorithm::MlKem768 => self.encapsulate_mlkem(KemAlgorithm::MlKem768),
1286 KemAlgorithm::MlKem1024 => self.encapsulate_mlkem(KemAlgorithm::MlKem1024),
1287 KemAlgorithm::Hybrid => self.encapsulate_hybrid(),
1288 }
1289 }
1290
1291 fn encapsulate_mlkem(&mut self, alg: KemAlgorithm) -> (Vec<u8>, [u8; 32]) {
1293 let kp = self.mlkem_keypair.as_ref().expect("ML-KEM keypair required");
1294 let result = match alg {
1295 KemAlgorithm::MlKem512 => mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut self.rng),
1296 KemAlgorithm::MlKem768 => mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut self.rng),
1297 KemAlgorithm::MlKem1024 => mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut self.rng),
1298 _ => unreachable!(),
1299 };
1300 result.unwrap_or_else(|| {
1301 self.encapsulate_rlwe()
1303 })
1304 }
1305
1306 fn encapsulate_hybrid(&mut self) -> (Vec<u8>, [u8; 32]) {
1308 let (rlwe_ct, rlwe_ss) = self.encapsulate_rlwe();
1310 let (mlkem_ct, mlkem_ss) = self.encapsulate_mlkem(KemAlgorithm::MlKem768);
1311
1312 let mut combined_ikm = [0u8; 64];
1314 combined_ikm[..32].copy_from_slice(&rlwe_ss);
1315 combined_ikm[32..].copy_from_slice(&mlkem_ss);
1316 let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
1317 let mut hybrid_ss = [0u8; 32];
1318 hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
1319
1320 let rlwe_len = (rlwe_ct.len() as u32).to_le_bytes();
1322 let mut hybrid_ct = Vec::with_capacity(4 + rlwe_ct.len() + mlkem_ct.len());
1323 hybrid_ct.extend_from_slice(&rlwe_len);
1324 hybrid_ct.extend_from_slice(&rlwe_ct);
1325 hybrid_ct.extend_from_slice(&mlkem_ct);
1326
1327 (hybrid_ct, hybrid_ss)
1328 }
1329
1330 fn encapsulate_rlwe(&mut self) -> (Vec<u8>, [u8; 32]) {
1336 let a = &self.current_key.a;
1338 let r = RingElement::random_ternary(self.params.n, self.params.q, &mut self.rng);
1339 let e1 = RingElement::random_gaussian(
1340 self.params.n,
1341 self.params.q,
1342 self.params.sigma,
1343 &mut self.rng,
1344 );
1345 let e2 = RingElement::random_gaussian(
1346 self.params.n,
1347 self.params.q,
1348 self.params.sigma,
1349 &mut self.rng,
1350 );
1351
1352 let m: Vec<i64> = (0..self.params.n)
1354 .map(|_| self.rng.gen_range(0..2i64))
1355 .collect();
1356
1357 let u = a.mul(&r).add(&e1);
1359
1360 let half_q = (self.params.q / 2) as i64;
1362 let encoded_m = RingElement {
1363 coeffs: m.iter().map(|&mi| mi * half_q).collect(),
1364 n: self.params.n,
1365 q: self.params.q,
1366 };
1367 let v = self.current_key.public_key.mul(&r).add(&e2).add(&encoded_m);
1368
1369 let mut ciphertext = u.to_bytes();
1371 ciphertext.extend(v.to_bytes());
1372
1373 let mut hasher = Sha256::new();
1375 for &mi in &m {
1376 hasher.update(mi.to_le_bytes());
1377 }
1378 let shared_secret: [u8; 32] = hasher.finalize().into();
1379
1380 (ciphertext, shared_secret)
1381 }
1382
1383 pub fn decapsulate(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1385 match self.algorithm {
1386 KemAlgorithm::Rlwe => self.decapsulate_rlwe(ciphertext),
1387 KemAlgorithm::MlKem512 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem512),
1388 KemAlgorithm::MlKem768 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem768),
1389 KemAlgorithm::MlKem1024 => self.decapsulate_mlkem(ciphertext, KemAlgorithm::MlKem1024),
1390 KemAlgorithm::Hybrid => self.decapsulate_hybrid(ciphertext),
1391 }
1392 }
1393
1394 fn decapsulate_mlkem(&self, ciphertext: &[u8], alg: KemAlgorithm) -> Option<[u8; 32]> {
1396 let kp = self.mlkem_keypair.as_ref()?;
1397 match alg {
1398 KemAlgorithm::MlKem512 => mlkem_ops::decapsulate_512(&kp.dk_bytes, ciphertext),
1399 KemAlgorithm::MlKem768 => mlkem_ops::decapsulate_768(&kp.dk_bytes, ciphertext),
1400 KemAlgorithm::MlKem1024 => mlkem_ops::decapsulate_1024(&kp.dk_bytes, ciphertext),
1401 _ => None,
1402 }
1403 }
1404
1405 fn decapsulate_hybrid(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1407 if ciphertext.len() < 4 { return None; }
1408 let rlwe_len = u32::from_le_bytes(ciphertext[..4].try_into().ok()?) as usize;
1409 if ciphertext.len() < 4 + rlwe_len { return None; }
1410
1411 let rlwe_ct = &ciphertext[4..4+rlwe_len];
1412 let mlkem_ct = &ciphertext[4+rlwe_len..];
1413
1414 let rlwe_ss = self.decapsulate_rlwe(rlwe_ct)?;
1415 let mlkem_ss = self.decapsulate_mlkem(mlkem_ct, KemAlgorithm::MlKem768)?;
1416
1417 let mut combined_ikm = [0u8; 64];
1418 combined_ikm[..32].copy_from_slice(&rlwe_ss);
1419 combined_ikm[32..].copy_from_slice(&mlkem_ss);
1420 let hk = Hkdf::<Sha256>::new(None, &combined_ikm);
1421 let mut hybrid_ss = [0u8; 32];
1422 hk.expand(b"spine-hybrid-kem", &mut hybrid_ss).expect("HKDF expand");
1423
1424 Some(hybrid_ss)
1425 }
1426
1427 fn decapsulate_rlwe(&self, ciphertext: &[u8]) -> Option<[u8; 32]> {
1433 let half = ciphertext.len() / 2;
1434 if half < self.params.n * 2 {
1435 return None;
1436 }
1437
1438 let u = RingElement::from_bytes(&ciphertext[..half], self.params.n, self.params.q);
1439 let v = RingElement::from_bytes(&ciphertext[half..], self.params.n, self.params.q);
1440
1441 let recovered = v.sub(&u.mul(&self.current_key.secret_key));
1443
1444 let half_q = self.params.q as i64 / 2;
1446 let quarter_q = self.params.q as i64 / 4;
1447 let m: Vec<i64> = recovered
1448 .coeffs
1449 .iter()
1450 .map(|&c| {
1451 let c_pos =
1453 ((c % self.params.q as i64) + self.params.q as i64) % self.params.q as i64;
1454 if (c_pos - half_q).abs() < quarter_q {
1456 1i64
1457 } else {
1458 0i64
1459 }
1460 })
1461 .collect();
1462
1463 let mut hasher = Sha256::new();
1465 for &mi in &m {
1466 hasher.update(mi.to_le_bytes());
1467 }
1468 Some(hasher.finalize().into())
1469 }
1470
1471 pub fn get_key_hash(&self) -> [u8; 32] {
1473 let mut hasher = Sha256::new();
1474 hasher.update(self.current_key.public_key.to_bytes());
1475 hasher.finalize().into()
1476 }
1477
1478 pub fn verify_evolution(&self, expected_hash: &[u8; 32]) -> bool {
1480 self.key_history
1481 .iter()
1482 .any(|h| h.ct_eq(expected_hash).into())
1483 }
1484
1485 pub fn get_evolution_counter(&self) -> u64 {
1487 self.evolution_counter
1488 }
1489
1490 pub fn export_public_key(&self) -> Vec<u8> {
1492 self.current_key.public_key.to_bytes()
1493 }
1494}
1495
1496#[derive(Debug, Clone, Serialize, Deserialize)]
1498pub struct QuantumSpeculativeProtocol {
1499 predictor: TransformerPredictor,
1500 key_evolution: QuantumKeyEvolution,
1501 prediction_threshold: f32,
1502 evolution_interval: u64,
1503 message_count: u64,
1504}
1505
1506impl QuantumSpeculativeProtocol {
1507 pub fn new(
1508 transformer_config: TransformerConfig,
1509 lattice_params: LatticeParams,
1510 seed: u64,
1511 ) -> Self {
1512 Self {
1513 predictor: TransformerPredictor::new(transformer_config),
1514 key_evolution: QuantumKeyEvolution::new(lattice_params, seed),
1515 prediction_threshold: 0.8,
1516 evolution_interval: 10,
1517 message_count: 0,
1518 }
1519 }
1520
1521 pub fn new_with_algorithm(
1523 transformer_config: TransformerConfig,
1524 lattice_params: LatticeParams,
1525 seed: u64,
1526 algorithm: KemAlgorithm,
1527 ) -> Self {
1528 Self {
1529 predictor: TransformerPredictor::new(transformer_config),
1530 key_evolution: QuantumKeyEvolution::new_with_algorithm(lattice_params, seed, algorithm),
1531 prediction_threshold: 0.8,
1532 evolution_interval: 10,
1533 message_count: 0,
1534 }
1535 }
1536
1537 pub fn algorithm(&self) -> KemAlgorithm {
1539 self.key_evolution.algorithm
1540 }
1541
1542 pub fn send(&mut self, message: &[u8]) -> QuantumMessage {
1544 let (matches, similarity) = self.predictor.verify_prediction(message);
1546
1547 let payload = if matches && similarity >= self.prediction_threshold {
1548 MessagePayload::Confirmation {
1550 hash: Self::hash_message(message),
1551 length: message.len(),
1552 }
1553 } else {
1554 let (ciphertext, shared_secret) = self.key_evolution.encapsulate();
1556
1557 let hk = Hkdf::<Sha256>::new(None, &shared_secret);
1559 let mut aes_key = [0u8; 32];
1560 hk.expand(b"spine-aead-key", &mut aes_key)
1561 .expect("HKDF expand failed");
1562
1563 let mut nonce_bytes = [0u8; 12];
1565 nonce_bytes[..8].copy_from_slice(&self.message_count.to_le_bytes());
1566 let nonce = Nonce::from_slice(&nonce_bytes);
1567
1568 let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
1570 let encrypted = cipher.encrypt(nonce, message).expect("AES-GCM encrypt");
1571
1572 let mut encrypted_message = nonce_bytes.to_vec();
1574 encrypted_message.extend(encrypted);
1575
1576 MessagePayload::Full {
1577 ciphertext,
1578 encrypted_message,
1579 }
1580 };
1581
1582 self.message_count += 1;
1584 let key_evolution = if self.message_count.is_multiple_of(self.evolution_interval) {
1585 Some(self.key_evolution.evolve())
1586 } else {
1587 None
1588 };
1589
1590 QuantumMessage {
1591 payload,
1592 evolution_counter: self.key_evolution.get_evolution_counter(),
1593 key_evolution,
1594 }
1595 }
1596
1597 pub fn get_morph_seed(&self) -> u64 {
1599 let key_hash = self.key_evolution.get_key_hash();
1600 u64::from_le_bytes(key_hash[0..8].try_into().unwrap())
1601 }
1602
1603 pub fn receive(&mut self, quantum_msg: &QuantumMessage) -> Option<Vec<u8>> {
1605 while self.key_evolution.get_evolution_counter() < quantum_msg.evolution_counter {
1607 self.key_evolution.evolve();
1608 }
1609
1610 let message = match &quantum_msg.payload {
1611 MessagePayload::Confirmation { hash, length } => {
1612 let predicted = self.predictor.predict_sequence(*length, true);
1614
1615 let predicted_hash = Self::hash_message(&predicted);
1617 if &predicted_hash == hash {
1618 Some(predicted)
1619 } else {
1620 None }
1622 }
1623 MessagePayload::Full {
1624 ciphertext,
1625 encrypted_message,
1626 } => {
1627 let shared_secret = self.key_evolution.decapsulate(ciphertext)?;
1629
1630 let hk = Hkdf::<Sha256>::new(None, &shared_secret);
1632 let mut aes_key = [0u8; 32];
1633 hk.expand(b"spine-aead-key", &mut aes_key)
1634 .expect("HKDF expand failed");
1635
1636 if encrypted_message.len() < 12 {
1638 return None;
1639 }
1640 let nonce = Nonce::from_slice(&encrypted_message[..12]);
1641 let ciphertext_data = &encrypted_message[12..];
1642
1643 let cipher = Aes256Gcm::new_from_slice(&aes_key).expect("AES key length");
1645 cipher.decrypt(nonce, ciphertext_data).ok()
1646 }
1647 };
1648
1649 if let Some(ref msg) = message {
1651 self.predictor.observe(msg);
1652 }
1653
1654 message
1655 }
1656
1657 fn hash_message(message: &[u8]) -> [u8; 32] {
1658 let mut hasher = Sha256::new();
1659 hasher.update(message);
1660 hasher.finalize().into()
1661 }
1662
1663 pub fn set_threshold(&mut self, threshold: f32) {
1665 self.prediction_threshold = threshold.clamp(0.0, 1.0);
1666 }
1667
1668 pub fn set_evolution_interval(&mut self, interval: u64) {
1670 self.evolution_interval = interval.max(1);
1671 }
1672
1673 pub fn reset(&mut self) {
1675 self.predictor.reset();
1676 self.message_count = 0;
1677 }
1678}
1679
1680#[derive(Debug, Clone, Serialize, Deserialize)]
1682pub struct QuantumMessage {
1683 pub payload: MessagePayload,
1684 pub evolution_counter: u64,
1685 pub key_evolution: Option<[u8; 32]>,
1686}
1687
1688#[derive(Debug, Clone, Serialize, Deserialize)]
1689pub enum MessagePayload {
1690 Confirmation { hash: [u8; 32], length: usize },
1692 Full {
1694 ciphertext: Vec<u8>,
1695 encrypted_message: Vec<u8>,
1696 },
1697}
1698
1699#[cfg(test)]
1700mod tests {
1701 use super::*;
1702
1703 fn assert_zeroize_on_drop<T: ZeroizeOnDrop>() {}
1711
1712 #[test]
1713 fn ringelement_implements_zeroize_on_drop() {
1714 assert_zeroize_on_drop::<RingElement>();
1715 }
1716
1717 #[test]
1718 fn mlkemkeypair_implements_zeroize_on_drop() {
1719 assert_zeroize_on_drop::<MlKemKeyPair>();
1720 }
1721
1722 #[test]
1723 fn quantumkeypair_implements_zeroize_on_drop() {
1724 assert_zeroize_on_drop::<QuantumKeyPair>();
1725 }
1726
1727 #[test]
1733 fn ringelement_zeroize_clears_all_coefficients() {
1734 let mut rng = StdRng::seed_from_u64(0xAB_CD);
1735 let mut r = RingElement::random(64, 8_192, &mut rng);
1736 assert!(
1737 r.coeffs.iter().any(|&c| c != 0),
1738 "test precondition: random RingElement should have non-zero coeffs"
1739 );
1740 r.zeroize();
1741 assert!(
1742 r.coeffs.iter().all(|&c| c == 0),
1743 "RingElement::zeroize did not clear every coefficient"
1744 );
1745 }
1746
1747 #[test]
1748 fn mlkemkeypair_zeroize_clears_dk_bytes() {
1749 let mut rng = StdRng::seed_from_u64(0x12_34);
1750 let mut kp = mlkem_ops::generate_768(&mut rng);
1751 assert!(
1752 kp.dk_bytes.iter().any(|&b| b != 0),
1753 "test precondition: fresh ML-KEM dk should be non-zero"
1754 );
1755 kp.zeroize();
1756 assert!(
1761 kp.dk_bytes.iter().all(|&b| b == 0),
1762 "MlKemKeyPair::zeroize left non-zero bytes in dk_bytes"
1763 );
1764 }
1765
1766 #[test]
1767 fn quantumkeyevolution_drop_clears_key_history() {
1768 let mut ev = QuantumKeyEvolution::new(LatticeParams::default(), 0xCAFE);
1769 ev.key_history.push_back([0x11u8; 32]);
1771 ev.key_history.push_back([0x22u8; 32]);
1772 assert_eq!(ev.key_history.len(), 2);
1773 ev.key_history.iter_mut().for_each(|h| h.zeroize());
1779 assert!(
1780 ev.key_history.iter().all(|h| h.iter().all(|&b| b == 0)),
1781 "QuantumKeyEvolution key_history not zeroed"
1782 );
1783 }
1784
1785 #[test]
1786 fn test_positional_encoding() {
1787 let pe = PositionalEncoding::new(100, 64);
1788 let enc0 = pe.get(0);
1789 let enc50 = pe.get(50);
1790 assert_eq!(enc0.len(), 64);
1791 assert_ne!(enc0, enc50);
1792 }
1793
1794 #[test]
1795 fn test_layer_norm() {
1796 let ln = LayerNorm::new(8);
1797 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1798 let output = ln.forward(&input);
1799 assert_eq!(output.len(), 8);
1800
1801 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
1803 assert!(mean.abs() < 0.01);
1804 }
1805
1806 #[test]
1807 fn test_titans_predictor() {
1808 let config = TitansConfig {
1809 embed_dim: 32,
1810 num_heads: 2,
1811 num_layers: 1,
1812 ff_dim: 64,
1813 max_seq_len: 64,
1814 memory_size: 16,
1815 seed: 42,
1816 };
1817 let mut predictor = TitansPredictor::new(config);
1818
1819 predictor.observe(b"Hello ");
1821 predictor.observe(b"World");
1822
1823 let (next, conf) = predictor.predict_next();
1825 assert!(conf > 0.0 && conf <= 1.0);
1826 let _ = next;
1828
1829 let surprise = predictor.get_surprise();
1831 assert!(surprise >= 0.0);
1832 }
1833
1834 #[test]
1835 fn test_titans_anomaly_detection() {
1836 let config = TitansConfig {
1837 embed_dim: 32,
1838 num_heads: 2,
1839 num_layers: 1,
1840 ff_dim: 64,
1841 max_seq_len: 64,
1842 memory_size: 16,
1843 seed: 42,
1844 };
1845 let mut predictor = TitansPredictor::new(config);
1846
1847 for _ in 0..10 {
1849 predictor.observe(b"GET /api/status\n");
1850 }
1851 let _normal_surprise = predictor.get_surprise();
1852
1853 predictor.observe(b"MALICIOUS_PAYLOAD_XYZ!!!");
1855 let anomaly_surprise = predictor.get_surprise();
1856
1857 assert!(anomaly_surprise >= 0.0);
1859 }
1860
1861 #[test]
1862 fn test_ring_operations() {
1863 let mut rng = StdRng::seed_from_u64(42);
1864 let params = LatticeParams {
1865 n: 16,
1866 q: 97,
1867 p: 3,
1868 sigma: 2.0,
1869 };
1870
1871 let a = RingElement::random(params.n, params.q, &mut rng);
1872 let b = RingElement::random(params.n, params.q, &mut rng);
1873
1874 let sum = a.add(&b);
1875 let product = a.mul(&b);
1876
1877 assert_eq!(sum.coeffs.len(), params.n);
1878 assert_eq!(product.coeffs.len(), params.n);
1879
1880 for &c in &sum.coeffs {
1882 assert!(c >= 0 && c < params.q as i64);
1883 }
1884 }
1885
1886 #[test]
1887 fn test_key_evolution() {
1888 let params = LatticeParams {
1889 n: 32,
1890 q: 257,
1891 p: 3,
1892 sigma: 2.0,
1893 };
1894 let mut ke = QuantumKeyEvolution::new(params, 42);
1895
1896 let hash1 = ke.get_key_hash();
1897 ke.evolve();
1898 let hash2 = ke.get_key_hash();
1899
1900 assert_ne!(hash1, hash2);
1902
1903 assert_eq!(ke.get_evolution_counter(), 1);
1905 }
1906
1907 #[test]
1908 fn test_encapsulation() {
1909 let params = LatticeParams {
1910 n: 32,
1911 q: 257,
1912 p: 3,
1913 sigma: 2.0,
1914 };
1915 let mut ke = QuantumKeyEvolution::new(params, 42);
1916
1917 let (ciphertext, _shared_secret1) = ke.encapsulate();
1918 assert!(!ciphertext.is_empty());
1919
1920 let shared_secret2 = ke.decapsulate(&ciphertext);
1921 assert!(shared_secret2.is_some());
1922
1923 }
1926
1927 #[test]
1928 fn test_quantum_speculative_protocol() {
1929 let config = TitansConfig {
1930 embed_dim: 16,
1931 num_heads: 2,
1932 num_layers: 1,
1933 ff_dim: 32,
1934 max_seq_len: 32,
1935 memory_size: 8,
1936 seed: 42,
1937 };
1938 let params = LatticeParams {
1939 n: 16,
1940 q: 97,
1941 p: 3,
1942 sigma: 2.0,
1943 };
1944
1945 let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
1946 let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
1947
1948 let msg = b"Hello Bob!";
1950 let quantum_msg = alice.send(msg);
1951
1952 let received = bob.receive(&quantum_msg);
1954 assert!(received.is_some());
1955 assert_eq!(received.unwrap(), msg.to_vec());
1956 }
1957
1958 #[test]
1959 fn test_prediction_efficiency() {
1960 let config = TransformerConfig::default();
1961 let params = LatticeParams::default();
1962
1963 let mut sender = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
1964 let mut receiver = QuantumSpeculativeProtocol::new(config, params, 42);
1965
1966 for _ in 0..5 {
1968 let msg1 = sender.send(b"GET /api/status");
1969 receiver.receive(&msg1);
1970
1971 let msg2 = sender.send(b"200 OK");
1972 receiver.receive(&msg2);
1973 }
1974
1975 let msg = sender.send(b"GET /api/status");
1977
1978 let received = receiver.receive(&msg);
1980 assert!(received.is_some());
1981 }
1982
1983 #[test]
1988 fn test_miras_predictor_basic() {
1989 let config = TitansConfig {
1990 embed_dim: 32,
1991 num_heads: 2,
1992 num_layers: 1,
1993 ff_dim: 64,
1994 max_seq_len: 64,
1995 memory_size: 16,
1996 seed: 42,
1997 };
1998 let mut predictor = MirasTitansPredictor::new(config);
1999
2000 predictor.observe(b"Hello World");
2002
2003 assert_eq!(predictor.variant(), "titans");
2005
2006 let (next, conf) = predictor.predict_next();
2008 assert!(conf > 0.0 && conf <= 1.0);
2009 let _ = next;
2011
2012 let stats = predictor.stats();
2014 assert_eq!(stats.message_count, 1);
2015 assert!(stats.miras_enhanced_predictions > 0);
2016 }
2017
2018 #[test]
2019 fn test_miras_predictor_variants() {
2020 let config = TitansConfig {
2021 embed_dim: 32,
2022 num_heads: 2,
2023 num_layers: 1,
2024 ff_dim: 64,
2025 max_seq_len: 64,
2026 memory_size: 16,
2027 seed: 42,
2028 };
2029
2030 for variant in [
2032 MirasVariant::Titans,
2033 MirasVariant::Yaad,
2034 MirasVariant::Moneta { p: 2.0 },
2035 MirasVariant::Memora,
2036 ] {
2037 let predictor = MirasTitansPredictor::new_with_variant(config.clone(), variant);
2038
2039 match variant {
2041 MirasVariant::Titans => assert_eq!(predictor.variant(), "titans"),
2042 MirasVariant::Yaad => assert_eq!(predictor.variant(), "yaad"),
2043 MirasVariant::Moneta { .. } => assert_eq!(predictor.variant(), "moneta"),
2044 MirasVariant::Memora => assert_eq!(predictor.variant(), "memora"),
2045 }
2046 }
2047
2048 let mut predictor =
2050 MirasTitansPredictor::new_with_variant(config.clone(), MirasVariant::Yaad);
2051 assert_eq!(predictor.variant(), "yaad");
2052
2053 predictor.observe(b"test");
2055 }
2057
2058 #[test]
2059 fn test_miras_predictor_combined_surprise() {
2060 let config = TitansConfig {
2061 embed_dim: 32,
2062 num_heads: 2,
2063 num_layers: 1,
2064 ff_dim: 64,
2065 max_seq_len: 64,
2066 memory_size: 16,
2067 seed: 42,
2068 };
2069 let mut predictor = MirasTitansPredictor::new(config);
2070
2071 for _ in 0..5 {
2073 predictor.observe(b"normal message pattern");
2074 }
2075
2076 let combined = predictor.get_combined_surprise();
2078 assert!(combined >= 0.0);
2079
2080 let titans_surprise = predictor.get_surprise();
2082 let miras_surprise = predictor.get_miras_surprise();
2083
2084 assert!(titans_surprise >= 0.0);
2085 assert!(miras_surprise.is_some());
2086 }
2087
2088 #[test]
2089 fn test_miras_predictor_anomaly_level() {
2090 let config = TitansConfig {
2091 embed_dim: 32,
2092 num_heads: 2,
2093 num_layers: 1,
2094 ff_dim: 64,
2095 max_seq_len: 64,
2096 memory_size: 16,
2097 seed: 42,
2098 };
2099 let mut predictor = MirasTitansPredictor::new(config);
2100
2101 assert_eq!(predictor.anomaly_level(), 0.0);
2103
2104 predictor.observe(b"test");
2106 let level = predictor.anomaly_level();
2107 assert!(level >= 0.0); }
2109
2110 #[test]
2111 fn test_miras_predictor_reset() {
2112 let config = TitansConfig {
2113 embed_dim: 32,
2114 num_heads: 2,
2115 num_layers: 1,
2116 ff_dim: 64,
2117 max_seq_len: 64,
2118 memory_size: 16,
2119 seed: 42,
2120 };
2121 let mut predictor = MirasTitansPredictor::new(config);
2122
2123 for _ in 0..10 {
2125 predictor.observe(b"data");
2126 }
2127 assert!(predictor.stats().message_count > 0);
2128
2129 predictor.reset_all();
2131 let stats = predictor.stats();
2132 assert_eq!(stats.message_count, 0);
2133 }
2134
2135 #[test]
2140 fn test_rlwe_ring_arithmetic_correctness() {
2141 let mut rng = StdRng::seed_from_u64(12345);
2143 let params = LatticeParams {
2144 n: 32,
2145 q: 257,
2146 p: 3,
2147 sigma: 2.0,
2148 };
2149
2150 let a = RingElement::random(params.n, params.q, &mut rng);
2151 let b = RingElement::random(params.n, params.q, &mut rng);
2152 let c = RingElement::random(params.n, params.q, &mut rng);
2153
2154 let ab = a.add(&b);
2156 let ba = b.add(&a);
2157 assert_eq!(ab.coeffs, ba.coeffs, "Addition should be commutative");
2158
2159 let ab_c = a.add(&b).add(&c);
2161 let a_bc = a.add(&b.add(&c));
2162 assert_eq!(ab_c.coeffs, a_bc.coeffs, "Addition should be associative");
2163
2164 let a_times_bplusc = a.mul(&b.add(&c));
2166 let ab_plus_ac = a.mul(&b).add(&a.mul(&c));
2167 assert_eq!(
2168 a_times_bplusc.coeffs, ab_plus_ac.coeffs,
2169 "Multiplication should distribute over addition"
2170 );
2171 }
2172
2173 #[test]
2174 fn test_rlwe_gaussian_distribution() {
2175 let mut rng = StdRng::seed_from_u64(54321);
2177 let params = LatticeParams {
2178 n: 1024,
2179 q: 12289, p: 3,
2181 sigma: 3.2,
2182 };
2183
2184 let e = RingElement::random_gaussian(params.n, params.q, params.sigma, &mut rng);
2185
2186 let mean: f64 = e.coeffs.iter().map(|&c| c as f64).sum::<f64>() / params.n as f64;
2188 let variance: f64 = e
2189 .coeffs
2190 .iter()
2191 .map(|&c| (c as f64 - mean).powi(2))
2192 .sum::<f64>()
2193 / params.n as f64;
2194
2195 assert!(
2197 mean.abs() < params.sigma,
2198 "Gaussian mean should be near 0, got {}",
2199 mean
2200 );
2201
2202 let expected_variance = params.sigma * params.sigma;
2204 assert!(
2205 (variance - expected_variance).abs() < expected_variance * 0.5,
2206 "Variance {} should be close to sigma^2 = {}",
2207 variance,
2208 expected_variance
2209 );
2210 }
2211
2212 #[test]
2213 fn test_rlwe_ternary_distribution() {
2214 let mut rng = StdRng::seed_from_u64(98765);
2216 let params = LatticeParams {
2217 n: 256,
2218 q: 257,
2219 p: 3,
2220 sigma: 2.0,
2221 };
2222
2223 let s = RingElement::random_ternary(params.n, params.q, &mut rng);
2224
2225 for &coeff in &s.coeffs {
2227 assert!(
2228 coeff == 0 || coeff == 1 || coeff == -1,
2229 "Ternary coefficient should be -1, 0, or 1, got {}",
2230 coeff
2231 );
2232 }
2233
2234 let count_zero = s.coeffs.iter().filter(|&&c| c == 0).count();
2236 let count_one = s.coeffs.iter().filter(|&&c| c == 1).count();
2237 let count_neg = s.coeffs.iter().filter(|&&c| c == -1).count();
2238
2239 let expected = params.n / 3;
2241 let tolerance = params.n / 4; assert!(
2243 (count_zero as isize - expected as isize).unsigned_abs() < tolerance,
2244 "Ternary distribution unbalanced: zeros={}, ones={}, neg={}",
2245 count_zero,
2246 count_one,
2247 count_neg
2248 );
2249 }
2250
2251 #[test]
2252 fn test_key_evolution_forward_secrecy() {
2253 let params = LatticeParams {
2255 n: 64,
2256 q: 257,
2257 p: 3,
2258 sigma: 2.0,
2259 };
2260
2261 let mut ke1 = QuantumKeyEvolution::new(params.clone(), 42);
2262 let mut ke2 = QuantumKeyEvolution::new(params, 42);
2263
2264 assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2266
2267 for _ in 0..5 {
2269 ke1.evolve();
2270 }
2271
2272 assert_ne!(ke1.get_key_hash(), ke2.get_key_hash());
2274
2275 assert_eq!(ke1.get_evolution_counter(), 5);
2277 assert_eq!(ke2.get_evolution_counter(), 0);
2278
2279 for _ in 0..5 {
2281 ke2.evolve();
2282 }
2283
2284 assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2286 }
2287
2288 #[test]
2289 fn test_key_evolution_history_integrity() {
2290 let params = LatticeParams {
2291 n: 32,
2292 q: 257,
2293 p: 3,
2294 sigma: 2.0,
2295 };
2296
2297 let mut ke = QuantumKeyEvolution::new(params, 42);
2298
2299 let mut hashes = Vec::new();
2301 for _ in 0..10 {
2302 let hash = ke.evolve();
2303 hashes.push(hash);
2304 }
2305
2306 let unique_count = hashes
2308 .iter()
2309 .collect::<std::collections::HashSet<_>>()
2310 .len();
2311 assert_eq!(unique_count, 10, "All evolution hashes should be unique");
2312
2313 for hash in &hashes {
2315 assert!(
2316 ke.verify_evolution(hash),
2317 "Recent evolution should be verifiable"
2318 );
2319 }
2320 }
2321
2322 #[test]
2323 fn test_quantum_protocol_message_integrity() {
2324 let config = TitansConfig {
2326 embed_dim: 16,
2327 num_heads: 2,
2328 num_layers: 1,
2329 ff_dim: 32,
2330 max_seq_len: 32,
2331 memory_size: 8,
2332 seed: 42,
2333 };
2334 let params = LatticeParams {
2335 n: 32,
2336 q: 257,
2337 p: 3,
2338 sigma: 2.0,
2339 };
2340
2341 let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
2342 let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
2343
2344 let test_messages = [
2346 b"A".to_vec(),
2347 b"Short".to_vec(),
2348 b"Medium length message".to_vec(),
2349 b"This is a longer message to test variable length handling properly".to_vec(),
2350 ];
2351
2352 for msg in &test_messages {
2353 let quantum_msg = alice.send(msg);
2354 let received = bob.receive(&quantum_msg);
2355 assert!(received.is_some(), "Should receive message");
2356 assert_eq!(
2357 &received.unwrap(),
2358 msg,
2359 "Received message should match original"
2360 );
2361 }
2362 }
2363
2364 #[test]
2365 fn test_tampered_ciphertext_detection() {
2366 let params = LatticeParams {
2367 n: 32,
2368 q: 257,
2369 p: 3,
2370 sigma: 2.0,
2371 };
2372
2373 let mut ke = QuantumKeyEvolution::new(params, 42);
2374
2375 let (mut ciphertext, original_secret) = ke.encapsulate();
2376
2377 if !ciphertext.is_empty() {
2379 ciphertext[0] ^= 0xFF;
2380 }
2381
2382 let tampered_secret = ke.decapsulate(&ciphertext);
2384
2385 if let Some(tampered) = tampered_secret {
2386 assert_ne!(
2389 tampered, original_secret,
2390 "Tampered ciphertext should produce different secret"
2391 );
2392 }
2393 }
2394
2395 #[test]
2396 fn test_lattice_params_security_levels() {
2397 let toy_params = LatticeParams {
2399 n: 16,
2400 q: 97,
2401 p: 3,
2402 sigma: 2.0,
2403 };
2404 let medium_params = LatticeParams {
2405 n: 256,
2406 q: 7681,
2407 p: 3,
2408 sigma: 3.19,
2409 };
2410 let _high_params = LatticeParams {
2411 n: 1024,
2412 q: 12289,
2413 p: 3,
2414 sigma: 3.19,
2415 };
2416
2417 assert!(
2419 toy_params.n.is_power_of_two(),
2420 "n should be power of 2 for NTT"
2421 );
2422 assert!(
2423 medium_params.n.is_power_of_two(),
2424 "n should be power of 2 for NTT"
2425 );
2426
2427 let mut ke_toy = QuantumKeyEvolution::new(toy_params, 1);
2429 let mut ke_med = QuantumKeyEvolution::new(medium_params, 1);
2430
2431 let (ct_toy, _) = ke_toy.encapsulate();
2433 let (ct_med, _) = ke_med.encapsulate();
2434
2435 assert!(!ct_toy.is_empty());
2436 assert!(!ct_med.is_empty());
2437
2438 assert!(
2440 ct_med.len() > ct_toy.len(),
2441 "Higher security params should produce larger ciphertext"
2442 );
2443 }
2444
2445 #[test]
2446 fn test_titans_predictor_statistical_properties() {
2447 let config = TitansConfig {
2448 embed_dim: 32,
2449 num_heads: 2,
2450 num_layers: 1,
2451 ff_dim: 64,
2452 max_seq_len: 64,
2453 memory_size: 16,
2454 seed: 42,
2455 };
2456 let mut predictor = TitansPredictor::new(config);
2457
2458 let pattern = b"ABCABC";
2460 for _ in 0..20 {
2461 predictor.observe(pattern);
2462 }
2463
2464 let (_, confidence) = predictor.predict_next();
2466 assert!(
2467 (0.0..=1.0).contains(&confidence),
2468 "Confidence should be normalized"
2469 );
2470
2471 let surprise = predictor.get_surprise();
2473 assert!(surprise >= 0.0, "Surprise should be non-negative");
2474 }
2475
2476 #[test]
2477 fn test_kem_shared_secret_match() {
2478 let params = LatticeParams {
2480 n: 64,
2481 q: 257,
2482 p: 3,
2483 sigma: 1.5,
2484 };
2485 let mut ke = QuantumKeyEvolution::new(params, 12345);
2486
2487 let (ciphertext, shared_secret_enc) = ke.encapsulate();
2488 let shared_secret_dec = ke.decapsulate(&ciphertext).unwrap();
2489
2490 assert_eq!(
2491 shared_secret_enc, shared_secret_dec,
2492 "KEM shared secrets must match between encapsulate and decapsulate"
2493 );
2494 }
2495
2496 #[test]
2497 fn test_aead_tampered_ciphertext_rejected() {
2498 let config = TransformerConfig::default();
2500 let params = LatticeParams {
2501 n: 32,
2502 q: 257,
2503 p: 3,
2504 sigma: 2.0,
2505 };
2506
2507 let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 42);
2508 let mut bob = QuantumSpeculativeProtocol::new(config, params, 42);
2509
2510 let msg = b"Secret message";
2511 let mut quantum_msg = alice.send(msg);
2512
2513 if let MessagePayload::Full {
2515 ref mut encrypted_message,
2516 ..
2517 } = quantum_msg.payload
2518 {
2519 if let Some(byte) = encrypted_message.last_mut() {
2520 *byte ^= 0xFF; }
2522 }
2523
2524 let received = bob.receive(&quantum_msg);
2526 assert!(
2527 received.is_none(),
2528 "Tampered ciphertext must be rejected by AEAD"
2529 );
2530 }
2531
2532 #[test]
2533 fn test_key_evolution_maintains_kem_invariant() {
2534 let params = LatticeParams {
2536 n: 32,
2537 q: 257,
2538 p: 3,
2539 sigma: 2.0,
2540 };
2541 let mut ke = QuantumKeyEvolution::new(params, 99);
2542
2543 for _ in 0..5 {
2544 ke.evolve();
2545 let (ct, ss_enc) = ke.encapsulate();
2547 let ss_dec = ke.decapsulate(&ct).unwrap();
2548 assert_eq!(ss_enc, ss_dec, "KEM must work after key evolution");
2549 }
2550 }
2551
2552 #[test]
2553 fn test_key_evolution_deterministic_hkdf() {
2554 let params = LatticeParams::default();
2556 let mut ke1 = QuantumKeyEvolution::new(params.clone(), 7777);
2557 let mut ke2 = QuantumKeyEvolution::new(params, 7777);
2558
2559 for _ in 0..5 {
2560 let h1 = ke1.evolve();
2561 let h2 = ke2.evolve();
2562 assert_eq!(
2563 h1, h2,
2564 "Deterministic evolution must produce identical hashes"
2565 );
2566 }
2567 assert_eq!(ke1.get_key_hash(), ke2.get_key_hash());
2568 }
2569
2570 #[test]
2571 fn test_aes_gcm_round_trip() {
2572 let config = TransformerConfig::default();
2574 let params = LatticeParams {
2575 n: 64,
2576 q: 257,
2577 p: 3,
2578 sigma: 1.5,
2579 };
2580
2581 let mut alice = QuantumSpeculativeProtocol::new(config.clone(), params.clone(), 100);
2582 let mut bob = QuantumSpeculativeProtocol::new(config, params, 100);
2583
2584 for i in 0..5 {
2586 let msg = format!("Message number {}", i);
2587 let quantum_msg = alice.send(msg.as_bytes());
2588 let received = bob.receive(&quantum_msg);
2589 assert!(received.is_some(), "Message {} should decrypt", i);
2590 assert_eq!(
2591 received.unwrap(),
2592 msg.as_bytes(),
2593 "Message {} content mismatch",
2594 i
2595 );
2596 }
2597 }
2598
2599 #[test]
2602 fn test_mlkem_512_round_trip() {
2603 let mut rng = StdRng::seed_from_u64(1);
2604 let kp = mlkem_ops::generate_512(&mut rng);
2605 assert_eq!(kp.algorithm, KemAlgorithm::MlKem512);
2606
2607 let (ct, ss_enc) = mlkem_ops::encapsulate_512(&kp.ek_bytes, &mut rng).unwrap();
2608 let ss_dec = mlkem_ops::decapsulate_512(&kp.dk_bytes, &ct).unwrap();
2609 assert_eq!(ss_enc.len(), 32);
2610 assert_eq!(ss_enc, ss_dec, "ML-KEM-512 shared secret mismatch");
2611 }
2612
2613 #[test]
2614 fn test_mlkem_768_round_trip() {
2615 let mut rng = StdRng::seed_from_u64(2);
2616 let kp = mlkem_ops::generate_768(&mut rng);
2617 assert_eq!(kp.algorithm, KemAlgorithm::MlKem768);
2618
2619 let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp.ek_bytes, &mut rng).unwrap();
2620 let ss_dec = mlkem_ops::decapsulate_768(&kp.dk_bytes, &ct).unwrap();
2621 assert_eq!(ss_enc.len(), 32);
2622 assert_eq!(ss_enc, ss_dec, "ML-KEM-768 shared secret mismatch");
2623 }
2624
2625 #[test]
2626 fn test_mlkem_1024_round_trip() {
2627 let mut rng = StdRng::seed_from_u64(3);
2628 let kp = mlkem_ops::generate_1024(&mut rng);
2629 assert_eq!(kp.algorithm, KemAlgorithm::MlKem1024);
2630
2631 let (ct, ss_enc) = mlkem_ops::encapsulate_1024(&kp.ek_bytes, &mut rng).unwrap();
2632 let ss_dec = mlkem_ops::decapsulate_1024(&kp.dk_bytes, &ct).unwrap();
2633 assert_eq!(ss_enc.len(), 32);
2634 assert_eq!(ss_enc, ss_dec, "ML-KEM-1024 shared secret mismatch");
2635 }
2636
2637 #[test]
2638 fn test_mlkem_different_keypairs_produce_different_secrets() {
2639 let mut rng = StdRng::seed_from_u64(4);
2640 let kp1 = mlkem_ops::generate_768(&mut rng);
2641 let kp2 = mlkem_ops::generate_768(&mut rng);
2642
2643 let (_, ss1) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
2644 let (_, ss2) = mlkem_ops::encapsulate_768(&kp2.ek_bytes, &mut rng).unwrap();
2645
2646 assert_ne!(ss1, ss2, "Different keypairs should yield different secrets");
2648 }
2649
2650 #[test]
2651 fn test_mlkem_wrong_key_decapsulation_fails() {
2652 let mut rng = StdRng::seed_from_u64(5);
2653 let kp1 = mlkem_ops::generate_768(&mut rng);
2654 let kp2 = mlkem_ops::generate_768(&mut rng);
2655
2656 let (ct, ss_enc) = mlkem_ops::encapsulate_768(&kp1.ek_bytes, &mut rng).unwrap();
2657 let ss_wrong = mlkem_ops::decapsulate_768(&kp2.dk_bytes, &ct).unwrap();
2659 assert_ne!(
2660 ss_enc, ss_wrong,
2661 "Wrong DK must produce different shared secret (implicit reject)"
2662 );
2663 }
2664
2665 #[test]
2666 fn test_kem_algorithm_default() {
2667 assert_eq!(KemAlgorithm::default(), KemAlgorithm::MlKem768);
2668 }
2669
2670 #[test]
2671 fn test_quantum_key_evolution_with_mlkem() {
2672 let params = LatticeParams {
2673 n: 32,
2674 q: 257,
2675 p: 3,
2676 sigma: 2.0,
2677 };
2678 let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::MlKem768);
2679
2680 let (ct, ss_enc) = ke.encapsulate();
2682 let ss_dec = ke.decapsulate(&ct).unwrap();
2683 assert_eq!(ss_enc, ss_dec, "ML-KEM encaps/decaps via QuantumKeyEvolution");
2684 assert!(!ct.is_empty());
2685 }
2686
2687 #[test]
2688 fn test_quantum_key_evolution_hybrid_kem() {
2689 let params = LatticeParams {
2690 n: 32,
2691 q: 257,
2692 p: 3,
2693 sigma: 2.0,
2694 };
2695 let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 42, KemAlgorithm::Hybrid);
2696
2697 let (ct, ss_enc) = ke.encapsulate();
2698 let ss_dec = ke.decapsulate(&ct).unwrap();
2699 assert_eq!(ss_enc, ss_dec, "Hybrid RLWE+ML-KEM shared secret mismatch");
2700 assert_eq!(ss_enc.len(), 32, "Hybrid shared secret should be 32 bytes");
2701 assert!(ct.len() > 100, "Hybrid ciphertext should be large");
2703 }
2704
2705 #[test]
2706 fn test_mlkem_key_evolution_maintains_invariant() {
2707 let params = LatticeParams {
2708 n: 32,
2709 q: 257,
2710 p: 3,
2711 sigma: 2.0,
2712 };
2713 let mut ke = QuantumKeyEvolution::new_with_algorithm(params, 55, KemAlgorithm::MlKem768);
2714
2715 for i in 0..5 {
2716 ke.evolve();
2717 let (ct, ss_enc) = ke.encapsulate();
2718 let ss_dec = ke.decapsulate(&ct).unwrap();
2719 assert_eq!(ss_enc, ss_dec, "ML-KEM must work after evolution step {}", i);
2720 }
2721 }
2722
2723 #[test]
2724 fn test_quantum_speculative_protocol_with_mlkem() {
2725 let config = TransformerConfig::default();
2726 let params = LatticeParams {
2727 n: 32,
2728 q: 257,
2729 p: 3,
2730 sigma: 2.0,
2731 };
2732
2733 let mut alice = QuantumSpeculativeProtocol::new_with_algorithm(
2734 config.clone(),
2735 params.clone(),
2736 42,
2737 KemAlgorithm::MlKem768,
2738 );
2739 let mut bob = QuantumSpeculativeProtocol::new_with_algorithm(
2740 config,
2741 params,
2742 42,
2743 KemAlgorithm::MlKem768,
2744 );
2745
2746 assert_eq!(alice.algorithm(), KemAlgorithm::MlKem768);
2747 assert_eq!(bob.algorithm(), KemAlgorithm::MlKem768);
2748
2749 let msg = b"ML-KEM secured message";
2750 let quantum_msg = alice.send(msg);
2751 let received = bob.receive(&quantum_msg);
2752 assert!(received.is_some());
2753 assert_eq!(received.unwrap(), msg);
2754 }
2755
2756 #[test]
2757 fn test_mlkem_ciphertext_sizes() {
2758 let mut rng = StdRng::seed_from_u64(6);
2759 let kp512 = mlkem_ops::generate_512(&mut rng);
2760 let kp768 = mlkem_ops::generate_768(&mut rng);
2761 let kp1024 = mlkem_ops::generate_1024(&mut rng);
2762
2763 let (ct512, _) = mlkem_ops::encapsulate_512(&kp512.ek_bytes, &mut rng).unwrap();
2764 let (ct768, _) = mlkem_ops::encapsulate_768(&kp768.ek_bytes, &mut rng).unwrap();
2765 let (ct1024, _) = mlkem_ops::encapsulate_1024(&kp1024.ek_bytes, &mut rng).unwrap();
2766
2767 assert_eq!(ct512.len(), 768, "ML-KEM-512 ciphertext should be 768 bytes");
2768 assert_eq!(ct768.len(), 1088, "ML-KEM-768 ciphertext should be 1088 bytes");
2769 assert_eq!(ct1024.len(), 1568, "ML-KEM-1024 ciphertext should be 1568 bytes");
2770
2771 assert!(ct512.len() < ct768.len());
2773 assert!(ct768.len() < ct1024.len());
2774 }
2775}