Skip to main content

spec_ai/spec_ai_api/api/
tls.rs

1//! TLS certificate generation and management
2//!
3//! Provides self-signed certificate generation using rcgen for the API server.
4//! The self-signed nature of the certificate can be used by clients (like the
5//! visionOS app) to verify they're connecting to a legitimate spec-ai server
6//! by validating the certificate fingerprint.
7
8use anyhow::{Context, Result};
9use base64::{engine::general_purpose::STANDARD, Engine};
10use rcgen::{
11    CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, SanType,
12};
13use ring::digest::{digest, SHA256};
14use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
15use serde::{Deserialize, Serialize};
16use std::path::Path;
17use std::sync::Arc;
18use std::time::Duration;
19
20/// Default certificate validity period (365 days)
21const DEFAULT_CERT_VALIDITY_DAYS: u32 = 365;
22
23/// Organization name for the certificate
24const CERT_ORG_NAME: &str = "spec-ai";
25
26/// Common name prefix for the certificate
27const CERT_CN_PREFIX: &str = "spec-ai-server";
28
29/// TLS configuration and certificate info
30#[derive(Debug, Clone)]
31pub struct TlsConfig {
32    /// The generated or loaded certificate (DER format)
33    pub certificate: Vec<u8>,
34    /// The private key (DER format)
35    pub private_key: Vec<u8>,
36    /// SHA-256 fingerprint of the certificate (hex encoded)
37    pub fingerprint: String,
38    /// Certificate in PEM format (for export/display)
39    pub certificate_pem: String,
40    /// When the certificate expires
41    pub not_after: String,
42}
43
44/// Certificate metadata returned to clients
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CertificateInfo {
47    /// SHA-256 fingerprint of the certificate (hex encoded)
48    pub fingerprint: String,
49    /// Certificate in PEM format
50    pub certificate_pem: String,
51    /// When the certificate was issued
52    pub not_before: String,
53    /// When the certificate expires
54    pub not_after: String,
55    /// Subject common name
56    pub subject: String,
57    /// Subject alternative names
58    pub san: Vec<String>,
59}
60
61impl TlsConfig {
62    /// Generate a new self-signed certificate
63    ///
64    /// # Arguments
65    /// * `hostname` - Primary hostname for the certificate
66    /// * `additional_sans` - Additional Subject Alternative Names (IPs, hostnames)
67    /// * `validity_days` - Certificate validity period in days
68    pub fn generate(
69        hostname: &str,
70        additional_sans: &[String],
71        validity_days: Option<u32>,
72    ) -> Result<Self> {
73        let validity = validity_days.unwrap_or(DEFAULT_CERT_VALIDITY_DAYS);
74
75        // Generate key pair
76        let key_pair = KeyPair::generate().context("Failed to generate key pair")?;
77
78        // Build certificate parameters
79        let mut params = CertificateParams::default();
80
81        // Set distinguished name
82        params
83            .distinguished_name
84            .push(DnType::OrganizationName, CERT_ORG_NAME);
85        params.distinguished_name.push(
86            DnType::CommonName,
87            format!("{}-{}", CERT_CN_PREFIX, hostname),
88        );
89
90        // Set validity period
91        let now = time::OffsetDateTime::now_utc();
92        params.not_before = now;
93        params.not_after = now + Duration::from_secs(validity as u64 * 24 * 60 * 60);
94
95        // Set key usages for TLS server
96        params.key_usages = vec![
97            KeyUsagePurpose::DigitalSignature,
98            KeyUsagePurpose::KeyEncipherment,
99        ];
100        params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
101
102        // Not a CA certificate
103        params.is_ca = IsCa::NoCa;
104
105        // Add Subject Alternative Names
106        let mut sans = vec![SanType::DnsName(
107            hostname.try_into().context("Invalid hostname")?,
108        )];
109
110        // Always add localhost variants
111        if hostname != "localhost" {
112            if let Ok(localhost) = "localhost".try_into() {
113                sans.push(SanType::DnsName(localhost));
114            }
115        }
116
117        // Add 127.0.0.1 as IP SAN
118        sans.push(SanType::IpAddress(std::net::IpAddr::V4(
119            std::net::Ipv4Addr::new(127, 0, 0, 1),
120        )));
121
122        // Add additional SANs
123        for san in additional_sans {
124            if let Ok(ip) = san.parse::<std::net::IpAddr>() {
125                sans.push(SanType::IpAddress(ip));
126            } else if let Ok(dns) = san.as_str().try_into() {
127                sans.push(SanType::DnsName(dns));
128            }
129        }
130
131        params.subject_alt_names = sans;
132
133        // Save not_after before consuming params
134        let not_after_time = params.not_after;
135
136        // Generate the certificate
137        let cert = params
138            .self_signed(&key_pair)
139            .context("Failed to generate self-signed certificate")?;
140
141        // Get DER-encoded certificate and key
142        let cert_der = cert.der().to_vec();
143        let key_der = key_pair.serialize_der();
144
145        // Calculate fingerprint
146        let fingerprint = Self::calculate_fingerprint(&cert_der);
147
148        // Get PEM format
149        let cert_pem = cert.pem();
150
151        // Format expiry date
152        let not_after = not_after_time
153            .format(&time::format_description::well_known::Rfc3339)
154            .unwrap_or_else(|_| "unknown".to_string());
155
156        tracing::info!(
157            "Generated self-signed TLS certificate for {} (fingerprint: {})",
158            hostname,
159            fingerprint
160        );
161
162        Ok(Self {
163            certificate: cert_der,
164            private_key: key_der,
165            fingerprint,
166            certificate_pem: cert_pem,
167            not_after,
168        })
169    }
170
171    /// Load certificate and key from PEM files
172    pub fn load_from_files(cert_path: &Path, key_path: &Path) -> Result<Self> {
173        let cert_pem = std::fs::read_to_string(cert_path)
174            .with_context(|| format!("Failed to read certificate file: {}", cert_path.display()))?;
175
176        let key_pem = std::fs::read_to_string(key_path)
177            .with_context(|| format!("Failed to read key file: {}", key_path.display()))?;
178
179        Self::load_from_pem(&cert_pem, &key_pem)
180    }
181
182    /// Load certificate and key from PEM strings
183    pub fn load_from_pem(cert_pem: &str, key_pem: &str) -> Result<Self> {
184        // Parse certificate
185        let mut cert_reader = std::io::BufReader::new(cert_pem.as_bytes());
186        let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
187            .collect::<Result<Vec<_>, _>>()
188            .context("Failed to parse certificate PEM")?;
189
190        let cert_der = certs
191            .into_iter()
192            .next()
193            .context("No certificate found in PEM")?;
194
195        // Parse private key
196        let mut key_reader = std::io::BufReader::new(key_pem.as_bytes());
197        let key_der = rustls_pemfile::private_key(&mut key_reader)
198            .context("Failed to parse private key PEM")?
199            .context("No private key found in PEM")?;
200
201        let fingerprint = Self::calculate_fingerprint(cert_der.as_ref());
202
203        Ok(Self {
204            certificate: cert_der.to_vec(),
205            private_key: match key_der {
206                PrivateKeyDer::Pkcs8(k) => k.secret_pkcs8_der().to_vec(),
207                PrivateKeyDer::Pkcs1(k) => k.secret_pkcs1_der().to_vec(),
208                PrivateKeyDer::Sec1(k) => k.secret_sec1_der().to_vec(),
209                _ => anyhow::bail!("Unsupported private key format"),
210            },
211            fingerprint,
212            certificate_pem: cert_pem.to_string(),
213            not_after: "unknown".to_string(), // Would need to parse cert to get this
214        })
215    }
216
217    /// Save certificate and key to PEM files
218    pub fn save_to_files(&self, cert_path: &Path, key_path: &Path) -> Result<()> {
219        // Ensure parent directories exist
220        if let Some(parent) = cert_path.parent() {
221            std::fs::create_dir_all(parent)?;
222        }
223        if let Some(parent) = key_path.parent() {
224            std::fs::create_dir_all(parent)?;
225        }
226
227        // Save certificate PEM
228        std::fs::write(cert_path, &self.certificate_pem)
229            .with_context(|| format!("Failed to write certificate to {}", cert_path.display()))?;
230
231        // Convert private key to PEM and save
232        let key_pem = format!(
233            "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----\n",
234            STANDARD.encode(&self.private_key)
235        );
236        std::fs::write(key_path, &key_pem)
237            .with_context(|| format!("Failed to write private key to {}", key_path.display()))?;
238
239        // Set restrictive permissions on key file
240        #[cfg(unix)]
241        {
242            use std::os::unix::fs::PermissionsExt;
243            std::fs::set_permissions(key_path, std::fs::Permissions::from_mode(0o600))?;
244        }
245
246        tracing::info!(
247            "Saved TLS certificate to {} and key to {}",
248            cert_path.display(),
249            key_path.display()
250        );
251
252        Ok(())
253    }
254
255    /// Calculate SHA-256 fingerprint of a certificate (DER format)
256    pub fn calculate_fingerprint(cert_der: &[u8]) -> String {
257        let hash = digest(&SHA256, cert_der);
258        hash.as_ref()
259            .iter()
260            .map(|b| format!("{:02X}", b))
261            .collect::<Vec<_>>()
262            .join(":")
263    }
264
265    /// Get certificate info for clients
266    pub fn get_certificate_info(&self, hostname: &str) -> CertificateInfo {
267        CertificateInfo {
268            fingerprint: self.fingerprint.clone(),
269            certificate_pem: self.certificate_pem.clone(),
270            not_before: "see certificate".to_string(),
271            not_after: self.not_after.clone(),
272            subject: format!("CN={}-{}, O={}", CERT_CN_PREFIX, hostname, CERT_ORG_NAME),
273            san: vec![
274                hostname.to_string(),
275                "localhost".to_string(),
276                "127.0.0.1".to_string(),
277            ],
278        }
279    }
280
281    /// Build rustls ServerConfig from this TLS config
282    pub fn build_server_config(&self) -> Result<Arc<rustls::ServerConfig>> {
283        let cert = CertificateDer::from(self.certificate.clone());
284        let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(self.private_key.clone()));
285
286        // Use aws-lc-rs as the crypto provider (installed by default via axum-server)
287        let config = rustls::ServerConfig::builder_with_provider(Arc::new(
288            rustls::crypto::aws_lc_rs::default_provider(),
289        ))
290        .with_safe_default_protocol_versions()
291        .context("Failed to set protocol versions")?
292        .with_no_client_auth()
293        .with_single_cert(vec![cert], key)
294        .context("Failed to build TLS server config")?;
295
296        Ok(Arc::new(config))
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_generate_certificate() {
306        let config = TlsConfig::generate("test.local", &[], Some(30)).unwrap();
307
308        assert!(!config.certificate.is_empty());
309        assert!(!config.private_key.is_empty());
310        assert!(!config.fingerprint.is_empty());
311        assert!(config.certificate_pem.contains("BEGIN CERTIFICATE"));
312
313        // Fingerprint should be colon-separated hex
314        assert!(config.fingerprint.contains(':'));
315        let parts: Vec<&str> = config.fingerprint.split(':').collect();
316        assert_eq!(parts.len(), 32); // SHA-256 = 32 bytes
317    }
318
319    #[test]
320    fn test_fingerprint_calculation() {
321        let data = b"test certificate data";
322        let fingerprint = TlsConfig::calculate_fingerprint(data);
323
324        // Should be 32 hex pairs separated by colons
325        let parts: Vec<&str> = fingerprint.split(':').collect();
326        assert_eq!(parts.len(), 32);
327
328        // Each part should be 2 hex chars
329        for part in parts {
330            assert_eq!(part.len(), 2);
331            assert!(part.chars().all(|c| c.is_ascii_hexdigit()));
332        }
333    }
334
335    #[test]
336    fn test_build_server_config() {
337        let tls = TlsConfig::generate("localhost", &[], None).unwrap();
338        let server_config = tls.build_server_config();
339        assert!(server_config.is_ok());
340    }
341
342    #[test]
343    fn test_additional_sans() {
344        let additional = vec!["192.168.1.100".to_string(), "myserver.local".to_string()];
345        let config = TlsConfig::generate("primary.local", &additional, None).unwrap();
346
347        assert!(!config.certificate.is_empty());
348        // The SANs are embedded in the certificate - we'd need to parse it to verify
349    }
350
351    #[test]
352    fn test_save_and_load() {
353        let temp_dir = tempfile::tempdir().unwrap();
354        let cert_path = temp_dir.path().join("cert.pem");
355        let key_path = temp_dir.path().join("key.pem");
356
357        // Generate and save
358        let original = TlsConfig::generate("test.local", &[], None).unwrap();
359        original.save_to_files(&cert_path, &key_path).unwrap();
360
361        // Load and verify
362        let loaded = TlsConfig::load_from_files(&cert_path, &key_path).unwrap();
363
364        assert_eq!(original.certificate, loaded.certificate);
365        assert_eq!(original.fingerprint, loaded.fingerprint);
366    }
367}