Skip to main content

spec_ai/spec_ai_api/api/
handlers.rs

1/// API request handlers
2use crate::spec_ai_api::agent::builder::AgentBuilder;
3use crate::spec_ai_api::agent::core::AgentCore;
4use crate::spec_ai_api::api::auth::{AuthService, TokenRequest, TokenResponse};
5use crate::spec_ai_api::api::mesh::{MeshRegistry, MeshState};
6use crate::spec_ai_api::api::models::*;
7use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
8use crate::spec_ai_api::persistence::Persistence;
9use crate::spec_ai_api::tools::ToolRegistry;
10use async_stream::stream;
11use axum::{
12    extract::{Json, State},
13    http::StatusCode,
14    response::{
15        IntoResponse, Response,
16        sse::{Event, Sse},
17    },
18};
19use futures::StreamExt;
20use serde_json::json;
21use std::convert::Infallible;
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24use std::time::{Instant, SystemTime, UNIX_EPOCH};
25use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
26use tokio::sync::RwLock;
27
28const DEFAULT_PAGE_SIZE: usize = 10;
29const MAX_PAGE_SIZE: usize = 25;
30const MAX_TOTAL_RESULTS: usize = 100;
31
32/// Shared application state
33#[derive(Clone)]
34pub struct AppState {
35    pub persistence: Persistence,
36    pub agent_registry: Arc<AgentRegistry>,
37    pub tool_registry: Arc<ToolRegistry>,
38    pub config: AppConfig,
39    pub start_time: Instant,
40    pub mesh_registry: MeshRegistry,
41    pub auth_service: Arc<AuthService>,
42}
43
44impl AppState {
45    pub fn new(
46        persistence: Persistence,
47        agent_registry: Arc<AgentRegistry>,
48        tool_registry: Arc<ToolRegistry>,
49        config: AppConfig,
50    ) -> Self {
51        // Initialize auth service from config
52        let auth_service = AuthService::new(
53            config.auth.credentials_file.as_deref(),
54            config.auth.token_secret.as_deref(),
55            Some(config.auth.token_expiry_secs),
56            config.auth.enabled,
57        )
58        .unwrap_or_else(|e| {
59            tracing::warn!("Failed to initialize auth service: {}. Auth disabled.", e);
60            AuthService::disabled()
61        });
62
63        Self {
64            persistence: persistence.clone(),
65            agent_registry,
66            tool_registry,
67            config,
68            start_time: Instant::now(),
69            mesh_registry: MeshRegistry::with_persistence(persistence),
70            auth_service: Arc::new(auth_service),
71        }
72    }
73}
74
75impl MeshState for AppState {
76    fn mesh_registry(&self) -> &MeshRegistry {
77        &self.mesh_registry
78    }
79}
80
81/// Health check endpoint
82pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
83    let uptime = state.start_time.elapsed().as_secs();
84    let active_sessions = match state.persistence.list_sessions() {
85        Ok(sessions) => sessions.len(),
86        Err(e) => {
87            tracing::warn!("Failed to count active sessions: {}", e);
88            0
89        }
90    };
91
92    let response = HealthResponse {
93        status: "healthy".to_string(),
94        version: env!("CARGO_PKG_VERSION").to_string(),
95        uptime_seconds: uptime,
96        active_sessions,
97    };
98
99    Json(response)
100}
101
102/// List available agents
103pub async fn list_agents(State(state): State<AppState>) -> impl IntoResponse {
104    let agent_names = state.agent_registry.list();
105    let mut agent_infos = Vec::new();
106
107    for name in agent_names {
108        if let Some(profile) = state.agent_registry.get(&name) {
109            agent_infos.push(AgentInfo {
110                id: name,
111                description: profile.prompt.unwrap_or_default(),
112                allowed_tools: profile.allowed_tools.unwrap_or_default(),
113                denied_tools: profile.denied_tools.unwrap_or_default(),
114            });
115        }
116    }
117
118    Json(AgentListResponse {
119        agents: agent_infos,
120    })
121    .into_response()
122}
123
124/// Query endpoint - process a message and return response
125pub async fn query(State(state): State<AppState>, Json(request): Json<QueryRequest>) -> Response {
126    // If streaming requested, delegate to streaming handler
127    if request.stream {
128        return (
129            StatusCode::BAD_REQUEST,
130            Json(ErrorResponse::new(
131                "invalid_request",
132                "Streaming not supported on /query endpoint. Use /stream instead.",
133            )),
134        )
135            .into_response();
136    }
137
138    // Determine which agent to use
139    let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
140
141    // Get or create session ID
142    let session_id = request
143        .session_id
144        .unwrap_or_else(|| format!("api_{}", uuid_v4()));
145
146    // Create agent instance
147    let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
148
149    let mut agent = match agent_result {
150        Ok(agent) => agent,
151        Err(e) => {
152            return (
153                StatusCode::BAD_REQUEST,
154                Json(ErrorResponse::new("agent_error", e.to_string())),
155            )
156                .into_response();
157        }
158    };
159
160    // Process the message
161    let start = Instant::now();
162
163    match agent.run_step(&request.message).await {
164        Ok(output) => {
165            let processing_time = start.elapsed().as_millis() as u64;
166            let tool_calls: Vec<ToolCallInfo> = output
167                .tool_invocations
168                .iter()
169                .map(|inv| ToolCallInfo {
170                    name: inv.name.clone(),
171                    arguments: inv.arguments.clone(),
172                    success: inv.success,
173                    output: inv.output.clone(),
174                    error: inv.error.clone(),
175                })
176                .collect();
177
178            let response = QueryResponse {
179                response: output.response,
180                session_id,
181                agent: agent_name,
182                tool_calls,
183                metadata: ResponseMetadata {
184                    timestamp: current_timestamp(),
185                    model: state.config.model.provider.clone(),
186                    processing_time_ms: processing_time,
187                    run_id: output.run_id,
188                },
189            };
190
191            Json(response).into_response()
192        }
193        Err(e) => (
194            StatusCode::INTERNAL_SERVER_ERROR,
195            Json(ErrorResponse::new("execution_error", e.to_string())),
196        )
197            .into_response(),
198    }
199}
200
201/// Streaming query endpoint
202pub async fn stream_query(
203    State(state): State<AppState>,
204    Json(request): Json<QueryRequest>,
205) -> Response {
206    let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
207    let session_id = request
208        .session_id
209        .unwrap_or_else(|| format!("api_{}", uuid_v4()));
210
211    // Create agent
212    let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
213
214    let agent = match agent_result {
215        Ok(agent) => agent,
216        Err(e) => {
217            return (
218                StatusCode::BAD_REQUEST,
219                Json(ErrorResponse::new("agent_error", e.to_string())),
220            )
221                .into_response();
222        }
223    };
224
225    // Create SSE stream
226    let agent = Arc::new(RwLock::new(agent));
227    let message = request.message.clone();
228    let session_id_clone = session_id.clone();
229    let agent_name_clone = agent_name.clone();
230    let model_id = state.config.model.provider.clone();
231
232    let sse_stream = stream! {
233        yield StreamChunk::Start {
234            session_id: session_id_clone.clone(),
235            agent: agent_name_clone.clone(),
236        };
237
238        let start = Instant::now();
239        let mut agent_lock = agent.write().await;
240
241        match agent_lock.run_step(&message).await {
242            Ok(output) => {
243                yield StreamChunk::Content { text: output.response.clone() };
244
245                for invocation in output.tool_invocations {
246                    yield StreamChunk::ToolCall {
247                        name: invocation.name.clone(),
248                        arguments: invocation.arguments.clone(),
249                    };
250                    yield StreamChunk::ToolResult {
251                        name: invocation.name.clone(),
252                        result: json!({
253                            "success": invocation.success,
254                            "output": invocation.output,
255                            "error": invocation.error,
256                        }),
257                    };
258                }
259
260                yield StreamChunk::End {
261                    metadata: ResponseMetadata {
262                        timestamp: current_timestamp(),
263                        model: model_id.clone(),
264                        processing_time_ms: start.elapsed().as_millis() as u64,
265                        run_id: output.run_id,
266                    },
267                };
268            }
269            Err(e) => {
270                yield StreamChunk::Error {
271                    message: e.to_string(),
272                };
273            }
274        }
275    };
276
277    Sse::new(sse_stream.map(|chunk| {
278        let json = serde_json::to_string(&chunk).unwrap();
279        Ok::<_, Infallible>(Event::default().data(json))
280    }))
281    .into_response()
282}
283
284/// Helper: Create agent instance
285async fn create_agent(
286    state: &AppState,
287    agent_name: &str,
288    session_id: &str,
289    _temperature: Option<f32>,
290) -> anyhow::Result<AgentCore> {
291    // Get the agent profile
292    let profile = state
293        .agent_registry
294        .get(agent_name)
295        .ok_or_else(|| anyhow::anyhow!("Agent '{}' not found", agent_name))?;
296
297    // Build the agent using the builder with config
298    let agent = AgentBuilder::new()
299        .with_profile(profile)
300        .with_config(state.config.clone())
301        .with_session_id(session_id)
302        .with_agent_name(agent_name.to_string())
303        .with_tool_registry(state.tool_registry.clone())
304        .with_persistence(state.persistence.clone())
305        .build()?;
306
307    Ok(agent)
308}
309
310/// Helper: Generate UUID v4
311fn uuid_v4() -> String {
312    let rng = std::collections::hash_map::RandomState::new();
313    let hash = std::hash::BuildHasher::hash_one(&rng, SystemTime::now());
314    format!("{:x}", hash)
315}
316
317/// Helper: Get current timestamp
318fn current_timestamp() -> String {
319    let now = SystemTime::now()
320        .duration_since(UNIX_EPOCH)
321        .unwrap()
322        .as_secs();
323    chrono::DateTime::from_timestamp(now as i64, 0)
324        .unwrap()
325        .to_rfc3339()
326}
327
328/// Token generation endpoint - exchange username/password for bearer token
329pub async fn generate_token(
330    State(state): State<AppState>,
331    Json(request): Json<TokenRequest>,
332) -> Response {
333    // Check if auth is enabled
334    if !state.auth_service.is_enabled() {
335        return (
336            StatusCode::SERVICE_UNAVAILABLE,
337            Json(ErrorResponse::new(
338                "auth_disabled",
339                "Authentication is not enabled on this server",
340            )),
341        )
342            .into_response();
343    }
344
345    // Verify credentials
346    if !state
347        .auth_service
348        .verify_password(&request.username, &request.password)
349    {
350        return (
351            StatusCode::UNAUTHORIZED,
352            Json(ErrorResponse::new(
353                "invalid_credentials",
354                "Invalid username or password",
355            )),
356        )
357            .into_response();
358    }
359
360    // Generate token
361    match state.auth_service.generate_token(&request.username) {
362        Ok(token) => {
363            let response = TokenResponse {
364                token,
365                token_type: "Bearer".to_string(),
366                expires_in: state.config.auth.token_expiry_secs,
367            };
368            Json(response).into_response()
369        }
370        Err(e) => {
371            tracing::error!("Failed to generate token: {}", e);
372            (
373                StatusCode::INTERNAL_SERVER_ERROR,
374                Json(ErrorResponse::new(
375                    "token_error",
376                    "Failed to generate token",
377                )),
378            )
379                .into_response()
380        }
381    }
382}
383
384/// Password hash generation endpoint - utility for creating password hashes
385/// This endpoint can be used to generate hashes for the credentials file
386pub async fn hash_password(
387    State(state): State<AppState>,
388    Json(body): Json<serde_json::Value>,
389) -> Response {
390    // Only allow if auth is enabled (to prevent abuse)
391    if !state.auth_service.is_enabled() {
392        return (
393            StatusCode::SERVICE_UNAVAILABLE,
394            Json(ErrorResponse::new(
395                "auth_disabled",
396                "Authentication is not enabled on this server",
397            )),
398        )
399            .into_response();
400    }
401
402    let Some(password) = body.get("password").and_then(|v| v.as_str()) else {
403        return (
404            StatusCode::BAD_REQUEST,
405            Json(ErrorResponse::new(
406                "invalid_request",
407                "Missing 'password' field in request body",
408            )),
409        )
410            .into_response();
411    };
412
413    match AuthService::hash_password(password) {
414        Ok(hash) => Json(json!({ "password_hash": hash })).into_response(),
415        Err(e) => {
416            tracing::error!("Failed to hash password: {}", e);
417            (
418                StatusCode::INTERNAL_SERVER_ERROR,
419                Json(ErrorResponse::new("hash_error", "Failed to hash password")),
420            )
421                .into_response()
422        }
423    }
424}
425
426/// Semantic code search endpoint
427pub async fn search(Json(request): Json<SearchRequest>) -> Response {
428    // Validate query
429    if request.query.trim().is_empty() {
430        return (
431            StatusCode::BAD_REQUEST,
432            Json(ErrorResponse::new(
433                "invalid_request",
434                "Query cannot be empty",
435            )),
436        )
437            .into_response();
438    }
439
440    // Resolve root path
441    let root = request
442        .root
443        .as_ref()
444        .map(PathBuf::from)
445        .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
446
447    if !root.exists() {
448        return (
449            StatusCode::BAD_REQUEST,
450            Json(ErrorResponse::new(
451                "invalid_root",
452                format!("Search root {} does not exist", root.display()),
453            )),
454        )
455            .into_response();
456    }
457
458    // Calculate pagination parameters
459    let page_size = request
460        .page_size
461        .unwrap_or(DEFAULT_PAGE_SIZE)
462        .clamp(1, MAX_PAGE_SIZE);
463    let page = request.page;
464    let offset = page * page_size;
465
466    // Ensure embeddings exist
467    let embeddings_path = match ensure_embeddings(&root, request.refresh).await {
468        Ok(path) => path,
469        Err(e) => {
470            tracing::error!("Failed to generate embeddings: {}", e);
471            return (
472                StatusCode::INTERNAL_SERVER_ERROR,
473                Json(ErrorResponse::new(
474                    "embeddings_error",
475                    format!("Failed to generate embeddings: {}", e),
476                )),
477            )
478                .into_response();
479        }
480    };
481
482    // Load searcher and run search
483    let mut searcher = match SemanticSearch::new(&embeddings_path) {
484        Ok(s) => s,
485        Err(e) => {
486            tracing::error!("Failed to load embeddings: {}", e);
487            return (
488                StatusCode::INTERNAL_SERVER_ERROR,
489                Json(ErrorResponse::new(
490                    "search_error",
491                    format!("Failed to load embeddings database: {}", e),
492                )),
493            )
494                .into_response();
495        }
496    };
497
498    // Fetch enough results to determine total and extract current page
499    let fetch_count = MAX_TOTAL_RESULTS.min(offset + page_size);
500    let hits = match searcher.search(&request.query, fetch_count) {
501        Ok(h) => h,
502        Err(e) => {
503            tracing::error!("Search failed: {}", e);
504            return (
505                StatusCode::INTERNAL_SERVER_ERROR,
506                Json(ErrorResponse::new(
507                    "search_error",
508                    format!("Search failed: {}", e),
509                )),
510            )
511                .into_response();
512        }
513    };
514
515    let total_results = hits.len();
516    let total_pages = total_results.div_ceil(page_size);
517
518    // Extract current page results
519    let page_results: Vec<SearchResult> = hits
520        .into_iter()
521        .skip(offset)
522        .take(page_size)
523        .map(|hit| {
524            let mut snippet = hit.content;
525            if snippet.len() > 480 {
526                snippet.truncate(480);
527                snippet.push_str("...[truncated]");
528            }
529            SearchResult {
530                path: hit.file_path,
531                similarity: hit.similarity,
532                snippet,
533            }
534        })
535        .collect();
536
537    let response = SearchResponse {
538        query: request.query,
539        root: root.display().to_string(),
540        page,
541        page_size,
542        total_results,
543        total_pages,
544        results: page_results,
545    };
546
547    Json(response).into_response()
548}
549
550/// Helper: Ensure embeddings database exists
551async fn ensure_embeddings(root: &Path, refresh: bool) -> anyhow::Result<PathBuf> {
552    let embeddings_path = root.join(".spec-ai").join("code_search_embeddings.json");
553
554    if embeddings_path.exists() && !refresh {
555        return Ok(embeddings_path);
556    }
557
558    if let Some(parent) = embeddings_path.parent() {
559        std::fs::create_dir_all(parent)?;
560    }
561
562    let options = JsonDatabaseOptions {
563        dir: root.to_path_buf(),
564        output_file_path: embeddings_path.clone(),
565        verbose: false,
566        chunker_config: Default::default(),
567        max_concurrent_files: 4,
568        ..Default::default()
569    };
570
571    let generator = JsonDatabaseGenerator::new(options)?;
572    generator.generate_database().await?;
573
574    Ok(embeddings_path)
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_uuid_generation() {
583        let uuid1 = uuid_v4();
584        let uuid2 = uuid_v4();
585
586        assert!(!uuid1.is_empty());
587        assert!(!uuid2.is_empty());
588        // UUIDs should be different (probabilistically)
589        // We won't assert this as it could theoretically fail
590    }
591
592    #[test]
593    fn test_timestamp_format() {
594        let ts = current_timestamp();
595        assert!(ts.contains('T'));
596        assert!(ts.contains('Z') || ts.contains('+'));
597    }
598}