1use super::{state::ServerState, task::TaskResult, types::*};
6use crate::observability::{get_global_cost_tracker, LLMCallContext};
7use crate::planner::cost::CostCalculator;
8use crate::planner::tokens::TokenCounter;
9use crate::streaming::{create_text_stream, StreamHandler, TextChunk};
10use crate::types::database::IDatabaseAdapter;
11use crate::types::memory::MemoryQuery;
12use crate::{
13 types::{ChannelType, Content, Memory, Room},
14 AgentRuntime, ZoeyError, MessageProcessor, Result,
15};
16use axum::response::sse::{Event, Sse};
17use axum::{
18 extract::State,
19 http::StatusCode,
20 response::{IntoResponse, Response},
21 Json,
22};
23use futures_util::stream::{BoxStream, StreamExt};
24use reqwest::Client as HttpClient;
25use serde::Deserialize;
26use serde_json::Value as JsonValue;
27use std::sync::OnceLock;
28use std::time::{Duration, Instant};
29use std::{
30 collections::HashMap,
31 sync::{Arc, RwLock},
32};
33use tokio_stream::wrappers::ReceiverStream;
34use tracing::{debug, error, info, warn};
35use uuid::Uuid;
36
37async fn run_chat_stream_job(
38 runtime: Arc<RwLock<AgentRuntime>>,
39 req_clone: ChatRequest,
40 stream_handler: StreamHandler,
41) {
42 let (provider, available_providers) = {
43 let rt_guard = runtime.read().unwrap();
44 let pref = rt_guard
45 .get_setting("model_provider")
46 .and_then(|v| v.as_str().map(|s| s.to_string()));
47 let providers: Vec<String> = rt_guard.get_providers().iter().map(|p| p.name().to_string()).collect();
48 (pref, providers)
49 };
50 info!(
51 "INTERACTION_PROVIDER provider_pref={} available=[{}]",
52 provider.clone().unwrap_or_else(|| "<none>".to_string()),
53 available_providers.join(", ")
54 );
55
56 if provider
57 .as_deref()
58 .map(|s| s.eq_ignore_ascii_case("openai"))
59 .unwrap_or(false)
60 {
61 let entity_id = req_clone.entity_id.unwrap_or_else(Uuid::new_v4);
62 let (agent_id, adapter) = {
63 let rt = runtime.read().unwrap();
64 let adapter = rt.adapter.read().unwrap().clone();
65 (rt.agent_id, adapter)
66 };
67 let recent_conversation = if let Some(ref adapter) = adapter {
68 fetch_recent_conversation(
69 adapter.as_ref(),
70 req_clone.room_id,
71 agent_id,
72 &{
73 let rt = runtime.read().unwrap();
74 rt.character.name.clone()
75 },
76 5, )
78 .await
79 } else {
80 String::new()
81 };
82 let (character_name, character_bio, ui_tone, ui_verbosity, last, prev) = {
83 let rt = runtime.read().unwrap();
84 let name = rt.character.name.clone();
85 let bio = rt.character.bio.clone().join(" ");
86 let tone = rt
87 .get_setting("ui:tone")
88 .and_then(|v| v.as_str().map(|s| s.to_string()));
89 let verbosity = rt.get_setting("ui:verbosity").map(|v| v.to_string());
90 let last_key = format!("ui:lastPrompt:{}:last", req_clone.room_id);
91 let prev_key = format!("ui:lastPrompt:{}:prev", req_clone.room_id);
92 let last = rt
93 .get_setting(&last_key)
94 .and_then(|v| v.as_str().map(|s| s.to_string()));
95 let prev = rt
96 .get_setting(&prev_key)
97 .and_then(|v| v.as_str().map(|s| s.to_string()));
98 (name, bio, tone, verbosity, last, prev)
99 };
100 let mut state = crate::types::State::new();
101 state.set_value(
102 "CHARACTER",
103 format!("Name: {}\nBio: {}", character_name, character_bio),
104 );
105 if let Some(t) = ui_tone {
106 state.set_value("UI_TONE", t);
107 }
108 if let Some(v) = ui_verbosity {
109 state.set_value("UI_VERBOSITY", v);
110 }
111 if let Some(p) = prev.clone() {
112 state.set_value("PREV_PROMPT", p);
113 }
114 if let Some(l) = last.clone() {
115 state.set_value("LAST_PROMPT", l);
116 }
117 state.set_value("ENTITY_NAME", "User");
118 state.set_value("MESSAGE_TEXT", req_clone.text.clone());
119 let recent = if !recent_conversation.is_empty() {
120 format!("{}\nUser: {}", recent_conversation, req_clone.text)
121 } else {
122 format!(
123 "{}\n{}\nUser: {}",
124 prev.map(|p| format!("User: {}", p)).unwrap_or_default(),
125 last.map(|l| format!("User: {}", l)).unwrap_or_default(),
126 req_clone.text
127 )
128 };
129 state.set_value("RECENT_MESSAGES", recent);
130
131 let message = crate::types::Memory {
133 id: uuid::Uuid::new_v4(),
134 entity_id,
135 agent_id,
136 room_id: req_clone.room_id,
137 content: crate::types::Content {
138 text: req_clone.text.clone(),
139 ..Default::default()
140 },
141 embedding: None,
142 metadata: None,
143 created_at: chrono::Utc::now().timestamp(),
144 unique: None,
145 similarity: None,
146 };
147 let providers = runtime.read().unwrap().providers.read().unwrap().clone();
148 let runtime_ref: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(());
149 for provider in &providers {
150 let name = provider.name().to_lowercase();
152 if name.contains("planner") || name.contains("recall") {
153 continue;
154 }
155 if let Ok(result) = provider.get(runtime_ref.clone(), &message, &state).await {
156 if let Some(text) = result.text {
157 state.set_value(provider.name().to_uppercase(), text);
158 }
159 if let Some(values) = result.values {
160 for (k, v) in values {
161 state.set_value(k, v);
162 }
163 }
164 }
165 }
166
167 let template = crate::templates::MESSAGE_HANDLER_TEMPLATE;
168 let prompt = crate::templates::compose_prompt_from_state(&state, template)
169 .unwrap_or_else(|_| req_clone.text.clone());
170
171 let prompt_preview: String = prompt.chars().take(500).collect();
173 let prompt_len = prompt.len();
174 debug!(
175 "INTERACTION_PROMPT room_id={} prompt_len={} preview={}...",
176 req_clone.room_id, prompt_len, prompt_preview
177 );
178 {
180 let mut rt = runtime.write().unwrap();
181 rt.set_setting("ui:lastPrompt", serde_json::json!(req_clone.text.clone()), false);
182 let last_key = format!("ui:lastPrompt:{}:prev", req_clone.room_id);
183 if let Some(old_last) = rt.get_setting(&format!("ui:lastPrompt:{}:last", req_clone.room_id)) {
184 rt.set_setting(&last_key, old_last, false);
185 }
186 rt.set_setting(&format!("ui:lastPrompt:{}:last", req_clone.room_id), serde_json::json!(req_clone.text.clone()), false);
187 }
188
189 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
190 if api_key.is_empty() {
191 let _ = stream_handler
192 .send_error(ZoeyError::other("OPENAI_API_KEY is not set"))
193 .await;
194 return;
195 }
196 static OPENAI_CLIENT: OnceLock<HttpClient> = OnceLock::new();
197 let model = {
198 let rt = runtime.read().unwrap();
199 rt.get_setting("OPENAI_MODEL")
200 .and_then(|v| v.as_str().map(|s| s.to_string()))
201 .unwrap_or_else(|| "gpt-4o-mini".to_string())
202 };
203 let dynamic_max = {
204 let calc = CostCalculator::new();
205 let mk = if model.contains("gpt-4o") {
206 "gpt-4o".to_string()
207 } else if model.contains("gpt-4") {
208 "gpt-4".to_string()
209 } else {
210 "gpt-4o-mini".to_string()
211 };
212 if let Some(pricing) = calc.get_pricing(&mk) {
213 let est_in = TokenCounter::estimate_tokens(&prompt);
214 let mut avail = if pricing.context_window > est_in {
215 pricing.context_window - est_in
216 } else {
217 0
218 };
219 avail = avail.min(pricing.max_output_tokens);
220 let safety = 64usize;
221 if avail > safety {
222 avail.saturating_sub(safety)
223 } else {
224 256
225 }
226 } else {
227 768
228 }
229 };
230 let client = OPENAI_CLIENT
231 .get_or_init(|| {
232 reqwest::Client::builder()
233 .pool_max_idle_per_host(50)
234 .pool_idle_timeout(std::time::Duration::from_secs(300))
235 .tcp_keepalive(std::time::Duration::from_secs(60))
236 .timeout(std::time::Duration::from_secs(120))
237 .build()
238 .unwrap_or_else(|_| reqwest::Client::new())
239 })
240 .clone();
241 let req_body = serde_json::json!({
242 "model": model,
243 "stream": true,
244 "max_tokens": std::cmp::max(dynamic_max, 2048),
245 "messages": [
246 {"role": "user", "content": prompt}
247 ]
248 });
249 let stream_timeout = std::env::var("OPENAI_STREAM_TIMEOUT_SECS")
250 .ok()
251 .and_then(|s| s.parse::<u64>().ok())
252 .unwrap_or(45);
253 let stream_start = Instant::now();
254 let prompt_tokens = TokenCounter::estimate_tokens(&prompt);
255 let resp = tokio::time::timeout(
256 Duration::from_secs(stream_timeout),
257 client
258 .post("https://api.openai.com/v1/chat/completions")
259 .bearer_auth(api_key)
260 .json(&req_body)
261 .send(),
262 )
263 .await;
264 match resp {
265 Err(_) => {
266 let _ = stream_handler
267 .send_error(ZoeyError::other("OpenAI streaming request timed out"))
268 .await;
269 }
270 Ok(Err(e)) => {
271 let _ = stream_handler
272 .send_error(ZoeyError::other(format!(
273 "OpenAI streaming request failed: {}",
274 e
275 )))
276 .await;
277 }
278 Ok(Ok(mut r)) => {
279 let mut buffer = String::new();
280 let mut full_text = String::new();
281 let mut last_chunk_at = Instant::now();
282 while let Ok(chunk_result) = tokio::time::timeout(
283 Duration::from_secs(stream_timeout),
284 r.chunk(),
285 )
286 .await
287 {
288 last_chunk_at = Instant::now();
289 let chunk = match chunk_result {
290 Ok(opt) => match opt {
291 Some(c) => c,
292 None => break,
293 },
294 Err(e) => {
295 let _ = stream_handler
296 .send_error(ZoeyError::other(format!(
297 "OpenAI streaming chunk failed: {}",
298 e
299 )))
300 .await;
301 break;
302 }
303 };
304 let s = String::from_utf8_lossy(&chunk);
305 buffer.push_str(&s);
306 let mut parts: Vec<&str> = buffer.split('\n').collect();
307 let tail = parts.pop().unwrap_or("");
308 for line in parts {
309 let l = line.trim();
310 if !l.starts_with("data:") {
311 continue;
312 }
313 let payload = l.trim_start_matches("data:").trim();
314 if payload == "[DONE]" {
315 let _ = stream_handler.send_chunk(String::new(), true).await;
316 let latency_ms = stream_start.elapsed().as_millis() as u64;
317 let completion_tokens = TokenCounter::estimate_tokens(&full_text);
318 let adapter = {
320 let rt = runtime.read().unwrap();
321 let x = rt.adapter.read().unwrap().clone();
322 x
323 };
324 if let Some(adapter) = adapter.as_ref() {
325 let agent_id = {
326 let rt = runtime.read().unwrap();
327 rt.agent_id
328 };
329 let response = Memory {
330 id: Uuid::new_v4(),
331 entity_id: agent_id,
332 agent_id,
333 room_id: req_clone.room_id,
334 content: Content {
335 text: full_text.clone(),
336 source: Some(req_clone.source.clone()),
337 ..Default::default()
338 },
339 embedding: None,
340 metadata: None,
341 created_at: chrono::Utc::now().timestamp(),
342 unique: Some(false),
343 similarity: None,
344 };
345 let _ = adapter.create_memory(&response, "messages").await;
346 }
347 if let Some(tracker) = get_global_cost_tracker() {
349 let context = LLMCallContext {
350 agent_id,
351 user_id: req_clone.entity_id.map(|u| u.to_string()),
352 conversation_id: Some(req_clone.room_id),
353 action_name: None,
354 evaluator_name: None,
355 temperature: Some(0.7),
356 cached_tokens: None,
357 ttft_ms: None,
358 prompt_hash: None,
359 prompt_preview: Some(req_clone.text.chars().take(100).collect()),
360 };
361 match tracker.record_llm_call(
362 "openai",
363 &model,
364 prompt_tokens,
365 completion_tokens,
366 latency_ms,
367 agent_id,
368 context,
369 ).await {
370 Ok(record) => {
371 info!("COST_TRACKED provider=openai model={} prompt_tokens={} completion_tokens={} cost_usd={:.6} latency_ms={}",
372 model, prompt_tokens, completion_tokens, record.total_cost_usd, latency_ms);
373 }
374 Err(e) => {
375 error!("Failed to track cost: {}", e);
376 }
377 }
378 }
379 {
380 let mut rt = runtime.write().unwrap();
381 let key = format!("ui:lastAddressed:{}", req_clone.room_id);
382 rt.set_setting(
383 &key,
384 serde_json::json!(chrono::Utc::now().timestamp()),
385 false,
386 );
387 }
388
389 let sample_id = {
391 let collector = {
392 let rt = runtime.read().unwrap();
393 rt.get_training_collector()
394 };
395 if let Some(collector) = collector {
396 match collector.record_interaction(
397 req_clone.text.clone(),
398 full_text.clone(),
399 None, 0.7, ).await {
402 Ok(id) => {
403 info!("TRAINING_SAMPLE_RECORDED sample_id={} prompt_len={} response_len={}",
404 id, req_clone.text.len(), full_text.len());
405 Some(id)
406 }
407 Err(e) => {
408 debug!("Training sample not recorded: {}", e);
409 None
410 }
411 }
412 } else {
413 None
414 }
415 };
416
417 if let Some(sid) = sample_id {
419 let _ = stream_handler.send_chunk_with_meta(
420 String::new(),
421 true,
422 Some(serde_json::json!({ "sampleId": sid.to_string() }))
423 ).await;
424 }
425
426 break;
427 }
428 if let Ok(json) = serde_json::from_str::<JsonValue>(payload) {
429 if let Some(choices) = json.get("choices").and_then(|v| v.as_array()) {
430 if let Some(delta) = choices.get(0).and_then(|c| c.get("delta")) {
431 if let Some(content) =
432 delta.get("content").and_then(|v| v.as_str())
433 {
434 let _ = stream_handler
435 .send_chunk(content.to_string(), false)
436 .await;
437 full_text.push_str(content);
438 }
439 }
440 }
441 }
442 }
443 buffer = tail.to_string();
444 }
445 }
446 Err(e) => {
447 let _ = stream_handler
448 .send_error(ZoeyError::other(format!(
449 "OpenAI streaming request failed: {}",
450 e
451 )))
452 .await;
453 }
454 }
455 return;
456 }
457
458 let is_local = provider
460 .as_deref()
461 .map(|s| {
462 let lc = s.to_lowercase();
463 lc == "ollama" || lc == "local" || lc == "local-llm" || lc == "llama" || lc == "llamacpp"
464 })
465 .unwrap_or(false);
466 info!("OLLAMA_CHECK is_local={} provider={:?}", is_local, provider);
467 if is_local {
468 let entity_id = req_clone.entity_id.unwrap_or_else(Uuid::new_v4);
469 let (agent_id, adapter) = {
470 let rt = runtime.read().unwrap();
471 let adapter = rt.adapter.read().unwrap().clone();
472 (rt.agent_id, adapter)
473 };
474 let recent_conversation = if let Some(ref adapter) = adapter {
475 fetch_recent_conversation(
476 adapter.as_ref(),
477 req_clone.room_id,
478 agent_id,
479 &{
480 let rt = runtime.read().unwrap();
481 rt.character.name.clone()
482 },
483 5, )
485 .await
486 } else {
487 String::new()
488 };
489 let (character_name, character_bio, ui_tone, ui_verbosity, last, prev) = {
490 let rt = runtime.read().unwrap();
491 let name = rt.character.name.clone();
492 let bio = rt.character.bio.clone().join(" ");
493 let tone = rt
494 .get_setting("ui:tone")
495 .and_then(|v| v.as_str().map(|s| s.to_string()));
496 let verbosity = rt.get_setting("ui:verbosity").map(|v| v.to_string());
497 let last_key = format!("ui:lastPrompt:{}:last", req_clone.room_id);
498 let prev_key = format!("ui:lastPrompt:{}:prev", req_clone.room_id);
499 let last = rt
500 .get_setting(&last_key)
501 .and_then(|v| v.as_str().map(|s| s.to_string()));
502 let prev = rt
503 .get_setting(&prev_key)
504 .and_then(|v| v.as_str().map(|s| s.to_string()));
505 (name, bio, tone, verbosity, last, prev)
506 };
507 let mut state = crate::types::State::new();
508 state.set_value(
509 "CHARACTER",
510 format!("Name: {}\nBio: {}", character_name, character_bio),
511 );
512 if let Some(t) = ui_tone {
513 state.set_value("UI_TONE", t);
514 }
515 if let Some(v) = ui_verbosity {
516 state.set_value("UI_VERBOSITY", v);
517 }
518 if let Some(p) = prev.clone() {
519 state.set_value("PREV_PROMPT", p);
520 }
521 if let Some(l) = last.clone() {
522 state.set_value("LAST_PROMPT", l);
523 }
524 state.set_value("ENTITY_NAME", "User");
525 state.set_value("MESSAGE_TEXT", req_clone.text.clone());
526 let recent = if !recent_conversation.is_empty() {
527 format!("{}\nUser: {}", recent_conversation, req_clone.text)
528 } else {
529 format!(
530 "{}\n{}\nUser: {}",
531 prev.map(|p| format!("User: {}", p)).unwrap_or_default(),
532 last.map(|l| format!("User: {}", l)).unwrap_or_default(),
533 req_clone.text
534 )
535 };
536 state.set_value("RECENT_MESSAGES", recent);
537
538 if let Some(knowledge_context) = retrieve_knowledge_context(req_clone.room_id, &req_clone.text, 5) {
540 info!(
541 "KNOWLEDGE_CONTEXT_INJECTED room_id={} context_len={}",
542 req_clone.room_id,
543 knowledge_context.len()
544 );
545 state.set_value("KNOWLEDGE_CONTEXT", knowledge_context);
546 }
547
548 let template = crate::templates::MESSAGE_HANDLER_TEMPLATE;
549 let prompt = crate::templates::compose_prompt_from_state(&state, template)
550 .unwrap_or_else(|_| req_clone.text.clone());
551
552 {
554 let mut rt = runtime.write().unwrap();
555 rt.set_setting("ui:lastPrompt", serde_json::json!(req_clone.text.clone()), false);
556 let last_key = format!("ui:lastPrompt:{}:prev", req_clone.room_id);
557 if let Some(old_last) = rt.get_setting(&format!("ui:lastPrompt:{}:last", req_clone.room_id)) {
558 rt.set_setting(&last_key, old_last, false);
559 }
560 rt.set_setting(&format!("ui:lastPrompt:{}:last", req_clone.room_id), serde_json::json!(req_clone.text.clone()), false);
561 }
562
563 let (ollama_base, ollama_model, max_tokens) = {
565 let rt = runtime.read().unwrap();
566 let base = rt.get_setting("LOCAL_LLM_ENDPOINT")
567 .and_then(|v| v.as_str().map(|s| s.to_string()))
568 .unwrap_or_else(|| std::env::var("OLLAMA_BASE_URL").unwrap_or_else(|_| "http://localhost:11434".to_string()));
569 let model = rt.get_setting("LOCAL_LLM_MODEL")
570 .and_then(|v| v.as_str().map(|s| s.to_string()))
571 .unwrap_or_else(|| std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3.2".to_string()));
572 let max = rt.get_setting("LOCAL_LLM_MAX_TOKENS")
573 .and_then(|v| v.as_u64().map(|u| u as usize))
574 .unwrap_or(800);
575 (base, model, max)
576 };
577
578 info!(
579 "OLLAMA_STREAMING endpoint={} model={} prompt_len={}",
580 ollama_base, ollama_model, prompt.len()
581 );
582
583 static OLLAMA_CLIENT: OnceLock<HttpClient> = OnceLock::new();
584 let client = OLLAMA_CLIENT
585 .get_or_init(|| {
586 reqwest::Client::builder()
587 .pool_max_idle_per_host(10)
588 .pool_idle_timeout(std::time::Duration::from_secs(120))
589 .timeout(std::time::Duration::from_secs(300))
590 .build()
591 .unwrap_or_else(|_| reqwest::Client::new())
592 })
593 .clone();
594
595 let req_body = serde_json::json!({
596 "model": ollama_model,
597 "prompt": prompt,
598 "stream": true,
599 "options": {
600 "temperature": 0.7,
601 "num_predict": max_tokens
602 }
603 });
604
605 let stream_timeout = std::env::var("OLLAMA_STREAM_TIMEOUT_SECS")
606 .ok()
607 .and_then(|s| s.parse::<u64>().ok())
608 .unwrap_or(120);
609 let stream_start = Instant::now();
610
611 let resp = tokio::time::timeout(
612 Duration::from_secs(stream_timeout),
613 client
614 .post(format!("{}/api/generate", ollama_base))
615 .json(&req_body)
616 .send(),
617 )
618 .await;
619
620 match resp {
621 Err(_) => {
622 let _ = stream_handler
623 .send_error(ZoeyError::other("Ollama streaming request timed out"))
624 .await;
625 }
626 Ok(Err(e)) => {
627 let _ = stream_handler
628 .send_error(ZoeyError::other(format!(
629 "Ollama streaming request failed: {}. Check if Ollama is running at {}",
630 e, ollama_base
631 )))
632 .await;
633 }
634 Ok(Ok(mut r)) => {
635 info!("OLLAMA_RESPONSE status={}", r.status());
636 if !r.status().is_success() {
637 let status = r.status();
638 let error_text = r.text().await.unwrap_or_default();
639 let _ = stream_handler
640 .send_error(ZoeyError::other(format!(
641 "Ollama API error {}: {}",
642 status, error_text
643 )))
644 .await;
645 return;
646 }
647
648 let mut buffer = String::new();
649 let mut full_text = String::new();
650 let mut chunks_received = 0usize;
651 while let Ok(chunk_result) = tokio::time::timeout(
652 Duration::from_secs(stream_timeout),
653 r.chunk(),
654 )
655 .await
656 {
657 let chunk = match chunk_result {
658 Ok(opt) => match opt {
659 Some(c) => c,
660 None => break,
661 },
662 Err(e) => {
663 let _ = stream_handler
664 .send_error(ZoeyError::other(format!(
665 "Ollama streaming chunk failed: {}",
666 e
667 )))
668 .await;
669 break;
670 }
671 };
672 let s = String::from_utf8_lossy(&chunk);
673 buffer.push_str(&s);
674
675 let mut parts: Vec<&str> = buffer.split('\n').collect();
677 let tail = parts.pop().unwrap_or("");
678 for line in parts {
679 let l = line.trim();
680 if l.is_empty() {
681 continue;
682 }
683 if let Ok(json) = serde_json::from_str::<JsonValue>(l) {
684 if let Some(response) = json.get("response").and_then(|v| v.as_str()) {
685 chunks_received += 1;
686 if chunks_received == 1 {
687 info!("OLLAMA_FIRST_CHUNK received, len={}", response.len());
688 }
689 let _ = stream_handler
690 .send_chunk(response.to_string(), false)
691 .await;
692 full_text.push_str(response);
693 }
694 if json.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
696 info!("OLLAMA_DONE total_chunks={} response_len={}", chunks_received, full_text.len());
697
698 let sample_id = {
700 let collector = {
701 let rt = runtime.read().unwrap();
702 rt.get_training_collector()
703 };
704 if let Some(collector) = collector {
705 match collector.record_interaction(
706 req_clone.text.clone(),
707 full_text.clone(),
708 None,
709 0.7,
710 ).await {
711 Ok(id) => {
712 info!("TRAINING_SAMPLE_RECORDED sample_id={} prompt_len={} response_len={}",
713 id, req_clone.text.len(), full_text.len());
714 Some(id)
715 }
716 Err(e) => {
717 debug!("Training sample not recorded: {}", e);
718 None
719 }
720 }
721 } else {
722 None
723 }
724 };
725
726 if let Some(sid) = sample_id {
728 let _ = stream_handler.send_chunk_with_meta(
729 String::new(),
730 true,
731 Some(serde_json::json!({ "sampleId": sid.to_string() }))
732 ).await;
733 } else {
734 let _ = stream_handler.send_chunk(String::new(), true).await;
735 }
736 if let Some(adapter) = adapter.as_ref() {
738 let response = Memory {
739 id: Uuid::new_v4(),
740 entity_id: agent_id,
741 agent_id,
742 room_id: req_clone.room_id,
743 content: Content {
744 text: full_text.clone(),
745 source: Some(req_clone.source.clone()),
746 ..Default::default()
747 },
748 embedding: None,
749 metadata: None,
750 created_at: chrono::Utc::now().timestamp(),
751 unique: Some(false),
752 similarity: None,
753 };
754 let _ = adapter.create_memory(&response, "messages").await;
755 }
756 if let Some(tracker) = get_global_cost_tracker() {
758 let latency_ms = stream_start.elapsed().as_millis() as u64;
759 let prompt_tokens = TokenCounter::estimate_tokens(&prompt);
760 let completion_tokens = TokenCounter::estimate_tokens(&full_text);
761 let context = LLMCallContext {
762 agent_id,
763 user_id: req_clone.entity_id.map(|u| u.to_string()),
764 conversation_id: Some(req_clone.room_id),
765 action_name: None,
766 evaluator_name: None,
767 temperature: Some(0.7),
768 cached_tokens: None,
769 ttft_ms: None,
770 prompt_hash: None,
771 prompt_preview: Some(req_clone.text.chars().take(100).collect()),
772 };
773 let _ = tracker.record_llm_call(
774 "ollama",
775 &ollama_model,
776 prompt_tokens,
777 completion_tokens,
778 latency_ms,
779 agent_id,
780 context,
781 ).await;
782 }
783 {
784 let mut rt = runtime.write().unwrap();
785 let key = format!("ui:lastAddressed:{}", req_clone.room_id);
786 rt.set_setting(
787 &key,
788 serde_json::json!(chrono::Utc::now().timestamp()),
789 false,
790 );
791 }
792 break;
793 }
794 }
795 }
796 buffer = tail.to_string();
797 }
798 if chunks_received > 0 {
800 info!("OLLAMA_STREAM_END chunks={} response_len={}", chunks_received, full_text.len());
801 } else {
802 error!("OLLAMA_STREAM_END no chunks received");
803 }
804 }
805 }
806 return;
807 }
808
809 match process_chat_task(runtime.clone(), req_clone.clone()).await {
811 Ok(resp) => {
812 let final_text = resp
813 .messages
814 .as_ref()
815 .and_then(|v| v.first())
816 .map(|m| m.content.text.clone())
817 .unwrap_or_default();
818 let chunk_size = 80usize;
819 let mut idx = 0;
820 if final_text.is_empty() {
821 let _ = stream_handler.send_chunk(String::new(), true).await;
822 } else {
823 while idx < final_text.len() {
824 let end = (idx + chunk_size).min(final_text.len());
825 let piece = final_text[idx..end].to_string();
826 let is_final = end >= final_text.len();
827 if stream_handler.send_chunk(piece, is_final).await.is_err() {
828 break;
829 }
830 idx = end;
831 if !is_final {
832 tokio::task::yield_now().await;
833 }
834 }
835 }
836 }
837 Err(e) => {
838 let _ = stream_handler
839 .send_error(ZoeyError::other(format!("Streaming failed: {}", e)))
840 .await;
841 }
842 }
843}
844
845async fn fetch_recent_conversation(
849 adapter: &dyn IDatabaseAdapter,
850 room_id: Uuid,
851 agent_id: Uuid,
852 agent_name: &str,
853 limit: usize,
854) -> String {
855 let query = MemoryQuery {
856 room_id: Some(room_id),
857 table_name: "messages".to_string(),
858 count: Some(limit),
859 ..Default::default()
860 };
861
862 match adapter.get_memories(query).await {
863 Ok(mut memories) => {
864 if memories.is_empty() {
865 return String::new();
866 }
867
868 memories.sort_by_key(|m| m.created_at);
870
871 memories
873 .iter()
874 .map(|m| {
875 let speaker = if m.entity_id == agent_id {
877 agent_name.to_string()
878 } else {
879 m.metadata
881 .as_ref()
882 .and_then(|meta| meta.entity_name.clone())
883 .unwrap_or_else(|| "User".to_string())
884 };
885 format!("{}: {}", speaker, m.content.text)
886 })
887 .collect::<Vec<_>>()
888 .join("\n")
889 }
890 Err(e) => {
891 eprintln!("[WARN] Failed to fetch recent conversation: {}", e);
892 String::new()
893 }
894 }
895}
896
897pub async fn health_check(State(state): State<ServerState>) -> Json<HealthResponse> {
899 let runtime = state.api_state.runtime.read().unwrap();
900 Json(HealthResponse {
901 status: "ok".to_string(),
902 agent_id: runtime.agent_id,
903 agent_name: runtime.character.name.clone(),
904 uptime: state.api_state.start_time.elapsed().as_secs(),
905 timestamp: chrono::Utc::now().to_rfc3339(),
906 })
907}
908
909pub async fn chat_handler(
911 State(server_state): State<ServerState>,
912 Json(request): Json<ChatRequest>,
913) -> Response {
914 let agent_name = {
915 let rt = server_state.api_state.runtime.read().unwrap();
916 rt.character.name.clone()
917 };
918 info!(
919 "[{}] chat request room_id={}, stream={}, text_len={}",
920 agent_name,
921 request.room_id,
922 request.stream,
923 request.text.len()
924 );
925
926 if request.stream {
928 return chat_stream_handler(State(server_state), Json(request))
929 .await
930 .into_response();
931 }
932
933 if request.text.trim().is_empty() {
935 return ApiError::BadRequest("Message text cannot be empty".to_string()).into_response();
936 }
937 let max_len = std::env::var("API_MAX_MESSAGE_BYTES")
938 .ok()
939 .and_then(|s| s.parse::<usize>().ok())
940 .unwrap_or(512_000); if request.text.len() > max_len {
942 return (
943 StatusCode::PAYLOAD_TOO_LARGE,
944 Json(serde_json::json!({
945 "success": false,
946 "error": "Message too large",
947 "code": StatusCode::PAYLOAD_TOO_LARGE.as_u16(),
948 })),
949 )
950 .into_response();
951 }
952
953 {
955 let runtime = server_state.api_state.runtime.clone();
956 let mut rt = runtime.write().unwrap();
957 let last_key = format!("ui:lastPrompt:{}:last", request.room_id);
958 let prev_key = format!("ui:lastPrompt:{}:prev", request.room_id);
959 let prev = rt
960 .get_setting(&last_key)
961 .and_then(|v| v.as_str().map(|s| s.to_string()));
962 if let Some(p) = prev {
963 rt.set_setting(&prev_key, serde_json::json!(p), false);
964 }
965 rt.set_setting(&last_key, serde_json::json!(request.text.clone()), false);
966 if let Some(owner) = request.entity_id {
967 let owner_key = format!("ROOM_OWNER:{}", request.room_id);
968 if rt.get_setting(&owner_key).is_none() {
969 rt.set_setting(&owner_key, serde_json::json!(owner.to_string()), false);
970 }
971 }
972 }
973
974 let task_id = server_state.task_manager.create_task();
976 let task_manager = server_state.task_manager.clone();
977 let runtime = server_state.api_state.runtime.clone();
978 let req_clone = request.clone();
979 std::thread::spawn(move || {
980 let rt = tokio::runtime::Builder::new_current_thread()
981 .enable_all()
982 .build()
983 .expect("chat task runtime");
984 rt.block_on(async move {
985 task_manager.mark_running(task_id);
986 let timeout_res = tokio::time::timeout(
987 Duration::from_secs(90),
988 process_chat_task(runtime, req_clone),
989 )
990 .await;
991
992 match timeout_res {
993 Ok(Ok(response)) => task_manager.complete_task(task_id, TaskResult::Chat(response)),
994 Ok(Err(e)) => task_manager.fail_task(task_id, e.to_string()),
995 Err(_) => task_manager.fail_task(task_id, "Chat task timed out".to_string()),
996 }
997 });
998 });
999
1000 Json(TaskResponse {
1001 success: true,
1002 task_id,
1003 message: "Chat task submitted successfully. Poll /agent/task/{task_id} for results."
1004 .to_string(),
1005 estimated_time_ms: Some(3000),
1006 })
1007 .into_response()
1008}
1009
1010pub async fn chat_stream_handler(
1012 State(server_state): State<ServerState>,
1013 Json(request): Json<ChatRequest>,
1014) -> impl IntoResponse {
1015 let agent_name = {
1016 let rt = server_state.api_state.runtime.read().unwrap();
1017 rt.character.name.clone()
1018 };
1019 info!(
1020 "INTERACTION_REQUEST_STREAM agent={} room_id={} text_len={} text_preview={}",
1021 agent_name,
1022 request.room_id,
1023 request.text.len(),
1024 request.text.chars().take(120).collect::<String>()
1025 );
1026
1027 if request.text.trim().is_empty() {
1028 return ApiError::BadRequest("Message text cannot be empty".to_string()).into_response();
1029 }
1030 {
1031 let max_len = std::env::var("API_MAX_MESSAGE_BYTES")
1032 .ok()
1033 .and_then(|s| s.parse::<usize>().ok())
1034 .unwrap_or(512_000); if request.text.len() > max_len {
1036 return (
1037 StatusCode::PAYLOAD_TOO_LARGE,
1038 Json(serde_json::json!({
1039 "success": false,
1040 "error": "Message too large",
1041 "code": StatusCode::PAYLOAD_TOO_LARGE.as_u16(),
1042 })),
1043 )
1044 .into_response();
1045 }
1046 }
1047
1048 {
1049 let runtime = server_state.api_state.runtime.clone();
1050 let mut rt = runtime.write().unwrap();
1051 let last_key = format!("ui:lastPrompt:{}:last", request.room_id);
1052 let prev_key = format!("ui:lastPrompt:{}:prev", request.room_id);
1053 let prev = rt
1054 .get_setting(&last_key)
1055 .and_then(|v| v.as_str().map(|s| s.to_string()));
1056 if let Some(p) = prev {
1057 rt.set_setting(&prev_key, serde_json::json!(p), false);
1058 }
1059 rt.set_setting(&last_key, serde_json::json!(request.text.clone()), false);
1060 if let Some(owner) = request.entity_id {
1061 let owner_key = format!("ROOM_OWNER:{}", request.room_id);
1062 if rt.get_setting(&owner_key).is_none() {
1063 rt.set_setting(&owner_key, serde_json::json!(owner.to_string()), false);
1064 }
1065 }
1066 }
1067
1068 let (sender, receiver) = create_text_stream(64);
1070 let stream_handler = StreamHandler::new(sender.clone());
1071
1072 let sse_stream: BoxStream<'static, std::result::Result<Event, std::convert::Infallible>> = ReceiverStream::new(receiver)
1074 .filter_map(|res| async move {
1075 match res {
1076 Ok(TextChunk { text, is_final, metadata }) => {
1077 let data = serde_json::json!({ "text": text, "final": is_final, "meta": metadata });
1078 Some(Ok(Event::default().event(if is_final { "complete" } else { "chunk" }).data(data.to_string())))
1079 }
1080 Err(e) => {
1081 let data = serde_json::json!({ "error": e.to_string() });
1082 Some(Ok(Event::default().event("error").data(data.to_string())))
1083 }
1084 }
1085 })
1086 .boxed();
1087
1088 let runtime = server_state.api_state.runtime.clone();
1090 let req_clone = request.clone();
1091
1092 static STREAM_SEMAPHORE: std::sync::OnceLock<std::sync::Arc<tokio::sync::Semaphore>> =
1095 std::sync::OnceLock::new();
1096 let semaphore = STREAM_SEMAPHORE
1097 .get_or_init(|| {
1098 let max_concurrent = std::env::var("MAX_CONCURRENT_STREAMS")
1099 .ok()
1100 .and_then(|s| s.parse().ok())
1101 .unwrap_or(64); std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent))
1103 })
1104 .clone();
1105
1106 let permit = match semaphore.clone().try_acquire_owned() {
1108 Ok(p) => p,
1109 Err(_) => {
1110 tokio::spawn(async move {
1112 let handler = StreamHandler::new(sender);
1113 let _ = handler
1114 .send_error(ZoeyError::other("Server at capacity, please retry"))
1115 .await;
1116 });
1117 return Sse::new(sse_stream).into_response();
1118 }
1119 };
1120
1121 static STREAM_EXECUTOR: std::sync::OnceLock<
1123 tokio::sync::mpsc::Sender<(
1124 tokio::sync::OwnedSemaphorePermit,
1125 Arc<RwLock<AgentRuntime>>,
1126 ChatRequest,
1127 StreamHandler,
1128 )>,
1129 > = std::sync::OnceLock::new();
1130 let tx = STREAM_EXECUTOR
1131 .get_or_init(|| {
1132 let (tx, mut rx) = tokio::sync::mpsc::channel::<(
1133 tokio::sync::OwnedSemaphorePermit,
1134 Arc<RwLock<AgentRuntime>>,
1135 ChatRequest,
1136 StreamHandler,
1137 )>(256);
1138 std::thread::Builder::new()
1139 .name("chat_stream_executor".to_string())
1140 .stack_size(16 * 1024 * 1024)
1141 .spawn(move || {
1142 let rt = tokio::runtime::Builder::new_current_thread()
1143 .enable_all()
1144 .build()
1145 .unwrap();
1146 rt.block_on(async move {
1147 while let Some((permit, runtime, req, handler)) = rx.recv().await {
1148 let _p = permit;
1149 run_chat_stream_job(runtime.clone(), req.clone(), handler).await;
1150 }
1151 });
1152 })
1153 .expect("stream executor thread");
1154 tx
1155 })
1156 .clone();
1157 let _ = tx
1158 .send((permit, runtime.clone(), req_clone.clone(), stream_handler))
1159 .await;
1160 return Sse::new(sse_stream).into_response();
1161}
1162
1163async fn process_chat_task(
1165 runtime: Arc<RwLock<AgentRuntime>>,
1166 request: ChatRequest,
1167) -> Result<ChatResponse> {
1168 eprintln!("[TRACE] process_chat_task: START");
1169 if request
1170 .metadata
1171 .get("skip_double_processing")
1172 .and_then(|v| v.as_bool())
1173 .unwrap_or(false)
1174 {
1175 return Ok(ChatResponse {
1176 success: true,
1177 messages: None,
1178 error: None,
1179 metadata: None,
1180 });
1181 }
1182 let entity_id = request.entity_id.unwrap_or_else(Uuid::new_v4);
1183 let agent_id = {
1184 let rt = runtime.read().unwrap();
1185 rt.agent_id
1186 };
1187 eprintln!("[TRACE] process_chat_task: agent_id={}", agent_id);
1188
1189 let message = Memory {
1191 id: Uuid::new_v4(),
1192 entity_id,
1193 agent_id,
1194 room_id: request.room_id,
1195 content: Content {
1196 text: request.text.clone(),
1197 source: Some(request.source.clone()),
1198 ..Default::default()
1199 },
1200 embedding: None,
1201 metadata: None,
1202 created_at: chrono::Utc::now().timestamp(),
1203 unique: Some(false),
1204 similarity: None,
1205 };
1206
1207 let world_id = Uuid::new_v4();
1209 let adapter_opt = {
1210 let rt = runtime.read().unwrap();
1211 let adapter_lock = rt.adapter.read().unwrap();
1212 adapter_lock.clone()
1213 };
1214 if let Some(adapter) = adapter_opt.as_ref() {
1215 let world = crate::types::World {
1217 id: world_id,
1218 name: format!("API World {}", world_id),
1219 agent_id,
1220 server_id: None,
1221 metadata: HashMap::new(),
1222 created_at: Some(chrono::Utc::now().timestamp()),
1223 };
1224 let _ = adapter.ensure_world(&world).await;
1225
1226 let entity = crate::types::Entity {
1228 id: entity_id,
1229 agent_id,
1230 name: Some(format!("User {}", entity_id)),
1231 username: None,
1232 email: None,
1233 avatar_url: None,
1234 metadata: HashMap::new(),
1235 created_at: Some(chrono::Utc::now().timestamp()),
1236 };
1237 let _ = adapter.create_entities(vec![entity]).await;
1238
1239 let room_record = crate::types::Room {
1241 id: request.room_id,
1242 agent_id: Some(agent_id),
1243 name: format!("Room {}", request.room_id),
1244 source: request.source.clone(),
1245 channel_type: ChannelType::Api,
1246 channel_id: None,
1247 server_id: None,
1248 world_id,
1249 metadata: HashMap::new(),
1250 created_at: Some(chrono::Utc::now().timestamp()),
1251 };
1252 let _ = adapter.create_room(&room_record).await;
1253
1254 let _ = adapter.add_participant(entity_id, request.room_id).await;
1256 }
1257
1258 let room = Room {
1260 id: request.room_id,
1261 agent_id: Some(agent_id),
1262 name: format!("Room {}", request.room_id),
1263 source: request.source.clone(),
1264 channel_type: ChannelType::Api,
1265 channel_id: None,
1266 server_id: None,
1267 world_id, metadata: HashMap::new(),
1269 created_at: Some(chrono::Utc::now().timestamp()),
1270 };
1271
1272 eprintln!("[TRACE] process_chat_task: calling MessageProcessor::process_message");
1274 let processor = MessageProcessor::new(runtime.clone());
1275 let responses = processor.process_message(message, room).await?;
1276 eprintln!(
1277 "[TRACE] process_chat_task: MessageProcessor returned {} responses",
1278 responses.len()
1279 );
1280 let agent_name = {
1281 let rt = runtime.read().unwrap();
1282 rt.character.name.clone()
1283 };
1284 let preview = responses
1285 .get(0)
1286 .map(|m| m.content.text.chars().take(120).collect::<String>())
1287 .unwrap_or_default();
1288 info!(
1289 "[{}] chat completed responses={}, preview={}",
1290 agent_name,
1291 responses.len(),
1292 preview
1293 );
1294
1295 Ok(ChatResponse {
1296 success: true,
1297 messages: Some(responses),
1298 error: None,
1299 metadata: None,
1300 })
1301}
1302
1303async fn process_message(
1305 runtime: Arc<RwLock<AgentRuntime>>,
1306 message: Memory,
1307 room: Room,
1308) -> Result<Vec<Memory>> {
1309 let processor = MessageProcessor::new(runtime.clone());
1311
1312 processor.process_message(message, room).await
1314}
1315
1316pub async fn action_handler(
1318 State(state): State<ServerState>,
1319 Json(request): Json<ActionRequest>,
1320) -> impl IntoResponse {
1321 let state = state.api_state;
1322 let agent_name = {
1323 let rt = state.runtime.read().unwrap();
1324 rt.character.name.clone()
1325 };
1326 info!("[{}] action request action={}", agent_name, request.action);
1327
1328 if request.action.trim().is_empty() {
1330 return ApiError::BadRequest("Action name cannot be empty".to_string()).into_response();
1331 }
1332
1333 let runtime = state.runtime.read().unwrap();
1334
1335 let actions = runtime.actions.read().unwrap();
1337 let action = match actions.iter().find(|a| a.name() == request.action) {
1338 Some(a) => a,
1339 None => {
1340 return ApiError::NotFound(format!("Action '{}' not found", request.action))
1341 .into_response();
1342 }
1343 };
1344
1345 info!("Would execute action: {}", action.name());
1348
1349 Json(ActionResponse {
1350 success: true,
1351 result: Some(serde_json::json!({
1352 "action": request.action,
1353 "status": "acknowledged"
1354 })),
1355 error: None,
1356 })
1357 .into_response()
1358}
1359
1360pub async fn state_handler(
1362 State(server_state): State<ServerState>,
1363 Json(request): Json<StateRequest>,
1364) -> impl IntoResponse {
1365 let agent_name = {
1366 let rt = server_state.api_state.runtime.read().unwrap();
1367 rt.character.name.clone()
1368 };
1369 info!("[{}] state request room_id={}", agent_name, request.room_id);
1370
1371 let task_id = server_state.task_manager.create_task();
1373 let task_manager = server_state.task_manager.clone();
1374 let runtime = server_state.api_state.runtime.clone();
1375
1376 std::thread::spawn(move || {
1378 let rt = tokio::runtime::Builder::new_current_thread()
1380 .enable_all()
1381 .build()
1382 .unwrap();
1383
1384 rt.block_on(async move {
1385 task_manager.mark_running(task_id);
1386
1387 let timeout_res = tokio::time::timeout(
1388 Duration::from_secs(15),
1389 process_state_task(runtime, request),
1390 )
1391 .await;
1392
1393 match timeout_res {
1394 Ok(Ok(response)) => {
1395 task_manager.complete_task(task_id, super::task::TaskResult::State(response));
1396 }
1397 Ok(Err(e)) => {
1398 task_manager.fail_task(task_id, e.to_string());
1399 }
1400 Err(_) => {
1401 task_manager.fail_task(task_id, "State task timed out".to_string());
1402 }
1403 }
1404 });
1405 });
1406
1407 Json(TaskResponse {
1409 success: true,
1410 task_id,
1411 message:
1412 "State composition task submitted successfully. Poll /agent/task/{task_id} for results."
1413 .to_string(),
1414 estimated_time_ms: Some(2000), })
1416 .into_response()
1417}
1418
1419async fn process_state_task(
1421 runtime: Arc<RwLock<AgentRuntime>>,
1422 request: StateRequest,
1423) -> Result<StateResponse> {
1424 let rt = runtime.read().unwrap();
1425 let entity_id = request.entity_id.unwrap_or_else(Uuid::new_v4);
1426 let agent_id = rt.agent_id;
1427 drop(rt); let message = Memory {
1431 id: Uuid::new_v4(),
1432 entity_id,
1433 agent_id,
1434 room_id: request.room_id,
1435 content: Content::default(),
1436 embedding: None,
1437 metadata: None,
1438 created_at: chrono::Utc::now().timestamp(),
1439 unique: Some(false),
1440 similarity: None,
1441 };
1442
1443 let agent_state = {
1445 let rt = runtime.read().unwrap();
1446 rt.compose_state(&message, None, false, false).await?
1447 };
1448 let agent_name = {
1449 let rt = runtime.read().unwrap();
1450 rt.character.name.clone()
1451 };
1452 info!(
1453 "[{}] state composed values={}",
1454 agent_name,
1455 agent_state.values.len()
1456 );
1457
1458 Ok(StateResponse {
1459 success: true,
1460 state: Some(agent_state),
1461 error: None,
1462 })
1463}
1464
1465pub async fn task_status_handler(
1467 State(server_state): State<ServerState>,
1468 axum::extract::Path(task_id): axum::extract::Path<Uuid>,
1469) -> impl IntoResponse {
1470 debug!("Task status request: task_id={}", task_id);
1471
1472 match server_state.task_manager.get_task(task_id) {
1473 Some(task) => {
1474 let result_json = task
1476 .result
1477 .as_ref()
1478 .map(|r| serde_json::to_value(r).ok())
1479 .flatten();
1480
1481 Json(TaskStatusResponse {
1482 task_id,
1483 status: format!("{:?}", task.status).to_lowercase(),
1484 result: result_json,
1485 error: task.error.clone(),
1486 duration_ms: task.duration_ms(),
1487 created_at: format!("{:?}", task.created_at),
1488 completed_at: task.completed_at.map(|t| format!("{:?}", t)),
1489 })
1490 .into_response()
1491 }
1492 None => ApiError::NotFound(format!("Task {} not found", task_id)).into_response(),
1493 }
1494}
1495
1496#[derive(Deserialize)]
1497pub struct ContextAddPayload {
1498 room_id: Uuid,
1499 key: String,
1500 value: String,
1501}
1502
1503pub async fn context_add_handler(
1505 State(server_state): State<ServerState>,
1506 Json(body): Json<ContextAddPayload>,
1507) -> impl IntoResponse {
1508 let runtime = server_state.api_state.runtime.clone();
1509 {
1510 let mut rt = runtime.write().unwrap();
1512 let key = format!("ui:lastThought:{}:{}", body.room_id, body.key);
1513 rt.set_setting(&key, serde_json::json!(body.value), false);
1514 }
1515 Json(serde_json::json!({"success": true})).into_response()
1516}
1517
1518#[derive(Deserialize)]
1519pub struct ContextSavePayload {
1520 room_id: Uuid,
1521 steps: Vec<String>,
1522}
1523
1524pub async fn context_save_handler(
1526 State(server_state): State<ServerState>,
1527 Json(body): Json<ContextSavePayload>,
1528) -> impl IntoResponse {
1529 let runtime = server_state.api_state.runtime.clone();
1530 let adapter = {
1531 let rt = runtime.read().unwrap();
1532 let x = rt.adapter.read().unwrap().clone();
1533 x
1534 };
1535 if let Some(adapter) = adapter {
1536 for step in body.steps.iter() {
1537 let mem = Memory {
1538 id: Uuid::new_v4(),
1539 entity_id: Uuid::new_v4(),
1540 agent_id: {
1541 let rt = runtime.read().unwrap();
1542 rt.agent_id
1543 },
1544 room_id: body.room_id,
1545 content: Content {
1546 text: step.clone(),
1547 source: Some("simpleui-thought".to_string()),
1548 ..Default::default()
1549 },
1550 embedding: None,
1551 metadata: None,
1552 created_at: chrono::Utc::now().timestamp(),
1553 unique: Some(false),
1554 similarity: None,
1555 };
1556 if let Err(e) = adapter.create_memory(&mem, "thoughts").await {
1557 error!("Failed to persist thought step: {}", e);
1558 }
1559 }
1560 return Json(serde_json::json!({"success": true})).into_response();
1561 }
1562 ApiError::Internal("No database adapter configured".to_string()).into_response()
1563}
1564
1565pub async fn character_list_handler(State(server_state): State<ServerState>) -> impl IntoResponse {
1567 let current_character = {
1569 let rt = server_state.api_state.runtime.read().unwrap();
1570 rt.character.name.clone()
1571 };
1572
1573 let mut list: Vec<String> = Vec::new();
1575 if let Ok(entries) = std::fs::read_dir("characters") {
1576 for e in entries.flatten() {
1577 if let Some(name) = e.file_name().to_str() {
1578 if name.ends_with(".xml") {
1579 list.push(name.to_string());
1580 }
1581 }
1582 }
1583 }
1584 Json(serde_json::json!({
1585 "success": true,
1586 "characters": list,
1587 "current": current_character
1588 })).into_response()
1589}
1590
1591pub async fn character_select_handler(
1593 State(server_state): State<ServerState>,
1594 Json(body): Json<serde_json::Value>,
1595) -> impl IntoResponse {
1596 let Some(filename) = body.get("filename").and_then(|v| v.as_str()) else {
1597 return ApiError::BadRequest("Missing filename".to_string()).into_response();
1598 };
1599 let path = format!("characters/{}", filename);
1600 let xml = match std::fs::read_to_string(&path) {
1601 Ok(s) => s,
1602 Err(_) => {
1603 return ApiError::NotFound("Character file not found".to_string()).into_response()
1604 }
1605 };
1606
1607 fn section<'a>(xml: &'a str, tag: &str) -> Option<&'a str> {
1609 let start = xml.find(&format!("<{}>", tag))?;
1610 let end = xml.find(&format!("</{}>", tag))?;
1611 Some(&xml[start + tag.len() + 2..end])
1612 }
1613 fn entries(xml: &str, section_name: &str) -> Vec<String> {
1614 let mut out = Vec::new();
1615 if let Some(sec) = section(xml, section_name) {
1616 let mut rest = sec;
1617 loop {
1618 if let Some(i) = rest.find("<entry>") {
1619 let r = &rest[i + 7..];
1620 if let Some(j) = r.find("</entry>") {
1621 out.push(r[..j].trim().to_string());
1622 rest = &r[j + 8..];
1623 continue;
1624 }
1625 }
1626 break;
1627 }
1628 }
1629 out
1630 }
1631
1632 let name = section(&xml, "name")
1633 .and_then(|s| s.lines().next())
1634 .unwrap_or("ZoeyAI")
1635 .trim()
1636 .to_string();
1637 let bio = entries(&xml, "bio");
1638 let lore = entries(&xml, "lore");
1639 let knowledge = entries(&xml, "knowledge");
1640
1641 {
1643 let mut rt = server_state.api_state.runtime.write().unwrap();
1644 rt.character.name = name;
1645 if !bio.is_empty() {
1646 rt.character.bio = bio;
1647 }
1648 if !lore.is_empty() {
1649 rt.character.lore = lore;
1650 }
1651 if !knowledge.is_empty() {
1652 rt.character.knowledge = knowledge;
1653 }
1654 }
1655
1656 Json(serde_json::json!({"success": true})).into_response()
1657}
1658
1659#[derive(Debug)]
1661pub enum ApiError {
1662 BadRequest(String),
1663 Unauthorized(String),
1664 Forbidden(String),
1665 NotFound(String),
1666 RateLimited(String),
1667 Internal(String),
1668}
1669
1670impl IntoResponse for ApiError {
1671 fn into_response(self) -> Response {
1672 let (status, message) = match self {
1673 ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
1674 ApiError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg),
1675 ApiError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg),
1676 ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
1677 ApiError::RateLimited(msg) => (StatusCode::TOO_MANY_REQUESTS, msg),
1678 ApiError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
1679 };
1680
1681 let body = Json(serde_json::json!({
1682 "success": false,
1683 "error": message,
1684 "code": status.as_u16(),
1685 }));
1686
1687 (status, body).into_response()
1688 }
1689}
1690
1691impl From<ZoeyError> for ApiError {
1692 fn from(err: ZoeyError) -> Self {
1693 error!("ZoeyError: {}", err);
1694 ApiError::Internal(err.to_string())
1695 }
1696}
1697
1698#[cfg(test)]
1699mod tests {
1700 use super::*;
1701
1702 #[test]
1703 fn test_api_error_response() {
1704 let err = ApiError::BadRequest("test error".to_string());
1705 let response = err.into_response();
1706 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1707 }
1708}
1709#[derive(Deserialize)]
1710pub struct DeleteRoomPayload {
1711 room_id: Uuid,
1712 entity_id: Uuid,
1713 #[serde(default)]
1714 purge_memories: bool,
1715}
1716
1717pub async fn delete_room_handler(
1719 State(server_state): State<ServerState>,
1720 Json(body): Json<DeleteRoomPayload>,
1721) -> impl IntoResponse {
1722 let runtime = server_state.api_state.runtime.clone();
1723 let (adapter, authorized) = {
1724 let rt = runtime.read().unwrap();
1725 let owner_key = format!("ROOM_OWNER:{}", body.room_id);
1726 let authorized = rt
1727 .get_setting(&owner_key)
1728 .and_then(|v| v.as_str().map(|s| s.to_string()))
1729 .map(|owner| owner == body.entity_id.to_string())
1730 .unwrap_or(false);
1731 (rt.get_adapter(), authorized)
1732 };
1733
1734 if !authorized {
1735 return ApiError::Forbidden("Only the room owner can delete this room".to_string())
1736 .into_response();
1737 }
1738
1739 if let Some(adapter) = adapter {
1740 if body.purge_memories {
1741 if let Ok(memories) = adapter
1742 .get_memories(MemoryQuery {
1743 room_id: Some(body.room_id),
1744 table_name: "messages".to_string(),
1745 ..Default::default()
1746 })
1747 .await
1748 {
1749 for m in memories {
1750 let _ = adapter.remove_memory(m.id, "messages").await;
1751 }
1752 }
1753 if let Ok(thoughts) = adapter
1754 .get_memories(MemoryQuery {
1755 room_id: Some(body.room_id),
1756 table_name: "thoughts".to_string(),
1757 ..Default::default()
1758 })
1759 .await
1760 {
1761 for t in thoughts {
1762 let _ = adapter.remove_memory(t.id, "thoughts").await;
1763 }
1764 }
1765
1766 if let Err(e) = delete_room_knowledge(body.room_id) {
1768 warn!("Failed to delete room knowledge: {}", e);
1769 }
1770 }
1771 return Json(serde_json::json!({"success": true})).into_response();
1772 }
1773 ApiError::Internal("No database adapter configured".to_string()).into_response()
1774}
1775
1776struct MemoryWorkItem {
1778 memory: Memory,
1779 response_tx: Option<tokio::sync::oneshot::Sender<std::result::Result<Uuid, String>>>,
1780}
1781
1782static MEMORY_QUEUE: OnceLock<tokio::sync::mpsc::Sender<MemoryWorkItem>> = OnceLock::new();
1784
1785pub fn init_memory_worker_pool(runtime: Arc<RwLock<crate::AgentRuntime>>) {
1787 let (tx, mut rx) = tokio::sync::mpsc::channel::<MemoryWorkItem>(1000);
1788
1789 let _ = MEMORY_QUEUE.set(tx);
1791
1792 for i in 0..4 {
1794 let runtime = runtime.clone();
1795 let mut rx_clone = {
1796 let (new_tx, new_rx) = tokio::sync::mpsc::channel::<MemoryWorkItem>(1000);
1798 if i > 0 {
1802 continue;
1803 } rx
1805 };
1806
1807 std::thread::Builder::new()
1808 .name(format!("memory_worker_{}", i))
1809 .stack_size(64 * 1024 * 1024) .spawn(move || {
1811 eprintln!("[DEBUG] memory_worker: thread started");
1812
1813 let rt_handle = std::sync::Arc::new(std::sync::Mutex::new(
1815 tokio::runtime::Builder::new_current_thread()
1816 .enable_all()
1817 .build()
1818 .unwrap(),
1819 ));
1820
1821 loop {
1822 let work = {
1824 let rt = rt_handle.lock().unwrap();
1825 rt.block_on(rx_clone.recv())
1826 };
1827
1828 let work = match work {
1829 Some(w) => w,
1830 None => break,
1831 };
1832
1833 eprintln!(
1834 "[DEBUG] memory_worker: got work item, id={}",
1835 work.memory.id
1836 );
1837
1838 let adapter = {
1840 let rt_guard = runtime.read().unwrap();
1841 rt_guard.get_adapter()
1842 };
1843
1844 if let Some(adapter) = adapter {
1845 let mem = work.memory.clone();
1846 let memory_id = mem.id;
1847
1848 let (result_tx, result_rx) = std::sync::mpsc::channel();
1851
1852 let rt = rt_handle.lock().unwrap();
1854 rt.spawn(async move {
1855 let result = adapter.create_memory(&mem, "messages").await;
1856 let _ = result_tx.send(result.map(|_| ()).map_err(|e| e.to_string()));
1857 });
1858 drop(rt);
1859
1860 let result =
1862 match result_rx.recv_timeout(std::time::Duration::from_secs(10)) {
1863 Ok(r) => r.map(|_| memory_id),
1864 Err(_) => Err("Memory creation timed out".to_string()),
1865 };
1866
1867 if let Some(tx) = work.response_tx {
1868 let _ = tx.send(result);
1869 }
1870 eprintln!("[DEBUG] memory_worker: work item processed");
1871 } else {
1872 if let Some(tx) = work.response_tx {
1873 let _ = tx.send(Ok(work.memory.id));
1874 }
1875 }
1876 }
1877 eprintln!("[DEBUG] memory_worker: thread exiting");
1878 })
1879 .ok();
1880
1881 break; }
1883}
1884
1885pub async fn memory_create_handler(
1888 State(server_state): State<ServerState>,
1889 Json(request): Json<super::types::MemoryCreateRequest>,
1890) -> Response {
1891 let runtime = server_state.api_state.runtime.clone();
1892
1893 let agent_id = {
1895 let rt = runtime.read().unwrap();
1896 rt.agent_id
1897 };
1898
1899 let memory_id = Uuid::new_v4();
1901 let mut content = Content {
1902 text: request.text,
1903 source: Some(request.source),
1904 ..Default::default()
1905 };
1906 for (k, v) in request.metadata {
1907 content.metadata.insert(k, v);
1908 }
1909
1910 let memory = Memory {
1911 id: memory_id,
1912 entity_id: request.entity_id,
1913 agent_id,
1914 room_id: request.room_id,
1915 content,
1916 embedding: None,
1917 metadata: None,
1918 created_at: chrono::Utc::now().timestamp(),
1919 unique: Some(false),
1920 similarity: None,
1921 };
1922
1923 let queue = match MEMORY_QUEUE.get() {
1925 Some(q) => q,
1926 None => {
1927 init_memory_worker_pool(runtime.clone());
1929 match MEMORY_QUEUE.get() {
1930 Some(q) => q,
1931 None => {
1932 return Json(super::types::MemoryCreateResponse {
1933 success: false,
1934 memory_id: None,
1935 error: Some("Memory queue not initialized".to_string()),
1936 })
1937 .into_response();
1938 }
1939 }
1940 }
1941 };
1942
1943 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
1945
1946 if queue
1948 .send(MemoryWorkItem {
1949 memory,
1950 response_tx: Some(resp_tx),
1951 })
1952 .await
1953 .is_err()
1954 {
1955 return Json(super::types::MemoryCreateResponse {
1956 success: false,
1957 memory_id: None,
1958 error: Some("Memory queue full".to_string()),
1959 })
1960 .into_response();
1961 }
1962
1963 match tokio::time::timeout(Duration::from_secs(10), resp_rx).await {
1965 Ok(Ok(Ok(id))) => Json(super::types::MemoryCreateResponse {
1966 success: true,
1967 memory_id: Some(id),
1968 error: None,
1969 })
1970 .into_response(),
1971 Ok(Ok(Err(e))) => Json(super::types::MemoryCreateResponse {
1972 success: false,
1973 memory_id: None,
1974 error: Some(e),
1975 })
1976 .into_response(),
1977 _ => Json(super::types::MemoryCreateResponse {
1978 success: false,
1979 memory_id: None,
1980 error: Some("Memory operation failed".to_string()),
1981 })
1982 .into_response(),
1983 }
1984}
1985
1986pub async fn memory_create_async(
1988 runtime: Arc<RwLock<crate::AgentRuntime>>,
1989 room_id: Uuid,
1990 entity_id: Uuid,
1991 text: String,
1992 source: String,
1993) {
1994 let agent_id = {
1995 let rt = runtime.read().unwrap();
1996 rt.agent_id
1997 };
1998
1999 let memory = Memory {
2000 id: Uuid::new_v4(),
2001 entity_id,
2002 agent_id,
2003 room_id,
2004 content: Content {
2005 text,
2006 source: Some(source),
2007 ..Default::default()
2008 },
2009 embedding: None,
2010 metadata: None,
2011 created_at: chrono::Utc::now().timestamp(),
2012 unique: Some(false),
2013 similarity: None,
2014 };
2015
2016 if let Some(queue) = MEMORY_QUEUE.get() {
2017 let _ = queue
2018 .send(MemoryWorkItem {
2019 memory,
2020 response_tx: None, })
2022 .await;
2023 }
2024}
2025
2026const KNOWLEDGE_MAX_CONTENT_SIZE: usize = 10 * 1024 * 1024; const KNOWLEDGE_MAX_FILENAME_LENGTH: usize = 255;
2033const KNOWLEDGE_MIN_CONTENT_LENGTH: usize = 10; const KNOWLEDGE_MAX_CHUNKS_PER_DOC: usize = 1000; const KNOWLEDGE_CHUNK_SIZE: usize = 512; const KNOWLEDGE_CHUNK_OVERLAP: usize = 64; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2040pub struct KnowledgeDocument {
2041 pub id: Uuid,
2042 pub room_id: Uuid,
2043 pub entity_id: Uuid,
2044 pub agent_id: Uuid,
2045 pub filename: String,
2046 pub doc_type: String,
2047 pub content: String,
2048 pub chunks: Vec<KnowledgeChunk>,
2049 pub word_count: usize,
2050 pub created_at: i64,
2051 pub metadata: HashMap<String, serde_json::Value>,
2052}
2053
2054#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2056pub struct KnowledgeChunk {
2057 pub id: Uuid,
2058 pub document_id: Uuid,
2059 pub text: String,
2060 pub index: usize,
2061 pub char_start: usize,
2062 pub char_end: usize,
2063}
2064
2065mod bm25 {
2068 use rust_stemmers::{Algorithm, Stemmer};
2069 use std::collections::HashMap;
2070
2071 pub struct BM25Search {
2072 corpus: Vec<String>,
2073 stemmer: Stemmer,
2074 k1: f64,
2075 b: f64,
2076 }
2077
2078 impl BM25Search {
2079 pub fn new(corpus: Vec<String>) -> Self {
2080 Self {
2081 corpus,
2082 stemmer: Stemmer::create(Algorithm::English),
2083 k1: 1.2,
2084 b: 0.75,
2085 }
2086 }
2087
2088 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
2089 if self.corpus.is_empty() {
2090 return Vec::new();
2091 }
2092
2093 let query_terms = self.tokenize_and_stem(query);
2094 let avg_doc_len = self.average_document_length();
2095
2096 let mut scores: Vec<(usize, f64)> = self.corpus
2097 .iter()
2098 .enumerate()
2099 .map(|(idx, doc)| {
2100 let score = self.bm25_score(&query_terms, doc, avg_doc_len);
2101 (idx, score)
2102 })
2103 .collect();
2104
2105 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
2106
2107 scores
2108 .into_iter()
2109 .take(top_k)
2110 .filter(|(_, score)| *score > 0.0)
2111 .map(|(idx, score)| (self.corpus[idx].clone(), score))
2112 .collect()
2113 }
2114
2115 fn bm25_score(&self, query_terms: &[String], document: &str, avg_doc_len: f64) -> f64 {
2116 let doc_terms = self.tokenize_and_stem(document);
2117 let doc_len = doc_terms.len() as f64;
2118
2119 if doc_len == 0.0 {
2120 return 0.0;
2121 }
2122
2123 let term_freqs = self.term_frequencies(&doc_terms);
2124
2125 query_terms
2126 .iter()
2127 .map(|term| {
2128 let tf = *term_freqs.get(term).unwrap_or(&0) as f64;
2129 if tf == 0.0 {
2130 return 0.0;
2131 }
2132 let idf = self.inverse_document_frequency(term);
2133
2134 let numerator = tf * (self.k1 + 1.0);
2135 let denominator = tf + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len.max(1.0)));
2136
2137 idf * (numerator / denominator)
2138 })
2139 .sum()
2140 }
2141
2142 fn inverse_document_frequency(&self, term: &str) -> f64 {
2143 let n = self.corpus.len() as f64;
2144 let df = self.corpus
2145 .iter()
2146 .filter(|doc| self.tokenize_and_stem(doc).contains(&term.to_string()))
2147 .count() as f64;
2148
2149 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
2150 }
2151
2152 fn tokenize_and_stem(&self, text: &str) -> Vec<String> {
2153 text.to_lowercase()
2154 .split_whitespace()
2155 .filter(|word| word.len() > 2)
2156 .map(|word| {
2157 let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric());
2158 self.stemmer.stem(cleaned).to_string()
2159 })
2160 .collect()
2161 }
2162
2163 fn term_frequencies(&self, terms: &[String]) -> HashMap<String, usize> {
2164 let mut freqs = HashMap::new();
2165 for term in terms {
2166 *freqs.entry(term.clone()).or_insert(0) += 1;
2167 }
2168 freqs
2169 }
2170
2171 fn average_document_length(&self) -> f64 {
2172 if self.corpus.is_empty() {
2173 return 1.0;
2174 }
2175
2176 let total: usize = self.corpus
2177 .iter()
2178 .map(|doc| self.tokenize_and_stem(doc).len())
2179 .sum();
2180
2181 (total as f64 / self.corpus.len() as f64).max(1.0)
2182 }
2183 }
2184}
2185
2186static KNOWLEDGE_STORE: OnceLock<Arc<RwLock<HashMap<Uuid, Vec<KnowledgeDocument>>>>> =
2189 OnceLock::new();
2190
2191fn get_knowledge_dir() -> std::path::PathBuf {
2193 let base = std::env::var("KNOWLEDGE_STORAGE_DIR")
2194 .unwrap_or_else(|_| ".zoey/db/knowledge".to_string());
2195 std::path::PathBuf::from(base)
2196}
2197
2198fn get_room_knowledge_path(room_id: Uuid) -> std::path::PathBuf {
2200 get_knowledge_dir().join(format!("{}.json", room_id))
2201}
2202
2203fn load_room_knowledge(room_id: Uuid) -> Vec<KnowledgeDocument> {
2205 let path = get_room_knowledge_path(room_id);
2206 if !path.exists() {
2207 return Vec::new();
2208 }
2209
2210 match std::fs::read_to_string(&path) {
2211 Ok(content) => {
2212 match serde_json::from_str::<Vec<KnowledgeDocument>>(&content) {
2213 Ok(docs) => {
2214 info!("KNOWLEDGE_LOADED room_id={} documents={}", room_id, docs.len());
2215 docs
2216 }
2217 Err(e) => {
2218 error!("KNOWLEDGE_LOAD_ERROR room_id={} error={}", room_id, e);
2219 Vec::new()
2220 }
2221 }
2222 }
2223 Err(e) => {
2224 error!("KNOWLEDGE_READ_ERROR room_id={} error={}", room_id, e);
2225 Vec::new()
2226 }
2227 }
2228}
2229
2230fn save_room_knowledge(room_id: Uuid, documents: &[KnowledgeDocument]) -> std::result::Result<(), String> {
2232 let dir = get_knowledge_dir();
2233
2234 if let Err(e) = std::fs::create_dir_all(&dir) {
2236 return Err(format!("Failed to create knowledge directory: {}", e));
2237 }
2238
2239 let path = get_room_knowledge_path(room_id);
2240
2241 let json = match serde_json::to_string_pretty(documents) {
2242 Ok(j) => j,
2243 Err(e) => return Err(format!("Failed to serialize knowledge: {}", e)),
2244 };
2245
2246 match std::fs::write(&path, json) {
2247 Ok(_) => {
2248 info!("KNOWLEDGE_SAVED room_id={} documents={} path={:?}", room_id, documents.len(), path);
2249 Ok(())
2250 }
2251 Err(e) => Err(format!("Failed to write knowledge file: {}", e)),
2252 }
2253}
2254
2255pub fn delete_room_knowledge(room_id: Uuid) -> std::result::Result<(), String> {
2257 {
2259 let store = get_knowledge_store();
2260 if let Ok(mut store_guard) = store.write() {
2261 store_guard.remove(&room_id);
2262 };
2263 }
2264
2265 let path = get_room_knowledge_path(room_id);
2267 if path.exists() {
2268 if let Err(e) = std::fs::remove_file(&path) {
2269 error!("KNOWLEDGE_DELETE_ERROR room_id={} error={}", room_id, e);
2270 return Err(format!("Failed to delete knowledge file: {}", e));
2271 }
2272 info!("KNOWLEDGE_DELETED room_id={}", room_id);
2273 }
2274
2275 Ok(())
2276}
2277
2278fn get_knowledge_store() -> Arc<RwLock<HashMap<Uuid, Vec<KnowledgeDocument>>>> {
2279 KNOWLEDGE_STORE
2280 .get_or_init(|| Arc::new(RwLock::new(HashMap::new())))
2281 .clone()
2282}
2283
2284fn get_room_documents(room_id: Uuid) -> Vec<KnowledgeDocument> {
2286 let store = get_knowledge_store();
2287
2288 {
2290 let store_guard = store.read().unwrap();
2291 if let Some(docs) = store_guard.get(&room_id) {
2292 return docs.clone();
2293 }
2294 }
2295
2296 let docs = load_room_knowledge(room_id);
2298
2299 if !docs.is_empty() {
2301 let mut store_guard = store.write().unwrap();
2302 store_guard.insert(room_id, docs.clone());
2303 }
2304
2305 docs
2306}
2307
2308fn validate_filename(filename: &str) -> std::result::Result<String, String> {
2310 if filename.is_empty() {
2312 return Err("Filename cannot be empty".to_string());
2313 }
2314 if filename.len() > KNOWLEDGE_MAX_FILENAME_LENGTH {
2315 return Err(format!(
2316 "Filename too long (max {} characters)",
2317 KNOWLEDGE_MAX_FILENAME_LENGTH
2318 ));
2319 }
2320
2321 let sanitized: String = filename
2323 .chars()
2324 .filter(|c| c.is_alphanumeric() || *c == '.' || *c == '_' || *c == '-' || *c == ' ')
2325 .collect();
2326
2327 let sanitized = sanitized
2329 .rsplit(['/', '\\'])
2330 .next()
2331 .unwrap_or(&sanitized)
2332 .trim()
2333 .to_string();
2334
2335 if sanitized.is_empty() {
2336 return Err("Invalid filename after sanitization".to_string());
2337 }
2338
2339 if sanitized.contains("..") {
2341 return Err("Invalid filename: path traversal detected".to_string());
2342 }
2343
2344 Ok(sanitized)
2345}
2346
2347fn validate_content(content: &str) -> std::result::Result<(), String> {
2349 if content.len() > KNOWLEDGE_MAX_CONTENT_SIZE {
2351 return Err(format!(
2352 "Content too large (max {} bytes)",
2353 KNOWLEDGE_MAX_CONTENT_SIZE
2354 ));
2355 }
2356
2357 if content.len() < KNOWLEDGE_MIN_CONTENT_LENGTH {
2358 return Err(format!(
2359 "Content too short (min {} characters)",
2360 KNOWLEDGE_MIN_CONTENT_LENGTH
2361 ));
2362 }
2363
2364 if content.contains('\0') {
2366 return Err("Content contains invalid null bytes".to_string());
2367 }
2368
2369 let non_printable_count = content
2371 .chars()
2372 .filter(|c| !c.is_ascii_graphic() && !c.is_whitespace())
2373 .count();
2374
2375 if non_printable_count > content.len() / 10 {
2376 return Err("Content contains too many non-printable characters".to_string());
2377 }
2378
2379 Ok(())
2380}
2381
2382fn chunk_content(content: &str, document_id: Uuid) -> Vec<KnowledgeChunk> {
2384 let mut chunks = Vec::new();
2385 let chars: Vec<char> = content.chars().collect();
2386 let total_len = chars.len();
2387
2388 if total_len == 0 {
2389 return chunks;
2390 }
2391
2392 let mut start = 0;
2393 let mut index = 0;
2394
2395 while start < total_len && index < KNOWLEDGE_MAX_CHUNKS_PER_DOC {
2396 let end = (start + KNOWLEDGE_CHUNK_SIZE).min(total_len);
2397
2398 let actual_end = if end < total_len {
2400 let slice: String = chars[start..end].iter().collect();
2402 let last_period = slice.rfind(|c| c == '.' || c == '!' || c == '?' || c == '\n');
2403 match last_period {
2404 Some(pos) if pos > KNOWLEDGE_CHUNK_SIZE / 2 => start + pos + 1,
2405 _ => end,
2406 }
2407 } else {
2408 end
2409 };
2410
2411 let chunk_text: String = chars[start..actual_end].iter().collect();
2412 let trimmed = chunk_text.trim();
2413
2414 if !trimmed.is_empty() {
2415 chunks.push(KnowledgeChunk {
2416 id: Uuid::new_v4(),
2417 document_id,
2418 text: trimmed.to_string(),
2419 index,
2420 char_start: start,
2421 char_end: actual_end,
2422 });
2423 index += 1;
2424 }
2425
2426 start = if actual_end >= total_len {
2428 total_len
2429 } else {
2430 (actual_end).saturating_sub(KNOWLEDGE_CHUNK_OVERLAP)
2431 };
2432
2433 if start == 0 && actual_end == 0 {
2435 break;
2436 }
2437 }
2438
2439 chunks
2440}
2441
2442fn scrub_pii_basic(content: &str) -> String {
2444 use regex::Regex;
2445
2446 let mut scrubbed = content.to_string();
2447
2448 if let Ok(re) = Regex::new(r"\b\d{3}-\d{2}-\d{4}\b") {
2450 scrubbed = re.replace_all(&scrubbed, "[SSN_REDACTED]").to_string();
2451 }
2452
2453 if let Ok(re) = Regex::new(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b") {
2455 scrubbed = re.replace_all(&scrubbed, "[CC_REDACTED]").to_string();
2456 }
2457
2458 if let Ok(re) = Regex::new(r"\b(sk-|api[_-]?key[:\s=]+)[A-Za-z0-9]{20,}\b") {
2460 scrubbed = re.replace_all(&scrubbed, "[API_KEY_REDACTED]").to_string();
2461 }
2462
2463 scrubbed
2464}
2465
2466fn extract_text_from_pdf(bytes: &[u8]) -> std::result::Result<String, String> {
2468 let temp_dir = std::env::temp_dir();
2471 let temp_path = temp_dir.join(format!("knowledge_pdf_{}.pdf", Uuid::new_v4()));
2472
2473 if let Err(e) = std::fs::write(&temp_path, bytes) {
2475 return Err(format!("Failed to write temp PDF file: {}", e));
2476 }
2477
2478 let result = pdf_extract::extract_text(&temp_path)
2480 .map_err(|e| format!("PDF extraction error: {}", e));
2481
2482 let _ = std::fs::remove_file(&temp_path);
2484
2485 result
2486}
2487
2488fn extract_text_from_excel(bytes: &[u8]) -> std::result::Result<String, String> {
2490 use calamine::{Reader, Xlsx, Data};
2491 use std::io::Cursor;
2492
2493 let cursor = Cursor::new(bytes.to_vec());
2494
2495 let mut workbook: Xlsx<_> = match Xlsx::new(cursor) {
2497 Ok(wb) => wb,
2498 Err(e) => {
2499 return Err(format!("Failed to open Excel file: {}", e));
2500 }
2501 };
2502
2503 let mut all_text = Vec::new();
2504
2505 let sheet_names: Vec<String> = workbook.sheet_names().to_vec();
2507
2508 for sheet_name in sheet_names {
2509 if let Ok(range) = workbook.worksheet_range(&sheet_name) {
2510 all_text.push(format!("## Sheet: {}\n", sheet_name));
2511
2512 for row in range.rows() {
2513 let row_text: Vec<String> = row
2514 .iter()
2515 .map(|cell| {
2516 match cell {
2517 Data::Empty => String::new(),
2518 Data::String(s) => s.clone(),
2519 Data::Float(f) => f.to_string(),
2520 Data::Int(i) => i.to_string(),
2521 Data::Bool(b) => b.to_string(),
2522 Data::Error(e) => format!("#ERR:{:?}", e),
2523 Data::DateTime(dt) => format!("{}", dt),
2524 Data::DateTimeIso(s) => s.clone(),
2525 Data::DurationIso(s) => s.clone(),
2526 }
2527 })
2528 .collect();
2529
2530 let row_str = row_text.join(" | ");
2532 if !row_str.trim().is_empty() && row_str.trim() != "|" {
2533 all_text.push(row_str);
2534 }
2535 }
2536
2537 all_text.push(String::new()); }
2539 }
2540
2541 let result = all_text.join("\n");
2542
2543 if result.trim().is_empty() {
2544 return Err("Excel file appears to be empty".to_string());
2545 }
2546
2547 Ok(result)
2548}
2549
2550pub async fn knowledge_ingest_handler(
2565 State(server_state): State<ServerState>,
2566 Json(request): Json<super::types::KnowledgeIngestRequest>,
2567) -> Response {
2568 use super::types::{KnowledgeDocumentType, KnowledgeIngestResponse};
2569
2570 let runtime = server_state.api_state.runtime.clone();
2571 let mut warnings: Vec<String> = Vec::new();
2572
2573 info!(
2575 "KNOWLEDGE_INGEST_START room_id={} entity_id={} filename={}",
2576 request.room_id, request.entity_id, request.filename
2577 );
2578
2579 let filename = match validate_filename(&request.filename) {
2581 Ok(f) => f,
2582 Err(e) => {
2583 error!("KNOWLEDGE_INGEST_ERROR filename validation failed: {}", e);
2584 return Json(KnowledgeIngestResponse::error(format!(
2585 "Invalid filename: {}",
2586 e
2587 )))
2588 .into_response();
2589 }
2590 };
2591
2592 let doc_type = match request.document_type {
2594 Some(dt) => dt,
2595 None => match KnowledgeDocumentType::from_filename(&filename) {
2596 Some(dt) => dt,
2597 None => {
2598 error!(
2599 "KNOWLEDGE_INGEST_ERROR unsupported file type for {}",
2600 filename
2601 );
2602 return Json(KnowledgeIngestResponse::error(
2603 "Unsupported file type. Allowed: .txt, .md, .csv, .json",
2604 ))
2605 .into_response();
2606 }
2607 },
2608 };
2609
2610 if let Some(ref mime) = request.mime_type {
2612 if !doc_type.valid_mime_type(mime) {
2613 warnings.push(format!(
2614 "MIME type '{}' may not match document type {:?}",
2615 mime, doc_type
2616 ));
2617 }
2618 }
2619
2620 let content = if request.base64_encoded {
2622 use base64::{engine::general_purpose::STANDARD, Engine};
2623 let bytes = match STANDARD.decode(&request.content) {
2624 Ok(b) => b,
2625 Err(_) => {
2626 error!("KNOWLEDGE_INGEST_ERROR invalid base64 encoding");
2627 return Json(KnowledgeIngestResponse::error("Invalid base64 encoding"))
2628 .into_response();
2629 }
2630 };
2631
2632 match doc_type {
2634 KnowledgeDocumentType::Pdf => {
2635 match extract_text_from_pdf(&bytes) {
2637 Ok(text) => text,
2638 Err(e) => {
2639 error!("KNOWLEDGE_INGEST_ERROR PDF extraction failed: {}", e);
2640 return Json(KnowledgeIngestResponse::error(format!(
2641 "Failed to extract text from PDF: {}",
2642 e
2643 )))
2644 .into_response();
2645 }
2646 }
2647 }
2648 KnowledgeDocumentType::Excel => {
2649 match extract_text_from_excel(&bytes) {
2651 Ok(text) => text,
2652 Err(e) => {
2653 error!("KNOWLEDGE_INGEST_ERROR Excel extraction failed: {}", e);
2654 return Json(KnowledgeIngestResponse::error(format!(
2655 "Failed to extract text from Excel: {}",
2656 e
2657 )))
2658 .into_response();
2659 }
2660 }
2661 }
2662 _ => {
2663 match String::from_utf8(bytes) {
2665 Ok(s) => s,
2666 Err(_) => {
2667 error!("KNOWLEDGE_INGEST_ERROR base64 content is not valid UTF-8");
2668 return Json(KnowledgeIngestResponse::error(
2669 "Base64 content is not valid UTF-8 text",
2670 ))
2671 .into_response();
2672 }
2673 }
2674 }
2675 }
2676 } else {
2677 request.content.clone()
2678 };
2679
2680 if let Err(e) = validate_content(&content) {
2682 error!("KNOWLEDGE_INGEST_ERROR content validation failed: {}", e);
2683 return Json(KnowledgeIngestResponse::error(format!(
2684 "Invalid content: {}",
2685 e
2686 )))
2687 .into_response();
2688 }
2689
2690 let scrubbed_content = scrub_pii_basic(&content);
2692 if scrubbed_content.len() != content.len() {
2693 warnings.push("Some PII patterns were automatically redacted".to_string());
2694 }
2695
2696 let agent_id = {
2698 let rt = runtime.read().unwrap();
2699 rt.agent_id
2700 };
2701
2702 let document_id = Uuid::new_v4();
2704 let word_count = scrubbed_content.split_whitespace().count();
2705
2706 let chunks = chunk_content(&scrubbed_content, document_id);
2708 let chunks_count = chunks.len();
2709
2710 if chunks.is_empty() {
2711 error!("KNOWLEDGE_INGEST_ERROR no valid chunks created");
2712 return Json(KnowledgeIngestResponse::error(
2713 "Content produced no valid chunks",
2714 ))
2715 .into_response();
2716 }
2717
2718 let document = KnowledgeDocument {
2720 id: document_id,
2721 room_id: request.room_id,
2722 entity_id: request.entity_id,
2723 agent_id,
2724 filename: filename.clone(),
2725 doc_type: format!("{:?}", doc_type),
2726 content: scrubbed_content,
2727 chunks,
2728 word_count,
2729 created_at: chrono::Utc::now().timestamp(),
2730 metadata: request.metadata,
2731 };
2732
2733 {
2735 let store = get_knowledge_store();
2736 let mut store_guard = store.write().unwrap();
2737 let room_docs = store_guard.entry(request.room_id).or_insert_with(|| {
2738 load_room_knowledge(request.room_id)
2740 });
2741 room_docs.push(document);
2742
2743 if let Err(e) = save_room_knowledge(request.room_id, room_docs) {
2745 warnings.push(format!("Warning: Failed to persist knowledge: {}", e));
2746 }
2747 }
2748
2749 info!(
2750 "KNOWLEDGE_INGEST_SUCCESS document_id={} filename={} chunks={} words={}",
2751 document_id, filename, chunks_count, word_count
2752 );
2753
2754 Json(
2755 KnowledgeIngestResponse::success(document_id, chunks_count, word_count)
2756 .with_warnings(warnings),
2757 )
2758 .into_response()
2759}
2760
2761pub async fn knowledge_query_handler(
2763 State(server_state): State<ServerState>,
2764 Json(request): Json<super::types::KnowledgeQueryRequest>,
2765) -> Response {
2766 use super::types::{KnowledgeChunkResult, KnowledgeQueryResponse};
2767
2768 info!(
2769 "KNOWLEDGE_QUERY room_id={} query_len={}",
2770 request.room_id,
2771 request.query.len()
2772 );
2773
2774 use bm25::BM25Search;
2775
2776 let documents = get_room_documents(request.room_id);
2778
2779 if documents.is_empty() {
2780 return Json(KnowledgeQueryResponse {
2781 success: true,
2782 results: Some(vec![]),
2783 total_documents: Some(0),
2784 error: None,
2785 })
2786 .into_response();
2787 }
2788
2789 let mut corpus: Vec<String> = Vec::new();
2791 let mut chunk_map: Vec<(Uuid, Uuid, String)> = Vec::new(); for doc in &documents {
2794 for chunk in &doc.chunks {
2795 corpus.push(chunk.text.clone());
2796 chunk_map.push((chunk.id, doc.id, doc.filename.clone()));
2797 }
2798 }
2799
2800 let bm25 = BM25Search::new(corpus.clone());
2802 let bm25_results = bm25.search(&request.query, request.max_results);
2803
2804 let final_results: Vec<KnowledgeChunkResult> = bm25_results
2806 .into_iter()
2807 .filter_map(|(text, score)| {
2808 corpus.iter().position(|c| c == &text).map(|idx| {
2810 let (chunk_id, doc_id, filename) = &chunk_map[idx];
2811 KnowledgeChunkResult {
2812 id: *chunk_id,
2813 document_id: *doc_id,
2814 text,
2815 score,
2816 filename: Some(filename.clone()),
2817 }
2818 })
2819 })
2820 .collect();
2821
2822 Json(KnowledgeQueryResponse {
2823 success: true,
2824 results: Some(final_results),
2825 total_documents: Some(documents.len()),
2826 error: None,
2827 })
2828 .into_response()
2829}
2830
2831pub async fn knowledge_list_handler(
2833 State(server_state): State<ServerState>,
2834 axum::extract::Path(room_id): axum::extract::Path<String>,
2835) -> Response {
2836 let room_uuid = match Uuid::parse_str(&room_id) {
2837 Ok(id) => id,
2838 Err(_) => {
2839 return Json(serde_json::json!({
2840 "success": false,
2841 "error": "Invalid room ID"
2842 }))
2843 .into_response();
2844 }
2845 };
2846
2847 let documents = get_room_documents(room_uuid);
2849
2850 let doc_list: Vec<serde_json::Value> = documents
2851 .iter()
2852 .map(|d| {
2853 serde_json::json!({
2854 "id": d.id,
2855 "filename": d.filename,
2856 "docType": d.doc_type,
2857 "wordCount": d.word_count,
2858 "chunksCount": d.chunks.len(),
2859 "createdAt": d.created_at,
2860 })
2861 })
2862 .collect();
2863
2864 Json(serde_json::json!({
2865 "success": true,
2866 "documents": doc_list,
2867 "totalDocuments": doc_list.len(),
2868 }))
2869 .into_response()
2870}
2871
2872pub fn retrieve_knowledge_context(room_id: Uuid, query: &str, max_chunks: usize) -> Option<String> {
2880 use bm25::BM25Search;
2881
2882 let documents = get_room_documents(room_id);
2884 if documents.is_empty() {
2885 return None;
2886 }
2887
2888 let mut corpus: Vec<String> = Vec::new();
2890 let mut chunk_sources: Vec<(usize, String, Uuid)> = Vec::new(); for doc in &documents {
2893 for chunk in &doc.chunks {
2894 let chunk_idx = corpus.len();
2895 corpus.push(chunk.text.clone());
2896 chunk_sources.push((chunk_idx, doc.filename.clone(), chunk.id));
2897 }
2898 }
2899
2900 if corpus.is_empty() {
2901 return None;
2902 }
2903
2904 let bm25 = BM25Search::new(corpus.clone());
2906 let results = bm25.search(query, max_chunks * 2); if results.is_empty() {
2909 info!("KNOWLEDGE_NO_MATCHES room_id={} query_len={}", room_id, query.len());
2910 return None;
2911 }
2912
2913 const MIN_BM25_SCORE: f64 = 0.5;
2915 let filtered_results: Vec<(String, f64, String)> = results
2916 .into_iter()
2917 .filter(|(_, score)| *score >= MIN_BM25_SCORE)
2918 .take(max_chunks)
2919 .map(|(text, score)| {
2920 let filename = corpus.iter()
2922 .position(|c| c == &text)
2923 .and_then(|idx| chunk_sources.iter().find(|(i, _, _)| *i == idx))
2924 .map(|(_, f, _)| f.clone())
2925 .unwrap_or_else(|| "unknown".to_string());
2926 (text, score, filename)
2927 })
2928 .collect();
2929
2930 if filtered_results.is_empty() {
2931 info!(
2932 "KNOWLEDGE_LOW_RELEVANCE room_id={} query_len={} (scores below threshold)",
2933 room_id, query.len()
2934 );
2935 return None;
2936 }
2937
2938 info!(
2939 "KNOWLEDGE_MATCHED room_id={} chunks={} top_score={:.2}",
2940 room_id,
2941 filtered_results.len(),
2942 filtered_results.first().map(|r| r.1).unwrap_or(0.0)
2943 );
2944
2945 let context_parts: Vec<String> = filtered_results
2947 .iter()
2948 .map(|(text, _score, filename)| {
2949 let truncated = if text.len() > 600 {
2951 format!("{}...", &text[..600])
2952 } else {
2953 text.clone()
2954 };
2955 format!("[{}]: {}", filename, truncated)
2956 })
2957 .collect();
2958
2959 Some(format!(
2960 "**Relevant excerpts from case documents:**\n\n{}",
2961 context_parts.join("\n\n")
2962 ))
2963}
2964
2965pub fn get_knowledge_summary(room_id: Uuid) -> Option<String> {
2967 let store = get_knowledge_store();
2968 let store_guard = store.read().ok()?;
2969
2970 let documents = store_guard.get(&room_id)?;
2971 if documents.is_empty() {
2972 return None;
2973 }
2974
2975 let total_words: usize = documents.iter().map(|d| d.word_count).sum();
2976 let total_chunks: usize = documents.iter().map(|d| d.chunks.len()).sum();
2977
2978 let doc_list: Vec<String> = documents
2979 .iter()
2980 .map(|d| format!("- **{}** ({} words, {} chunks)", d.filename, d.word_count, d.chunks.len()))
2981 .collect();
2982
2983 Some(format!(
2984 "### Case Knowledge Base\n\n\
2985 **{} documents** | **{} words** | **{} searchable chunks**\n\n\
2986 {}",
2987 documents.len(),
2988 total_words,
2989 total_chunks,
2990 doc_list.join("\n")
2991 ))
2992}
2993
2994pub async fn training_statistics_handler(
3000 State(state): State<ServerState>,
3001) -> impl IntoResponse {
3002 let runtime = state.api_state.runtime.read().unwrap();
3003
3004 if let Some(collector) = runtime.get_training_collector() {
3006 let stats = collector.get_statistics();
3007 let response = serde_json::json!({
3008 "status": "success",
3009 "data": {
3010 "type": "statistics",
3011 "totalSamples": stats.total_samples,
3012 "highQualityCount": stats.high_quality_count,
3013 "mediumQualityCount": stats.medium_quality_count,
3014 "lowQualityCount": stats.low_quality_count,
3015 "withThoughtsCount": stats.with_thoughts_count,
3016 "withFeedbackCount": stats.with_feedback_count,
3017 "avgQualityScore": stats.avg_quality_score,
3018 "avgFeedbackScore": stats.avg_feedback_score,
3019 "categories": stats.categories,
3020 "tags": stats.tags,
3021 "rlhfEnabled": collector.is_rlhf_enabled()
3022 }
3023 });
3024 (StatusCode::OK, Json(response))
3025 } else {
3026 let response = serde_json::json!({
3027 "status": "error",
3028 "code": "NOT_AVAILABLE",
3029 "message": "Training collector not initialized"
3030 });
3031 (StatusCode::SERVICE_UNAVAILABLE, Json(response))
3032 }
3033}
3034
3035#[derive(Debug, Deserialize)]
3037#[serde(rename_all = "camelCase")]
3038pub struct AddFeedbackRequest {
3039 pub sample_id: Uuid,
3040 pub feedback_score: f32,
3041 pub feedback_text: Option<String>,
3042}
3043
3044pub async fn training_feedback_handler(
3045 State(state): State<ServerState>,
3046 Json(payload): Json<AddFeedbackRequest>,
3047) -> Response {
3048 let collector = {
3050 let runtime = state.api_state.runtime.read().unwrap();
3051 runtime.get_training_collector()
3052 };
3053
3054 if let Some(collector) = collector {
3055 match collector.add_feedback(payload.sample_id, payload.feedback_score, payload.feedback_text).await {
3056 Ok(_) => {
3057 info!("Training feedback added for sample {}: score={}", payload.sample_id, payload.feedback_score);
3058 let response = serde_json::json!({
3059 "status": "success",
3060 "data": {
3061 "type": "feedbackAdded",
3062 "sampleId": payload.sample_id.to_string()
3063 }
3064 });
3065 (StatusCode::OK, Json(response)).into_response()
3066 }
3067 Err(e) => {
3068 error!("Failed to add feedback: {}", e);
3069 let response = serde_json::json!({
3070 "status": "error",
3071 "code": "FEEDBACK_FAILED",
3072 "message": e.to_string()
3073 });
3074 (StatusCode::BAD_REQUEST, Json(response)).into_response()
3075 }
3076 }
3077 } else {
3078 let response = serde_json::json!({
3079 "status": "error",
3080 "code": "NOT_AVAILABLE",
3081 "message": "Training collector not initialized"
3082 });
3083 (StatusCode::SERVICE_UNAVAILABLE, Json(response)).into_response()
3084 }
3085}
3086
3087#[derive(Debug, Deserialize)]
3089#[serde(rename_all = "camelCase")]
3090pub struct ExportDataRequest {
3091 pub format: Option<String>,
3092 pub include_negative: Option<bool>,
3093}
3094
3095pub async fn training_export_handler(
3096 State(state): State<ServerState>,
3097 Json(payload): Json<ExportDataRequest>,
3098) -> Response {
3099 let collector = {
3101 let runtime = state.api_state.runtime.read().unwrap();
3102 runtime.get_training_collector()
3103 };
3104
3105 if let Some(collector) = collector {
3106 let format = payload.format.as_deref().unwrap_or("jsonl");
3107
3108 let export_result = match format.to_lowercase().as_str() {
3109 "jsonl" => collector.export_jsonl().await,
3110 "alpaca" => collector.export_alpaca().await,
3111 "sharegpt" => collector.export_sharegpt().await,
3112 "openai" => collector.export_openai().await,
3113 _ => {
3114 let response = serde_json::json!({
3115 "status": "error",
3116 "code": "INVALID_FORMAT",
3117 "message": "Unsupported format. Use: jsonl, alpaca, sharegpt, openai"
3118 });
3119 return (StatusCode::BAD_REQUEST, Json(response)).into_response();
3120 }
3121 };
3122
3123 match export_result {
3124 Ok(data) => {
3125 let sample_count = data.lines().count();
3126 info!("Training data exported: {} samples in {} format", sample_count, format);
3127 let response = serde_json::json!({
3128 "status": "success",
3129 "data": {
3130 "type": "exportedData",
3131 "format": format,
3132 "sampleCount": sample_count,
3133 "data": data
3134 }
3135 });
3136 (StatusCode::OK, Json(response)).into_response()
3137 }
3138 Err(e) => {
3139 error!("Failed to export training data: {}", e);
3140 let response = serde_json::json!({
3141 "status": "error",
3142 "code": "EXPORT_FAILED",
3143 "message": e.to_string()
3144 });
3145 (StatusCode::INTERNAL_SERVER_ERROR, Json(response)).into_response()
3146 }
3147 }
3148 } else {
3149 let response = serde_json::json!({
3150 "status": "error",
3151 "code": "NOT_AVAILABLE",
3152 "message": "Training collector not initialized"
3153 });
3154 (StatusCode::SERVICE_UNAVAILABLE, Json(response)).into_response()
3155 }
3156}
3157
3158#[derive(Debug, Deserialize)]
3160#[serde(rename_all = "camelCase")]
3161pub struct ListSamplesQuery {
3162 pub limit: Option<usize>,
3163 pub offset: Option<usize>,
3164 pub min_quality: Option<f32>,
3165}
3166
3167pub async fn training_samples_handler(
3168 State(state): State<ServerState>,
3169 axum::extract::Query(query): axum::extract::Query<ListSamplesQuery>,
3170) -> impl IntoResponse {
3171 let runtime = state.api_state.runtime.read().unwrap();
3172
3173 if let Some(collector) = runtime.get_training_collector() {
3174 let all_samples = collector.get_samples();
3175 let total = all_samples.len();
3176
3177 let filtered: Vec<_> = if let Some(min_q) = query.min_quality {
3179 all_samples.into_iter().filter(|s| s.quality_score >= min_q).collect()
3180 } else {
3181 all_samples
3182 };
3183
3184 let offset = query.offset.unwrap_or(0);
3186 let limit = query.limit.unwrap_or(50).min(500);
3187
3188 let samples: Vec<_> = filtered.into_iter()
3189 .skip(offset)
3190 .take(limit)
3191 .collect();
3192
3193 let response = serde_json::json!({
3194 "status": "success",
3195 "data": {
3196 "type": "sampleList",
3197 "samples": samples,
3198 "total": total,
3199 "offset": offset
3200 }
3201 });
3202 (StatusCode::OK, Json(response))
3203 } else {
3204 let response = serde_json::json!({
3205 "status": "error",
3206 "code": "NOT_AVAILABLE",
3207 "message": "Training collector not initialized"
3208 });
3209 (StatusCode::SERVICE_UNAVAILABLE, Json(response))
3210 }
3211}
3212
3213#[derive(Debug, Deserialize)]
3215#[serde(rename_all = "camelCase")]
3216pub struct StartTrainingRequest {
3217 pub format: Option<String>,
3218 pub config: Option<TrainingJobConfig>,
3219 pub backend: Option<serde_json::Value>,
3221}
3222
3223#[derive(Debug, Clone, Deserialize)]
3224#[serde(rename_all = "camelCase")]
3225pub struct TrainingJobConfig {
3226 pub name: Option<String>,
3227 pub min_quality: Option<f32>,
3228 pub include_negative: Option<bool>,
3229 pub auto_save: Option<bool>,
3230}
3231
3232static TRAINING_JOBS: OnceLock<Arc<RwLock<HashMap<Uuid, TrainingJobStatus>>>> = OnceLock::new();
3234
3235pub fn get_training_jobs() -> &'static Arc<RwLock<HashMap<Uuid, TrainingJobStatus>>> {
3237 TRAINING_JOBS.get_or_init(|| Arc::new(RwLock::new(HashMap::new())))
3238}
3239
3240#[derive(Clone, serde::Serialize)]
3241#[serde(rename_all = "camelCase")]
3242pub struct TrainingJobStatus {
3243 pub job_id: Uuid,
3244 pub state: String,
3245 pub progress: f32,
3246 pub samples_processed: usize,
3247 pub total_samples: usize,
3248 pub started_at: i64,
3249 pub completed_at: Option<i64>,
3250 pub error: Option<String>,
3251 pub result_path: Option<String>,
3252}
3253
3254pub async fn training_start_handler(
3255 State(state): State<ServerState>,
3256 Json(payload): Json<StartTrainingRequest>,
3257) -> Response {
3258 let has_mcp_service = {
3260 let runtime = state.api_state.runtime.read().unwrap();
3261 runtime.get_service("mcp-server").is_some()
3262 };
3263
3264 let config_clone = payload.config.clone();
3266
3267 if has_mcp_service {
3268 let mcp_port = std::env::var("MCP_PORT").unwrap_or_else(|_| "8443".to_string());
3270 let mcp_host = std::env::var("MCP_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
3271 let mcp_url = format!("http://{}:{}/mcp/training/start", mcp_host, mcp_port);
3272
3273 let auth_token = std::env::var("MCP_AUTH_TOKEN").ok();
3275
3276 let client = HttpClient::builder()
3277 .timeout(std::time::Duration::from_secs(5))
3278 .build()
3279 .unwrap_or_else(|_| HttpClient::new());
3280 let mut request = client.post(&mcp_url);
3281
3282 if let Some(token) = auth_token {
3284 request = request.header("Authorization", format!("Bearer {}", token));
3285 }
3286
3287 let format_str = payload.format.as_deref().unwrap_or("jsonl");
3289 let config = payload.config.unwrap_or(TrainingJobConfig {
3290 name: Some(format!("Training job {}", Uuid::new_v4())),
3291 min_quality: Some(0.6),
3292 include_negative: Some(true),
3293 auto_save: Some(true),
3294 });
3295
3296 let config_name = config.name.unwrap_or_else(|| format!("Training job {}", Uuid::new_v4()));
3298 let config_min_quality = config.min_quality.unwrap_or(0.6);
3299 let config_include_negative = config.include_negative.unwrap_or(true);
3300 let config_auto_save = config.auto_save.unwrap_or(true);
3301
3302 let format_enum = format_str.to_lowercase();
3305
3306 let mut mcp_payload = serde_json::json!({
3308 "format": format_enum,
3309 "config": {
3310 "name": config_name,
3311 "minQuality": config_min_quality,
3312 "includeNegative": config_include_negative,
3313 "autoSave": config_auto_save,
3314 }
3315 });
3316
3317 if let Some(backend_config) = &payload.backend {
3319 mcp_payload["backend"] = backend_config.clone();
3320 info!("Using dynamic backend config from UI: {:?}", backend_config);
3321 }
3322
3323 match request.json(&mcp_payload).send().await {
3324 Ok(response) => {
3325 let status = response.status();
3326 match response.text().await {
3328 Ok(text) => {
3329 if text.trim_start().starts_with("<!") || text.trim_start().starts_with("<html") {
3331 warn!("MCP server not accessible on {}:{} (got HTML response). Falling back to export-only.", mcp_host, mcp_port);
3332 } else {
3334 match serde_json::from_str::<serde_json::Value>(&text) {
3336 Ok(data) => {
3337 let job_id = data.get("data")
3340 .and_then(|d| d.get("jobId"))
3341 .and_then(|id| id.as_str())
3342 .or_else(|| {
3343 data.get("data")
3344 .and_then(|d| d.get("job_id"))
3345 .and_then(|id| id.as_str())
3346 });
3347
3348 let response_json = serde_json::json!({
3349 "status": "success",
3350 "data": {
3351 "type": "jobStarted",
3352 "jobId": job_id.unwrap_or("unknown"),
3353 "estimatedDuration": 30
3354 }
3355 });
3356
3357 return (StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::ACCEPTED), Json(response_json)).into_response();
3358 }
3359 Err(e) => {
3360 warn!("Failed to parse MCP response as JSON: {}. Response text (first 200 chars): {}. Falling back to export-only.", e, text.chars().take(200).collect::<String>());
3361 }
3363 }
3364 }
3365 }
3366 Err(e) => {
3367 warn!("Failed to read MCP response body: {}. Falling back to export-only.", e);
3368 }
3370 }
3371 }
3372 Err(e) => {
3373 warn!("Failed to proxy to MCP server: {}. Falling back to export-only.", e);
3374 }
3376 }
3377 } else {
3378 }
3380
3381 let collector = {
3383 let runtime = state.api_state.runtime.read().unwrap();
3384 runtime.get_training_collector()
3385 };
3386
3387 if let Some(collector) = collector {
3388 let job_id = Uuid::new_v4();
3389 let format = payload.format.as_deref().unwrap_or("jsonl");
3390 let _config = config_clone.unwrap_or(TrainingJobConfig {
3391 name: Some(format!("Training job {}", job_id)),
3392 min_quality: Some(0.6),
3393 include_negative: Some(true),
3394 auto_save: Some(true),
3395 });
3396
3397 let job_status = TrainingJobStatus {
3398 job_id,
3399 state: "pending".to_string(),
3400 progress: 0.0,
3401 samples_processed: 0,
3402 total_samples: collector.count(),
3403 started_at: chrono::Utc::now().timestamp(),
3404 completed_at: None,
3405 error: None,
3406 result_path: None,
3407 };
3408
3409 {
3411 let mut jobs = get_training_jobs().write().unwrap();
3412 jobs.insert(job_id, job_status.clone());
3413 }
3414
3415 let collector_clone = collector.clone();
3417 let format_clone = format.to_string();
3418 tokio::spawn(async move {
3419 {
3421 let mut jobs = get_training_jobs().write().unwrap();
3422 if let Some(job) = jobs.get_mut(&job_id) {
3423 job.state = "running".to_string();
3424 }
3425 }
3426
3427 let training_format = match format_clone.as_str() {
3429 "alpaca" => crate::training::TrainingFormat::Alpaca,
3430 "sharegpt" => crate::training::TrainingFormat::ShareGpt,
3431 "openai" => crate::training::TrainingFormat::OpenAi,
3432 _ => crate::training::TrainingFormat::Jsonl,
3433 };
3434
3435 match collector_clone.save_to_file(training_format).await {
3436 Ok(path) => {
3437 let mut jobs = get_training_jobs().write().unwrap();
3438 if let Some(job) = jobs.get_mut(&job_id) {
3439 job.state = "completed".to_string();
3440 job.progress = 1.0;
3441 job.completed_at = Some(chrono::Utc::now().timestamp());
3442 job.result_path = Some(path.to_string_lossy().to_string());
3443 }
3444 info!("Training job {} completed: {}", job_id, path.display());
3445 }
3446 Err(e) => {
3447 let mut jobs = get_training_jobs().write().unwrap();
3448 if let Some(job) = jobs.get_mut(&job_id) {
3449 job.state = "failed".to_string();
3450 job.completed_at = Some(chrono::Utc::now().timestamp());
3451 job.error = Some(e.to_string());
3452 }
3453 error!("Training job {} failed: {}", job_id, e);
3454 }
3455 }
3456 });
3457
3458 info!("Training job {} started", job_id);
3459 let response = serde_json::json!({
3460 "status": "success",
3461 "data": {
3462 "type": "jobStarted",
3463 "jobId": job_id.to_string(),
3464 "estimatedDuration": 30
3465 }
3466 });
3467 (StatusCode::ACCEPTED, Json(response)).into_response()
3468 } else {
3469 let response = serde_json::json!({
3470 "status": "error",
3471 "code": "NOT_AVAILABLE",
3472 "message": "Training collector not initialized"
3473 });
3474 (StatusCode::SERVICE_UNAVAILABLE, Json(response)).into_response()
3475 }
3476}
3477
3478#[derive(Debug, Deserialize)]
3480#[serde(rename_all = "camelCase")]
3481pub struct JobStatusQuery {
3482 pub job_id: Uuid,
3483}
3484
3485pub async fn training_job_status_handler(
3486 axum::extract::Query(query): axum::extract::Query<JobStatusQuery>,
3487) -> impl IntoResponse {
3488 let jobs = get_training_jobs().read().unwrap();
3489
3490 if let Some(job) = jobs.get(&query.job_id) {
3491 let response = serde_json::json!({
3492 "status": "success",
3493 "data": {
3494 "type": "jobStatus",
3495 "job": job
3496 }
3497 });
3498 (StatusCode::OK, Json(response))
3499 } else {
3500 let response = serde_json::json!({
3501 "status": "error",
3502 "code": "NOT_FOUND",
3503 "message": format!("Job {} not found", query.job_id)
3504 });
3505 (StatusCode::NOT_FOUND, Json(response))
3506 }
3507}
3508
3509pub async fn training_jobs_handler() -> impl IntoResponse {
3511 let jobs = get_training_jobs().read().unwrap();
3512 let job_list: Vec<_> = jobs.values().cloned().collect();
3513 let total = job_list.len();
3514
3515 let response = serde_json::json!({
3516 "status": "success",
3517 "data": {
3518 "type": "jobList",
3519 "jobs": job_list,
3520 "total": total
3521 }
3522 });
3523 (StatusCode::OK, Json(response))
3524}