1use crate::agent::builder::AgentBuilder;
3use crate::agent::core::AgentCore;
4use crate::api::mesh::{MeshRegistry, MeshState};
5use crate::api::models::*;
6use crate::config::{AgentRegistry, AppConfig};
7use crate::persistence::Persistence;
8use crate::tools::ToolRegistry;
9use async_stream::stream;
10use axum::{
11 extract::{Json, State},
12 http::StatusCode,
13 response::{
14 sse::{Event, Sse},
15 IntoResponse, Response,
16 },
17};
18use futures::StreamExt;
19use serde_json::json;
20use std::convert::Infallible;
21use std::sync::Arc;
22use std::time::{Instant, SystemTime, UNIX_EPOCH};
23use tokio::sync::RwLock;
24
25#[derive(Clone)]
27pub struct AppState {
28 pub persistence: Persistence,
29 pub agent_registry: Arc<AgentRegistry>,
30 pub tool_registry: Arc<ToolRegistry>,
31 pub config: AppConfig,
32 pub start_time: Instant,
33 pub mesh_registry: MeshRegistry,
34}
35
36impl AppState {
37 pub fn new(
38 persistence: Persistence,
39 agent_registry: Arc<AgentRegistry>,
40 tool_registry: Arc<ToolRegistry>,
41 config: AppConfig,
42 ) -> Self {
43 Self {
44 persistence: persistence.clone(),
45 agent_registry,
46 tool_registry,
47 config,
48 start_time: Instant::now(),
49 mesh_registry: MeshRegistry::with_persistence(persistence),
50 }
51 }
52}
53
54impl MeshState for AppState {
55 fn mesh_registry(&self) -> &MeshRegistry {
56 &self.mesh_registry
57 }
58}
59
60pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
62 let uptime = state.start_time.elapsed().as_secs();
63
64 let response = HealthResponse {
65 status: "healthy".to_string(),
66 version: env!("CARGO_PKG_VERSION").to_string(),
67 uptime_seconds: uptime,
68 active_sessions: 0, };
70
71 Json(response)
72}
73
74pub async fn list_agents(State(state): State<AppState>) -> impl IntoResponse {
76 let agent_names = state.agent_registry.list();
77 let mut agent_infos = Vec::new();
78
79 for name in agent_names {
80 if let Some(profile) = state.agent_registry.get(&name) {
81 agent_infos.push(AgentInfo {
82 id: name,
83 description: profile.prompt.unwrap_or_default(),
84 allowed_tools: profile.allowed_tools.unwrap_or_default(),
85 denied_tools: profile.denied_tools.unwrap_or_default(),
86 });
87 }
88 }
89
90 Json(AgentListResponse {
91 agents: agent_infos,
92 })
93 .into_response()
94}
95
96pub async fn query(State(state): State<AppState>, Json(request): Json<QueryRequest>) -> Response {
98 if request.stream {
100 return (
101 StatusCode::BAD_REQUEST,
102 Json(ErrorResponse::new(
103 "invalid_request",
104 "Streaming not supported on /query endpoint. Use /stream instead.",
105 )),
106 )
107 .into_response();
108 }
109
110 let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
112
113 let session_id = request
115 .session_id
116 .unwrap_or_else(|| format!("api_{}", uuid_v4()));
117
118 let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
120
121 let mut agent = match agent_result {
122 Ok(agent) => agent,
123 Err(e) => {
124 return (
125 StatusCode::BAD_REQUEST,
126 Json(ErrorResponse::new("agent_error", e.to_string())),
127 )
128 .into_response();
129 }
130 };
131
132 let start = Instant::now();
134
135 match agent.run_step(&request.message).await {
136 Ok(output) => {
137 let processing_time = start.elapsed().as_millis() as u64;
138 let tool_calls: Vec<ToolCallInfo> = output
139 .tool_invocations
140 .iter()
141 .map(|inv| ToolCallInfo {
142 name: inv.name.clone(),
143 arguments: inv.arguments.clone(),
144 success: inv.success,
145 output: inv.output.clone(),
146 error: inv.error.clone(),
147 })
148 .collect();
149
150 let response = QueryResponse {
151 response: output.response,
152 session_id,
153 agent: agent_name,
154 tool_calls,
155 metadata: ResponseMetadata {
156 timestamp: current_timestamp(),
157 model: state.config.model.provider.clone(),
158 processing_time_ms: processing_time,
159 run_id: output.run_id,
160 },
161 };
162
163 Json(response).into_response()
164 }
165 Err(e) => (
166 StatusCode::INTERNAL_SERVER_ERROR,
167 Json(ErrorResponse::new("execution_error", e.to_string())),
168 )
169 .into_response(),
170 }
171}
172
173pub async fn stream_query(
175 State(state): State<AppState>,
176 Json(request): Json<QueryRequest>,
177) -> Response {
178 let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
179 let session_id = request
180 .session_id
181 .unwrap_or_else(|| format!("api_{}", uuid_v4()));
182
183 let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
185
186 let agent = match agent_result {
187 Ok(agent) => agent,
188 Err(e) => {
189 return (
190 StatusCode::BAD_REQUEST,
191 Json(ErrorResponse::new("agent_error", e.to_string())),
192 )
193 .into_response();
194 }
195 };
196
197 let agent = Arc::new(RwLock::new(agent));
199 let message = request.message.clone();
200 let session_id_clone = session_id.clone();
201 let agent_name_clone = agent_name.clone();
202 let model_id = state.config.model.provider.clone();
203
204 let sse_stream = stream! {
205 yield StreamChunk::Start {
206 session_id: session_id_clone.clone(),
207 agent: agent_name_clone.clone(),
208 };
209
210 let start = Instant::now();
211 let mut agent_lock = agent.write().await;
212
213 match agent_lock.run_step(&message).await {
214 Ok(output) => {
215 yield StreamChunk::Content { text: output.response.clone() };
216
217 for invocation in output.tool_invocations {
218 yield StreamChunk::ToolCall {
219 name: invocation.name.clone(),
220 arguments: invocation.arguments.clone(),
221 };
222 yield StreamChunk::ToolResult {
223 name: invocation.name.clone(),
224 result: json!({
225 "success": invocation.success,
226 "output": invocation.output,
227 "error": invocation.error,
228 }),
229 };
230 }
231
232 yield StreamChunk::End {
233 metadata: ResponseMetadata {
234 timestamp: current_timestamp(),
235 model: model_id.clone(),
236 processing_time_ms: start.elapsed().as_millis() as u64,
237 run_id: output.run_id,
238 },
239 };
240 }
241 Err(e) => {
242 yield StreamChunk::Error {
243 message: e.to_string(),
244 };
245 }
246 }
247 };
248
249 Sse::new(sse_stream.map(|chunk| {
250 let json = serde_json::to_string(&chunk).unwrap();
251 Ok::<_, Infallible>(Event::default().data(json))
252 }))
253 .into_response()
254}
255
256async fn create_agent(
258 state: &AppState,
259 agent_name: &str,
260 session_id: &str,
261 _temperature: Option<f32>,
262) -> anyhow::Result<AgentCore> {
263 let profile = state
265 .agent_registry
266 .get(agent_name)
267 .ok_or_else(|| anyhow::anyhow!("Agent '{}' not found", agent_name))?;
268
269 let agent = AgentBuilder::new()
271 .with_profile(profile)
272 .with_config(state.config.clone())
273 .with_session_id(session_id)
274 .with_agent_name(agent_name.to_string())
275 .with_tool_registry(state.tool_registry.clone())
276 .with_persistence(state.persistence.clone())
277 .build()?;
278
279 Ok(agent)
280}
281
282fn uuid_v4() -> String {
284 let rng = std::collections::hash_map::RandomState::new();
285 let hash = std::hash::BuildHasher::hash_one(&rng, SystemTime::now());
286 format!("{:x}", hash)
287}
288
289fn current_timestamp() -> String {
291 let now = SystemTime::now()
292 .duration_since(UNIX_EPOCH)
293 .unwrap()
294 .as_secs();
295 chrono::DateTime::from_timestamp(now as i64, 0)
296 .unwrap()
297 .to_rfc3339()
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_uuid_generation() {
306 let uuid1 = uuid_v4();
307 let uuid2 = uuid_v4();
308
309 assert!(!uuid1.is_empty());
310 assert!(!uuid2.is_empty());
311 }
314
315 #[test]
316 fn test_timestamp_format() {
317 let ts = current_timestamp();
318 assert!(ts.contains('T'));
319 assert!(ts.contains('Z') || ts.contains('+'));
320 }
321}