1use crate::agent::model::{GenerationConfig, ModelProvider};
6pub use crate::agent::output::{
7 AgentOutput, GraphDebugInfo, GraphDebugNode, MemoryRecallMatch, MemoryRecallStats,
8 MemoryRecallStrategy, ToolInvocation,
9};
10use crate::config::agent::AgentProfile;
11use crate::embeddings::EmbeddingsClient;
12use crate::persistence::Persistence;
13use crate::policy::{PolicyDecision, PolicyEngine};
14use crate::spec::AgentSpec;
15use crate::tools::{ToolRegistry, ToolResult};
16use crate::types::{Message, MessageRole};
17use crate::SYNC_GRAPH_NAMESPACE;
18use anyhow::{Context, Result};
19use chrono::Utc;
20use serde_json::{json, Value};
21use spec_ai_knowledge_graph::{EdgeType, NodeType, TraversalDirection};
22use std::collections::{HashMap, HashSet};
23use std::path::Path;
24use std::sync::Arc;
25use std::time::Instant;
26use tokio::sync::RwLock;
27use tracing::{debug, info, warn};
28
29const DEFAULT_MAIN_TEMPERATURE: f32 = 0.7;
30const DEFAULT_TOP_P: f32 = 0.9;
31const DEFAULT_FAST_TEMPERATURE: f32 = 0.3;
32const DEFAULT_ESCALATION_THRESHOLD: f32 = 0.6;
33
34struct RecallResult {
35 messages: Vec<Message>,
36 stats: Option<MemoryRecallStats>,
37}
38
39struct ExtractedEntity {
41 name: String,
42 entity_type: String,
43 confidence: f32,
44}
45
46struct ExtractedConcept {
48 name: String,
49 relevance: f32,
50}
51
52#[derive(Debug, Clone)]
53struct GoalContext {
54 message_id: i64,
55 text: String,
56 requires_tool: bool,
57 satisfied: bool,
58 node_id: Option<i64>,
59}
60
61impl GoalContext {
62 fn new(message_id: i64, text: &str, requires_tool: bool, node_id: Option<i64>) -> Self {
63 Self {
64 message_id,
65 text: text.to_string(),
66 requires_tool,
67 satisfied: !requires_tool,
68 node_id,
69 }
70 }
71}
72
73pub struct AgentCore {
75 profile: AgentProfile,
77 provider: Arc<dyn ModelProvider>,
79 fast_provider: Option<Arc<dyn ModelProvider>>,
81 embeddings_client: Option<EmbeddingsClient>,
83 persistence: Persistence,
85 session_id: String,
87 agent_name: Option<String>,
89 conversation_history: Vec<Message>,
91 tool_registry: Arc<ToolRegistry>,
93 policy_engine: Arc<PolicyEngine>,
95 tool_permission_cache: Arc<RwLock<HashMap<String, bool>>>,
97 speak_responses: bool,
99}
100
101impl AgentCore {
102 fn sanitize_session_id(session_id: String) -> (String, bool) {
104 if session_id == SYNC_GRAPH_NAMESPACE {
105 (format!("{}-agent", session_id), true)
106 } else {
107 (session_id, false)
108 }
109 }
110
111 pub fn new(
113 profile: AgentProfile,
114 provider: Arc<dyn ModelProvider>,
115 embeddings_client: Option<EmbeddingsClient>,
116 persistence: Persistence,
117 session_id: String,
118 agent_name: Option<String>,
119 tool_registry: Arc<ToolRegistry>,
120 policy_engine: Arc<PolicyEngine>,
121 speak_responses: bool,
122 ) -> Self {
123 let (session_id, rewrote_namespace) = Self::sanitize_session_id(session_id);
124 if rewrote_namespace {
125 warn!(
126 "Session namespace '{}' is reserved for sync; using '{}' for agent graph state",
127 SYNC_GRAPH_NAMESPACE, session_id
128 );
129 }
130
131 Self {
132 profile,
133 provider,
134 fast_provider: None,
135 embeddings_client,
136 persistence,
137 session_id,
138 agent_name,
139 conversation_history: Vec::new(),
140 tool_registry,
141 policy_engine,
142 tool_permission_cache: Arc::new(RwLock::new(HashMap::new())),
143 speak_responses,
144 }
145 }
146
147 pub fn with_fast_provider(mut self, fast_provider: Arc<dyn ModelProvider>) -> Self {
149 self.fast_provider = Some(fast_provider);
150 self
151 }
152
153 pub fn with_session(mut self, session_id: String) -> Self {
155 let (session_id, rewrote_namespace) = Self::sanitize_session_id(session_id);
156 if rewrote_namespace {
157 warn!(
158 "Session namespace '{}' is reserved for sync; using '{}' for agent graph state",
159 SYNC_GRAPH_NAMESPACE, session_id
160 );
161 }
162 self.session_id = session_id;
163 self.conversation_history.clear();
164 self.tool_permission_cache = Arc::new(RwLock::new(HashMap::new()));
165 self
166 }
167
168 pub async fn run_step(&mut self, input: &str) -> Result<AgentOutput> {
170 let run_id = format!("run-{}", Utc::now().timestamp_micros());
171 let total_timer = Instant::now();
172
173 let recall_timer = Instant::now();
175 let recall_result = self.recall_memories(input).await?;
176 self.log_timing("run_step.recall_memories", recall_timer);
177 let recalled_messages = recall_result.messages;
178 let recall_stats = recall_result.stats;
179
180 let prompt_timer = Instant::now();
182 let mut prompt = self.build_prompt(input, &recalled_messages).await?;
183 self.log_timing("run_step.build_prompt", prompt_timer);
184
185 let store_user_timer = Instant::now();
187 let user_message_id = self.store_message(MessageRole::User, input).await?;
188 self.log_timing("run_step.store_user_message", store_user_timer);
189
190 let mut goal_context =
192 Some(self.create_goal_context(user_message_id, input, self.profile.enable_graph)?);
193
194 let mut tool_invocations = Vec::new();
196 let mut final_response = String::new();
197 let mut token_usage = None;
198 let mut finish_reason = None;
199 let mut auto_response: Option<String> = None;
200 let mut reasoning: Option<String> = None;
201 let mut reasoning_summary: Option<String> = None;
202
203 if let Some(goal) = goal_context.as_mut() {
205 if goal.requires_tool {
206 if let Some((tool_name, tool_args)) =
207 Self::infer_goal_tool_action(goal.text.as_str())
208 {
209 if self.is_tool_allowed(&tool_name).await {
210 let tool_timer = Instant::now();
211 let tool_result = self.execute_tool(&run_id, &tool_name, &tool_args).await;
212 self.log_timing("run_step.tool_execution.auto", tool_timer);
213 match tool_result {
214 Ok(result) => {
215 let invocation = ToolInvocation::from_result(
216 &tool_name,
217 tool_args.clone(),
218 &result,
219 );
220 if let Err(err) = self
221 .record_goal_tool_result(goal, &tool_name, &tool_args, &result)
222 {
223 warn!("Failed to record goal progress: {}", err);
224 }
225 if result.success {
226 if let Err(err) =
227 self.update_goal_status(goal, "completed", true, None)
228 {
229 warn!("Failed to update goal status: {}", err);
230 } else {
231 goal.satisfied = true;
232 }
233 }
234 auto_response = Some(Self::format_auto_tool_response(
235 &tool_name,
236 invocation.output.as_deref(),
237 ));
238 tool_invocations.push(invocation);
239 }
240 Err(err) => {
241 warn!("Auto tool execution '{}' failed: {}", tool_name, err);
242 }
243 }
244 }
245 }
246 }
247 }
248
249 let skip_model = goal_context
250 .as_ref()
251 .map(|goal| goal.requires_tool && goal.satisfied && auto_response.is_some())
252 .unwrap_or(false);
253
254 let mut fast_model_final: Option<(String, f32)> = None;
256 if !skip_model {
257 if let Some(task_type) = self.detect_task_type(input) {
258 let complexity = self.estimate_task_complexity(input);
259 if self.should_use_fast_model(&task_type, complexity) {
260 let fast_timer = Instant::now();
261 let fast_result = self.fast_reasoning(&task_type, input).await;
262 self.log_timing("run_step.fast_reasoning_attempt", fast_timer);
263 match fast_result {
264 Ok((fast_text, confidence)) => {
265 if confidence >= self.escalation_threshold() {
266 fast_model_final = Some((fast_text, confidence));
267 } else {
268 prompt.push_str(&format!(
269 "\n\nFAST_MODEL_HINT (task={} confidence={:.0}%):\n{}\n\nRefine this hint and produce a complete answer.",
270 task_type,
271 (confidence * 100.0).round(),
272 fast_text
273 ));
274 }
275 }
276 Err(err) => {
277 warn!("Fast reasoning failed for task {}: {}", task_type, err);
278 }
279 }
280 }
281 }
282 }
283
284 if skip_model {
285 final_response = auto_response.unwrap_or_else(|| "Task completed.".to_string());
286 finish_reason = Some("auto_tool".to_string());
287 } else if let Some((fast_text, confidence)) = fast_model_final {
288 final_response = fast_text;
289 finish_reason = Some(format!("fast_model ({:.0}%)", (confidence * 100.0).round()));
290 } else {
291 for _iteration in 0..5 {
293 let generation_config = self.build_generation_config();
295 let model_timer = Instant::now();
296 let response_result = self.provider.generate(&prompt, &generation_config).await;
297 self.log_timing("run_step.main_model_call", model_timer);
298 let response = response_result.context("Failed to generate response from model")?;
299
300 token_usage = response.usage;
301 finish_reason = response.finish_reason.clone();
302 final_response = response.content.clone();
303 reasoning = response.reasoning.clone();
304
305 if let Some(ref reasoning_text) = reasoning {
307 reasoning_summary = self.summarize_reasoning(reasoning_text).await;
308 }
309
310 let sdk_tool_calls = response.tool_calls.clone().unwrap_or_default();
312
313 if sdk_tool_calls.is_empty() {
315 let is_complete = finish_reason.as_ref().is_some_and(|reason| {
317 let reason_lower = reason.to_lowercase();
318 reason_lower.contains("stop")
319 || reason_lower.contains("end_turn")
320 || reason_lower.contains("complete")
321 || reason_lower == "length"
322 });
323
324 let goal_needs_tool = goal_context
326 .as_ref()
327 .is_some_and(|g| g.requires_tool && !g.satisfied);
328
329 if is_complete && !goal_needs_tool {
330 debug!("Early termination: response complete with no tool calls needed");
331 break;
332 }
333 }
334
335 if !sdk_tool_calls.is_empty() {
336 for tool_call in sdk_tool_calls {
338 let tool_name = &tool_call.function_name;
339 let tool_args = &tool_call.arguments;
340
341 if !self.is_tool_allowed(tool_name).await {
343 warn!(
344 "Tool '{}' is not allowed by agent policy - prompting user",
345 tool_name
346 );
347
348 match self.prompt_for_tool_permission(tool_name).await {
350 Ok(true) => {
351 info!("User granted permission for tool '{}'", tool_name);
352 }
354 Ok(false) => {
355 let error_msg =
356 format!("Tool '{}' was denied by user", tool_name);
357 warn!("{}", error_msg);
358 tool_invocations.push(ToolInvocation {
359 name: tool_name.clone(),
360 arguments: tool_args.clone(),
361 success: false,
362 output: None,
363 error: Some(error_msg),
364 });
365 continue;
366 }
367 Err(e) => {
368 let error_msg = format!(
369 "Failed to get user permission for tool '{}': {}",
370 tool_name, e
371 );
372 warn!("{}", error_msg);
373 tool_invocations.push(ToolInvocation {
374 name: tool_name.clone(),
375 arguments: tool_args.clone(),
376 success: false,
377 output: None,
378 error: Some(error_msg),
379 });
380 continue;
381 }
382 }
383 }
384
385 let tool_timer = Instant::now();
387 let exec_result = self.execute_tool(&run_id, tool_name, tool_args).await;
388 self.log_timing("run_step.tool_execution.sdk", tool_timer);
389 match exec_result {
390 Ok(result) => {
391 let invocation = ToolInvocation::from_result(
392 tool_name,
393 tool_args.clone(),
394 &result,
395 );
396 let tool_output = invocation.output.clone().unwrap_or_default();
397 let was_success = invocation.success;
398 let error_message = invocation
399 .error
400 .clone()
401 .unwrap_or_else(|| "Tool execution failed".to_string());
402 tool_invocations.push(invocation);
403
404 if let Some(goal) = goal_context.as_mut() {
405 if let Err(err) = self.record_goal_tool_result(
406 goal, tool_name, tool_args, &result,
407 ) {
408 warn!("Failed to record goal progress: {}", err);
409 }
410 if result.success && goal.requires_tool && !goal.satisfied {
411 if let Err(err) =
412 self.update_goal_status(goal, "in_progress", true, None)
413 {
414 warn!("Failed to update goal status: {}", err);
415 }
416 }
417 }
418
419 if was_success {
420 prompt.push_str(&format!(
422 "\n\nTOOL_RESULT from {}:\n{}\n\nBased on this result, please continue.",
423 tool_name, tool_output
424 ));
425 } else {
426 prompt.push_str(&format!(
427 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
428 error_message
429 ));
430 }
431 }
432 Err(e) => {
433 let error_msg =
434 format!("Error executing tool '{}': {}", tool_name, e);
435 warn!("{}", error_msg);
436 prompt.push_str(&format!(
437 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
438 error_msg
439 ));
440 tool_invocations.push(ToolInvocation {
441 name: tool_name.clone(),
442 arguments: tool_args.clone(),
443 success: false,
444 output: None,
445 error: Some(error_msg),
446 });
447 }
448 }
449 }
450
451 continue;
453 }
454
455 if let Some(goal) = goal_context.as_ref() {
456 if goal.requires_tool && !goal.satisfied {
457 prompt.push_str(
458 "\n\nGOAL_STATUS: pending. The user request requires executing an allowed tool. Please call an appropriate tool.",
459 );
460 continue;
461 }
462 }
463
464 break;
466 }
467 }
468
469 let store_assistant_timer = Instant::now();
471 let response_message_id = self
472 .store_message_with_reasoning(
473 MessageRole::Assistant,
474 &final_response,
475 reasoning.as_deref(),
476 )
477 .await?;
478 self.log_timing("run_step.store_assistant_message", store_assistant_timer);
479
480 if let Some(goal) = goal_context.as_mut() {
481 if goal.requires_tool {
482 if goal.satisfied {
483 if let Err(err) =
484 self.update_goal_status(goal, "completed", true, Some(response_message_id))
485 {
486 warn!("Failed to finalize goal status: {}", err);
487 }
488 } else if let Err(err) =
489 self.update_goal_status(goal, "blocked", false, Some(response_message_id))
490 {
491 warn!("Failed to record blocked goal status: {}", err);
492 }
493 } else if let Err(err) =
494 self.update_goal_status(goal, "completed", true, Some(response_message_id))
495 {
496 warn!("Failed to finalize goal status: {}", err);
497 }
498 }
499
500 self.conversation_history.push(Message {
502 id: user_message_id,
503 session_id: self.session_id.clone(),
504 role: MessageRole::User,
505 content: input.to_string(),
506 created_at: Utc::now(),
507 });
508
509 self.conversation_history.push(Message {
510 id: response_message_id,
511 session_id: self.session_id.clone(),
512 role: MessageRole::Assistant,
513 content: final_response.clone(),
514 created_at: Utc::now(),
515 });
516
517 let next_action_recommendation =
520 if self.profile.enable_graph && self.conversation_history.len() >= 3 {
521 let graph_timer = Instant::now();
522 let recommendation = self.evaluate_graph_for_next_action(
523 user_message_id,
524 response_message_id,
525 &final_response,
526 &tool_invocations,
527 )?;
528 self.log_timing("run_step.evaluate_graph_for_next_action", graph_timer);
529 recommendation
530 } else {
531 None
532 };
533
534 if let Some(ref recommendation) = next_action_recommendation {
536 tracing::debug!("Knowledge graph recommends next action: {}", recommendation);
537 let system_content = format!("Graph recommendation: {}", recommendation);
538 let system_store_timer = Instant::now();
539 let system_message_id = self
540 .store_message(MessageRole::System, &system_content)
541 .await?;
542 self.log_timing("run_step.store_system_message", system_store_timer);
543
544 self.conversation_history.push(Message {
545 id: system_message_id,
546 session_id: self.session_id.clone(),
547 role: MessageRole::System,
548 content: system_content,
549 created_at: Utc::now(),
550 });
551 }
552
553 let graph_debug = match self.snapshot_graph_debug_info() {
554 Ok(info) => Some(info),
555 Err(err) => {
556 warn!("Failed to capture graph debug info: {}", err);
557 None
558 }
559 };
560
561 self.log_timing("run_step.total", total_timer);
562
563 Ok(AgentOutput {
564 response: final_response,
565 response_message_id: Some(response_message_id),
566 token_usage,
567 tool_invocations,
568 finish_reason,
569 recall_stats,
570 run_id,
571 next_action: next_action_recommendation,
572 reasoning,
573 reasoning_summary,
574 graph_debug,
575 })
576 }
577
578 pub async fn run_spec(&mut self, spec: &AgentSpec) -> Result<AgentOutput> {
580 debug!(
581 "Executing structured spec '{}' (source: {:?})",
582 spec.display_name(),
583 spec.source_path()
584 );
585 let prompt = spec.to_prompt();
586 self.run_step(&prompt).await
587 }
588
589 fn build_generation_config(&self) -> GenerationConfig {
591 let temperature = match self.profile.temperature {
592 Some(temp) if temp.is_finite() => Some(temp.clamp(0.0, 2.0)),
593 Some(_) => {
594 warn!(
595 "Ignoring invalid temperature for agent {:?}, falling back to {}",
596 self.agent_name, DEFAULT_MAIN_TEMPERATURE
597 );
598 Some(DEFAULT_MAIN_TEMPERATURE)
599 }
600 None => None,
601 };
602
603 let top_p = if self.profile.top_p.is_finite() {
604 Some(self.profile.top_p.clamp(0.0, 1.0))
605 } else {
606 warn!(
607 "Invalid top_p detected for agent {:?}, falling back to {}",
608 self.agent_name, DEFAULT_TOP_P
609 );
610 Some(DEFAULT_TOP_P)
611 };
612
613 GenerationConfig {
614 temperature,
615 max_tokens: self.profile.max_context_tokens.map(|t| t as u32),
616 stop_sequences: None,
617 top_p,
618 frequency_penalty: None,
619 presence_penalty: None,
620 }
621 }
622
623 fn snapshot_graph_debug_info(&self) -> Result<GraphDebugInfo> {
624 let mut info = GraphDebugInfo {
625 enabled: self.profile.enable_graph,
626 graph_memory_enabled: self.profile.graph_memory,
627 auto_graph_enabled: self.profile.auto_graph,
628 graph_steering_enabled: self.profile.graph_steering,
629 node_count: 0,
630 edge_count: 0,
631 recent_nodes: Vec::new(),
632 };
633
634 if !self.profile.enable_graph {
635 return Ok(info);
636 }
637
638 info.node_count = self.persistence.count_graph_nodes(&self.session_id)?.max(0) as usize;
639 info.edge_count = self.persistence.count_graph_edges(&self.session_id)?.max(0) as usize;
640
641 let recent_nodes = self
642 .persistence
643 .list_graph_nodes(&self.session_id, None, Some(5))?;
644 info.recent_nodes = recent_nodes
645 .into_iter()
646 .map(|node| GraphDebugNode {
647 id: node.id,
648 node_type: node.node_type.as_str().to_string(),
649 label: node.label,
650 })
651 .collect();
652
653 Ok(info)
654 }
655
656 async fn summarize_reasoning(&self, reasoning: &str) -> Option<String> {
658 let fast_provider = self.fast_provider.as_ref()?;
660
661 if reasoning.len() < 50 {
662 return Some(reasoning.to_string());
664 }
665
666 let summary_prompt = format!(
667 "Summarize the following reasoning in 1-2 concise sentences that explain the thought process:\n\n{}\n\nSummary:",
668 reasoning
669 );
670
671 let config = GenerationConfig {
672 temperature: Some(0.3),
673 max_tokens: Some(100),
674 stop_sequences: None,
675 top_p: Some(0.9),
676 frequency_penalty: None,
677 presence_penalty: None,
678 };
679
680 let timer = Instant::now();
681 let response = fast_provider.generate(&summary_prompt, &config).await;
682 self.log_timing("summarize_reasoning.generate", timer);
683 match response {
684 Ok(response) => {
685 let summary = response.content.trim().to_string();
686 if !summary.is_empty() {
687 debug!("Generated reasoning summary: {}", summary);
688 Some(summary)
689 } else {
690 None
691 }
692 }
693 Err(e) => {
694 warn!("Failed to summarize reasoning: {}", e);
695 None
696 }
697 }
698 }
699
700 async fn recall_memories(&self, query: &str) -> Result<RecallResult> {
702 const RECENT_CONTEXT: i64 = 2;
703 let mut context = Vec::new();
705 let mut seen_ids = HashSet::new();
706
707 let recent_messages = self
708 .persistence
709 .list_messages(&self.session_id, RECENT_CONTEXT)?;
710
711 if self.conversation_history.is_empty() && recent_messages.is_empty() {
714 return Ok(RecallResult {
715 messages: Vec::new(),
716 stats: Some(MemoryRecallStats {
717 strategy: MemoryRecallStrategy::RecentContext {
718 limit: RECENT_CONTEXT as usize,
719 },
720 matches: Vec::new(),
721 }),
722 });
723 }
724
725 for message in recent_messages {
726 seen_ids.insert(message.id);
727 context.push(message);
728 }
729
730 if self.profile.enable_graph && self.profile.graph_memory {
732 let mut graph_messages = Vec::new();
733
734 for msg in &context {
736 let nodes = self.persistence.list_graph_nodes(
738 &self.session_id,
739 Some(NodeType::Message),
740 Some(10),
741 )?;
742
743 for node in nodes {
744 if let Some(msg_id) = node.properties["message_id"].as_i64() {
745 if msg_id == msg.id {
746 let neighbors = self.persistence.traverse_neighbors(
748 &self.session_id,
749 node.id,
750 TraversalDirection::Both,
751 self.profile.graph_depth,
752 )?;
753
754 for neighbor in neighbors {
756 if neighbor.node_type == NodeType::Message {
757 if let Some(related_msg_id) =
758 neighbor.properties["message_id"].as_i64()
759 {
760 if !seen_ids.contains(&related_msg_id) {
761 if let Some(related_msg) =
762 self.persistence.get_message(related_msg_id)?
763 {
764 seen_ids.insert(related_msg.id);
765 graph_messages.push(related_msg);
766 }
767 }
768 }
769 }
770 }
771 }
772 }
773 }
774 }
775
776 context.extend(graph_messages);
778 }
779
780 if let Some(client) = &self.embeddings_client {
781 if self.profile.memory_k == 0 || query.trim().is_empty() {
782 return Ok(RecallResult {
783 messages: context,
784 stats: None,
785 });
786 }
787
788 let embed_timer = Instant::now();
789 let embed_result = client.embed_batch(&[query]).await;
790 self.log_timing("recall_memories.embed_batch", embed_timer);
791 match embed_result {
792 Ok(mut embeddings) => match embeddings.pop() {
793 Some(query_embedding) if !query_embedding.is_empty() => {
794 let recalled = self.persistence.recall_top_k(
795 &self.session_id,
796 &query_embedding,
797 self.profile.memory_k,
798 )?;
799
800 let mut matches = Vec::new();
801 let mut semantic_context = Vec::new();
802
803 for (memory, score) in recalled {
804 if let Some(message_id) = memory.message_id {
805 if seen_ids.contains(&message_id) {
806 continue;
807 }
808
809 if let Some(message) = self.persistence.get_message(message_id)? {
810 seen_ids.insert(message.id);
811 matches.push(MemoryRecallMatch {
812 message_id: Some(message.id),
813 score,
814 role: message.role.clone(),
815 preview: preview_text(&message.content),
816 });
817 semantic_context.push(message);
818 }
819 } else {
820 if let Some(transcription_text) =
822 self.persistence.get_transcription_by_embedding(memory.id)?
823 {
824 let transcription_message = Message {
826 id: memory.id, session_id: memory.session_id.clone(),
828 role: MessageRole::User, content: format!("[Transcription] {}", transcription_text),
830 created_at: memory.created_at,
831 };
832
833 matches.push(MemoryRecallMatch {
834 message_id: None, score,
836 role: MessageRole::User,
837 preview: preview_text(&transcription_text),
838 });
839 semantic_context.push(transcription_message);
840 }
841 }
842 }
843
844 if self.profile.enable_graph && self.profile.graph_memory {
846 let mut graph_expanded = Vec::new();
847
848 for msg in &semantic_context {
849 let nodes = self.persistence.list_graph_nodes(
851 &self.session_id,
852 Some(NodeType::Message),
853 Some(100),
854 )?;
855
856 for node in nodes {
857 if let Some(msg_id) = node.properties["message_id"].as_i64() {
858 if msg_id == msg.id {
859 let neighbors = self.persistence.traverse_neighbors(
861 &self.session_id,
862 node.id,
863 TraversalDirection::Both,
864 self.profile.graph_depth,
865 )?;
866
867 for neighbor in neighbors {
868 if matches!(
870 neighbor.node_type,
871 NodeType::Fact
872 | NodeType::Concept
873 | NodeType::Entity
874 ) {
875 let graph_content = format!(
877 "[Graph Context - {} {}]: {}",
878 neighbor.node_type.as_str(),
879 neighbor.label,
880 neighbor.properties
881 );
882
883 let graph_msg = Message {
885 id: -1, session_id: self.session_id.clone(),
887 role: MessageRole::System,
888 content: graph_content,
889 created_at: Utc::now(),
890 };
891
892 graph_expanded.push(graph_msg);
893 }
894 }
895 }
896 }
897 }
898 }
899
900 let total_slots = self.profile.memory_k.max(1);
902 let mut graph_limit =
903 ((total_slots as f32) * self.profile.graph_weight).round() as usize;
904 graph_limit = graph_limit.min(total_slots);
905 if graph_limit == 0 && !graph_expanded.is_empty() {
906 graph_limit = 1;
907 }
908
909 let mut semantic_limit = total_slots.saturating_sub(graph_limit);
910 if semantic_limit == 0 && !semantic_context.is_empty() {
911 semantic_limit = 1;
912 graph_limit = graph_limit.saturating_sub(1);
913 }
914
915 let mut limited_semantic = semantic_context;
916 if limited_semantic.len() > semantic_limit && semantic_limit > 0 {
917 limited_semantic.truncate(semantic_limit);
918 }
919
920 let mut limited_graph = graph_expanded;
921 if limited_graph.len() > graph_limit && graph_limit > 0 {
922 limited_graph.truncate(graph_limit);
923 }
924
925 context.extend(limited_semantic);
926 context.extend(limited_graph);
927 } else {
928 context.extend(semantic_context);
929 }
930
931 return Ok(RecallResult {
932 messages: context,
933 stats: Some(MemoryRecallStats {
934 strategy: MemoryRecallStrategy::Semantic {
935 requested: self.profile.memory_k,
936 returned: matches.len(),
937 },
938 matches,
939 }),
940 });
941 }
942 _ => {
943 return Ok(RecallResult {
944 messages: context,
945 stats: Some(MemoryRecallStats {
946 strategy: MemoryRecallStrategy::Semantic {
947 requested: self.profile.memory_k,
948 returned: 0,
949 },
950 matches: Vec::new(),
951 }),
952 });
953 }
954 },
955 Err(err) => {
956 warn!("Failed to embed recall query: {}", err);
957 return Ok(RecallResult {
958 messages: context,
959 stats: None,
960 });
961 }
962 }
963 }
964
965 let limit = self.profile.memory_k as i64;
967 let messages = self.persistence.list_messages(&self.session_id, limit)?;
968
969 let stats = if self.profile.memory_k > 0 {
970 Some(MemoryRecallStats {
971 strategy: MemoryRecallStrategy::RecentContext {
972 limit: self.profile.memory_k,
973 },
974 matches: Vec::new(),
975 })
976 } else {
977 None
978 };
979
980 Ok(RecallResult { messages, stats })
981 }
982
983 async fn build_prompt(&self, input: &str, context_messages: &[Message]) -> Result<String> {
985 let mut prompt = String::new();
986
987 if let Some(system_prompt) = &self.profile.prompt {
989 prompt.push_str("System: ");
990 prompt.push_str(system_prompt);
991 prompt.push_str("\n\n");
992 }
993
994 if self.speak_responses {
996 prompt.push_str("System: Speech mode is enabled; respond with concise, natural sentences suitable for text-to-speech. Avoid markdown/code fences and keep the reply brief.\n\n");
997 }
998
999 let available_tools = self.tool_registry.list();
1001 tracing::debug!("Tool registry has {} tools", available_tools.len());
1002 if !available_tools.is_empty() {
1003 prompt.push_str("Available tools:\n");
1004 for tool_name in &available_tools {
1005 info!(
1006 "Checking tool: {} - allowed: {}",
1007 tool_name,
1008 self.is_tool_allowed(tool_name).await
1009 );
1010 if self.is_tool_allowed(tool_name).await {
1011 if let Some(tool) = self.tool_registry.get(tool_name) {
1012 prompt.push_str(&format!("- {}: {}\n", tool_name, tool.description()));
1013 }
1014 }
1015 }
1016 prompt.push('\n');
1017 }
1018
1019 if !context_messages.is_empty() {
1021 prompt.push_str("Previous conversation:\n");
1022 for msg in context_messages {
1023 prompt.push_str(&format!("{}: {}\n", msg.role.as_str(), msg.content));
1024 }
1025 prompt.push('\n');
1026 }
1027
1028 prompt.push_str(&format!("user: {}\n", input));
1030
1031 prompt.push_str("assistant:");
1032
1033 Ok(prompt)
1034 }
1035
1036 async fn store_message(&self, role: MessageRole, content: &str) -> Result<i64> {
1038 self.store_message_with_reasoning(role, content, None).await
1039 }
1040
1041 async fn store_message_with_reasoning(
1043 &self,
1044 role: MessageRole,
1045 content: &str,
1046 reasoning: Option<&str>,
1047 ) -> Result<i64> {
1048 let message_id = self
1049 .persistence
1050 .insert_message(&self.session_id, role.clone(), content)
1051 .context("Failed to store message")?;
1052
1053 let mut embedding_id = None;
1054
1055 if let Some(client) = &self.embeddings_client {
1056 if !content.trim().is_empty() {
1057 let embed_timer = Instant::now();
1058 let embed_result = client.embed_batch(&[content]).await;
1059 self.log_timing("embeddings.message_content", embed_timer);
1060 match embed_result {
1061 Ok(mut embeddings) => {
1062 if let Some(embedding) = embeddings.pop() {
1063 if !embedding.is_empty() {
1064 match self.persistence.insert_memory_vector(
1065 &self.session_id,
1066 Some(message_id),
1067 &embedding,
1068 ) {
1069 Ok(emb_id) => {
1070 embedding_id = Some(emb_id);
1071 }
1072 Err(err) => {
1073 warn!(
1074 "Failed to persist embedding for message {}: {}",
1075 message_id, err
1076 );
1077 }
1078 }
1079 }
1080 }
1081 }
1082 Err(err) => {
1083 warn!(
1084 "Failed to create embedding for message {}: {}",
1085 message_id, err
1086 );
1087 }
1088 }
1089 }
1090 }
1091
1092 if self.profile.enable_graph && self.profile.auto_graph {
1094 self.build_graph_for_message(message_id, role, content, embedding_id, reasoning)?;
1095 }
1096
1097 Ok(message_id)
1098 }
1099
1100 fn build_graph_for_message(
1102 &self,
1103 message_id: i64,
1104 role: MessageRole,
1105 content: &str,
1106 embedding_id: Option<i64>,
1107 reasoning: Option<&str>,
1108 ) -> Result<()> {
1109 use serde_json::json;
1110
1111 let mut message_props = json!({
1113 "message_id": message_id,
1114 "role": role.as_str(),
1115 "content_preview": preview_text(content),
1116 "timestamp": Utc::now().to_rfc3339(),
1117 });
1118
1119 if let Some(reasoning_text) = reasoning {
1121 if !reasoning_text.is_empty() {
1122 message_props["has_reasoning"] = json!(true);
1123 message_props["reasoning_preview"] = json!(preview_text(reasoning_text));
1124 }
1125 }
1126
1127 let message_node_id = self.persistence.insert_graph_node(
1128 &self.session_id,
1129 NodeType::Message,
1130 &format!("{:?}Message", role),
1131 &message_props,
1132 embedding_id,
1133 )?;
1134
1135 let mut entities = self.extract_entities_from_text(content);
1137 let mut concepts = self.extract_concepts_from_text(content);
1138
1139 if let Some(reasoning_text) = reasoning {
1142 if !reasoning_text.is_empty() {
1143 debug!(
1144 "Extracting entities/concepts from reasoning for message {}",
1145 message_id
1146 );
1147 let reasoning_entities = self.extract_entities_from_text(reasoning_text);
1148 let reasoning_concepts = self.extract_concepts_from_text(reasoning_text);
1149
1150 for mut reasoning_entity in reasoning_entities {
1152 if let Some(existing) = entities.iter_mut().find(|e| {
1154 e.name.to_lowercase() == reasoning_entity.name.to_lowercase()
1155 && e.entity_type == reasoning_entity.entity_type
1156 }) {
1157 existing.confidence =
1159 (existing.confidence + reasoning_entity.confidence * 0.5).min(1.0);
1160 } else {
1161 reasoning_entity.confidence *= 0.8;
1163 entities.push(reasoning_entity);
1164 }
1165 }
1166
1167 for mut reasoning_concept in reasoning_concepts {
1169 if let Some(existing) = concepts
1170 .iter_mut()
1171 .find(|c| c.name.to_lowercase() == reasoning_concept.name.to_lowercase())
1172 {
1173 existing.relevance =
1174 (existing.relevance + reasoning_concept.relevance * 0.5).min(1.0);
1175 } else {
1176 reasoning_concept.relevance *= 0.8;
1177 concepts.push(reasoning_concept);
1178 }
1179 }
1180 }
1181 }
1182
1183 for entity in entities {
1185 let entity_node_id = self.persistence.insert_graph_node(
1186 &self.session_id,
1187 NodeType::Entity,
1188 &entity.entity_type,
1189 &json!({
1190 "name": entity.name,
1191 "type": entity.entity_type,
1192 "extracted_from": message_id,
1193 }),
1194 None,
1195 )?;
1196
1197 self.persistence.insert_graph_edge(
1199 &self.session_id,
1200 message_node_id,
1201 entity_node_id,
1202 EdgeType::Mentions,
1203 Some("mentions"),
1204 Some(&json!({"confidence": entity.confidence})),
1205 entity.confidence,
1206 )?;
1207 }
1208
1209 for concept in concepts {
1211 let concept_node_id = self.persistence.insert_graph_node(
1212 &self.session_id,
1213 NodeType::Concept,
1214 "Concept",
1215 &json!({
1216 "name": concept.name,
1217 "extracted_from": message_id,
1218 }),
1219 None,
1220 )?;
1221
1222 self.persistence.insert_graph_edge(
1224 &self.session_id,
1225 message_node_id,
1226 concept_node_id,
1227 EdgeType::RelatesTo,
1228 Some("discusses"),
1229 Some(&json!({"relevance": concept.relevance})),
1230 concept.relevance,
1231 )?;
1232 }
1233
1234 let recent_messages = self.persistence.list_messages(&self.session_id, 2)?;
1236 if recent_messages.len() > 1 {
1237 let nodes = self.persistence.list_graph_nodes(
1239 &self.session_id,
1240 Some(NodeType::Message),
1241 Some(10),
1242 )?;
1243
1244 for node in nodes {
1245 if let Some(prev_msg_id) = node.properties["message_id"].as_i64() {
1246 if prev_msg_id != message_id && prev_msg_id == recent_messages[0].id {
1247 self.persistence.insert_graph_edge(
1249 &self.session_id,
1250 node.id,
1251 message_node_id,
1252 EdgeType::FollowsFrom,
1253 Some("conversation_flow"),
1254 None,
1255 1.0,
1256 )?;
1257 break;
1258 }
1259 }
1260 }
1261 }
1262
1263 Ok(())
1264 }
1265
1266 fn create_goal_context(
1267 &self,
1268 message_id: i64,
1269 input: &str,
1270 persist: bool,
1271 ) -> Result<GoalContext> {
1272 let requires_tool = Self::goal_requires_tool(input);
1273 let node_id = if self.profile.enable_graph {
1274 if persist {
1275 let properties = json!({
1276 "message_id": message_id,
1277 "goal_text": input,
1278 "status": "pending",
1279 "requires_tool": requires_tool,
1280 "satisfied": false,
1281 "created_at": Utc::now().to_rfc3339(),
1282 });
1283 Some(self.persistence.insert_graph_node(
1284 &self.session_id,
1285 NodeType::Goal,
1286 "Goal",
1287 &properties,
1288 None,
1289 )?)
1290 } else {
1291 None
1292 }
1293 } else {
1294 None
1295 };
1296
1297 Ok(GoalContext::new(message_id, input, requires_tool, node_id))
1298 }
1299
1300 fn update_goal_status(
1301 &self,
1302 goal: &mut GoalContext,
1303 status: &str,
1304 satisfied: bool,
1305 response_message_id: Option<i64>,
1306 ) -> Result<()> {
1307 goal.satisfied = satisfied;
1308 if let Some(node_id) = goal.node_id {
1309 let properties = json!({
1310 "message_id": goal.message_id,
1311 "goal_text": goal.text,
1312 "status": status,
1313 "requires_tool": goal.requires_tool,
1314 "satisfied": satisfied,
1315 "response_message_id": response_message_id,
1316 "updated_at": Utc::now().to_rfc3339(),
1317 });
1318 self.persistence.update_graph_node(node_id, &properties)?;
1319 }
1320 Ok(())
1321 }
1322
1323 fn record_goal_tool_result(
1324 &self,
1325 goal: &GoalContext,
1326 tool_name: &str,
1327 args: &Value,
1328 result: &ToolResult,
1329 ) -> Result<()> {
1330 if let Some(goal_node_id) = goal.node_id {
1331 let timestamp = Utc::now().to_rfc3339();
1332 let mut properties = json!({
1333 "tool": tool_name,
1334 "arguments": args,
1335 "success": result.success,
1336 "output_preview": preview_text(&result.output),
1337 "error": result.error,
1338 "timestamp": timestamp,
1339 });
1340
1341 let mut prompt_payload: Option<Value> = None;
1342 if tool_name == "prompt_user" && result.success {
1343 match serde_json::from_str::<Value>(&result.output) {
1344 Ok(payload) => {
1345 if let Some(props) = properties.as_object_mut() {
1346 props.insert("prompt_user_payload".to_string(), payload.clone());
1347 if let Some(response) = payload.get("response") {
1348 props.insert(
1349 "response_preview".to_string(),
1350 Value::String(preview_json_value(response)),
1351 );
1352 }
1353 }
1354 prompt_payload = Some(payload);
1355 }
1356 Err(err) => {
1357 warn!("Failed to parse prompt_user payload for graph: {}", err);
1358 if let Some(props) = properties.as_object_mut() {
1359 props.insert(
1360 "prompt_user_parse_error".to_string(),
1361 Value::String(err.to_string()),
1362 );
1363 }
1364 }
1365 }
1366 }
1367
1368 let tool_node_id = self.persistence.insert_graph_node(
1369 &self.session_id,
1370 NodeType::ToolResult,
1371 tool_name,
1372 &properties,
1373 None,
1374 )?;
1375 self.persistence.insert_graph_edge(
1376 &self.session_id,
1377 tool_node_id,
1378 goal_node_id,
1379 EdgeType::Produces,
1380 Some("satisfies"),
1381 None,
1382 if result.success { 1.0 } else { 0.1 },
1383 )?;
1384
1385 if let Some(payload) = prompt_payload {
1386 let response_preview = payload
1387 .get("response")
1388 .map(preview_json_value)
1389 .unwrap_or_default();
1390
1391 let response_properties = json!({
1392 "prompt": payload_field(&payload, "prompt"),
1393 "input_type": payload_field(&payload, "input_type"),
1394 "response": payload_field(&payload, "response"),
1395 "display_value": payload_field(&payload, "display_value"),
1396 "selections": payload_field(&payload, "selections"),
1397 "metadata": payload_field(&payload, "metadata"),
1398 "used_default": payload_field(&payload, "used_default"),
1399 "used_prefill": payload_field(&payload, "used_prefill"),
1400 "response_preview": response_preview,
1401 "timestamp": timestamp,
1402 });
1403
1404 let response_node_id = self.persistence.insert_graph_node(
1405 &self.session_id,
1406 NodeType::Event,
1407 "UserInput",
1408 &response_properties,
1409 None,
1410 )?;
1411
1412 self.persistence.insert_graph_edge(
1413 &self.session_id,
1414 tool_node_id,
1415 response_node_id,
1416 EdgeType::Produces,
1417 Some("captures_input"),
1418 None,
1419 1.0,
1420 )?;
1421
1422 self.persistence.insert_graph_edge(
1423 &self.session_id,
1424 response_node_id,
1425 goal_node_id,
1426 EdgeType::RelatesTo,
1427 Some("addresses_goal"),
1428 None,
1429 0.9,
1430 )?;
1431 }
1432 }
1433 Ok(())
1434 }
1435
1436 fn goal_requires_tool(input: &str) -> bool {
1437 let normalized = input.to_lowercase();
1438 const ACTION_VERBS: [&str; 18] = [
1439 "list", "show", "read", "write", "create", "update", "delete", "run", "execute",
1440 "open", "search", "fetch", "download", "scan", "compile", "test", "build", "inspect",
1441 ];
1442
1443 if ACTION_VERBS
1444 .iter()
1445 .any(|verb| normalized.contains(verb) && normalized.contains(' '))
1446 {
1447 return true;
1448 }
1449
1450 let mentions_local_context = normalized.contains("this directory")
1453 || normalized.contains("current directory")
1454 || normalized.contains("this folder")
1455 || normalized.contains("here");
1456
1457 let mentions_project = normalized.contains("this project")
1458 || normalized.contains("this repo")
1459 || normalized.contains("this repository")
1460 || normalized.contains("this codebase")
1461 || ((normalized.contains("project")
1463 || normalized.contains("repo")
1464 || normalized.contains("repository")
1465 || normalized.contains("codebase"))
1466 && mentions_local_context);
1467
1468 let asks_about_project = normalized.contains("what can")
1469 || normalized.contains("what is")
1470 || normalized.contains("what does")
1471 || normalized.contains("tell me")
1472 || normalized.contains("describe")
1473 || normalized.contains("about");
1474
1475 mentions_project && asks_about_project
1476 }
1477
1478 fn escalation_threshold(&self) -> f32 {
1479 if self.profile.escalation_threshold.is_finite() {
1480 self.profile.escalation_threshold.clamp(0.0, 1.0)
1481 } else {
1482 warn!(
1483 "Invalid escalation_threshold for agent {:?}, defaulting to {}",
1484 self.agent_name, DEFAULT_ESCALATION_THRESHOLD
1485 );
1486 DEFAULT_ESCALATION_THRESHOLD
1487 }
1488 }
1489
1490 fn detect_task_type(&self, input: &str) -> Option<String> {
1491 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1492 return None;
1493 }
1494
1495 let text = input.to_lowercase();
1496
1497 let candidates: [(&str, &[&str]); 6] = [
1498 ("entity_extraction", &["entity", "extract", "named"]),
1499 ("decision_routing", &["classify", "categorize", "route"]),
1500 (
1501 "tool_selection",
1502 &["which tool", "use which tool", "tool should"],
1503 ),
1504 ("confidence_scoring", &["confidence", "certainty"]),
1505 ("summarization", &["summarize", "summary"]),
1506 ("graph_analysis", &["graph", "connection", "relationships"]),
1507 ];
1508
1509 for (task, keywords) in candidates {
1510 if keywords.iter().any(|kw| text.contains(kw))
1511 && self.profile.fast_model_tasks.iter().any(|t| t == task)
1512 {
1513 return Some(task.to_string());
1514 }
1515 }
1516
1517 None
1518 }
1519
1520 fn estimate_task_complexity(&self, input: &str) -> f32 {
1521 let words = input.split_whitespace().count() as f32;
1522 let clauses =
1523 input.matches(" and ").count() as f32 + input.matches(" then ").count() as f32;
1524 let newlines = input.matches('\n').count() as f32;
1525
1526 let length_factor = (words / 120.0).min(1.0);
1527 let clause_factor = (clauses / 4.0).min(1.0);
1528 let structure_factor = (newlines / 5.0).min(1.0);
1529
1530 (0.6 * length_factor + 0.3 * clause_factor + 0.1 * structure_factor).clamp(0.0, 1.0)
1531 }
1532
1533 fn infer_goal_tool_action(goal_text: &str) -> Option<(String, Value)> {
1534 let text = goal_text.to_lowercase();
1535
1536 let mentions_local_context = text.contains("this directory")
1538 || text.contains("current directory")
1539 || text.contains("this folder")
1540 || text.contains("here");
1541
1542 let mentions_project = text.contains("this project")
1543 || text.contains("this repo")
1544 || text.contains("this repository")
1545 || text.contains("this codebase")
1546 || ((text.contains("project")
1547 || text.contains("repo")
1548 || text.contains("repository")
1549 || text.contains("codebase"))
1550 && mentions_local_context);
1551
1552 let asks_about_project = text.contains("what can")
1553 || text.contains("what is")
1554 || text.contains("what does")
1555 || text.contains("tell me")
1556 || text.contains("describe")
1557 || text.contains("about");
1558
1559 if mentions_project && asks_about_project {
1560 for candidate in &["README.md", "Readme.md", "readme.md"] {
1562 if Path::new(candidate).exists() {
1563 return Some((
1564 "file_read".to_string(),
1565 json!({
1566 "path": candidate,
1567 "max_bytes": 65536
1568 }),
1569 ));
1570 }
1571 }
1572
1573 return Some((
1575 "search".to_string(),
1576 json!({
1577 "query": "Cargo.toml|package.json|pyproject.toml|setup.py",
1578 "regex": true,
1579 "case_sensitive": false,
1580 "max_results": 20
1581 }),
1582 ));
1583 }
1584
1585 if text.contains("list")
1587 && (text.contains("directory") || text.contains("files") || text.contains("folder"))
1588 {
1589 return Some((
1590 "shell".to_string(),
1591 json!({
1592 "command": "ls -a"
1593 }),
1594 ));
1595 }
1596
1597 if text.contains("show") && text.contains("current directory") {
1598 return Some((
1599 "shell".to_string(),
1600 json!({
1601 "command": "ls -a"
1602 }),
1603 ));
1604 }
1605
1606 None
1610 }
1611
1612 fn parse_confidence(text: &str) -> Option<f32> {
1613 for line in text.lines() {
1614 let lower = line.to_lowercase();
1615 if lower.contains("confidence") {
1616 let token = lower
1617 .split(|c: char| !(c.is_ascii_digit() || c == '.'))
1618 .find(|chunk| !chunk.is_empty())?;
1619 if let Ok(value) = token.parse::<f32>() {
1620 if (0.0..=1.0).contains(&value) {
1621 return Some(value);
1622 }
1623 }
1624 }
1625 }
1626 None
1627 }
1628
1629 fn strip_fast_answer(text: &str) -> String {
1630 let mut answer = String::new();
1631 for line in text.lines() {
1632 if line.to_lowercase().starts_with("answer:") {
1633 answer.push_str(line.split_once(':').map(|x| x.1).unwrap_or("").trim());
1634 break;
1635 }
1636 }
1637 if answer.is_empty() {
1638 text.trim().to_string()
1639 } else {
1640 answer
1641 }
1642 }
1643
1644 fn format_auto_tool_response(tool_name: &str, output: Option<&str>) -> String {
1645 let sanitized = output.unwrap_or("").trim();
1646 if sanitized.is_empty() {
1647 return format!("Executed `{}` successfully.", tool_name);
1648 }
1649
1650 if tool_name == "file_read" {
1651 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1652 let path = value.get("path").and_then(|v| v.as_str()).unwrap_or("file");
1653 let content = value.get("content").and_then(|v| v.as_str()).unwrap_or("");
1654
1655 let max_chars = 4000;
1657 let display_content = if content.len() > max_chars {
1658 let mut snippet = content[..max_chars].to_string();
1659 snippet.push_str("\n...\n[truncated]");
1660 snippet
1661 } else {
1662 content.to_string()
1663 };
1664
1665 return format!("Contents of {}:\n{}", path, display_content);
1666 }
1667 }
1668
1669 if tool_name == "search" {
1670 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1671 let query = value.get("query").and_then(|v| v.as_str()).unwrap_or("");
1672
1673 if let Some(results) = value.get("results").and_then(|v| v.as_array()) {
1674 if results.is_empty() {
1675 return if query.is_empty() {
1676 "Search returned no results.".to_string()
1677 } else {
1678 format!("Search for {:?} returned no results.", query)
1679 };
1680 }
1681
1682 let mut lines = Vec::new();
1683 if query.is_empty() {
1684 lines.push("Search results:".to_string());
1685 } else {
1686 lines.push(format!("Search results for {:?}:", query));
1687 }
1688
1689 for entry in results.iter().take(5) {
1690 let path = entry
1691 .get("path")
1692 .and_then(|v| v.as_str())
1693 .unwrap_or("<unknown>");
1694 let line = entry.get("line").and_then(|v| v.as_u64()).unwrap_or(0);
1695 let snippet = entry
1696 .get("snippet")
1697 .and_then(|v| v.as_str())
1698 .unwrap_or("")
1699 .replace('\n', " ");
1700
1701 lines.push(format!("- {}:{} - {}", path, line, snippet));
1702 }
1703
1704 return lines.join("\n");
1705 }
1706 }
1707 }
1708
1709 if tool_name == "shell" || tool_name == "bash" {
1710 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1711 let std_out = value
1712 .get("stdout")
1713 .and_then(|v| v.as_str())
1714 .unwrap_or(sanitized);
1715 let command = value.get("command").and_then(|v| v.as_str()).unwrap_or("");
1716 let stderr = value
1717 .get("stderr")
1718 .and_then(|v| v.as_str())
1719 .map(|s| s.trim())
1720 .filter(|s| !s.is_empty())
1721 .unwrap_or("");
1722 let mut response = String::new();
1723 if !command.is_empty() {
1724 response.push_str(&format!("Command `{}` output:\n", command));
1725 }
1726 response.push_str(std_out.trim_end());
1727 if !stderr.is_empty() {
1728 response.push_str("\n\nstderr:\n");
1729 response.push_str(stderr);
1730 }
1731 if response.trim().is_empty() {
1732 return "Command completed without output.".to_string();
1733 }
1734 return response;
1735 }
1736 }
1737
1738 sanitized.to_string()
1739 }
1740
1741 fn extract_entities_from_text(&self, text: &str) -> Vec<ExtractedEntity> {
1743 if self.profile.fast_reasoning
1745 && self.fast_provider.is_some()
1746 && self
1747 .profile
1748 .fast_model_tasks
1749 .contains(&"entity_extraction".to_string())
1750 {
1751 debug!("Using fast model for entity extraction");
1753 }
1756
1757 let mut entities = Vec::new();
1758
1759 let url_regex = regex::Regex::new(r"https?://[^\s]+").unwrap();
1764 for mat in url_regex.find_iter(text) {
1765 entities.push(ExtractedEntity {
1766 name: mat.as_str().to_string(),
1767 entity_type: "URL".to_string(),
1768 confidence: 0.9,
1769 });
1770 }
1771
1772 let email_regex =
1774 regex::Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b").unwrap();
1775 for mat in email_regex.find_iter(text) {
1776 entities.push(ExtractedEntity {
1777 name: mat.as_str().to_string(),
1778 entity_type: "Email".to_string(),
1779 confidence: 0.9,
1780 });
1781 }
1782
1783 let quote_regex = regex::Regex::new(r#""([^"]+)""#).unwrap();
1785 for cap in quote_regex.captures_iter(text) {
1786 if let Some(quoted) = cap.get(1) {
1787 entities.push(ExtractedEntity {
1788 name: quoted.as_str().to_string(),
1789 entity_type: "Quote".to_string(),
1790 confidence: 0.7,
1791 });
1792 }
1793 }
1794
1795 entities
1796 }
1797
1798 async fn fast_reasoning(&self, task: &str, input: &str) -> Result<(String, f32)> {
1800 let total_timer = Instant::now();
1801 let result = if let Some(ref fast_provider) = self.fast_provider {
1802 let prompt = format!(
1803 "You are a fast specialist model that assists a more capable agent.\nTask: {}\nInput: {}\n\nRespond with two lines:\nAnswer: <concise result>\nConfidence: <0-1 decimal>",
1804 task, input
1805 );
1806
1807 let fast_temperature = if self.profile.fast_model_temperature.is_finite() {
1808 self.profile.fast_model_temperature.clamp(0.0, 2.0)
1809 } else {
1810 warn!(
1811 "Invalid fast_model_temperature for agent {:?}, using {}",
1812 self.agent_name, DEFAULT_FAST_TEMPERATURE
1813 );
1814 DEFAULT_FAST_TEMPERATURE
1815 };
1816
1817 let config = GenerationConfig {
1818 temperature: Some(fast_temperature),
1819 max_tokens: Some(256), stop_sequences: None,
1821 top_p: Some(DEFAULT_TOP_P),
1822 frequency_penalty: None,
1823 presence_penalty: None,
1824 };
1825
1826 let call_timer = Instant::now();
1827 let response_result = fast_provider.generate(&prompt, &config).await;
1828 self.log_timing("fast_reasoning.generate", call_timer);
1829 let response = response_result?;
1830
1831 let confidence = Self::parse_confidence(&response.content).unwrap_or(0.7);
1832 let cleaned = Self::strip_fast_answer(&response.content);
1833
1834 Ok((cleaned, confidence))
1835 } else {
1836 Ok((String::new(), 0.0))
1838 };
1839
1840 self.log_timing("fast_reasoning.total", total_timer);
1841 result
1842 }
1843
1844 fn should_use_fast_model(&self, task_type: &str, complexity_score: f32) -> bool {
1846 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1848 return false; }
1850
1851 if !self
1853 .profile
1854 .fast_model_tasks
1855 .contains(&task_type.to_string())
1856 {
1857 return false; }
1859
1860 let threshold = self.escalation_threshold();
1862 if complexity_score > threshold {
1863 info!(
1864 "Task complexity {} exceeds threshold {}, using main model",
1865 complexity_score, threshold
1866 );
1867 return false; }
1869
1870 true }
1872
1873 fn extract_concepts_from_text(&self, text: &str) -> Vec<ExtractedConcept> {
1875 let mut concepts = Vec::new();
1876
1877 let concept_keywords = vec![
1879 ("graph", "Knowledge Graph"),
1880 ("memory", "Memory System"),
1881 ("embedding", "Embeddings"),
1882 ("tool", "Tool Usage"),
1883 ("agent", "Agent System"),
1884 ("database", "Database"),
1885 ("query", "Query Processing"),
1886 ("node", "Graph Node"),
1887 ("edge", "Graph Edge"),
1888 ];
1889
1890 let text_lower = text.to_lowercase();
1891 for (keyword, concept_name) in concept_keywords {
1892 if text_lower.contains(keyword) {
1893 concepts.push(ExtractedConcept {
1894 name: concept_name.to_string(),
1895 relevance: 0.6,
1896 });
1897 }
1898 }
1899
1900 concepts
1901 }
1902
1903 pub fn session_id(&self) -> &str {
1905 &self.session_id
1906 }
1907
1908 pub fn profile(&self) -> &AgentProfile {
1910 &self.profile
1911 }
1912
1913 pub fn agent_name(&self) -> Option<&str> {
1915 self.agent_name.as_deref()
1916 }
1917
1918 pub fn conversation_history(&self) -> &[Message] {
1920 &self.conversation_history
1921 }
1922
1923 pub fn load_history(&mut self, limit: i64) -> Result<()> {
1925 self.conversation_history = self.persistence.list_messages(&self.session_id, limit)?;
1926 Ok(())
1927 }
1928
1929 async fn is_tool_allowed(&self, tool_name: &str) -> bool {
1931 {
1933 let cache = self.tool_permission_cache.read().await;
1934 if let Some(&allowed) = cache.get(tool_name) {
1935 return allowed;
1936 }
1937 }
1938
1939 let profile_allowed = self.profile.is_tool_allowed(tool_name);
1941 debug!(
1942 "Profile check for tool '{}': allowed={}, allowed_tools={:?}, denied_tools={:?}",
1943 tool_name, profile_allowed, self.profile.allowed_tools, self.profile.denied_tools
1944 );
1945 if !profile_allowed {
1946 self.tool_permission_cache
1947 .write()
1948 .await
1949 .insert(tool_name.to_string(), false);
1950 return false;
1951 }
1952
1953 let agent_name = self.agent_name.as_deref().unwrap_or("agent");
1955 let decision = self.policy_engine.check(agent_name, "tool_call", tool_name);
1956 debug!(
1957 "Policy check for tool '{}': decision={:?}",
1958 tool_name, decision
1959 );
1960
1961 let allowed = matches!(decision, PolicyDecision::Allow);
1962 self.tool_permission_cache
1963 .write()
1964 .await
1965 .insert(tool_name.to_string(), allowed);
1966 allowed
1967 }
1968
1969 async fn prompt_for_tool_permission(&mut self, tool_name: &str) -> Result<bool> {
1971 info!("Requesting user permission for tool: {}", tool_name);
1972
1973 let tool_description = self
1975 .tool_registry
1976 .get(tool_name)
1977 .map(|t| t.description().to_string())
1978 .unwrap_or_else(|| "No description available".to_string());
1979
1980 let prompt_args = json!({
1982 "prompt": format!(
1983 "The agent wants to use the '{}' tool.\n\nDescription: {}\n\nDo you want to allow this?",
1984 tool_name,
1985 tool_description
1986 ),
1987 "input_type": "boolean",
1988 "required": true,
1989 });
1990
1991 match self.tool_registry.execute("prompt_user", prompt_args).await {
1992 Ok(result) if result.success => {
1993 info!("prompt_user output: {}", result.output);
1994
1995 let allowed =
1997 if let Ok(response_json) = serde_json::from_str::<Value>(&result.output) {
1998 info!("Parsed JSON response: {:?}", response_json);
1999 let value = response_json["response"].as_bool();
2001 info!("Extracted boolean value: {:?}", value);
2002 value.unwrap_or(false)
2003 } else {
2004 info!("Failed to parse JSON, trying plain text fallback");
2005 let response = result.output.trim().to_lowercase();
2007 let parsed = response == "yes" || response == "y" || response == "true";
2008 info!("Plain text parse result for '{}': {}", response, parsed);
2009 parsed
2010 };
2011
2012 info!(
2013 "User {} tool '{}'",
2014 if allowed { "allowed" } else { "denied" },
2015 tool_name
2016 );
2017
2018 if allowed {
2019 self.add_allowed_tool(tool_name).await;
2021 } else {
2022 self.add_denied_tool(tool_name).await;
2024 }
2025
2026 Ok(allowed)
2027 }
2028 Ok(result) => {
2029 warn!("Failed to prompt user: {:?}", result.error);
2030 Ok(false)
2031 }
2032 Err(e) => {
2033 warn!("Error prompting user for permission: {}", e);
2034 Ok(false)
2035 }
2036 }
2037 }
2038
2039 async fn add_allowed_tool(&mut self, tool_name: &str) {
2041 let tools = self.profile.allowed_tools.get_or_insert_with(Vec::new);
2042 if !tools.contains(&tool_name.to_string()) {
2043 tools.push(tool_name.to_string());
2044 info!("Added '{}' to allowed tools list", tool_name);
2045 }
2046 self.tool_permission_cache.write().await.remove(tool_name);
2048 }
2049
2050 async fn add_denied_tool(&mut self, tool_name: &str) {
2052 let tools = self.profile.denied_tools.get_or_insert_with(Vec::new);
2053 if !tools.contains(&tool_name.to_string()) {
2054 tools.push(tool_name.to_string());
2055 info!("Added '{}' to denied tools list", tool_name);
2056 }
2057 self.tool_permission_cache.write().await.remove(tool_name);
2059 }
2060
2061 async fn execute_tool(
2063 &self,
2064 run_id: &str,
2065 tool_name: &str,
2066 args: &Value,
2067 ) -> Result<ToolResult> {
2068 let exec_result = self.tool_registry.execute(tool_name, args.clone()).await;
2070 let result = match exec_result {
2071 Ok(res) => res,
2072 Err(err) => ToolResult::failure(err.to_string()),
2073 };
2074
2075 let result_json = serde_json::json!({
2077 "output": result.output,
2078 "success": result.success,
2079 "error": result.error,
2080 });
2081
2082 let error_str = result.error.as_deref();
2083 self.persistence
2084 .log_tool(
2085 &self.session_id,
2086 self.agent_name.as_deref().unwrap_or("unknown"),
2087 run_id,
2088 tool_name,
2089 args,
2090 &result_json,
2091 result.success,
2092 error_str,
2093 )
2094 .context("Failed to log tool execution")?;
2095
2096 Ok(result)
2097 }
2098
2099 pub fn tool_registry(&self) -> &ToolRegistry {
2101 &self.tool_registry
2102 }
2103
2104 pub fn policy_engine(&self) -> &PolicyEngine {
2106 &self.policy_engine
2107 }
2108
2109 pub fn set_policy_engine(&mut self, policy_engine: Arc<PolicyEngine>) {
2111 self.policy_engine = policy_engine;
2112 }
2113
2114 pub fn set_speak_responses(&mut self, enabled: bool) {
2116 #[cfg(target_os = "macos")]
2117 {
2118 self.speak_responses = enabled;
2119 }
2120 #[cfg(not(target_os = "macos"))]
2121 {
2122 let _ = enabled;
2123 self.speak_responses = false;
2124 }
2125 }
2126
2127 pub async fn generate_embedding(&self, text: &str) -> Option<i64> {
2130 if let Some(client) = &self.embeddings_client {
2131 if !text.trim().is_empty() {
2132 match client.embed_batch(&[text]).await {
2133 Ok(mut embeddings) => {
2134 if let Some(embedding) = embeddings.pop() {
2135 if !embedding.is_empty() {
2136 match self.persistence.insert_memory_vector(
2137 &self.session_id,
2138 None, &embedding,
2140 ) {
2141 Ok(emb_id) => return Some(emb_id),
2142 Err(err) => {
2143 warn!("Failed to persist embedding: {}", err);
2144 }
2145 }
2146 }
2147 }
2148 }
2149 Err(err) => {
2150 warn!("Failed to generate embedding: {}", err);
2151 }
2152 }
2153 }
2154 }
2155 None
2156 }
2157
2158 fn evaluate_graph_for_next_action(
2160 &self,
2161 user_message_id: i64,
2162 assistant_message_id: i64,
2163 response_content: &str,
2164 tool_invocations: &[ToolInvocation],
2165 ) -> Result<Option<String>> {
2166 let nodes = self.persistence.list_graph_nodes(
2168 &self.session_id,
2169 Some(NodeType::Message),
2170 Some(50),
2171 )?;
2172
2173 let mut assistant_node_id = None;
2174 let mut _user_node_id = None;
2175
2176 for node in &nodes {
2177 if let Some(msg_id) = node.properties["message_id"].as_i64() {
2178 if msg_id == assistant_message_id {
2179 assistant_node_id = Some(node.id);
2180 } else if msg_id == user_message_id {
2181 _user_node_id = Some(node.id);
2182 }
2183 }
2184 }
2185
2186 if assistant_node_id.is_none() {
2187 debug!("Assistant message node not found in graph");
2188 return Ok(None);
2189 }
2190
2191 let assistant_node_id = assistant_node_id.unwrap();
2192
2193 let neighbors = self.persistence.traverse_neighbors(
2195 &self.session_id,
2196 assistant_node_id,
2197 TraversalDirection::Both,
2198 2, )?;
2200
2201 let goal_nodes =
2203 self.persistence
2204 .list_graph_nodes(&self.session_id, Some(NodeType::Goal), Some(10))?;
2205
2206 let mut pending_goals = Vec::new();
2207 let mut completed_goals = Vec::new();
2208
2209 for goal in &goal_nodes {
2210 if let Some(status) = goal.properties["status"].as_str() {
2211 match status {
2212 "pending" | "in_progress" => {
2213 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2214 pending_goals.push(goal_text.to_string());
2215 }
2216 }
2217 "completed" => {
2218 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2219 completed_goals.push(goal_text.to_string());
2220 }
2221 }
2222 _ => {}
2223 }
2224 }
2225 }
2226
2227 let tool_nodes = self.persistence.list_graph_nodes(
2229 &self.session_id,
2230 Some(NodeType::ToolResult),
2231 Some(10),
2232 )?;
2233
2234 let mut recent_tool_failures = Vec::new();
2235 let mut recent_tool_successes = Vec::new();
2236
2237 for tool_node in &tool_nodes {
2238 if let Some(success) = tool_node.properties["success"].as_bool() {
2239 let tool_name = tool_node.properties["tool"].as_str().unwrap_or("unknown");
2240 if success {
2241 recent_tool_successes.push(tool_name.to_string());
2242 } else {
2243 recent_tool_failures.push(tool_name.to_string());
2244 }
2245 }
2246 }
2247
2248 let mut key_entities = HashSet::new();
2250 let mut key_concepts = HashSet::new();
2251
2252 for neighbor in &neighbors {
2253 match neighbor.node_type {
2254 NodeType::Entity => {
2255 if let Some(name) = neighbor.properties["name"].as_str() {
2256 key_entities.insert(name.to_string());
2257 }
2258 }
2259 NodeType::Concept => {
2260 if let Some(name) = neighbor.properties["name"].as_str() {
2261 key_concepts.insert(name.to_string());
2262 }
2263 }
2264 _ => {}
2265 }
2266 }
2267
2268 let recommendation = self.generate_action_recommendation(
2270 &pending_goals,
2271 &completed_goals,
2272 &recent_tool_failures,
2273 &recent_tool_successes,
2274 &key_entities,
2275 &key_concepts,
2276 response_content,
2277 tool_invocations,
2278 );
2279
2280 if let Some(ref rec) = recommendation {
2282 let properties = json!({
2283 "recommendation": rec,
2284 "user_message_id": user_message_id,
2285 "assistant_message_id": assistant_message_id,
2286 "pending_goals": pending_goals,
2287 "completed_goals": completed_goals,
2288 "tool_failures": recent_tool_failures,
2289 "tool_successes": recent_tool_successes,
2290 "key_entities": key_entities.into_iter().collect::<Vec<_>>(),
2291 "key_concepts": key_concepts.into_iter().collect::<Vec<_>>(),
2292 "timestamp": Utc::now().to_rfc3339(),
2293 });
2294
2295 let rec_node_id = self.persistence.insert_graph_node(
2296 &self.session_id,
2297 NodeType::Event,
2298 "NextActionRecommendation",
2299 &properties,
2300 None,
2301 )?;
2302
2303 self.persistence.insert_graph_edge(
2305 &self.session_id,
2306 assistant_node_id,
2307 rec_node_id,
2308 EdgeType::Produces,
2309 Some("recommends"),
2310 None,
2311 0.8,
2312 )?;
2313 }
2314
2315 Ok(recommendation)
2316 }
2317
2318 fn generate_action_recommendation(
2320 &self,
2321 pending_goals: &[String],
2322 completed_goals: &[String],
2323 recent_tool_failures: &[String],
2324 _recent_tool_successes: &[String],
2325 _key_entities: &HashSet<String>,
2326 key_concepts: &HashSet<String>,
2327 response_content: &str,
2328 tool_invocations: &[ToolInvocation],
2329 ) -> Option<String> {
2330 let mut recommendations = Vec::new();
2331
2332 if !pending_goals.is_empty() {
2334 let goals_str = pending_goals.join(", ");
2335 recommendations.push(format!("Continue working on pending goals: {}", goals_str));
2336 }
2337
2338 if !recent_tool_failures.is_empty() {
2340 let unique_failures: HashSet<_> = recent_tool_failures.iter().collect();
2341 for tool in unique_failures {
2342 recommendations.push(format!(
2343 "Consider alternative approach for failed tool: {}",
2344 tool
2345 ));
2346 }
2347 }
2348
2349 let response_lower = response_content.to_lowercase();
2351 if response_lower.contains("error") || response_lower.contains("failed") {
2352 recommendations.push("Investigate and resolve the reported error".to_string());
2353 }
2354
2355 if response_lower.contains("?") || response_lower.contains("unclear") {
2356 recommendations.push("Clarify the uncertain aspects mentioned".to_string());
2357 }
2358
2359 if tool_invocations.len() > 1 {
2361 let tool_sequence: Vec<_> = tool_invocations.iter().map(|t| t.name.as_str()).collect();
2362
2363 if tool_sequence.contains(&"file_read") && !tool_sequence.contains(&"file_write") {
2365 recommendations
2366 .push("Consider modifying the read files if changes are needed".to_string());
2367 }
2368
2369 if tool_sequence.contains(&"search")
2370 && tool_invocations.last().is_some_and(|t| t.success)
2371 {
2372 recommendations
2373 .push("Examine the search results for relevant information".to_string());
2374 }
2375 }
2376
2377 if key_concepts.contains("Knowledge Graph") || key_concepts.contains("Graph Node") {
2379 recommendations
2380 .push("Consider visualizing or querying the graph structure".to_string());
2381 }
2382
2383 if key_concepts.contains("Database") || key_concepts.contains("Query Processing") {
2384 recommendations.push("Verify data integrity and query performance".to_string());
2385 }
2386
2387 if !completed_goals.is_empty() && !pending_goals.is_empty() {
2389 recommendations.push(format!(
2390 "Build on completed work ({} done) to address remaining goals ({} pending)",
2391 completed_goals.len(),
2392 pending_goals.len()
2393 ));
2394 }
2395
2396 if recommendations.is_empty() {
2398 if completed_goals.len() > pending_goals.len() && recent_tool_failures.is_empty() {
2400 Some(
2401 "Current objectives appear satisfied. Ready for new tasks or refinements."
2402 .to_string(),
2403 )
2404 } else {
2405 None
2406 }
2407 } else {
2408 Some(recommendations[0].clone())
2410 }
2411 }
2412
2413 fn log_timing(&self, stage: &str, start: Instant) {
2414 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
2415 let agent_label = self.agent_name.as_deref().unwrap_or("unnamed");
2416 info!(
2417 target: "agent_timing",
2418 "stage={} duration_ms={:.2} agent={} session_id={}",
2419 stage,
2420 duration_ms,
2421 agent_label,
2422 self.session_id
2423 );
2424 }
2425}
2426
2427fn preview_text(content: &str) -> String {
2428 const MAX_CHARS: usize = 80;
2429 let trimmed = content.trim();
2430 let mut preview = String::new();
2431 for (idx, ch) in trimmed.chars().enumerate() {
2432 if idx >= MAX_CHARS {
2433 preview.push_str("...");
2434 break;
2435 }
2436 preview.push(ch);
2437 }
2438 if preview.is_empty() {
2439 trimmed.to_string()
2440 } else {
2441 preview
2442 }
2443}
2444
2445fn preview_json_value(value: &Value) -> String {
2446 match value {
2447 Value::String(text) => preview_text(text),
2448 Value::Null => "null".to_string(),
2449 _ => {
2450 let raw = value.to_string();
2451 if raw.len() > 80 {
2452 format!("{}...", &raw[..77])
2453 } else {
2454 raw
2455 }
2456 }
2457 }
2458}
2459
2460fn payload_field(payload: &Value, key: &str) -> Value {
2461 payload.get(key).cloned().unwrap_or(Value::Null)
2462}
2463
2464#[cfg(test)]
2465mod tests {
2466 use super::*;
2467 use crate::agent::providers::MockProvider;
2468 use crate::config::AgentProfile;
2469 use crate::embeddings::{EmbeddingsClient, EmbeddingsService};
2470 use async_trait::async_trait;
2471 use tempfile::tempdir;
2472
2473 fn create_test_agent(session_id: &str) -> (AgentCore, tempfile::TempDir) {
2474 create_test_agent_with_embeddings(session_id, None)
2475 }
2476
2477 fn create_test_agent_with_embeddings(
2478 session_id: &str,
2479 embeddings_client: Option<EmbeddingsClient>,
2480 ) -> (AgentCore, tempfile::TempDir) {
2481 let dir = tempdir().unwrap();
2482 let db_path = dir.path().join("test.duckdb");
2483 let persistence = Persistence::new(&db_path).unwrap();
2484
2485 let profile = AgentProfile {
2486 prompt: Some("You are a helpful assistant.".to_string()),
2487 style: None,
2488 temperature: Some(0.7),
2489 model_provider: None,
2490 model_name: None,
2491 allowed_tools: None,
2492 denied_tools: None,
2493 memory_k: 5,
2494 top_p: 0.9,
2495 max_context_tokens: Some(2048),
2496 enable_graph: false,
2497 graph_memory: false,
2498 auto_graph: false,
2499 graph_steering: false,
2500 graph_depth: 3,
2501 graph_weight: 0.5,
2502 graph_threshold: 0.7,
2503 fast_reasoning: false,
2504 fast_model_provider: None,
2505 fast_model_name: None,
2506 fast_model_temperature: 0.3,
2507 fast_model_tasks: vec![],
2508 escalation_threshold: 0.6,
2509 show_reasoning: false,
2510 enable_audio_transcription: false,
2511 audio_response_mode: "immediate".to_string(),
2512 audio_scenario: None,
2513 };
2514
2515 let provider = Arc::new(MockProvider::new("This is a test response."));
2516 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2517 let policy_engine = Arc::new(PolicyEngine::new());
2518
2519 (
2520 AgentCore::new(
2521 profile,
2522 provider,
2523 embeddings_client,
2524 persistence,
2525 session_id.to_string(),
2526 Some(session_id.to_string()),
2527 tool_registry,
2528 policy_engine,
2529 false,
2530 ),
2531 dir,
2532 )
2533 }
2534
2535 fn create_fast_reasoning_agent(
2536 session_id: &str,
2537 fast_output: &str,
2538 ) -> (AgentCore, tempfile::TempDir) {
2539 let dir = tempdir().unwrap();
2540 let db_path = dir.path().join("fast.duckdb");
2541 let persistence = Persistence::new(&db_path).unwrap();
2542
2543 let profile = AgentProfile {
2544 prompt: Some("You are a helpful assistant.".to_string()),
2545 style: None,
2546 temperature: Some(0.7),
2547 model_provider: None,
2548 model_name: None,
2549 allowed_tools: None,
2550 denied_tools: None,
2551 memory_k: 5,
2552 top_p: 0.9,
2553 max_context_tokens: Some(2048),
2554 enable_graph: false,
2555 graph_memory: false,
2556 auto_graph: false,
2557 graph_steering: false,
2558 graph_depth: 3,
2559 graph_weight: 0.5,
2560 graph_threshold: 0.7,
2561 fast_reasoning: true,
2562 fast_model_provider: Some("mock".to_string()),
2563 fast_model_name: Some("mock-fast".to_string()),
2564 fast_model_temperature: 0.3,
2565 fast_model_tasks: vec!["entity_extraction".to_string()],
2566 escalation_threshold: 0.5,
2567 show_reasoning: false,
2568 enable_audio_transcription: false,
2569 audio_response_mode: "immediate".to_string(),
2570 audio_scenario: None,
2571 };
2572
2573 profile.validate().unwrap();
2574
2575 let provider = Arc::new(MockProvider::new("This is a test response."));
2576 let fast_provider = Arc::new(MockProvider::new(fast_output.to_string()));
2577 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2578 let policy_engine = Arc::new(PolicyEngine::new());
2579
2580 (
2581 AgentCore::new(
2582 profile,
2583 provider,
2584 None,
2585 persistence,
2586 session_id.to_string(),
2587 Some(session_id.to_string()),
2588 tool_registry,
2589 policy_engine,
2590 false,
2591 )
2592 .with_fast_provider(fast_provider),
2593 dir,
2594 )
2595 }
2596
2597 #[derive(Clone)]
2598 struct KeywordEmbeddingsService;
2599
2600 #[async_trait]
2601 impl EmbeddingsService for KeywordEmbeddingsService {
2602 async fn create_embeddings(
2603 &self,
2604 _model: &str,
2605 inputs: Vec<String>,
2606 ) -> Result<Vec<Vec<f32>>> {
2607 Ok(inputs
2608 .into_iter()
2609 .map(|input| keyword_embedding(&input))
2610 .collect())
2611 }
2612 }
2613
2614 fn keyword_embedding(input: &str) -> Vec<f32> {
2615 let lower = input.to_ascii_lowercase();
2616 let alpha = if lower.contains("alpha") { 1.0 } else { 0.0 };
2617 let beta = if lower.contains("beta") { 1.0 } else { 0.0 };
2618 vec![alpha, beta]
2619 }
2620
2621 fn test_embeddings_client() -> EmbeddingsClient {
2622 EmbeddingsClient::with_service(
2623 "test",
2624 Arc::new(KeywordEmbeddingsService) as Arc<dyn EmbeddingsService>,
2625 )
2626 }
2627
2628 #[tokio::test]
2629 async fn test_agent_core_run_step() {
2630 let (mut agent, _dir) = create_test_agent("test-session-1");
2631
2632 let output = agent.run_step("Hello, how are you?").await.unwrap();
2633
2634 assert!(!output.response.is_empty());
2635 assert!(output.token_usage.is_some());
2636 assert_eq!(output.tool_invocations.len(), 0);
2637 }
2638
2639 #[tokio::test]
2640 async fn fast_model_short_circuits_when_confident() {
2641 let (mut agent, _dir) = create_fast_reasoning_agent(
2642 "fast-confident",
2643 "Answer: Entities detected.\nConfidence: 0.9",
2644 );
2645
2646 let output = agent
2647 .run_step("Extract the entities mentioned in this string.")
2648 .await
2649 .unwrap();
2650
2651 assert!(output
2652 .finish_reason
2653 .unwrap_or_default()
2654 .contains("fast_model"));
2655 assert!(output.response.contains("Entities detected"));
2656 }
2657
2658 #[tokio::test]
2659 async fn fast_model_only_hints_when_low_confidence() {
2660 let (mut agent, _dir) =
2661 create_fast_reasoning_agent("fast-hint", "Answer: Unsure.\nConfidence: 0.2");
2662
2663 let output = agent
2664 .run_step("Extract the entities mentioned in this string.")
2665 .await
2666 .unwrap();
2667
2668 assert_eq!(output.finish_reason.as_deref(), Some("stop"));
2669 assert_eq!(output.response, "This is a test response.");
2670 }
2671
2672 #[tokio::test]
2673 async fn test_agent_core_conversation_history() {
2674 let (mut agent, _dir) = create_test_agent("test-session-2");
2675
2676 agent.run_step("First message").await.unwrap();
2677 agent.run_step("Second message").await.unwrap();
2678
2679 let history = agent.conversation_history();
2680 assert_eq!(history.len(), 4); assert_eq!(history[0].role, MessageRole::User);
2682 assert_eq!(history[1].role, MessageRole::Assistant);
2683 }
2684
2685 #[tokio::test]
2686 async fn test_agent_core_session_switch() {
2687 let (mut agent, _dir) = create_test_agent("session-1");
2688
2689 agent.run_step("Message in session 1").await.unwrap();
2690 assert_eq!(agent.session_id(), "session-1");
2691
2692 agent = agent.with_session("session-2".to_string());
2693 assert_eq!(agent.session_id(), "session-2");
2694 assert_eq!(agent.conversation_history().len(), 0);
2695 }
2696
2697 #[tokio::test]
2698 async fn agent_session_avoids_sync_namespace() {
2699 let (mut agent, _dir) = create_test_agent(SYNC_GRAPH_NAMESPACE);
2700
2701 assert_eq!(
2702 agent.session_id(),
2703 format!("{}-agent", SYNC_GRAPH_NAMESPACE)
2704 );
2705
2706 agent = agent.with_session(SYNC_GRAPH_NAMESPACE.to_string());
2707 assert_eq!(
2708 agent.session_id(),
2709 format!("{}-agent", SYNC_GRAPH_NAMESPACE)
2710 );
2711 }
2712
2713 #[tokio::test]
2714 async fn test_agent_core_build_prompt() {
2715 let (agent, _dir) = create_test_agent("test-session-3");
2716
2717 let context = vec![
2718 Message {
2719 id: 1,
2720 session_id: "test-session-3".to_string(),
2721 role: MessageRole::User,
2722 content: "Previous question".to_string(),
2723 created_at: Utc::now(),
2724 },
2725 Message {
2726 id: 2,
2727 session_id: "test-session-3".to_string(),
2728 role: MessageRole::Assistant,
2729 content: "Previous answer".to_string(),
2730 created_at: Utc::now(),
2731 },
2732 ];
2733
2734 let prompt = agent
2735 .build_prompt("Current question", &context)
2736 .await
2737 .unwrap();
2738
2739 assert!(prompt.contains("You are a helpful assistant"));
2740 assert!(prompt.contains("Previous conversation"));
2741 assert!(prompt.contains("user: Previous question"));
2742 assert!(prompt.contains("assistant: Previous answer"));
2743 assert!(prompt.contains("user: Current question"));
2744 }
2745
2746 #[tokio::test]
2747 async fn test_agent_core_persistence() {
2748 let (mut agent, _dir) = create_test_agent("persist-test");
2749
2750 agent.run_step("Test message").await.unwrap();
2751
2752 let messages = agent
2754 .persistence
2755 .list_messages("persist-test", 100)
2756 .unwrap();
2757
2758 assert_eq!(messages.len(), 2); assert_eq!(messages[0].role, MessageRole::User);
2760 assert_eq!(messages[0].content, "Test message");
2761 }
2762
2763 #[tokio::test]
2764 async fn store_message_records_embeddings() {
2765 let (agent, _dir) =
2766 create_test_agent_with_embeddings("embedding-store", Some(test_embeddings_client()));
2767
2768 let message_id = agent
2769 .store_message(MessageRole::User, "Alpha detail")
2770 .await
2771 .unwrap();
2772
2773 let query = vec![1.0f32, 0.0];
2774 let recalled = agent
2775 .persistence
2776 .recall_top_k("embedding-store", &query, 1)
2777 .unwrap();
2778
2779 assert_eq!(recalled.len(), 1);
2780 assert_eq!(recalled[0].0.message_id, Some(message_id));
2781 }
2782
2783 #[tokio::test]
2784 async fn recall_memories_appends_semantic_matches() {
2785 let (agent, _dir) =
2786 create_test_agent_with_embeddings("semantic-recall", Some(test_embeddings_client()));
2787
2788 agent
2789 .store_message(MessageRole::User, "Alpha question")
2790 .await
2791 .unwrap();
2792 agent
2793 .store_message(MessageRole::Assistant, "Alpha answer")
2794 .await
2795 .unwrap();
2796 agent
2797 .store_message(MessageRole::User, "Beta prompt")
2798 .await
2799 .unwrap();
2800 agent
2801 .store_message(MessageRole::Assistant, "Beta reply")
2802 .await
2803 .unwrap();
2804
2805 let recall = agent.recall_memories("alpha follow up").await.unwrap();
2806 assert!(matches!(
2807 recall.stats.as_ref().map(|s| &s.strategy),
2808 Some(MemoryRecallStrategy::Semantic { .. })
2809 ));
2810 assert_eq!(
2811 recall
2812 .stats
2813 .as_ref()
2814 .map(|s| s.matches.len())
2815 .unwrap_or_default(),
2816 2
2817 );
2818
2819 let recalled = recall.messages;
2820 assert_eq!(recalled.len(), 4);
2821 assert_eq!(recalled[0].content, "Beta prompt");
2822 assert_eq!(recalled[1].content, "Beta reply");
2823
2824 let tail: Vec<_> = recalled[2..].iter().map(|m| m.content.as_str()).collect();
2825 assert!(tail.contains(&"Alpha question"));
2826 assert!(tail.contains(&"Alpha answer"));
2827 }
2828
2829 #[tokio::test]
2830 async fn test_agent_tool_permission_allowed() {
2831 let dir = tempdir().unwrap();
2832 let db_path = dir.path().join("test.duckdb");
2833 let persistence = Persistence::new(&db_path).unwrap();
2834
2835 let mut profile = AgentProfile {
2836 prompt: Some("Test".to_string()),
2837 style: None,
2838 temperature: Some(0.7),
2839 model_provider: None,
2840 model_name: None,
2841 allowed_tools: Some(vec!["echo".to_string()]),
2842 denied_tools: None,
2843 memory_k: 5,
2844 top_p: 0.9,
2845 max_context_tokens: Some(2048),
2846 enable_graph: false,
2847 graph_memory: false,
2848 auto_graph: false,
2849 graph_steering: false,
2850 graph_depth: 3,
2851 graph_weight: 0.5,
2852 graph_threshold: 0.7,
2853 fast_reasoning: false,
2854 fast_model_provider: None,
2855 fast_model_name: None,
2856 fast_model_temperature: 0.3,
2857 fast_model_tasks: vec![],
2858 escalation_threshold: 0.6,
2859 show_reasoning: false,
2860 enable_audio_transcription: false,
2861 audio_response_mode: "immediate".to_string(),
2862 audio_scenario: None,
2863 };
2864
2865 let provider = Arc::new(MockProvider::new("Test"));
2866 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2867
2868 let mut policy_engine = PolicyEngine::new();
2870 policy_engine.add_rule(crate::policy::PolicyRule {
2871 agent: "*".to_string(),
2872 action: "tool_call".to_string(),
2873 resource: "*".to_string(),
2874 effect: crate::policy::PolicyEffect::Allow,
2875 });
2876 let policy_engine = Arc::new(policy_engine);
2877
2878 let agent = AgentCore::new(
2879 profile.clone(),
2880 provider.clone(),
2881 None,
2882 persistence.clone(),
2883 "test-session".to_string(),
2884 Some("policy-test".to_string()),
2885 tool_registry.clone(),
2886 policy_engine.clone(),
2887 false,
2888 );
2889
2890 assert!(agent.is_tool_allowed("echo").await);
2891 assert!(!agent.is_tool_allowed("calculator").await);
2892
2893 profile.allowed_tools = None;
2895 profile.denied_tools = Some(vec!["calculator".to_string()]);
2896
2897 let agent = AgentCore::new(
2898 profile,
2899 provider,
2900 None,
2901 persistence,
2902 "test-session-2".to_string(),
2903 Some("policy-test-2".to_string()),
2904 tool_registry,
2905 policy_engine,
2906 false,
2907 );
2908
2909 assert!(agent.is_tool_allowed("echo").await);
2910 assert!(!agent.is_tool_allowed("calculator").await);
2911 }
2912
2913 #[tokio::test]
2914 async fn test_agent_tool_execution_with_logging() {
2915 let dir = tempdir().unwrap();
2916 let db_path = dir.path().join("test.duckdb");
2917 let persistence = Persistence::new(&db_path).unwrap();
2918
2919 let profile = AgentProfile {
2920 prompt: Some("Test".to_string()),
2921 style: None,
2922 temperature: Some(0.7),
2923 model_provider: None,
2924 model_name: None,
2925 allowed_tools: Some(vec!["echo".to_string()]),
2926 denied_tools: None,
2927 memory_k: 5,
2928 top_p: 0.9,
2929 max_context_tokens: Some(2048),
2930 enable_graph: false,
2931 graph_memory: false,
2932 auto_graph: false,
2933 graph_steering: false,
2934 graph_depth: 3,
2935 graph_weight: 0.5,
2936 graph_threshold: 0.7,
2937 fast_reasoning: false,
2938 fast_model_provider: None,
2939 fast_model_name: None,
2940 fast_model_temperature: 0.3,
2941 fast_model_tasks: vec![],
2942 escalation_threshold: 0.6,
2943 show_reasoning: false,
2944 enable_audio_transcription: false,
2945 audio_response_mode: "immediate".to_string(),
2946 audio_scenario: None,
2947 };
2948
2949 let provider = Arc::new(MockProvider::new("Test"));
2950
2951 let mut tool_registry = crate::tools::ToolRegistry::new();
2953 tool_registry.register(Arc::new(crate::tools::builtin::EchoTool::new()));
2954
2955 let policy_engine = Arc::new(PolicyEngine::new());
2956
2957 let agent = AgentCore::new(
2958 profile,
2959 provider,
2960 None,
2961 persistence.clone(),
2962 "tool-exec-test".to_string(),
2963 Some("tool-agent".to_string()),
2964 Arc::new(tool_registry),
2965 policy_engine,
2966 false,
2967 );
2968
2969 let args = serde_json::json!({"message": "test message"});
2971 let result = agent
2972 .execute_tool("run-tool-test", "echo", &args)
2973 .await
2974 .unwrap();
2975
2976 assert!(result.success);
2977 assert_eq!(result.output, "test message");
2978
2979 }
2981
2982 #[tokio::test]
2983 async fn test_agent_tool_registry_access() {
2984 let (agent, _dir) = create_test_agent("registry-test");
2985
2986 let registry = agent.tool_registry();
2987 assert!(registry.is_empty());
2988 }
2989
2990 #[test]
2991 fn test_goal_requires_tool_detection() {
2992 assert!(AgentCore::goal_requires_tool(
2993 "List the files in this directory"
2994 ));
2995 assert!(AgentCore::goal_requires_tool("Run the tests"));
2996 assert!(!AgentCore::goal_requires_tool("Explain recursion"));
2997 assert!(AgentCore::goal_requires_tool(
2998 "Tell me about the project in this directory"
2999 ));
3000 }
3001
3002 #[test]
3003 fn test_infer_goal_tool_action_project_description() {
3004 let query = "Tell me about the project in this directory";
3005 let inferred = AgentCore::infer_goal_tool_action(query)
3006 .expect("Should infer a tool for project description");
3007 let (tool, args) = inferred;
3008 assert!(
3009 tool == "file_read" || tool == "search",
3010 "unexpected tool: {}",
3011 tool
3012 );
3013 if tool == "file_read" {
3014 assert!(args.get("path").is_some());
3016 assert!(args.get("max_bytes").is_some());
3017 } else {
3018 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
3020 assert!(query.contains("Cargo.toml") || query.contains("package.json"));
3021 }
3022 }
3023}