soul_core/executor/
llm.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::mpsc;
8
9use crate::error::SoulResult;
10use crate::tool::ToolOutput;
11use crate::types::ToolDefinition;
12
13use super::ToolExecutor;
14
15#[derive(Debug, Clone)]
17pub struct LlmExecutorRequest {
18 pub model: Option<String>,
19 pub system_prompt: Option<String>,
20 pub user_message: String,
21}
22
23pub type LlmFn = Arc<
25 dyn Fn(LlmExecutorRequest) -> Pin<Box<dyn Future<Output = SoulResult<String>> + Send>>
26 + Send
27 + Sync,
28>;
29
30pub struct LlmExecutor {
32 llm_fn: LlmFn,
33 default_model: Option<String>,
34 default_system_prompt: Option<String>,
35}
36
37impl LlmExecutor {
38 pub fn new(llm_fn: LlmFn) -> Self {
39 Self {
40 llm_fn,
41 default_model: None,
42 default_system_prompt: None,
43 }
44 }
45
46 pub fn with_model(mut self, model: impl Into<String>) -> Self {
47 self.default_model = Some(model.into());
48 self
49 }
50
51 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
52 self.default_system_prompt = Some(prompt.into());
53 self
54 }
55}
56
57#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
58#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
59impl ToolExecutor for LlmExecutor {
60 async fn execute(
61 &self,
62 definition: &ToolDefinition,
63 _call_id: &str,
64 arguments: serde_json::Value,
65 _partial_tx: Option<mpsc::UnboundedSender<String>>,
66 ) -> SoulResult<ToolOutput> {
67 let user_message = if let Some(text) = arguments.get("text").and_then(|v| v.as_str()) {
68 text.to_string()
69 } else {
70 format!(
71 "Execute tool '{}' with arguments: {}",
72 definition.name,
73 serde_json::to_string_pretty(&arguments).unwrap_or_default()
74 )
75 };
76
77 let request = LlmExecutorRequest {
78 model: self.default_model.clone(),
79 system_prompt: self.default_system_prompt.clone(),
80 user_message,
81 };
82
83 let result = (self.llm_fn)(request).await?;
84 Ok(ToolOutput::success(result))
85 }
86
87 fn executor_name(&self) -> &str {
88 "llm"
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use serde_json::json;
96
97 fn test_def() -> ToolDefinition {
98 ToolDefinition {
99 name: "summarize".into(),
100 description: "Summarize text".into(),
101 input_schema: json!({"type": "object"}),
102 }
103 }
104
105 fn mock_llm_fn() -> LlmFn {
106 Arc::new(|req: LlmExecutorRequest| {
107 Box::pin(async move { Ok(format!("LLM response to: {}", req.user_message)) })
108 })
109 }
110
111 #[tokio::test]
112 async fn executes_with_text_argument() {
113 let executor = LlmExecutor::new(mock_llm_fn());
114 let result = executor
115 .execute(
116 &test_def(),
117 "c1",
118 json!({"text": "Please summarize this"}),
119 None,
120 )
121 .await
122 .unwrap();
123 assert!(result.content.contains("Please summarize this"));
124 }
125
126 #[tokio::test]
127 async fn executes_without_text_argument() {
128 let executor = LlmExecutor::new(mock_llm_fn());
129 let result = executor
130 .execute(&test_def(), "c1", json!({"data": [1, 2, 3]}), None)
131 .await
132 .unwrap();
133 assert!(result.content.contains("summarize"));
134 }
135
136 #[tokio::test]
137 async fn passes_model_and_system_prompt() {
138 let llm_fn: LlmFn = Arc::new(|req: LlmExecutorRequest| {
139 Box::pin(async move {
140 Ok(format!(
141 "model={}, sys={}",
142 req.model.unwrap_or_default(),
143 req.system_prompt.unwrap_or_default()
144 ))
145 })
146 });
147
148 let executor = LlmExecutor::new(llm_fn)
149 .with_model("haiku")
150 .with_system_prompt("Be concise");
151 let result = executor
152 .execute(&test_def(), "c1", json!({"text": "test"}), None)
153 .await
154 .unwrap();
155 assert!(result.content.contains("model=haiku"));
156 assert!(result.content.contains("sys=Be concise"));
157 }
158
159 #[test]
160 fn executor_name() {
161 let executor = LlmExecutor::new(mock_llm_fn());
162 assert_eq!(executor.executor_name(), "llm");
163 }
164
165 #[test]
166 fn is_send_sync() {
167 fn assert_send_sync<T: Send + Sync>() {}
168 assert_send_sync::<LlmExecutor>();
169 }
170}