Skip to main content

world_id_primitives/
rp.rs

1#![allow(clippy::unreadable_literal)]
2
3use std::{fmt, str::FromStr};
4
5use ark_ff::{BigInteger as _, PrimeField as _};
6use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
7
8use crate::FieldElement;
9
10const RP_SIGNATURE_MSG_VERSION: u8 = 0x01;
11
12/// The id of a relying party.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct RpId(u64);
15
16impl RpId {
17    /// Converts the RP id to an u64
18    #[must_use]
19    pub const fn into_inner(self) -> u64 {
20        self.0
21    }
22
23    /// Creates a new `RpId` by wrapping a `u64`
24    #[must_use]
25    pub const fn new(value: u64) -> Self {
26        Self(value)
27    }
28}
29
30impl fmt::Display for RpId {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "rp_{:016x}", self.0)
33    }
34}
35
36impl FromStr for RpId {
37    type Err = String;
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        if let Some(id) = s.strip_prefix("rp_") {
41            Ok(Self(u64::from_str_radix(id, 16).map_err(|_| {
42                "Invalid RP ID format: expected hex string".to_string()
43            })?))
44        } else {
45            Err("A valid RP ID must start with 'rp_'".to_string())
46        }
47    }
48}
49
50impl From<u64> for RpId {
51    fn from(value: u64) -> Self {
52        Self(value)
53    }
54}
55
56impl From<RpId> for FieldElement {
57    fn from(value: RpId) -> Self {
58        Self::from(value.0)
59    }
60}
61
62impl Serialize for RpId {
63    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
64    where
65        S: Serializer,
66    {
67        if serializer.is_human_readable() {
68            serializer.serialize_str(&self.to_string())
69        } else {
70            u64::serialize(&self.0, serializer)
71        }
72    }
73}
74
75impl<'de> Deserialize<'de> for RpId {
76    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77    where
78        D: Deserializer<'de>,
79    {
80        if deserializer.is_human_readable() {
81            let s = String::deserialize(deserializer)?;
82            Self::from_str(&s).map_err(D::Error::custom)
83        } else {
84            let value = u64::deserialize(deserializer)?;
85            Ok(Self(value))
86        }
87    }
88}
89
90/// Computes the message to be signed for the RP signature.
91///
92/// The message format is: `version || nonce || created_at || expires_at` (49 bytes total).
93/// - `version`: 1 byte (currently hardcoded to `0x01`)
94/// - `nonce`: 32 bytes (big-endian)
95/// - `created_at`: 8 bytes (big-endian)
96/// - `expires_at`: 8 bytes (big-endian)
97#[must_use]
98pub fn compute_rp_signature_msg(
99    nonce: ark_babyjubjub::Fq,
100    created_at: u64,
101    expires_at: u64,
102) -> Vec<u8> {
103    let mut msg = Vec::with_capacity(49);
104    msg.push(RP_SIGNATURE_MSG_VERSION);
105    msg.extend(nonce.into_bigint().to_bytes_be());
106    msg.extend(created_at.to_be_bytes());
107    msg.extend(expires_at.to_be_bytes());
108    msg
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_rpid_display() {
117        let rp_id = RpId::new(0x123456789abcdef0);
118        assert_eq!(rp_id.to_string(), "rp_123456789abcdef0");
119
120        let rp_id = RpId::new(u64::MAX);
121        assert_eq!(rp_id.to_string(), "rp_ffffffffffffffff");
122
123        let rp_id = RpId::new(0);
124        assert_eq!(rp_id.to_string(), "rp_0000000000000000");
125    }
126
127    #[test]
128    fn test_rpid_from_str() {
129        let rp_id = "rp_123456789abcdef0".parse::<RpId>().unwrap();
130        assert_eq!(rp_id.0, 0x123456789abcdef0);
131
132        let rp_id = "rp_ffffffffffffffff".parse::<RpId>().unwrap();
133        assert_eq!(rp_id.0, u64::MAX);
134
135        let rp_id = "rp_0000000000000000".parse::<RpId>().unwrap();
136        assert_eq!(rp_id.0, 0);
137
138        let rp_id = "rp_123456789ABCDEF0".parse::<RpId>().unwrap();
139        assert_eq!(rp_id.0, 0x123456789abcdef0);
140    }
141
142    #[test]
143    fn test_rpid_from_str_errors() {
144        assert!("123456789abcdef0".parse::<RpId>().is_err());
145        assert!("rp_invalid".parse::<RpId>().is_err());
146        assert!("rp_".parse::<RpId>().is_err());
147    }
148
149    #[test]
150    fn test_rpid_roundtrip() {
151        let original = RpId::new(0x123456789abcdef0);
152        let s = original.to_string();
153        let parsed = s.parse::<RpId>().unwrap();
154        assert_eq!(original, parsed);
155    }
156
157    #[test]
158    fn test_rpid_json_serialization() {
159        let rp_id = RpId::new(0x123456789abcdef0);
160        let json = serde_json::to_string(&rp_id).unwrap();
161        assert_eq!(json, "\"rp_123456789abcdef0\"");
162
163        let deserialized: RpId = serde_json::from_str(&json).unwrap();
164        assert_eq!(rp_id, deserialized);
165    }
166
167    #[test]
168    fn test_rpid_binary_serialization() {
169        let rp_id = RpId::new(0x123456789abcdef0);
170
171        let mut buffer = Vec::new();
172        ciborium::into_writer(&rp_id, &mut buffer).unwrap();
173
174        let decoded: RpId = ciborium::from_reader(&buffer[..]).unwrap();
175
176        assert_eq!(rp_id, decoded);
177    }
178
179    #[test]
180    fn test_compute_rp_signature_msg_fixed_length() {
181        // Test with small values that would have leading zeros in variable-length encoding
182        // to ensure we always get fixed 32-byte field elements
183        let nonce = ark_babyjubjub::Fq::from(1u64);
184        let created_at = 1000u64;
185        let expires_at = 2000u64;
186
187        let msg = compute_rp_signature_msg(nonce, created_at, expires_at);
188
189        // Message must always be exactly 49 bytes:
190        // 1 (version) + 32 (nonce) + 8 (created_at) + 8 (expires_at)
191        assert_eq!(
192            msg.len(),
193            49,
194            "RP signature message must be exactly 49 bytes"
195        );
196        assert_eq!(
197            msg[0], RP_SIGNATURE_MSG_VERSION,
198            "RP signature message version must be 0x01"
199        );
200    }
201}