1use crate::spec_ai_api::api::graph_handlers::{
3 bootstrap_graph, create_edge, create_node, delete_edge, delete_node, get_edge, get_node,
4 list_edges, list_nodes, stream_changelog, update_node,
5};
6use crate::spec_ai_api::api::handlers::{
7 generate_token, hash_password, health_check, list_agents, query, search, stream_query, AppState,
8};
9use crate::spec_ai_api::api::mesh::{
10 acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
11 register_instance, send_message, MeshClient,
12};
13use crate::spec_ai_api::api::middleware::auth_middleware;
14use crate::spec_ai_api::api::sync_handlers::{
15 bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
16 list_conflicts, list_sync_configs, toggle_sync,
17};
18use crate::spec_ai_api::api::tls::TlsConfig;
19use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
20use crate::spec_ai_api::persistence::Persistence;
21use crate::spec_ai_api::sync::{start_sync_coordinator, SyncCoordinatorConfig};
22use crate::spec_ai_api::tools::ToolRegistry;
23use anyhow::{Context, Result};
24use axum::{
25 middleware,
26 routing::{delete, get, post, put},
27 Json, Router,
28};
29use axum_server::tls_rustls::RustlsConfig;
30use std::net::SocketAddr;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tower_http::cors::{Any, CorsLayer};
34use tower_http::trace::TraceLayer;
35
36fn install_crypto_provider() {
38 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
41}
42
43#[derive(Debug, Clone)]
45pub struct ApiConfig {
46 pub host: String,
48 pub port: u16,
50 pub api_key: Option<String>,
52 pub enable_cors: bool,
54 pub tls_cert_path: Option<PathBuf>,
57 pub tls_key_path: Option<PathBuf>,
59 pub tls_san: Vec<String>,
61 pub tls_validity_days: u32,
63}
64
65impl Default for ApiConfig {
66 fn default() -> Self {
67 Self {
68 host: "127.0.0.1".to_string(),
69 port: 3000,
70 api_key: None,
71 enable_cors: true,
72 tls_cert_path: None,
73 tls_key_path: None,
74 tls_san: Vec::new(),
75 tls_validity_days: 365,
76 }
77 }
78}
79
80impl ApiConfig {
81 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn with_host(mut self, host: impl Into<String>) -> Self {
86 self.host = host.into();
87 self
88 }
89
90 pub fn with_port(mut self, port: u16) -> Self {
91 self.port = port;
92 self
93 }
94
95 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
96 self.api_key = Some(api_key.into());
97 self
98 }
99
100 pub fn with_cors(mut self, enable: bool) -> Self {
101 self.enable_cors = enable;
102 self
103 }
104
105 pub fn with_tls_cert(
106 mut self,
107 cert_path: impl Into<PathBuf>,
108 key_path: impl Into<PathBuf>,
109 ) -> Self {
110 self.tls_cert_path = Some(cert_path.into());
111 self.tls_key_path = Some(key_path.into());
112 self
113 }
114
115 pub fn with_tls_san(mut self, san: Vec<String>) -> Self {
116 self.tls_san = san;
117 self
118 }
119
120 pub fn with_tls_validity(mut self, days: u32) -> Self {
121 self.tls_validity_days = days;
122 self
123 }
124
125 pub fn bind_address(&self) -> String {
126 format!("{}:{}", self.host, self.port)
127 }
128}
129
130pub struct ApiServer {
132 config: ApiConfig,
133 state: AppState,
134 tls_config: TlsConfig,
135}
136
137impl ApiServer {
138 pub fn new(
142 config: ApiConfig,
143 persistence: Persistence,
144 agent_registry: Arc<AgentRegistry>,
145 tool_registry: Arc<ToolRegistry>,
146 app_config: AppConfig,
147 ) -> Result<Self> {
148 install_crypto_provider();
150
151 let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
152
153 let tls_config = if let (Some(cert_path), Some(key_path)) =
155 (&config.tls_cert_path, &config.tls_key_path)
156 {
157 TlsConfig::load_from_files(cert_path, key_path)
158 .context("Failed to load TLS certificate")?
159 } else {
160 let tls = TlsConfig::generate(
161 &config.host,
162 &config.tls_san,
163 Some(config.tls_validity_days),
164 )
165 .context("Failed to generate TLS certificate")?;
166
167 let cert_dir = dirs_next::home_dir()
169 .unwrap_or_else(|| PathBuf::from("."))
170 .join(".spec-ai")
171 .join("tls");
172 let cert_path = cert_dir.join("server.crt");
173 let key_path = cert_dir.join("server.key");
174
175 if let Err(e) = tls.save_to_files(&cert_path, &key_path) {
176 tracing::warn!("Could not save generated TLS certificate: {}", e);
177 } else {
178 tracing::info!(
179 "Saved TLS certificate to {} (fingerprint: {})",
180 cert_path.display(),
181 tls.fingerprint
182 );
183 }
184
185 tls
186 };
187
188 tracing::info!(
189 "TLS initialized with certificate fingerprint: {}",
190 tls_config.fingerprint
191 );
192
193 Ok(Self {
194 config,
195 state,
196 tls_config,
197 })
198 }
199
200 pub fn mesh_registry(&self) -> &crate::spec_ai_api::api::mesh::MeshRegistry {
202 &self.state.mesh_registry
203 }
204
205 pub fn tls_config(&self) -> &TlsConfig {
207 &self.tls_config
208 }
209
210 pub fn certificate_fingerprint(&self) -> &str {
212 &self.tls_config.fingerprint
213 }
214
215 fn build_router(&self) -> Router {
217 let cert_info = self.tls_config.get_certificate_info(&self.config.host);
219
220 let public_routes = Router::new()
222 .route("/health", get(health_check))
224 .route("/cert", get(move || async move { Json(cert_info.clone()) }))
226 .route("/auth/token", post(generate_token))
228 .route("/auth/hash", post(hash_password));
229
230 let protected_routes = Router::new()
232 .route("/agents", get(list_agents))
234 .route("/query", post(query))
236 .route("/stream", post(stream_query))
237 .route("/api/search", post(search))
239 .route("/registry/register", post(register_instance::<AppState>))
241 .route("/registry/agents", get(list_instances::<AppState>))
242 .route(
243 "/registry/heartbeat/{instance_id}",
244 post(heartbeat::<AppState>),
245 )
246 .route(
247 "/registry/deregister/{instance_id}",
248 delete(deregister_instance::<AppState>),
249 )
250 .route(
252 "/messages/send/{source_instance}",
253 post(send_message::<AppState>),
254 )
255 .route("/messages/{instance_id}", get(get_messages::<AppState>))
256 .route(
257 "/messages/ack/{instance_id}",
258 post(acknowledge_messages::<AppState>),
259 )
260 .route("/sync/request", post(handle_sync_request))
262 .route("/sync/apply", post(handle_sync_apply))
263 .route(
264 "/sync/status/{session_id}/{graph_name}",
265 get(get_sync_status),
266 )
267 .route("/sync/enable/{session_id}/{graph_name}", post(toggle_sync))
268 .route("/sync/configs/{session_id}", get(list_sync_configs))
269 .route("/sync/bulk/{session_id}", post(bulk_toggle_sync))
270 .route(
271 "/sync/configure/{session_id}/{graph_name}",
272 post(configure_sync),
273 )
274 .route("/sync/conflicts", get(list_conflicts))
275 .route("/graph/nodes", get(list_nodes))
277 .route("/graph/nodes", post(create_node))
278 .route("/graph/nodes/{node_id}", get(get_node))
279 .route("/graph/nodes/{node_id}", put(update_node))
280 .route("/graph/nodes/{node_id}", delete(delete_node))
281 .route("/graph/edges", get(list_edges))
282 .route("/graph/edges", post(create_edge))
283 .route("/graph/edges/{edge_id}", get(get_edge))
284 .route("/graph/edges/{edge_id}", delete(delete_edge))
285 .route("/graph/stream", get(stream_changelog))
286 .route("/bootstrap", post(bootstrap_graph))
288 .layer(middleware::from_fn_with_state(
290 self.state.auth_service.clone(),
291 auth_middleware,
292 ));
293
294 let mut router = Router::new()
296 .merge(public_routes)
297 .merge(protected_routes)
298 .with_state(self.state.clone());
299
300 if self.config.enable_cors {
302 let cors = CorsLayer::new()
303 .allow_origin(Any)
304 .allow_methods(Any)
305 .allow_headers(Any);
306 router = router.layer(cors);
307 }
308
309 router = router.layer(TraceLayer::new_for_http());
311
312 router
313 }
314
315 pub async fn run(self) -> Result<()> {
317 if self.state.config.sync.enabled {
319 self.start_sync_coordinator_background();
320 }
321
322 let app = self.build_router();
323 let bind_addr: SocketAddr = self
324 .config
325 .bind_address()
326 .parse()
327 .context("Invalid bind address")?;
328
329 let rustls_config = RustlsConfig::from_der(
331 vec![self.tls_config.certificate.clone()],
332 self.tls_config.private_key.clone(),
333 )
334 .await
335 .context("Failed to create TLS config")?;
336
337 tracing::info!(
338 "Starting HTTPS server on {} (fingerprint: {})",
339 bind_addr,
340 self.tls_config.fingerprint
341 );
342
343 axum_server::bind_rustls(bind_addr, rustls_config)
344 .serve(app.into_make_service())
345 .await
346 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
347
348 Ok(())
349 }
350
351 fn start_sync_coordinator_background(&self) {
353 let persistence = Arc::new(self.state.persistence.clone());
354 let mesh_registry = Arc::new(self.state.mesh_registry.clone());
355 let mesh_client = Arc::new(MeshClient::new("localhost", self.config.port));
356 let sync_config = SyncCoordinatorConfig::from(&self.state.config.sync);
357
358 for ns in &self.state.config.sync.namespaces {
360 if let Err(e) =
361 self.state
362 .persistence
363 .graph_set_sync_enabled(&ns.session_id, &ns.graph_name, true)
364 {
365 tracing::warn!(
366 "Failed to enable sync for {}/{}: {}",
367 ns.session_id,
368 ns.graph_name,
369 e
370 );
371 }
372 }
373
374 tokio::spawn(async move {
376 let _handle =
377 start_sync_coordinator(persistence, mesh_registry, mesh_client, sync_config).await;
378 });
380
381 tracing::info!(
382 "Started sync coordinator with {} configured namespaces",
383 self.state.config.sync.namespaces.len()
384 );
385 }
386
387 pub async fn run_with_shutdown(
389 self,
390 shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
391 ) -> Result<()> {
392 if self.state.config.sync.enabled {
394 self.start_sync_coordinator_background();
395 }
396
397 let app = self.build_router();
398 let bind_addr: SocketAddr = self
399 .config
400 .bind_address()
401 .parse()
402 .context("Invalid bind address")?;
403
404 let rustls_config = RustlsConfig::from_der(
406 vec![self.tls_config.certificate.clone()],
407 self.tls_config.private_key.clone(),
408 )
409 .await
410 .context("Failed to create TLS config")?;
411
412 tracing::info!(
413 "Starting HTTPS server on {} (fingerprint: {})",
414 bind_addr,
415 self.tls_config.fingerprint
416 );
417
418 let handle = axum_server::Handle::new();
420 let handle_clone = handle.clone();
421
422 tokio::spawn(async move {
424 shutdown_signal.await;
425 handle_clone.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
426 });
427
428 axum_server::bind_rustls(bind_addr, rustls_config)
429 .handle(handle)
430 .serve(app.into_make_service())
431 .await
432 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
433
434 Ok(())
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_api_config_default() {
444 let config = ApiConfig::default();
445 assert_eq!(config.host, "127.0.0.1");
446 assert_eq!(config.port, 3000);
447 assert!(config.api_key.is_none());
448 assert!(config.enable_cors);
449 }
450
451 #[test]
452 fn test_api_config_builder() {
453 let config = ApiConfig::new()
454 .with_host("0.0.0.0")
455 .with_port(8080)
456 .with_api_key("secret123")
457 .with_cors(false);
458
459 assert_eq!(config.host, "0.0.0.0");
460 assert_eq!(config.port, 8080);
461 assert_eq!(config.api_key, Some("secret123".to_string()));
462 assert!(!config.enable_cors);
463 }
464
465 #[test]
466 fn test_bind_address() {
467 let config = ApiConfig::new().with_host("localhost").with_port(5000);
468
469 assert_eq!(config.bind_address(), "localhost:5000");
470 }
471}