1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::time::Instant;
5
6use bytes::Bytes;
7use parking_lot::Mutex;
8use rustls_pki_types::CertificateDer;
9
10#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConnId(pub u64);
12
13impl std::fmt::Display for ConnId {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(f, "{:016x}", self.0)
16 }
17}
18
19#[derive(
20 Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
21)]
22pub enum Transport {
23 Tcp,
24 Udp,
25}
26
27#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
28pub enum HttpVersion {
29 Http1_0,
30 Http1_1,
31 Http2,
32 Http3,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
36pub enum TlsVersion {
37 Tls12,
38 Tls13,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct TlsInfo {
43 pub sni: Option<String>,
44 pub alpn: Option<Vec<u8>>,
45 pub version: Option<TlsVersion>,
46 pub peer_cert: Option<Arc<PeerCertificate>>,
47}
48
49#[derive(Clone, Debug, Default)]
65pub struct PeerCertificate {
66 pub leaf_der: Bytes,
70 pub subject_cn: Option<String>,
71 pub san_dns: Vec<String>,
72 pub fingerprint_sha256: String,
73 pub spki_sha256: String,
74 pub issuer_cn: Option<String>,
75 pub serial: String,
76}
77
78impl PeerCertificate {
79 #[must_use]
84 pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
85 use sha2::{Digest, Sha256};
86 use x509_parser::prelude::*;
87
88 let bytes = leaf_der.as_ref();
89 let (_, cert) = X509Certificate::from_der(bytes).ok()?;
90 let tbs = &cert.tbs_certificate;
91
92 let subject_cn = tbs
93 .subject()
94 .iter_common_name()
95 .next()
96 .and_then(|attr| attr.as_str().ok().map(ToString::to_string));
97 let issuer_cn = tbs
98 .issuer()
99 .iter_common_name()
100 .next()
101 .and_then(|attr| attr.as_str().ok().map(ToString::to_string));
102
103 let mut san_dns: Vec<String> = Vec::new();
107 if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
108 for name in &san_ext.value.general_names {
109 if let GeneralName::DNSName(d) = name {
110 san_dns.push((*d).to_string());
111 }
112 }
113 }
114
115 let mut hasher = Sha256::new();
116 hasher.update(bytes);
117 let fingerprint_sha256 = hex_lower(&hasher.finalize());
118
119 let spki_sha256 = {
120 let spki_der = tbs.subject_pki.raw;
121 let mut h = Sha256::new();
122 h.update(spki_der);
123 hex_lower(&h.finalize())
124 };
125
126 let serial = hex_lower(&tbs.serial.to_bytes_be());
132
133 Some(Self {
134 leaf_der: Bytes::copy_from_slice(bytes),
135 subject_cn,
136 san_dns,
137 fingerprint_sha256,
138 spki_sha256,
139 issuer_cn,
140 serial,
141 })
142 }
143}
144
145fn hex_lower(bytes: &[u8]) -> String {
146 use std::fmt::Write as _;
147 let mut s = String::with_capacity(bytes.len() * 2);
148 for b in bytes {
149 let _ = write!(s, "{b:02x}");
150 }
151 s
152}
153
154pub struct ConnContext {
155 pub id: ConnId,
156 pub remote: SocketAddr,
157 pub local: SocketAddr,
158 pub transport: Transport,
159 pub entered_at: Instant,
160
161 pub tls: Mutex<Option<TlsInfo>>,
162 pub http_version: OnceLock<HttpVersion>,
163
164 pub user: Mutex<http::Extensions>,
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
173 let rendered = format!("{}", ConnId(0));
174 assert_eq!(rendered, "0000000000000000");
175 assert_eq!(rendered.len(), 16);
176 }
177
178 #[test]
179 fn conn_id_display_is_lowercase_hex() {
180 let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
181 assert_eq!(rendered, "0badf00ddeadbeef");
182 assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
183 }
184
185 #[test]
186 fn conn_id_display_zero_pads_small_values() {
187 let rendered = format!("{}", ConnId(1));
190 assert_eq!(rendered, "0000000000000001");
191 }
192
193 #[test]
194 fn conn_id_display_renders_u64_max() {
195 let rendered = format!("{}", ConnId(u64::MAX));
196 assert_eq!(rendered, "ffffffffffffffff");
197 assert_eq!(rendered.len(), 16);
198 }
199
200 #[test]
201 fn conn_id_serde_round_trip() {
202 let id = ConnId(0x1234_5678_9abc_def0);
203 let encoded = serde_json::to_string(&id).expect("serialize");
204 let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
205 assert_eq!(decoded, id);
206 }
207
208 #[test]
209 fn tls_version_variants_are_exhaustive_at_two() {
210 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
214 let matched = match v {
215 TlsVersion::Tls12 => "1.2",
216 TlsVersion::Tls13 => "1.3",
217 };
218 assert!(!matched.is_empty());
219 }
220 }
221
222 #[test]
223 fn tls_version_serde_round_trip_per_variant() {
224 for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
225 let encoded = serde_json::to_string(&v).expect("serialize");
226 let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
227 assert_eq!(decoded, v);
228 }
229 }
230
231 #[test]
232 fn transport_serde_round_trip_per_variant() {
233 for t in [Transport::Tcp, Transport::Udp] {
234 let encoded = serde_json::to_string(&t).expect("serialize");
235 let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
236 assert_eq!(decoded, t);
237 }
238 }
239
240 #[test]
241 fn http_version_serde_round_trip_per_variant() {
242 for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
243 let encoded = serde_json::to_string(&v).expect("serialize");
244 let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
245 assert_eq!(decoded, v);
246 }
247 }
248}