Skip to main content

vane_core/
conn_context.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::time::Instant;
5
6use bytes::Bytes;
7use parking_lot::Mutex;
8use rustls_pki_types::CertificateDer;
9
10#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConnId(pub u64);
12
13impl std::fmt::Display for ConnId {
14	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15		write!(f, "{:016x}", self.0)
16	}
17}
18
19#[derive(
20	Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
21)]
22pub enum Transport {
23	Tcp,
24	Udp,
25}
26
27#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
28pub enum HttpVersion {
29	Http1_0,
30	Http1_1,
31	Http2,
32	Http3,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
36pub enum TlsVersion {
37	Tls12,
38	Tls13,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct TlsInfo {
43	pub sni: Option<String>,
44	pub alpn: Option<Vec<u8>>,
45	pub version: Option<TlsVersion>,
46	pub peer_cert: Option<Arc<PeerCertificate>>,
47}
48
49/// Verified client certificate captured at TLS handshake time, with
50/// every predicate-readable field pre-extracted so the per-Check
51/// dispatch is allocation-light. Built once by the engine's
52/// post-handshake population (`run_tls`); the seven
53/// `tls.peer_cert.*` predicates read pre-computed strings off this
54/// struct rather than re-parsing the DER on every test.
55///
56/// `leaf_der` retains the raw DER bytes so future predicates (or a
57/// post-MVP debug surface) can re-derive any field x509-parser
58/// exposes; the seven currently-spec'd fields are pre-extracted.
59///
60/// All `String`-typed fields are byte-for-byte canonical: hex digests
61/// are ASCII-lowercase; `serial` is hex (lowercase, no leading-zero
62/// stripping). See `spec/architecture/18-predicate-schema.md` §
63/// _Authoritative field paths_ for the canonical formats.
64#[derive(Clone, Debug, Default)]
65pub struct PeerCertificate {
66	/// Raw leaf cert DER. Retained for future predicates that need
67	/// fields not pre-extracted; current readers should use the
68	/// pre-extracted scalar fields below.
69	pub leaf_der: Bytes,
70	pub subject_cn: Option<String>,
71	pub san_dns: Vec<String>,
72	pub fingerprint_sha256: String,
73	pub spki_sha256: String,
74	pub issuer_cn: Option<String>,
75	pub serial: String,
76}
77
78impl PeerCertificate {
79	/// Pre-extract every `tls.peer_cert.*` predicate-readable field
80	/// from a raw leaf cert DER. Returns `None` when the bytes are
81	/// not a parseable X.509v3 certificate; the caller treats that as
82	/// "no verified peer cert" (sound-by-default per spec).
83	#[must_use]
84	pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
85		use sha2::{Digest, Sha256};
86		use x509_parser::prelude::*;
87
88		let bytes = leaf_der.as_ref();
89		let (_, cert) = X509Certificate::from_der(bytes).ok()?;
90		let tbs = &cert.tbs_certificate;
91
92		let subject_cn = tbs
93			.subject()
94			.iter_common_name()
95			.next()
96			.and_then(|attr| attr.as_str().ok().map(ToString::to_string));
97		let issuer_cn = tbs
98			.issuer()
99			.iter_common_name()
100			.next()
101			.and_then(|attr| attr.as_str().ok().map(ToString::to_string));
102
103		// SAN dNSName entries — RFC 5280 §4.2.1.6. Other GeneralName
104		// variants (URI, RFC822, etc.) are not exposed via this path
105		// per the predicate-schema table.
106		let mut san_dns: Vec<String> = Vec::new();
107		if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
108			for name in &san_ext.value.general_names {
109				if let GeneralName::DNSName(d) = name {
110					san_dns.push((*d).to_string());
111				}
112			}
113		}
114
115		let mut hasher = Sha256::new();
116		hasher.update(bytes);
117		let fingerprint_sha256 = hex_lower(&hasher.finalize());
118
119		let spki_sha256 = {
120			let spki_der = tbs.subject_pki.raw;
121			let mut h = Sha256::new();
122			h.update(spki_der);
123			hex_lower(&h.finalize())
124		};
125
126		// Serial: x509-parser gives BigUint; canonicalise as
127		// lowercase hex, big-endian, no leading-zero stripping (per
128		// spec). `to_bytes_be` returns the minimal-length
129		// representation; pad nothing — operators copy the value out
130		// verbatim from `openssl x509 -serial` when matching.
131		let serial = hex_lower(&tbs.serial.to_bytes_be());
132
133		Some(Self {
134			leaf_der: Bytes::copy_from_slice(bytes),
135			subject_cn,
136			san_dns,
137			fingerprint_sha256,
138			spki_sha256,
139			issuer_cn,
140			serial,
141		})
142	}
143}
144
145fn hex_lower(bytes: &[u8]) -> String {
146	use std::fmt::Write as _;
147	let mut s = String::with_capacity(bytes.len() * 2);
148	for b in bytes {
149		let _ = write!(s, "{b:02x}");
150	}
151	s
152}
153
154pub struct ConnContext {
155	pub id: ConnId,
156	pub remote: SocketAddr,
157	pub local: SocketAddr,
158	pub transport: Transport,
159	pub entered_at: Instant,
160
161	pub tls: Mutex<Option<TlsInfo>>,
162	pub http_version: OnceLock<HttpVersion>,
163
164	pub user: Mutex<http::Extensions>,
165}
166
167#[cfg(test)]
168mod tests {
169	use super::*;
170
171	#[test]
172	fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
173		let rendered = format!("{}", ConnId(0));
174		assert_eq!(rendered, "0000000000000000");
175		assert_eq!(rendered.len(), 16);
176	}
177
178	#[test]
179	fn conn_id_display_is_lowercase_hex() {
180		let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
181		assert_eq!(rendered, "0badf00ddeadbeef");
182		assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
183	}
184
185	#[test]
186	fn conn_id_display_zero_pads_small_values() {
187		// non-zero top nibble would mean no left padding; a small value exercises
188		// the {:016x} pad path explicitly.
189		let rendered = format!("{}", ConnId(1));
190		assert_eq!(rendered, "0000000000000001");
191	}
192
193	#[test]
194	fn conn_id_display_renders_u64_max() {
195		let rendered = format!("{}", ConnId(u64::MAX));
196		assert_eq!(rendered, "ffffffffffffffff");
197		assert_eq!(rendered.len(), 16);
198	}
199
200	#[test]
201	fn conn_id_serde_round_trip() {
202		let id = ConnId(0x1234_5678_9abc_def0);
203		let encoded = serde_json::to_string(&id).expect("serialize");
204		let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
205		assert_eq!(decoded, id);
206	}
207
208	#[test]
209	fn tls_version_variants_are_exhaustive_at_two() {
210		// Adding a TlsVersion variant without updating this arm would be a
211		// compile error — the spec (08-tls.md) constrains accepted versions
212		// to 1.2 and 1.3 only.
213		for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
214			let matched = match v {
215				TlsVersion::Tls12 => "1.2",
216				TlsVersion::Tls13 => "1.3",
217			};
218			assert!(!matched.is_empty());
219		}
220	}
221
222	#[test]
223	fn tls_version_serde_round_trip_per_variant() {
224		for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
225			let encoded = serde_json::to_string(&v).expect("serialize");
226			let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
227			assert_eq!(decoded, v);
228		}
229	}
230
231	#[test]
232	fn transport_serde_round_trip_per_variant() {
233		for t in [Transport::Tcp, Transport::Udp] {
234			let encoded = serde_json::to_string(&t).expect("serialize");
235			let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
236			assert_eq!(decoded, t);
237		}
238	}
239
240	#[test]
241	fn http_version_serde_round_trip_per_variant() {
242		for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
243			let encoded = serde_json::to_string(&v).expect("serialize");
244			let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
245			assert_eq!(decoded, v);
246		}
247	}
248}