1use anyhow::anyhow;
2use std::pin::Pin;
3use std::sync::Arc;
4use log::info;
5
6use crate::{
7 Agent, AgentAction, AgentFinish, AgentOutput, BaseMemory, ModelChatMessage, ChatMessageContent, ChatModel,
8 McpClient, McpToolAdapter, OpenAIChatModel, Runnable, Tool, parse_model_output
9};
10use serde_json::Value;
11
12pub struct McpAgent {
15 client: Arc<dyn McpClient>,
16 tools: Vec<Box<dyn Tool + Send + Sync>>,
17 system_prompt: String,
18 openai_model: Option<OpenAIChatModel>,
19 memory: Option<Box<dyn BaseMemory>>,
20}
21
22impl McpAgent {
23 pub fn new(client: Arc<dyn McpClient>, system_prompt: String) -> Self {
25 Self {
26 client,
27 tools: Vec::new(),
28 system_prompt,
29 openai_model: None, memory: None, }
32 }
33
34 pub fn with_openai_model(client: Arc<dyn McpClient>, system_prompt: String, openai_model: OpenAIChatModel) -> Self {
36 Self {
37 client,
38 tools: Vec::new(),
39 system_prompt,
40 openai_model: Some(openai_model),
41 memory: None, }
43 }
44
45 pub fn with_memory(client: Arc<dyn McpClient>, system_prompt: String, memory: Box<dyn BaseMemory>) -> Self {
47 Self {
48 client,
49 tools: Vec::new(),
50 system_prompt,
51 openai_model: None,
52 memory: Some(memory),
53 }
54 }
55
56 pub fn with_openai_model_and_memory(client: Arc<dyn McpClient>, system_prompt: String, openai_model: OpenAIChatModel, memory: Box<dyn BaseMemory>) -> Self {
58 Self {
59 client,
60 tools: Vec::new(),
61 system_prompt,
62 openai_model: Some(openai_model),
63 memory: Some(memory),
64 }
65 }
66
67 pub fn get_memory(&self) -> Option<&Box<dyn BaseMemory>> {
69 self.memory.as_ref()
70 }
71
72 pub fn add_tool(&mut self, tool: Box<dyn Tool + Send + Sync>) {
74 self.tools.push(tool);
75 }
76
77 pub async fn auto_add_tools(&mut self) -> Result<(), anyhow::Error> {
81 use crate::McpToolAdapter;
82
83 let tools = self.client.get_tools().await?;
85
86 for tool in &tools {
88 info!("MCP Client Get Tool: {} - {}", tool.name, tool.description);
89 }
90
91 for tool in tools {
93 let tool_adapter = McpToolAdapter::new(
94 self.client.clone(),
95 tool
96 );
97 self.add_tool(Box::new(tool_adapter));
98 }
99
100 Ok(())
101 }
102}
103
104impl Agent for McpAgent {
105 fn tools(&self) -> Vec<Box<dyn Tool + Send + Sync>> {
106 let mut cloned_tools: Vec<Box<dyn Tool + Send + Sync>> = Vec::new();
109
110 for tool in &self.tools {
113 if let Some(mcp_tool_adapter) = tool.as_any().downcast_ref::<McpToolAdapter>() {
115 let cloned_adapter = McpToolAdapter::new(
117 mcp_tool_adapter.get_client(),
118 mcp_tool_adapter.get_mcp_tool(),
119 );
120 cloned_tools.push(Box::new(cloned_adapter));
121 } else {
122 info!(
125 "Warning: Unable to clone non-McpToolAdapter type tool: {}",
126 tool.name()
127 );
128 }
129 }
130
131 cloned_tools
132 }
133
134 fn execute(
135 &self,
136 _action: &AgentAction,
137 ) -> std::pin::Pin<
138 Box<dyn std::future::Future<Output = Result<String, anyhow::Error>> + Send + '_>,
139 > {
140 Box::pin(async move {
141 Err(anyhow!("Tool execution functionality is not implemented yet"))
144 })
145 }
146
147 fn clone_agent(&self) -> Box<dyn Agent> {
148 let new_agent = McpAgent::new(
150 self.client.clone(),
151 self.system_prompt.clone(),
152 );
153
154 Box::new(new_agent)
156 }
157}
158
159impl Clone for McpAgent {
160 fn clone(&self) -> Self {
161 Self {
163 client: Arc::clone(&self.client),
164 tools: Vec::new(), system_prompt: self.system_prompt.clone(),
166 openai_model: self.openai_model.clone(), memory: self.memory.clone(), }
169 }
170}
171
172impl Runnable<std::collections::HashMap<String, String>, AgentOutput> for McpAgent {
173 fn invoke(
174 &self,
175 input: std::collections::HashMap<String, String>,
176 ) -> Pin<Box<dyn std::future::Future<Output = Result<AgentOutput, anyhow::Error>> + Send>> {
177 let system_prompt = self.system_prompt.clone();
179 let input_text = input
180 .get("input")
181 .cloned()
182 .unwrap_or_default()
183 .to_string()
184 .trim()
185 .to_string();
186
187 let tool_descriptions: String = if !self.tools.is_empty() {
189 let mut descriptions = String::new();
190 for tool in &self.tools {
191 descriptions.push_str(&format!("- {}: {}\n", tool.name(), tool.description()));
192 }
193 descriptions
194 } else {
195 String::new()
196 };
197
198 let memory_clone = self.memory.clone();
200
201 let enhanced_system_prompt = if !tool_descriptions.is_empty() {
203 format!("{}
204You are an AI assistant that follows the ReAct (Reasoning and Acting) framework.
205You should think step by step and decide whether to use tools based on user needs.
206You should carefully review and when confirming the use of the tool, if there are omissions, errors, or other issues with the parameters, you should reply and remind the user.
207Available tools:\n{}\n\nWhen you need to use a tool, please respond in the following JSON format:
208 \n{{\"call_tool\": {{\"name\": \"Tool Name\", \"parameters\": {{\"parameter_name\": \"parameter_value\"}}}}}}
209 When you don't need to use a tool, please respond in the following JSON format:\n{{\"content\": \"Your answer\"}}
210 Please think carefully about whether the user's request requires a tool to be used, and only use tools when necessary.",
211 system_prompt, tool_descriptions)
212 } else {
213 system_prompt
214 };
215
216 let openai_model_clone = self.openai_model.clone();
218
219 Box::pin(async move {
220 if input_text.is_empty() {
222 let mut return_values = std::collections::HashMap::new();
223 return_values.insert("answer".to_string(), "Please enter valid content".to_string());
224 let model_name = if let Some(ref openai_model) = openai_model_clone {
226 openai_model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string())
227 } else {
228 "unknown".to_string()
229 };
230 return_values.insert("model".to_string(), model_name);
231 return Ok(AgentOutput::Finish(AgentFinish { return_values }));
232 }
233
234 let model = if let Some(ref openai_model) = openai_model_clone {
236 openai_model
238 } else {
239 let mut return_values = std::collections::HashMap::new();
241 return_values.insert("answer".to_string(), "No OpenAI model provided".to_string());
242 return_values.insert("model".to_string(), "unknown".to_string());
243 return Ok(AgentOutput::Finish(AgentFinish { return_values }));
244 };
245
246 let mut messages = Vec::new();
248
249 let enhanced_system_prompt_with_summary = {
251 let mut enhanced_prompt = enhanced_system_prompt;
252 if let Some(memory) = &memory_clone {
253 if let Some(composite_memory) = memory.as_any().downcast_ref::<crate::memory::composite_memory::CompositeMemory>() {
256 match composite_memory.get_summary().await {
258 Ok(Some(summary)) => {
259 enhanced_prompt = format!("{}\n\nPrevious conversation summary: {}", enhanced_prompt, summary);
261 log::info!("Summary appended to system prompt");
262 },
263 Ok(None) => {
264 log::info!("No summary content found");
265 },
266 Err(e) => {
267 log::warn!("Error getting summary: {}", e);
268 }
269 }
270 } else {
271 match memory.load_memory_variables(&std::collections::HashMap::new()).await {
273 Ok(memories) => {
274 if let Some(summary) = memories.get("summary") {
275 if let Some(summary_str) = summary.as_str() {
276 enhanced_prompt = format!("{}\n\nPrevious conversation summary: {}", enhanced_prompt, summary_str);
278 log::info!("Summary retrieved from memory variables and appended to system prompt");
279 }
280 }
281 },
282 Err(e) => {
283 log::warn!("Error getting summary from memory variables: {}", e);
284 }
285 }
286 }
287 }
288 enhanced_prompt
289 };
290
291 messages.push(ModelChatMessage::System(ChatMessageContent {
293 content: enhanced_system_prompt_with_summary,
294 name: None,
295 additional_kwargs: std::collections::HashMap::new(),
296 }));
297
298 if let Some(memory) = &memory_clone {
300 match memory.load_memory_variables(&std::collections::HashMap::new()).await {
301 Ok(memories) => {
302 info!("Loaded memory variables: {:?}", memories);
303 if let Some(chat_history) = memories.get("chat_history") {
304 if let serde_json::Value::Array(messages_array) = chat_history {
305 for message in messages_array {
306 if let serde_json::Value::Object(msg_obj) = message {
307 let role = msg_obj.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
308 let content = msg_obj.get("content").and_then(|v| v.as_str()).unwrap_or("");
309
310 if content.trim().is_empty() {
312 continue;
313 }
314
315 if role == "assistant" && content.contains("user:") && content.contains("assistant:") {
317 continue;
318 }
319
320 match role {
324 "human" | "user" => {
325 log::info!("Loaded human message: content={}", content);
327 messages.push(ModelChatMessage::Human(ChatMessageContent {
328 content: content.to_string(),
329 name: None,
330 additional_kwargs: std::collections::HashMap::new(),
331 }));
332 },
333 "ai" | "assistant" => {
334 log::info!("Loaded AI message: content={}", content);
336 messages.push(ModelChatMessage::AIMessage(ChatMessageContent {
337 content: content.to_string(),
338 name: None,
339 additional_kwargs: std::collections::HashMap::new(),
340 }));
341 },
342 "tool" => {
343 let content_str = content.to_string();
345 log::info!("Loaded tool message: content={}", content_str);
347 messages.push(ModelChatMessage::ToolMessage(ChatMessageContent {
348 content: content_str,
349 name: None,
350 additional_kwargs: std::collections::HashMap::new(),
351 }));
352 },
353 _ => {
354 log::info!("Loaded unknown role message: role={}, content={}", role, content);
356 }
358 }
359 }
360 }
361 }
362 }
363 },
364 Err(e) => {
365 log::warn!("Failed to load memory variables: {}", e);
367 }
368 }
369 }
370
371 messages.push(ModelChatMessage::Human(ChatMessageContent {
373 content: input_text.clone(),
374 name: None,
375 additional_kwargs: std::collections::HashMap::new(),
376 }));
377 log::info!("Messages to be sent to model:");
381 for (i, msg) in messages.iter().enumerate() {
382 match msg {
383 ModelChatMessage::System(content) => {
384 log::info!(" {}. role=system, content={}", i+1, content.content);
385 },
386 ModelChatMessage::Human(content) => {
387 log::info!(" {}. role=user, content={}", i+1, content.content);
388 },
389 ModelChatMessage::AIMessage(content) => {
390 log::info!(" {}. role=assistant, content={}", i+1, content.content);
391 },
392 ModelChatMessage::ToolMessage(content) => {
393 log::info!(" {}. role=tool, content={}", i+1, content.content);
394 },
395 }
396 }
397
398 let result = model.invoke(messages).await;
400
401 match result {
402 Ok(completion) => {
403 let content = match completion.message {
405 ModelChatMessage::AIMessage(content) => content.content,
406 _ => { format!("{},{:?}", "Non-AI message received", completion.message) }
407 };
408
409 let model_name = model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string());
411
412 if let Some(memory) = &memory_clone {
414 let mut inputs = std::collections::HashMap::new();
415 inputs.insert("input".to_string(), serde_json::Value::String(input_text.clone()));
416
417 let processed_content = if content.starts_with('"') && content.ends_with('"') {
419 match serde_json::from_str::<serde_json::Value>(&content) {
421 Ok(serde_json::Value::String(s)) => s,
422 _ => content.clone(),
423 }
424 } else if content.starts_with('{') && content.ends_with('}') {
425 match serde_json::from_str::<serde_json::Value>(&content) {
427 Ok(json_obj) => {
428 if let Some(content_value) = json_obj.get("content") {
430 if let Some(content_str) = content_value.as_str() {
431 content_str.to_string()
432 } else {
433 content.clone()
434 }
435 } else {
436 content.clone()
437 }
438 },
439 _ => content.clone(),
440 }
441 } else {
442 content.clone()
443 };
444
445 let mut outputs = std::collections::HashMap::new();
446 outputs.insert("output".to_string(), serde_json::Value::String(processed_content));
447
448 if let Err(e) = memory.save_context(&inputs, &outputs).await {
449 log::warn!("Failed to save context to memory: {}", e);
450 }
451 }
452
453 if let Ok(parsed_output) = parse_model_output(&content) {
456 match parsed_output {
457 AgentOutput::Action(action) => {
458 return Ok(AgentOutput::Action(action));
460 }
461 AgentOutput::Finish(_) => {
462 let mut return_values = std::collections::HashMap::new();
464 return_values.insert("answer".to_string(), content.clone());
465 return_values.insert("model".to_string(), model_name);
466 return Ok(AgentOutput::Finish(AgentFinish { return_values }));
467 }
468 }
469 } else {
470 if content.contains("call_tool") {
473 if let Ok(agent_action) = parse_tool_call_from_content(&content) {
476 Ok(AgentOutput::Action(agent_action))
477 } else {
478 let mut return_values = std::collections::HashMap::new();
480 return_values.insert("answer".to_string(), content.clone());
481 return_values.insert("model".to_string(), model_name);
482 Ok(AgentOutput::Finish(AgentFinish { return_values }))
483 }
484 } else {
485 let mut return_values = std::collections::HashMap::new();
487 return_values.insert("answer".to_string(), content.clone());
488 return_values.insert("model".to_string(), model_name);
489 Ok(AgentOutput::Finish(AgentFinish { return_values }))
490 }
491 }
492 }
493 Err(e) => {
494 let model_name = if let Some(ref model) = openai_model_clone {
497 model.model_name().map(|s| s.to_string()).unwrap_or("unknown".to_string())
498 } else {
499 "unknown".to_string()
500 };
501
502 if let Some(memory) = &memory_clone {
504 let mut inputs = std::collections::HashMap::new();
505 inputs.insert("input".to_string(), serde_json::Value::String(input_text.clone()));
506
507 let mut outputs = std::collections::HashMap::new();
508 outputs.insert("output".to_string(), serde_json::Value::String(format!("Model invocation failed: {}", e)));
509
510 if let Err(e) = memory.save_context(&inputs, &outputs).await {
511 log::warn!("Failed to save context to memory: {}", e);
512 }
513 }
514
515 let mut return_values = std::collections::HashMap::new();
516 return_values.insert("answer".to_string(), format!("Model invocation failed: {}", e));
517 return_values.insert("model".to_string(), model_name);
518 Ok(AgentOutput::Finish(AgentFinish { return_values }))
519 }
520 }
521 })
522 }
523
524 fn clone_to_owned(
525 &self,
526 ) -> Box<dyn Runnable<std::collections::HashMap<String, String>, AgentOutput> + Send + Sync>
527 {
528 Box::new(self.clone())
529 }
530}
531
532fn extract_json_object(content: &str) -> Option<String> {
534 if let Some(start) = content.find('{') {
536 if let Some(end) = content.rfind('}') {
537 if end > start {
538 let json_str = &content[start..=end];
540
541 if let Ok(value) = serde_json::from_str::<serde_json::Value>(json_str) {
543 if value.is_object() {
544 return Some(json_str.to_string());
545 }
546 }
547 }
548 }
549 }
550 None
551}
552
553fn parse_tool_call_from_content(content: &str) -> Result<AgentAction, anyhow::Error> {
555 if let Some(json_str) = extract_json_object(content) {
557 let value: Value = serde_json::from_str(&json_str)?;
559
560 if let Some(call_tool) = value.get("call_tool").and_then(|v| v.as_object()) {
562 let tool_name = call_tool
564 .get("name")
565 .and_then(|v| v.as_str())
566 .ok_or_else(|| anyhow::anyhow!("Missing tool name"))?
567 .to_string();
568
569 let tool_input = call_tool
571 .get("parameters")
572 .cloned()
573 .unwrap_or(Value::Object(serde_json::Map::new()))
574 .to_string();
575
576 let action = AgentAction {
578 tool: tool_name,
579 tool_input,
580 log: content.to_string(),
581 thought: None,
582 };
583
584 return Ok(action);
585 }
586 }
587
588 Err(anyhow::anyhow!("Failed to parse tool call from content"))
590}