Skip to main content

spec_ai/spec_ai_api/api/
server.rs

1/// HTTP server implementation with mandatory TLS
2use crate::spec_ai_api::api::graph_handlers::{
3    bootstrap_graph, create_edge, create_node, delete_edge, delete_node, get_edge, get_node,
4    list_edges, list_nodes, stream_changelog, update_node,
5};
6use crate::spec_ai_api::api::handlers::{
7    AppState, generate_token, hash_password, health_check, list_agents, query, search, stream_query,
8};
9use crate::spec_ai_api::api::mesh::{
10    MeshClient, acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
11    register_instance, send_message,
12};
13use crate::spec_ai_api::api::middleware::auth_middleware;
14use crate::spec_ai_api::api::sync_handlers::{
15    bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
16    list_conflicts, list_sync_configs, toggle_sync,
17};
18use crate::spec_ai_api::api::tls::TlsConfig;
19use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
20use crate::spec_ai_api::persistence::Persistence;
21use crate::spec_ai_api::sync::{SyncCoordinatorConfig, start_sync_coordinator};
22use crate::spec_ai_api::tools::ToolRegistry;
23use anyhow::{Context, Result};
24use axum::{
25    Json, Router, middleware,
26    routing::{delete, get, post, put},
27};
28use axum_server::tls_rustls::RustlsConfig;
29use std::net::SocketAddr;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tower_http::cors::{Any, CorsLayer};
33use tower_http::trace::TraceLayer;
34
35/// Install the rustls crypto provider (call once at startup)
36fn install_crypto_provider() {
37    // Install aws-lc-rs as the default crypto provider for rustls
38    // This is required because rustls 0.23+ doesn't auto-select a provider
39    let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
40}
41
42/// API server configuration
43#[derive(Debug, Clone)]
44pub struct ApiConfig {
45    /// Server host address
46    pub host: String,
47    /// Server port
48    pub port: u16,
49    /// Optional API key for authentication (legacy, prefer token auth)
50    pub api_key: Option<String>,
51    /// Enable CORS
52    pub enable_cors: bool,
53    /// Path to TLS certificate file (PEM format)
54    /// If not provided, a self-signed certificate is generated
55    pub tls_cert_path: Option<PathBuf>,
56    /// Path to TLS private key file (PEM format)
57    pub tls_key_path: Option<PathBuf>,
58    /// Additional Subject Alternative Names for generated certificate
59    pub tls_san: Vec<String>,
60    /// Certificate validity in days (for generated certs)
61    pub tls_validity_days: u32,
62}
63
64impl Default for ApiConfig {
65    fn default() -> Self {
66        Self {
67            host: "127.0.0.1".to_string(),
68            port: 3000,
69            api_key: None,
70            enable_cors: true,
71            tls_cert_path: None,
72            tls_key_path: None,
73            tls_san: Vec::new(),
74            tls_validity_days: 365,
75        }
76    }
77}
78
79impl ApiConfig {
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    pub fn with_host(mut self, host: impl Into<String>) -> Self {
85        self.host = host.into();
86        self
87    }
88
89    pub fn with_port(mut self, port: u16) -> Self {
90        self.port = port;
91        self
92    }
93
94    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
95        self.api_key = Some(api_key.into());
96        self
97    }
98
99    pub fn with_cors(mut self, enable: bool) -> Self {
100        self.enable_cors = enable;
101        self
102    }
103
104    pub fn with_tls_cert(
105        mut self,
106        cert_path: impl Into<PathBuf>,
107        key_path: impl Into<PathBuf>,
108    ) -> Self {
109        self.tls_cert_path = Some(cert_path.into());
110        self.tls_key_path = Some(key_path.into());
111        self
112    }
113
114    pub fn with_tls_san(mut self, san: Vec<String>) -> Self {
115        self.tls_san = san;
116        self
117    }
118
119    pub fn with_tls_validity(mut self, days: u32) -> Self {
120        self.tls_validity_days = days;
121        self
122    }
123
124    pub fn bind_address(&self) -> String {
125        format!("{}:{}", self.host, self.port)
126    }
127}
128
129/// API server with mandatory TLS
130pub struct ApiServer {
131    config: ApiConfig,
132    state: AppState,
133    tls_config: TlsConfig,
134}
135
136impl ApiServer {
137    /// Create a new API server with TLS
138    ///
139    /// If no certificate is provided in config, a self-signed certificate is generated.
140    pub fn new(
141        config: ApiConfig,
142        persistence: Persistence,
143        agent_registry: Arc<AgentRegistry>,
144        tool_registry: Arc<ToolRegistry>,
145        app_config: AppConfig,
146    ) -> Result<Self> {
147        // Install crypto provider for rustls (idempotent, safe to call multiple times)
148        install_crypto_provider();
149
150        let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
151
152        // Initialize TLS - either load from files or generate self-signed
153        let tls_config = if let (Some(cert_path), Some(key_path)) =
154            (&config.tls_cert_path, &config.tls_key_path)
155        {
156            TlsConfig::load_from_files(cert_path, key_path)
157                .context("Failed to load TLS certificate")?
158        } else {
159            let tls = TlsConfig::generate(
160                &config.host,
161                &config.tls_san,
162                Some(config.tls_validity_days),
163            )
164            .context("Failed to generate TLS certificate")?;
165
166            // Save generated cert for potential reuse
167            let cert_dir = dirs_next::home_dir()
168                .unwrap_or_else(|| PathBuf::from("."))
169                .join(".spec-ai")
170                .join("tls");
171            let cert_path = cert_dir.join("server.crt");
172            let key_path = cert_dir.join("server.key");
173
174            if let Err(e) = tls.save_to_files(&cert_path, &key_path) {
175                tracing::warn!("Could not save generated TLS certificate: {}", e);
176            } else {
177                tracing::info!(
178                    "Saved TLS certificate to {} (fingerprint: {})",
179                    cert_path.display(),
180                    tls.fingerprint
181                );
182            }
183
184            tls
185        };
186
187        tracing::info!(
188            "TLS initialized with certificate fingerprint: {}",
189            tls_config.fingerprint
190        );
191
192        Ok(Self {
193            config,
194            state,
195            tls_config,
196        })
197    }
198
199    /// Get the mesh registry for self-registration
200    pub fn mesh_registry(&self) -> &crate::spec_ai_api::api::mesh::MeshRegistry {
201        &self.state.mesh_registry
202    }
203
204    /// Get the TLS configuration (for certificate info)
205    pub fn tls_config(&self) -> &TlsConfig {
206        &self.tls_config
207    }
208
209    /// Get the certificate fingerprint
210    pub fn certificate_fingerprint(&self) -> &str {
211        &self.tls_config.fingerprint
212    }
213
214    /// Build the router with all routes
215    fn build_router(&self) -> Router {
216        // Create certificate info for the endpoint
217        let cert_info = self.tls_config.get_certificate_info(&self.config.host);
218
219        // Public routes that don't require authentication
220        let public_routes = Router::new()
221            // Health endpoint is always public
222            .route("/health", get(health_check))
223            // Certificate info endpoint - clients can use this to get/verify the fingerprint
224            .route("/cert", get(move || async move { Json(cert_info.clone()) }))
225            // Auth endpoints are public (needed to get tokens)
226            .route("/auth/token", post(generate_token))
227            .route("/auth/hash", post(hash_password));
228
229        // Protected routes that require authentication when enabled
230        let protected_routes = Router::new()
231            // Info endpoints
232            .route("/agents", get(list_agents))
233            // Query endpoints
234            .route("/query", post(query))
235            .route("/stream", post(stream_query))
236            // Search endpoint
237            .route("/api/search", post(search))
238            // Mesh registry endpoints
239            .route("/registry/register", post(register_instance::<AppState>))
240            .route("/registry/agents", get(list_instances::<AppState>))
241            .route(
242                "/registry/heartbeat/{instance_id}",
243                post(heartbeat::<AppState>),
244            )
245            .route(
246                "/registry/deregister/{instance_id}",
247                delete(deregister_instance::<AppState>),
248            )
249            // Message routing endpoints
250            .route(
251                "/messages/send/{source_instance}",
252                post(send_message::<AppState>),
253            )
254            .route("/messages/{instance_id}", get(get_messages::<AppState>))
255            .route(
256                "/messages/ack/{instance_id}",
257                post(acknowledge_messages::<AppState>),
258            )
259            // Graph sync endpoints
260            .route("/sync/request", post(handle_sync_request))
261            .route("/sync/apply", post(handle_sync_apply))
262            .route(
263                "/sync/status/{session_id}/{graph_name}",
264                get(get_sync_status),
265            )
266            .route("/sync/enable/{session_id}/{graph_name}", post(toggle_sync))
267            .route("/sync/configs/{session_id}", get(list_sync_configs))
268            .route("/sync/bulk/{session_id}", post(bulk_toggle_sync))
269            .route(
270                "/sync/configure/{session_id}/{graph_name}",
271                post(configure_sync),
272            )
273            .route("/sync/conflicts", get(list_conflicts))
274            // Graph CRUD endpoints
275            .route("/graph/nodes", get(list_nodes))
276            .route("/graph/nodes", post(create_node))
277            .route("/graph/nodes/{node_id}", get(get_node))
278            .route("/graph/nodes/{node_id}", put(update_node))
279            .route("/graph/nodes/{node_id}", delete(delete_node))
280            .route("/graph/edges", get(list_edges))
281            .route("/graph/edges", post(create_edge))
282            .route("/graph/edges/{edge_id}", get(get_edge))
283            .route("/graph/edges/{edge_id}", delete(delete_edge))
284            .route("/graph/stream", get(stream_changelog))
285            // Bootstrap endpoint
286            .route("/bootstrap", post(bootstrap_graph))
287            // Apply auth middleware to protected routes
288            .layer(middleware::from_fn_with_state(
289                self.state.auth_service.clone(),
290                auth_middleware,
291            ));
292
293        // Merge public and protected routes
294        let mut router = Router::new()
295            .merge(public_routes)
296            .merge(protected_routes)
297            .with_state(self.state.clone());
298
299        // Add CORS if enabled
300        if self.config.enable_cors {
301            let cors = CorsLayer::new()
302                .allow_origin(Any)
303                .allow_methods(Any)
304                .allow_headers(Any);
305            router = router.layer(cors);
306        }
307
308        // Add tracing
309        router = router.layer(TraceLayer::new_for_http());
310
311        router
312    }
313
314    /// Run the server with TLS
315    pub async fn run(self) -> Result<()> {
316        // Start sync coordinator if sync is enabled
317        if self.state.config.sync.enabled {
318            self.start_sync_coordinator_background();
319        }
320
321        let app = self.build_router();
322        let bind_addr: SocketAddr = self
323            .config
324            .bind_address()
325            .parse()
326            .context("Invalid bind address")?;
327
328        // Build rustls config
329        let rustls_config = RustlsConfig::from_der(
330            vec![self.tls_config.certificate.clone()],
331            self.tls_config.private_key.clone(),
332        )
333        .await
334        .context("Failed to create TLS config")?;
335
336        tracing::info!(
337            "Starting HTTPS server on {} (fingerprint: {})",
338            bind_addr,
339            self.tls_config.fingerprint
340        );
341
342        axum_server::bind_rustls(bind_addr, rustls_config)
343            .serve(app.into_make_service())
344            .await
345            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
346
347        Ok(())
348    }
349
350    /// Start the sync coordinator as a background task
351    fn start_sync_coordinator_background(&self) {
352        let persistence = Arc::new(self.state.persistence.clone());
353        let mesh_registry = Arc::new(self.state.mesh_registry.clone());
354        let mesh_client = Arc::new(MeshClient::new("localhost", self.config.port));
355        let sync_config = SyncCoordinatorConfig::from(&self.state.config.sync);
356
357        // Apply configured namespaces
358        for ns in &self.state.config.sync.namespaces {
359            if let Err(e) =
360                self.state
361                    .persistence
362                    .graph_set_sync_enabled(&ns.session_id, &ns.graph_name, true)
363            {
364                tracing::warn!(
365                    "Failed to enable sync for {}/{}: {}",
366                    ns.session_id,
367                    ns.graph_name,
368                    e
369                );
370            }
371        }
372
373        // Spawn the sync coordinator
374        tokio::spawn(async move {
375            let _handle =
376                start_sync_coordinator(persistence, mesh_registry, mesh_client, sync_config).await;
377            // The coordinator runs indefinitely
378        });
379
380        tracing::info!(
381            "Started sync coordinator with {} configured namespaces",
382            self.state.config.sync.namespaces.len()
383        );
384    }
385
386    /// Run the server with TLS and graceful shutdown
387    pub async fn run_with_shutdown(
388        self,
389        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
390    ) -> Result<()> {
391        // Start sync coordinator if sync is enabled
392        if self.state.config.sync.enabled {
393            self.start_sync_coordinator_background();
394        }
395
396        let app = self.build_router();
397        let bind_addr: SocketAddr = self
398            .config
399            .bind_address()
400            .parse()
401            .context("Invalid bind address")?;
402
403        // Build rustls config
404        let rustls_config = RustlsConfig::from_der(
405            vec![self.tls_config.certificate.clone()],
406            self.tls_config.private_key.clone(),
407        )
408        .await
409        .context("Failed to create TLS config")?;
410
411        tracing::info!(
412            "Starting HTTPS server on {} (fingerprint: {})",
413            bind_addr,
414            self.tls_config.fingerprint
415        );
416
417        // Create handle for graceful shutdown
418        let handle = axum_server::Handle::new();
419        let handle_clone = handle.clone();
420
421        // Spawn shutdown listener
422        tokio::spawn(async move {
423            shutdown_signal.await;
424            handle_clone.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
425        });
426
427        axum_server::bind_rustls(bind_addr, rustls_config)
428            .handle(handle)
429            .serve(app.into_make_service())
430            .await
431            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
432
433        Ok(())
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_api_config_default() {
443        let config = ApiConfig::default();
444        assert_eq!(config.host, "127.0.0.1");
445        assert_eq!(config.port, 3000);
446        assert!(config.api_key.is_none());
447        assert!(config.enable_cors);
448    }
449
450    #[test]
451    fn test_api_config_builder() {
452        let config = ApiConfig::new()
453            .with_host("0.0.0.0")
454            .with_port(8080)
455            .with_api_key("secret123")
456            .with_cors(false);
457
458        assert_eq!(config.host, "0.0.0.0");
459        assert_eq!(config.port, 8080);
460        assert_eq!(config.api_key, Some("secret123".to_string()));
461        assert!(!config.enable_cors);
462    }
463
464    #[test]
465    fn test_bind_address() {
466        let config = ApiConfig::new().with_host("localhost").with_port(5000);
467
468        assert_eq!(config.bind_address(), "localhost:5000");
469    }
470}