1use crate::FieldElement;
2use embed_doc_image::embed_doc_image;
3use ruint::aliases::U256;
4use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
5
6#[expect(unused_imports, reason = "used in doc comments")]
7use crate::circuit_inputs::QueryProofCircuitInput;
8
9const SESSION_FIELD_ELEMENT_PREFIX: u8 = 0x01;
10
11pub trait SessionFieldElement {
13 fn random_for_session<R: rand::CryptoRng + rand::RngCore>(rng: &mut R) -> FieldElement;
20 fn is_valid_for_session(&self) -> bool;
23}
24
25impl SessionFieldElement for FieldElement {
26 fn random_for_session<R: rand::CryptoRng + rand::RngCore>(rng: &mut R) -> FieldElement {
27 let mut bytes = [0u8; 32];
28 rng.fill_bytes(&mut bytes);
29 bytes[0] = SESSION_FIELD_ELEMENT_PREFIX;
30 let seed = U256::from_be_bytes(bytes);
31 Self::try_from(seed).expect(
32 "should always fit in the field because with 0x01 as the MSB, the field element < babyjubjub modulus",
33 )
34 }
35
36 fn is_valid_for_session(&self) -> bool {
37 self.to_be_bytes()[0] == SESSION_FIELD_ELEMENT_PREFIX
38 }
39}
40
41#[embed_doc_image("session-proofs.png", "assets/session-proofs.png")]
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct SessionId {
57 commitment: FieldElement,
62 oprf_seed: FieldElement,
84}
85
86impl SessionId {
87 const JSON_PREFIX: &str = "session_";
88
89 #[must_use]
94 pub fn new(commitment: FieldElement, oprf_seed: FieldElement) -> Self {
95 assert!(oprf_seed.is_valid_for_session());
99 Self {
100 commitment,
101 oprf_seed,
102 }
103 }
104
105 pub fn from_r_seed<R: rand::CryptoRng + rand::RngCore>(
118 leaf_index: u64,
119 session_id_r_seed: FieldElement,
120 oprf_seed: Option<FieldElement>,
121 rng: &mut R,
122 ) -> Result<Self, &str> {
123 let sub_ds = FieldElement::from_be_bytes_mod_order(b"H(id, r)");
124
125 let oprf_seed = if let Some(seed) = oprf_seed {
126 if !seed.is_valid_for_session() {
127 return Err("oprf_seed is not valid for session");
128 }
129 seed
130 } else {
131 FieldElement::random_for_session(rng)
132 };
133
134 let mut input = [*sub_ds, leaf_index.into(), *session_id_r_seed];
135 poseidon2::bn254::t3::permutation_in_place(&mut input);
136 let commitment = input[1].into();
137 Ok(Self {
138 commitment,
139 oprf_seed,
140 })
141 }
142
143 #[must_use]
145 pub const fn commitment(&self) -> FieldElement {
146 self.commitment
147 }
148
149 #[must_use]
151 pub const fn oprf_seed(&self) -> FieldElement {
152 self.oprf_seed
153 }
154
155 #[must_use]
157 pub fn to_compressed_bytes(&self) -> [u8; 64] {
158 let mut bytes = [0u8; 64];
159 bytes[..32].copy_from_slice(&self.commitment.to_be_bytes());
160 bytes[32..].copy_from_slice(&self.oprf_seed.to_be_bytes());
161 bytes
162 }
163
164 pub fn from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
169 if bytes.len() != 64 {
170 return Err(format!(
171 "Invalid length: expected 64 bytes, got {}",
172 bytes.len()
173 ));
174 }
175
176 let commitment = FieldElement::from_be_bytes(bytes[..32].try_into().unwrap())
177 .map_err(|e| format!("invalid commitment: {e}"))?;
178 let oprf_seed = FieldElement::from_be_bytes(bytes[32..].try_into().unwrap())
179 .map_err(|e| format!("invalid oprf_seed: {e}"))?;
180
181 if bytes[32] != SESSION_FIELD_ELEMENT_PREFIX {
182 return Err("invalid prefix for oprf_seed".to_string());
183 }
184
185 Ok(Self {
186 commitment,
187 oprf_seed,
188 })
189 }
190}
191
192impl Default for SessionId {
193 fn default() -> Self {
194 let mut oprf_seed = [0u8; 32];
195 oprf_seed[0] = SESSION_FIELD_ELEMENT_PREFIX;
196 let oprf_seed = U256::from_be_bytes(oprf_seed)
197 .try_into()
198 .expect("always fits in the field");
199 Self {
200 commitment: FieldElement::ZERO,
201 oprf_seed,
202 }
203 }
204}
205
206impl Serialize for SessionId {
207 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208 where
209 S: Serializer,
210 {
211 let bytes = self.to_compressed_bytes();
212 if serializer.is_human_readable() {
213 serializer.serialize_str(&format!("{}{}", Self::JSON_PREFIX, hex::encode(bytes)))
215 } else {
216 serializer.serialize_bytes(&bytes)
218 }
219 }
220}
221
222impl<'de> Deserialize<'de> for SessionId {
223 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224 where
225 D: Deserializer<'de>,
226 {
227 let bytes = if deserializer.is_human_readable() {
228 let value = String::deserialize(deserializer)?;
229 let hex_str = value.strip_prefix(Self::JSON_PREFIX).ok_or_else(|| {
230 D::Error::custom(format!(
231 "session id must start with '{}'",
232 Self::JSON_PREFIX
233 ))
234 })?;
235 hex::decode(hex_str).map_err(D::Error::custom)?
236 } else {
237 Vec::deserialize(deserializer)?
238 };
239
240 Self::from_compressed_bytes(&bytes).map_err(D::Error::custom)
241 }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
264pub struct SessionNullifier {
265 nullifier: FieldElement,
267 action: FieldElement,
269}
270
271impl SessionNullifier {
272 const JSON_PREFIX: &str = "snil_";
273
274 #[must_use]
276 pub const fn new(nullifier: FieldElement, action: FieldElement) -> Self {
277 Self { nullifier, action }
278 }
279
280 #[must_use]
282 pub const fn nullifier(&self) -> FieldElement {
283 self.nullifier
284 }
285
286 #[must_use]
288 pub const fn action(&self) -> FieldElement {
289 self.action
290 }
291
292 #[must_use]
296 pub fn as_ethereum_representation(&self) -> [U256; 2] {
297 [self.nullifier.into(), self.action.into()]
298 }
299
300 pub fn from_ethereum_representation(value: [U256; 2]) -> Result<Self, String> {
305 let nullifier =
306 FieldElement::try_from(value[0]).map_err(|e| format!("invalid nullifier: {e}"))?;
307 let action =
308 FieldElement::try_from(value[1]).map_err(|e| format!("invalid action: {e}"))?;
309 Ok(Self { nullifier, action })
310 }
311
312 #[must_use]
314 pub fn to_compressed_bytes(&self) -> [u8; 64] {
315 let mut bytes = [0u8; 64];
316 bytes[..32].copy_from_slice(&self.nullifier.to_be_bytes());
317 bytes[32..].copy_from_slice(&self.action.to_be_bytes());
318 bytes
319 }
320
321 pub fn from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
326 if bytes.len() != 64 {
327 return Err(format!(
328 "Invalid length: expected 64 bytes, got {}",
329 bytes.len()
330 ));
331 }
332
333 let nullifier = FieldElement::from_be_bytes(bytes[..32].try_into().unwrap())
334 .map_err(|e| format!("invalid nullifier: {e}"))?;
335 let action = FieldElement::from_be_bytes(bytes[32..].try_into().unwrap())
336 .map_err(|e| format!("invalid action: {e}"))?;
337
338 Ok(Self { nullifier, action })
339 }
340}
341
342impl Default for SessionNullifier {
343 fn default() -> Self {
344 Self {
345 nullifier: FieldElement::ZERO,
346 action: FieldElement::ZERO,
347 }
348 }
349}
350
351impl Serialize for SessionNullifier {
352 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
353 where
354 S: Serializer,
355 {
356 let bytes = self.to_compressed_bytes();
357 if serializer.is_human_readable() {
358 serializer.serialize_str(&format!("{}{}", Self::JSON_PREFIX, hex::encode(bytes)))
360 } else {
361 serializer.serialize_bytes(&bytes)
363 }
364 }
365}
366
367impl<'de> Deserialize<'de> for SessionNullifier {
368 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
369 where
370 D: Deserializer<'de>,
371 {
372 let bytes = if deserializer.is_human_readable() {
373 let value = String::deserialize(deserializer)?;
374 let hex_str = value.strip_prefix(Self::JSON_PREFIX).ok_or_else(|| {
375 D::Error::custom(format!(
376 "session nullifier must start with '{}'",
377 Self::JSON_PREFIX
378 ))
379 })?;
380 hex::decode(hex_str).map_err(D::Error::custom)?
381 } else {
382 Vec::deserialize(deserializer)?
383 };
384
385 Self::from_compressed_bytes(&bytes).map_err(D::Error::custom)
386 }
387}
388
389impl From<SessionNullifier> for [U256; 2] {
390 fn from(value: SessionNullifier) -> Self {
391 value.as_ethereum_representation()
392 }
393}
394
395impl From<(FieldElement, FieldElement)> for SessionNullifier {
396 fn from((nullifier, action): (FieldElement, FieldElement)) -> Self {
397 Self::new(nullifier, action)
398 }
399}
400
401#[cfg(test)]
402mod session_id_tests {
403 use super::*;
404 use ruint::uint;
405
406 fn test_field_element(value: u64) -> FieldElement {
407 FieldElement::from(value)
408 }
409
410 fn test_oprf_seed(value: u64) -> FieldElement {
412 let n = U256::from(value)
414 | uint!(0x0100000000000000000000000000000000000000000000000000000000000000_U256);
415 FieldElement::try_from(n).expect("test value fits in field")
416 }
417
418 #[test]
419 fn test_new_and_accessors() {
420 let commitment = test_field_element(1001);
421 let seed = test_oprf_seed(42);
422 let id = SessionId::new(commitment, seed);
423
424 assert_eq!(id.commitment(), commitment);
425 assert_eq!(id.oprf_seed(), seed);
426 }
427
428 #[test]
429 fn test_default() {
430 let id = SessionId::default();
431 assert_eq!(id.commitment(), FieldElement::ZERO);
432 assert_eq!(
433 id.oprf_seed(),
434 uint!(0x0100000000000000000000000000000000000000000000000000000000000000_U256)
435 .try_into()
436 .unwrap()
437 );
438 }
439
440 #[test]
441 fn test_bytes_roundtrip() {
442 let id = SessionId::new(test_field_element(1001), test_oprf_seed(42));
443 let bytes = id.to_compressed_bytes();
444
445 assert_eq!(bytes.len(), 64);
446
447 let decoded = SessionId::from_compressed_bytes(&bytes).unwrap();
448 assert_eq!(id, decoded);
449 }
450
451 #[test]
452 fn test_bytes_use_field_element_encoding() {
453 let id = SessionId::new(test_field_element(1001), test_oprf_seed(42));
454 let bytes = id.to_compressed_bytes();
455
456 let mut expected = [0u8; 64];
457 expected[..32].copy_from_slice(&id.commitment().to_be_bytes());
458 expected[32..].copy_from_slice(&id.oprf_seed().to_be_bytes());
459 assert_eq!(bytes, expected);
460 }
461
462 #[test]
463 fn test_invalid_bytes_length() {
464 let too_short = vec![0u8; 63];
465 let result = SessionId::from_compressed_bytes(&too_short);
466 assert!(result.is_err());
467 assert!(result.unwrap_err().contains("Invalid length"));
468
469 let too_long = vec![0u8; 65];
470 let result = SessionId::from_compressed_bytes(&too_long);
471 assert!(result.is_err());
472 assert!(result.unwrap_err().contains("Invalid length"));
473 }
474
475 #[test]
476 fn test_from_compressed_bytes_rejects_wrong_oprf_seed_prefix() {
477 let mut bytes = [0u8; 64];
478 bytes[32] = 0x00;
481 let result = SessionId::from_compressed_bytes(&bytes);
482 assert!(result.is_err());
483 assert!(
484 result.unwrap_err().contains("invalid prefix"),
485 "should reject oprf_seed without 0x01 prefix"
486 );
487 }
488
489 #[test]
490 fn test_json_roundtrip() {
491 let id = SessionId::new(test_field_element(1001), test_oprf_seed(42));
492 let json = serde_json::to_string(&id).unwrap();
493
494 assert!(json.starts_with("\"session_"));
495 assert!(json.ends_with('"'));
496
497 let decoded: SessionId = serde_json::from_str(&json).unwrap();
498 assert_eq!(id, decoded);
499 }
500
501 #[test]
502 fn test_json_format() {
503 let id = SessionId::new(test_field_element(1), test_oprf_seed(2));
504 let json = serde_json::to_string(&id).unwrap();
505
506 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
507 assert!(parsed.is_string());
508 let value = parsed.as_str().unwrap();
509 assert!(value.starts_with("session_"));
510 }
511
512 #[test]
513 fn test_json_wrong_prefix_rejected() {
514 let result = serde_json::from_str::<SessionId>("\"snil_00\"");
515 assert!(result.is_err());
516 }
517
518 #[test]
519 fn test_from_r_seed_generates_random_seed() {
520 let mut rng = rand::rngs::OsRng;
521 let r_seed = test_field_element(999);
522
523 let id1 = SessionId::from_r_seed(0, r_seed, None, &mut rng).unwrap();
524 let id2 = SessionId::from_r_seed(0, r_seed, None, &mut rng).unwrap();
525
526 assert_ne!(id1.oprf_seed(), id2.oprf_seed());
527 }
528
529 #[test]
530 fn test_from_r_seed_generated_seed_has_session_prefix() {
531 let mut rng = rand::rngs::OsRng;
532 let r_seed = test_field_element(999);
533
534 for _ in 0..50 {
535 let id = SessionId::from_r_seed(0, r_seed, None, &mut rng).unwrap();
536 assert_eq!(id.oprf_seed().to_u256() >> 248, U256::from(1));
538 }
539 }
540
541 #[test]
542 fn test_from_r_seed_commitment_snapshot() {
543 let leaf_index = 42u64;
544 let r_seed = test_field_element(123);
545 let oprf_seed = test_oprf_seed(456);
546
547 let session_id =
548 SessionId::from_r_seed(leaf_index, r_seed, Some(oprf_seed), &mut rand::rngs::OsRng)
549 .unwrap();
550
551 let expected = "0x1e7853ebd4fc9d9f0232fdcfae116023610bdf66a22e2700445d7a2e0e7e6152"
552 .parse::<U256>()
553 .unwrap();
554 assert_eq!(
555 session_id.commitment().to_u256(),
556 expected,
557 "commitment snapashot for session commitment changed"
558 );
559 }
560}
561
562#[cfg(test)]
563mod session_nullifier_tests {
564 use super::*;
565
566 fn test_field_element(value: u64) -> FieldElement {
567 FieldElement::from(value)
568 }
569
570 #[test]
571 fn test_new_and_accessors() {
572 let nullifier = test_field_element(1001);
573 let action = test_field_element(42);
574 let session = SessionNullifier::new(nullifier, action);
575
576 assert_eq!(session.nullifier(), nullifier);
577 assert_eq!(session.action(), action);
578 }
579
580 #[test]
581 fn test_as_ethereum_representation() {
582 let nullifier = test_field_element(100);
583 let action = test_field_element(200);
584 let session = SessionNullifier::new(nullifier, action);
585
586 let repr = session.as_ethereum_representation();
587 assert_eq!(repr[0], U256::from(100));
588 assert_eq!(repr[1], U256::from(200));
589 }
590
591 #[test]
592 fn test_from_ethereum_representation() {
593 let repr = [U256::from(100), U256::from(200)];
594 let session = SessionNullifier::from_ethereum_representation(repr).unwrap();
595
596 assert_eq!(session.nullifier(), test_field_element(100));
597 assert_eq!(session.action(), test_field_element(200));
598 }
599
600 #[test]
601 fn test_json_roundtrip() {
602 let session = SessionNullifier::new(test_field_element(1001), test_field_element(42));
603 let json = serde_json::to_string(&session).unwrap();
604
605 assert!(json.starts_with("\"snil_"));
607 assert!(json.ends_with('"'));
608
609 let decoded: SessionNullifier = serde_json::from_str(&json).unwrap();
611 assert_eq!(session, decoded);
612 }
613
614 #[test]
615 fn test_json_format() {
616 let session = SessionNullifier::new(test_field_element(1), test_field_element(2));
617 let json = serde_json::to_string(&session).unwrap();
618
619 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
621 assert!(parsed.is_string());
622 let value = parsed.as_str().unwrap();
623 assert!(value.starts_with("snil_"));
624 }
625
626 #[test]
627 fn test_bytes_roundtrip() {
628 let session = SessionNullifier::new(test_field_element(1001), test_field_element(42));
629 let bytes = session.to_compressed_bytes();
630
631 assert_eq!(bytes.len(), 64); let decoded = SessionNullifier::from_compressed_bytes(&bytes).unwrap();
634 assert_eq!(session, decoded);
635 }
636
637 #[test]
638 fn test_bytes_use_field_element_encoding() {
639 let session = SessionNullifier::new(test_field_element(1001), test_field_element(42));
640 let bytes = session.to_compressed_bytes();
641
642 let mut expected = [0u8; 64];
643 expected[..32].copy_from_slice(&session.nullifier().to_be_bytes());
644 expected[32..].copy_from_slice(&session.action().to_be_bytes());
645 assert_eq!(bytes, expected);
646 }
647
648 #[test]
649 fn test_invalid_bytes_length() {
650 let too_short = vec![0u8; 63];
651 let result = SessionNullifier::from_compressed_bytes(&too_short);
652 assert!(result.is_err());
653 assert!(result.unwrap_err().contains("Invalid length"));
654
655 let too_long = vec![0u8; 65];
656 let result = SessionNullifier::from_compressed_bytes(&too_long);
657 assert!(result.is_err());
658 assert!(result.unwrap_err().contains("Invalid length"));
659 }
660
661 #[test]
662 fn test_default() {
663 let session = SessionNullifier::default();
664 assert_eq!(session.nullifier(), FieldElement::ZERO);
665 assert_eq!(session.action(), FieldElement::ZERO);
666 }
667
668 #[test]
669 fn test_from_tuple() {
670 let nullifier = test_field_element(100);
671 let action = test_field_element(200);
672 let session: SessionNullifier = (nullifier, action).into();
673
674 assert_eq!(session.nullifier(), nullifier);
675 assert_eq!(session.action(), action);
676 }
677
678 #[test]
679 fn test_into_u256_array() {
680 let session = SessionNullifier::new(test_field_element(100), test_field_element(200));
681 let arr: [U256; 2] = session.into();
682
683 assert_eq!(arr[0], U256::from(100));
684 assert_eq!(arr[1], U256::from(200));
685 }
686}