sphinx_packet/surb/
mod.rs

1use crate::constants::{NODE_ADDRESS_LENGTH, PAYLOAD_KEY_SEED_SIZE, PAYLOAD_KEY_SIZE};
2use crate::header::delays::Delay;
3use crate::payload::key::{PayloadKey, PayloadKeySeed};
4use crate::payload::Payload;
5use crate::route::{Destination, Node, NodeAddressBytes};
6use crate::version::Version;
7use crate::{header, SphinxPacket};
8use crate::{Error, ErrorKind, Result};
9use header::{SphinxHeader, HEADER_SIZE};
10use std::fmt;
11use x25519_dalek::StaticSecret;
12
13// legacy compatibility wrapper
14#[derive(Debug)]
15enum PayloadKeysMaterial {
16    DerivedKeys(Vec<PayloadKey>),
17    KeySeeds(Vec<PayloadKeySeed>),
18}
19
20impl PayloadKeysMaterial {
21    fn from_bytes(bytes: &[u8]) -> Result<PayloadKeysMaterial> {
22        // given that our maximum path length is 5, payload key is 192 and key seed is 16,
23        // the maximum possible size of 'updated' surb seeds is 5*16 = 80, which is smaller than
24        // a single key, and thus we can use this information in order to determine which variant we should attempt to parse
25        if bytes.len() < PAYLOAD_KEY_SIZE {
26            // seeds
27            if bytes.len() % PAYLOAD_KEY_SEED_SIZE != 0 {
28                return Err(Error::new(
29                    ErrorKind::InvalidSURB,
30                    "bytes of invalid length provided",
31                ));
32            }
33            let seeds_count = bytes.len() / PAYLOAD_KEY_SEED_SIZE;
34            let mut payload_key_seeds = Vec::with_capacity(seeds_count);
35            for i in 0..seeds_count {
36                let mut payload_key = [0u8; PAYLOAD_KEY_SEED_SIZE];
37                payload_key.copy_from_slice(
38                    &bytes[i * PAYLOAD_KEY_SEED_SIZE..(i + 1) * PAYLOAD_KEY_SEED_SIZE],
39                );
40                payload_key_seeds.push(payload_key);
41            }
42            Ok(PayloadKeysMaterial::KeySeeds(payload_key_seeds))
43        } else {
44            // full keys
45            if bytes.len() % PAYLOAD_KEY_SIZE != 0 {
46                return Err(Error::new(
47                    ErrorKind::InvalidSURB,
48                    "bytes of invalid length provided",
49                ));
50            }
51            let key_count = bytes.len() / PAYLOAD_KEY_SIZE;
52            let mut payload_keys = Vec::with_capacity(key_count);
53            for i in 0..key_count {
54                let mut payload_key = [0u8; PAYLOAD_KEY_SIZE];
55                payload_key
56                    .copy_from_slice(&bytes[i * PAYLOAD_KEY_SIZE..(i + 1) * PAYLOAD_KEY_SIZE]);
57                payload_keys.push(payload_key);
58            }
59            Ok(PayloadKeysMaterial::DerivedKeys(payload_keys))
60        }
61    }
62}
63
64/// A Single Use Reply Block (SURB) must have a pre-aggregated Sphinx header,
65/// the address of the first hop in the route of the SURB, and the key material
66/// used to layer encrypt the payload.
67#[allow(non_snake_case)]
68pub struct SURB {
69    SURB_header: header::SphinxHeader,
70    first_hop_address: NodeAddressBytes,
71    payload_keys_material: PayloadKeysMaterial,
72}
73
74impl fmt::Debug for SURB {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("SURB")
77            .field("SURB_header", &self.SURB_header)
78            .field("first_hop_address", &self.first_hop_address)
79            .field("payload_keys_material", &self.payload_keys_material)
80            .finish()
81    }
82}
83
84pub struct SURBMaterial {
85    surb_route: Vec<Node>,
86    surb_delays: Vec<Delay>,
87    surb_destination: Destination,
88    version: Version,
89}
90
91impl SURBMaterial {
92    pub fn new(route: Vec<Node>, delays: Vec<Delay>, destination: Destination) -> Self {
93        SURBMaterial {
94            surb_route: route,
95            surb_delays: delays,
96            surb_destination: destination,
97            version: Default::default(),
98        }
99    }
100
101    #[allow(non_snake_case)]
102    pub fn construct_SURB(self) -> Result<SURB> {
103        let surb_initial_secret = StaticSecret::random();
104        SURB::new(surb_initial_secret, self)
105    }
106
107    #[must_use]
108    pub fn with_version(mut self, version: Version) -> Self {
109        self.version = version;
110        self
111    }
112}
113
114#[allow(non_snake_case)]
115impl SURB {
116    pub fn new(surb_initial_secret: StaticSecret, surb_material: SURBMaterial) -> Result<Self> {
117        let surb_route = surb_material.surb_route;
118        let surb_delays = surb_material.surb_delays;
119        let surb_destination = surb_material.surb_destination;
120
121        /* Pre-computes the header of the Sphinx packet which will be used as SURB
122        and encapsulates it into struct together with the address of the first hop in the route of the SURB, and the key material
123        which should be used to layer encrypt the payload. */
124        let Some(first_hop) = surb_route.first() else {
125            return Err(Error::new(
126                ErrorKind::InvalidSURB,
127                "tried to create SURB for an empty route",
128            ));
129        };
130
131        if surb_route.len() != surb_delays.len() {
132            return Err(Error::new(ErrorKind::InvalidSURB, format!("creating SURB for contradictory data: route has len {} while there are {} delays generated", surb_route.len(), surb_delays.len())));
133        }
134
135        #[allow(deprecated)]
136        let built_header = header::SphinxHeader::new_versioned(
137            &surb_initial_secret,
138            &surb_route,
139            &surb_delays,
140            &surb_destination,
141            surb_material.version,
142        );
143
144        if surb_material.version.expects_legacy_full_payload_keys() {
145            Ok(SURB {
146                first_hop_address: first_hop.address,
147                payload_keys_material: PayloadKeysMaterial::DerivedKeys(
148                    built_header.legacy_full_payload_keys(),
149                ),
150                SURB_header: built_header.into_header(),
151            })
152        } else {
153            Ok(SURB {
154                first_hop_address: first_hop.address,
155                payload_keys_material: PayloadKeysMaterial::KeySeeds(
156                    built_header.payload_key_seeds(),
157                ),
158                SURB_header: built_header.into_header(),
159            })
160        }
161    }
162
163    /// Function takes the precomputed surb header, layer encrypts the plaintext payload content
164    /// using the precomputed payload key material and returns the full Sphinx packet
165    /// together with the address of first hop to which it should be forwarded.
166    pub fn use_surb(
167        self,
168        plaintext_message: &[u8],
169        payload_size: usize,
170    ) -> Result<(SphinxPacket, NodeAddressBytes)> {
171        let header = self.SURB_header;
172
173        // Note that Payload::encapsulate_message performs checks to verify whether the plaintext
174        // is going to fit in the packet.
175        let payload = match self.payload_keys_material {
176            PayloadKeysMaterial::DerivedKeys(keys) => {
177                Payload::encapsulate_message(plaintext_message, keys.as_slice(), payload_size)?
178            }
179            PayloadKeysMaterial::KeySeeds(seeds) => {
180                Payload::encapsulate_message(plaintext_message, &seeds, payload_size)?
181            }
182        };
183
184        Ok((SphinxPacket { header, payload }, self.first_hop_address))
185    }
186
187    pub fn to_bytes(&self) -> Vec<u8> {
188        let initial_bytes = self
189            .SURB_header
190            .to_bytes()
191            .into_iter()
192            .chain(self.first_hop_address.to_bytes());
193
194        match &self.payload_keys_material {
195            PayloadKeysMaterial::DerivedKeys(keys) => initial_bytes
196                .chain(keys.iter().flat_map(|k| k.iter().copied()))
197                .collect(),
198            PayloadKeysMaterial::KeySeeds(seeds) => initial_bytes
199                .chain(seeds.iter().flat_map(|s| s.iter().copied()))
200                .collect(),
201        }
202    }
203
204    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
205        // SURB needs to contain AT LEAST a single payload key (or seed)
206        if bytes.len() < HEADER_SIZE + NODE_ADDRESS_LENGTH + PAYLOAD_KEY_SEED_SIZE {
207            return Err(Error::new(
208                ErrorKind::InvalidSURB,
209                "not enough bytes provided to try to recover a SURB",
210            ));
211        }
212
213        let header_bytes = &bytes[..HEADER_SIZE];
214        let first_hop_bytes = &bytes[HEADER_SIZE..HEADER_SIZE + NODE_ADDRESS_LENGTH];
215        let payload_keys_material_bytes = &bytes[HEADER_SIZE + NODE_ADDRESS_LENGTH..];
216
217        let SURB_header = SphinxHeader::from_bytes(header_bytes)?;
218        let first_hop_address = NodeAddressBytes::try_from_byte_slice(first_hop_bytes)?;
219        let payload_keys_material = PayloadKeysMaterial::from_bytes(payload_keys_material_bytes)?;
220
221        Ok(SURB {
222            SURB_header,
223            first_hop_address,
224            payload_keys_material,
225        })
226    }
227
228    pub fn first_hop(&self) -> NodeAddressBytes {
229        self.first_hop_address
230    }
231
232    pub fn materials_count(&self) -> usize {
233        match &self.payload_keys_material {
234            PayloadKeysMaterial::DerivedKeys(keys) => keys.len(),
235            PayloadKeysMaterial::KeySeeds(seeds) => seeds.len(),
236        }
237    }
238
239    pub fn uses_key_seeds(&self) -> bool {
240        matches!(self.payload_keys_material, PayloadKeysMaterial::KeySeeds(_))
241    }
242}
243
244#[cfg(test)]
245mod prepare_and_use_process_surb {
246    use super::*;
247    use crate::constants::NODE_ADDRESS_LENGTH;
248    use crate::header::{delays, HEADER_SIZE};
249    use crate::version::{PAYLOAD_KEYS_SEEDS_VERSION, X25519_WITH_EXPLICIT_PAYLOAD_KEYS_VERSION};
250    use crate::{
251        packet::builder::DEFAULT_PAYLOAD_SIZE,
252        test_utils::fixtures::{destination_fixture, keygen},
253    };
254    use std::time::Duration;
255
256    fn surb_material_fixture() -> SURBMaterial {
257        let (_, node1_pk) = keygen();
258        let node1 = Node {
259            address: NodeAddressBytes::from_bytes([5u8; NODE_ADDRESS_LENGTH]),
260            pub_key: node1_pk,
261        };
262        let (_, node2_pk) = keygen();
263        let node2 = Node {
264            address: NodeAddressBytes::from_bytes([4u8; NODE_ADDRESS_LENGTH]),
265            pub_key: node2_pk,
266        };
267        let (_, node3_pk) = keygen();
268        let node3 = Node {
269            address: NodeAddressBytes::from_bytes([2u8; NODE_ADDRESS_LENGTH]),
270            pub_key: node3_pk,
271        };
272
273        let surb_route = vec![node1, node2, node3];
274        let surb_destination = destination_fixture();
275        let surb_delays =
276            delays::generate_from_average_duration(surb_route.len(), Duration::from_secs(3));
277
278        SURBMaterial::new(surb_route, surb_delays, surb_destination)
279    }
280
281    #[allow(non_snake_case)]
282    fn legacy_SURB_fixture() -> SURB {
283        let surb_initial_secret = StaticSecret::random();
284        let surb_material =
285            surb_material_fixture().with_version(X25519_WITH_EXPLICIT_PAYLOAD_KEYS_VERSION);
286
287        SURB::new(surb_initial_secret, surb_material).unwrap()
288    }
289
290    #[allow(non_snake_case)]
291    fn seeded_SURB_fixture() -> SURB {
292        let surb_initial_secret = StaticSecret::random();
293        let surb_material = surb_material_fixture().with_version(PAYLOAD_KEYS_SEEDS_VERSION);
294
295        SURB::new(surb_initial_secret, surb_material).unwrap()
296    }
297
298    #[test]
299    fn returns_error_if_surb_route_empty() {
300        let surb_route = Vec::new();
301        let surb_destination = destination_fixture();
302        let surb_initial_secret = StaticSecret::random();
303        let surb_delays =
304            delays::generate_from_average_duration(surb_route.len(), Duration::from_secs(3));
305        let expected = ErrorKind::InvalidSURB;
306
307        match SURB::new(
308            surb_initial_secret,
309            SURBMaterial::new(surb_route, surb_delays, surb_destination),
310        ) {
311            Err(err) => assert_eq!(expected, err.kind()),
312            _ => panic!("Should have returned an error when route empty"),
313        };
314    }
315
316    #[test]
317    fn surb_header_has_correct_length() {
318        let pre_surb = legacy_SURB_fixture();
319        assert_eq!(pre_surb.SURB_header.to_bytes().len(), HEADER_SIZE);
320    }
321
322    #[test]
323    fn to_bytes_returns_correct_value() {
324        let pre_surb = legacy_SURB_fixture();
325        let PayloadKeysMaterial::DerivedKeys(keys) = &pre_surb.payload_keys_material else {
326            unreachable!()
327        };
328
329        let pre_surb_bytes = pre_surb.to_bytes();
330        let expected = [
331            pre_surb.SURB_header.to_bytes(),
332            [5u8; NODE_ADDRESS_LENGTH].to_vec(),
333            keys[0].to_vec(),
334            keys[1].to_vec(),
335            keys[2].to_vec(),
336        ]
337        .concat();
338        assert_eq!(pre_surb_bytes, expected);
339
340        let pre_surb = seeded_SURB_fixture();
341        let PayloadKeysMaterial::KeySeeds(seeds) = &pre_surb.payload_keys_material else {
342            unreachable!()
343        };
344
345        let pre_surb_bytes = pre_surb.to_bytes();
346        let expected = [
347            pre_surb.SURB_header.to_bytes(),
348            [5u8; NODE_ADDRESS_LENGTH].to_vec(),
349            seeds[0].to_vec(),
350            seeds[1].to_vec(),
351            seeds[2].to_vec(),
352        ]
353        .concat();
354        assert_eq!(pre_surb_bytes, expected);
355    }
356
357    #[test]
358    fn returns_error_is_payload_too_large() {
359        let pre_surb = legacy_SURB_fixture();
360        let plaintext_message = vec![42u8; 5000];
361        let expected = ErrorKind::InvalidPayload;
362
363        match SURB::use_surb(pre_surb, &plaintext_message, DEFAULT_PAYLOAD_SIZE) {
364            Err(err) => assert_eq!(expected, err.kind()),
365            _ => panic!("Should have returned an error when payload bytes too long"),
366        };
367    }
368
369    #[test]
370    #[allow(non_snake_case)]
371    fn can_be_converted_to_and_from_bytes_with_legacy_keys() {
372        let dummy_SURB = legacy_SURB_fixture();
373        let bytes = dummy_SURB.to_bytes();
374        let recovered_SURB = SURB::from_bytes(&bytes).unwrap();
375
376        assert_eq!(
377            dummy_SURB.first_hop_address,
378            recovered_SURB.first_hop_address
379        );
380
381        let PayloadKeysMaterial::DerivedKeys(original_keys) = &dummy_SURB.payload_keys_material
382        else {
383            unreachable!()
384        };
385
386        let PayloadKeysMaterial::DerivedKeys(recovered_keys) =
387            &recovered_SURB.payload_keys_material
388        else {
389            unreachable!()
390        };
391
392        for i in 0..original_keys.len() {
393            assert_eq!(original_keys[i], recovered_keys[i])
394        }
395
396        // TODO: saner way of comparing headers...
397        assert_eq!(
398            dummy_SURB.SURB_header.to_bytes(),
399            dummy_SURB.SURB_header.to_bytes()
400        );
401    }
402
403    #[test]
404    #[allow(non_snake_case)]
405    fn can_be_converted_to_and_from_bytes_with_key_seeds() {
406        let dummy_SURB = seeded_SURB_fixture();
407        let bytes = dummy_SURB.to_bytes();
408        let recovered_SURB = SURB::from_bytes(&bytes).unwrap();
409
410        assert_eq!(
411            dummy_SURB.first_hop_address,
412            recovered_SURB.first_hop_address
413        );
414
415        let PayloadKeysMaterial::KeySeeds(original_seeds) = &dummy_SURB.payload_keys_material
416        else {
417            unreachable!()
418        };
419
420        let PayloadKeysMaterial::KeySeeds(recovered_seeds) = &recovered_SURB.payload_keys_material
421        else {
422            unreachable!()
423        };
424
425        for i in 0..original_seeds.len() {
426            assert_eq!(original_seeds[i], recovered_seeds[i])
427        }
428
429        // TODO: saner way of comparing headers...
430        assert_eq!(
431            dummy_SURB.SURB_header.to_bytes(),
432            dummy_SURB.SURB_header.to_bytes()
433        );
434    }
435}