Skip to main content

spec_ai/spec_ai_api/api/
server.rs

1/// HTTP server implementation with mandatory TLS
2use crate::spec_ai_api::api::graph_handlers::{
3    bootstrap_graph, create_edge, create_node, delete_edge, delete_node, get_edge, get_node,
4    list_edges, list_nodes, stream_changelog, update_node,
5};
6use crate::spec_ai_api::api::handlers::{
7    generate_token, hash_password, health_check, list_agents, query, search, stream_query, AppState,
8};
9use crate::spec_ai_api::api::mesh::{
10    acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
11    register_instance, send_message, MeshClient,
12};
13use crate::spec_ai_api::api::middleware::auth_middleware;
14use crate::spec_ai_api::api::sync_handlers::{
15    bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
16    list_conflicts, list_sync_configs, toggle_sync,
17};
18use crate::spec_ai_api::api::tls::TlsConfig;
19use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
20use crate::spec_ai_api::persistence::Persistence;
21use crate::spec_ai_api::sync::{start_sync_coordinator, SyncCoordinatorConfig};
22use crate::spec_ai_api::tools::ToolRegistry;
23use anyhow::{Context, Result};
24use axum::{
25    middleware,
26    routing::{delete, get, post, put},
27    Json, Router,
28};
29use axum_server::tls_rustls::RustlsConfig;
30use std::net::SocketAddr;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tower_http::cors::{Any, CorsLayer};
34use tower_http::trace::TraceLayer;
35
36/// Install the rustls crypto provider (call once at startup)
37fn install_crypto_provider() {
38    // Install aws-lc-rs as the default crypto provider for rustls
39    // This is required because rustls 0.23+ doesn't auto-select a provider
40    let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
41}
42
43/// API server configuration
44#[derive(Debug, Clone)]
45pub struct ApiConfig {
46    /// Server host address
47    pub host: String,
48    /// Server port
49    pub port: u16,
50    /// Optional API key for authentication (legacy, prefer token auth)
51    pub api_key: Option<String>,
52    /// Enable CORS
53    pub enable_cors: bool,
54    /// Path to TLS certificate file (PEM format)
55    /// If not provided, a self-signed certificate is generated
56    pub tls_cert_path: Option<PathBuf>,
57    /// Path to TLS private key file (PEM format)
58    pub tls_key_path: Option<PathBuf>,
59    /// Additional Subject Alternative Names for generated certificate
60    pub tls_san: Vec<String>,
61    /// Certificate validity in days (for generated certs)
62    pub tls_validity_days: u32,
63}
64
65impl Default for ApiConfig {
66    fn default() -> Self {
67        Self {
68            host: "127.0.0.1".to_string(),
69            port: 3000,
70            api_key: None,
71            enable_cors: true,
72            tls_cert_path: None,
73            tls_key_path: None,
74            tls_san: Vec::new(),
75            tls_validity_days: 365,
76        }
77    }
78}
79
80impl ApiConfig {
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    pub fn with_host(mut self, host: impl Into<String>) -> Self {
86        self.host = host.into();
87        self
88    }
89
90    pub fn with_port(mut self, port: u16) -> Self {
91        self.port = port;
92        self
93    }
94
95    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
96        self.api_key = Some(api_key.into());
97        self
98    }
99
100    pub fn with_cors(mut self, enable: bool) -> Self {
101        self.enable_cors = enable;
102        self
103    }
104
105    pub fn with_tls_cert(
106        mut self,
107        cert_path: impl Into<PathBuf>,
108        key_path: impl Into<PathBuf>,
109    ) -> Self {
110        self.tls_cert_path = Some(cert_path.into());
111        self.tls_key_path = Some(key_path.into());
112        self
113    }
114
115    pub fn with_tls_san(mut self, san: Vec<String>) -> Self {
116        self.tls_san = san;
117        self
118    }
119
120    pub fn with_tls_validity(mut self, days: u32) -> Self {
121        self.tls_validity_days = days;
122        self
123    }
124
125    pub fn bind_address(&self) -> String {
126        format!("{}:{}", self.host, self.port)
127    }
128}
129
130/// API server with mandatory TLS
131pub struct ApiServer {
132    config: ApiConfig,
133    state: AppState,
134    tls_config: TlsConfig,
135}
136
137impl ApiServer {
138    /// Create a new API server with TLS
139    ///
140    /// If no certificate is provided in config, a self-signed certificate is generated.
141    pub fn new(
142        config: ApiConfig,
143        persistence: Persistence,
144        agent_registry: Arc<AgentRegistry>,
145        tool_registry: Arc<ToolRegistry>,
146        app_config: AppConfig,
147    ) -> Result<Self> {
148        // Install crypto provider for rustls (idempotent, safe to call multiple times)
149        install_crypto_provider();
150
151        let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
152
153        // Initialize TLS - either load from files or generate self-signed
154        let tls_config = if let (Some(cert_path), Some(key_path)) =
155            (&config.tls_cert_path, &config.tls_key_path)
156        {
157            TlsConfig::load_from_files(cert_path, key_path)
158                .context("Failed to load TLS certificate")?
159        } else {
160            let tls = TlsConfig::generate(
161                &config.host,
162                &config.tls_san,
163                Some(config.tls_validity_days),
164            )
165            .context("Failed to generate TLS certificate")?;
166
167            // Save generated cert for potential reuse
168            let cert_dir = dirs_next::home_dir()
169                .unwrap_or_else(|| PathBuf::from("."))
170                .join(".spec-ai")
171                .join("tls");
172            let cert_path = cert_dir.join("server.crt");
173            let key_path = cert_dir.join("server.key");
174
175            if let Err(e) = tls.save_to_files(&cert_path, &key_path) {
176                tracing::warn!("Could not save generated TLS certificate: {}", e);
177            } else {
178                tracing::info!(
179                    "Saved TLS certificate to {} (fingerprint: {})",
180                    cert_path.display(),
181                    tls.fingerprint
182                );
183            }
184
185            tls
186        };
187
188        tracing::info!(
189            "TLS initialized with certificate fingerprint: {}",
190            tls_config.fingerprint
191        );
192
193        Ok(Self {
194            config,
195            state,
196            tls_config,
197        })
198    }
199
200    /// Get the mesh registry for self-registration
201    pub fn mesh_registry(&self) -> &crate::spec_ai_api::api::mesh::MeshRegistry {
202        &self.state.mesh_registry
203    }
204
205    /// Get the TLS configuration (for certificate info)
206    pub fn tls_config(&self) -> &TlsConfig {
207        &self.tls_config
208    }
209
210    /// Get the certificate fingerprint
211    pub fn certificate_fingerprint(&self) -> &str {
212        &self.tls_config.fingerprint
213    }
214
215    /// Build the router with all routes
216    fn build_router(&self) -> Router {
217        // Create certificate info for the endpoint
218        let cert_info = self.tls_config.get_certificate_info(&self.config.host);
219
220        // Public routes that don't require authentication
221        let public_routes = Router::new()
222            // Health endpoint is always public
223            .route("/health", get(health_check))
224            // Certificate info endpoint - clients can use this to get/verify the fingerprint
225            .route("/cert", get(move || async move { Json(cert_info.clone()) }))
226            // Auth endpoints are public (needed to get tokens)
227            .route("/auth/token", post(generate_token))
228            .route("/auth/hash", post(hash_password));
229
230        // Protected routes that require authentication when enabled
231        let protected_routes = Router::new()
232            // Info endpoints
233            .route("/agents", get(list_agents))
234            // Query endpoints
235            .route("/query", post(query))
236            .route("/stream", post(stream_query))
237            // Search endpoint
238            .route("/api/search", post(search))
239            // Mesh registry endpoints
240            .route("/registry/register", post(register_instance::<AppState>))
241            .route("/registry/agents", get(list_instances::<AppState>))
242            .route(
243                "/registry/heartbeat/{instance_id}",
244                post(heartbeat::<AppState>),
245            )
246            .route(
247                "/registry/deregister/{instance_id}",
248                delete(deregister_instance::<AppState>),
249            )
250            // Message routing endpoints
251            .route(
252                "/messages/send/{source_instance}",
253                post(send_message::<AppState>),
254            )
255            .route("/messages/{instance_id}", get(get_messages::<AppState>))
256            .route(
257                "/messages/ack/{instance_id}",
258                post(acknowledge_messages::<AppState>),
259            )
260            // Graph sync endpoints
261            .route("/sync/request", post(handle_sync_request))
262            .route("/sync/apply", post(handle_sync_apply))
263            .route(
264                "/sync/status/{session_id}/{graph_name}",
265                get(get_sync_status),
266            )
267            .route("/sync/enable/{session_id}/{graph_name}", post(toggle_sync))
268            .route("/sync/configs/{session_id}", get(list_sync_configs))
269            .route("/sync/bulk/{session_id}", post(bulk_toggle_sync))
270            .route(
271                "/sync/configure/{session_id}/{graph_name}",
272                post(configure_sync),
273            )
274            .route("/sync/conflicts", get(list_conflicts))
275            // Graph CRUD endpoints
276            .route("/graph/nodes", get(list_nodes))
277            .route("/graph/nodes", post(create_node))
278            .route("/graph/nodes/{node_id}", get(get_node))
279            .route("/graph/nodes/{node_id}", put(update_node))
280            .route("/graph/nodes/{node_id}", delete(delete_node))
281            .route("/graph/edges", get(list_edges))
282            .route("/graph/edges", post(create_edge))
283            .route("/graph/edges/{edge_id}", get(get_edge))
284            .route("/graph/edges/{edge_id}", delete(delete_edge))
285            .route("/graph/stream", get(stream_changelog))
286            // Bootstrap endpoint
287            .route("/bootstrap", post(bootstrap_graph))
288            // Apply auth middleware to protected routes
289            .layer(middleware::from_fn_with_state(
290                self.state.auth_service.clone(),
291                auth_middleware,
292            ));
293
294        // Merge public and protected routes
295        let mut router = Router::new()
296            .merge(public_routes)
297            .merge(protected_routes)
298            .with_state(self.state.clone());
299
300        // Add CORS if enabled
301        if self.config.enable_cors {
302            let cors = CorsLayer::new()
303                .allow_origin(Any)
304                .allow_methods(Any)
305                .allow_headers(Any);
306            router = router.layer(cors);
307        }
308
309        // Add tracing
310        router = router.layer(TraceLayer::new_for_http());
311
312        router
313    }
314
315    /// Run the server with TLS
316    pub async fn run(self) -> Result<()> {
317        // Start sync coordinator if sync is enabled
318        if self.state.config.sync.enabled {
319            self.start_sync_coordinator_background();
320        }
321
322        let app = self.build_router();
323        let bind_addr: SocketAddr = self
324            .config
325            .bind_address()
326            .parse()
327            .context("Invalid bind address")?;
328
329        // Build rustls config
330        let rustls_config = RustlsConfig::from_der(
331            vec![self.tls_config.certificate.clone()],
332            self.tls_config.private_key.clone(),
333        )
334        .await
335        .context("Failed to create TLS config")?;
336
337        tracing::info!(
338            "Starting HTTPS server on {} (fingerprint: {})",
339            bind_addr,
340            self.tls_config.fingerprint
341        );
342
343        axum_server::bind_rustls(bind_addr, rustls_config)
344            .serve(app.into_make_service())
345            .await
346            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
347
348        Ok(())
349    }
350
351    /// Start the sync coordinator as a background task
352    fn start_sync_coordinator_background(&self) {
353        let persistence = Arc::new(self.state.persistence.clone());
354        let mesh_registry = Arc::new(self.state.mesh_registry.clone());
355        let mesh_client = Arc::new(MeshClient::new("localhost", self.config.port));
356        let sync_config = SyncCoordinatorConfig::from(&self.state.config.sync);
357
358        // Apply configured namespaces
359        for ns in &self.state.config.sync.namespaces {
360            if let Err(e) =
361                self.state
362                    .persistence
363                    .graph_set_sync_enabled(&ns.session_id, &ns.graph_name, true)
364            {
365                tracing::warn!(
366                    "Failed to enable sync for {}/{}: {}",
367                    ns.session_id,
368                    ns.graph_name,
369                    e
370                );
371            }
372        }
373
374        // Spawn the sync coordinator
375        tokio::spawn(async move {
376            let _handle =
377                start_sync_coordinator(persistence, mesh_registry, mesh_client, sync_config).await;
378            // The coordinator runs indefinitely
379        });
380
381        tracing::info!(
382            "Started sync coordinator with {} configured namespaces",
383            self.state.config.sync.namespaces.len()
384        );
385    }
386
387    /// Run the server with TLS and graceful shutdown
388    pub async fn run_with_shutdown(
389        self,
390        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
391    ) -> Result<()> {
392        // Start sync coordinator if sync is enabled
393        if self.state.config.sync.enabled {
394            self.start_sync_coordinator_background();
395        }
396
397        let app = self.build_router();
398        let bind_addr: SocketAddr = self
399            .config
400            .bind_address()
401            .parse()
402            .context("Invalid bind address")?;
403
404        // Build rustls config
405        let rustls_config = RustlsConfig::from_der(
406            vec![self.tls_config.certificate.clone()],
407            self.tls_config.private_key.clone(),
408        )
409        .await
410        .context("Failed to create TLS config")?;
411
412        tracing::info!(
413            "Starting HTTPS server on {} (fingerprint: {})",
414            bind_addr,
415            self.tls_config.fingerprint
416        );
417
418        // Create handle for graceful shutdown
419        let handle = axum_server::Handle::new();
420        let handle_clone = handle.clone();
421
422        // Spawn shutdown listener
423        tokio::spawn(async move {
424            shutdown_signal.await;
425            handle_clone.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
426        });
427
428        axum_server::bind_rustls(bind_addr, rustls_config)
429            .handle(handle)
430            .serve(app.into_make_service())
431            .await
432            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
433
434        Ok(())
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_api_config_default() {
444        let config = ApiConfig::default();
445        assert_eq!(config.host, "127.0.0.1");
446        assert_eq!(config.port, 3000);
447        assert!(config.api_key.is_none());
448        assert!(config.enable_cors);
449    }
450
451    #[test]
452    fn test_api_config_builder() {
453        let config = ApiConfig::new()
454            .with_host("0.0.0.0")
455            .with_port(8080)
456            .with_api_key("secret123")
457            .with_cors(false);
458
459        assert_eq!(config.host, "0.0.0.0");
460        assert_eq!(config.port, 8080);
461        assert_eq!(config.api_key, Some("secret123".to_string()));
462        assert!(!config.enable_cors);
463    }
464
465    #[test]
466    fn test_bind_address() {
467        let config = ApiConfig::new().with_host("localhost").with_port(5000);
468
469        assert_eq!(config.bind_address(), "localhost:5000");
470    }
471}