1use crate::error::{ProxyError, Result};
2use rcgen::{
3 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyPair, KeyUsagePurpose,
4 SanType,
5};
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::{debug, info};
12
13pub struct CertifiedKey {
15 pub cert_der: CertificateDer<'static>,
16 pub key_der: PrivatePkcs8KeyDer<'static>,
17}
18
19pub struct CertificateAuthority {
21 ca_cert: rcgen::Certificate,
22 ca_key: KeyPair,
23 cache: Mutex<HashMap<String, Arc<CertifiedKey>>>,
24}
25
26impl CertificateAuthority {
27 pub async fn new() -> Result<Self> {
29 Self::with_dir(Self::ca_dir()?).await
30 }
31
32 pub async fn with_dir(dir: PathBuf) -> Result<Self> {
34 tokio::fs::create_dir_all(&dir).await?;
35
36 let cert_path = dir.join("ca.pem");
37 let key_path = dir.join("ca-key.pem");
38
39 let (ca_cert, ca_key) = if cert_path.exists() && key_path.exists() {
40 info!("Loading existing CA certificate from {}", dir.display());
41 Self::load_ca(&cert_path, &key_path).await?
42 } else {
43 info!("Generating new CA certificate in {}", dir.display());
44 let (cert, key) = Self::generate_ca()?;
45 Self::save_ca(&cert, &key, &cert_path, &key_path).await?;
46 (cert, key)
47 };
48
49 Ok(Self {
50 ca_cert,
51 ca_key,
52 cache: Mutex::new(HashMap::new()),
53 })
54 }
55
56 pub fn ca_cert_path() -> Result<PathBuf> {
58 Ok(Self::ca_dir()?.join("ca.pem"))
59 }
60
61 pub async fn get_or_create_cert(&self, domain: &str) -> Result<Arc<CertifiedKey>> {
63 {
64 let cache = self.cache.lock().await;
65 if let Some(ck) = cache.get(domain) {
66 debug!("Using cached certificate for {domain}");
67 return Ok(ck.clone());
68 }
69 }
70
71 debug!("Generating certificate for {domain}");
72 let ck = self.generate_domain_cert(domain)?;
73 let ck = Arc::new(ck);
74
75 {
76 let mut cache = self.cache.lock().await;
77 cache.insert(domain.to_string(), ck.clone());
78 }
79
80 Ok(ck)
81 }
82
83 fn ca_dir() -> Result<PathBuf> {
84 let home = std::env::var("HOME")
85 .map_err(|_| ProxyError::Other("HOME environment variable not set".into()))?;
86 Ok(PathBuf::from(home).join(".rustgate"))
87 }
88
89 fn generate_ca() -> Result<(rcgen::Certificate, KeyPair)> {
90 let mut params = CertificateParams::default();
91 let mut dn = DistinguishedName::new();
92 dn.push(DnType::CommonName, "RustGate CA");
93 dn.push(DnType::OrganizationName, "RustGate");
94 params.distinguished_name = dn;
95 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
96 params.key_usages = vec![
97 KeyUsagePurpose::KeyCertSign,
98 KeyUsagePurpose::CrlSign,
99 ];
100
101 let key = KeyPair::generate()?;
102 let cert = params.self_signed(&key)?;
103 Ok((cert, key))
104 }
105
106 async fn save_ca(
107 cert: &rcgen::Certificate,
108 key: &KeyPair,
109 cert_path: &PathBuf,
110 key_path: &PathBuf,
111 ) -> Result<()> {
112 tokio::fs::write(cert_path, cert.pem()).await?;
113 tokio::fs::write(key_path, key.serialize_pem()).await?;
114
115 #[cfg(unix)]
117 {
118 use std::os::unix::fs::PermissionsExt;
119 let perms = std::fs::Permissions::from_mode(0o600);
120 tokio::fs::set_permissions(key_path, perms).await?;
121 }
122
123 Ok(())
124 }
125
126 async fn load_ca(
127 cert_path: &PathBuf,
128 key_path: &PathBuf,
129 ) -> Result<(rcgen::Certificate, KeyPair)> {
130 let key_pem = tokio::fs::read_to_string(key_path).await?;
131 let key = KeyPair::from_pem(&key_pem)?;
132
133 let cert_pem = tokio::fs::read_to_string(cert_path).await?;
134 let params = CertificateParams::from_ca_cert_pem(&cert_pem)?;
135
136 let cert = params.self_signed(&key)?;
140
141 let original_der = Self::pem_to_der(&cert_pem)?;
142 let regenerated_der = cert.der().to_vec();
143 let original_spki = Self::extract_spki(&original_der)?;
144 let regenerated_spki = Self::extract_spki(®enerated_der)?;
145 if original_spki != regenerated_spki {
146 return Err(ProxyError::Other(
147 "CA certificate and private key do not match: \
148 public key in ca.pem differs from ca-key.pem"
149 .into(),
150 ));
151 }
152
153 Ok((cert, key))
154 }
155
156 fn pem_to_der(pem_str: &str) -> Result<Vec<u8>> {
158 let mut reader = std::io::BufReader::new(pem_str.as_bytes());
159 let certs = rustls_pemfile::certs(&mut reader)
160 .collect::<std::result::Result<Vec<_>, _>>()?;
161 certs
162 .into_iter()
163 .next()
164 .map(|c| c.to_vec())
165 .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
166 }
167
168 fn extract_spki(der: &[u8]) -> Result<Vec<u8>> {
171 let tbs = Self::asn1_sequence_contents(der)?;
173 let tbs_inner = Self::asn1_sequence_contents(tbs)?;
176
177 let mut pos = 0;
178 for i in 0..6 {
180 if pos >= tbs_inner.len() {
181 return Err(ProxyError::Other(
182 format!("Unexpected end of TBSCertificate at field {i}"),
183 ));
184 }
185 let (_, field_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
186 pos += field_len;
187 }
188
189 if pos >= tbs_inner.len() {
191 return Err(ProxyError::Other(
192 "SubjectPublicKeyInfo not found in certificate".into(),
193 ));
194 }
195 let (_, spki_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
196 Ok(tbs_inner[pos..pos + spki_len].to_vec())
197 }
198
199 fn asn1_sequence_contents(data: &[u8]) -> Result<&[u8]> {
201 if data.is_empty() || (data[0] & 0x1f) != 0x10 {
202 return Err(ProxyError::Other("Expected ASN.1 SEQUENCE".into()));
203 }
204 let (header_len, total_len) = Self::asn1_read_tag_and_length(data)?;
205 let content_len = total_len - header_len;
206 Ok(&data[header_len..header_len + content_len])
207 }
208
209 fn asn1_read_tag_and_length(data: &[u8]) -> Result<(usize, usize)> {
211 if data.len() < 2 {
212 return Err(ProxyError::Other("ASN.1 data too short".into()));
213 }
214 let mut pos = 1; let length_byte = data[pos];
216 pos += 1;
217
218 let content_len = if length_byte & 0x80 == 0 {
219 length_byte as usize
220 } else {
221 let num_bytes = (length_byte & 0x7f) as usize;
222 if pos + num_bytes > data.len() {
223 return Err(ProxyError::Other("ASN.1 length overflow".into()));
224 }
225 let mut len = 0usize;
226 for &b in &data[pos..pos + num_bytes] {
227 len = (len << 8) | b as usize;
228 }
229 pos += num_bytes;
230 len
231 };
232
233 let total_len = pos + content_len;
234 if total_len > data.len() {
235 return Err(ProxyError::Other(
236 "ASN.1 element extends beyond input data".into(),
237 ));
238 }
239
240 Ok((pos, total_len))
241 }
242
243 fn generate_domain_cert(&self, domain: &str) -> Result<CertifiedKey> {
244 let mut params = CertificateParams::new(vec![domain.to_string()])?;
245 let mut dn = DistinguishedName::new();
246 dn.push(DnType::CommonName, domain);
247 params.distinguished_name = dn;
248
249 if let Ok(ip) = domain.parse::<std::net::IpAddr>() {
252 params.subject_alt_names = vec![SanType::IpAddress(ip)];
253 }
254
255 let key = KeyPair::generate()?;
256 let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
257
258 let cert_der = CertificateDer::from(cert.der().to_vec());
259 let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
260
261 Ok(CertifiedKey { cert_der, key_der })
262 }
263}