Skip to main content

rustgate/
cert.rs

1use crate::error::{ProxyError, Result};
2use rcgen::{
3    BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyPair, KeyUsagePurpose,
4    SanType,
5};
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::{debug, info};
12
13/// Holds a certificate and its private key for TLS.
14pub struct CertifiedKey {
15    pub cert_der: CertificateDer<'static>,
16    pub key_der: PrivatePkcs8KeyDer<'static>,
17}
18
19/// Manages the root CA and generates per-domain certificates.
20pub struct CertificateAuthority {
21    ca_cert: rcgen::Certificate,
22    ca_key: KeyPair,
23    cache: Mutex<HashMap<String, Arc<CertifiedKey>>>,
24}
25
26impl CertificateAuthority {
27    /// Load or create a CA certificate. Stores files under `~/.rustgate/`.
28    pub async fn new() -> Result<Self> {
29        Self::with_dir(Self::ca_dir()?).await
30    }
31
32    /// Load or create a CA certificate in the specified directory.
33    pub async fn with_dir(dir: PathBuf) -> Result<Self> {
34        tokio::fs::create_dir_all(&dir).await?;
35
36        let cert_path = dir.join("ca.pem");
37        let key_path = dir.join("ca-key.pem");
38
39        let cert_exists = cert_path.exists();
40        let key_exists = key_path.exists();
41
42        // Partial CA state is fatal — prevent silent rekey
43        if cert_exists != key_exists {
44            return Err(ProxyError::Other(format!(
45                "Partial CA state in {}: {} exists but {} is missing. \
46                 Restore the missing file or remove both to reinitialize.",
47                dir.display(),
48                if cert_exists { "ca.pem" } else { "ca-key.pem" },
49                if cert_exists { "ca-key.pem" } else { "ca.pem" },
50            )));
51        }
52
53        let (ca_cert, ca_key) = if cert_exists {
54            info!("Loading existing CA certificate from {}", dir.display());
55            Self::load_ca(&cert_path, &key_path).await?
56        } else {
57            info!("Generating new CA certificate in {}", dir.display());
58            let (cert, key) = Self::generate_ca()?;
59            Self::save_ca(&cert, &key, &cert_path, &key_path).await?;
60            (cert, key)
61        };
62
63        Ok(Self {
64            ca_cert,
65            ca_key,
66            cache: Mutex::new(HashMap::new()),
67        })
68    }
69
70    /// Return the path to the CA PEM file for users to install.
71    pub fn ca_cert_path() -> Result<PathBuf> {
72        Ok(Self::ca_dir()?.join("ca.pem"))
73    }
74
75    /// Generate a fake certificate for the given domain, signed by the CA.
76    pub async fn get_or_create_cert(&self, domain: &str) -> Result<Arc<CertifiedKey>> {
77        {
78            let cache = self.cache.lock().await;
79            if let Some(ck) = cache.get(domain) {
80                debug!("Using cached certificate for {domain}");
81                return Ok(ck.clone());
82            }
83        }
84
85        debug!("Generating certificate for {domain}");
86        let ck = self.generate_domain_cert(domain)?;
87        let ck = Arc::new(ck);
88
89        {
90            let mut cache = self.cache.lock().await;
91            cache.insert(domain.to_string(), ck.clone());
92        }
93
94        Ok(ck)
95    }
96
97    fn ca_dir() -> Result<PathBuf> {
98        let home = std::env::var("HOME")
99            .map_err(|_| ProxyError::Other("HOME environment variable not set".into()))?;
100        Ok(PathBuf::from(home).join(".rustgate"))
101    }
102
103    fn generate_ca() -> Result<(rcgen::Certificate, KeyPair)> {
104        let mut params = CertificateParams::default();
105        let mut dn = DistinguishedName::new();
106        dn.push(DnType::CommonName, "RustGate CA");
107        dn.push(DnType::OrganizationName, "RustGate");
108        params.distinguished_name = dn;
109        params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
110        params.key_usages = vec![
111            KeyUsagePurpose::KeyCertSign,
112            KeyUsagePurpose::CrlSign,
113        ];
114
115        let key = KeyPair::generate()?;
116        let cert = params.self_signed(&key)?;
117        Ok((cert, key))
118    }
119
120    async fn save_ca(
121        cert: &rcgen::Certificate,
122        key: &KeyPair,
123        cert_path: &PathBuf,
124        key_path: &PathBuf,
125    ) -> Result<()> {
126        tokio::fs::write(cert_path, cert.pem()).await?;
127        tokio::fs::write(key_path, key.serialize_pem()).await?;
128
129        // Restrict private key to owner-only access (0600)
130        #[cfg(unix)]
131        {
132            use std::os::unix::fs::PermissionsExt;
133            let perms = std::fs::Permissions::from_mode(0o600);
134            tokio::fs::set_permissions(key_path, perms).await?;
135        }
136
137        Ok(())
138    }
139
140    async fn load_ca(
141        cert_path: &PathBuf,
142        key_path: &PathBuf,
143    ) -> Result<(rcgen::Certificate, KeyPair)> {
144        let key_pem = tokio::fs::read_to_string(key_path).await?;
145        let key = KeyPair::from_pem(&key_pem)?;
146
147        let cert_pem = tokio::fs::read_to_string(cert_path).await?;
148        let params = CertificateParams::from_ca_cert_pem(&cert_pem)?;
149
150        // Verify the private key matches the certificate's public key.
151        // Re-sign with the loaded key and check that the public key in
152        // the resulting cert matches the original.
153        let cert = params.self_signed(&key)?;
154
155        let original_der = Self::pem_to_der(&cert_pem)?;
156        let regenerated_der = cert.der().to_vec();
157        let original_spki = Self::extract_spki(&original_der)?;
158        let regenerated_spki = Self::extract_spki(&regenerated_der)?;
159        if original_spki != regenerated_spki {
160            return Err(ProxyError::Other(
161                "CA certificate and private key do not match: \
162                 public key in ca.pem differs from ca-key.pem"
163                    .into(),
164            ));
165        }
166
167        Ok((cert, key))
168    }
169
170    /// Extract the raw PEM body into DER bytes.
171    fn pem_to_der(pem_str: &str) -> Result<Vec<u8>> {
172        let mut reader = std::io::BufReader::new(pem_str.as_bytes());
173        let certs = rustls_pemfile::certs(&mut reader)
174            .collect::<std::result::Result<Vec<_>, _>>()?;
175        certs
176            .into_iter()
177            .next()
178            .map(|c| c.to_vec())
179            .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
180    }
181
182    /// Extract SubjectPublicKeyInfo bytes from a DER-encoded X.509 certificate.
183    /// Uses minimal ASN.1 parsing: Certificate -> TBSCertificate -> SPKI (7th field).
184    fn extract_spki(der: &[u8]) -> Result<Vec<u8>> {
185        // Certificate is a SEQUENCE containing TBSCertificate, signatureAlgorithm, signature
186        let tbs = Self::asn1_sequence_contents(der)?;
187        // TBSCertificate is a SEQUENCE: version, serialNumber, signature, issuer,
188        //   validity, subject, subjectPublicKeyInfo, ...
189        let tbs_inner = Self::asn1_sequence_contents(tbs)?;
190
191        let mut pos = 0;
192        // Skip 6 fields: version (explicit tag [0]), serial, sigAlg, issuer, validity, subject
193        for i in 0..6 {
194            if pos >= tbs_inner.len() {
195                return Err(ProxyError::Other(
196                    format!("Unexpected end of TBSCertificate at field {i}"),
197                ));
198            }
199            let (_, field_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
200            pos += field_len;
201        }
202
203        // The 7th field is SubjectPublicKeyInfo
204        if pos >= tbs_inner.len() {
205            return Err(ProxyError::Other(
206                "SubjectPublicKeyInfo not found in certificate".into(),
207            ));
208        }
209        let (_, spki_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
210        Ok(tbs_inner[pos..pos + spki_len].to_vec())
211    }
212
213    /// Parse the contents (value bytes) of an ASN.1 SEQUENCE.
214    fn asn1_sequence_contents(data: &[u8]) -> Result<&[u8]> {
215        if data.is_empty() || (data[0] & 0x1f) != 0x10 {
216            return Err(ProxyError::Other("Expected ASN.1 SEQUENCE".into()));
217        }
218        let (header_len, total_len) = Self::asn1_read_tag_and_length(data)?;
219        let content_len = total_len - header_len;
220        Ok(&data[header_len..header_len + content_len])
221    }
222
223    /// Read ASN.1 tag and length, returning (header_size, total_element_size).
224    fn asn1_read_tag_and_length(data: &[u8]) -> Result<(usize, usize)> {
225        if data.len() < 2 {
226            return Err(ProxyError::Other("ASN.1 data too short".into()));
227        }
228        let mut pos = 1; // skip tag byte
229        let length_byte = data[pos];
230        pos += 1;
231
232        let content_len = if length_byte & 0x80 == 0 {
233            length_byte as usize
234        } else {
235            let num_bytes = (length_byte & 0x7f) as usize;
236            if pos + num_bytes > data.len() {
237                return Err(ProxyError::Other("ASN.1 length overflow".into()));
238            }
239            let mut len = 0usize;
240            for &b in &data[pos..pos + num_bytes] {
241                len = (len << 8) | b as usize;
242            }
243            pos += num_bytes;
244            len
245        };
246
247        let total_len = pos + content_len;
248        if total_len > data.len() {
249            return Err(ProxyError::Other(
250                "ASN.1 element extends beyond input data".into(),
251            ));
252        }
253
254        Ok((pos, total_len))
255    }
256
257    /// Generate a client certificate signed by this CA (EKU: ClientAuth).
258    /// Returns (cert_pem, key_pem) as Strings.
259    pub fn generate_client_cert(&self, cn: &str) -> Result<(String, String)> {
260        let mut params = CertificateParams::default();
261        let mut dn = DistinguishedName::new();
262        dn.push(DnType::CommonName, cn);
263        dn.push(DnType::OrganizationName, "RustGate");
264        params.distinguished_name = dn;
265        params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ClientAuth];
266
267        let key = KeyPair::generate()?;
268        let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
269
270        Ok((cert.pem(), key.serialize_pem()))
271    }
272
273    /// Generate a server certificate signed by this CA (EKU: ServerAuth).
274    pub fn generate_server_cert(&self, host: &str) -> Result<CertifiedKey> {
275        let mut params = CertificateParams::new(vec![host.to_string()])?;
276        let mut dn = DistinguishedName::new();
277        dn.push(DnType::CommonName, host);
278        params.distinguished_name = dn;
279        params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];
280
281        if let Ok(ip) = host.parse::<std::net::IpAddr>() {
282            params.subject_alt_names = vec![SanType::IpAddress(ip)];
283        }
284
285        let key = KeyPair::generate()?;
286        let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
287
288        let cert_der = CertificateDer::from(cert.der().to_vec());
289        let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
290        Ok(CertifiedKey { cert_der, key_der })
291    }
292
293    /// Return the CA certificate in DER format (for building RootCertStore).
294    pub fn ca_cert_der(&self) -> CertificateDer<'static> {
295        CertificateDer::from(self.ca_cert.der().to_vec())
296    }
297
298    fn generate_domain_cert(&self, domain: &str) -> Result<CertifiedKey> {
299        let mut params = CertificateParams::new(vec![domain.to_string()])?;
300        let mut dn = DistinguishedName::new();
301        dn.push(DnType::CommonName, domain);
302        params.distinguished_name = dn;
303
304        // SAN is already set by CertificateParams::new
305        // Override for IP addresses
306        if let Ok(ip) = domain.parse::<std::net::IpAddr>() {
307            params.subject_alt_names = vec![SanType::IpAddress(ip)];
308        }
309
310        let key = KeyPair::generate()?;
311        let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
312
313        let cert_der = CertificateDer::from(cert.der().to_vec());
314        let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
315
316        Ok(CertifiedKey { cert_der, key_der })
317    }
318}