1use crate::spec_ai_api::agent::builder::AgentBuilder;
3use crate::spec_ai_api::agent::core::AgentCore;
4use crate::spec_ai_api::api::auth::{AuthService, TokenRequest, TokenResponse};
5use crate::spec_ai_api::api::mesh::{MeshRegistry, MeshState};
6use crate::spec_ai_api::api::models::*;
7use crate::spec_ai_api::config::{AgentRegistry, AppConfig};
8use crate::spec_ai_api::persistence::Persistence;
9use crate::spec_ai_api::tools::ToolRegistry;
10use async_stream::stream;
11use axum::{
12 extract::{Json, State},
13 http::StatusCode,
14 response::{
15 IntoResponse, Response,
16 sse::{Event, Sse},
17 },
18};
19use futures::StreamExt;
20use serde_json::json;
21use std::convert::Infallible;
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24use std::time::{Instant, SystemTime, UNIX_EPOCH};
25use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
26use tokio::sync::RwLock;
27
28const DEFAULT_PAGE_SIZE: usize = 10;
29const MAX_PAGE_SIZE: usize = 25;
30const MAX_TOTAL_RESULTS: usize = 100;
31
32#[derive(Clone)]
34pub struct AppState {
35 pub persistence: Persistence,
36 pub agent_registry: Arc<AgentRegistry>,
37 pub tool_registry: Arc<ToolRegistry>,
38 pub config: AppConfig,
39 pub start_time: Instant,
40 pub mesh_registry: MeshRegistry,
41 pub auth_service: Arc<AuthService>,
42}
43
44impl AppState {
45 pub fn new(
46 persistence: Persistence,
47 agent_registry: Arc<AgentRegistry>,
48 tool_registry: Arc<ToolRegistry>,
49 config: AppConfig,
50 ) -> Self {
51 let auth_service = AuthService::new(
53 config.auth.credentials_file.as_deref(),
54 config.auth.token_secret.as_deref(),
55 Some(config.auth.token_expiry_secs),
56 config.auth.enabled,
57 )
58 .unwrap_or_else(|e| {
59 tracing::warn!("Failed to initialize auth service: {}. Auth disabled.", e);
60 AuthService::disabled()
61 });
62
63 Self {
64 persistence: persistence.clone(),
65 agent_registry,
66 tool_registry,
67 config,
68 start_time: Instant::now(),
69 mesh_registry: MeshRegistry::with_persistence(persistence),
70 auth_service: Arc::new(auth_service),
71 }
72 }
73}
74
75impl MeshState for AppState {
76 fn mesh_registry(&self) -> &MeshRegistry {
77 &self.mesh_registry
78 }
79}
80
81pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
83 let uptime = state.start_time.elapsed().as_secs();
84 let active_sessions = match state.persistence.list_sessions() {
85 Ok(sessions) => sessions.len(),
86 Err(e) => {
87 tracing::warn!("Failed to count active sessions: {}", e);
88 0
89 }
90 };
91
92 let response = HealthResponse {
93 status: "healthy".to_string(),
94 version: env!("CARGO_PKG_VERSION").to_string(),
95 uptime_seconds: uptime,
96 active_sessions,
97 };
98
99 Json(response)
100}
101
102pub async fn list_agents(State(state): State<AppState>) -> impl IntoResponse {
104 let agent_names = state.agent_registry.list();
105 let mut agent_infos = Vec::new();
106
107 for name in agent_names {
108 if let Some(profile) = state.agent_registry.get(&name) {
109 agent_infos.push(AgentInfo {
110 id: name,
111 description: profile.prompt.unwrap_or_default(),
112 allowed_tools: profile.allowed_tools.unwrap_or_default(),
113 denied_tools: profile.denied_tools.unwrap_or_default(),
114 });
115 }
116 }
117
118 Json(AgentListResponse {
119 agents: agent_infos,
120 })
121 .into_response()
122}
123
124pub async fn query(State(state): State<AppState>, Json(request): Json<QueryRequest>) -> Response {
126 if request.stream {
128 return (
129 StatusCode::BAD_REQUEST,
130 Json(ErrorResponse::new(
131 "invalid_request",
132 "Streaming not supported on /query endpoint. Use /stream instead.",
133 )),
134 )
135 .into_response();
136 }
137
138 let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
140
141 let session_id = request
143 .session_id
144 .unwrap_or_else(|| format!("api_{}", uuid_v4()));
145
146 let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
148
149 let mut agent = match agent_result {
150 Ok(agent) => agent,
151 Err(e) => {
152 return (
153 StatusCode::BAD_REQUEST,
154 Json(ErrorResponse::new("agent_error", e.to_string())),
155 )
156 .into_response();
157 }
158 };
159
160 let start = Instant::now();
162
163 match agent.run_step(&request.message).await {
164 Ok(output) => {
165 let processing_time = start.elapsed().as_millis() as u64;
166 let tool_calls: Vec<ToolCallInfo> = output
167 .tool_invocations
168 .iter()
169 .map(|inv| ToolCallInfo {
170 name: inv.name.clone(),
171 arguments: inv.arguments.clone(),
172 success: inv.success,
173 output: inv.output.clone(),
174 error: inv.error.clone(),
175 })
176 .collect();
177
178 let response = QueryResponse {
179 response: output.response,
180 session_id,
181 agent: agent_name,
182 tool_calls,
183 metadata: ResponseMetadata {
184 timestamp: current_timestamp(),
185 model: state.config.model.provider.clone(),
186 processing_time_ms: processing_time,
187 run_id: output.run_id,
188 },
189 };
190
191 Json(response).into_response()
192 }
193 Err(e) => (
194 StatusCode::INTERNAL_SERVER_ERROR,
195 Json(ErrorResponse::new("execution_error", e.to_string())),
196 )
197 .into_response(),
198 }
199}
200
201pub async fn stream_query(
203 State(state): State<AppState>,
204 Json(request): Json<QueryRequest>,
205) -> Response {
206 let agent_name = request.agent.unwrap_or_else(|| "default".to_string());
207 let session_id = request
208 .session_id
209 .unwrap_or_else(|| format!("api_{}", uuid_v4()));
210
211 let agent_result = create_agent(&state, &agent_name, &session_id, request.temperature).await;
213
214 let agent = match agent_result {
215 Ok(agent) => agent,
216 Err(e) => {
217 return (
218 StatusCode::BAD_REQUEST,
219 Json(ErrorResponse::new("agent_error", e.to_string())),
220 )
221 .into_response();
222 }
223 };
224
225 let agent = Arc::new(RwLock::new(agent));
227 let message = request.message.clone();
228 let session_id_clone = session_id.clone();
229 let agent_name_clone = agent_name.clone();
230 let model_id = state.config.model.provider.clone();
231
232 let sse_stream = stream! {
233 yield StreamChunk::Start {
234 session_id: session_id_clone.clone(),
235 agent: agent_name_clone.clone(),
236 };
237
238 let start = Instant::now();
239 let mut agent_lock = agent.write().await;
240
241 match agent_lock.run_step(&message).await {
242 Ok(output) => {
243 yield StreamChunk::Content { text: output.response.clone() };
244
245 for invocation in output.tool_invocations {
246 yield StreamChunk::ToolCall {
247 name: invocation.name.clone(),
248 arguments: invocation.arguments.clone(),
249 };
250 yield StreamChunk::ToolResult {
251 name: invocation.name.clone(),
252 result: json!({
253 "success": invocation.success,
254 "output": invocation.output,
255 "error": invocation.error,
256 }),
257 };
258 }
259
260 yield StreamChunk::End {
261 metadata: ResponseMetadata {
262 timestamp: current_timestamp(),
263 model: model_id.clone(),
264 processing_time_ms: start.elapsed().as_millis() as u64,
265 run_id: output.run_id,
266 },
267 };
268 }
269 Err(e) => {
270 yield StreamChunk::Error {
271 message: e.to_string(),
272 };
273 }
274 }
275 };
276
277 Sse::new(sse_stream.map(|chunk| {
278 let json = serde_json::to_string(&chunk).unwrap();
279 Ok::<_, Infallible>(Event::default().data(json))
280 }))
281 .into_response()
282}
283
284async fn create_agent(
286 state: &AppState,
287 agent_name: &str,
288 session_id: &str,
289 _temperature: Option<f32>,
290) -> anyhow::Result<AgentCore> {
291 let profile = state
293 .agent_registry
294 .get(agent_name)
295 .ok_or_else(|| anyhow::anyhow!("Agent '{}' not found", agent_name))?;
296
297 let agent = AgentBuilder::new()
299 .with_profile(profile)
300 .with_config(state.config.clone())
301 .with_session_id(session_id)
302 .with_agent_name(agent_name.to_string())
303 .with_tool_registry(state.tool_registry.clone())
304 .with_persistence(state.persistence.clone())
305 .build()?;
306
307 Ok(agent)
308}
309
310fn uuid_v4() -> String {
312 let rng = std::collections::hash_map::RandomState::new();
313 let hash = std::hash::BuildHasher::hash_one(&rng, SystemTime::now());
314 format!("{:x}", hash)
315}
316
317fn current_timestamp() -> String {
319 let now = SystemTime::now()
320 .duration_since(UNIX_EPOCH)
321 .unwrap()
322 .as_secs();
323 chrono::DateTime::from_timestamp(now as i64, 0)
324 .unwrap()
325 .to_rfc3339()
326}
327
328pub async fn generate_token(
330 State(state): State<AppState>,
331 Json(request): Json<TokenRequest>,
332) -> Response {
333 if !state.auth_service.is_enabled() {
335 return (
336 StatusCode::SERVICE_UNAVAILABLE,
337 Json(ErrorResponse::new(
338 "auth_disabled",
339 "Authentication is not enabled on this server",
340 )),
341 )
342 .into_response();
343 }
344
345 if !state
347 .auth_service
348 .verify_password(&request.username, &request.password)
349 {
350 return (
351 StatusCode::UNAUTHORIZED,
352 Json(ErrorResponse::new(
353 "invalid_credentials",
354 "Invalid username or password",
355 )),
356 )
357 .into_response();
358 }
359
360 match state.auth_service.generate_token(&request.username) {
362 Ok(token) => {
363 let response = TokenResponse {
364 token,
365 token_type: "Bearer".to_string(),
366 expires_in: state.config.auth.token_expiry_secs,
367 };
368 Json(response).into_response()
369 }
370 Err(e) => {
371 tracing::error!("Failed to generate token: {}", e);
372 (
373 StatusCode::INTERNAL_SERVER_ERROR,
374 Json(ErrorResponse::new(
375 "token_error",
376 "Failed to generate token",
377 )),
378 )
379 .into_response()
380 }
381 }
382}
383
384pub async fn hash_password(
387 State(state): State<AppState>,
388 Json(body): Json<serde_json::Value>,
389) -> Response {
390 if !state.auth_service.is_enabled() {
392 return (
393 StatusCode::SERVICE_UNAVAILABLE,
394 Json(ErrorResponse::new(
395 "auth_disabled",
396 "Authentication is not enabled on this server",
397 )),
398 )
399 .into_response();
400 }
401
402 let Some(password) = body.get("password").and_then(|v| v.as_str()) else {
403 return (
404 StatusCode::BAD_REQUEST,
405 Json(ErrorResponse::new(
406 "invalid_request",
407 "Missing 'password' field in request body",
408 )),
409 )
410 .into_response();
411 };
412
413 match AuthService::hash_password(password) {
414 Ok(hash) => Json(json!({ "password_hash": hash })).into_response(),
415 Err(e) => {
416 tracing::error!("Failed to hash password: {}", e);
417 (
418 StatusCode::INTERNAL_SERVER_ERROR,
419 Json(ErrorResponse::new("hash_error", "Failed to hash password")),
420 )
421 .into_response()
422 }
423 }
424}
425
426pub async fn search(Json(request): Json<SearchRequest>) -> Response {
428 if request.query.trim().is_empty() {
430 return (
431 StatusCode::BAD_REQUEST,
432 Json(ErrorResponse::new(
433 "invalid_request",
434 "Query cannot be empty",
435 )),
436 )
437 .into_response();
438 }
439
440 let root = request
442 .root
443 .as_ref()
444 .map(PathBuf::from)
445 .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
446
447 if !root.exists() {
448 return (
449 StatusCode::BAD_REQUEST,
450 Json(ErrorResponse::new(
451 "invalid_root",
452 format!("Search root {} does not exist", root.display()),
453 )),
454 )
455 .into_response();
456 }
457
458 let page_size = request
460 .page_size
461 .unwrap_or(DEFAULT_PAGE_SIZE)
462 .clamp(1, MAX_PAGE_SIZE);
463 let page = request.page;
464 let offset = page * page_size;
465
466 let embeddings_path = match ensure_embeddings(&root, request.refresh).await {
468 Ok(path) => path,
469 Err(e) => {
470 tracing::error!("Failed to generate embeddings: {}", e);
471 return (
472 StatusCode::INTERNAL_SERVER_ERROR,
473 Json(ErrorResponse::new(
474 "embeddings_error",
475 format!("Failed to generate embeddings: {}", e),
476 )),
477 )
478 .into_response();
479 }
480 };
481
482 let mut searcher = match SemanticSearch::new(&embeddings_path) {
484 Ok(s) => s,
485 Err(e) => {
486 tracing::error!("Failed to load embeddings: {}", e);
487 return (
488 StatusCode::INTERNAL_SERVER_ERROR,
489 Json(ErrorResponse::new(
490 "search_error",
491 format!("Failed to load embeddings database: {}", e),
492 )),
493 )
494 .into_response();
495 }
496 };
497
498 let fetch_count = MAX_TOTAL_RESULTS.min(offset + page_size);
500 let hits = match searcher.search(&request.query, fetch_count) {
501 Ok(h) => h,
502 Err(e) => {
503 tracing::error!("Search failed: {}", e);
504 return (
505 StatusCode::INTERNAL_SERVER_ERROR,
506 Json(ErrorResponse::new(
507 "search_error",
508 format!("Search failed: {}", e),
509 )),
510 )
511 .into_response();
512 }
513 };
514
515 let total_results = hits.len();
516 let total_pages = total_results.div_ceil(page_size);
517
518 let page_results: Vec<SearchResult> = hits
520 .into_iter()
521 .skip(offset)
522 .take(page_size)
523 .map(|hit| {
524 let mut snippet = hit.content;
525 if snippet.len() > 480 {
526 snippet.truncate(480);
527 snippet.push_str("...[truncated]");
528 }
529 SearchResult {
530 path: hit.file_path,
531 similarity: hit.similarity,
532 snippet,
533 }
534 })
535 .collect();
536
537 let response = SearchResponse {
538 query: request.query,
539 root: root.display().to_string(),
540 page,
541 page_size,
542 total_results,
543 total_pages,
544 results: page_results,
545 };
546
547 Json(response).into_response()
548}
549
550async fn ensure_embeddings(root: &Path, refresh: bool) -> anyhow::Result<PathBuf> {
552 let embeddings_path = root.join(".spec-ai").join("code_search_embeddings.json");
553
554 if embeddings_path.exists() && !refresh {
555 return Ok(embeddings_path);
556 }
557
558 if let Some(parent) = embeddings_path.parent() {
559 std::fs::create_dir_all(parent)?;
560 }
561
562 let options = JsonDatabaseOptions {
563 dir: root.to_path_buf(),
564 output_file_path: embeddings_path.clone(),
565 verbose: false,
566 chunker_config: Default::default(),
567 max_concurrent_files: 4,
568 ..Default::default()
569 };
570
571 let generator = JsonDatabaseGenerator::new(options)?;
572 generator.generate_database().await?;
573
574 Ok(embeddings_path)
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_uuid_generation() {
583 let uuid1 = uuid_v4();
584 let uuid2 = uuid_v4();
585
586 assert!(!uuid1.is_empty());
587 assert!(!uuid2.is_empty());
588 }
591
592 #[test]
593 fn test_timestamp_format() {
594 let ts = current_timestamp();
595 assert!(ts.contains('T'));
596 assert!(ts.contains('Z') || ts.contains('+'));
597 }
598}