1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::time::Instant;
5
6use bytes::Bytes;
7use parking_lot::Mutex;
8use rustls_pki_types::CertificateDer;
9
10#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConnId(pub u64);
12
13impl std::fmt::Display for ConnId {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(f, "{:016x}", self.0)
16 }
17}
18
19#[derive(
20 Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
21)]
22pub enum Transport {
23 Tcp,
24 Udp,
25}
26
27#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
28pub enum HttpVersion {
29 Http1_0,
30 Http1_1,
31 Http2,
32 Http3,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
36pub enum TlsVersion {
37 Tls12,
38 Tls13,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct TlsInfo {
43 pub sni: Option<String>,
44 pub alpn: Option<Vec<u8>>,
45 pub version: Option<TlsVersion>,
46 pub peer_cert: Option<Arc<PeerCertificate>>,
47 pub zero_rtt_used: bool,
55}
56
57#[derive(Clone, Debug, Default)]
73pub struct PeerCertificate {
74 pub leaf_der: Bytes,
78 pub subject_cn: Option<String>,
79 pub san_dns: Vec<String>,
80 pub fingerprint_sha256: String,
81 pub spki_sha256: String,
82 pub issuer_cn: Option<String>,
83 pub serial: String,
84}
85
86impl PeerCertificate {
87 #[must_use]
92 pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
93 use sha2::{Digest, Sha256};
94 use x509_parser::prelude::*;
95
96 let bytes = leaf_der.as_ref();
97 let (_, cert) = X509Certificate::from_der(bytes).ok()?;
98 let tbs = &cert.tbs_certificate;
99
100 let subject_cn = tbs
101 .subject()
102 .iter_common_name()
103 .next()
104 .and_then(|attr| attr.as_str().ok().map(ToString::to_string));
105 let issuer_cn = tbs
106 .issuer()
107 .iter_common_name()
108 .next()
109 .and_then(|attr| attr.as_str().ok().map(ToString::to_string));
110
111 let mut san_dns: Vec<String> = Vec::new();
115 if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
116 for name in &san_ext.value.general_names {
117 if let GeneralName::DNSName(d) = name {
118 san_dns.push((*d).to_string());
119 }
120 }
121 }
122
123 let mut hasher = Sha256::new();
124 hasher.update(bytes);
125 let fingerprint_sha256 = hex_lower(&hasher.finalize());
126
127 let spki_sha256 = {
128 let spki_der = tbs.subject_pki.raw;
129 let mut h = Sha256::new();
130 h.update(spki_der);
131 hex_lower(&h.finalize())
132 };
133
134 let serial = hex_lower(&tbs.serial.to_bytes_be());
140
141 Some(Self {
142 leaf_der: Bytes::copy_from_slice(bytes),
143 subject_cn,
144 san_dns,
145 fingerprint_sha256,
146 spki_sha256,
147 issuer_cn,
148 serial,
149 })
150 }
151}
152
153fn hex_lower(bytes: &[u8]) -> String {
154 use std::fmt::Write as _;
155 let mut s = String::with_capacity(bytes.len() * 2);
156 for b in bytes {
157 let _ = write!(s, "{b:02x}");
158 }
159 s
160}
161
162pub struct ConnContext {
163 pub id: ConnId,
164 pub remote: SocketAddr,
165 pub local: SocketAddr,
166 pub transport: Transport,
167 pub entered_at: Instant,
168
169 pub tls: Mutex<Option<TlsInfo>>,
170 pub http_version: OnceLock<HttpVersion>,
171
172 pub user: Mutex<http::Extensions>,
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
181 let rendered = format!("{}", ConnId(0));
182 assert_eq!(rendered, "0000000000000000");
183 assert_eq!(rendered.len(), 16);
184 }
185
186 #[test]
187 fn conn_id_display_is_lowercase_hex() {
188 let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
189 assert_eq!(rendered, "0badf00ddeadbeef");
190 assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
191 }
192
193 #[test]
194 fn conn_id_display_zero_pads_small_values() {
195 let rendered = format!("{}", ConnId(1));
198 assert_eq!(rendered, "0000000000000001");
199 }
200
201 #[test]
202 fn conn_id_display_renders_u64_max() {
203 let rendered = format!("{}", ConnId(u64::MAX));
204 assert_eq!(rendered, "ffffffffffffffff");
205 assert_eq!(rendered.len(), 16);
206 }
207
208 #[test]
209 fn conn_id_serde_round_trip() {
210 let id = ConnId(0x1234_5678_9abc_def0);
211 let encoded = serde_json::to_string(&id).expect("serialize");
212 let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
213 assert_eq!(decoded, id);
214 }
215
216 #[test]
217 fn tls_version_variants_are_exhaustive_at_two() {
218 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
222 let matched = match v {
223 TlsVersion::Tls12 => "1.2",
224 TlsVersion::Tls13 => "1.3",
225 };
226 assert!(!matched.is_empty());
227 }
228 }
229
230 #[test]
231 fn tls_version_serde_round_trip_per_variant() {
232 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
233 let encoded = serde_json::to_string(&v).expect("serialize");
234 let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
235 assert_eq!(decoded, v);
236 }
237 }
238
239 #[test]
240 fn transport_serde_round_trip_per_variant() {
241 for t in [Transport::Tcp, Transport::Udp] {
242 let encoded = serde_json::to_string(&t).expect("serialize");
243 let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
244 assert_eq!(decoded, t);
245 }
246 }
247
248 #[test]
249 fn http_version_serde_round_trip_per_variant() {
250 for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
251 let encoded = serde_json::to_string(&v).expect("serialize");
252 let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
253 assert_eq!(decoded, v);
254 }
255 }
256}