spec_ai_api/api/
handlers.rs

1/// API request handlers
2use crate::agent::builder::AgentBuilder;
3use crate::agent::core::AgentCore;
4use crate::api::mesh::{MeshRegistry, MeshState};
5use crate::api::models::*;
6use crate::config::{AgentRegistry, AppConfig};
7use crate::persistence::Persistence;
8use crate::tools::ToolRegistry;
9use async_stream::stream;
10use axum::{
11    extract::{Json, State},
12    http::StatusCode,
13    response::{
14        sse::{Event, Sse},
15        IntoResponse, Response,
16    },
17};
18use futures::StreamExt;
19use serde_json::json;
20use std::convert::Infallible;
21use std::sync::Arc;
22use std::time::{Instant, SystemTime, UNIX_EPOCH};
23use tokio::sync::RwLock;
24
25/// Shared application state
26#[derive(Clone)]
27pub struct AppState {
28    pub persistence: Persistence,
29    pub agent_registry: Arc<AgentRegistry>,
30    pub tool_registry: Arc<ToolRegistry>,
31    pub config: AppConfig,
32    pub start_time: Instant,
33    pub mesh_registry: MeshRegistry,
34}
35
36impl AppState {
37    pub fn new(
38        persistence: Persistence,
39        agent_registry: Arc<AgentRegistry>,
40        tool_registry: Arc<ToolRegistry>,
41        config: AppConfig,
42    ) -> Self {
43        Self {
44            persistence: persistence.clone(),
45            agent_registry,
46            tool_registry,
47            config,
48            start_time: Instant::now(),
49            mesh_registry: MeshRegistry::with_persistence(persistence),
50        }
51    }
52}
53
54impl MeshState for AppState {
55    fn mesh_registry(&self) -> &MeshRegistry {
56        &self.mesh_registry
57    }
58}
59
60/// Health check endpoint
61pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
62    let uptime = state.start_time.elapsed().as_secs();
63
64    let response = HealthResponse {
65        status: "healthy".to_string(),
66        version: env!("CARGO_PKG_VERSION").to_string(),
67        uptime_seconds: uptime,
68        active_sessions: 0, // TODO: Track active sessions
69    };
70
71    Json(response)
72}
73
74/// List available agents
75pub async fn list_agents(State(state): State<AppState>) -> impl IntoResponse {
76    let agent_names = state.agent_registry.list();
77    let mut agent_infos = Vec::new();
78
79    for name in agent_names {
80        if let Some(profile) = state.agent_registry.get(&name) {
81            agent_infos.push(AgentInfo {
82                id: name,
83                description: profile.prompt.unwrap_or_default(),
84                allowed_tools: profile.allowed_tools.unwrap_or_default(),
85                denied_tools: profile.denied_tools.unwrap_or_default(),
86            });
87        }
88    }
89
90    Json(AgentListResponse {
91        agents: agent_infos,
92    })
93    .into_response()
94}
95
96/// Query endpoint - process a message and return response
97pub async fn query(State(state): State<AppState>, Json(request): Json<QueryRequest>) -> Response {
98    // If streaming requested, delegate to streaming handler
99    if request.stream {
100        return (
101            StatusCode::BAD_REQUEST,
102            Json(ErrorResponse::new(
103                "invalid_request",
104                "Streaming not supported on /query endpoint. Use /stream instead.",
105            )),
106        )
107            .into_response();
108    }
109
110    // Determine which agent to use
111    let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
112
113    // Get or create session ID
114    let session_id = request
115        .session_id
116        .unwrap_or_else(|| format!("api_{}", uuid_v4()));
117
118    // Create agent instance
119    let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
120
121    let mut agent = match agent_result {
122        Ok(agent) => agent,
123        Err(e) => {
124            return (
125                StatusCode::BAD_REQUEST,
126                Json(ErrorResponse::new("agent_error", e.to_string())),
127            )
128                .into_response();
129        }
130    };
131
132    // Process the message
133    let start = Instant::now();
134
135    match agent.run_step(&request.message).await {
136        Ok(output) => {
137            let processing_time = start.elapsed().as_millis() as u64;
138            let tool_calls: Vec<ToolCallInfo> = output
139                .tool_invocations
140                .iter()
141                .map(|inv| ToolCallInfo {
142                    name: inv.name.clone(),
143                    arguments: inv.arguments.clone(),
144                    success: inv.success,
145                    output: inv.output.clone(),
146                    error: inv.error.clone(),
147                })
148                .collect();
149
150            let response = QueryResponse {
151                response: output.response,
152                session_id,
153                agent: agent_name,
154                tool_calls,
155                metadata: ResponseMetadata {
156                    timestamp: current_timestamp(),
157                    model: state.config.model.provider.clone(),
158                    processing_time_ms: processing_time,
159                    run_id: output.run_id,
160                },
161            };
162
163            Json(response).into_response()
164        }
165        Err(e) => (
166            StatusCode::INTERNAL_SERVER_ERROR,
167            Json(ErrorResponse::new("execution_error", e.to_string())),
168        )
169            .into_response(),
170    }
171}
172
173/// Streaming query endpoint
174pub async fn stream_query(
175    State(state): State<AppState>,
176    Json(request): Json<QueryRequest>,
177) -> Response {
178    let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
179    let session_id = request
180        .session_id
181        .unwrap_or_else(|| format!("api_{}", uuid_v4()));
182
183    // Create agent
184    let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
185
186    let agent = match agent_result {
187        Ok(agent) => agent,
188        Err(e) => {
189            return (
190                StatusCode::BAD_REQUEST,
191                Json(ErrorResponse::new("agent_error", e.to_string())),
192            )
193                .into_response();
194        }
195    };
196
197    // Create SSE stream
198    let agent = Arc::new(RwLock::new(agent));
199    let message = request.message.clone();
200    let session_id_clone = session_id.clone();
201    let agent_name_clone = agent_name.clone();
202    let model_id = state.config.model.provider.clone();
203
204    let sse_stream = stream! {
205        yield StreamChunk::Start {
206            session_id: session_id_clone.clone(),
207            agent: agent_name_clone.clone(),
208        };
209
210        let start = Instant::now();
211        let mut agent_lock = agent.write().await;
212
213        match agent_lock.run_step(&message).await {
214            Ok(output) => {
215                yield StreamChunk::Content { text: output.response.clone() };
216
217                for invocation in output.tool_invocations {
218                    yield StreamChunk::ToolCall {
219                        name: invocation.name.clone(),
220                        arguments: invocation.arguments.clone(),
221                    };
222                    yield StreamChunk::ToolResult {
223                        name: invocation.name.clone(),
224                        result: json!({
225                            "success": invocation.success,
226                            "output": invocation.output,
227                            "error": invocation.error,
228                        }),
229                    };
230                }
231
232                yield StreamChunk::End {
233                    metadata: ResponseMetadata {
234                        timestamp: current_timestamp(),
235                        model: model_id.clone(),
236                        processing_time_ms: start.elapsed().as_millis() as u64,
237                        run_id: output.run_id,
238                    },
239                };
240            }
241            Err(e) => {
242                yield StreamChunk::Error {
243                    message: e.to_string(),
244                };
245            }
246        }
247    };
248
249    Sse::new(sse_stream.map(|chunk| {
250        let json = serde_json::to_string(&chunk).unwrap();
251        Ok::<_, Infallible>(Event::default().data(json))
252    }))
253    .into_response()
254}
255
256/// Helper: Create agent instance
257async fn create_agent(
258    state: &AppState,
259    agent_name: &str,
260    session_id: &str,
261    _temperature: Option<f32>,
262) -> anyhow::Result<AgentCore> {
263    // Get the agent profile
264    let profile = state
265        .agent_registry
266        .get(agent_name)
267        .ok_or_else(|| anyhow::anyhow!("Agent '{}' not found", agent_name))?;
268
269    // Build the agent using the builder with config
270    let agent = AgentBuilder::new()
271        .with_profile(profile)
272        .with_config(state.config.clone())
273        .with_session_id(session_id)
274        .with_agent_name(agent_name.to_string())
275        .with_tool_registry(state.tool_registry.clone())
276        .with_persistence(state.persistence.clone())
277        .build()?;
278
279    Ok(agent)
280}
281
282/// Helper: Generate UUID v4
283fn uuid_v4() -> String {
284    let rng = std::collections::hash_map::RandomState::new();
285    let hash = std::hash::BuildHasher::hash_one(&rng, SystemTime::now());
286    format!("{:x}", hash)
287}
288
289/// Helper: Get current timestamp
290fn current_timestamp() -> String {
291    let now = SystemTime::now()
292        .duration_since(UNIX_EPOCH)
293        .unwrap()
294        .as_secs();
295    chrono::DateTime::from_timestamp(now as i64, 0)
296        .unwrap()
297        .to_rfc3339()
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_uuid_generation() {
306        let uuid1 = uuid_v4();
307        let uuid2 = uuid_v4();
308
309        assert!(!uuid1.is_empty());
310        assert!(!uuid2.is_empty());
311        // UUIDs should be different (probabilistically)
312        // We won't assert this as it could theoretically fail
313    }
314
315    #[test]
316    fn test_timestamp_format() {
317        let ts = current_timestamp();
318        assert!(ts.contains('T'));
319        assert!(ts.contains('Z') || ts.contains('+'));
320    }
321}