1use crate::spec_ai_api::api::graph_handlers::{
3 bootstrap_graph, create_edge, create_node, delete_edge, delete_node, get_edge, get_node,
4 list_edges, list_nodes, stream_changelog, update_node,
5};
6use crate::spec_ai_api::api::handlers::{
7 AppState, generate_token, hash_password, health_check, list_agents, query, search, stream_query,
8};
9use crate::spec_ai_api::api::mesh::{
10 MeshClient, acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
11 register_instance, send_message,
12};
13use crate::spec_ai_api::api::middleware::auth_middleware;
14use crate::spec_ai_api::api::sync_handlers::{
15 bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
16 list_conflicts, list_sync_configs, toggle_sync,
17};
18use crate::spec_ai_api::api::tls::TlsConfig;
19use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
20use crate::spec_ai_api::persistence::Persistence;
21use crate::spec_ai_api::sync::{SyncCoordinatorConfig, start_sync_coordinator};
22use crate::spec_ai_api::tools::ToolRegistry;
23use anyhow::{Context, Result};
24use axum::{
25 Json, Router, middleware,
26 routing::{delete, get, post, put},
27};
28use axum_server::tls_rustls::RustlsConfig;
29use std::net::SocketAddr;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tower_http::cors::{Any, CorsLayer};
33use tower_http::trace::TraceLayer;
34
35fn install_crypto_provider() {
37 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
40}
41
42#[derive(Debug, Clone)]
44pub struct ApiConfig {
45 pub host: String,
47 pub port: u16,
49 pub api_key: Option<String>,
51 pub enable_cors: bool,
53 pub tls_cert_path: Option<PathBuf>,
56 pub tls_key_path: Option<PathBuf>,
58 pub tls_san: Vec<String>,
60 pub tls_validity_days: u32,
62}
63
64impl Default for ApiConfig {
65 fn default() -> Self {
66 Self {
67 host: "127.0.0.1".to_string(),
68 port: 3000,
69 api_key: None,
70 enable_cors: true,
71 tls_cert_path: None,
72 tls_key_path: None,
73 tls_san: Vec::new(),
74 tls_validity_days: 365,
75 }
76 }
77}
78
79impl ApiConfig {
80 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn with_host(mut self, host: impl Into<String>) -> Self {
85 self.host = host.into();
86 self
87 }
88
89 pub fn with_port(mut self, port: u16) -> Self {
90 self.port = port;
91 self
92 }
93
94 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
95 self.api_key = Some(api_key.into());
96 self
97 }
98
99 pub fn with_cors(mut self, enable: bool) -> Self {
100 self.enable_cors = enable;
101 self
102 }
103
104 pub fn with_tls_cert(
105 mut self,
106 cert_path: impl Into<PathBuf>,
107 key_path: impl Into<PathBuf>,
108 ) -> Self {
109 self.tls_cert_path = Some(cert_path.into());
110 self.tls_key_path = Some(key_path.into());
111 self
112 }
113
114 pub fn with_tls_san(mut self, san: Vec<String>) -> Self {
115 self.tls_san = san;
116 self
117 }
118
119 pub fn with_tls_validity(mut self, days: u32) -> Self {
120 self.tls_validity_days = days;
121 self
122 }
123
124 pub fn bind_address(&self) -> String {
125 format!("{}:{}", self.host, self.port)
126 }
127}
128
129pub struct ApiServer {
131 config: ApiConfig,
132 state: AppState,
133 tls_config: TlsConfig,
134}
135
136impl ApiServer {
137 pub fn new(
141 config: ApiConfig,
142 persistence: Persistence,
143 agent_registry: Arc<AgentRegistry>,
144 tool_registry: Arc<ToolRegistry>,
145 app_config: AppConfig,
146 ) -> Result<Self> {
147 install_crypto_provider();
149
150 let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
151
152 let tls_config = if let (Some(cert_path), Some(key_path)) =
154 (&config.tls_cert_path, &config.tls_key_path)
155 {
156 TlsConfig::load_from_files(cert_path, key_path)
157 .context("Failed to load TLS certificate")?
158 } else {
159 let tls = TlsConfig::generate(
160 &config.host,
161 &config.tls_san,
162 Some(config.tls_validity_days),
163 )
164 .context("Failed to generate TLS certificate")?;
165
166 let cert_dir = dirs_next::home_dir()
168 .unwrap_or_else(|| PathBuf::from("."))
169 .join(".spec-ai")
170 .join("tls");
171 let cert_path = cert_dir.join("server.crt");
172 let key_path = cert_dir.join("server.key");
173
174 if let Err(e) = tls.save_to_files(&cert_path, &key_path) {
175 tracing::warn!("Could not save generated TLS certificate: {}", e);
176 } else {
177 tracing::info!(
178 "Saved TLS certificate to {} (fingerprint: {})",
179 cert_path.display(),
180 tls.fingerprint
181 );
182 }
183
184 tls
185 };
186
187 tracing::info!(
188 "TLS initialized with certificate fingerprint: {}",
189 tls_config.fingerprint
190 );
191
192 Ok(Self {
193 config,
194 state,
195 tls_config,
196 })
197 }
198
199 pub fn mesh_registry(&self) -> &crate::spec_ai_api::api::mesh::MeshRegistry {
201 &self.state.mesh_registry
202 }
203
204 pub fn tls_config(&self) -> &TlsConfig {
206 &self.tls_config
207 }
208
209 pub fn certificate_fingerprint(&self) -> &str {
211 &self.tls_config.fingerprint
212 }
213
214 fn build_router(&self) -> Router {
216 let cert_info = self.tls_config.get_certificate_info(&self.config.host);
218
219 let public_routes = Router::new()
221 .route("/health", get(health_check))
223 .route("/cert", get(move || async move { Json(cert_info.clone()) }))
225 .route("/auth/token", post(generate_token))
227 .route("/auth/hash", post(hash_password));
228
229 let protected_routes = Router::new()
231 .route("/agents", get(list_agents))
233 .route("/query", post(query))
235 .route("/stream", post(stream_query))
236 .route("/api/search", post(search))
238 .route("/registry/register", post(register_instance::<AppState>))
240 .route("/registry/agents", get(list_instances::<AppState>))
241 .route(
242 "/registry/heartbeat/{instance_id}",
243 post(heartbeat::<AppState>),
244 )
245 .route(
246 "/registry/deregister/{instance_id}",
247 delete(deregister_instance::<AppState>),
248 )
249 .route(
251 "/messages/send/{source_instance}",
252 post(send_message::<AppState>),
253 )
254 .route("/messages/{instance_id}", get(get_messages::<AppState>))
255 .route(
256 "/messages/ack/{instance_id}",
257 post(acknowledge_messages::<AppState>),
258 )
259 .route("/sync/request", post(handle_sync_request))
261 .route("/sync/apply", post(handle_sync_apply))
262 .route(
263 "/sync/status/{session_id}/{graph_name}",
264 get(get_sync_status),
265 )
266 .route("/sync/enable/{session_id}/{graph_name}", post(toggle_sync))
267 .route("/sync/configs/{session_id}", get(list_sync_configs))
268 .route("/sync/bulk/{session_id}", post(bulk_toggle_sync))
269 .route(
270 "/sync/configure/{session_id}/{graph_name}",
271 post(configure_sync),
272 )
273 .route("/sync/conflicts", get(list_conflicts))
274 .route("/graph/nodes", get(list_nodes))
276 .route("/graph/nodes", post(create_node))
277 .route("/graph/nodes/{node_id}", get(get_node))
278 .route("/graph/nodes/{node_id}", put(update_node))
279 .route("/graph/nodes/{node_id}", delete(delete_node))
280 .route("/graph/edges", get(list_edges))
281 .route("/graph/edges", post(create_edge))
282 .route("/graph/edges/{edge_id}", get(get_edge))
283 .route("/graph/edges/{edge_id}", delete(delete_edge))
284 .route("/graph/stream", get(stream_changelog))
285 .route("/bootstrap", post(bootstrap_graph))
287 .layer(middleware::from_fn_with_state(
289 self.state.auth_service.clone(),
290 auth_middleware,
291 ));
292
293 let mut router = Router::new()
295 .merge(public_routes)
296 .merge(protected_routes)
297 .with_state(self.state.clone());
298
299 if self.config.enable_cors {
301 let cors = CorsLayer::new()
302 .allow_origin(Any)
303 .allow_methods(Any)
304 .allow_headers(Any);
305 router = router.layer(cors);
306 }
307
308 router = router.layer(TraceLayer::new_for_http());
310
311 router
312 }
313
314 pub async fn run(self) -> Result<()> {
316 if self.state.config.sync.enabled {
318 self.start_sync_coordinator_background();
319 }
320
321 let app = self.build_router();
322 let bind_addr: SocketAddr = self
323 .config
324 .bind_address()
325 .parse()
326 .context("Invalid bind address")?;
327
328 let rustls_config = RustlsConfig::from_der(
330 vec![self.tls_config.certificate.clone()],
331 self.tls_config.private_key.clone(),
332 )
333 .await
334 .context("Failed to create TLS config")?;
335
336 tracing::info!(
337 "Starting HTTPS server on {} (fingerprint: {})",
338 bind_addr,
339 self.tls_config.fingerprint
340 );
341
342 axum_server::bind_rustls(bind_addr, rustls_config)
343 .serve(app.into_make_service())
344 .await
345 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
346
347 Ok(())
348 }
349
350 fn start_sync_coordinator_background(&self) {
352 let persistence = Arc::new(self.state.persistence.clone());
353 let mesh_registry = Arc::new(self.state.mesh_registry.clone());
354 let mesh_client = Arc::new(MeshClient::new("localhost", self.config.port));
355 let sync_config = SyncCoordinatorConfig::from(&self.state.config.sync);
356
357 for ns in &self.state.config.sync.namespaces {
359 if let Err(e) =
360 self.state
361 .persistence
362 .graph_set_sync_enabled(&ns.session_id, &ns.graph_name, true)
363 {
364 tracing::warn!(
365 "Failed to enable sync for {}/{}: {}",
366 ns.session_id,
367 ns.graph_name,
368 e
369 );
370 }
371 }
372
373 tokio::spawn(async move {
375 let _handle =
376 start_sync_coordinator(persistence, mesh_registry, mesh_client, sync_config).await;
377 });
379
380 tracing::info!(
381 "Started sync coordinator with {} configured namespaces",
382 self.state.config.sync.namespaces.len()
383 );
384 }
385
386 pub async fn run_with_shutdown(
388 self,
389 shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
390 ) -> Result<()> {
391 if self.state.config.sync.enabled {
393 self.start_sync_coordinator_background();
394 }
395
396 let app = self.build_router();
397 let bind_addr: SocketAddr = self
398 .config
399 .bind_address()
400 .parse()
401 .context("Invalid bind address")?;
402
403 let rustls_config = RustlsConfig::from_der(
405 vec![self.tls_config.certificate.clone()],
406 self.tls_config.private_key.clone(),
407 )
408 .await
409 .context("Failed to create TLS config")?;
410
411 tracing::info!(
412 "Starting HTTPS server on {} (fingerprint: {})",
413 bind_addr,
414 self.tls_config.fingerprint
415 );
416
417 let handle = axum_server::Handle::new();
419 let handle_clone = handle.clone();
420
421 tokio::spawn(async move {
423 shutdown_signal.await;
424 handle_clone.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
425 });
426
427 axum_server::bind_rustls(bind_addr, rustls_config)
428 .handle(handle)
429 .serve(app.into_make_service())
430 .await
431 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
432
433 Ok(())
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_api_config_default() {
443 let config = ApiConfig::default();
444 assert_eq!(config.host, "127.0.0.1");
445 assert_eq!(config.port, 3000);
446 assert!(config.api_key.is_none());
447 assert!(config.enable_cors);
448 }
449
450 #[test]
451 fn test_api_config_builder() {
452 let config = ApiConfig::new()
453 .with_host("0.0.0.0")
454 .with_port(8080)
455 .with_api_key("secret123")
456 .with_cors(false);
457
458 assert_eq!(config.host, "0.0.0.0");
459 assert_eq!(config.port, 8080);
460 assert_eq!(config.api_key, Some("secret123".to_string()));
461 assert!(!config.enable_cors);
462 }
463
464 #[test]
465 fn test_bind_address() {
466 let config = ApiConfig::new().with_host("localhost").with_port(5000);
467
468 assert_eq!(config.bind_address(), "localhost:5000");
469 }
470}