Skip to main content

rustgate/
cert.rs

1use crate::error::{ProxyError, Result};
2use rcgen::{
3    BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyPair, KeyUsagePurpose,
4    SanType,
5};
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::{debug, info};
12
13/// Holds a certificate and its private key for TLS.
14pub struct CertifiedKey {
15    pub cert_der: CertificateDer<'static>,
16    pub key_der: PrivatePkcs8KeyDer<'static>,
17}
18
19/// Manages the root CA and generates per-domain certificates.
20pub struct CertificateAuthority {
21    ca_cert: rcgen::Certificate,
22    ca_key: KeyPair,
23    cache: Mutex<HashMap<String, Arc<CertifiedKey>>>,
24}
25
26impl CertificateAuthority {
27    /// Load or create a CA certificate. Stores files under `~/.rustgate/`.
28    pub async fn new() -> Result<Self> {
29        Self::with_dir(Self::ca_dir()?).await
30    }
31
32    /// Load or create a CA certificate in the specified directory.
33    pub async fn with_dir(dir: PathBuf) -> Result<Self> {
34        tokio::fs::create_dir_all(&dir).await?;
35
36        let cert_path = dir.join("ca.pem");
37        let key_path = dir.join("ca-key.pem");
38
39        let (ca_cert, ca_key) = if cert_path.exists() && key_path.exists() {
40            info!("Loading existing CA certificate from {}", dir.display());
41            Self::load_ca(&cert_path, &key_path).await?
42        } else {
43            info!("Generating new CA certificate in {}", dir.display());
44            let (cert, key) = Self::generate_ca()?;
45            Self::save_ca(&cert, &key, &cert_path, &key_path).await?;
46            (cert, key)
47        };
48
49        Ok(Self {
50            ca_cert,
51            ca_key,
52            cache: Mutex::new(HashMap::new()),
53        })
54    }
55
56    /// Return the path to the CA PEM file for users to install.
57    pub fn ca_cert_path() -> Result<PathBuf> {
58        Ok(Self::ca_dir()?.join("ca.pem"))
59    }
60
61    /// Generate a fake certificate for the given domain, signed by the CA.
62    pub async fn get_or_create_cert(&self, domain: &str) -> Result<Arc<CertifiedKey>> {
63        {
64            let cache = self.cache.lock().await;
65            if let Some(ck) = cache.get(domain) {
66                debug!("Using cached certificate for {domain}");
67                return Ok(ck.clone());
68            }
69        }
70
71        debug!("Generating certificate for {domain}");
72        let ck = self.generate_domain_cert(domain)?;
73        let ck = Arc::new(ck);
74
75        {
76            let mut cache = self.cache.lock().await;
77            cache.insert(domain.to_string(), ck.clone());
78        }
79
80        Ok(ck)
81    }
82
83    fn ca_dir() -> Result<PathBuf> {
84        let home = std::env::var("HOME")
85            .map_err(|_| ProxyError::Other("HOME environment variable not set".into()))?;
86        Ok(PathBuf::from(home).join(".rustgate"))
87    }
88
89    fn generate_ca() -> Result<(rcgen::Certificate, KeyPair)> {
90        let mut params = CertificateParams::default();
91        let mut dn = DistinguishedName::new();
92        dn.push(DnType::CommonName, "RustGate CA");
93        dn.push(DnType::OrganizationName, "RustGate");
94        params.distinguished_name = dn;
95        params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
96        params.key_usages = vec![
97            KeyUsagePurpose::KeyCertSign,
98            KeyUsagePurpose::CrlSign,
99        ];
100
101        let key = KeyPair::generate()?;
102        let cert = params.self_signed(&key)?;
103        Ok((cert, key))
104    }
105
106    async fn save_ca(
107        cert: &rcgen::Certificate,
108        key: &KeyPair,
109        cert_path: &PathBuf,
110        key_path: &PathBuf,
111    ) -> Result<()> {
112        tokio::fs::write(cert_path, cert.pem()).await?;
113        tokio::fs::write(key_path, key.serialize_pem()).await?;
114
115        // Restrict private key to owner-only access (0600)
116        #[cfg(unix)]
117        {
118            use std::os::unix::fs::PermissionsExt;
119            let perms = std::fs::Permissions::from_mode(0o600);
120            tokio::fs::set_permissions(key_path, perms).await?;
121        }
122
123        Ok(())
124    }
125
126    async fn load_ca(
127        cert_path: &PathBuf,
128        key_path: &PathBuf,
129    ) -> Result<(rcgen::Certificate, KeyPair)> {
130        let key_pem = tokio::fs::read_to_string(key_path).await?;
131        let key = KeyPair::from_pem(&key_pem)?;
132
133        let cert_pem = tokio::fs::read_to_string(cert_path).await?;
134        let params = CertificateParams::from_ca_cert_pem(&cert_pem)?;
135
136        // Verify the private key matches the certificate's public key.
137        // Re-sign with the loaded key and check that the public key in
138        // the resulting cert matches the original.
139        let cert = params.self_signed(&key)?;
140
141        let original_der = Self::pem_to_der(&cert_pem)?;
142        let regenerated_der = cert.der().to_vec();
143        let original_spki = Self::extract_spki(&original_der)?;
144        let regenerated_spki = Self::extract_spki(&regenerated_der)?;
145        if original_spki != regenerated_spki {
146            return Err(ProxyError::Other(
147                "CA certificate and private key do not match: \
148                 public key in ca.pem differs from ca-key.pem"
149                    .into(),
150            ));
151        }
152
153        Ok((cert, key))
154    }
155
156    /// Extract the raw PEM body into DER bytes.
157    fn pem_to_der(pem_str: &str) -> Result<Vec<u8>> {
158        let mut reader = std::io::BufReader::new(pem_str.as_bytes());
159        let certs = rustls_pemfile::certs(&mut reader)
160            .collect::<std::result::Result<Vec<_>, _>>()?;
161        certs
162            .into_iter()
163            .next()
164            .map(|c| c.to_vec())
165            .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
166    }
167
168    /// Extract SubjectPublicKeyInfo bytes from a DER-encoded X.509 certificate.
169    /// Uses minimal ASN.1 parsing: Certificate -> TBSCertificate -> SPKI (7th field).
170    fn extract_spki(der: &[u8]) -> Result<Vec<u8>> {
171        // Certificate is a SEQUENCE containing TBSCertificate, signatureAlgorithm, signature
172        let tbs = Self::asn1_sequence_contents(der)?;
173        // TBSCertificate is a SEQUENCE: version, serialNumber, signature, issuer,
174        //   validity, subject, subjectPublicKeyInfo, ...
175        let tbs_inner = Self::asn1_sequence_contents(tbs)?;
176
177        let mut pos = 0;
178        // Skip 6 fields: version (explicit tag [0]), serial, sigAlg, issuer, validity, subject
179        for i in 0..6 {
180            if pos >= tbs_inner.len() {
181                return Err(ProxyError::Other(
182                    format!("Unexpected end of TBSCertificate at field {i}"),
183                ));
184            }
185            let (_, field_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
186            pos += field_len;
187        }
188
189        // The 7th field is SubjectPublicKeyInfo
190        if pos >= tbs_inner.len() {
191            return Err(ProxyError::Other(
192                "SubjectPublicKeyInfo not found in certificate".into(),
193            ));
194        }
195        let (_, spki_len) = Self::asn1_read_tag_and_length(&tbs_inner[pos..])?;
196        Ok(tbs_inner[pos..pos + spki_len].to_vec())
197    }
198
199    /// Parse the contents (value bytes) of an ASN.1 SEQUENCE.
200    fn asn1_sequence_contents(data: &[u8]) -> Result<&[u8]> {
201        if data.is_empty() || (data[0] & 0x1f) != 0x10 {
202            return Err(ProxyError::Other("Expected ASN.1 SEQUENCE".into()));
203        }
204        let (header_len, total_len) = Self::asn1_read_tag_and_length(data)?;
205        let content_len = total_len - header_len;
206        Ok(&data[header_len..header_len + content_len])
207    }
208
209    /// Read ASN.1 tag and length, returning (header_size, total_element_size).
210    fn asn1_read_tag_and_length(data: &[u8]) -> Result<(usize, usize)> {
211        if data.len() < 2 {
212            return Err(ProxyError::Other("ASN.1 data too short".into()));
213        }
214        let mut pos = 1; // skip tag byte
215        let length_byte = data[pos];
216        pos += 1;
217
218        let content_len = if length_byte & 0x80 == 0 {
219            length_byte as usize
220        } else {
221            let num_bytes = (length_byte & 0x7f) as usize;
222            if pos + num_bytes > data.len() {
223                return Err(ProxyError::Other("ASN.1 length overflow".into()));
224            }
225            let mut len = 0usize;
226            for &b in &data[pos..pos + num_bytes] {
227                len = (len << 8) | b as usize;
228            }
229            pos += num_bytes;
230            len
231        };
232
233        let total_len = pos + content_len;
234        if total_len > data.len() {
235            return Err(ProxyError::Other(
236                "ASN.1 element extends beyond input data".into(),
237            ));
238        }
239
240        Ok((pos, total_len))
241    }
242
243    fn generate_domain_cert(&self, domain: &str) -> Result<CertifiedKey> {
244        let mut params = CertificateParams::new(vec![domain.to_string()])?;
245        let mut dn = DistinguishedName::new();
246        dn.push(DnType::CommonName, domain);
247        params.distinguished_name = dn;
248
249        // SAN is already set by CertificateParams::new
250        // Override for IP addresses
251        if let Ok(ip) = domain.parse::<std::net::IpAddr>() {
252            params.subject_alt_names = vec![SanType::IpAddress(ip)];
253        }
254
255        let key = KeyPair::generate()?;
256        let cert = params.signed_by(&key, &self.ca_cert, &self.ca_key)?;
257
258        let cert_der = CertificateDer::from(cert.der().to_vec());
259        let key_der = PrivatePkcs8KeyDer::from(key.serialize_der());
260
261        Ok(CertifiedKey { cert_der, key_der })
262    }
263}