1use crate::agent::model::{GenerationConfig, ModelProvider};
6pub use crate::agent::output::{
7 AgentOutput, GraphDebugInfo, GraphDebugNode, MemoryRecallMatch, MemoryRecallStats,
8 MemoryRecallStrategy, ToolInvocation,
9};
10use crate::config::agent::AgentProfile;
11use crate::embeddings::EmbeddingsClient;
12use crate::persistence::Persistence;
13use crate::policy::{PolicyDecision, PolicyEngine};
14use crate::spec::AgentSpec;
15use crate::tools::{ToolRegistry, ToolResult};
16use crate::types::{EdgeType, Message, MessageRole, NodeType, TraversalDirection};
17use anyhow::{Context, Result};
18use chrono::Utc;
19use serde_json::{json, Value};
20use std::collections::{HashMap, HashSet};
21use std::path::Path;
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::sync::RwLock;
25use tracing::{debug, info, warn};
26
27const DEFAULT_MAIN_TEMPERATURE: f32 = 0.7;
28const DEFAULT_TOP_P: f32 = 0.9;
29const DEFAULT_FAST_TEMPERATURE: f32 = 0.3;
30const DEFAULT_ESCALATION_THRESHOLD: f32 = 0.6;
31
32struct RecallResult {
33 messages: Vec<Message>,
34 stats: Option<MemoryRecallStats>,
35}
36
37struct ExtractedEntity {
39 name: String,
40 entity_type: String,
41 confidence: f32,
42}
43
44struct ExtractedConcept {
46 name: String,
47 relevance: f32,
48}
49
50#[derive(Debug, Clone)]
51struct GoalContext {
52 message_id: i64,
53 text: String,
54 requires_tool: bool,
55 satisfied: bool,
56 node_id: Option<i64>,
57}
58
59impl GoalContext {
60 fn new(message_id: i64, text: &str, requires_tool: bool, node_id: Option<i64>) -> Self {
61 Self {
62 message_id,
63 text: text.to_string(),
64 requires_tool,
65 satisfied: !requires_tool,
66 node_id,
67 }
68 }
69}
70
71pub struct AgentCore {
73 profile: AgentProfile,
75 provider: Arc<dyn ModelProvider>,
77 fast_provider: Option<Arc<dyn ModelProvider>>,
79 embeddings_client: Option<EmbeddingsClient>,
81 persistence: Persistence,
83 session_id: String,
85 agent_name: Option<String>,
87 conversation_history: Vec<Message>,
89 tool_registry: Arc<ToolRegistry>,
91 policy_engine: Arc<PolicyEngine>,
93 tool_permission_cache: Arc<RwLock<HashMap<String, bool>>>,
95}
96
97impl AgentCore {
98 pub fn new(
100 profile: AgentProfile,
101 provider: Arc<dyn ModelProvider>,
102 embeddings_client: Option<EmbeddingsClient>,
103 persistence: Persistence,
104 session_id: String,
105 agent_name: Option<String>,
106 tool_registry: Arc<ToolRegistry>,
107 policy_engine: Arc<PolicyEngine>,
108 ) -> Self {
109 Self {
110 profile,
111 provider,
112 fast_provider: None,
113 embeddings_client,
114 persistence,
115 session_id,
116 agent_name,
117 conversation_history: Vec::new(),
118 tool_registry,
119 policy_engine,
120 tool_permission_cache: Arc::new(RwLock::new(HashMap::new())),
121 }
122 }
123
124 pub fn with_fast_provider(mut self, fast_provider: Arc<dyn ModelProvider>) -> Self {
126 self.fast_provider = Some(fast_provider);
127 self
128 }
129
130 pub fn with_session(mut self, session_id: String) -> Self {
132 self.session_id = session_id;
133 self.conversation_history.clear();
134 self.tool_permission_cache = Arc::new(RwLock::new(HashMap::new()));
135 self
136 }
137
138 pub async fn run_step(&mut self, input: &str) -> Result<AgentOutput> {
140 let run_id = format!("run-{}", Utc::now().timestamp_micros());
141 let total_timer = Instant::now();
142
143 let recall_timer = Instant::now();
145 let recall_result = self.recall_memories(input).await?;
146 self.log_timing("run_step.recall_memories", recall_timer);
147 let recalled_messages = recall_result.messages;
148 let recall_stats = recall_result.stats;
149
150 let prompt_timer = Instant::now();
152 let mut prompt = self.build_prompt(input, &recalled_messages).await?;
153 self.log_timing("run_step.build_prompt", prompt_timer);
154
155 let store_user_timer = Instant::now();
157 let user_message_id = self.store_message(MessageRole::User, input).await?;
158 self.log_timing("run_step.store_user_message", store_user_timer);
159
160 let mut goal_context =
162 Some(self.create_goal_context(user_message_id, input, self.profile.enable_graph)?);
163
164 let mut tool_invocations = Vec::new();
166 let mut final_response = String::new();
167 let mut token_usage = None;
168 let mut finish_reason = None;
169 let mut auto_response: Option<String> = None;
170 let mut reasoning: Option<String> = None;
171 let mut reasoning_summary: Option<String> = None;
172
173 if let Some(goal) = goal_context.as_mut() {
175 if goal.requires_tool {
176 if let Some((tool_name, tool_args)) =
177 Self::infer_goal_tool_action(goal.text.as_str())
178 {
179 if self.is_tool_allowed(&tool_name).await {
180 let tool_timer = Instant::now();
181 let tool_result = self.execute_tool(&run_id, &tool_name, &tool_args).await;
182 self.log_timing("run_step.tool_execution.auto", tool_timer);
183 match tool_result {
184 Ok(result) => {
185 let invocation = ToolInvocation::from_result(
186 &tool_name,
187 tool_args.clone(),
188 &result,
189 );
190 if let Err(err) = self
191 .record_goal_tool_result(goal, &tool_name, &tool_args, &result)
192 {
193 warn!("Failed to record goal progress: {}", err);
194 }
195 if result.success {
196 if let Err(err) =
197 self.update_goal_status(goal, "completed", true, None)
198 {
199 warn!("Failed to update goal status: {}", err);
200 } else {
201 goal.satisfied = true;
202 }
203 }
204 auto_response = Some(Self::format_auto_tool_response(
205 &tool_name,
206 invocation.output.as_deref(),
207 ));
208 tool_invocations.push(invocation);
209 }
210 Err(err) => {
211 warn!("Auto tool execution '{}' failed: {}", tool_name, err);
212 }
213 }
214 }
215 }
216 }
217 }
218
219 let skip_model = goal_context
220 .as_ref()
221 .map(|goal| goal.requires_tool && goal.satisfied && auto_response.is_some())
222 .unwrap_or(false);
223
224 let mut fast_model_final: Option<(String, f32)> = None;
226 if !skip_model {
227 if let Some(task_type) = self.detect_task_type(input) {
228 let complexity = self.estimate_task_complexity(input);
229 if self.should_use_fast_model(&task_type, complexity) {
230 let fast_timer = Instant::now();
231 let fast_result = self.fast_reasoning(&task_type, input).await;
232 self.log_timing("run_step.fast_reasoning_attempt", fast_timer);
233 match fast_result {
234 Ok((fast_text, confidence)) => {
235 if confidence >= self.escalation_threshold() {
236 fast_model_final = Some((fast_text, confidence));
237 } else {
238 prompt.push_str(&format!(
239 "\n\nFAST_MODEL_HINT (task={} confidence={:.0}%):\n{}\n\nRefine this hint and produce a complete answer.",
240 task_type,
241 (confidence * 100.0).round(),
242 fast_text
243 ));
244 }
245 }
246 Err(err) => {
247 warn!("Fast reasoning failed for task {}: {}", task_type, err);
248 }
249 }
250 }
251 }
252 }
253
254 if skip_model {
255 final_response = auto_response.unwrap_or_else(|| "Task completed.".to_string());
256 finish_reason = Some("auto_tool".to_string());
257 } else if let Some((fast_text, confidence)) = fast_model_final {
258 final_response = fast_text;
259 finish_reason = Some(format!("fast_model ({:.0}%)", (confidence * 100.0).round()));
260 } else {
261 for _iteration in 0..5 {
263 let generation_config = self.build_generation_config();
265 let model_timer = Instant::now();
266 let response_result = self.provider.generate(&prompt, &generation_config).await;
267 self.log_timing("run_step.main_model_call", model_timer);
268 let response = response_result.context("Failed to generate response from model")?;
269
270 token_usage = response.usage;
271 finish_reason = response.finish_reason.clone();
272 final_response = response.content.clone();
273 reasoning = response.reasoning.clone();
274
275 if let Some(ref reasoning_text) = reasoning {
277 reasoning_summary = self.summarize_reasoning(reasoning_text).await;
278 }
279
280 let sdk_tool_calls = response.tool_calls.clone().unwrap_or_default();
282
283 if sdk_tool_calls.is_empty() {
285 let is_complete = finish_reason.as_ref().is_some_and(|reason| {
287 let reason_lower = reason.to_lowercase();
288 reason_lower.contains("stop")
289 || reason_lower.contains("end_turn")
290 || reason_lower.contains("complete")
291 || reason_lower == "length"
292 });
293
294 let goal_needs_tool = goal_context
296 .as_ref()
297 .is_some_and(|g| g.requires_tool && !g.satisfied);
298
299 if is_complete && !goal_needs_tool {
300 debug!("Early termination: response complete with no tool calls needed");
301 break;
302 }
303 }
304
305 if !sdk_tool_calls.is_empty() {
306 for tool_call in sdk_tool_calls {
308 let tool_name = &tool_call.function_name;
309 let tool_args = &tool_call.arguments;
310
311 if !self.is_tool_allowed(tool_name).await {
313 warn!(
314 "Tool '{}' is not allowed by agent policy - prompting user",
315 tool_name
316 );
317
318 match self.prompt_for_tool_permission(tool_name).await {
320 Ok(true) => {
321 info!("User granted permission for tool '{}'", tool_name);
322 }
324 Ok(false) => {
325 let error_msg =
326 format!("Tool '{}' was denied by user", tool_name);
327 warn!("{}", error_msg);
328 tool_invocations.push(ToolInvocation {
329 name: tool_name.clone(),
330 arguments: tool_args.clone(),
331 success: false,
332 output: None,
333 error: Some(error_msg),
334 });
335 continue;
336 }
337 Err(e) => {
338 let error_msg = format!(
339 "Failed to get user permission for tool '{}': {}",
340 tool_name, e
341 );
342 warn!("{}", error_msg);
343 tool_invocations.push(ToolInvocation {
344 name: tool_name.clone(),
345 arguments: tool_args.clone(),
346 success: false,
347 output: None,
348 error: Some(error_msg),
349 });
350 continue;
351 }
352 }
353 }
354
355 let tool_timer = Instant::now();
357 let exec_result = self.execute_tool(&run_id, tool_name, tool_args).await;
358 self.log_timing("run_step.tool_execution.sdk", tool_timer);
359 match exec_result {
360 Ok(result) => {
361 let invocation = ToolInvocation::from_result(
362 tool_name,
363 tool_args.clone(),
364 &result,
365 );
366 let tool_output = invocation.output.clone().unwrap_or_default();
367 let was_success = invocation.success;
368 let error_message = invocation
369 .error
370 .clone()
371 .unwrap_or_else(|| "Tool execution failed".to_string());
372 tool_invocations.push(invocation);
373
374 if let Some(goal) = goal_context.as_mut() {
375 if let Err(err) = self.record_goal_tool_result(
376 goal, tool_name, tool_args, &result,
377 ) {
378 warn!("Failed to record goal progress: {}", err);
379 }
380 if result.success && goal.requires_tool && !goal.satisfied {
381 if let Err(err) =
382 self.update_goal_status(goal, "in_progress", true, None)
383 {
384 warn!("Failed to update goal status: {}", err);
385 }
386 }
387 }
388
389 if was_success {
390 prompt.push_str(&format!(
392 "\n\nTOOL_RESULT from {}:\n{}\n\nBased on this result, please continue.",
393 tool_name, tool_output
394 ));
395 } else {
396 prompt.push_str(&format!(
397 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
398 error_message
399 ));
400 }
401 }
402 Err(e) => {
403 let error_msg =
404 format!("Error executing tool '{}': {}", tool_name, e);
405 warn!("{}", error_msg);
406 prompt.push_str(&format!(
407 "\n\nTOOL_ERROR: {}\n\nPlease continue without this tool.",
408 error_msg
409 ));
410 tool_invocations.push(ToolInvocation {
411 name: tool_name.clone(),
412 arguments: tool_args.clone(),
413 success: false,
414 output: None,
415 error: Some(error_msg),
416 });
417 }
418 }
419 }
420
421 continue;
423 }
424
425 if let Some(goal) = goal_context.as_ref() {
426 if goal.requires_tool && !goal.satisfied {
427 prompt.push_str(
428 "\n\nGOAL_STATUS: pending. The user request requires executing an allowed tool. Please call an appropriate tool.",
429 );
430 continue;
431 }
432 }
433
434 break;
436 }
437 }
438
439 let store_assistant_timer = Instant::now();
441 let response_message_id = self
442 .store_message_with_reasoning(
443 MessageRole::Assistant,
444 &final_response,
445 reasoning.as_deref(),
446 )
447 .await?;
448 self.log_timing("run_step.store_assistant_message", store_assistant_timer);
449
450 if let Some(goal) = goal_context.as_mut() {
451 if goal.requires_tool {
452 if goal.satisfied {
453 if let Err(err) =
454 self.update_goal_status(goal, "completed", true, Some(response_message_id))
455 {
456 warn!("Failed to finalize goal status: {}", err);
457 }
458 } else if let Err(err) =
459 self.update_goal_status(goal, "blocked", false, Some(response_message_id))
460 {
461 warn!("Failed to record blocked goal status: {}", err);
462 }
463 } else if let Err(err) =
464 self.update_goal_status(goal, "completed", true, Some(response_message_id))
465 {
466 warn!("Failed to finalize goal status: {}", err);
467 }
468 }
469
470 self.conversation_history.push(Message {
472 id: user_message_id,
473 session_id: self.session_id.clone(),
474 role: MessageRole::User,
475 content: input.to_string(),
476 created_at: Utc::now(),
477 });
478
479 self.conversation_history.push(Message {
480 id: response_message_id,
481 session_id: self.session_id.clone(),
482 role: MessageRole::Assistant,
483 content: final_response.clone(),
484 created_at: Utc::now(),
485 });
486
487 let next_action_recommendation =
490 if self.profile.enable_graph && self.conversation_history.len() >= 3 {
491 let graph_timer = Instant::now();
492 let recommendation = self.evaluate_graph_for_next_action(
493 user_message_id,
494 response_message_id,
495 &final_response,
496 &tool_invocations,
497 )?;
498 self.log_timing("run_step.evaluate_graph_for_next_action", graph_timer);
499 recommendation
500 } else {
501 None
502 };
503
504 if let Some(ref recommendation) = next_action_recommendation {
506 tracing::debug!("Knowledge graph recommends next action: {}", recommendation);
507 let system_content = format!("Graph recommendation: {}", recommendation);
508 let system_store_timer = Instant::now();
509 let system_message_id = self
510 .store_message(MessageRole::System, &system_content)
511 .await?;
512 self.log_timing("run_step.store_system_message", system_store_timer);
513
514 self.conversation_history.push(Message {
515 id: system_message_id,
516 session_id: self.session_id.clone(),
517 role: MessageRole::System,
518 content: system_content,
519 created_at: Utc::now(),
520 });
521 }
522
523 let graph_debug = match self.snapshot_graph_debug_info() {
524 Ok(info) => Some(info),
525 Err(err) => {
526 warn!("Failed to capture graph debug info: {}", err);
527 None
528 }
529 };
530
531 self.log_timing("run_step.total", total_timer);
532
533 Ok(AgentOutput {
534 response: final_response,
535 response_message_id: Some(response_message_id),
536 token_usage,
537 tool_invocations,
538 finish_reason,
539 recall_stats,
540 run_id,
541 next_action: next_action_recommendation,
542 reasoning,
543 reasoning_summary,
544 graph_debug,
545 })
546 }
547
548 pub async fn run_spec(&mut self, spec: &AgentSpec) -> Result<AgentOutput> {
550 debug!(
551 "Executing structured spec '{}' (source: {:?})",
552 spec.display_name(),
553 spec.source_path()
554 );
555 let prompt = spec.to_prompt();
556 self.run_step(&prompt).await
557 }
558
559 fn build_generation_config(&self) -> GenerationConfig {
561 let temperature = match self.profile.temperature {
562 Some(temp) if temp.is_finite() => Some(temp.clamp(0.0, 2.0)),
563 Some(_) => {
564 warn!(
565 "Ignoring invalid temperature for agent {:?}, falling back to {}",
566 self.agent_name, DEFAULT_MAIN_TEMPERATURE
567 );
568 Some(DEFAULT_MAIN_TEMPERATURE)
569 }
570 None => None,
571 };
572
573 let top_p = if self.profile.top_p.is_finite() {
574 Some(self.profile.top_p.clamp(0.0, 1.0))
575 } else {
576 warn!(
577 "Invalid top_p detected for agent {:?}, falling back to {}",
578 self.agent_name, DEFAULT_TOP_P
579 );
580 Some(DEFAULT_TOP_P)
581 };
582
583 GenerationConfig {
584 temperature,
585 max_tokens: self.profile.max_context_tokens.map(|t| t as u32),
586 stop_sequences: None,
587 top_p,
588 frequency_penalty: None,
589 presence_penalty: None,
590 }
591 }
592
593 fn snapshot_graph_debug_info(&self) -> Result<GraphDebugInfo> {
594 let mut info = GraphDebugInfo {
595 enabled: self.profile.enable_graph,
596 graph_memory_enabled: self.profile.graph_memory,
597 auto_graph_enabled: self.profile.auto_graph,
598 graph_steering_enabled: self.profile.graph_steering,
599 node_count: 0,
600 edge_count: 0,
601 recent_nodes: Vec::new(),
602 };
603
604 if !self.profile.enable_graph {
605 return Ok(info);
606 }
607
608 info.node_count = self.persistence.count_graph_nodes(&self.session_id)?.max(0) as usize;
609 info.edge_count = self.persistence.count_graph_edges(&self.session_id)?.max(0) as usize;
610
611 let recent_nodes = self
612 .persistence
613 .list_graph_nodes(&self.session_id, None, Some(5))?;
614 info.recent_nodes = recent_nodes
615 .into_iter()
616 .map(|node| GraphDebugNode {
617 id: node.id,
618 node_type: node.node_type.as_str().to_string(),
619 label: node.label,
620 })
621 .collect();
622
623 Ok(info)
624 }
625
626 async fn summarize_reasoning(&self, reasoning: &str) -> Option<String> {
628 let fast_provider = self.fast_provider.as_ref()?;
630
631 if reasoning.len() < 50 {
632 return Some(reasoning.to_string());
634 }
635
636 let summary_prompt = format!(
637 "Summarize the following reasoning in 1-2 concise sentences that explain the thought process:\n\n{}\n\nSummary:",
638 reasoning
639 );
640
641 let config = GenerationConfig {
642 temperature: Some(0.3),
643 max_tokens: Some(100),
644 stop_sequences: None,
645 top_p: Some(0.9),
646 frequency_penalty: None,
647 presence_penalty: None,
648 };
649
650 let timer = Instant::now();
651 let response = fast_provider.generate(&summary_prompt, &config).await;
652 self.log_timing("summarize_reasoning.generate", timer);
653 match response {
654 Ok(response) => {
655 let summary = response.content.trim().to_string();
656 if !summary.is_empty() {
657 debug!("Generated reasoning summary: {}", summary);
658 Some(summary)
659 } else {
660 None
661 }
662 }
663 Err(e) => {
664 warn!("Failed to summarize reasoning: {}", e);
665 None
666 }
667 }
668 }
669
670 async fn recall_memories(&self, query: &str) -> Result<RecallResult> {
672 const RECENT_CONTEXT: i64 = 2;
673 let mut context = Vec::new();
675 let mut seen_ids = HashSet::new();
676
677 let recent_messages = self
678 .persistence
679 .list_messages(&self.session_id, RECENT_CONTEXT)?;
680
681 if self.conversation_history.is_empty() && recent_messages.is_empty() {
684 return Ok(RecallResult {
685 messages: Vec::new(),
686 stats: Some(MemoryRecallStats {
687 strategy: MemoryRecallStrategy::RecentContext {
688 limit: RECENT_CONTEXT as usize,
689 },
690 matches: Vec::new(),
691 }),
692 });
693 }
694
695 for message in recent_messages {
696 seen_ids.insert(message.id);
697 context.push(message);
698 }
699
700 if self.profile.enable_graph && self.profile.graph_memory {
702 let mut graph_messages = Vec::new();
703
704 for msg in &context {
706 let nodes = self.persistence.list_graph_nodes(
708 &self.session_id,
709 Some(NodeType::Message),
710 Some(10),
711 )?;
712
713 for node in nodes {
714 if let Some(msg_id) = node.properties["message_id"].as_i64() {
715 if msg_id == msg.id {
716 let neighbors = self.persistence.traverse_neighbors(
718 &self.session_id,
719 node.id,
720 TraversalDirection::Both,
721 self.profile.graph_depth,
722 )?;
723
724 for neighbor in neighbors {
726 if neighbor.node_type == NodeType::Message {
727 if let Some(related_msg_id) =
728 neighbor.properties["message_id"].as_i64()
729 {
730 if !seen_ids.contains(&related_msg_id) {
731 if let Some(related_msg) =
732 self.persistence.get_message(related_msg_id)?
733 {
734 seen_ids.insert(related_msg.id);
735 graph_messages.push(related_msg);
736 }
737 }
738 }
739 }
740 }
741 }
742 }
743 }
744 }
745
746 context.extend(graph_messages);
748 }
749
750 if let Some(client) = &self.embeddings_client {
751 if self.profile.memory_k == 0 || query.trim().is_empty() {
752 return Ok(RecallResult {
753 messages: context,
754 stats: None,
755 });
756 }
757
758 let embed_timer = Instant::now();
759 let embed_result = client.embed_batch(&[query]).await;
760 self.log_timing("recall_memories.embed_batch", embed_timer);
761 match embed_result {
762 Ok(mut embeddings) => match embeddings.pop() {
763 Some(query_embedding) if !query_embedding.is_empty() => {
764 let recalled = self.persistence.recall_top_k(
765 &self.session_id,
766 &query_embedding,
767 self.profile.memory_k,
768 )?;
769
770 let mut matches = Vec::new();
771 let mut semantic_context = Vec::new();
772
773 for (memory, score) in recalled {
774 if let Some(message_id) = memory.message_id {
775 if seen_ids.contains(&message_id) {
776 continue;
777 }
778
779 if let Some(message) = self.persistence.get_message(message_id)? {
780 seen_ids.insert(message.id);
781 matches.push(MemoryRecallMatch {
782 message_id: Some(message.id),
783 score,
784 role: message.role.clone(),
785 preview: preview_text(&message.content),
786 });
787 semantic_context.push(message);
788 }
789 } else {
790 if let Some(transcription_text) =
792 self.persistence.get_transcription_by_embedding(memory.id)?
793 {
794 let transcription_message = Message {
796 id: memory.id, session_id: memory.session_id.clone(),
798 role: MessageRole::User, content: format!("[Transcription] {}", transcription_text),
800 created_at: memory.created_at,
801 };
802
803 matches.push(MemoryRecallMatch {
804 message_id: None, score,
806 role: MessageRole::User,
807 preview: preview_text(&transcription_text),
808 });
809 semantic_context.push(transcription_message);
810 }
811 }
812 }
813
814 if self.profile.enable_graph && self.profile.graph_memory {
816 let mut graph_expanded = Vec::new();
817
818 for msg in &semantic_context {
819 let nodes = self.persistence.list_graph_nodes(
821 &self.session_id,
822 Some(NodeType::Message),
823 Some(100),
824 )?;
825
826 for node in nodes {
827 if let Some(msg_id) = node.properties["message_id"].as_i64() {
828 if msg_id == msg.id {
829 let neighbors = self.persistence.traverse_neighbors(
831 &self.session_id,
832 node.id,
833 TraversalDirection::Both,
834 self.profile.graph_depth,
835 )?;
836
837 for neighbor in neighbors {
838 if matches!(
840 neighbor.node_type,
841 NodeType::Fact
842 | NodeType::Concept
843 | NodeType::Entity
844 ) {
845 let graph_content = format!(
847 "[Graph Context - {} {}]: {}",
848 neighbor.node_type.as_str(),
849 neighbor.label,
850 neighbor.properties
851 );
852
853 let graph_msg = Message {
855 id: -1, session_id: self.session_id.clone(),
857 role: MessageRole::System,
858 content: graph_content,
859 created_at: Utc::now(),
860 };
861
862 graph_expanded.push(graph_msg);
863 }
864 }
865 }
866 }
867 }
868 }
869
870 let total_slots = self.profile.memory_k.max(1);
872 let mut graph_limit =
873 ((total_slots as f32) * self.profile.graph_weight).round() as usize;
874 graph_limit = graph_limit.min(total_slots);
875 if graph_limit == 0 && !graph_expanded.is_empty() {
876 graph_limit = 1;
877 }
878
879 let mut semantic_limit = total_slots.saturating_sub(graph_limit);
880 if semantic_limit == 0 && !semantic_context.is_empty() {
881 semantic_limit = 1;
882 graph_limit = graph_limit.saturating_sub(1);
883 }
884
885 let mut limited_semantic = semantic_context;
886 if limited_semantic.len() > semantic_limit && semantic_limit > 0 {
887 limited_semantic.truncate(semantic_limit);
888 }
889
890 let mut limited_graph = graph_expanded;
891 if limited_graph.len() > graph_limit && graph_limit > 0 {
892 limited_graph.truncate(graph_limit);
893 }
894
895 context.extend(limited_semantic);
896 context.extend(limited_graph);
897 } else {
898 context.extend(semantic_context);
899 }
900
901 return Ok(RecallResult {
902 messages: context,
903 stats: Some(MemoryRecallStats {
904 strategy: MemoryRecallStrategy::Semantic {
905 requested: self.profile.memory_k,
906 returned: matches.len(),
907 },
908 matches,
909 }),
910 });
911 }
912 _ => {
913 return Ok(RecallResult {
914 messages: context,
915 stats: Some(MemoryRecallStats {
916 strategy: MemoryRecallStrategy::Semantic {
917 requested: self.profile.memory_k,
918 returned: 0,
919 },
920 matches: Vec::new(),
921 }),
922 });
923 }
924 },
925 Err(err) => {
926 warn!("Failed to embed recall query: {}", err);
927 return Ok(RecallResult {
928 messages: context,
929 stats: None,
930 });
931 }
932 }
933 }
934
935 let limit = self.profile.memory_k as i64;
937 let messages = self.persistence.list_messages(&self.session_id, limit)?;
938
939 let stats = if self.profile.memory_k > 0 {
940 Some(MemoryRecallStats {
941 strategy: MemoryRecallStrategy::RecentContext {
942 limit: self.profile.memory_k,
943 },
944 matches: Vec::new(),
945 })
946 } else {
947 None
948 };
949
950 Ok(RecallResult { messages, stats })
951 }
952
953 async fn build_prompt(&self, input: &str, context_messages: &[Message]) -> Result<String> {
955 let mut prompt = String::new();
956
957 if let Some(system_prompt) = &self.profile.prompt {
959 prompt.push_str("System: ");
960 prompt.push_str(system_prompt);
961 prompt.push_str("\n\n");
962 }
963
964 let available_tools = self.tool_registry.list();
966 tracing::debug!("Tool registry has {} tools", available_tools.len());
967 if !available_tools.is_empty() {
968 prompt.push_str("Available tools:\n");
969 for tool_name in &available_tools {
970 info!(
971 "Checking tool: {} - allowed: {}",
972 tool_name,
973 self.is_tool_allowed(tool_name).await
974 );
975 if self.is_tool_allowed(tool_name).await {
976 if let Some(tool) = self.tool_registry.get(tool_name) {
977 prompt.push_str(&format!("- {}: {}\n", tool_name, tool.description()));
978 }
979 }
980 }
981 prompt.push('\n');
982 }
983
984 if !context_messages.is_empty() {
986 prompt.push_str("Previous conversation:\n");
987 for msg in context_messages {
988 prompt.push_str(&format!("{}: {}\n", msg.role.as_str(), msg.content));
989 }
990 prompt.push('\n');
991 }
992
993 prompt.push_str(&format!("user: {}\n", input));
995
996 prompt.push_str("assistant:");
997
998 Ok(prompt)
999 }
1000
1001 async fn store_message(&self, role: MessageRole, content: &str) -> Result<i64> {
1003 self.store_message_with_reasoning(role, content, None).await
1004 }
1005
1006 async fn store_message_with_reasoning(
1008 &self,
1009 role: MessageRole,
1010 content: &str,
1011 reasoning: Option<&str>,
1012 ) -> Result<i64> {
1013 let message_id = self
1014 .persistence
1015 .insert_message(&self.session_id, role.clone(), content)
1016 .context("Failed to store message")?;
1017
1018 let mut embedding_id = None;
1019
1020 if let Some(client) = &self.embeddings_client {
1021 if !content.trim().is_empty() {
1022 let embed_timer = Instant::now();
1023 let embed_result = client.embed_batch(&[content]).await;
1024 self.log_timing("embeddings.message_content", embed_timer);
1025 match embed_result {
1026 Ok(mut embeddings) => {
1027 if let Some(embedding) = embeddings.pop() {
1028 if !embedding.is_empty() {
1029 match self.persistence.insert_memory_vector(
1030 &self.session_id,
1031 Some(message_id),
1032 &embedding,
1033 ) {
1034 Ok(emb_id) => {
1035 embedding_id = Some(emb_id);
1036 }
1037 Err(err) => {
1038 warn!(
1039 "Failed to persist embedding for message {}: {}",
1040 message_id, err
1041 );
1042 }
1043 }
1044 }
1045 }
1046 }
1047 Err(err) => {
1048 warn!(
1049 "Failed to create embedding for message {}: {}",
1050 message_id, err
1051 );
1052 }
1053 }
1054 }
1055 }
1056
1057 if self.profile.enable_graph && self.profile.auto_graph {
1059 self.build_graph_for_message(message_id, role, content, embedding_id, reasoning)?;
1060 }
1061
1062 Ok(message_id)
1063 }
1064
1065 fn build_graph_for_message(
1067 &self,
1068 message_id: i64,
1069 role: MessageRole,
1070 content: &str,
1071 embedding_id: Option<i64>,
1072 reasoning: Option<&str>,
1073 ) -> Result<()> {
1074 use serde_json::json;
1075
1076 let mut message_props = json!({
1078 "message_id": message_id,
1079 "role": role.as_str(),
1080 "content_preview": preview_text(content),
1081 "timestamp": Utc::now().to_rfc3339(),
1082 });
1083
1084 if let Some(reasoning_text) = reasoning {
1086 if !reasoning_text.is_empty() {
1087 message_props["has_reasoning"] = json!(true);
1088 message_props["reasoning_preview"] = json!(preview_text(reasoning_text));
1089 }
1090 }
1091
1092 let message_node_id = self.persistence.insert_graph_node(
1093 &self.session_id,
1094 NodeType::Message,
1095 &format!("{:?}Message", role),
1096 &message_props,
1097 embedding_id,
1098 )?;
1099
1100 let mut entities = self.extract_entities_from_text(content);
1102 let mut concepts = self.extract_concepts_from_text(content);
1103
1104 if let Some(reasoning_text) = reasoning {
1107 if !reasoning_text.is_empty() {
1108 debug!(
1109 "Extracting entities/concepts from reasoning for message {}",
1110 message_id
1111 );
1112 let reasoning_entities = self.extract_entities_from_text(reasoning_text);
1113 let reasoning_concepts = self.extract_concepts_from_text(reasoning_text);
1114
1115 for mut reasoning_entity in reasoning_entities {
1117 if let Some(existing) = entities.iter_mut().find(|e| {
1119 e.name.to_lowercase() == reasoning_entity.name.to_lowercase()
1120 && e.entity_type == reasoning_entity.entity_type
1121 }) {
1122 existing.confidence =
1124 (existing.confidence + reasoning_entity.confidence * 0.5).min(1.0);
1125 } else {
1126 reasoning_entity.confidence *= 0.8;
1128 entities.push(reasoning_entity);
1129 }
1130 }
1131
1132 for mut reasoning_concept in reasoning_concepts {
1134 if let Some(existing) = concepts
1135 .iter_mut()
1136 .find(|c| c.name.to_lowercase() == reasoning_concept.name.to_lowercase())
1137 {
1138 existing.relevance =
1139 (existing.relevance + reasoning_concept.relevance * 0.5).min(1.0);
1140 } else {
1141 reasoning_concept.relevance *= 0.8;
1142 concepts.push(reasoning_concept);
1143 }
1144 }
1145 }
1146 }
1147
1148 for entity in entities {
1150 let entity_node_id = self.persistence.insert_graph_node(
1151 &self.session_id,
1152 NodeType::Entity,
1153 &entity.entity_type,
1154 &json!({
1155 "name": entity.name,
1156 "type": entity.entity_type,
1157 "extracted_from": message_id,
1158 }),
1159 None,
1160 )?;
1161
1162 self.persistence.insert_graph_edge(
1164 &self.session_id,
1165 message_node_id,
1166 entity_node_id,
1167 EdgeType::Mentions,
1168 Some("mentions"),
1169 Some(&json!({"confidence": entity.confidence})),
1170 entity.confidence,
1171 )?;
1172 }
1173
1174 for concept in concepts {
1176 let concept_node_id = self.persistence.insert_graph_node(
1177 &self.session_id,
1178 NodeType::Concept,
1179 "Concept",
1180 &json!({
1181 "name": concept.name,
1182 "extracted_from": message_id,
1183 }),
1184 None,
1185 )?;
1186
1187 self.persistence.insert_graph_edge(
1189 &self.session_id,
1190 message_node_id,
1191 concept_node_id,
1192 EdgeType::RelatesTo,
1193 Some("discusses"),
1194 Some(&json!({"relevance": concept.relevance})),
1195 concept.relevance,
1196 )?;
1197 }
1198
1199 let recent_messages = self.persistence.list_messages(&self.session_id, 2)?;
1201 if recent_messages.len() > 1 {
1202 let nodes = self.persistence.list_graph_nodes(
1204 &self.session_id,
1205 Some(NodeType::Message),
1206 Some(10),
1207 )?;
1208
1209 for node in nodes {
1210 if let Some(prev_msg_id) = node.properties["message_id"].as_i64() {
1211 if prev_msg_id != message_id && prev_msg_id == recent_messages[0].id {
1212 self.persistence.insert_graph_edge(
1214 &self.session_id,
1215 node.id,
1216 message_node_id,
1217 EdgeType::FollowsFrom,
1218 Some("conversation_flow"),
1219 None,
1220 1.0,
1221 )?;
1222 break;
1223 }
1224 }
1225 }
1226 }
1227
1228 Ok(())
1229 }
1230
1231 fn create_goal_context(
1232 &self,
1233 message_id: i64,
1234 input: &str,
1235 persist: bool,
1236 ) -> Result<GoalContext> {
1237 let requires_tool = Self::goal_requires_tool(input);
1238 let node_id = if self.profile.enable_graph {
1239 if persist {
1240 let properties = json!({
1241 "message_id": message_id,
1242 "goal_text": input,
1243 "status": "pending",
1244 "requires_tool": requires_tool,
1245 "satisfied": false,
1246 "created_at": Utc::now().to_rfc3339(),
1247 });
1248 Some(self.persistence.insert_graph_node(
1249 &self.session_id,
1250 NodeType::Goal,
1251 "Goal",
1252 &properties,
1253 None,
1254 )?)
1255 } else {
1256 None
1257 }
1258 } else {
1259 None
1260 };
1261
1262 Ok(GoalContext::new(message_id, input, requires_tool, node_id))
1263 }
1264
1265 fn update_goal_status(
1266 &self,
1267 goal: &mut GoalContext,
1268 status: &str,
1269 satisfied: bool,
1270 response_message_id: Option<i64>,
1271 ) -> Result<()> {
1272 goal.satisfied = satisfied;
1273 if let Some(node_id) = goal.node_id {
1274 let properties = json!({
1275 "message_id": goal.message_id,
1276 "goal_text": goal.text,
1277 "status": status,
1278 "requires_tool": goal.requires_tool,
1279 "satisfied": satisfied,
1280 "response_message_id": response_message_id,
1281 "updated_at": Utc::now().to_rfc3339(),
1282 });
1283 self.persistence.update_graph_node(node_id, &properties)?;
1284 }
1285 Ok(())
1286 }
1287
1288 fn record_goal_tool_result(
1289 &self,
1290 goal: &GoalContext,
1291 tool_name: &str,
1292 args: &Value,
1293 result: &ToolResult,
1294 ) -> Result<()> {
1295 if let Some(goal_node_id) = goal.node_id {
1296 let timestamp = Utc::now().to_rfc3339();
1297 let mut properties = json!({
1298 "tool": tool_name,
1299 "arguments": args,
1300 "success": result.success,
1301 "output_preview": preview_text(&result.output),
1302 "error": result.error,
1303 "timestamp": timestamp,
1304 });
1305
1306 let mut prompt_payload: Option<Value> = None;
1307 if tool_name == "prompt_user" && result.success {
1308 match serde_json::from_str::<Value>(&result.output) {
1309 Ok(payload) => {
1310 if let Some(props) = properties.as_object_mut() {
1311 props.insert("prompt_user_payload".to_string(), payload.clone());
1312 if let Some(response) = payload.get("response") {
1313 props.insert(
1314 "response_preview".to_string(),
1315 Value::String(preview_json_value(response)),
1316 );
1317 }
1318 }
1319 prompt_payload = Some(payload);
1320 }
1321 Err(err) => {
1322 warn!("Failed to parse prompt_user payload for graph: {}", err);
1323 if let Some(props) = properties.as_object_mut() {
1324 props.insert(
1325 "prompt_user_parse_error".to_string(),
1326 Value::String(err.to_string()),
1327 );
1328 }
1329 }
1330 }
1331 }
1332
1333 let tool_node_id = self.persistence.insert_graph_node(
1334 &self.session_id,
1335 NodeType::ToolResult,
1336 tool_name,
1337 &properties,
1338 None,
1339 )?;
1340 self.persistence.insert_graph_edge(
1341 &self.session_id,
1342 tool_node_id,
1343 goal_node_id,
1344 EdgeType::Produces,
1345 Some("satisfies"),
1346 None,
1347 if result.success { 1.0 } else { 0.1 },
1348 )?;
1349
1350 if let Some(payload) = prompt_payload {
1351 let response_preview = payload
1352 .get("response")
1353 .map(preview_json_value)
1354 .unwrap_or_default();
1355
1356 let response_properties = json!({
1357 "prompt": payload_field(&payload, "prompt"),
1358 "input_type": payload_field(&payload, "input_type"),
1359 "response": payload_field(&payload, "response"),
1360 "display_value": payload_field(&payload, "display_value"),
1361 "selections": payload_field(&payload, "selections"),
1362 "metadata": payload_field(&payload, "metadata"),
1363 "used_default": payload_field(&payload, "used_default"),
1364 "used_prefill": payload_field(&payload, "used_prefill"),
1365 "response_preview": response_preview,
1366 "timestamp": timestamp,
1367 });
1368
1369 let response_node_id = self.persistence.insert_graph_node(
1370 &self.session_id,
1371 NodeType::Event,
1372 "UserInput",
1373 &response_properties,
1374 None,
1375 )?;
1376
1377 self.persistence.insert_graph_edge(
1378 &self.session_id,
1379 tool_node_id,
1380 response_node_id,
1381 EdgeType::Produces,
1382 Some("captures_input"),
1383 None,
1384 1.0,
1385 )?;
1386
1387 self.persistence.insert_graph_edge(
1388 &self.session_id,
1389 response_node_id,
1390 goal_node_id,
1391 EdgeType::RelatesTo,
1392 Some("addresses_goal"),
1393 None,
1394 0.9,
1395 )?;
1396 }
1397 }
1398 Ok(())
1399 }
1400
1401 fn goal_requires_tool(input: &str) -> bool {
1402 let normalized = input.to_lowercase();
1403 const ACTION_VERBS: [&str; 18] = [
1404 "list", "show", "read", "write", "create", "update", "delete", "run", "execute",
1405 "open", "search", "fetch", "download", "scan", "compile", "test", "build", "inspect",
1406 ];
1407
1408 if ACTION_VERBS
1409 .iter()
1410 .any(|verb| normalized.contains(verb) && normalized.contains(' '))
1411 {
1412 return true;
1413 }
1414
1415 let mentions_local_context = normalized.contains("this directory")
1418 || normalized.contains("current directory")
1419 || normalized.contains("this folder")
1420 || normalized.contains("here");
1421
1422 let mentions_project = normalized.contains("this project")
1423 || normalized.contains("this repo")
1424 || normalized.contains("this repository")
1425 || normalized.contains("this codebase")
1426 || ((normalized.contains("project")
1428 || normalized.contains("repo")
1429 || normalized.contains("repository")
1430 || normalized.contains("codebase"))
1431 && mentions_local_context);
1432
1433 let asks_about_project = normalized.contains("what can")
1434 || normalized.contains("what is")
1435 || normalized.contains("what does")
1436 || normalized.contains("tell me")
1437 || normalized.contains("describe")
1438 || normalized.contains("about");
1439
1440 mentions_project && asks_about_project
1441 }
1442
1443 fn escalation_threshold(&self) -> f32 {
1444 if self.profile.escalation_threshold.is_finite() {
1445 self.profile.escalation_threshold.clamp(0.0, 1.0)
1446 } else {
1447 warn!(
1448 "Invalid escalation_threshold for agent {:?}, defaulting to {}",
1449 self.agent_name, DEFAULT_ESCALATION_THRESHOLD
1450 );
1451 DEFAULT_ESCALATION_THRESHOLD
1452 }
1453 }
1454
1455 fn detect_task_type(&self, input: &str) -> Option<String> {
1456 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1457 return None;
1458 }
1459
1460 let text = input.to_lowercase();
1461
1462 let candidates: [(&str, &[&str]); 6] = [
1463 ("entity_extraction", &["entity", "extract", "named"]),
1464 ("decision_routing", &["classify", "categorize", "route"]),
1465 (
1466 "tool_selection",
1467 &["which tool", "use which tool", "tool should"],
1468 ),
1469 ("confidence_scoring", &["confidence", "certainty"]),
1470 ("summarization", &["summarize", "summary"]),
1471 ("graph_analysis", &["graph", "connection", "relationships"]),
1472 ];
1473
1474 for (task, keywords) in candidates {
1475 if keywords.iter().any(|kw| text.contains(kw))
1476 && self.profile.fast_model_tasks.iter().any(|t| t == task)
1477 {
1478 return Some(task.to_string());
1479 }
1480 }
1481
1482 None
1483 }
1484
1485 fn estimate_task_complexity(&self, input: &str) -> f32 {
1486 let words = input.split_whitespace().count() as f32;
1487 let clauses =
1488 input.matches(" and ").count() as f32 + input.matches(" then ").count() as f32;
1489 let newlines = input.matches('\n').count() as f32;
1490
1491 let length_factor = (words / 120.0).min(1.0);
1492 let clause_factor = (clauses / 4.0).min(1.0);
1493 let structure_factor = (newlines / 5.0).min(1.0);
1494
1495 (0.6 * length_factor + 0.3 * clause_factor + 0.1 * structure_factor).clamp(0.0, 1.0)
1496 }
1497
1498 fn infer_goal_tool_action(goal_text: &str) -> Option<(String, Value)> {
1499 let text = goal_text.to_lowercase();
1500
1501 let mentions_local_context = text.contains("this directory")
1503 || text.contains("current directory")
1504 || text.contains("this folder")
1505 || text.contains("here");
1506
1507 let mentions_project = text.contains("this project")
1508 || text.contains("this repo")
1509 || text.contains("this repository")
1510 || text.contains("this codebase")
1511 || ((text.contains("project")
1512 || text.contains("repo")
1513 || text.contains("repository")
1514 || text.contains("codebase"))
1515 && mentions_local_context);
1516
1517 let asks_about_project = text.contains("what can")
1518 || text.contains("what is")
1519 || text.contains("what does")
1520 || text.contains("tell me")
1521 || text.contains("describe")
1522 || text.contains("about");
1523
1524 if mentions_project && asks_about_project {
1525 for candidate in &["README.md", "Readme.md", "readme.md"] {
1527 if Path::new(candidate).exists() {
1528 return Some((
1529 "file_read".to_string(),
1530 json!({
1531 "path": candidate,
1532 "max_bytes": 65536
1533 }),
1534 ));
1535 }
1536 }
1537
1538 return Some((
1540 "search".to_string(),
1541 json!({
1542 "query": "Cargo.toml|package.json|pyproject.toml|setup.py",
1543 "regex": true,
1544 "case_sensitive": false,
1545 "max_results": 20
1546 }),
1547 ));
1548 }
1549
1550 if text.contains("list")
1552 && (text.contains("directory") || text.contains("files") || text.contains("folder"))
1553 {
1554 return Some((
1555 "shell".to_string(),
1556 json!({
1557 "command": "ls -a"
1558 }),
1559 ));
1560 }
1561
1562 if text.contains("show") && text.contains("current directory") {
1563 return Some((
1564 "shell".to_string(),
1565 json!({
1566 "command": "ls -a"
1567 }),
1568 ));
1569 }
1570
1571 None
1575 }
1576
1577 fn parse_confidence(text: &str) -> Option<f32> {
1578 for line in text.lines() {
1579 let lower = line.to_lowercase();
1580 if lower.contains("confidence") {
1581 let token = lower
1582 .split(|c: char| !(c.is_ascii_digit() || c == '.'))
1583 .find(|chunk| !chunk.is_empty())?;
1584 if let Ok(value) = token.parse::<f32>() {
1585 if (0.0..=1.0).contains(&value) {
1586 return Some(value);
1587 }
1588 }
1589 }
1590 }
1591 None
1592 }
1593
1594 fn strip_fast_answer(text: &str) -> String {
1595 let mut answer = String::new();
1596 for line in text.lines() {
1597 if line.to_lowercase().starts_with("answer:") {
1598 answer.push_str(line.split_once(':').map(|x| x.1).unwrap_or("").trim());
1599 break;
1600 }
1601 }
1602 if answer.is_empty() {
1603 text.trim().to_string()
1604 } else {
1605 answer
1606 }
1607 }
1608
1609 fn format_auto_tool_response(tool_name: &str, output: Option<&str>) -> String {
1610 let sanitized = output.unwrap_or("").trim();
1611 if sanitized.is_empty() {
1612 return format!("Executed `{}` successfully.", tool_name);
1613 }
1614
1615 if tool_name == "file_read" {
1616 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1617 let path = value.get("path").and_then(|v| v.as_str()).unwrap_or("file");
1618 let content = value.get("content").and_then(|v| v.as_str()).unwrap_or("");
1619
1620 let max_chars = 4000;
1622 let display_content = if content.len() > max_chars {
1623 let mut snippet = content[..max_chars].to_string();
1624 snippet.push_str("\n...\n[truncated]");
1625 snippet
1626 } else {
1627 content.to_string()
1628 };
1629
1630 return format!("Contents of {}:\n{}", path, display_content);
1631 }
1632 }
1633
1634 if tool_name == "search" {
1635 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1636 let query = value.get("query").and_then(|v| v.as_str()).unwrap_or("");
1637
1638 if let Some(results) = value.get("results").and_then(|v| v.as_array()) {
1639 if results.is_empty() {
1640 return if query.is_empty() {
1641 "Search returned no results.".to_string()
1642 } else {
1643 format!("Search for {:?} returned no results.", query)
1644 };
1645 }
1646
1647 let mut lines = Vec::new();
1648 if query.is_empty() {
1649 lines.push("Search results:".to_string());
1650 } else {
1651 lines.push(format!("Search results for {:?}:", query));
1652 }
1653
1654 for entry in results.iter().take(5) {
1655 let path = entry
1656 .get("path")
1657 .and_then(|v| v.as_str())
1658 .unwrap_or("<unknown>");
1659 let line = entry.get("line").and_then(|v| v.as_u64()).unwrap_or(0);
1660 let snippet = entry
1661 .get("snippet")
1662 .and_then(|v| v.as_str())
1663 .unwrap_or("")
1664 .replace('\n', " ");
1665
1666 lines.push(format!("- {}:{} - {}", path, line, snippet));
1667 }
1668
1669 return lines.join("\n");
1670 }
1671 }
1672 }
1673
1674 if tool_name == "shell" || tool_name == "bash" {
1675 if let Ok(value) = serde_json::from_str::<Value>(sanitized) {
1676 let std_out = value
1677 .get("stdout")
1678 .and_then(|v| v.as_str())
1679 .unwrap_or(sanitized);
1680 let command = value.get("command").and_then(|v| v.as_str()).unwrap_or("");
1681 let stderr = value
1682 .get("stderr")
1683 .and_then(|v| v.as_str())
1684 .map(|s| s.trim())
1685 .filter(|s| !s.is_empty())
1686 .unwrap_or("");
1687 let mut response = String::new();
1688 if !command.is_empty() {
1689 response.push_str(&format!("Command `{}` output:\n", command));
1690 }
1691 response.push_str(std_out.trim_end());
1692 if !stderr.is_empty() {
1693 response.push_str("\n\nstderr:\n");
1694 response.push_str(stderr);
1695 }
1696 if response.trim().is_empty() {
1697 return "Command completed without output.".to_string();
1698 }
1699 return response;
1700 }
1701 }
1702
1703 sanitized.to_string()
1704 }
1705
1706 fn extract_entities_from_text(&self, text: &str) -> Vec<ExtractedEntity> {
1708 if self.profile.fast_reasoning
1710 && self.fast_provider.is_some()
1711 && self
1712 .profile
1713 .fast_model_tasks
1714 .contains(&"entity_extraction".to_string())
1715 {
1716 debug!("Using fast model for entity extraction");
1718 }
1721
1722 let mut entities = Vec::new();
1723
1724 let url_regex = regex::Regex::new(r"https?://[^\s]+").unwrap();
1729 for mat in url_regex.find_iter(text) {
1730 entities.push(ExtractedEntity {
1731 name: mat.as_str().to_string(),
1732 entity_type: "URL".to_string(),
1733 confidence: 0.9,
1734 });
1735 }
1736
1737 let email_regex =
1739 regex::Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b").unwrap();
1740 for mat in email_regex.find_iter(text) {
1741 entities.push(ExtractedEntity {
1742 name: mat.as_str().to_string(),
1743 entity_type: "Email".to_string(),
1744 confidence: 0.9,
1745 });
1746 }
1747
1748 let quote_regex = regex::Regex::new(r#""([^"]+)""#).unwrap();
1750 for cap in quote_regex.captures_iter(text) {
1751 if let Some(quoted) = cap.get(1) {
1752 entities.push(ExtractedEntity {
1753 name: quoted.as_str().to_string(),
1754 entity_type: "Quote".to_string(),
1755 confidence: 0.7,
1756 });
1757 }
1758 }
1759
1760 entities
1761 }
1762
1763 async fn fast_reasoning(&self, task: &str, input: &str) -> Result<(String, f32)> {
1765 let total_timer = Instant::now();
1766 let result = if let Some(ref fast_provider) = self.fast_provider {
1767 let prompt = format!(
1768 "You are a fast specialist model that assists a more capable agent.\nTask: {}\nInput: {}\n\nRespond with two lines:\nAnswer: <concise result>\nConfidence: <0-1 decimal>",
1769 task, input
1770 );
1771
1772 let fast_temperature = if self.profile.fast_model_temperature.is_finite() {
1773 self.profile.fast_model_temperature.clamp(0.0, 2.0)
1774 } else {
1775 warn!(
1776 "Invalid fast_model_temperature for agent {:?}, using {}",
1777 self.agent_name, DEFAULT_FAST_TEMPERATURE
1778 );
1779 DEFAULT_FAST_TEMPERATURE
1780 };
1781
1782 let config = GenerationConfig {
1783 temperature: Some(fast_temperature),
1784 max_tokens: Some(256), stop_sequences: None,
1786 top_p: Some(DEFAULT_TOP_P),
1787 frequency_penalty: None,
1788 presence_penalty: None,
1789 };
1790
1791 let call_timer = Instant::now();
1792 let response_result = fast_provider.generate(&prompt, &config).await;
1793 self.log_timing("fast_reasoning.generate", call_timer);
1794 let response = response_result?;
1795
1796 let confidence = Self::parse_confidence(&response.content).unwrap_or(0.7);
1797 let cleaned = Self::strip_fast_answer(&response.content);
1798
1799 Ok((cleaned, confidence))
1800 } else {
1801 Ok((String::new(), 0.0))
1803 };
1804
1805 self.log_timing("fast_reasoning.total", total_timer);
1806 result
1807 }
1808
1809 fn should_use_fast_model(&self, task_type: &str, complexity_score: f32) -> bool {
1811 if !self.profile.fast_reasoning || self.fast_provider.is_none() {
1813 return false; }
1815
1816 if !self
1818 .profile
1819 .fast_model_tasks
1820 .contains(&task_type.to_string())
1821 {
1822 return false; }
1824
1825 let threshold = self.escalation_threshold();
1827 if complexity_score > threshold {
1828 info!(
1829 "Task complexity {} exceeds threshold {}, using main model",
1830 complexity_score, threshold
1831 );
1832 return false; }
1834
1835 true }
1837
1838 fn extract_concepts_from_text(&self, text: &str) -> Vec<ExtractedConcept> {
1840 let mut concepts = Vec::new();
1841
1842 let concept_keywords = vec![
1844 ("graph", "Knowledge Graph"),
1845 ("memory", "Memory System"),
1846 ("embedding", "Embeddings"),
1847 ("tool", "Tool Usage"),
1848 ("agent", "Agent System"),
1849 ("database", "Database"),
1850 ("query", "Query Processing"),
1851 ("node", "Graph Node"),
1852 ("edge", "Graph Edge"),
1853 ];
1854
1855 let text_lower = text.to_lowercase();
1856 for (keyword, concept_name) in concept_keywords {
1857 if text_lower.contains(keyword) {
1858 concepts.push(ExtractedConcept {
1859 name: concept_name.to_string(),
1860 relevance: 0.6,
1861 });
1862 }
1863 }
1864
1865 concepts
1866 }
1867
1868 pub fn session_id(&self) -> &str {
1870 &self.session_id
1871 }
1872
1873 pub fn profile(&self) -> &AgentProfile {
1875 &self.profile
1876 }
1877
1878 pub fn agent_name(&self) -> Option<&str> {
1880 self.agent_name.as_deref()
1881 }
1882
1883 pub fn conversation_history(&self) -> &[Message] {
1885 &self.conversation_history
1886 }
1887
1888 pub fn load_history(&mut self, limit: i64) -> Result<()> {
1890 self.conversation_history = self.persistence.list_messages(&self.session_id, limit)?;
1891 Ok(())
1892 }
1893
1894 async fn is_tool_allowed(&self, tool_name: &str) -> bool {
1896 {
1898 let cache = self.tool_permission_cache.read().await;
1899 if let Some(&allowed) = cache.get(tool_name) {
1900 return allowed;
1901 }
1902 }
1903
1904 let profile_allowed = self.profile.is_tool_allowed(tool_name);
1906 debug!(
1907 "Profile check for tool '{}': allowed={}, allowed_tools={:?}, denied_tools={:?}",
1908 tool_name, profile_allowed, self.profile.allowed_tools, self.profile.denied_tools
1909 );
1910 if !profile_allowed {
1911 self.tool_permission_cache
1912 .write()
1913 .await
1914 .insert(tool_name.to_string(), false);
1915 return false;
1916 }
1917
1918 let agent_name = self.agent_name.as_deref().unwrap_or("agent");
1920 let decision = self.policy_engine.check(agent_name, "tool_call", tool_name);
1921 debug!(
1922 "Policy check for tool '{}': decision={:?}",
1923 tool_name, decision
1924 );
1925
1926 let allowed = matches!(decision, PolicyDecision::Allow);
1927 self.tool_permission_cache
1928 .write()
1929 .await
1930 .insert(tool_name.to_string(), allowed);
1931 allowed
1932 }
1933
1934 async fn prompt_for_tool_permission(&mut self, tool_name: &str) -> Result<bool> {
1936 info!("Requesting user permission for tool: {}", tool_name);
1937
1938 let tool_description = self
1940 .tool_registry
1941 .get(tool_name)
1942 .map(|t| t.description().to_string())
1943 .unwrap_or_else(|| "No description available".to_string());
1944
1945 let prompt_args = json!({
1947 "prompt": format!(
1948 "The agent wants to use the '{}' tool.\n\nDescription: {}\n\nDo you want to allow this?",
1949 tool_name,
1950 tool_description
1951 ),
1952 "input_type": "boolean",
1953 "required": true,
1954 });
1955
1956 match self.tool_registry.execute("prompt_user", prompt_args).await {
1957 Ok(result) if result.success => {
1958 info!("prompt_user output: {}", result.output);
1959
1960 let allowed =
1962 if let Ok(response_json) = serde_json::from_str::<Value>(&result.output) {
1963 info!("Parsed JSON response: {:?}", response_json);
1964 let value = response_json["response"].as_bool();
1966 info!("Extracted boolean value: {:?}", value);
1967 value.unwrap_or(false)
1968 } else {
1969 info!("Failed to parse JSON, trying plain text fallback");
1970 let response = result.output.trim().to_lowercase();
1972 let parsed = response == "yes" || response == "y" || response == "true";
1973 info!("Plain text parse result for '{}': {}", response, parsed);
1974 parsed
1975 };
1976
1977 info!(
1978 "User {} tool '{}'",
1979 if allowed { "allowed" } else { "denied" },
1980 tool_name
1981 );
1982
1983 if allowed {
1984 self.add_allowed_tool(tool_name).await;
1986 } else {
1987 self.add_denied_tool(tool_name).await;
1989 }
1990
1991 Ok(allowed)
1992 }
1993 Ok(result) => {
1994 warn!("Failed to prompt user: {:?}", result.error);
1995 Ok(false)
1996 }
1997 Err(e) => {
1998 warn!("Error prompting user for permission: {}", e);
1999 Ok(false)
2000 }
2001 }
2002 }
2003
2004 async fn add_allowed_tool(&mut self, tool_name: &str) {
2006 let tools = self.profile.allowed_tools.get_or_insert_with(Vec::new);
2007 if !tools.contains(&tool_name.to_string()) {
2008 tools.push(tool_name.to_string());
2009 info!("Added '{}' to allowed tools list", tool_name);
2010 }
2011 self.tool_permission_cache.write().await.remove(tool_name);
2013 }
2014
2015 async fn add_denied_tool(&mut self, tool_name: &str) {
2017 let tools = self.profile.denied_tools.get_or_insert_with(Vec::new);
2018 if !tools.contains(&tool_name.to_string()) {
2019 tools.push(tool_name.to_string());
2020 info!("Added '{}' to denied tools list", tool_name);
2021 }
2022 self.tool_permission_cache.write().await.remove(tool_name);
2024 }
2025
2026 async fn execute_tool(
2028 &self,
2029 run_id: &str,
2030 tool_name: &str,
2031 args: &Value,
2032 ) -> Result<ToolResult> {
2033 let exec_result = self.tool_registry.execute(tool_name, args.clone()).await;
2035 let result = match exec_result {
2036 Ok(res) => res,
2037 Err(err) => ToolResult::failure(err.to_string()),
2038 };
2039
2040 let result_json = serde_json::json!({
2042 "output": result.output,
2043 "success": result.success,
2044 "error": result.error,
2045 });
2046
2047 let error_str = result.error.as_deref();
2048 self.persistence
2049 .log_tool(
2050 &self.session_id,
2051 self.agent_name.as_deref().unwrap_or("unknown"),
2052 run_id,
2053 tool_name,
2054 args,
2055 &result_json,
2056 result.success,
2057 error_str,
2058 )
2059 .context("Failed to log tool execution")?;
2060
2061 Ok(result)
2062 }
2063
2064 pub fn tool_registry(&self) -> &ToolRegistry {
2066 &self.tool_registry
2067 }
2068
2069 pub fn policy_engine(&self) -> &PolicyEngine {
2071 &self.policy_engine
2072 }
2073
2074 pub fn set_policy_engine(&mut self, policy_engine: Arc<PolicyEngine>) {
2076 self.policy_engine = policy_engine;
2077 }
2078
2079 pub async fn generate_embedding(&self, text: &str) -> Option<i64> {
2082 if let Some(client) = &self.embeddings_client {
2083 if !text.trim().is_empty() {
2084 match client.embed_batch(&[text]).await {
2085 Ok(mut embeddings) => {
2086 if let Some(embedding) = embeddings.pop() {
2087 if !embedding.is_empty() {
2088 match self.persistence.insert_memory_vector(
2089 &self.session_id,
2090 None, &embedding,
2092 ) {
2093 Ok(emb_id) => return Some(emb_id),
2094 Err(err) => {
2095 warn!("Failed to persist embedding: {}", err);
2096 }
2097 }
2098 }
2099 }
2100 }
2101 Err(err) => {
2102 warn!("Failed to generate embedding: {}", err);
2103 }
2104 }
2105 }
2106 }
2107 None
2108 }
2109
2110 fn evaluate_graph_for_next_action(
2112 &self,
2113 user_message_id: i64,
2114 assistant_message_id: i64,
2115 response_content: &str,
2116 tool_invocations: &[ToolInvocation],
2117 ) -> Result<Option<String>> {
2118 let nodes = self.persistence.list_graph_nodes(
2120 &self.session_id,
2121 Some(NodeType::Message),
2122 Some(50),
2123 )?;
2124
2125 let mut assistant_node_id = None;
2126 let mut _user_node_id = None;
2127
2128 for node in &nodes {
2129 if let Some(msg_id) = node.properties["message_id"].as_i64() {
2130 if msg_id == assistant_message_id {
2131 assistant_node_id = Some(node.id);
2132 } else if msg_id == user_message_id {
2133 _user_node_id = Some(node.id);
2134 }
2135 }
2136 }
2137
2138 if assistant_node_id.is_none() {
2139 debug!("Assistant message node not found in graph");
2140 return Ok(None);
2141 }
2142
2143 let assistant_node_id = assistant_node_id.unwrap();
2144
2145 let neighbors = self.persistence.traverse_neighbors(
2147 &self.session_id,
2148 assistant_node_id,
2149 TraversalDirection::Both,
2150 2, )?;
2152
2153 let goal_nodes =
2155 self.persistence
2156 .list_graph_nodes(&self.session_id, Some(NodeType::Goal), Some(10))?;
2157
2158 let mut pending_goals = Vec::new();
2159 let mut completed_goals = Vec::new();
2160
2161 for goal in &goal_nodes {
2162 if let Some(status) = goal.properties["status"].as_str() {
2163 match status {
2164 "pending" | "in_progress" => {
2165 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2166 pending_goals.push(goal_text.to_string());
2167 }
2168 }
2169 "completed" => {
2170 if let Some(goal_text) = goal.properties["goal_text"].as_str() {
2171 completed_goals.push(goal_text.to_string());
2172 }
2173 }
2174 _ => {}
2175 }
2176 }
2177 }
2178
2179 let tool_nodes = self.persistence.list_graph_nodes(
2181 &self.session_id,
2182 Some(NodeType::ToolResult),
2183 Some(10),
2184 )?;
2185
2186 let mut recent_tool_failures = Vec::new();
2187 let mut recent_tool_successes = Vec::new();
2188
2189 for tool_node in &tool_nodes {
2190 if let Some(success) = tool_node.properties["success"].as_bool() {
2191 let tool_name = tool_node.properties["tool"].as_str().unwrap_or("unknown");
2192 if success {
2193 recent_tool_successes.push(tool_name.to_string());
2194 } else {
2195 recent_tool_failures.push(tool_name.to_string());
2196 }
2197 }
2198 }
2199
2200 let mut key_entities = HashSet::new();
2202 let mut key_concepts = HashSet::new();
2203
2204 for neighbor in &neighbors {
2205 match neighbor.node_type {
2206 NodeType::Entity => {
2207 if let Some(name) = neighbor.properties["name"].as_str() {
2208 key_entities.insert(name.to_string());
2209 }
2210 }
2211 NodeType::Concept => {
2212 if let Some(name) = neighbor.properties["name"].as_str() {
2213 key_concepts.insert(name.to_string());
2214 }
2215 }
2216 _ => {}
2217 }
2218 }
2219
2220 let recommendation = self.generate_action_recommendation(
2222 &pending_goals,
2223 &completed_goals,
2224 &recent_tool_failures,
2225 &recent_tool_successes,
2226 &key_entities,
2227 &key_concepts,
2228 response_content,
2229 tool_invocations,
2230 );
2231
2232 if let Some(ref rec) = recommendation {
2234 let properties = json!({
2235 "recommendation": rec,
2236 "user_message_id": user_message_id,
2237 "assistant_message_id": assistant_message_id,
2238 "pending_goals": pending_goals,
2239 "completed_goals": completed_goals,
2240 "tool_failures": recent_tool_failures,
2241 "tool_successes": recent_tool_successes,
2242 "key_entities": key_entities.into_iter().collect::<Vec<_>>(),
2243 "key_concepts": key_concepts.into_iter().collect::<Vec<_>>(),
2244 "timestamp": Utc::now().to_rfc3339(),
2245 });
2246
2247 let rec_node_id = self.persistence.insert_graph_node(
2248 &self.session_id,
2249 NodeType::Event,
2250 "NextActionRecommendation",
2251 &properties,
2252 None,
2253 )?;
2254
2255 self.persistence.insert_graph_edge(
2257 &self.session_id,
2258 assistant_node_id,
2259 rec_node_id,
2260 EdgeType::Produces,
2261 Some("recommends"),
2262 None,
2263 0.8,
2264 )?;
2265 }
2266
2267 Ok(recommendation)
2268 }
2269
2270 fn generate_action_recommendation(
2272 &self,
2273 pending_goals: &[String],
2274 completed_goals: &[String],
2275 recent_tool_failures: &[String],
2276 _recent_tool_successes: &[String],
2277 _key_entities: &HashSet<String>,
2278 key_concepts: &HashSet<String>,
2279 response_content: &str,
2280 tool_invocations: &[ToolInvocation],
2281 ) -> Option<String> {
2282 let mut recommendations = Vec::new();
2283
2284 if !pending_goals.is_empty() {
2286 let goals_str = pending_goals.join(", ");
2287 recommendations.push(format!("Continue working on pending goals: {}", goals_str));
2288 }
2289
2290 if !recent_tool_failures.is_empty() {
2292 let unique_failures: HashSet<_> = recent_tool_failures.iter().collect();
2293 for tool in unique_failures {
2294 recommendations.push(format!(
2295 "Consider alternative approach for failed tool: {}",
2296 tool
2297 ));
2298 }
2299 }
2300
2301 let response_lower = response_content.to_lowercase();
2303 if response_lower.contains("error") || response_lower.contains("failed") {
2304 recommendations.push("Investigate and resolve the reported error".to_string());
2305 }
2306
2307 if response_lower.contains("?") || response_lower.contains("unclear") {
2308 recommendations.push("Clarify the uncertain aspects mentioned".to_string());
2309 }
2310
2311 if tool_invocations.len() > 1 {
2313 let tool_sequence: Vec<_> = tool_invocations.iter().map(|t| t.name.as_str()).collect();
2314
2315 if tool_sequence.contains(&"file_read") && !tool_sequence.contains(&"file_write") {
2317 recommendations
2318 .push("Consider modifying the read files if changes are needed".to_string());
2319 }
2320
2321 if tool_sequence.contains(&"search")
2322 && tool_invocations.last().is_some_and(|t| t.success)
2323 {
2324 recommendations
2325 .push("Examine the search results for relevant information".to_string());
2326 }
2327 }
2328
2329 if key_concepts.contains("Knowledge Graph") || key_concepts.contains("Graph Node") {
2331 recommendations
2332 .push("Consider visualizing or querying the graph structure".to_string());
2333 }
2334
2335 if key_concepts.contains("Database") || key_concepts.contains("Query Processing") {
2336 recommendations.push("Verify data integrity and query performance".to_string());
2337 }
2338
2339 if !completed_goals.is_empty() && !pending_goals.is_empty() {
2341 recommendations.push(format!(
2342 "Build on completed work ({} done) to address remaining goals ({} pending)",
2343 completed_goals.len(),
2344 pending_goals.len()
2345 ));
2346 }
2347
2348 if recommendations.is_empty() {
2350 if completed_goals.len() > pending_goals.len() && recent_tool_failures.is_empty() {
2352 Some(
2353 "Current objectives appear satisfied. Ready for new tasks or refinements."
2354 .to_string(),
2355 )
2356 } else {
2357 None
2358 }
2359 } else {
2360 Some(recommendations[0].clone())
2362 }
2363 }
2364
2365 fn log_timing(&self, stage: &str, start: Instant) {
2366 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
2367 let agent_label = self.agent_name.as_deref().unwrap_or("unnamed");
2368 info!(
2369 target: "agent_timing",
2370 "stage={} duration_ms={:.2} agent={} session_id={}",
2371 stage,
2372 duration_ms,
2373 agent_label,
2374 self.session_id
2375 );
2376 }
2377}
2378
2379fn preview_text(content: &str) -> String {
2380 const MAX_CHARS: usize = 80;
2381 let trimmed = content.trim();
2382 let mut preview = String::new();
2383 for (idx, ch) in trimmed.chars().enumerate() {
2384 if idx >= MAX_CHARS {
2385 preview.push_str("...");
2386 break;
2387 }
2388 preview.push(ch);
2389 }
2390 if preview.is_empty() {
2391 trimmed.to_string()
2392 } else {
2393 preview
2394 }
2395}
2396
2397fn preview_json_value(value: &Value) -> String {
2398 match value {
2399 Value::String(text) => preview_text(text),
2400 Value::Null => "null".to_string(),
2401 _ => {
2402 let raw = value.to_string();
2403 if raw.len() > 80 {
2404 format!("{}...", &raw[..77])
2405 } else {
2406 raw
2407 }
2408 }
2409 }
2410}
2411
2412fn payload_field(payload: &Value, key: &str) -> Value {
2413 payload.get(key).cloned().unwrap_or(Value::Null)
2414}
2415
2416#[cfg(test)]
2417mod tests {
2418 use super::*;
2419 use crate::agent::providers::MockProvider;
2420 use crate::config::AgentProfile;
2421 use crate::embeddings::{EmbeddingsClient, EmbeddingsService};
2422 use async_trait::async_trait;
2423 use tempfile::tempdir;
2424
2425 fn create_test_agent(session_id: &str) -> (AgentCore, tempfile::TempDir) {
2426 create_test_agent_with_embeddings(session_id, None)
2427 }
2428
2429 fn create_test_agent_with_embeddings(
2430 session_id: &str,
2431 embeddings_client: Option<EmbeddingsClient>,
2432 ) -> (AgentCore, tempfile::TempDir) {
2433 let dir = tempdir().unwrap();
2434 let db_path = dir.path().join("test.duckdb");
2435 let persistence = Persistence::new(&db_path).unwrap();
2436
2437 let profile = AgentProfile {
2438 prompt: Some("You are a helpful assistant.".to_string()),
2439 style: None,
2440 temperature: Some(0.7),
2441 model_provider: None,
2442 model_name: None,
2443 allowed_tools: None,
2444 denied_tools: None,
2445 memory_k: 5,
2446 top_p: 0.9,
2447 max_context_tokens: Some(2048),
2448 enable_graph: false,
2449 graph_memory: false,
2450 auto_graph: false,
2451 graph_steering: false,
2452 graph_depth: 3,
2453 graph_weight: 0.5,
2454 graph_threshold: 0.7,
2455 fast_reasoning: false,
2456 fast_model_provider: None,
2457 fast_model_name: None,
2458 fast_model_temperature: 0.3,
2459 fast_model_tasks: vec![],
2460 escalation_threshold: 0.6,
2461 show_reasoning: false,
2462 enable_audio_transcription: false,
2463 audio_response_mode: "immediate".to_string(),
2464 audio_scenario: None,
2465 };
2466
2467 let provider = Arc::new(MockProvider::new("This is a test response."));
2468 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2469 let policy_engine = Arc::new(PolicyEngine::new());
2470
2471 (
2472 AgentCore::new(
2473 profile,
2474 provider,
2475 embeddings_client,
2476 persistence,
2477 session_id.to_string(),
2478 Some(session_id.to_string()),
2479 tool_registry,
2480 policy_engine,
2481 ),
2482 dir,
2483 )
2484 }
2485
2486 fn create_fast_reasoning_agent(
2487 session_id: &str,
2488 fast_output: &str,
2489 ) -> (AgentCore, tempfile::TempDir) {
2490 let dir = tempdir().unwrap();
2491 let db_path = dir.path().join("fast.duckdb");
2492 let persistence = Persistence::new(&db_path).unwrap();
2493
2494 let profile = AgentProfile {
2495 prompt: Some("You are a helpful assistant.".to_string()),
2496 style: None,
2497 temperature: Some(0.7),
2498 model_provider: None,
2499 model_name: None,
2500 allowed_tools: None,
2501 denied_tools: None,
2502 memory_k: 5,
2503 top_p: 0.9,
2504 max_context_tokens: Some(2048),
2505 enable_graph: false,
2506 graph_memory: false,
2507 auto_graph: false,
2508 graph_steering: false,
2509 graph_depth: 3,
2510 graph_weight: 0.5,
2511 graph_threshold: 0.7,
2512 fast_reasoning: true,
2513 fast_model_provider: Some("mock".to_string()),
2514 fast_model_name: Some("mock-fast".to_string()),
2515 fast_model_temperature: 0.3,
2516 fast_model_tasks: vec!["entity_extraction".to_string()],
2517 escalation_threshold: 0.5,
2518 show_reasoning: false,
2519 enable_audio_transcription: false,
2520 audio_response_mode: "immediate".to_string(),
2521 audio_scenario: None,
2522 };
2523
2524 profile.validate().unwrap();
2525
2526 let provider = Arc::new(MockProvider::new("This is a test response."));
2527 let fast_provider = Arc::new(MockProvider::new(fast_output.to_string()));
2528 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2529 let policy_engine = Arc::new(PolicyEngine::new());
2530
2531 (
2532 AgentCore::new(
2533 profile,
2534 provider,
2535 None,
2536 persistence,
2537 session_id.to_string(),
2538 Some(session_id.to_string()),
2539 tool_registry,
2540 policy_engine,
2541 )
2542 .with_fast_provider(fast_provider),
2543 dir,
2544 )
2545 }
2546
2547 #[derive(Clone)]
2548 struct KeywordEmbeddingsService;
2549
2550 #[async_trait]
2551 impl EmbeddingsService for KeywordEmbeddingsService {
2552 async fn create_embeddings(
2553 &self,
2554 _model: &str,
2555 inputs: Vec<String>,
2556 ) -> Result<Vec<Vec<f32>>> {
2557 Ok(inputs
2558 .into_iter()
2559 .map(|input| keyword_embedding(&input))
2560 .collect())
2561 }
2562 }
2563
2564 fn keyword_embedding(input: &str) -> Vec<f32> {
2565 let lower = input.to_ascii_lowercase();
2566 let alpha = if lower.contains("alpha") { 1.0 } else { 0.0 };
2567 let beta = if lower.contains("beta") { 1.0 } else { 0.0 };
2568 vec![alpha, beta]
2569 }
2570
2571 fn test_embeddings_client() -> EmbeddingsClient {
2572 EmbeddingsClient::with_service(
2573 "test",
2574 Arc::new(KeywordEmbeddingsService) as Arc<dyn EmbeddingsService>,
2575 )
2576 }
2577
2578 #[tokio::test]
2579 async fn test_agent_core_run_step() {
2580 let (mut agent, _dir) = create_test_agent("test-session-1");
2581
2582 let output = agent.run_step("Hello, how are you?").await.unwrap();
2583
2584 assert!(!output.response.is_empty());
2585 assert!(output.token_usage.is_some());
2586 assert_eq!(output.tool_invocations.len(), 0);
2587 }
2588
2589 #[tokio::test]
2590 async fn fast_model_short_circuits_when_confident() {
2591 let (mut agent, _dir) = create_fast_reasoning_agent(
2592 "fast-confident",
2593 "Answer: Entities detected.\nConfidence: 0.9",
2594 );
2595
2596 let output = agent
2597 .run_step("Extract the entities mentioned in this string.")
2598 .await
2599 .unwrap();
2600
2601 assert!(output
2602 .finish_reason
2603 .unwrap_or_default()
2604 .contains("fast_model"));
2605 assert!(output.response.contains("Entities detected"));
2606 }
2607
2608 #[tokio::test]
2609 async fn fast_model_only_hints_when_low_confidence() {
2610 let (mut agent, _dir) =
2611 create_fast_reasoning_agent("fast-hint", "Answer: Unsure.\nConfidence: 0.2");
2612
2613 let output = agent
2614 .run_step("Extract the entities mentioned in this string.")
2615 .await
2616 .unwrap();
2617
2618 assert_eq!(output.finish_reason.as_deref(), Some("stop"));
2619 assert_eq!(output.response, "This is a test response.");
2620 }
2621
2622 #[tokio::test]
2623 async fn test_agent_core_conversation_history() {
2624 let (mut agent, _dir) = create_test_agent("test-session-2");
2625
2626 agent.run_step("First message").await.unwrap();
2627 agent.run_step("Second message").await.unwrap();
2628
2629 let history = agent.conversation_history();
2630 assert_eq!(history.len(), 4); assert_eq!(history[0].role, MessageRole::User);
2632 assert_eq!(history[1].role, MessageRole::Assistant);
2633 }
2634
2635 #[tokio::test]
2636 async fn test_agent_core_session_switch() {
2637 let (mut agent, _dir) = create_test_agent("session-1");
2638
2639 agent.run_step("Message in session 1").await.unwrap();
2640 assert_eq!(agent.session_id(), "session-1");
2641
2642 agent = agent.with_session("session-2".to_string());
2643 assert_eq!(agent.session_id(), "session-2");
2644 assert_eq!(agent.conversation_history().len(), 0);
2645 }
2646
2647 #[tokio::test]
2648 async fn test_agent_core_build_prompt() {
2649 let (agent, _dir) = create_test_agent("test-session-3");
2650
2651 let context = vec![
2652 Message {
2653 id: 1,
2654 session_id: "test-session-3".to_string(),
2655 role: MessageRole::User,
2656 content: "Previous question".to_string(),
2657 created_at: Utc::now(),
2658 },
2659 Message {
2660 id: 2,
2661 session_id: "test-session-3".to_string(),
2662 role: MessageRole::Assistant,
2663 content: "Previous answer".to_string(),
2664 created_at: Utc::now(),
2665 },
2666 ];
2667
2668 let prompt = agent
2669 .build_prompt("Current question", &context)
2670 .await
2671 .unwrap();
2672
2673 assert!(prompt.contains("You are a helpful assistant"));
2674 assert!(prompt.contains("Previous conversation"));
2675 assert!(prompt.contains("user: Previous question"));
2676 assert!(prompt.contains("assistant: Previous answer"));
2677 assert!(prompt.contains("user: Current question"));
2678 }
2679
2680 #[tokio::test]
2681 async fn test_agent_core_persistence() {
2682 let (mut agent, _dir) = create_test_agent("persist-test");
2683
2684 agent.run_step("Test message").await.unwrap();
2685
2686 let messages = agent
2688 .persistence
2689 .list_messages("persist-test", 100)
2690 .unwrap();
2691
2692 assert_eq!(messages.len(), 2); assert_eq!(messages[0].role, MessageRole::User);
2694 assert_eq!(messages[0].content, "Test message");
2695 }
2696
2697 #[tokio::test]
2698 async fn store_message_records_embeddings() {
2699 let (agent, _dir) =
2700 create_test_agent_with_embeddings("embedding-store", Some(test_embeddings_client()));
2701
2702 let message_id = agent
2703 .store_message(MessageRole::User, "Alpha detail")
2704 .await
2705 .unwrap();
2706
2707 let query = vec![1.0f32, 0.0];
2708 let recalled = agent
2709 .persistence
2710 .recall_top_k("embedding-store", &query, 1)
2711 .unwrap();
2712
2713 assert_eq!(recalled.len(), 1);
2714 assert_eq!(recalled[0].0.message_id, Some(message_id));
2715 }
2716
2717 #[tokio::test]
2718 async fn recall_memories_appends_semantic_matches() {
2719 let (agent, _dir) =
2720 create_test_agent_with_embeddings("semantic-recall", Some(test_embeddings_client()));
2721
2722 agent
2723 .store_message(MessageRole::User, "Alpha question")
2724 .await
2725 .unwrap();
2726 agent
2727 .store_message(MessageRole::Assistant, "Alpha answer")
2728 .await
2729 .unwrap();
2730 agent
2731 .store_message(MessageRole::User, "Beta prompt")
2732 .await
2733 .unwrap();
2734 agent
2735 .store_message(MessageRole::Assistant, "Beta reply")
2736 .await
2737 .unwrap();
2738
2739 let recall = agent.recall_memories("alpha follow up").await.unwrap();
2740 assert!(matches!(
2741 recall.stats.as_ref().map(|s| &s.strategy),
2742 Some(MemoryRecallStrategy::Semantic { .. })
2743 ));
2744 assert_eq!(
2745 recall
2746 .stats
2747 .as_ref()
2748 .map(|s| s.matches.len())
2749 .unwrap_or_default(),
2750 2
2751 );
2752
2753 let recalled = recall.messages;
2754 assert_eq!(recalled.len(), 4);
2755 assert_eq!(recalled[0].content, "Beta prompt");
2756 assert_eq!(recalled[1].content, "Beta reply");
2757
2758 let tail: Vec<_> = recalled[2..].iter().map(|m| m.content.as_str()).collect();
2759 assert!(tail.contains(&"Alpha question"));
2760 assert!(tail.contains(&"Alpha answer"));
2761 }
2762
2763 #[tokio::test]
2764 async fn test_agent_tool_permission_allowed() {
2765 let dir = tempdir().unwrap();
2766 let db_path = dir.path().join("test.duckdb");
2767 let persistence = Persistence::new(&db_path).unwrap();
2768
2769 let mut profile = AgentProfile {
2770 prompt: Some("Test".to_string()),
2771 style: None,
2772 temperature: Some(0.7),
2773 model_provider: None,
2774 model_name: None,
2775 allowed_tools: Some(vec!["echo".to_string()]),
2776 denied_tools: None,
2777 memory_k: 5,
2778 top_p: 0.9,
2779 max_context_tokens: Some(2048),
2780 enable_graph: false,
2781 graph_memory: false,
2782 auto_graph: false,
2783 graph_steering: false,
2784 graph_depth: 3,
2785 graph_weight: 0.5,
2786 graph_threshold: 0.7,
2787 fast_reasoning: false,
2788 fast_model_provider: None,
2789 fast_model_name: None,
2790 fast_model_temperature: 0.3,
2791 fast_model_tasks: vec![],
2792 escalation_threshold: 0.6,
2793 show_reasoning: false,
2794 enable_audio_transcription: false,
2795 audio_response_mode: "immediate".to_string(),
2796 audio_scenario: None,
2797 };
2798
2799 let provider = Arc::new(MockProvider::new("Test"));
2800 let tool_registry = Arc::new(crate::tools::ToolRegistry::new());
2801
2802 let mut policy_engine = PolicyEngine::new();
2804 policy_engine.add_rule(crate::policy::PolicyRule {
2805 agent: "*".to_string(),
2806 action: "tool_call".to_string(),
2807 resource: "*".to_string(),
2808 effect: crate::policy::PolicyEffect::Allow,
2809 });
2810 let policy_engine = Arc::new(policy_engine);
2811
2812 let agent = AgentCore::new(
2813 profile.clone(),
2814 provider.clone(),
2815 None,
2816 persistence.clone(),
2817 "test-session".to_string(),
2818 Some("policy-test".to_string()),
2819 tool_registry.clone(),
2820 policy_engine.clone(),
2821 );
2822
2823 assert!(agent.is_tool_allowed("echo").await);
2824 assert!(!agent.is_tool_allowed("calculator").await);
2825
2826 profile.allowed_tools = None;
2828 profile.denied_tools = Some(vec!["calculator".to_string()]);
2829
2830 let agent = AgentCore::new(
2831 profile,
2832 provider,
2833 None,
2834 persistence,
2835 "test-session-2".to_string(),
2836 Some("policy-test-2".to_string()),
2837 tool_registry,
2838 policy_engine,
2839 );
2840
2841 assert!(agent.is_tool_allowed("echo").await);
2842 assert!(!agent.is_tool_allowed("calculator").await);
2843 }
2844
2845 #[tokio::test]
2846 async fn test_agent_tool_execution_with_logging() {
2847 let dir = tempdir().unwrap();
2848 let db_path = dir.path().join("test.duckdb");
2849 let persistence = Persistence::new(&db_path).unwrap();
2850
2851 let profile = AgentProfile {
2852 prompt: Some("Test".to_string()),
2853 style: None,
2854 temperature: Some(0.7),
2855 model_provider: None,
2856 model_name: None,
2857 allowed_tools: Some(vec!["echo".to_string()]),
2858 denied_tools: None,
2859 memory_k: 5,
2860 top_p: 0.9,
2861 max_context_tokens: Some(2048),
2862 enable_graph: false,
2863 graph_memory: false,
2864 auto_graph: false,
2865 graph_steering: false,
2866 graph_depth: 3,
2867 graph_weight: 0.5,
2868 graph_threshold: 0.7,
2869 fast_reasoning: false,
2870 fast_model_provider: None,
2871 fast_model_name: None,
2872 fast_model_temperature: 0.3,
2873 fast_model_tasks: vec![],
2874 escalation_threshold: 0.6,
2875 show_reasoning: false,
2876 enable_audio_transcription: false,
2877 audio_response_mode: "immediate".to_string(),
2878 audio_scenario: None,
2879 };
2880
2881 let provider = Arc::new(MockProvider::new("Test"));
2882
2883 let mut tool_registry = crate::tools::ToolRegistry::new();
2885 tool_registry.register(Arc::new(crate::tools::builtin::EchoTool::new()));
2886
2887 let policy_engine = Arc::new(PolicyEngine::new());
2888
2889 let agent = AgentCore::new(
2890 profile,
2891 provider,
2892 None,
2893 persistence.clone(),
2894 "tool-exec-test".to_string(),
2895 Some("tool-agent".to_string()),
2896 Arc::new(tool_registry),
2897 policy_engine,
2898 );
2899
2900 let args = serde_json::json!({"message": "test message"});
2902 let result = agent
2903 .execute_tool("run-tool-test", "echo", &args)
2904 .await
2905 .unwrap();
2906
2907 assert!(result.success);
2908 assert_eq!(result.output, "test message");
2909
2910 }
2912
2913 #[tokio::test]
2914 async fn test_agent_tool_registry_access() {
2915 let (agent, _dir) = create_test_agent("registry-test");
2916
2917 let registry = agent.tool_registry();
2918 assert!(registry.is_empty());
2919 }
2920
2921 #[test]
2922 fn test_goal_requires_tool_detection() {
2923 assert!(AgentCore::goal_requires_tool(
2924 "List the files in this directory"
2925 ));
2926 assert!(AgentCore::goal_requires_tool("Run the tests"));
2927 assert!(!AgentCore::goal_requires_tool("Explain recursion"));
2928 assert!(AgentCore::goal_requires_tool(
2929 "Tell me about the project in this directory"
2930 ));
2931 }
2932
2933 #[test]
2934 fn test_infer_goal_tool_action_project_description() {
2935 let query = "Tell me about the project in this directory";
2936 let inferred = AgentCore::infer_goal_tool_action(query)
2937 .expect("Should infer a tool for project description");
2938 let (tool, args) = inferred;
2939 assert!(
2940 tool == "file_read" || tool == "search",
2941 "unexpected tool: {}",
2942 tool
2943 );
2944 if tool == "file_read" {
2945 assert!(args.get("path").is_some());
2947 assert!(args.get("max_bytes").is_some());
2948 } else {
2949 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
2951 assert!(query.contains("Cargo.toml") || query.contains("package.json"));
2952 }
2953 }
2954}