rustls_graviola/
ticketer.rs1use core::sync::atomic::{AtomicUsize, Ordering};
2use std::fmt;
3use std::sync::Arc;
4
5use graviola::{aead, random};
6use rustls::crypto::GetRandomFailed;
7use rustls::server::ProducesTickets;
8use rustls::{Error, TicketRotator};
9
10pub struct Ticketer;
12
13impl Ticketer {
14 #[allow(clippy::new_ret_no_self)]
19 pub fn new() -> Result<Arc<dyn ProducesTickets>, Error> {
20 Ok(Arc::new(TicketRotator::new(
21 ONE_TICKET_LIFETIME_SECS,
22 make_ticket_generator,
23 )?))
24 }
25}
26
27fn make_ticket_generator() -> Result<Box<dyn ProducesTickets>, GetRandomFailed> {
28 Ok(Box::new(XChaCha20Ticketer::new()?))
29}
30
31struct XChaCha20Ticketer {
32 key: aead::XChaCha20Poly1305,
33 key_name: [u8; 16],
34 lifetime: u32,
35 maximum_ciphertext_len: AtomicUsize,
36}
37
38impl XChaCha20Ticketer {
39 fn new() -> Result<Self, GetRandomFailed> {
40 let mut key = [0u8; 32];
41 let mut key_name = [0u8; 16];
42
43 random::fill(&mut key).map_err(|_| GetRandomFailed)?;
44 random::fill(&mut key_name).map_err(|_| GetRandomFailed)?;
45
46 let key = aead::XChaCha20Poly1305::new(key);
47
48 Ok(Self {
49 key,
50 key_name,
51 lifetime: ONE_TICKET_LIFETIME_SECS,
52 maximum_ciphertext_len: AtomicUsize::new(0),
53 })
54 }
55}
56
57impl ProducesTickets for XChaCha20Ticketer {
58 fn enabled(&self) -> bool {
59 true
60 }
61
62 fn lifetime(&self) -> u32 {
63 self.lifetime
64 }
65
66 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
67 let mut nonce = [0u8; 24];
68 random::fill(&mut nonce).ok()?;
69
70 let mut tag = [0u8; 16];
79 let mut res =
80 Vec::with_capacity(self.key_name.len() + nonce.len() + message.len() + tag.len());
81 res.extend(&self.key_name);
82 res.extend(&nonce);
83 res.extend(message);
84
85 self.key.encrypt(
86 &nonce,
87 &self.key_name,
88 &mut res[self.key_name.len() + nonce.len()..],
89 &mut tag,
90 );
91 res.extend(tag);
92
93 self.maximum_ciphertext_len
94 .fetch_max(res.len(), Ordering::SeqCst);
95
96 Some(res)
97 }
98
99 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
100 if ciphertext.len() > self.maximum_ciphertext_len.load(Ordering::SeqCst) {
101 return None;
102 }
103
104 let plain_len = ciphertext
105 .len()
106 .saturating_sub(self.key_name.len() + 24 + 16);
107
108 if plain_len == 0 {
109 return None;
110 }
111
112 let (alleged_key_name, rest) = ciphertext.split_at(self.key_name.len());
113
114 if alleged_key_name != self.key_name {
116 return None;
117 }
118
119 let (nonce, rest) = rest.split_at(24);
120 let nonce = nonce.try_into().unwrap();
121 let (plain, alleged_tag) = rest.split_at(plain_len);
122 let mut plain = plain.to_vec();
123
124 self.key
125 .decrypt(&nonce, alleged_key_name, &mut plain, alleged_tag)
126 .ok()?;
127 Some(plain)
128 }
129}
130
131impl fmt::Debug for XChaCha20Ticketer {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 f.debug_struct("XChaCha20Ticketer")
134 .field("lifetime", &self.lifetime)
135 .finish_non_exhaustive()
136 }
137}
138
139const ONE_TICKET_LIFETIME_SECS: u32 = 6 * 60 * 60;
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn roundtrip() {
147 let t = Ticketer::new().unwrap();
148 let ehello = t.encrypt(b"hello").unwrap();
149 assert_eq!(t.decrypt(&ehello).unwrap(), b"hello");
150
151 assert!(t.enabled());
152 assert_eq!(t.lifetime(), ONE_TICKET_LIFETIME_SECS * 2);
153 println!("{t:?}");
154 }
155
156 #[test]
157 fn make_generator() {
158 let g = make_ticket_generator().unwrap();
159 assert!(g.enabled());
160 assert_eq!(g.lifetime(), ONE_TICKET_LIFETIME_SECS);
161 println!("{g:?}");
162 }
163
164 #[test]
165 fn length_checks() {
166 let t = Ticketer::new().unwrap();
167 assert_eq!(t.decrypt(b""), None);
168 assert_eq!(t.decrypt(b"a"), None);
169
170 let e = t.encrypt(b"a").unwrap();
171 assert_eq!(t.decrypt(&e).unwrap(), b"a");
172 assert_eq!(t.decrypt(&e[..e.len() - 1]), None);
173 }
174
175 #[test]
176 fn non_malleable() {
177 let t = Ticketer::new().unwrap();
178 let ehello = t.encrypt(b"hello").unwrap();
179
180 for i in 0..ehello.len() {
181 let mut ehello_tmp = ehello.clone();
182 ehello_tmp[i] ^= 1;
183 assert_eq!(None, t.decrypt(&ehello_tmp));
184 }
185 }
186}