spec_ai_api/api/
server.rs

1/// HTTP server implementation
2use crate::api::handlers::{health_check, list_agents, query, stream_query, AppState};
3use crate::api::mesh::{
4    acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
5    register_instance, send_message,
6};
7use crate::api::sync_handlers::{
8    bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
9    list_conflicts, list_sync_configs, toggle_sync,
10};
11use crate::config::{AgentRegistry, AppConfig};
12use crate::persistence::Persistence;
13use crate::tools::ToolRegistry;
14use anyhow::Result;
15use axum::{
16    routing::{delete, get, post},
17    Router,
18};
19use std::sync::Arc;
20use tower_http::cors::{Any, CorsLayer};
21use tower_http::trace::TraceLayer;
22
23/// API server configuration
24#[derive(Debug, Clone)]
25pub struct ApiConfig {
26    /// Server host address
27    pub host: String,
28    /// Server port
29    pub port: u16,
30    /// Optional API key for authentication
31    pub api_key: Option<String>,
32    /// Enable CORS
33    pub enable_cors: bool,
34}
35
36impl Default for ApiConfig {
37    fn default() -> Self {
38        Self {
39            host: "127.0.0.1".to_string(),
40            port: 3000,
41            api_key: None,
42            enable_cors: true,
43        }
44    }
45}
46
47impl ApiConfig {
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    pub fn with_host(mut self, host: impl Into<String>) -> Self {
53        self.host = host.into();
54        self
55    }
56
57    pub fn with_port(mut self, port: u16) -> Self {
58        self.port = port;
59        self
60    }
61
62    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
63        self.api_key = Some(api_key.into());
64        self
65    }
66
67    pub fn with_cors(mut self, enable: bool) -> Self {
68        self.enable_cors = enable;
69        self
70    }
71
72    pub fn bind_address(&self) -> String {
73        format!("{}:{}", self.host, self.port)
74    }
75}
76
77/// API server
78pub struct ApiServer {
79    config: ApiConfig,
80    state: AppState,
81}
82
83impl ApiServer {
84    /// Create a new API server
85    pub fn new(
86        config: ApiConfig,
87        persistence: Persistence,
88        agent_registry: Arc<AgentRegistry>,
89        tool_registry: Arc<ToolRegistry>,
90        app_config: AppConfig,
91    ) -> Self {
92        let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
93
94        Self { config, state }
95    }
96
97    /// Get the mesh registry for self-registration
98    pub fn mesh_registry(&self) -> &crate::api::mesh::MeshRegistry {
99        &self.state.mesh_registry
100    }
101
102    /// Build the router with all routes
103    fn build_router(&self) -> Router {
104        let mut router = Router::new()
105            // Health and info endpoints
106            .route("/health", get(health_check))
107            .route("/agents", get(list_agents))
108            // Query endpoints
109            .route("/query", post(query))
110            .route("/stream", post(stream_query))
111            // Mesh registry endpoints
112            .route("/registry/register", post(register_instance::<AppState>))
113            .route("/registry/agents", get(list_instances::<AppState>))
114            .route(
115                "/registry/heartbeat/:instance_id",
116                post(heartbeat::<AppState>),
117            )
118            .route(
119                "/registry/deregister/:instance_id",
120                delete(deregister_instance::<AppState>),
121            )
122            // Message routing endpoints
123            .route(
124                "/messages/send/:source_instance",
125                post(send_message::<AppState>),
126            )
127            .route("/messages/:instance_id", get(get_messages::<AppState>))
128            .route(
129                "/messages/ack/:instance_id",
130                post(acknowledge_messages::<AppState>),
131            )
132            // Graph sync endpoints
133            .route("/sync/request", post(handle_sync_request))
134            .route("/sync/apply", post(handle_sync_apply))
135            .route("/sync/status/:session_id/:graph_name", get(get_sync_status))
136            .route("/sync/enable/:session_id/:graph_name", post(toggle_sync))
137            .route("/sync/configs/:session_id", get(list_sync_configs))
138            .route("/sync/bulk/:session_id", post(bulk_toggle_sync))
139            .route(
140                "/sync/configure/:session_id/:graph_name",
141                post(configure_sync),
142            )
143            .route("/sync/conflicts", get(list_conflicts))
144            // Add state
145            .with_state(self.state.clone());
146
147        // Add CORS if enabled
148        if self.config.enable_cors {
149            let cors = CorsLayer::new()
150                .allow_origin(Any)
151                .allow_methods(Any)
152                .allow_headers(Any);
153            router = router.layer(cors);
154        }
155
156        // Add tracing
157        router = router.layer(TraceLayer::new_for_http());
158
159        router
160    }
161
162    /// Run the server
163    pub async fn run(self) -> Result<()> {
164        let app = self.build_router();
165        let bind_addr = self.config.bind_address();
166
167        tracing::debug!("Starting API server on {}", bind_addr);
168
169        let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
170
171        axum::serve(listener, app)
172            .await
173            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
174
175        Ok(())
176    }
177
178    /// Run the server with graceful shutdown
179    pub async fn run_with_shutdown(
180        self,
181        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
182    ) -> Result<()> {
183        let app = self.build_router();
184        let bind_addr = self.config.bind_address();
185
186        tracing::debug!("Starting API server on {}", bind_addr);
187
188        let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
189
190        axum::serve(listener, app)
191            .with_graceful_shutdown(shutdown_signal)
192            .await
193            .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
194
195        Ok(())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_api_config_default() {
205        let config = ApiConfig::default();
206        assert_eq!(config.host, "127.0.0.1");
207        assert_eq!(config.port, 3000);
208        assert!(config.api_key.is_none());
209        assert!(config.enable_cors);
210    }
211
212    #[test]
213    fn test_api_config_builder() {
214        let config = ApiConfig::new()
215            .with_host("0.0.0.0")
216            .with_port(8080)
217            .with_api_key("secret123")
218            .with_cors(false);
219
220        assert_eq!(config.host, "0.0.0.0");
221        assert_eq!(config.port, 8080);
222        assert_eq!(config.api_key, Some("secret123".to_string()));
223        assert!(!config.enable_cors);
224    }
225
226    #[test]
227    fn test_bind_address() {
228        let config = ApiConfig::new().with_host("localhost").with_port(5000);
229
230        assert_eq!(config.bind_address(), "localhost:5000");
231    }
232}