1use crate::error::{ProxyError, Result};
2use rcgen::{
3 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyPair, KeyUsagePurpose,
4 SanType,
5};
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::{debug, info};
12
13pub struct CertifiedKey {
15 pub cert_der: CertificateDer<'static>,
16 pub key_der: PrivatePkcs8KeyDer<'static>,
17}
18
19pub struct CertificateAuthority {
21 ca_cert: rcgen::Certificate,
22 ca_key: KeyPair,
23 cache: Mutex<HashMap<String, Arc<CertifiedKey>>>,
24}
25
26impl CertificateAuthority {
27 pub async fn new() -> Result<Self> {
29 Self::with_dir(Self::ca_dir()?).await
30 }
31
32 pub async fn with_dir(dir: PathBuf) -> Result<Self> {
34 tokio::fs::create_dir_all(&dir).await?;
35
36 let cert_path = dir.join("ca.pem");
37 let key_path = dir.join("ca-key.pem");
38
39 let cert_exists = cert_path.exists();
40 let key_exists = key_path.exists();
41
42 if cert_exists != key_exists {
44 return Err(ProxyError::Other(format!(
45 "Partial CA state in {}: {} exists but {} is missing. \
46 Restore the missing file or remove both to reinitialize.",
47 dir.display(),
48 if cert_exists { "ca.pem" } else { "ca-key.pem" },
49 if cert_exists { "ca-key.pem" } else { "ca.pem" },
50 )));
51 }
52
53 let (ca_cert, ca_key) = if cert_exists {
54 info!("Loading existing CA certificate from {}", dir.display());
55 Self::load_ca(&cert_path, &key_path).await?
56 } else {
57 info!("Generating new CA certificate in {}", dir.display());
58 let (cert, key) = Self::generate_ca()?;
59 Self::save_ca(&cert, &key, &cert_path, &key_path).await?;
60 (cert, key)
61 };
62
63 Ok(Self {
64 ca_cert,
65 ca_key,
66 cache: Mutex::new(HashMap::new()),
67 })
68 }
69
70 pub fn ca_cert_path() -> Result<PathBuf> {
72 Ok(Self::ca_dir()?.join("ca.pem"))
73 }
74
75 pub async fn get_or_create_cert(&self, domain: &str) -> Result<Arc<CertifiedKey>> {
77 {
78 let cache = self.cache.lock().await;
79 if let Some(ck) = cache.get(domain) {
80 debug!("Using cached certificate for {domain}");
81 return Ok(ck.clone());
82 }
83 }
84
85 debug!("Generating certificate for {domain}");
86 let ck = self.generate_domain_cert(domain)?;
87 let ck = Arc::new(ck);
88
89 {
90 let mut cache = self.cache.lock().await;
91 cache.insert(domain.to_string(), ck.clone());
92 }
93
94 Ok(ck)
95 }
96
97 fn ca_dir() -> Result<PathBuf> {
98 let home = std::env::var("HOME")
99 .map_err(|_| ProxyError::Other("HOME environment variable not set".into()))?;
100 Ok(PathBuf::from(home).join(".rustgate"))
101 }
102
103 fn generate_ca() -> Result<(rcgen::Certificate, KeyPair)> {
104 let mut params = CertificateParams::default();
105 let mut dn = DistinguishedName::new();
106 dn.push(DnType::CommonName, "RustGate CA");
107 dn.push(DnType::OrganizationName, "RustGate");
108 params.distinguished_name = dn;
109 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
110 params.key_usages = vec![
111 KeyUsagePurpose::KeyCertSign,
112 KeyUsagePurpose::CrlSign,
113 ];
114
115 let key = KeyPair::generate()?;
116 let cert = params.self_signed(&key)?;
117 Ok((cert, key))
118 }
119
120 async fn save_ca(
121 cert: &rcgen::Certificate,
122 key: &KeyPair,
123 cert_path: &PathBuf,
124 key_path: &PathBuf,
125 ) -> Result<()> {
126 tokio::fs::write(cert_path, cert.pem()).await?;
127 tokio::fs::write(key_path, key.serialize_pem()).await?;
128
129 #[cfg(unix)]
131 {
132 use std::os::unix::fs::PermissionsExt;
133 let perms = std::fs::Permissions::from_mode(0o600);
134 tokio::fs::set_permissions(key_path, perms).await?;
135 }
136
137 Ok(())
138 }
139
140 async fn load_ca(
141 cert_path: &PathBuf,
142 key_path: &PathBuf,
143 ) -> Result<(rcgen::Certificate, KeyPair)> {
144 let key_pem = tokio::fs::read_to_string(key_path).await?;
145 let key = KeyPair::from_pem(&key_pem)?;
146
147 let cert_pem = tokio::fs::read_to_string(cert_path).await?;
148 let params = CertificateParams::from_ca_cert_pem(&cert_pem)?;
149
150 let cert = params.self_signed(&key)?;
154
155 let original_der = Self::pem_to_der(&cert_pem)?;
156 let regenerated_der = cert.der().to_vec();
157 let original_spki = Self::extract_spki(&original_der)?;
158 let regenerated_spki = Self::extract_spki(®enerated_der)?;
159 if original_spki != regenerated_spki {
160 return Err(ProxyError::Other(
161 "CA certificate and private key do not match: \
162 public key in ca.pem differs from ca-key.pem"
163 .into(),
164 ));
165 }
166
167 Ok((cert, key))
168 }
169
170 fn pem_to_der(pem_str: &str) -> Result<Vec<u8>> {
172 let mut reader = std::io::BufReader::new(pem_str.as_bytes());
173 let certs = rustls_pemfile::certs(&mut reader)
174 .collect::<std::result::Result<Vec<_>, _>>()?;
175 certs
176 .into_iter()
177 .next()
178 .map(|c| c.to_vec())
179 .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
180 }
181
182 fn extract_spki(der: &[u8]) -> Result<Vec<u8>> {
185 let tbs = Self::asn1_sequence_contents(der)?;
187 let tbs_inner = Self::asn1_sequence_contents(tbs)?;
190
191 let mut pos = 0;
192 for i in 0..6 {
194 if pos >= tbs_inner.len() {
195 return Err(ProxyError::Other(
196 format!("Unexpected end of TBSCertificate at field {i}"),
197 ));
198 }
199 let (_, field_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
200 pos += field_len;
201 }
202
203 if pos >= tbs_inner.len() {
205 return Err(ProxyError::Other(
206 "SubjectPublicKeyInfo not found in certificate".into(),
207 ));
208 }
209 let (_, spki_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
210 Ok(tbs_inner[pos..pos + spki_len].to_vec())
211 }
212
213 fn asn1_sequence_contents(data: &[u8]) -> Result<&[u8]> {
215 if data.is_empty() || (data[0] & 0x1f) != 0x10 {
216 return Err(ProxyError::Other("Expected ASN.1 SEQUENCE".into()));
217 }
218 let (header_len, total_len) = Self::asn1_read_tag_and_length(data)?;
219 let content_len = total_len - header_len;
220 Ok(&data[header_len..header_len + content_len])
221 }
222
223 fn asn1_read_tag_and_length(data: &[u8]) -> Result<(usize, usize)> {
225 if data.len() < 2 {
226 return Err(ProxyError::Other("ASN.1 data too short".into()));
227 }
228 let mut pos = 1; let length_byte = data[pos];
230 pos += 1;
231
232 let content_len = if length_byte & 0x80 == 0 {
233 length_byte as usize
234 } else {
235 let num_bytes = (length_byte & 0x7f) as usize;
236 if pos + num_bytes > data.len() {
237 return Err(ProxyError::Other("ASN.1 length overflow".into()));
238 }
239 let mut len = 0usize;
240 for &b in &data[pos..pos + num_bytes] {
241 len = (len << 8) | b as usize;
242 }
243 pos += num_bytes;
244 len
245 };
246
247 let total_len = pos + content_len;
248 if total_len > data.len() {
249 return Err(ProxyError::Other(
250 "ASN.1 element extends beyond input data".into(),
251 ));
252 }
253
254 Ok((pos, total_len))
255 }
256
257 pub fn generate_client_cert(&self, cn: &str) -> Result<(String, String)> {
260 let mut params = CertificateParams::default();
261 let mut dn = DistinguishedName::new();
262 dn.push(DnType::CommonName, cn);
263 dn.push(DnType::OrganizationName, "RustGate");
264 params.distinguished_name = dn;
265 params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ClientAuth];
266
267 let key = KeyPair::generate()?;
268 let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
269
270 Ok((cert.pem(), key.serialize_pem()))
271 }
272
273 pub fn generate_server_cert(&self, host: &str) -> Result<CertifiedKey> {
275 let mut params = CertificateParams::new(vec![host.to_string()])?;
276 let mut dn = DistinguishedName::new();
277 dn.push(DnType::CommonName, host);
278 params.distinguished_name = dn;
279 params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];
280
281 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
282 params.subject_alt_names = vec![SanType::IpAddress(ip)];
283 }
284
285 let key = KeyPair::generate()?;
286 let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
287
288 let cert_der = CertificateDer::from(cert.der().to_vec());
289 let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
290 Ok(CertifiedKey { cert_der, key_der })
291 }
292
293 pub fn ca_cert_der(&self) -> CertificateDer<'static> {
295 CertificateDer::from(self.ca_cert.der().to_vec())
296 }
297
298 fn generate_domain_cert(&self, domain: &str) -> Result<CertifiedKey> {
299 let mut params = CertificateParams::new(vec![domain.to_string()])?;
300 let mut dn = DistinguishedName::new();
301 dn.push(DnType::CommonName, domain);
302 params.distinguished_name = dn;
303
304 if let Ok(ip) = domain.parse::<std::net::IpAddr>() {
307 params.subject_alt_names = vec![SanType::IpAddress(ip)];
308 }
309
310 let key = KeyPair::generate()?;
311 let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
312
313 let cert_der = CertificateDer::from(cert.der().to_vec());
314 let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
315
316 Ok(CertifiedKey { cert_der, key_der })
317 }
318}