1use crate::api::handlers::{health_check, list_agents, query, stream_query, AppState};
3use crate::api::mesh::{
4 acknowledge_messages, deregister_instance, get_messages, heartbeat, list_instances,
5 register_instance, send_message,
6};
7use crate::api::sync_handlers::{
8 bulk_toggle_sync, configure_sync, get_sync_status, handle_sync_apply, handle_sync_request,
9 list_conflicts, list_sync_configs, toggle_sync,
10};
11use crate::config::{AgentRegistry, AppConfig};
12use crate::persistence::Persistence;
13use crate::tools::ToolRegistry;
14use anyhow::Result;
15use axum::{
16 routing::{delete, get, post},
17 Router,
18};
19use std::sync::Arc;
20use tower_http::cors::{Any, CorsLayer};
21use tower_http::trace::TraceLayer;
22
23#[derive(Debug, Clone)]
25pub struct ApiConfig {
26 pub host: String,
28 pub port: u16,
30 pub api_key: Option<String>,
32 pub enable_cors: bool,
34}
35
36impl Default for ApiConfig {
37 fn default() -> Self {
38 Self {
39 host: "127.0.0.1".to_string(),
40 port: 3000,
41 api_key: None,
42 enable_cors: true,
43 }
44 }
45}
46
47impl ApiConfig {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_host(mut self, host: impl Into<String>) -> Self {
53 self.host = host.into();
54 self
55 }
56
57 pub fn with_port(mut self, port: u16) -> Self {
58 self.port = port;
59 self
60 }
61
62 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
63 self.api_key = Some(api_key.into());
64 self
65 }
66
67 pub fn with_cors(mut self, enable: bool) -> Self {
68 self.enable_cors = enable;
69 self
70 }
71
72 pub fn bind_address(&self) -> String {
73 format!("{}:{}", self.host, self.port)
74 }
75}
76
77pub struct ApiServer {
79 config: ApiConfig,
80 state: AppState,
81}
82
83impl ApiServer {
84 pub fn new(
86 config: ApiConfig,
87 persistence: Persistence,
88 agent_registry: Arc<AgentRegistry>,
89 tool_registry: Arc<ToolRegistry>,
90 app_config: AppConfig,
91 ) -> Self {
92 let state = AppState::new(persistence, agent_registry, tool_registry, app_config);
93
94 Self { config, state }
95 }
96
97 pub fn mesh_registry(&self) -> &crate::api::mesh::MeshRegistry {
99 &self.state.mesh_registry
100 }
101
102 fn build_router(&self) -> Router {
104 let mut router = Router::new()
105 .route("/health", get(health_check))
107 .route("/agents", get(list_agents))
108 .route("/query", post(query))
110 .route("/stream", post(stream_query))
111 .route("/registry/register", post(register_instance::<AppState>))
113 .route("/registry/agents", get(list_instances::<AppState>))
114 .route(
115 "/registry/heartbeat/:instance_id",
116 post(heartbeat::<AppState>),
117 )
118 .route(
119 "/registry/deregister/:instance_id",
120 delete(deregister_instance::<AppState>),
121 )
122 .route(
124 "/messages/send/:source_instance",
125 post(send_message::<AppState>),
126 )
127 .route("/messages/:instance_id", get(get_messages::<AppState>))
128 .route(
129 "/messages/ack/:instance_id",
130 post(acknowledge_messages::<AppState>),
131 )
132 .route("/sync/request", post(handle_sync_request))
134 .route("/sync/apply", post(handle_sync_apply))
135 .route("/sync/status/:session_id/:graph_name", get(get_sync_status))
136 .route("/sync/enable/:session_id/:graph_name", post(toggle_sync))
137 .route("/sync/configs/:session_id", get(list_sync_configs))
138 .route("/sync/bulk/:session_id", post(bulk_toggle_sync))
139 .route(
140 "/sync/configure/:session_id/:graph_name",
141 post(configure_sync),
142 )
143 .route("/sync/conflicts", get(list_conflicts))
144 .with_state(self.state.clone());
146
147 if self.config.enable_cors {
149 let cors = CorsLayer::new()
150 .allow_origin(Any)
151 .allow_methods(Any)
152 .allow_headers(Any);
153 router = router.layer(cors);
154 }
155
156 router = router.layer(TraceLayer::new_for_http());
158
159 router
160 }
161
162 pub async fn run(self) -> Result<()> {
164 let app = self.build_router();
165 let bind_addr = self.config.bind_address();
166
167 tracing::debug!("Starting API server on {}", bind_addr);
168
169 let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
170
171 axum::serve(listener, app)
172 .await
173 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
174
175 Ok(())
176 }
177
178 pub async fn run_with_shutdown(
180 self,
181 shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
182 ) -> Result<()> {
183 let app = self.build_router();
184 let bind_addr = self.config.bind_address();
185
186 tracing::debug!("Starting API server on {}", bind_addr);
187
188 let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
189
190 axum::serve(listener, app)
191 .with_graceful_shutdown(shutdown_signal)
192 .await
193 .map_err(|e| anyhow::anyhow!("Server error: {}", e))?;
194
195 Ok(())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_api_config_default() {
205 let config = ApiConfig::default();
206 assert_eq!(config.host, "127.0.0.1");
207 assert_eq!(config.port, 3000);
208 assert!(config.api_key.is_none());
209 assert!(config.enable_cors);
210 }
211
212 #[test]
213 fn test_api_config_builder() {
214 let config = ApiConfig::new()
215 .with_host("0.0.0.0")
216 .with_port(8080)
217 .with_api_key("secret123")
218 .with_cors(false);
219
220 assert_eq!(config.host, "0.0.0.0");
221 assert_eq!(config.port, 8080);
222 assert_eq!(config.api_key, Some("secret123".to_string()));
223 assert!(!config.enable_cors);
224 }
225
226 #[test]
227 fn test_bind_address() {
228 let config = ApiConfig::new().with_host("localhost").with_port(5000);
229
230 assert_eq!(config.bind_address(), "localhost:5000");
231 }
232}