1use anyhow::{Context, Result};
9use base64::{engine::general_purpose::STANDARD, Engine};
10use rcgen::{
11 CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, SanType,
12};
13use ring::digest::{digest, SHA256};
14use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
15use serde::{Deserialize, Serialize};
16use std::path::Path;
17use std::sync::Arc;
18use std::time::Duration;
19
20const DEFAULT_CERT_VALIDITY_DAYS: u32 = 365;
22
23const CERT_ORG_NAME: &str = "spec-ai";
25
26const CERT_CN_PREFIX: &str = "spec-ai-server";
28
29#[derive(Debug, Clone)]
31pub struct TlsConfig {
32 pub certificate: Vec<u8>,
34 pub private_key: Vec<u8>,
36 pub fingerprint: String,
38 pub certificate_pem: String,
40 pub not_after: String,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CertificateInfo {
47 pub fingerprint: String,
49 pub certificate_pem: String,
51 pub not_before: String,
53 pub not_after: String,
55 pub subject: String,
57 pub san: Vec<String>,
59}
60
61impl TlsConfig {
62 pub fn generate(
69 hostname: &str,
70 additional_sans: &[String],
71 validity_days: Option<u32>,
72 ) -> Result<Self> {
73 let validity = validity_days.unwrap_or(DEFAULT_CERT_VALIDITY_DAYS);
74
75 let key_pair = KeyPair::generate().context("Failed to generate key pair")?;
77
78 let mut params = CertificateParams::default();
80
81 params
83 .distinguished_name
84 .push(DnType::OrganizationName, CERT_ORG_NAME);
85 params.distinguished_name.push(
86 DnType::CommonName,
87 format!("{}-{}", CERT_CN_PREFIX, hostname),
88 );
89
90 let now = time::OffsetDateTime::now_utc();
92 params.not_before = now;
93 params.not_after = now + Duration::from_secs(validity as u64 * 24 * 60 * 60);
94
95 params.key_usages = vec![
97 KeyUsagePurpose::DigitalSignature,
98 KeyUsagePurpose::KeyEncipherment,
99 ];
100 params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
101
102 params.is_ca = IsCa::NoCa;
104
105 let mut sans = vec![SanType::DnsName(
107 hostname.try_into().context("Invalid hostname")?,
108 )];
109
110 if hostname != "localhost" {
112 if let Ok(localhost) = "localhost".try_into() {
113 sans.push(SanType::DnsName(localhost));
114 }
115 }
116
117 sans.push(SanType::IpAddress(std::net::IpAddr::V4(
119 std::net::Ipv4Addr::new(127, 0, 0, 1),
120 )));
121
122 for san in additional_sans {
124 if let Ok(ip) = san.parse::<std::net::IpAddr>() {
125 sans.push(SanType::IpAddress(ip));
126 } else if let Ok(dns) = san.as_str().try_into() {
127 sans.push(SanType::DnsName(dns));
128 }
129 }
130
131 params.subject_alt_names = sans;
132
133 let not_after_time = params.not_after;
135
136 let cert = params
138 .self_signed(&key_pair)
139 .context("Failed to generate self-signed certificate")?;
140
141 let cert_der = cert.der().to_vec();
143 let key_der = key_pair.serialize_der();
144
145 let fingerprint = Self::calculate_fingerprint(&cert_der);
147
148 let cert_pem = cert.pem();
150
151 let not_after = not_after_time
153 .format(&time::format_description::well_known::Rfc3339)
154 .unwrap_or_else(|_| "unknown".to_string());
155
156 tracing::info!(
157 "Generated self-signed TLS certificate for {} (fingerprint: {})",
158 hostname,
159 fingerprint
160 );
161
162 Ok(Self {
163 certificate: cert_der,
164 private_key: key_der,
165 fingerprint,
166 certificate_pem: cert_pem,
167 not_after,
168 })
169 }
170
171 pub fn load_from_files(cert_path: &Path, key_path: &Path) -> Result<Self> {
173 let cert_pem = std::fs::read_to_string(cert_path)
174 .with_context(|| format!("Failed to read certificate file: {}", cert_path.display()))?;
175
176 let key_pem = std::fs::read_to_string(key_path)
177 .with_context(|| format!("Failed to read key file: {}", key_path.display()))?;
178
179 Self::load_from_pem(&cert_pem, &key_pem)
180 }
181
182 pub fn load_from_pem(cert_pem: &str, key_pem: &str) -> Result<Self> {
184 let mut cert_reader = std::io::BufReader::new(cert_pem.as_bytes());
186 let certs: Vec<CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
187 .collect::<Result<Vec<_>, _>>()
188 .context("Failed to parse certificate PEM")?;
189
190 let cert_der = certs
191 .into_iter()
192 .next()
193 .context("No certificate found in PEM")?;
194
195 let mut key_reader = std::io::BufReader::new(key_pem.as_bytes());
197 let key_der = rustls_pemfile::private_key(&mut key_reader)
198 .context("Failed to parse private key PEM")?
199 .context("No private key found in PEM")?;
200
201 let fingerprint = Self::calculate_fingerprint(cert_der.as_ref());
202
203 Ok(Self {
204 certificate: cert_der.to_vec(),
205 private_key: match key_der {
206 PrivateKeyDer::Pkcs8(k) => k.secret_pkcs8_der().to_vec(),
207 PrivateKeyDer::Pkcs1(k) => k.secret_pkcs1_der().to_vec(),
208 PrivateKeyDer::Sec1(k) => k.secret_sec1_der().to_vec(),
209 _ => anyhow::bail!("Unsupported private key format"),
210 },
211 fingerprint,
212 certificate_pem: cert_pem.to_string(),
213 not_after: "unknown".to_string(), })
215 }
216
217 pub fn save_to_files(&self, cert_path: &Path, key_path: &Path) -> Result<()> {
219 if let Some(parent) = cert_path.parent() {
221 std::fs::create_dir_all(parent)?;
222 }
223 if let Some(parent) = key_path.parent() {
224 std::fs::create_dir_all(parent)?;
225 }
226
227 std::fs::write(cert_path, &self.certificate_pem)
229 .with_context(|| format!("Failed to write certificate to {}", cert_path.display()))?;
230
231 let key_pem = format!(
233 "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----\n",
234 STANDARD.encode(&self.private_key)
235 );
236 std::fs::write(key_path, &key_pem)
237 .with_context(|| format!("Failed to write private key to {}", key_path.display()))?;
238
239 #[cfg(unix)]
241 {
242 use std::os::unix::fs::PermissionsExt;
243 std::fs::set_permissions(key_path, std::fs::Permissions::from_mode(0o600))?;
244 }
245
246 tracing::info!(
247 "Saved TLS certificate to {} and key to {}",
248 cert_path.display(),
249 key_path.display()
250 );
251
252 Ok(())
253 }
254
255 pub fn calculate_fingerprint(cert_der: &[u8]) -> String {
257 let hash = digest(&SHA256, cert_der);
258 hash.as_ref()
259 .iter()
260 .map(|b| format!("{:02X}", b))
261 .collect::<Vec<_>>()
262 .join(":")
263 }
264
265 pub fn get_certificate_info(&self, hostname: &str) -> CertificateInfo {
267 CertificateInfo {
268 fingerprint: self.fingerprint.clone(),
269 certificate_pem: self.certificate_pem.clone(),
270 not_before: "see certificate".to_string(),
271 not_after: self.not_after.clone(),
272 subject: format!("CN={}-{}, O={}", CERT_CN_PREFIX, hostname, CERT_ORG_NAME),
273 san: vec![
274 hostname.to_string(),
275 "localhost".to_string(),
276 "127.0.0.1".to_string(),
277 ],
278 }
279 }
280
281 pub fn build_server_config(&self) -> Result<Arc<rustls::ServerConfig>> {
283 let cert = CertificateDer::from(self.certificate.clone());
284 let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(self.private_key.clone()));
285
286 let config = rustls::ServerConfig::builder_with_provider(Arc::new(
288 rustls::crypto::aws_lc_rs::default_provider(),
289 ))
290 .with_safe_default_protocol_versions()
291 .context("Failed to set protocol versions")?
292 .with_no_client_auth()
293 .with_single_cert(vec![cert], key)
294 .context("Failed to build TLS server config")?;
295
296 Ok(Arc::new(config))
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_generate_certificate() {
306 let config = TlsConfig::generate("test.local", &[], Some(30)).unwrap();
307
308 assert!(!config.certificate.is_empty());
309 assert!(!config.private_key.is_empty());
310 assert!(!config.fingerprint.is_empty());
311 assert!(config.certificate_pem.contains("BEGIN CERTIFICATE"));
312
313 assert!(config.fingerprint.contains(':'));
315 let parts: Vec<&str> = config.fingerprint.split(':').collect();
316 assert_eq!(parts.len(), 32); }
318
319 #[test]
320 fn test_fingerprint_calculation() {
321 let data = b"test certificate data";
322 let fingerprint = TlsConfig::calculate_fingerprint(data);
323
324 let parts: Vec<&str> = fingerprint.split(':').collect();
326 assert_eq!(parts.len(), 32);
327
328 for part in parts {
330 assert_eq!(part.len(), 2);
331 assert!(part.chars().all(|c| c.is_ascii_hexdigit()));
332 }
333 }
334
335 #[test]
336 fn test_build_server_config() {
337 let tls = TlsConfig::generate("localhost", &[], None).unwrap();
338 let server_config = tls.build_server_config();
339 assert!(server_config.is_ok());
340 }
341
342 #[test]
343 fn test_additional_sans() {
344 let additional = vec!["192.168.1.100".to_string(), "myserver.local".to_string()];
345 let config = TlsConfig::generate("primary.local", &additional, None).unwrap();
346
347 assert!(!config.certificate.is_empty());
348 }
350
351 #[test]
352 fn test_save_and_load() {
353 let temp_dir = tempfile::tempdir().unwrap();
354 let cert_path = temp_dir.path().join("cert.pem");
355 let key_path = temp_dir.path().join("key.pem");
356
357 let original = TlsConfig::generate("test.local", &[], None).unwrap();
359 original.save_to_files(&cert_path, &key_path).unwrap();
360
361 let loaded = TlsConfig::load_from_files(&cert_path, &key_path).unwrap();
363
364 assert_eq!(original.certificate, loaded.certificate);
365 assert_eq!(original.fingerprint, loaded.fingerprint);
366 }
367}