Skip to main content

vane_core/
conn_context.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::time::Instant;
5
6use bytes::Bytes;
7use parking_lot::Mutex;
8use rustls_pki_types::CertificateDer;
9
10#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConnId(pub u64);
12
13impl std::fmt::Display for ConnId {
14	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15		write!(f, "{:016x}", self.0)
16	}
17}
18
19#[derive(
20	Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
21)]
22pub enum Transport {
23	Tcp,
24	Udp,
25}
26
27#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
28pub enum HttpVersion {
29	Http1_0,
30	Http1_1,
31	Http2,
32	Http3,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
36pub enum TlsVersion {
37	Tls12,
38	Tls13,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct TlsInfo {
43	pub sni: Option<String>,
44	pub alpn: Option<Vec<u8>>,
45	pub version: Option<TlsVersion>,
46	pub peer_cert: Option<Arc<PeerCertificate>>,
47	/// Whether the client's request arrived (in part or wholly) as
48	/// TLS 1.3 0-RTT (early data). Set at handshake completion in the
49	/// engine's `run_tls` from rustls's `is_early_data_accepted()`.
50	/// The L7 executor consults this together with the matched rule's
51	/// `allow_zero_rtt` to decide whether to short-circuit the request
52	/// with a synthetic 425 Too Early. See
53	/// `spec/crates/engine-tls.md` § _TLS 1.3 0-RTT (early data)_.
54	pub zero_rtt_used: bool,
55}
56
57/// Verified client certificate captured at TLS handshake time, with
58/// every predicate-readable field pre-extracted so the per-Check
59/// dispatch is allocation-light. Built once by the engine's
60/// post-handshake population (`run_tls`); the seven
61/// `tls.peer_cert.*` predicates read pre-computed strings off this
62/// struct rather than re-parsing the DER on every test.
63///
64/// `leaf_der` retains the raw DER bytes so future predicates (or a
65/// post-MVP debug surface) can re-derive any field x509-parser
66/// exposes; the seven currently-spec'd fields are pre-extracted.
67///
68/// All `String`-typed fields are byte-for-byte canonical: hex digests
69/// are ASCII-lowercase; `serial` is hex (lowercase, no leading-zero
70/// stripping). See `spec/crates/core.md` §
71/// _Predicate_ for the canonical formats.
72#[derive(Clone, Debug, Default)]
73pub struct PeerCertificate {
74	/// Raw leaf cert DER. Retained for future predicates that need
75	/// fields not pre-extracted; current readers should use the
76	/// pre-extracted scalar fields below.
77	pub leaf_der: Bytes,
78	pub subject_cn: Option<String>,
79	pub san_dns: Vec<String>,
80	pub fingerprint_sha256: String,
81	pub spki_sha256: String,
82	pub issuer_cn: Option<String>,
83	pub serial: String,
84}
85
86impl PeerCertificate {
87	/// Pre-extract every `tls.peer_cert.*` predicate-readable field
88	/// from a raw leaf cert DER. Returns `None` when the bytes are
89	/// not a parseable X.509v3 certificate; the caller treats that as
90	/// "no verified peer cert" (sound-by-default per spec).
91	#[must_use]
92	pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
93		use sha2::{Digest, Sha256};
94		use x509_parser::prelude::*;
95
96		let bytes = leaf_der.as_ref();
97		let (_, cert) = X509Certificate::from_der(bytes).ok()?;
98		let tbs = &cert.tbs_certificate;
99
100		let subject_cn = tbs
101			.subject()
102			.iter_common_name()
103			.next()
104			.and_then(|attr| attr.as_str().ok().map(ToString::to_string));
105		let issuer_cn = tbs
106			.issuer()
107			.iter_common_name()
108			.next()
109			.and_then(|attr| attr.as_str().ok().map(ToString::to_string));
110
111		// SAN dNSName entries — RFC 5280 §4.2.1.6. Other GeneralName
112		// variants (URI, RFC822, etc.) are not exposed via this path
113		// per the predicate-schema table.
114		let mut san_dns: Vec<String> = Vec::new();
115		if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
116			for name in &san_ext.value.general_names {
117				if let GeneralName::DNSName(d) = name {
118					san_dns.push((*d).to_string());
119				}
120			}
121		}
122
123		let mut hasher = Sha256::new();
124		hasher.update(bytes);
125		let fingerprint_sha256 = hex_lower(&hasher.finalize());
126
127		let spki_sha256 = {
128			let spki_der = tbs.subject_pki.raw;
129			let mut h = Sha256::new();
130			h.update(spki_der);
131			hex_lower(&h.finalize())
132		};
133
134		// Serial: x509-parser gives BigUint; canonicalise as
135		// lowercase hex, big-endian, no leading-zero stripping (per
136		// spec). `to_bytes_be` returns the minimal-length
137		// representation; pad nothing — operators copy the value out
138		// verbatim from `openssl x509 -serial` when matching.
139		let serial = hex_lower(&tbs.serial.to_bytes_be());
140
141		Some(Self {
142			leaf_der: Bytes::copy_from_slice(bytes),
143			subject_cn,
144			san_dns,
145			fingerprint_sha256,
146			spki_sha256,
147			issuer_cn,
148			serial,
149		})
150	}
151}
152
153fn hex_lower(bytes: &[u8]) -> String {
154	use std::fmt::Write as _;
155	let mut s = String::with_capacity(bytes.len() * 2);
156	for b in bytes {
157		let _ = write!(s, "{b:02x}");
158	}
159	s
160}
161
162pub struct ConnContext {
163	pub id: ConnId,
164	pub remote: SocketAddr,
165	pub local: SocketAddr,
166	pub transport: Transport,
167	pub entered_at: Instant,
168
169	pub tls: Mutex<Option<TlsInfo>>,
170	pub http_version: OnceLock<HttpVersion>,
171
172	pub user: Mutex<http::Extensions>,
173}
174
175#[cfg(test)]
176mod tests {
177	use super::*;
178
179	#[test]
180	fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
181		let rendered = format!("{}", ConnId(0));
182		assert_eq!(rendered, "0000000000000000");
183		assert_eq!(rendered.len(), 16);
184	}
185
186	#[test]
187	fn conn_id_display_is_lowercase_hex() {
188		let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
189		assert_eq!(rendered, "0badf00ddeadbeef");
190		assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
191	}
192
193	#[test]
194	fn conn_id_display_zero_pads_small_values() {
195		// non-zero top nibble would mean no left padding; a small value exercises
196		// the {:016x} pad path explicitly.
197		let rendered = format!("{}", ConnId(1));
198		assert_eq!(rendered, "0000000000000001");
199	}
200
201	#[test]
202	fn conn_id_display_renders_u64_max() {
203		let rendered = format!("{}", ConnId(u64::MAX));
204		assert_eq!(rendered, "ffffffffffffffff");
205		assert_eq!(rendered.len(), 16);
206	}
207
208	#[test]
209	fn conn_id_serde_round_trip() {
210		let id = ConnId(0x1234_5678_9abc_def0);
211		let encoded = serde_json::to_string(&id).expect("serialize");
212		let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
213		assert_eq!(decoded, id);
214	}
215
216	#[test]
217	fn tls_version_variants_are_exhaustive_at_two() {
218		// Adding a TlsVersion variant without updating this arm would be a
219		// compile error — the spec (spec/crates/engine-tls.md) constrains accepted versions
220		// to 1.2 and 1.3 only.
221		for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
222			let matched = match v {
223				TlsVersion::Tls12 => "1.2",
224				TlsVersion::Tls13 => "1.3",
225			};
226			assert!(!matched.is_empty());
227		}
228	}
229
230	#[test]
231	fn tls_version_serde_round_trip_per_variant() {
232		for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
233			let encoded = serde_json::to_string(&v).expect("serialize");
234			let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
235			assert_eq!(decoded, v);
236		}
237	}
238
239	#[test]
240	fn transport_serde_round_trip_per_variant() {
241		for t in [Transport::Tcp, Transport::Udp] {
242			let encoded = serde_json::to_string(&t).expect("serialize");
243			let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
244			assert_eq!(decoded, t);
245		}
246	}
247
248	#[test]
249	fn http_version_serde_round_trip_per_variant() {
250		for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
251			let encoded = serde_json::to_string(&v).expect("serialize");
252			let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
253			assert_eq!(decoded, v);
254		}
255	}
256}