1use crate::agent::model::{GenerationConfig, ModelProvider};
6use futures::Stream;
7use std::pin::Pin;
8pub use crate::agent::output::{
9 AgentOutput, GraphDebugInfo, GraphDebugNode, MemoryRecallMatch, MemoryRecallStats,
10 MemoryRecallStrategy, ToolInvocation,
11};
12use crate::config::agent::AgentProfile;
13use crate::embeddings::EmbeddingsClient;
14use crate::persistence::Persistence;
15use crate::policy::{PolicyDecision, PolicyEngine};
16use crate::spec::AgentSpec;
17use crate::tools::{ToolRegistry, ToolResult};
18use crate::types::{Message, MessageRole};
19use crate::SYNC_GRAPH_NAMESPACE;
20use anyhow::{Context, Result};
21use chrono::Utc;
22use serde_json::{json, Value};
23use spec_ai_knowledge_graph::{EdgeType, NodeType, TraversalDirection};
24use std::collections::{HashMap, HashSet};
25use std::path::Path;
26use std::sync::Arc;
27use std::time::Instant;
28use tokio::sync::RwLock;
29use tracing::{debug, info, warn};
30
31const DEFAULT_MAIN_TEMPERATURE: f32 = 0.7;
32const DEFAULT_TOP_P: f32 = 0.9;
33const DEFAULT_FAST_TEMPERATURE: f32 = 0.3;
34const DEFAULT_ESCALATION_THRESHOLD: f32 = 0.6;
35
36struct RecallResult {
37 messages: Vec<Message>,
38 stats: Option<MemoryRecallStats>,
39}
40
41struct ExtractedEntity {
43 name: String,
44 entity_type: String,
45 confidence: f32,
46}
47
48struct ExtractedConcept {
50 name: String,
51 relevance: f32,
52}
53
54#[derive(Debug, Clone)]
55struct GoalContext {
56 message_id: i64,
57 text: String,
58 requires_tool: bool,
59 satisfied: bool,
60 node_id: Option<i64>,
61}
62
63impl GoalContext {
64 fn new(message_id: i64, text: &str, requires_tool: bool, node_id: Option<i64>) -> Self {
65 Self {
66 message_id,
67 text: text.to_string(),
68 requires_tool,
69 satisfied: !requires_tool,
70 node_id,
71 }
72 }
73}
74
75pub struct AgentCore {
77 profile: AgentProfile,
79 provider: Arc<dyn ModelProvider>,
81 fast_provider: Option<Arc<dyn ModelProvider>>,
83 embeddings_client: Option<EmbeddingsClient>,
85 persistence: Persistence,
87 session_id: String,
89 agent_name: Option<String>,
91 conversation_history: Vec<Message>,
93 tool_registry: Arc<ToolRegistry>,
95 policy_engine: Arc<PolicyEngine>,
97 tool_permission_cache: Arc<RwLock<HashMap<String, bool>>>,
99 speak_responses: bool,
101}
102
103impl AgentCore {
104 fn sanitize_session_id(session_id: String) -> (String, bool) {
106 if session_id == SYNC_GRAPH_NAMESPACE {
107 (format!("{}-agent", session_id), true)
108 } else {
109 (session_id, false)
110 }
111 }
112
113 pub fn new(
115 profile: AgentProfile,
116 provider: Arc<dyn ModelProvider>,
117 embeddings_client: Option<EmbeddingsClient>,
118 persistence: Persistence,
119 session_id: String,
120 agent_name: Option<String>,
121 tool_registry: Arc<ToolRegistry>,
122 policy_engine: Arc<PolicyEngine>,
123 speak_responses: bool,
124 ) -> Self {
125 let (session_id, rewrote_namespace) = Self::sanitize_session_id(session_id);
126 if rewrote_namespace {
127 warn!(
128 "Session namespace '{}' is reserved for sync; using '{}' for agent graph state",
129 SYNC_GRAPH_NAMESPACE, session_id
130 );
131 }
132
133 Self {
134 profile,
135 provider,
136 fast_provider: None,
137 embeddings_client,
138 persistence,
139 session_id,
140 agent_name,
141 conversation_history: Vec::new(),
142 tool_registry,
143 policy_engine,
144 tool_permission_cache: Arc::new(RwLock::new(HashMap::new())),
145 speak_responses,
146 }
147 }
148
149 pub fn with_fast_provider(mut self, fast_provider: Arc<dyn ModelProvider>) -> Self {
151 self.fast_provider = Some(fast_provider);
152 self
153 }
154
155 pub fn with_session(mut self, session_id: String) -> Self {
157 let (session_id, rewrote_namespace) = Self::sanitize_session_id(session_id);
158 if rewrote_namespace {
159 warn!(
160 "Session namespace '{}' is reserved for sync; using '{}' for agent graph state",
161 SYNC_GRAPH_NAMESPACE, session_id
162 );
163 }
164 self.session_id = session_id;
165 self.conversation_history.clear();
166 self.tool_permission_cache = Arc::new(RwLock::new(HashMap::new()));
167 self
168 }
169
170 pub async fn run_step(&mut self, input: &str) -> Result<AgentOutput> {
172 let run_id = format!("run-{}", Utc::now().timestamp_micros());
173 let total_timer = Instant::now();
174
175 let recall_timer = Instant::now();
177 let recall_result = self.recall_memories(input).await?;
178 self.log_timing("run_step.recall_memories", recall_timer);
179 let recalled_messages = recall_result.messages;
180 let recall_stats = recall_result.stats;
181
182 let prompt_timer = Instant::now();
184 let mut prompt = self.build_prompt(input, &recalled_messages).await?;
185 self.log_timing("run_step.build_prompt", prompt_timer);
186
187 let store_user_timer = Instant::now();
189 let user_message_id = self.store_message(MessageRole::User, input).await?;
190 self.log_timing("run_step.store_user_message", store_user_timer);
191
192 let mut goal_context =
194 Some(self.create_goal_context(user_message_id, input, self.profile.enable_graph)?);
195
196 let mut tool_invocations = Vec::new();
198 let mut final_response = String::new();
199 let mut token_usage = None;
200 let mut finish_reason = None;
201 let mut auto_response: Option<String> = None;
202 let mut reasoning: Option<String> = None;
203 let mut reasoning_summary: Option<String> = None;
204
205 if let Some(goal) = goal_context.as_mut() {
207 if goal.requires_tool {
208 if let Some((tool_name, tool_args)) =
209 Self::infer_goal_tool_action(goal.text.as_str())
210 {
211 if self.is_tool_allowed(&tool_name).await {
212 let tool_timer = Instant::now();
213 let tool_result = self.execute_tool(&run_id, &tool_name, &tool_args).await;
214 self.log_timing("run_step.tool_execution.auto", tool_timer);
215 match tool_result {
216 Ok(result) => {
217 let invocation = ToolInvocation::from_result(
218 &tool_name,
219 tool_args.clone(),
220 &result,
221 );
222 if let Err(err) = self
223 .record_goal_tool_result(goal, &tool_name, &tool_args, &result)
224 {
225 warn!("Failed to record goal progress: {}", err);
226 }
227 if result.success {
228 if let Err(err) =
229 self.update_goal_status(goal, "completed", true, None)
230 {
231 warn!("Failed to update goal status: {}", err);
232 } else {
233 goal.satisfied = true;
234 }
235 }
236 auto_response = Some(Self::format_auto_tool_response(
237 &tool_name,
238 invocation.output.as_deref(),
239 ));
240 tool_invocations.push(invocation);
241 }
242 Err(err) => {
243 warn!("Auto tool execution '{}' failed: {}", tool_name, err);
244 }
245 }
246 }
247 }
248 }
249 }
250
251 let skip_model = goal_context
252 .as_ref()
253 .map(|goal| goal.requires_tool && goal.satisfied && auto_response.is_some())
254 .unwrap_or(false);
255
256 let mut fast_model_final: Option<(String, f32)> = None;
258 if !skip_model {
259 if let Some(task_type) = self.detect_task_type(input) {
260 let complexity = self.estimate_task_complexity(input);
261 if self.should_use_fast_model(&task_type, complexity) {
262 let fast_timer = Instant::now();
263 let fast_result = self.fast_reasoning(&task_type, input).await;
264 self.log_timing("run_step.fast_reasoning_attempt", fast_timer);
265 match fast_result {
266 Ok((fast_text, confidence)) => {
267 if confidence >= self.escalation_threshold() {
268 fast_model_final = Some((fast_text, confidence));
269 } else {
270 prompt.push_str(&format!(
271 "\n\nFAST_MODEL_HINT (task={} confidence={:.0}%):\n{}\n\nRefine this hint and produce a complete answer.",
272 task_type,
273 (confidence * 100.0).round(),
274 fast_text
275 ));
276 }
277 }
278 Err(err) => {
279 warn!("Fast reasoning failed for task {}: {}", task_type, err);
280 }
281 }
282 }
283 }
284 }
285
286 if skip_model {
287 final_response = auto_response.unwrap_or_else(|| "Task completed.".to_string());
288 finish_reason = Some("auto_tool".to_string());
289 } else if let Some((fast_text, confidence)) = fast_model_final {
290 final_response = fast_text;
291 finish_reason = Some(format!("fast_model ({:.0}%)", (confidence * 100.0).round()));
292 } else {
293 for _iteration in 0..5 {
295 let generation_config = self.build_generation_config();
297 let model_timer = Instant::now();
298 let response_result = self.provider.generate(&prompt, &generation_config).await;
299 self.log_timing("run_step.main_model_call", model_timer);
300 let response = response_result.context("Failed to generate response from model")?;
301
302 token_usage = response.usage;
303 finish_reason = response.finish_reason.clone();
304 final_response = response.content.clone();
305 reasoning = response.reasoning.clone();
306
307 if let Some(ref reasoning_text) = reasoning {
309 reasoning_summary = self.summarize_reasoning(reasoning_text).await;
310 }
311
312 let sdk_tool_calls = response.tool_calls.clone().unwrap_or_default();
314
315 if sdk_tool_calls.is_empty() {
317 let is_complete = finish_reason.as_ref().is_some_and(|reason| {
319 let reason_lower = reason.to_lowercase();
320 reason_lower.contains("stop")
321 || reason_lower.contains("end_turn")
322 || reason_lower.contains("complete")
323 || reason_lower == "length"
324 });
325
326 let goal_needs_tool = goal_context
328 .as_ref()
329 .is_some_and(|g| g.requires_tool && !g.satisfied);
330
331 if is_complete && !goal_needs_tool {
332 debug!("Early termination: response complete with no tool calls needed");
333 break;
334 }
335 }
336
337 if !sdk_tool_calls.is_empty() {
338 for tool_call in sdk_tool_calls {
340 let tool_name = &tool_call.function_name;
341 let tool_args = &tool_call.arguments;
342
343 if !self.is_tool_allowed(tool_name).await {
345 warn!(
346 "Tool '{}' is not allowed by agent policy - prompting user",
347 tool_name
348 );
349
350 match self.prompt_for_tool_permission(tool_name).await {
352 Ok(true) => {
353 info!("User granted permission for tool '{}'", tool_name);
354 }
356 Ok(false) => {
357 let error_msg =
358 format!("Tool '{}' was denied by user", tool_name);
359 warn!("{}", error_msg);
360 tool_invocations.push(ToolInvocation {
361 name: tool_name.clone(),
362 arguments: tool_args.clone(),
363 success: false,
364 output: None,
365 error: Some(error_msg),
366 });
367 continue;
368 }
369 Err(e) => {
370 let error_msg = format!(
371 "Failed to get user permission for tool '{}': {}",
372 tool_name, e
373 );
374 warn!("{}", error_msg);
375 tool_invocations.push(ToolInvocation {
376 name: tool_name.clone(),
377 arguments: tool_args.clone(),
378 success: false,
379 output: None,
380 error: Some(error_msg),
381 });
382 continue;
383 }
384 }
385 }
386
387 let tool_timer = Instant::now();
389 let exec_result = self.execute_tool(&run_id, tool_name, tool_args).await;
390 self.log_timing("run_step.tool_execution.sdk", tool_timer);
391 match exec_result {
392 Ok(result) => {
393 let invocation = ToolInvocation::from_result(
394 tool_name,
395 tool_args.clone(),
396 &result,
397 );
398 let tool_output = invocation.output.clone().unwrap_or_default();
399 let was_success = invocation.success;
400 let error_message = invocation
401 .error
402 .clone()
403 .unwrap_or_else(|| "Tool execution failed".to_string());
404 tool_invocations.push(invocation);
405
406 if let Some(goal) = goal_context.as_mut() {
407 if let Err(err) = self.record_goal_tool_result(
408 goal, tool_name, tool_args, &result,
409 ) {
410 warn!("Failed to record goal progress: {}", err);
411 }
412 if result.success && goal.requires_tool && !goal.satisfied {
413 if let Err(err) =
414 self.update_goal_status(goal, "in_progress", true, None)
415 {
416 warn!("Failed to update goal status: {}", err);
417 }
418 }
419 }
420
421 if was_success {
422 prompt.push_str(&format!(
424 "\n\nTOOL_RESULT from {}:\n{}\n\nBased on this result, please continue.",
425 tool_name, tool_output
426 ));
427 } else {
428 prompt.push_str(&format!(
429 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
430 error_message
431 ));
432 }
433 }
434 Err(e) => {
435 let error_msg =
436 format!("Error executing tool '{}': {}", tool_name, e);
437 warn!("{}", error_msg);
438 prompt.push_str(&format!(
439 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
440 error_msg
441 ));
442 tool_invocations.push(ToolInvocation {
443 name: tool_name.clone(),
444 arguments: tool_args.clone(),
445 success: false,
446 output: None,
447 error: Some(error_msg),
448 });
449 }
450 }
451 }
452
453 continue;
455 }
456
457 if let Some(goal) = goal_context.as_ref() {
458 if goal.requires_tool && !goal.satisfied {
459 prompt.push_str(
460 "\n\nGOAL_STATUS: pending. The user request requires executing an allowed tool. Please call an appropriate tool.",
461 );
462 continue;
463 }
464 }
465
466 break;
468 }
469 }
470
471 let store_assistant_timer = Instant::now();
473 let response_message_id = self
474 .store_message_with_reasoning(
475 MessageRole::Assistant,
476 &final_response,
477 reasoning.as_deref(),
478 )
479 .await?;
480 self.log_timing("run_step.store_assistant_message", store_assistant_timer);
481
482 if let Some(goal) = goal_context.as_mut() {
483 if goal.requires_tool {
484 if goal.satisfied {
485 if let Err(err) =
486 self.update_goal_status(goal, "completed", true, Some(response_message_id))
487 {
488 warn!("Failed to finalize goal status: {}", err);
489 }
490 } else if let Err(err) =
491 self.update_goal_status(goal, "blocked", false, Some(response_message_id))
492 {
493 warn!("Failed to record blocked goal status: {}", err);
494 }
495 } else if let Err(err) =
496 self.update_goal_status(goal, "completed", true, Some(response_message_id))
497 {
498 warn!("Failed to finalize goal status: {}", err);
499 }
500 }
501
502 self.conversation_history.push(Message {
504 id: user_message_id,
505 session_id: self.session_id.clone(),
506 role: MessageRole::User,
507 content: input.to_string(),
508 created_at: Utc::now(),
509 });
510
511 self.conversation_history.push(Message {
512 id: response_message_id,
513 session_id: self.session_id.clone(),
514 role: MessageRole::Assistant,
515 content: final_response.clone(),
516 created_at: Utc::now(),
517 });
518
519 let next_action_recommendation =
522 if self.profile.enable_graph && self.conversation_history.len() >= 3 {
523 let graph_timer = Instant::now();
524 let recommendation = self.evaluate_graph_for_next_action(
525 user_message_id,
526 response_message_id,
527 &final_response,
528 &tool_invocations,
529 )?;
530 self.log_timing("run_step.evaluate_graph_for_next_action", graph_timer);
531 recommendation
532 } else {
533 None
534 };
535
536 if let Some(ref recommendation) = next_action_recommendation {
538 tracing::debug!("Knowledge graph recommends next action: {}", recommendation);
539 let system_content = format!("Graph recommendation: {}", recommendation);
540 let system_store_timer = Instant::now();
541 let system_message_id = self
542 .store_message(MessageRole::System, &system_content)
543 .await?;
544 self.log_timing("run_step.store_system_message", system_store_timer);
545
546 self.conversation_history.push(Message {
547 id: system_message_id,
548 session_id: self.session_id.clone(),
549 role: MessageRole::System,
550 content: system_content,
551 created_at: Utc::now(),
552 });
553 }
554
555 let graph_debug = match self.snapshot_graph_debug_info() {
556 Ok(info) => Some(info),
557 Err(err) => {
558 warn!("Failed to capture graph debug info: {}", err);
559 None
560 }
561 };
562
563 self.log_timing("run_step.total", total_timer);
564
565 Ok(AgentOutput {
566 response: final_response,
567 response_message_id: Some(response_message_id),
568 token_usage,
569 tool_invocations,
570 finish_reason,
571 recall_stats,
572 run_id,
573 next_action: next_action_recommendation,
574 reasoning,
575 reasoning_summary,
576 graph_debug,
577 })
578 }
579
580 pub async fn run_spec(&mut self, spec: &AgentSpec) -> Result<AgentOutput> {
582 debug!(
583 "Executing structured spec '{}' (source: {:?})",
584 spec.display_name(),
585 spec.source_path()
586 );
587 let prompt = spec.to_prompt();
588 self.run_step(&prompt).await
589 }
590
591 pub async fn run_step_streaming(
597 &mut self,
598 input: &str,
599 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
600 let recall_result = self.recall_memories(input).await?;
602 let recalled_messages = recall_result.messages;
603
604 let prompt = self.build_prompt(input, &recalled_messages).await?;
606
607 let user_message_id = self.store_message(MessageRole::User, input).await?;
609
610 self.conversation_history.push(Message {
612 id: user_message_id,
613 session_id: self.session_id.clone(),
614 role: MessageRole::User,
615 content: input.to_string(),
616 created_at: Utc::now(),
617 });
618
619 let generation_config = self.build_generation_config();
621 let stream = self
622 .provider
623 .stream(&prompt, &generation_config)
624 .await
625 .context("Failed to start streaming response from model")?;
626
627 Ok(stream)
628 }
629
630 pub async fn finalize_streaming_step(&mut self, content: &str) -> Result<i64> {
634 let message_id = self.store_message(MessageRole::Assistant, content).await?;
636
637 self.conversation_history.push(Message {
639 id: message_id,
640 session_id: self.session_id.clone(),
641 role: MessageRole::Assistant,
642 content: content.to_string(),
643 created_at: Utc::now(),
644 });
645
646 Ok(message_id)
647 }
648
649 fn build_generation_config(&self) -> GenerationConfig {
651 let temperature = match self.profile.temperature {
652 Some(temp) if temp.is_finite() => Some(temp.clamp(0.0, 2.0)),
653 Some(_) => {
654 warn!(
655 "Ignoring invalid temperature for agent {:?}, falling back to {}",
656 self.agent_name, DEFAULT_MAIN_TEMPERATURE
657 );
658 Some(DEFAULT_MAIN_TEMPERATURE)
659 }
660 None => None,
661 };
662
663 let top_p = if self.profile.top_p.is_finite() {
664 Some(self.profile.top_p.clamp(0.0, 1.0))
665 } else {
666 warn!(
667 "Invalid top_p detected for agent {:?}, falling back to {}",
668 self.agent_name, DEFAULT_TOP_P
669 );
670 Some(DEFAULT_TOP_P)
671 };
672
673 GenerationConfig {
674 temperature,
675 max_tokens: self.profile.max_context_tokens.map(|t| t as u32),
676 stop_sequences: None,
677 top_p,
678 frequency_penalty: None,
679 presence_penalty: None,
680 }
681 }
682
683 fn snapshot_graph_debug_info(&self) -> Result<GraphDebugInfo> {
684 let mut info = GraphDebugInfo {
685 enabled: self.profile.enable_graph,
686 graph_memory_enabled: self.profile.graph_memory,
687 auto_graph_enabled: self.profile.auto_graph,
688 graph_steering_enabled: self.profile.graph_steering,
689 node_count: 0,
690 edge_count: 0,
691 recent_nodes: Vec::new(),
692 };
693
694 if !self.profile.enable_graph {
695 return Ok(info);
696 }
697
698 info.node_count = self.persistence.count_graph_nodes(&self.session_id)?.max(0) as usize;
699 info.edge_count = self.persistence.count_graph_edges(&self.session_id)?.max(0) as usize;
700
701 let recent_nodes = self
702 .persistence
703 .list_graph_nodes(&self.session_id, None, Some(5))?;
704 info.recent_nodes = recent_nodes
705 .into_iter()
706 .map(|node| GraphDebugNode {
707 id: node.id,
708 node_type: node.node_type.as_str().to_string(),
709 label: node.label,
710 })
711 .collect();
712
713 Ok(info)
714 }
715
716 async fn summarize_reasoning(&self, reasoning: &str) -> Option<String> {
718 let fast_provider = self.fast_provider.as_ref()?;
720
721 if reasoning.len() < 50 {
722 return Some(reasoning.to_string());
724 }
725
726 let summary_prompt = format!(
727 "Summarize the following reasoning in 1-2 concise sentences that explain the thought process:\n\n{}\n\nSummary:",
728 reasoning
729 );
730
731 let config = GenerationConfig {
732 temperature: Some(0.3),
733 max_tokens: Some(100),
734 stop_sequences: None,
735 top_p: Some(0.9),
736 frequency_penalty: None,
737 presence_penalty: None,
738 };
739
740 let timer = Instant::now();
741 let response = fast_provider.generate(&summary_prompt, &config).await;
742 self.log_timing("summarize_reasoning.generate", timer);
743 match response {
744 Ok(response) => {
745 let summary = response.content.trim().to_string();
746 if !summary.is_empty() {
747 debug!("Generated reasoning summary: {}", summary);
748 Some(summary)
749 } else {
750 None
751 }
752 }
753 Err(e) => {
754 warn!("Failed to summarize reasoning: {}", e);
755 None
756 }
757 }
758 }
759
760 async fn recall_memories(&self, query: &str) -> Result<RecallResult> {
762 const RECENT_CONTEXT: i64 = 2;
763 let mut context = Vec::new();
765 let mut seen_ids = HashSet::new();
766
767 let recent_messages = self
768 .persistence
769 .list_messages(&self.session_id, RECENT_CONTEXT)?;
770
771 if self.conversation_history.is_empty() && recent_messages.is_empty() {
774 return Ok(RecallResult {
775 messages: Vec::new(),
776 stats: Some(MemoryRecallStats {
777 strategy: MemoryRecallStrategy::RecentContext {
778 limit: RECENT_CONTEXT as usize,
779 },
780 matches: Vec::new(),
781 }),
782 });
783 }
784
785 for message in recent_messages {
786 seen_ids.insert(message.id);
787 context.push(message);
788 }
789
790 if self.profile.enable_graph && self.profile.graph_memory {
792 let mut graph_messages = Vec::new();
793
794 for msg in &context {
796 let nodes = self.persistence.list_graph_nodes(
798 &self.session_id,
799 Some(NodeType::Message),
800 Some(10),
801 )?;
802
803 for node in nodes {
804 if let Some(msg_id) = node.properties["message_id"].as_i64() {
805 if msg_id == msg.id {
806 let neighbors = self.persistence.traverse_neighbors(
808 &self.session_id,
809 node.id,
810 TraversalDirection::Both,
811 self.profile.graph_depth,
812 )?;
813
814 for neighbor in neighbors {
816 if neighbor.node_type == NodeType::Message {
817 if let Some(related_msg_id) =
818 neighbor.properties["message_id"].as_i64()
819 {
820 if !seen_ids.contains(&related_msg_id) {
821 if let Some(related_msg) =
822 self.persistence.get_message(related_msg_id)?
823 {
824 seen_ids.insert(related_msg.id);
825 graph_messages.push(related_msg);
826 }
827 }
828 }
829 }
830 }
831 }
832 }
833 }
834 }
835
836 context.extend(graph_messages);
838 }
839
840 if let Some(client) = &self.embeddings_client {
841 if self.profile.memory_k == 0 || query.trim().is_empty() {
842 return Ok(RecallResult {
843 messages: context,
844 stats: None,
845 });
846 }
847
848 let embed_timer = Instant::now();
849 let embed_result = client.embed_batch(&[query]).await;
850 self.log_timing("recall_memories.embed_batch", embed_timer);
851 match embed_result {
852 Ok(mut embeddings) => match embeddings.pop() {
853 Some(query_embedding) if !query_embedding.is_empty() => {
854 let recalled = self.persistence.recall_top_k(
855 &self.session_id,
856 &query_embedding,
857 self.profile.memory_k,
858 )?;
859
860 let mut matches = Vec::new();
861 let mut semantic_context = Vec::new();
862
863 for (memory, score) in recalled {
864 if let Some(message_id) = memory.message_id {
865 if seen_ids.contains(&message_id) {
866 continue;
867 }
868
869 if let Some(message) = self.persistence.get_message(message_id)? {
870 seen_ids.insert(message.id);
871 matches.push(MemoryRecallMatch {
872 message_id: Some(message.id),
873 score,
874 role: message.role.clone(),
875 preview: preview_text(&message.content),
876 });
877 semantic_context.push(message);
878 }
879 } else {
880 if let Some(transcription_text) =
882 self.persistence.get_transcription_by_embedding(memory.id)?
883 {
884 let transcription_message = Message {
886 id: memory.id, session_id: memory.session_id.clone(),
888 role: MessageRole::User, content: format!("[Transcription] {}", transcription_text),
890 created_at: memory.created_at,
891 };
892
893 matches.push(MemoryRecallMatch {
894 message_id: None, score,
896 role: MessageRole::User,
897 preview: preview_text(&transcription_text),
898 });
899 semantic_context.push(transcription_message);
900 }
901 }
902 }
903
904 if self.profile.enable_graph && self.profile.graph_memory {
906 let mut graph_expanded = Vec::new();
907
908 for msg in &semantic_context {
909 let nodes = self.persistence.list_graph_nodes(
911 &self.session_id,
912 Some(NodeType::Message),
913 Some(100),
914 )?;
915
916 for node in nodes {
917 if let Some(msg_id) = node.properties["message_id"].as_i64() {
918 if msg_id == msg.id {
919 let neighbors = self.persistence.traverse_neighbors(
921 &self.session_id,
922 node.id,
923 TraversalDirection::Both,
924 self.profile.graph_depth,
925 )?;
926
927 for neighbor in neighbors {
928 if matches!(
930 neighbor.node_type,
931 NodeType::Fact
932 | NodeType::Concept
933 | NodeType::Entity
934 ) {
935 let graph_content = format!(
937 "[Graph Context - {} {}]: {}",
938 neighbor.node_type.as_str(),
939 neighbor.label,
940 neighbor.properties
941 );
942
943 let graph_msg = Message {
945 id: -1, session_id: self.session_id.clone(),
947 role: MessageRole::System,
948 content: graph_content,
949 created_at: Utc::now(),
950 };
951
952 graph_expanded.push(graph_msg);
953 }
954 }
955 }
956 }
957 }
958 }
959
960 let total_slots = self.profile.memory_k.max(1);
962 let mut graph_limit =
963 ((total_slots as f32) * self.profile.graph_weight).round() as usize;
964 graph_limit = graph_limit.min(total_slots);
965 if graph_limit == 0 && !graph_expanded.is_empty() {
966 graph_limit = 1;
967 }
968
969 let mut semantic_limit = total_slots.saturating_sub(graph_limit);
970 if semantic_limit == 0 && !semantic_context.is_empty() {
971 semantic_limit = 1;
972 graph_limit = graph_limit.saturating_sub(1);
973 }
974
975 let mut limited_semantic = semantic_context;
976 if limited_semantic.len() > semantic_limit && semantic_limit > 0 {
977 limited_semantic.truncate(semantic_limit);
978 }
979
980 let mut limited_graph = graph_expanded;
981 if limited_graph.len() > graph_limit && graph_limit > 0 {
982 limited_graph.truncate(graph_limit);
983 }
984
985 context.extend(limited_semantic);
986 context.extend(limited_graph);
987 } else {
988 context.extend(semantic_context);
989 }
990
991 return Ok(RecallResult {
992 messages: context,
993 stats: Some(MemoryRecallStats {
994 strategy: MemoryRecallStrategy::Semantic {
995 requested: self.profile.memory_k,
996 returned: matches.len(),
997 },
998 matches,
999 }),
1000 });
1001 }
1002 _ => {
1003 return Ok(RecallResult {
1004 messages: context,
1005 stats: Some(MemoryRecallStats {
1006 strategy: MemoryRecallStrategy::Semantic {
1007 requested: self.profile.memory_k,
1008 returned: 0,
1009 },
1010 matches: Vec::new(),
1011 }),
1012 });
1013 }
1014 },
1015 Err(err) => {
1016 warn!("Failed to embed recall query: {}", err);
1017 return Ok(RecallResult {
1018 messages: context,
1019 stats: None,
1020 });
1021 }
1022 }
1023 }
1024
1025 let limit = self.profile.memory_k as i64;
1027 let messages = self.persistence.list_messages(&self.session_id, limit)?;
1028
1029 let stats = if self.profile.memory_k > 0 {
1030 Some(MemoryRecallStats {
1031 strategy: MemoryRecallStrategy::RecentContext {
1032 limit: self.profile.memory_k,
1033 },
1034 matches: Vec::new(),
1035 })
1036 } else {
1037 None
1038 };
1039
1040 Ok(RecallResult { messages, stats })
1041 }
1042
1043 async fn build_prompt(&self, input: &str, context_messages: &[Message]) -> Result<String> {
1045 let mut prompt = String::new();
1046
1047 if let Some(system_prompt) = &self.profile.prompt {
1049 prompt.push_str("System: ");
1050 prompt.push_str(system_prompt);
1051 prompt.push_str("\n\n");
1052 }
1053
1054 if self.speak_responses {
1056 prompt.push_str("System: Speech mode is enabled; respond with concise, natural sentences suitable for text-to-speech. Avoid markdown/code fences and keep the reply brief.\n\n");
1057 }
1058
1059 let available_tools = self.tool_registry.list();
1061 tracing::debug!("Tool registry has {} tools", available_tools.len());
1062 if !available_tools.is_empty() {
1063 prompt.push_str("Available tools:\n");
1064 for tool_name in &available_tools {
1065 info!(
1066 "Checking tool: {} - allowed: {}",
1067 tool_name,
1068 self.is_tool_allowed(tool_name).await
1069 );
1070 if self.is_tool_allowed(tool_name).await {
1071 if let Some(tool) = self.tool_registry.get(tool_name) {
1072 prompt.push_str(&format!("- {}: {}\n", tool_name, tool.description()));
1073 }
1074 }
1075 }
1076 prompt.push('\n');
1077 }
1078
1079 if !context_messages.is_empty() {
1081 prompt.push_str("Previous conversation:\n");
1082 for msg in context_messages {
1083 prompt.push_str(&format!("{}: {}\n", msg.role.as_str(), msg.content));
1084 }
1085 prompt.push('\n');
1086 }
1087
1088 prompt.push_str(&format!("user: {}\n", input));
1090
1091 prompt.push_str("assistant:");
1092
1093 Ok(prompt)
1094 }
1095
1096 async fn store_message(&self, role: MessageRole, content: &str) -> Result<i64> {
1098 self.store_message_with_reasoning(role, content, None).await
1099 }
1100
1101 async fn store_message_with_reasoning(
1103 &self,
1104 role: MessageRole,
1105 content: &str,
1106 reasoning: Option<&str>,
1107 ) -> Result<i64> {
1108 let message_id = self
1109 .persistence
1110 .insert_message(&self.session_id, role.clone(), content)
1111 .context("Failed to store message")?;
1112
1113 let mut embedding_id = None;
1114
1115 if let Some(client) = &self.embeddings_client {
1116 if !content.trim().is_empty() {
1117 let embed_timer = Instant::now();
1118 let embed_result = client.embed_batch(&[content]).await;
1119 self.log_timing("embeddings.message_content", embed_timer);
1120 match embed_result {
1121 Ok(mut embeddings) => {
1122 if let Some(embedding) = embeddings.pop() {
1123 if !embedding.is_empty() {
1124 match self.persistence.insert_memory_vector(
1125 &self.session_id,
1126 Some(message_id),
1127 &embedding,
1128 ) {
1129 Ok(emb_id) => {
1130 embedding_id = Some(emb_id);
1131 }
1132 Err(err) => {
1133 warn!(
1134 "Failed to persist embedding for message {}: {}",
1135 message_id, err
1136 );
1137 }
1138 }
1139 }
1140 }
1141 }
1142 Err(err) => {
1143 warn!(
1144 "Failed to create embedding for message {}: {}",
1145 message_id, err
1146 );
1147 }
1148 }
1149 }
1150 }
1151
1152 if self.profile.enable_graph && self.profile.auto_graph {
1154 self.build_graph_for_message(message_id, role, content, embedding_id, reasoning)?;
1155 }
1156
1157 Ok(message_id)
1158 }
1159
1160 fn build_graph_for_message(
1162 &self,
1163 message_id: i64,
1164 role: MessageRole,
1165 content: &str,
1166 embedding_id: Option<i64>,
1167 reasoning: Option<&str>,
1168 ) -> Result<()> {
1169 use serde_json::json;
1170
1171 let mut message_props = json!({
1173 "message_id": message_id,
1174 "role": role.as_str(),
1175 "content_preview": preview_text(content),
1176 "timestamp": Utc::now().to_rfc3339(),
1177 });
1178
1179 if let Some(reasoning_text) = reasoning {
1181 if !reasoning_text.is_empty() {
1182 message_props["has_reasoning"] = json!(true);
1183 message_props["reasoning_preview"] = json!(preview_text(reasoning_text));
1184 }
1185 }
1186
1187 let message_node_id = self.persistence.insert_graph_node(
1188 &self.session_id,
1189 NodeType::Message,
1190 &format!("{:?}Message", role),
1191 &message_props,
1192 embedding_id,
1193 )?;
1194
1195 let mut entities = self.extract_entities_from_text(content);
1197 let mut concepts = self.extract_concepts_from_text(content);
1198
1199 if let Some(reasoning_text) = reasoning {
1202 if !reasoning_text.is_empty() {
1203 debug!(
1204 "Extracting entities/concepts from reasoning for message {}",
1205 message_id
1206 );
1207 let reasoning_entities = self.extract_entities_from_text(reasoning_text);
1208 let reasoning_concepts = self.extract_concepts_from_text(reasoning_text);
1209
1210 for mut reasoning_entity in reasoning_entities {
1212 if let Some(existing) = entities.iter_mut().find(|e| {
1214 e.name.to_lowercase() == reasoning_entity.name.to_lowercase()
1215 && e.entity_type == reasoning_entity.entity_type
1216 }) {
1217 existing.confidence =
1219 (existing.confidence + reasoning_entity.confidence * 0.5).min(1.0);
1220 } else {
1221 reasoning_entity.confidence *= 0.8;
1223 entities.push(reasoning_entity);
1224 }
1225 }
1226
1227 for mut reasoning_concept in reasoning_concepts {
1229 if let Some(existing) = concepts
1230 .iter_mut()
1231 .find(|c| c.name.to_lowercase() == reasoning_concept.name.to_lowercase())
1232 {
1233 existing.relevance =
1234 (existing.relevance + reasoning_concept.relevance * 0.5).min(1.0);
1235 } else {
1236 reasoning_concept.relevance *= 0.8;
1237 concepts.push(reasoning_concept);
1238 }
1239 }
1240 }
1241 }
1242
1243 for entity in entities {
1245 let entity_node_id = self.persistence.insert_graph_node(
1246 &self.session_id,
1247 NodeType::Entity,
1248 &entity.entity_type,
1249 &json!({
1250 "name": entity.name,
1251 "type": entity.entity_type,
1252 "extracted_from": message_id,
1253 }),
1254 None,
1255 )?;
1256
1257 self.persistence.insert_graph_edge(
1259 &self.session_id,
1260 message_node_id,
1261 entity_node_id,
1262 EdgeType::Mentions,
1263 Some("mentions"),
1264 Some(&json!({"confidence": entity.confidence})),
1265 entity.confidence,
1266 )?;
1267 }
1268
1269 for concept in concepts {
1271 let concept_node_id = self.persistence.insert_graph_node(
1272 &self.session_id,
1273 NodeType::Concept,
1274 "Concept",
1275 &json!({
1276 "name": concept.name,
1277 "extracted_from": message_id,
1278 }),
1279 None,
1280 )?;
1281
1282 self.persistence.insert_graph_edge(
1284 &self.session_id,
1285 message_node_id,
1286 concept_node_id,
1287 EdgeType::RelatesTo,
1288 Some("discusses"),
1289 Some(&json!({"relevance": concept.relevance})),
1290 concept.relevance,
1291 )?;
1292 }
1293
1294 let recent_messages = self.persistence.list_messages(&self.session_id, 2)?;
1296 if recent_messages.len() > 1 {
1297 let nodes = self.persistence.list_graph_nodes(
1299 &self.session_id,
1300 Some(NodeType::Message),
1301 Some(10),
1302 )?;
1303
1304 for node in nodes {
1305 if let Some(prev_msg_id) = node.properties["message_id"].as_i64() {
1306 if prev_msg_id != message_id && prev_msg_id == recent_messages[0].id {
1307 self.persistence.insert_graph_edge(
1309 &self.session_id,
1310 node.id,
1311 message_node_id,
1312 EdgeType::FollowsFrom,
1313 Some("conversation_flow"),
1314 None,
1315 1.0,
1316 )?;
1317 break;
1318 }
1319 }
1320 }
1321 }
1322
1323 Ok(())
1324 }
1325
1326 fn create_goal_context(
1327 &self,
1328 message_id: i64,
1329 input: &str,
1330 persist: bool,
1331 ) -> Result<GoalContext> {
1332 let requires_tool = Self::goal_requires_tool(input);
1333 let node_id = if self.profile.enable_graph {
1334 if persist {
1335 let properties = json!({
1336 "message_id": message_id,
1337 "goal_text": input,
1338 "status": "pending",
1339 "requires_tool": requires_tool,
1340 "satisfied": false,
1341 "created_at": Utc::now().to_rfc3339(),
1342 });
1343 Some(self.persistence.insert_graph_node(
1344 &self.session_id,
1345 NodeType::Goal,
1346 "Goal",
1347 &properties,
1348 None,
1349 )?)
1350 } else {
1351 None
1352 }
1353 } else {
1354 None
1355 };
1356
1357 Ok(GoalContext::new(message_id, input, requires_tool, node_id))
1358 }
1359
1360 fn update_goal_status(
1361 &self,
1362 goal: &mut GoalContext,
1363 status: &str,
1364 satisfied: bool,
1365 response_message_id: Option<i64>,
1366 ) -> Result<()> {
1367 goal.satisfied = satisfied;
1368 if let Some(node_id) = goal.node_id {
1369 let properties = json!({
1370 "message_id": goal.message_id,
1371 "goal_text": goal.text,
1372 "status": status,
1373 "requires_tool": goal.requires_tool,
1374 "satisfied": satisfied,
1375 "response_message_id": response_message_id,
1376 "updated_at": Utc::now().to_rfc3339(),
1377 });
1378 self.persistence.update_graph_node(node_id, &properties)?;
1379 }
1380 Ok(())
1381 }
1382
1383 fn record_goal_tool_result(
1384 &self,
1385 goal: &GoalContext,
1386 tool_name: &str,
1387 args: &Value,
1388 result: &ToolResult,
1389 ) -> Result<()> {
1390 if let Some(goal_node_id) = goal.node_id {
1391 let timestamp = Utc::now().to_rfc3339();
1392 let mut properties = json!({
1393 "tool": tool_name,
1394 "arguments": args,
1395 "success": result.success,
1396 "output_preview": preview_text(&result.output),
1397 "error": result.error,
1398 "timestamp": timestamp,
1399 });
1400
1401 let mut prompt_payload: Option<Value> = None;
1402 if tool_name == "prompt_user" && result.success {
1403 match serde_json::from_str::<Value>(&result.output) {
1404 Ok(payload) => {
1405 if let Some(props) = properties.as_object_mut() {
1406 props.insert("prompt_user_payload".to_string(), payload.clone());
1407 if let Some(response) = payload.get("response") {
1408 props.insert(
1409 "response_preview".to_string(),
1410 Value::String(preview_json_value(response)),
1411 );
1412 }
1413 }
1414 prompt_payload = Some(payload);
1415 }
1416 Err(err) => {
1417 warn!("Failed to parse prompt_user payload for graph: {}", err);
1418 if let Some(props) = properties.as_object_mut() {
1419 props.insert(
1420 "prompt_user_parse_error".to_string(),
1421 Value::String(err.to_string()),
1422 );
1423 }
1424 }
1425 }
1426 }
1427
1428 let tool_node_id = self.persistence.insert_graph_node(
1429 &self.session_id,
1430 NodeType::ToolResult,
1431 tool_name,
1432 &properties,
1433 None,
1434 )?;
1435 self.persistence.insert_graph_edge(
1436 &self.session_id,
1437 tool_node_id,
1438 goal_node_id,
1439 EdgeType::Produces,
1440 Some("satisfies"),
1441 None,
1442 if result.success { 1.0 } else { 0.1 },
1443 )?;
1444
1445 if let Some(payload) = prompt_payload {
1446 let response_preview = payload
1447 .get("response")
1448 .map(preview_json_value)
1449 .unwrap_or_default();
1450
1451 let response_properties = json!({
1452 "prompt": payload_field(&payload, "prompt"),
1453 "input_type": payload_field(&payload, "input_type"),
1454 "response": payload_field(&payload, "response"),
1455 "display_value": payload_field(&payload, "display_value"),
1456 "selections": payload_field(&payload, "selections"),
1457 "metadata": payload_field(&payload, "metadata"),
1458 "used_default": payload_field(&payload, "used_default"),
1459 "used_prefill": payload_field(&payload, "used_prefill"),
1460 "response_preview": response_preview,
1461 "timestamp": timestamp,
1462 });
1463
1464 let response_node_id = self.persistence.insert_graph_node(
1465 &self.session_id,
1466 NodeType::Event,
1467 "UserInput",
1468 &response_properties,
1469 None,
1470 )?;
1471
1472 self.persistence.insert_graph_edge(
1473 &self.session_id,
1474 tool_node_id,
1475 response_node_id,
1476 EdgeType::Produces,
1477 Some("captures_input"),
1478 None,
1479 1.0,
1480 )?;
1481
1482 self.persistence.insert_graph_edge(
1483 &self.session_id,
1484 response_node_id,
1485 goal_node_id,
1486 EdgeType::RelatesTo,
1487 Some("addresses_goal"),
1488 None,
1489 0.9,
1490 )?;
1491 }
1492 }
1493 Ok(())
1494 }
1495
1496 fn goal_requires_tool(input: &str) -> bool {
1497 let normalized = input.to_lowercase();
1498 const ACTION_VERBS: [&str; 18] = [
1499 "list", "show", "read", "write", "create", "update", "delete", "run", "execute",
1500 "open", "search", "fetch", "download", "scan", "compile", "test", "build", "inspect",
1501 ];
1502
1503 if ACTION_VERBS
1504 .iter()
1505 .any(|verb| normalized.contains(verb) && normalized.contains(' '))
1506 {
1507 return true;
1508 }
1509
1510 let mentions_local_context = normalized.contains("this directory")
1513 || normalized.contains("current directory")
1514 || normalized.contains("this folder")
1515 || normalized.contains("here");
1516
1517 let mentions_project = normalized.contains("this project")
1518 || normalized.contains("this repo")
1519 || normalized.contains("this repository")
1520 || normalized.contains("this codebase")
1521 || ((normalized.contains("project")
1523 || normalized.contains("repo")
1524 || normalized.contains("repository")
1525 || normalized.contains("codebase"))
1526 && mentions_local_context);
1527
1528 let asks_about_project = normalized.contains("what can")
1529 || normalized.contains("what is")
1530 || normalized.contains("what does")
1531 || normalized.contains("tell me")
1532 || normalized.contains("describe")
1533 || normalized.contains("about");
1534
1535 mentions_project && asks_about_project
1536 }
1537
1538 fn escalation_threshold(&self) -> f32 {
1539 if self.profile.escalation_threshold.is_finite() {
1540 self.profile.escalation_threshold.clamp(0.0, 1.0)
1541 } else {
1542 warn!(
1543 "Invalid escalation_threshold for agent {:?}, defaulting to {}",
1544 self.agent_name, DEFAULT_ESCALATION_THRESHOLD
1545 );
1546 DEFAULT_ESCALATION_THRESHOLD
1547 }
1548 }
1549
1550 fn detect_task_type(&self, input: &str) -> Option<String> {
1551 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1552 return None;
1553 }
1554
1555 let text = input.to_lowercase();
1556
1557 let candidates: [(&str, &[&str]); 6] = [
1558 ("entity_extraction", &["entity", "extract", "named"]),
1559 ("decision_routing", &["classify", "categorize", "route"]),
1560 (
1561 "tool_selection",
1562 &["which tool", "use which tool", "tool should"],
1563 ),
1564 ("confidence_scoring", &["confidence", "certainty"]),
1565 ("summarization", &["summarize", "summary"]),
1566 ("graph_analysis", &["graph", "connection", "relationships"]),
1567 ];
1568
1569 for (task, keywords) in candidates {
1570 if keywords.iter().any(|kw| text.contains(kw))
1571 && self.profile.fast_model_tasks.iter().any(|t| t == task)
1572 {
1573 return Some(task.to_string());
1574 }
1575 }
1576
1577 None
1578 }
1579
1580 fn estimate_task_complexity(&self, input: &str) -> f32 {
1581 let words = input.split_whitespace().count() as f32;
1582 let clauses =
1583 input.matches(" and ").count() as f32 + input.matches(" then ").count() as f32;
1584 let newlines = input.matches('\n').count() as f32;
1585
1586 let length_factor = (words / 120.0).min(1.0);
1587 let clause_factor = (clauses / 4.0).min(1.0);
1588 let structure_factor = (newlines / 5.0).min(1.0);
1589
1590 (0.6 * length_factor + 0.3 * clause_factor + 0.1 * structure_factor).clamp(0.0, 1.0)
1591 }
1592
1593 fn infer_goal_tool_action(goal_text: &str) -> Option<(String, Value)> {
1594 let text = goal_text.to_lowercase();
1595
1596 let mentions_local_context = text.contains("this directory")
1598 || text.contains("current directory")
1599 || text.contains("this folder")
1600 || text.contains("here");
1601
1602 let mentions_project = text.contains("this project")
1603 || text.contains("this repo")
1604 || text.contains("this repository")
1605 || text.contains("this codebase")
1606 || ((text.contains("project")
1607 || text.contains("repo")
1608 || text.contains("repository")
1609 || text.contains("codebase"))
1610 && mentions_local_context);
1611
1612 let asks_about_project = text.contains("what can")
1613 || text.contains("what is")
1614 || text.contains("what does")
1615 || text.contains("tell me")
1616 || text.contains("describe")
1617 || text.contains("about");
1618
1619 if mentions_project && asks_about_project {
1620 for candidate in &["README.md", "Readme.md", "readme.md"] {
1622 if Path::new(candidate).exists() {
1623 return Some((
1624 "file_read".to_string(),
1625 json!({
1626 "path": candidate,
1627 "max_bytes": 65536
1628 }),
1629 ));
1630 }
1631 }
1632
1633 return Some((
1635 "search".to_string(),
1636 json!({
1637 "query": "Cargo.toml|package.json|pyproject.toml|setup.py",
1638 "regex": true,
1639 "case_sensitive": false,
1640 "max_results": 20
1641 }),
1642 ));
1643 }
1644
1645 if text.contains("list")
1647 && (text.contains("directory") || text.contains("files") || text.contains("folder"))
1648 {
1649 return Some((
1650 "shell".to_string(),
1651 json!({
1652 "command": "ls -a"
1653 }),
1654 ));
1655 }
1656
1657 if text.contains("show") && text.contains("current directory") {
1658 return Some((
1659 "shell".to_string(),
1660 json!({
1661 "command": "ls -a"
1662 }),
1663 ));
1664 }
1665
1666 None
1670 }
1671
1672 fn parse_confidence(text: &str) -> Option<f32> {
1673 for line in text.lines() {
1674 let lower = line.to_lowercase();
1675 if lower.contains("confidence") {
1676 let token = lower
1677 .split(|c: char| !(c.is_ascii_digit() || c == '.'))
1678 .find(|chunk| !chunk.is_empty())?;
1679 if let Ok(value) = token.parse::<f32>() {
1680 if (0.0..=1.0).contains(&value) {
1681 return Some(value);
1682 }
1683 }
1684 }
1685 }
1686 None
1687 }
1688
1689 fn strip_fast_answer(text: &str) -> String {
1690 let mut answer = String::new();
1691 for line in text.lines() {
1692 if line.to_lowercase().starts_with("answer:") {
1693 answer.push_str(line.split_once(':').map(|x| x.1).unwrap_or("").trim());
1694 break;
1695 }
1696 }
1697 if answer.is_empty() {
1698 text.trim().to_string()
1699 } else {
1700 answer
1701 }
1702 }
1703
1704 fn format_auto_tool_response(tool_name: &str, output: Option<&str>) -> String {
1705 let sanitized = output.unwrap_or("").trim();
1706 if sanitized.is_empty() {
1707 return format!("Executed `{}` successfully.", tool_name);
1708 }
1709
1710 if tool_name == "file_read" {
1711 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1712 let path = value.get("path").and_then(|v| v.as_str()).unwrap_or("file");
1713 let content = value.get("content").and_then(|v| v.as_str()).unwrap_or("");
1714
1715 let max_chars = 4000;
1717 let display_content = if content.len() > max_chars {
1718 let mut snippet = content[..max_chars].to_string();
1719 snippet.push_str("\n...\n[truncated]");
1720 snippet
1721 } else {
1722 content.to_string()
1723 };
1724
1725 return format!("Contents of {}:\n{}", path, display_content);
1726 }
1727 }
1728
1729 if tool_name == "search" {
1730 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1731 let query = value.get("query").and_then(|v| v.as_str()).unwrap_or("");
1732
1733 if let Some(results) = value.get("results").and_then(|v| v.as_array()) {
1734 if results.is_empty() {
1735 return if query.is_empty() {
1736 "Search returned no results.".to_string()
1737 } else {
1738 format!("Search for {:?} returned no results.", query)
1739 };
1740 }
1741
1742 let mut lines = Vec::new();
1743 if query.is_empty() {
1744 lines.push("Search results:".to_string());
1745 } else {
1746 lines.push(format!("Search results for {:?}:", query));
1747 }
1748
1749 for entry in results.iter().take(5) {
1750 let path = entry
1751 .get("path")
1752 .and_then(|v| v.as_str())
1753 .unwrap_or("<unknown>");
1754 let line = entry.get("line").and_then(|v| v.as_u64()).unwrap_or(0);
1755 let snippet = entry
1756 .get("snippet")
1757 .and_then(|v| v.as_str())
1758 .unwrap_or("")
1759 .replace('\n', " ");
1760
1761 lines.push(format!("- {}:{} - {}", path, line, snippet));
1762 }
1763
1764 return lines.join("\n");
1765 }
1766 }
1767 }
1768
1769 if tool_name == "shell" || tool_name == "bash" {
1770 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1771 let std_out = value
1772 .get("stdout")
1773 .and_then(|v| v.as_str())
1774 .unwrap_or(sanitized);
1775 let command = value.get("command").and_then(|v| v.as_str()).unwrap_or("");
1776 let stderr = value
1777 .get("stderr")
1778 .and_then(|v| v.as_str())
1779 .map(|s| s.trim())
1780 .filter(|s| !s.is_empty())
1781 .unwrap_or("");
1782 let mut response = String::new();
1783 if !command.is_empty() {
1784 response.push_str(&format!("Command `{}` output:\n", command));
1785 }
1786 response.push_str(std_out.trim_end());
1787 if !stderr.is_empty() {
1788 response.push_str("\n\nstderr:\n");
1789 response.push_str(stderr);
1790 }
1791 if response.trim().is_empty() {
1792 return "Command completed without output.".to_string();
1793 }
1794 return response;
1795 }
1796 }
1797
1798 sanitized.to_string()
1799 }
1800
1801 fn extract_entities_from_text(&self, text: &str) -> Vec<ExtractedEntity> {
1803 if self.profile.fast_reasoning
1805 && self.fast_provider.is_some()
1806 && self
1807 .profile
1808 .fast_model_tasks
1809 .contains(&"entity_extraction".to_string())
1810 {
1811 debug!("Using fast model for entity extraction");
1813 }
1816
1817 let mut entities = Vec::new();
1818
1819 let url_regex = regex::Regex::new(r"https?://[^\s]+").unwrap();
1824 for mat in url_regex.find_iter(text) {
1825 entities.push(ExtractedEntity {
1826 name: mat.as_str().to_string(),
1827 entity_type: "URL".to_string(),
1828 confidence: 0.9,
1829 });
1830 }
1831
1832 let email_regex =
1834 regex::Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b").unwrap();
1835 for mat in email_regex.find_iter(text) {
1836 entities.push(ExtractedEntity {
1837 name: mat.as_str().to_string(),
1838 entity_type: "Email".to_string(),
1839 confidence: 0.9,
1840 });
1841 }
1842
1843 let quote_regex = regex::Regex::new(r#""([^"]+)""#).unwrap();
1845 for cap in quote_regex.captures_iter(text) {
1846 if let Some(quoted) = cap.get(1) {
1847 entities.push(ExtractedEntity {
1848 name: quoted.as_str().to_string(),
1849 entity_type: "Quote".to_string(),
1850 confidence: 0.7,
1851 });
1852 }
1853 }
1854
1855 entities
1856 }
1857
1858 async fn fast_reasoning(&self, task: &str, input: &str) -> Result<(String, f32)> {
1860 let total_timer = Instant::now();
1861 let result = if let Some(ref fast_provider) = self.fast_provider {
1862 let prompt = format!(
1863 "You are a fast specialist model that assists a more capable agent.\nTask: {}\nInput: {}\n\nRespond with two lines:\nAnswer: <concise result>\nConfidence: <0-1 decimal>",
1864 task, input
1865 );
1866
1867 let fast_temperature = if self.profile.fast_model_temperature.is_finite() {
1868 self.profile.fast_model_temperature.clamp(0.0, 2.0)
1869 } else {
1870 warn!(
1871 "Invalid fast_model_temperature for agent {:?}, using {}",
1872 self.agent_name, DEFAULT_FAST_TEMPERATURE
1873 );
1874 DEFAULT_FAST_TEMPERATURE
1875 };
1876
1877 let config = GenerationConfig {
1878 temperature: Some(fast_temperature),
1879 max_tokens: Some(256), stop_sequences: None,
1881 top_p: Some(DEFAULT_TOP_P),
1882 frequency_penalty: None,
1883 presence_penalty: None,
1884 };
1885
1886 let call_timer = Instant::now();
1887 let response_result = fast_provider.generate(&prompt, &config).await;
1888 self.log_timing("fast_reasoning.generate", call_timer);
1889 let response = response_result?;
1890
1891 let confidence = Self::parse_confidence(&response.content).unwrap_or(0.7);
1892 let cleaned = Self::strip_fast_answer(&response.content);
1893
1894 Ok((cleaned, confidence))
1895 } else {
1896 Ok((String::new(), 0.0))
1898 };
1899
1900 self.log_timing("fast_reasoning.total", total_timer);
1901 result
1902 }
1903
1904 fn should_use_fast_model(&self, task_type: &str, complexity_score: f32) -> bool {
1906 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1908 return false; }
1910
1911 if !self
1913 .profile
1914 .fast_model_tasks
1915 .contains(&task_type.to_string())
1916 {
1917 return false; }
1919
1920 let threshold = self.escalation_threshold();
1922 if complexity_score > threshold {
1923 info!(
1924 "Task complexity {} exceeds threshold {}, using main model",
1925 complexity_score, threshold
1926 );
1927 return false; }
1929
1930 true }
1932
1933 fn extract_concepts_from_text(&self, text: &str) -> Vec<ExtractedConcept> {
1935 let mut concepts = Vec::new();
1936
1937 let concept_keywords = vec![
1939 ("graph", "Knowledge Graph"),
1940 ("memory", "Memory System"),
1941 ("embedding", "Embeddings"),
1942 ("tool", "Tool Usage"),
1943 ("agent", "Agent System"),
1944 ("database", "Database"),
1945 ("query", "Query Processing"),
1946 ("node", "Graph Node"),
1947 ("edge", "Graph Edge"),
1948 ];
1949
1950 let text_lower = text.to_lowercase();
1951 for (keyword, concept_name) in concept_keywords {
1952 if text_lower.contains(keyword) {
1953 concepts.push(ExtractedConcept {
1954 name: concept_name.to_string(),
1955 relevance: 0.6,
1956 });
1957 }
1958 }
1959
1960 concepts
1961 }
1962
1963 pub fn session_id(&self) -> &str {
1965 &self.session_id
1966 }
1967
1968 pub fn profile(&self) -> &AgentProfile {
1970 &self.profile
1971 }
1972
1973 pub fn agent_name(&self) -> Option<&str> {
1975 self.agent_name.as_deref()
1976 }
1977
1978 pub fn conversation_history(&self) -> &[Message] {
1980 &self.conversation_history
1981 }
1982
1983 pub fn load_history(&mut self, limit: i64) -> Result<()> {
1985 self.conversation_history = self.persistence.list_messages(&self.session_id, limit)?;
1986 Ok(())
1987 }
1988
1989 async fn is_tool_allowed(&self, tool_name: &str) -> bool {
1991 {
1993 let cache = self.tool_permission_cache.read().await;
1994 if let Some(&allowed) = cache.get(tool_name) {
1995 return allowed;
1996 }
1997 }
1998
1999 let profile_allowed = self.profile.is_tool_allowed(tool_name);
2001 debug!(
2002 "Profile check for tool '{}': allowed={}, allowed_tools={:?}, denied_tools={:?}",
2003 tool_name, profile_allowed, self.profile.allowed_tools, self.profile.denied_tools
2004 );
2005 if !profile_allowed {
2006 self.tool_permission_cache
2007 .write()
2008 .await
2009 .insert(tool_name.to_string(), false);
2010 return false;
2011 }
2012
2013 let agent_name = self.agent_name.as_deref().unwrap_or("agent");
2015 let decision = self.policy_engine.check(agent_name, "tool_call", tool_name);
2016 debug!(
2017 "Policy check for tool '{}': decision={:?}",
2018 tool_name, decision
2019 );
2020
2021 let allowed = matches!(decision, PolicyDecision::Allow);
2022 self.tool_permission_cache
2023 .write()
2024 .await
2025 .insert(tool_name.to_string(), allowed);
2026 allowed
2027 }
2028
2029 async fn prompt_for_tool_permission(&mut self, tool_name: &str) -> Result<bool> {
2031 info!("Requesting user permission for tool: {}", tool_name);
2032
2033 let tool_description = self
2035 .tool_registry
2036 .get(tool_name)
2037 .map(|t| t.description().to_string())
2038 .unwrap_or_else(|| "No description available".to_string());
2039
2040 let prompt_args = json!({
2042 "prompt": format!(
2043 "The agent wants to use the '{}' tool.\n\nDescription: {}\n\nDo you want to allow this?",
2044 tool_name,
2045 tool_description
2046 ),
2047 "input_type": "boolean",
2048 "required": true,
2049 });
2050
2051 match self.tool_registry.execute("prompt_user", prompt_args).await {
2052 Ok(result) if result.success => {
2053 info!("prompt_user output: {}", result.output);
2054
2055 let allowed =
2057 if let Ok(response_json) = serde_json::from_str::<Value>(&result.output) {
2058 info!("Parsed JSON response: {:?}", response_json);
2059 let value = response_json["response"].as_bool();
2061 info!("Extracted boolean value: {:?}", value);
2062 value.unwrap_or(false)
2063 } else {
2064 info!("Failed to parse JSON, trying plain text fallback");
2065 let response = result.output.trim().to_lowercase();
2067 let parsed = response == "yes" || response == "y" || response == "true";
2068 info!("Plain text parse result for '{}': {}", response, parsed);
2069 parsed
2070 };
2071
2072 info!(
2073 "User {} tool '{}'",
2074 if allowed { "allowed" } else { "denied" },
2075 tool_name
2076 );
2077
2078 if allowed {
2079 self.add_allowed_tool(tool_name).await;
2081 } else {
2082 self.add_denied_tool(tool_name).await;
2084 }
2085
2086 Ok(allowed)
2087 }
2088 Ok(result) => {
2089 warn!("Failed to prompt user: {:?}", result.error);
2090 Ok(false)
2091 }
2092 Err(e) => {
2093 warn!("Error prompting user for permission: {}", e);
2094 Ok(false)
2095 }
2096 }
2097 }
2098
2099 async fn add_allowed_tool(&mut self, tool_name: &str) {
2101 let tools = self.profile.allowed_tools.get_or_insert_with(Vec::new);
2102 if !tools.contains(&tool_name.to_string()) {
2103 tools.push(tool_name.to_string());
2104 info!("Added '{}' to allowed tools list", tool_name);
2105 }
2106 self.tool_permission_cache.write().await.remove(tool_name);
2108 }
2109
2110 async fn add_denied_tool(&mut self, tool_name: &str) {
2112 let tools = self.profile.denied_tools.get_or_insert_with(Vec::new);
2113 if !tools.contains(&tool_name.to_string()) {
2114 tools.push(tool_name.to_string());
2115 info!("Added '{}' to denied tools list", tool_name);
2116 }
2117 self.tool_permission_cache.write().await.remove(tool_name);
2119 }
2120
2121 async fn execute_tool(
2123 &self,
2124 run_id: &str,
2125 tool_name: &str,
2126 args: &Value,
2127 ) -> Result<ToolResult> {
2128 let exec_result = self.tool_registry.execute(tool_name, args.clone()).await;
2130 let result = match exec_result {
2131 Ok(res) => res,
2132 Err(err) => ToolResult::failure(err.to_string()),
2133 };
2134
2135 let result_json = serde_json::json!({
2137 "output": result.output,
2138 "success": result.success,
2139 "error": result.error,
2140 });
2141
2142 let error_str = result.error.as_deref();
2143 self.persistence
2144 .log_tool(
2145 &self.session_id,
2146 self.agent_name.as_deref().unwrap_or("unknown"),
2147 run_id,
2148 tool_name,
2149 args,
2150 &result_json,
2151 result.success,
2152 error_str,
2153 )
2154 .context("Failed to log tool execution")?;
2155
2156 Ok(result)
2157 }
2158
2159 pub fn tool_registry(&self) -> &ToolRegistry {
2161 &self.tool_registry
2162 }
2163
2164 pub fn policy_engine(&self) -> &PolicyEngine {
2166 &self.policy_engine
2167 }
2168
2169 pub fn set_policy_engine(&mut self, policy_engine: Arc<PolicyEngine>) {
2171 self.policy_engine = policy_engine;
2172 }
2173
2174 pub fn set_speak_responses(&mut self, enabled: bool) {
2176 #[cfg(target_os = "macos")]
2177 {
2178 self.speak_responses = enabled;
2179 }
2180 #[cfg(not(target_os = "macos"))]
2181 {
2182 let _ = enabled;
2183 self.speak_responses = false;
2184 }
2185 }
2186
2187 pub async fn generate_embedding(&self, text: &str) -> Option<i64> {
2190 if let Some(client) = &self.embeddings_client {
2191 if !text.trim().is_empty() {
2192 match client.embed_batch(&[text]).await {
2193 Ok(mut embeddings) => {
2194 if let Some(embedding) = embeddings.pop() {
2195 if !embedding.is_empty() {
2196 match self.persistence.insert_memory_vector(
2197 &self.session_id,
2198 None, &embedding,
2200 ) {
2201 Ok(emb_id) => return Some(emb_id),
2202 Err(err) => {
2203 warn!("Failed to persist embedding: {}", err);
2204 }
2205 }
2206 }
2207 }
2208 }
2209 Err(err) => {
2210 warn!("Failed to generate embedding: {}", err);
2211 }
2212 }
2213 }
2214 }
2215 None
2216 }
2217
2218 fn evaluate_graph_for_next_action(
2220 &self,
2221 user_message_id: i64,
2222 assistant_message_id: i64,
2223 response_content: &str,
2224 tool_invocations: &[ToolInvocation],
2225 ) -> Result<Option<String>> {
2226 let nodes = self.persistence.list_graph_nodes(
2228 &self.session_id,
2229 Some(NodeType::Message),
2230 Some(50),
2231 )?;
2232
2233 let mut assistant_node_id = None;
2234 let mut _user_node_id = None;
2235
2236 for node in &nodes {
2237 if let Some(msg_id) = node.properties["message_id"].as_i64() {
2238 if msg_id == assistant_message_id {
2239 assistant_node_id = Some(node.id);
2240 } else if msg_id == user_message_id {
2241 _user_node_id = Some(node.id);
2242 }
2243 }
2244 }
2245
2246 if assistant_node_id.is_none() {
2247 debug!("Assistant message node not found in graph");
2248 return Ok(None);
2249 }
2250
2251 let assistant_node_id = assistant_node_id.unwrap();
2252
2253 let neighbors = self.persistence.traverse_neighbors(
2255 &self.session_id,
2256 assistant_node_id,
2257 TraversalDirection::Both,
2258 2, )?;
2260
2261 let goal_nodes =
2263 self.persistence
2264 .list_graph_nodes(&self.session_id, Some(NodeType::Goal), Some(10))?;
2265
2266 let mut pending_goals = Vec::new();
2267 let mut completed_goals = Vec::new();
2268
2269 for goal in &goal_nodes {
2270 if let Some(status) = goal.properties["status"].as_str() {
2271 match status {
2272 "pending" | "in_progress" => {
2273 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2274 pending_goals.push(goal_text.to_string());
2275 }
2276 }
2277 "completed" => {
2278 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2279 completed_goals.push(goal_text.to_string());
2280 }
2281 }
2282 _ => {}
2283 }
2284 }
2285 }
2286
2287 let tool_nodes = self.persistence.list_graph_nodes(
2289 &self.session_id,
2290 Some(NodeType::ToolResult),
2291 Some(10),
2292 )?;
2293
2294 let mut recent_tool_failures = Vec::new();
2295 let mut recent_tool_successes = Vec::new();
2296
2297 for tool_node in &tool_nodes {
2298 if let Some(success) = tool_node.properties["success"].as_bool() {
2299 let tool_name = tool_node.properties["tool"].as_str().unwrap_or("unknown");
2300 if success {
2301 recent_tool_successes.push(tool_name.to_string());
2302 } else {
2303 recent_tool_failures.push(tool_name.to_string());
2304 }
2305 }
2306 }
2307
2308 let mut key_entities = HashSet::new();
2310 let mut key_concepts = HashSet::new();
2311
2312 for neighbor in &neighbors {
2313 match neighbor.node_type {
2314 NodeType::Entity => {
2315 if let Some(name) = neighbor.properties["name"].as_str() {
2316 key_entities.insert(name.to_string());
2317 }
2318 }
2319 NodeType::Concept => {
2320 if let Some(name) = neighbor.properties["name"].as_str() {
2321 key_concepts.insert(name.to_string());
2322 }
2323 }
2324 _ => {}
2325 }
2326 }
2327
2328 let recommendation = self.generate_action_recommendation(
2330 &pending_goals,
2331 &completed_goals,
2332 &recent_tool_failures,
2333 &recent_tool_successes,
2334 &key_entities,
2335 &key_concepts,
2336 response_content,
2337 tool_invocations,
2338 );
2339
2340 if let Some(ref rec) = recommendation {
2342 let properties = json!({
2343 "recommendation": rec,
2344 "user_message_id": user_message_id,
2345 "assistant_message_id": assistant_message_id,
2346 "pending_goals": pending_goals,
2347 "completed_goals": completed_goals,
2348 "tool_failures": recent_tool_failures,
2349 "tool_successes": recent_tool_successes,
2350 "key_entities": key_entities.into_iter().collect::<Vec<_>>(),
2351 "key_concepts": key_concepts.into_iter().collect::<Vec<_>>(),
2352 "timestamp": Utc::now().to_rfc3339(),
2353 });
2354
2355 let rec_node_id = self.persistence.insert_graph_node(
2356 &self.session_id,
2357 NodeType::Event,
2358 "NextActionRecommendation",
2359 &properties,
2360 None,
2361 )?;
2362
2363 self.persistence.insert_graph_edge(
2365 &self.session_id,
2366 assistant_node_id,
2367 rec_node_id,
2368 EdgeType::Produces,
2369 Some("recommends"),
2370 None,
2371 0.8,
2372 )?;
2373 }
2374
2375 Ok(recommendation)
2376 }
2377
2378 fn generate_action_recommendation(
2380 &self,
2381 pending_goals: &[String],
2382 completed_goals: &[String],
2383 recent_tool_failures: &[String],
2384 _recent_tool_successes: &[String],
2385 _key_entities: &HashSet<String>,
2386 key_concepts: &HashSet<String>,
2387 response_content: &str,
2388 tool_invocations: &[ToolInvocation],
2389 ) -> Option<String> {
2390 let mut recommendations = Vec::new();
2391
2392 if !pending_goals.is_empty() {
2394 let goals_str = pending_goals.join(", ");
2395 recommendations.push(format!("Continue working on pending goals: {}", goals_str));
2396 }
2397
2398 if !recent_tool_failures.is_empty() {
2400 let unique_failures: HashSet<_> = recent_tool_failures.iter().collect();
2401 for tool in unique_failures {
2402 recommendations.push(format!(
2403 "Consider alternative approach for failed tool: {}",
2404 tool
2405 ));
2406 }
2407 }
2408
2409 let response_lower = response_content.to_lowercase();
2411 if response_lower.contains("error") || response_lower.contains("failed") {
2412 recommendations.push("Investigate and resolve the reported error".to_string());
2413 }
2414
2415 if response_lower.contains("?") || response_lower.contains("unclear") {
2416 recommendations.push("Clarify the uncertain aspects mentioned".to_string());
2417 }
2418
2419 if tool_invocations.len() > 1 {
2421 let tool_sequence: Vec<_> = tool_invocations.iter().map(|t| t.name.as_str()).collect();
2422
2423 if tool_sequence.contains(&"file_read") && !tool_sequence.contains(&"file_write") {
2425 recommendations
2426 .push("Consider modifying the read files if changes are needed".to_string());
2427 }
2428
2429 if tool_sequence.contains(&"search")
2430 && tool_invocations.last().is_some_and(|t| t.success)
2431 {
2432 recommendations
2433 .push("Examine the search results for relevant information".to_string());
2434 }
2435 }
2436
2437 if key_concepts.contains("Knowledge Graph") || key_concepts.contains("Graph Node") {
2439 recommendations
2440 .push("Consider visualizing or querying the graph structure".to_string());
2441 }
2442
2443 if key_concepts.contains("Database") || key_concepts.contains("Query Processing") {
2444 recommendations.push("Verify data integrity and query performance".to_string());
2445 }
2446
2447 if !completed_goals.is_empty() && !pending_goals.is_empty() {
2449 recommendations.push(format!(
2450 "Build on completed work ({} done) to address remaining goals ({} pending)",
2451 completed_goals.len(),
2452 pending_goals.len()
2453 ));
2454 }
2455
2456 if recommendations.is_empty() {
2458 if completed_goals.len() > pending_goals.len() && recent_tool_failures.is_empty() {
2460 Some(
2461 "Current objectives appear satisfied. Ready for new tasks or refinements."
2462 .to_string(),
2463 )
2464 } else {
2465 None
2466 }
2467 } else {
2468 Some(recommendations[0].clone())
2470 }
2471 }
2472
2473 fn log_timing(&self, stage: &str, start: Instant) {
2474 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
2475 let agent_label = self.agent_name.as_deref().unwrap_or("unnamed");
2476 info!(
2477 target: "agent_timing",
2478 "stage={} duration_ms={:.2} agent={} session_id={}",
2479 stage,
2480 duration_ms,
2481 agent_label,
2482 self.session_id
2483 );
2484 }
2485}
2486
2487fn preview_text(content: &str) -> String {
2488 const MAX_CHARS: usize = 80;
2489 let trimmed = content.trim();
2490 let mut preview = String::new();
2491 for (idx, ch) in trimmed.chars().enumerate() {
2492 if idx >= MAX_CHARS {
2493 preview.push_str("...");
2494 break;
2495 }
2496 preview.push(ch);
2497 }
2498 if preview.is_empty() {
2499 trimmed.to_string()
2500 } else {
2501 preview
2502 }
2503}
2504
2505fn preview_json_value(value: &Value) -> String {
2506 match value {
2507 Value::String(text) => preview_text(text),
2508 Value::Null => "null".to_string(),
2509 _ => {
2510 let raw = value.to_string();
2511 if raw.len() > 80 {
2512 format!("{}...", &raw[..77])
2513 } else {
2514 raw
2515 }
2516 }
2517 }
2518}
2519
2520fn payload_field(payload: &Value, key: &str) -> Value {
2521 payload.get(key).cloned().unwrap_or(Value::Null)
2522}
2523
2524#[cfg(test)]
2525mod tests {
2526 use super::*;
2527 use crate::agent::providers::MockProvider;
2528 use crate::config::AgentProfile;
2529 use crate::embeddings::{EmbeddingsClient, EmbeddingsService};
2530 use async_trait::async_trait;
2531 use tempfile::tempdir;
2532
2533 fn create_test_agent(session_id: &str) -> (AgentCore, tempfile::TempDir) {
2534 create_test_agent_with_embeddings(session_id, None)
2535 }
2536
2537 fn create_test_agent_with_embeddings(
2538 session_id: &str,
2539 embeddings_client: Option<EmbeddingsClient>,
2540 ) -> (AgentCore, tempfile::TempDir) {
2541 let dir = tempdir().unwrap();
2542 let db_path = dir.path().join("test.duckdb");
2543 let persistence = Persistence::new(&db_path).unwrap();
2544
2545 let profile = AgentProfile {
2546 prompt: Some("You are a helpful assistant.".to_string()),
2547 style: None,
2548 temperature: Some(0.7),
2549 model_provider: None,
2550 model_name: None,
2551 allowed_tools: None,
2552 denied_tools: None,
2553 memory_k: 5,
2554 top_p: 0.9,
2555 max_context_tokens: Some(2048),
2556 enable_graph: false,
2557 graph_memory: false,
2558 auto_graph: false,
2559 graph_steering: false,
2560 graph_depth: 3,
2561 graph_weight: 0.5,
2562 graph_threshold: 0.7,
2563 fast_reasoning: false,
2564 fast_model_provider: None,
2565 fast_model_name: None,
2566 fast_model_temperature: 0.3,
2567 fast_model_tasks: vec![],
2568 escalation_threshold: 0.6,
2569 show_reasoning: false,
2570 enable_audio_transcription: false,
2571 audio_response_mode: "immediate".to_string(),
2572 audio_scenario: None,
2573 ..Default::default()
2574 };
2575
2576 let provider = Arc::new(MockProvider::new("This is a test response."));
2577 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2578 let policy_engine = Arc::new(PolicyEngine::new());
2579
2580 (
2581 AgentCore::new(
2582 profile,
2583 provider,
2584 embeddings_client,
2585 persistence,
2586 session_id.to_string(),
2587 Some(session_id.to_string()),
2588 tool_registry,
2589 policy_engine,
2590 false,
2591 ),
2592 dir,
2593 )
2594 }
2595
2596 fn create_fast_reasoning_agent(
2597 session_id: &str,
2598 fast_output: &str,
2599 ) -> (AgentCore, tempfile::TempDir) {
2600 let dir = tempdir().unwrap();
2601 let db_path = dir.path().join("fast.duckdb");
2602 let persistence = Persistence::new(&db_path).unwrap();
2603
2604 let profile = AgentProfile {
2605 prompt: Some("You are a helpful assistant.".to_string()),
2606 style: None,
2607 temperature: Some(0.7),
2608 model_provider: None,
2609 model_name: None,
2610 allowed_tools: None,
2611 denied_tools: None,
2612 memory_k: 5,
2613 top_p: 0.9,
2614 max_context_tokens: Some(2048),
2615 enable_graph: false,
2616 graph_memory: false,
2617 auto_graph: false,
2618 graph_steering: false,
2619 graph_depth: 3,
2620 graph_weight: 0.5,
2621 graph_threshold: 0.7,
2622 fast_reasoning: true,
2623 fast_model_provider: Some("mock".to_string()),
2624 fast_model_name: Some("mock-fast".to_string()),
2625 fast_model_temperature: 0.3,
2626 fast_model_tasks: vec!["entity_extraction".to_string()],
2627 escalation_threshold: 0.5,
2628 show_reasoning: false,
2629 enable_audio_transcription: false,
2630 audio_response_mode: "immediate".to_string(),
2631 audio_scenario: None,
2632 ..Default::default()
2633 };
2634
2635 profile.validate().unwrap();
2636
2637 let provider = Arc::new(MockProvider::new("This is a test response."));
2638 let fast_provider = Arc::new(MockProvider::new(fast_output.to_string()));
2639 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2640 let policy_engine = Arc::new(PolicyEngine::new());
2641
2642 (
2643 AgentCore::new(
2644 profile,
2645 provider,
2646 None,
2647 persistence,
2648 session_id.to_string(),
2649 Some(session_id.to_string()),
2650 tool_registry,
2651 policy_engine,
2652 false,
2653 )
2654 .with_fast_provider(fast_provider),
2655 dir,
2656 )
2657 }
2658
2659 #[derive(Clone)]
2660 struct KeywordEmbeddingsService;
2661
2662 #[async_trait]
2663 impl EmbeddingsService for KeywordEmbeddingsService {
2664 async fn create_embeddings(
2665 &self,
2666 _model: &str,
2667 inputs: Vec<String>,
2668 ) -> Result<Vec<Vec<f32>>> {
2669 Ok(inputs
2670 .into_iter()
2671 .map(|input| keyword_embedding(&input))
2672 .collect())
2673 }
2674 }
2675
2676 fn keyword_embedding(input: &str) -> Vec<f32> {
2677 let lower = input.to_ascii_lowercase();
2678 let alpha = if lower.contains("alpha") { 1.0 } else { 0.0 };
2679 let beta = if lower.contains("beta") { 1.0 } else { 0.0 };
2680 vec![alpha, beta]
2681 }
2682
2683 fn test_embeddings_client() -> EmbeddingsClient {
2684 EmbeddingsClient::with_service(
2685 "test",
2686 Arc::new(KeywordEmbeddingsService) as Arc<dyn EmbeddingsService>,
2687 )
2688 }
2689
2690 #[tokio::test]
2691 async fn test_agent_core_run_step() {
2692 let (mut agent, _dir) = create_test_agent("test-session-1");
2693
2694 let output = agent.run_step("Hello, how are you?").await.unwrap();
2695
2696 assert!(!output.response.is_empty());
2697 assert!(output.token_usage.is_some());
2698 assert_eq!(output.tool_invocations.len(), 0);
2699 }
2700
2701 #[tokio::test]
2702 async fn fast_model_short_circuits_when_confident() {
2703 let (mut agent, _dir) = create_fast_reasoning_agent(
2704 "fast-confident",
2705 "Answer: Entities detected.\nConfidence: 0.9",
2706 );
2707
2708 let output = agent
2709 .run_step("Extract the entities mentioned in this string.")
2710 .await
2711 .unwrap();
2712
2713 assert!(output
2714 .finish_reason
2715 .unwrap_or_default()
2716 .contains("fast_model"));
2717 assert!(output.response.contains("Entities detected"));
2718 }
2719
2720 #[tokio::test]
2721 async fn fast_model_only_hints_when_low_confidence() {
2722 let (mut agent, _dir) =
2723 create_fast_reasoning_agent("fast-hint", "Answer: Unsure.\nConfidence: 0.2");
2724
2725 let output = agent
2726 .run_step("Extract the entities mentioned in this string.")
2727 .await
2728 .unwrap();
2729
2730 assert_eq!(output.finish_reason.as_deref(), Some("stop"));
2731 assert_eq!(output.response, "This is a test response.");
2732 }
2733
2734 #[tokio::test]
2735 async fn test_agent_core_conversation_history() {
2736 let (mut agent, _dir) = create_test_agent("test-session-2");
2737
2738 agent.run_step("First message").await.unwrap();
2739 agent.run_step("Second message").await.unwrap();
2740
2741 let history = agent.conversation_history();
2742 assert_eq!(history.len(), 4); assert_eq!(history[0].role, MessageRole::User);
2744 assert_eq!(history[1].role, MessageRole::Assistant);
2745 }
2746
2747 #[tokio::test]
2748 async fn test_agent_core_session_switch() {
2749 let (mut agent, _dir) = create_test_agent("session-1");
2750
2751 agent.run_step("Message in session 1").await.unwrap();
2752 assert_eq!(agent.session_id(), "session-1");
2753
2754 agent = agent.with_session("session-2".to_string());
2755 assert_eq!(agent.session_id(), "session-2");
2756 assert_eq!(agent.conversation_history().len(), 0);
2757 }
2758
2759 #[tokio::test]
2760 async fn agent_session_avoids_sync_namespace() {
2761 let (mut agent, _dir) = create_test_agent(SYNC_GRAPH_NAMESPACE);
2762
2763 assert_eq!(
2764 agent.session_id(),
2765 format!("{}-agent", SYNC_GRAPH_NAMESPACE)
2766 );
2767
2768 agent = agent.with_session(SYNC_GRAPH_NAMESPACE.to_string());
2769 assert_eq!(
2770 agent.session_id(),
2771 format!("{}-agent", SYNC_GRAPH_NAMESPACE)
2772 );
2773 }
2774
2775 #[tokio::test]
2776 async fn test_agent_core_build_prompt() {
2777 let (agent, _dir) = create_test_agent("test-session-3");
2778
2779 let context = vec![
2780 Message {
2781 id: 1,
2782 session_id: "test-session-3".to_string(),
2783 role: MessageRole::User,
2784 content: "Previous question".to_string(),
2785 created_at: Utc::now(),
2786 },
2787 Message {
2788 id: 2,
2789 session_id: "test-session-3".to_string(),
2790 role: MessageRole::Assistant,
2791 content: "Previous answer".to_string(),
2792 created_at: Utc::now(),
2793 },
2794 ];
2795
2796 let prompt = agent
2797 .build_prompt("Current question", &context)
2798 .await
2799 .unwrap();
2800
2801 assert!(prompt.contains("You are a helpful assistant"));
2802 assert!(prompt.contains("Previous conversation"));
2803 assert!(prompt.contains("user: Previous question"));
2804 assert!(prompt.contains("assistant: Previous answer"));
2805 assert!(prompt.contains("user: Current question"));
2806 }
2807
2808 #[tokio::test]
2809 async fn test_agent_core_persistence() {
2810 let (mut agent, _dir) = create_test_agent("persist-test");
2811
2812 agent.run_step("Test message").await.unwrap();
2813
2814 let messages = agent
2816 .persistence
2817 .list_messages("persist-test", 100)
2818 .unwrap();
2819
2820 assert_eq!(messages.len(), 2); assert_eq!(messages[0].role, MessageRole::User);
2822 assert_eq!(messages[0].content, "Test message");
2823 }
2824
2825 #[tokio::test]
2826 async fn store_message_records_embeddings() {
2827 let (agent, _dir) =
2828 create_test_agent_with_embeddings("embedding-store", Some(test_embeddings_client()));
2829
2830 let message_id = agent
2831 .store_message(MessageRole::User, "Alpha detail")
2832 .await
2833 .unwrap();
2834
2835 let query = vec![1.0f32, 0.0];
2836 let recalled = agent
2837 .persistence
2838 .recall_top_k("embedding-store", &query, 1)
2839 .unwrap();
2840
2841 assert_eq!(recalled.len(), 1);
2842 assert_eq!(recalled[0].0.message_id, Some(message_id));
2843 }
2844
2845 #[tokio::test]
2846 async fn recall_memories_appends_semantic_matches() {
2847 let (agent, _dir) =
2848 create_test_agent_with_embeddings("semantic-recall", Some(test_embeddings_client()));
2849
2850 agent
2851 .store_message(MessageRole::User, "Alpha question")
2852 .await
2853 .unwrap();
2854 agent
2855 .store_message(MessageRole::Assistant, "Alpha answer")
2856 .await
2857 .unwrap();
2858 agent
2859 .store_message(MessageRole::User, "Beta prompt")
2860 .await
2861 .unwrap();
2862 agent
2863 .store_message(MessageRole::Assistant, "Beta reply")
2864 .await
2865 .unwrap();
2866
2867 let recall = agent.recall_memories("alpha follow up").await.unwrap();
2868 assert!(matches!(
2869 recall.stats.as_ref().map(|s| &s.strategy),
2870 Some(MemoryRecallStrategy::Semantic { .. })
2871 ));
2872 assert_eq!(
2873 recall
2874 .stats
2875 .as_ref()
2876 .map(|s| s.matches.len())
2877 .unwrap_or_default(),
2878 2
2879 );
2880
2881 let recalled = recall.messages;
2882 assert_eq!(recalled.len(), 4);
2883 assert_eq!(recalled[0].content, "Beta prompt");
2884 assert_eq!(recalled[1].content, "Beta reply");
2885
2886 let tail: Vec<_> = recalled[2..].iter().map(|m| m.content.as_str()).collect();
2887 assert!(tail.contains(&"Alpha question"));
2888 assert!(tail.contains(&"Alpha answer"));
2889 }
2890
2891 #[tokio::test]
2892 async fn test_agent_tool_permission_allowed() {
2893 let dir = tempdir().unwrap();
2894 let db_path = dir.path().join("test.duckdb");
2895 let persistence = Persistence::new(&db_path).unwrap();
2896
2897 let mut profile = AgentProfile {
2898 prompt: Some("Test".to_string()),
2899 style: None,
2900 temperature: Some(0.7),
2901 model_provider: None,
2902 model_name: None,
2903 allowed_tools: Some(vec!["echo".to_string()]),
2904 denied_tools: None,
2905 memory_k: 5,
2906 top_p: 0.9,
2907 max_context_tokens: Some(2048),
2908 enable_graph: false,
2909 graph_memory: false,
2910 auto_graph: false,
2911 graph_steering: false,
2912 graph_depth: 3,
2913 graph_weight: 0.5,
2914 graph_threshold: 0.7,
2915 fast_reasoning: false,
2916 fast_model_provider: None,
2917 fast_model_name: None,
2918 fast_model_temperature: 0.3,
2919 fast_model_tasks: vec![],
2920 escalation_threshold: 0.6,
2921 show_reasoning: false,
2922 enable_audio_transcription: false,
2923 audio_response_mode: "immediate".to_string(),
2924 audio_scenario: None,
2925 ..Default::default()
2926 };
2927
2928 let provider = Arc::new(MockProvider::new("Test"));
2929 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2930
2931 let mut policy_engine = PolicyEngine::new();
2933 policy_engine.add_rule(crate::policy::PolicyRule {
2934 agent: "*".to_string(),
2935 action: "tool_call".to_string(),
2936 resource: "*".to_string(),
2937 effect: crate::policy::PolicyEffect::Allow,
2938 });
2939 let policy_engine = Arc::new(policy_engine);
2940
2941 let agent = AgentCore::new(
2942 profile.clone(),
2943 provider.clone(),
2944 None,
2945 persistence.clone(),
2946 "test-session".to_string(),
2947 Some("policy-test".to_string()),
2948 tool_registry.clone(),
2949 policy_engine.clone(),
2950 false,
2951 );
2952
2953 assert!(agent.is_tool_allowed("echo").await);
2954 assert!(!agent.is_tool_allowed("calculator").await);
2955
2956 profile.allowed_tools = None;
2958 profile.denied_tools = Some(vec!["calculator".to_string()]);
2959
2960 let agent = AgentCore::new(
2961 profile,
2962 provider,
2963 None,
2964 persistence,
2965 "test-session-2".to_string(),
2966 Some("policy-test-2".to_string()),
2967 tool_registry,
2968 policy_engine,
2969 false,
2970 );
2971
2972 assert!(agent.is_tool_allowed("echo").await);
2973 assert!(!agent.is_tool_allowed("calculator").await);
2974 }
2975
2976 #[tokio::test]
2977 async fn test_agent_tool_execution_with_logging() {
2978 let dir = tempdir().unwrap();
2979 let db_path = dir.path().join("test.duckdb");
2980 let persistence = Persistence::new(&db_path).unwrap();
2981
2982 let profile = AgentProfile {
2983 prompt: Some("Test".to_string()),
2984 style: None,
2985 temperature: Some(0.7),
2986 model_provider: None,
2987 model_name: None,
2988 allowed_tools: Some(vec!["echo".to_string()]),
2989 denied_tools: None,
2990 memory_k: 5,
2991 top_p: 0.9,
2992 max_context_tokens: Some(2048),
2993 enable_graph: false,
2994 graph_memory: false,
2995 auto_graph: false,
2996 graph_steering: false,
2997 graph_depth: 3,
2998 graph_weight: 0.5,
2999 graph_threshold: 0.7,
3000 fast_reasoning: false,
3001 fast_model_provider: None,
3002 fast_model_name: None,
3003 fast_model_temperature: 0.3,
3004 fast_model_tasks: vec![],
3005 escalation_threshold: 0.6,
3006 show_reasoning: false,
3007 enable_audio_transcription: false,
3008 audio_response_mode: "immediate".to_string(),
3009 audio_scenario: None,
3010 ..Default::default()
3011 };
3012
3013 let provider = Arc::new(MockProvider::new("Test"));
3014
3015 let mut tool_registry = crate::tools::ToolRegistry::new();
3017 tool_registry.register(Arc::new(crate::tools::builtin::EchoTool::new()));
3018
3019 let policy_engine = Arc::new(PolicyEngine::new());
3020
3021 let agent = AgentCore::new(
3022 profile,
3023 provider,
3024 None,
3025 persistence.clone(),
3026 "tool-exec-test".to_string(),
3027 Some("tool-agent".to_string()),
3028 Arc::new(tool_registry),
3029 policy_engine,
3030 false,
3031 );
3032
3033 let args = serde_json::json!({"message": "test message"});
3035 let result = agent
3036 .execute_tool("run-tool-test", "echo", &args)
3037 .await
3038 .unwrap();
3039
3040 assert!(result.success);
3041 assert_eq!(result.output, "test message");
3042
3043 }
3045
3046 #[tokio::test]
3047 async fn test_agent_tool_registry_access() {
3048 let (agent, _dir) = create_test_agent("registry-test");
3049
3050 let registry = agent.tool_registry();
3051 assert!(registry.is_empty());
3052 }
3053
3054 #[test]
3055 fn test_goal_requires_tool_detection() {
3056 assert!(AgentCore::goal_requires_tool(
3057 "List the files in this directory"
3058 ));
3059 assert!(AgentCore::goal_requires_tool("Run the tests"));
3060 assert!(!AgentCore::goal_requires_tool("Explain recursion"));
3061 assert!(AgentCore::goal_requires_tool(
3062 "Tell me about the project in this directory"
3063 ));
3064 }
3065
3066 #[test]
3067 fn test_infer_goal_tool_action_project_description() {
3068 let query = "Tell me about the project in this directory";
3069 let inferred = AgentCore::infer_goal_tool_action(query)
3070 .expect("Should infer a tool for project description");
3071 let (tool, args) = inferred;
3072 assert!(
3073 tool == "file_read" || tool == "search",
3074 "unexpected tool: {}",
3075 tool
3076 );
3077 if tool == "file_read" {
3078 assert!(args.get("path").is_some());
3080 assert!(args.get("max_bytes").is_some());
3081 } else {
3082 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
3084 assert!(query.contains("Cargo.toml") || query.contains("package.json"));
3085 }
3086 }
3087}