1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::time::Instant;
5
6use bytes::Bytes;
7use parking_lot::Mutex;
8use rustls_pki_types::CertificateDer;
9
10#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConnId(pub u64);
12
13impl std::fmt::Display for ConnId {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(f, "{:016x}", self.0)
16 }
17}
18
19#[derive(
20 Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
21)]
22pub enum Transport {
23 Tcp,
24 Udp,
25}
26
27#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
28pub enum HttpVersion {
29 Http1_0,
30 Http1_1,
31 Http2,
32 Http3,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
36pub enum TlsVersion {
37 Tls12,
38 Tls13,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct TlsInfo {
43 pub sni: Option<Arc<str>>,
48 pub alpn: Option<Arc<[u8]>>,
51 pub version: Option<TlsVersion>,
52 pub peer_cert: Option<Arc<PeerCertificate>>,
53 pub zero_rtt_used: bool,
61}
62
63#[derive(Clone, Debug, Default)]
79pub struct PeerCertificate {
80 pub leaf_der: Bytes,
84 pub subject_cn: Option<Arc<str>>,
90 pub san_dns: Arc<[Arc<str>]>,
91 pub fingerprint_sha256: Arc<str>,
92 pub spki_sha256: Arc<str>,
93 pub issuer_cn: Option<Arc<str>>,
94 pub serial: Arc<str>,
95}
96
97impl PeerCertificate {
98 #[must_use]
103 pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
104 use sha2::{Digest, Sha256};
105 use x509_parser::prelude::*;
106
107 let bytes = leaf_der.as_ref();
108 let (_, cert) = X509Certificate::from_der(bytes).ok()?;
109 let tbs = &cert.tbs_certificate;
110
111 let subject_cn =
112 tbs.subject().iter_common_name().next().and_then(|attr| attr.as_str().ok().map(Arc::from));
113 let issuer_cn =
114 tbs.issuer().iter_common_name().next().and_then(|attr| attr.as_str().ok().map(Arc::from));
115
116 let mut san_dns: Vec<Arc<str>> = Vec::new();
120 if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
121 for name in &san_ext.value.general_names {
122 if let GeneralName::DNSName(d) = name {
123 san_dns.push(Arc::from(*d));
124 }
125 }
126 }
127 let san_dns: Arc<[Arc<str>]> = san_dns.into();
128
129 let mut hasher = Sha256::new();
130 hasher.update(bytes);
131 let fingerprint_sha256: Arc<str> = Arc::from(hex_lower(&hasher.finalize()));
132
133 let spki_sha256: Arc<str> = {
134 let spki_der = tbs.subject_pki.raw;
135 let mut h = Sha256::new();
136 h.update(spki_der);
137 Arc::from(hex_lower(&h.finalize()))
138 };
139
140 let serial: Arc<str> = Arc::from(hex_lower(&tbs.serial.to_bytes_be()));
146
147 Some(Self {
148 leaf_der: Bytes::copy_from_slice(bytes),
149 subject_cn,
150 san_dns,
151 fingerprint_sha256,
152 spki_sha256,
153 issuer_cn,
154 serial,
155 })
156 }
157}
158
159fn hex_lower(bytes: &[u8]) -> String {
160 use std::fmt::Write as _;
161 let mut s = String::with_capacity(bytes.len() * 2);
162 for b in bytes {
163 let _ = write!(s, "{b:02x}");
164 }
165 s
166}
167
168#[non_exhaustive]
177pub struct ConnContext {
178 pub id: ConnId,
179 pub remote: SocketAddr,
180 pub local: SocketAddr,
181 pub transport: Transport,
182 pub entered_at: Instant,
183
184 pub tls: Mutex<Option<TlsInfo>>,
185 pub http_version: OnceLock<HttpVersion>,
186
187 pub user: Mutex<http::Extensions>,
188}
189
190impl ConnContext {
191 #[must_use]
199 pub fn new(
200 id: ConnId,
201 remote: SocketAddr,
202 local: SocketAddr,
203 transport: Transport,
204 entered_at: Instant,
205 ) -> Self {
206 Self {
207 id,
208 remote,
209 local,
210 transport,
211 entered_at,
212 tls: Mutex::new(None),
213 http_version: OnceLock::new(),
214 user: Mutex::new(http::Extensions::new()),
215 }
216 }
217
218 pub fn tls(&self) -> parking_lot::MutexGuard<'_, Option<TlsInfo>> {
222 self.tls.lock()
223 }
224
225 pub fn with_user<R>(&self, f: impl FnOnce(&mut http::Extensions) -> R) -> R {
231 let mut guard = self.user.lock();
232 f(&mut guard)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
242 let rendered = format!("{}", ConnId(0));
243 assert_eq!(rendered, "0000000000000000");
244 assert_eq!(rendered.len(), 16);
245 }
246
247 #[test]
248 fn conn_id_display_is_lowercase_hex() {
249 let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
250 assert_eq!(rendered, "0badf00ddeadbeef");
251 assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
252 }
253
254 #[test]
255 fn conn_id_display_zero_pads_small_values() {
256 let rendered = format!("{}", ConnId(1));
259 assert_eq!(rendered, "0000000000000001");
260 }
261
262 #[test]
263 fn conn_id_display_renders_u64_max() {
264 let rendered = format!("{}", ConnId(u64::MAX));
265 assert_eq!(rendered, "ffffffffffffffff");
266 assert_eq!(rendered.len(), 16);
267 }
268
269 #[test]
270 fn conn_id_serde_round_trip() {
271 let id = ConnId(0x1234_5678_9abc_def0);
272 let encoded = serde_json::to_string(&id).expect("serialize");
273 let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
274 assert_eq!(decoded, id);
275 }
276
277 #[test]
278 fn tls_version_variants_are_exhaustive_at_two() {
279 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
283 let matched = match v {
284 TlsVersion::Tls12 => "1.2",
285 TlsVersion::Tls13 => "1.3",
286 };
287 assert!(!matched.is_empty());
288 }
289 }
290
291 #[test]
292 fn tls_version_serde_round_trip_per_variant() {
293 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
294 let encoded = serde_json::to_string(&v).expect("serialize");
295 let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
296 assert_eq!(decoded, v);
297 }
298 }
299
300 #[test]
301 fn transport_serde_round_trip_per_variant() {
302 for t in [Transport::Tcp, Transport::Udp] {
303 let encoded = serde_json::to_string(&t).expect("serialize");
304 let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
305 assert_eq!(decoded, t);
306 }
307 }
308
309 #[test]
310 fn http_version_serde_round_trip_per_variant() {
311 for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
312 let encoded = serde_json::to_string(&v).expect("serialize");
313 let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
314 assert_eq!(decoded, v);
315 }
316 }
317}