1use std::collections::HashMap;
37use std::sync::{Arc, Mutex};
38use std::time::{Duration, Instant};
39
40use bytes::Bytes;
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct NativeH3SessionCacheKey {
45 pub sni: String,
47 pub alpn: Vec<Vec<u8>>,
51 pub verify_peer: bool,
55 pub fingerprint_pin: Option<String>,
64}
65
66impl NativeH3SessionCacheKey {
67 pub fn new(
68 sni: impl Into<String>,
69 alpn: impl IntoIterator<Item = Vec<u8>>,
70 verify_peer: bool,
71 fingerprint_pin: Option<String>,
72 ) -> Self {
73 Self {
74 sni: sni.into(),
75 alpn: alpn.into_iter().collect(),
76 verify_peer,
77 fingerprint_pin,
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
89pub struct NativeH3SessionEntry {
90 pub der: Bytes,
91 pub max_early_data: u32,
92 pub received_at: Instant,
93 pub lifetime: Duration,
94}
95
96impl NativeH3SessionEntry {
97 pub fn new(der: Bytes, max_early_data: u32, lifetime: Duration) -> Self {
98 Self {
99 der,
100 max_early_data,
101 received_at: Instant::now(),
102 lifetime,
103 }
104 }
105
106 pub fn is_expired(&self) -> bool {
107 self.received_at.elapsed() >= self.lifetime
108 }
109
110 pub fn supports_zero_rtt(&self) -> bool {
111 self.max_early_data > 0 && !self.is_expired()
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct NativeH3SessionCache {
118 inner: Arc<Mutex<NativeH3SessionCacheInner>>,
119}
120
121#[derive(Debug)]
122struct NativeH3SessionCacheInner {
123 entries: HashMap<NativeH3SessionCacheKey, NativeH3SessionEntry>,
124 default_lifetime: Duration,
125 max_entries: usize,
126}
127
128const DEFAULT_LIFETIME_SECS: u64 = 6 * 3600;
129const DEFAULT_MAX_ENTRIES: usize = 256;
130
131impl NativeH3SessionCache {
132 pub fn new() -> Self {
133 Self::with_capacity(
134 DEFAULT_MAX_ENTRIES,
135 Duration::from_secs(DEFAULT_LIFETIME_SECS),
136 )
137 }
138
139 pub fn with_capacity(max_entries: usize, default_lifetime: Duration) -> Self {
140 Self {
141 inner: Arc::new(Mutex::new(NativeH3SessionCacheInner {
142 entries: HashMap::new(),
143 default_lifetime,
144 max_entries: max_entries.max(1),
145 })),
146 }
147 }
148
149 pub fn insert(
156 &self,
157 key: NativeH3SessionCacheKey,
158 der: impl Into<Bytes>,
159 max_early_data: u32,
160 lifetime: Option<Duration>,
161 ) {
162 let mut inner = self.inner.lock().expect("native H3 session cache poisoned");
163 let lifetime = lifetime.unwrap_or(inner.default_lifetime);
164 if inner.entries.len() >= inner.max_entries && !inner.entries.contains_key(&key) {
165 let oldest_expired = inner
168 .entries
169 .iter()
170 .filter(|(_, entry)| entry.is_expired())
171 .min_by_key(|(_, entry)| entry.received_at)
172 .map(|(k, _)| k.clone());
173 if let Some(stale) = oldest_expired {
174 inner.entries.remove(&stale);
175 } else if let Some(oldest) = inner
176 .entries
177 .iter()
178 .min_by_key(|(_, entry)| entry.received_at)
179 .map(|(k, _)| k.clone())
180 {
181 inner.entries.remove(&oldest);
182 }
183 }
184 inner.entries.insert(
185 key,
186 NativeH3SessionEntry::new(der.into(), max_early_data, lifetime),
187 );
188 }
189
190 pub fn get(&self, key: &NativeH3SessionCacheKey) -> Option<NativeH3SessionEntry> {
193 let mut inner = self.inner.lock().expect("native H3 session cache poisoned");
194 match inner.entries.get(key) {
195 Some(entry) if !entry.is_expired() => Some(entry.clone()),
196 Some(_) => {
197 inner.entries.remove(key);
198 None
199 }
200 None => None,
201 }
202 }
203
204 pub fn evict(&self, key: &NativeH3SessionCacheKey) {
206 let mut inner = self.inner.lock().expect("native H3 session cache poisoned");
207 inner.entries.remove(key);
208 }
209
210 pub fn purge_expired(&self) {
212 let mut inner = self.inner.lock().expect("native H3 session cache poisoned");
213 inner.entries.retain(|_, entry| !entry.is_expired());
214 }
215
216 pub fn clear(&self) {
218 let mut inner = self.inner.lock().expect("native H3 session cache poisoned");
219 inner.entries.clear();
220 }
221
222 pub fn len(&self) -> usize {
223 self.inner
224 .lock()
225 .expect("native H3 session cache poisoned")
226 .entries
227 .len()
228 }
229
230 pub fn is_empty(&self) -> bool {
231 self.len() == 0
232 }
233}
234
235impl Default for NativeH3SessionCache {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn key(
246 sni: &str,
247 alpn: &[&[u8]],
248 verify_peer: bool,
249 pin: Option<&str>,
250 ) -> NativeH3SessionCacheKey {
251 NativeH3SessionCacheKey::new(
252 sni,
253 alpn.iter().map(|p| p.to_vec()),
254 verify_peer,
255 pin.map(|s| s.to_string()),
256 )
257 }
258
259 #[test]
260 fn insert_get_round_trip() {
261 let cache = NativeH3SessionCache::new();
262 let k = key("example.com", &[b"h3"], true, Some("chrome"));
263 cache.insert(k.clone(), Bytes::from_static(b"der-bytes"), 16_384, None);
264
265 let entry = cache.get(&k).expect("entry present");
266 assert_eq!(entry.der.as_ref(), b"der-bytes");
267 assert_eq!(entry.max_early_data, 16_384);
268 assert!(entry.supports_zero_rtt());
269 }
270
271 #[test]
272 fn fingerprint_pin_isolates_entries() {
273 let cache = NativeH3SessionCache::new();
274 let chrome_key = key("example.com", &[b"h3"], true, Some("chrome"));
275 let firefox_key = key("example.com", &[b"h3"], true, Some("firefox"));
276 cache.insert(
277 chrome_key.clone(),
278 Bytes::from_static(b"chrome-der"),
279 0,
280 None,
281 );
282 assert!(cache.get(&firefox_key).is_none());
283 assert!(cache.get(&chrome_key).is_some());
284 }
285
286 #[test]
287 fn verify_peer_dimension_isolates_entries() {
288 let cache = NativeH3SessionCache::new();
289 let strict = key("example.com", &[b"h3"], true, None);
290 let relaxed = key("example.com", &[b"h3"], false, None);
291 cache.insert(strict.clone(), Bytes::from_static(b"strict"), 0, None);
292 cache.insert(relaxed.clone(), Bytes::from_static(b"relaxed"), 0, None);
293 assert_eq!(cache.get(&strict).unwrap().der.as_ref(), b"strict");
294 assert_eq!(cache.get(&relaxed).unwrap().der.as_ref(), b"relaxed");
295 }
296
297 #[test]
298 fn alpn_dimension_isolates_entries() {
299 let cache = NativeH3SessionCache::new();
300 let h3 = key("example.com", &[b"h3"], true, None);
301 let h2 = key("example.com", &[b"h2"], true, None);
302 cache.insert(h3.clone(), Bytes::from_static(b"h3"), 0, None);
303 assert!(cache.get(&h2).is_none());
304 assert_eq!(cache.get(&h3).unwrap().der.as_ref(), b"h3");
305 }
306
307 #[test]
308 fn expired_entries_are_evicted_on_lookup() {
309 let cache = NativeH3SessionCache::with_capacity(8, Duration::from_millis(50));
310 let k = key("example.com", &[b"h3"], true, None);
311 cache.insert(k.clone(), Bytes::from_static(b"short-lived"), 0, None);
312 std::thread::sleep(Duration::from_millis(80));
313 assert!(cache.get(&k).is_none());
314 assert_eq!(cache.len(), 0);
315 }
316
317 #[test]
318 fn capacity_bound_evicts_oldest_entry() {
319 let cache = NativeH3SessionCache::with_capacity(2, Duration::from_secs(60));
320 let a = key("a", &[b"h3"], true, None);
321 let b = key("b", &[b"h3"], true, None);
322 let c = key("c", &[b"h3"], true, None);
323 cache.insert(a.clone(), Bytes::from_static(b"a"), 0, None);
324 std::thread::sleep(Duration::from_millis(5));
325 cache.insert(b.clone(), Bytes::from_static(b"b"), 0, None);
326 std::thread::sleep(Duration::from_millis(5));
327 cache.insert(c.clone(), Bytes::from_static(b"c"), 0, None);
328 assert_eq!(cache.len(), 2);
329 assert!(cache.get(&a).is_none(), "oldest entry should be evicted");
330 assert!(cache.get(&b).is_some());
331 assert!(cache.get(&c).is_some());
332 }
333}