Skip to main content

soul_core/executor/
llm.rs

1//! LLM executor — delegates tool calls to an LLM.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use tokio::sync::mpsc;
8
9use crate::error::SoulResult;
10use crate::tool::ToolOutput;
11use crate::types::ToolDefinition;
12
13use super::ToolExecutor;
14
15/// Request to an LLM for tool execution.
16#[derive(Debug, Clone)]
17pub struct LlmExecutorRequest {
18    pub model: Option<String>,
19    pub system_prompt: Option<String>,
20    pub user_message: String,
21}
22
23/// Function type for LLM delegation.
24pub type LlmFn = Arc<
25    dyn Fn(LlmExecutorRequest) -> Pin<Box<dyn Future<Output = SoulResult<String>> + Send>>
26        + Send
27        + Sync,
28>;
29
30/// Executes tools by delegating to an LLM.
31pub struct LlmExecutor {
32    llm_fn: LlmFn,
33    default_model: Option<String>,
34    default_system_prompt: Option<String>,
35}
36
37impl LlmExecutor {
38    pub fn new(llm_fn: LlmFn) -> Self {
39        Self {
40            llm_fn,
41            default_model: None,
42            default_system_prompt: None,
43        }
44    }
45
46    pub fn with_model(mut self, model: impl Into<String>) -> Self {
47        self.default_model = Some(model.into());
48        self
49    }
50
51    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
52        self.default_system_prompt = Some(prompt.into());
53        self
54    }
55}
56
57#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
58#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
59impl ToolExecutor for LlmExecutor {
60    async fn execute(
61        &self,
62        definition: &ToolDefinition,
63        _call_id: &str,
64        arguments: serde_json::Value,
65        _partial_tx: Option<mpsc::UnboundedSender<String>>,
66    ) -> SoulResult<ToolOutput> {
67        let user_message = if let Some(text) = arguments.get("text").and_then(|v| v.as_str()) {
68            text.to_string()
69        } else {
70            format!(
71                "Execute tool '{}' with arguments: {}",
72                definition.name,
73                serde_json::to_string_pretty(&arguments).unwrap_or_default()
74            )
75        };
76
77        let request = LlmExecutorRequest {
78            model: self.default_model.clone(),
79            system_prompt: self.default_system_prompt.clone(),
80            user_message,
81        };
82
83        let result = (self.llm_fn)(request).await?;
84        Ok(ToolOutput::success(result))
85    }
86
87    fn executor_name(&self) -> &str {
88        "llm"
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use serde_json::json;
96
97    fn test_def() -> ToolDefinition {
98        ToolDefinition {
99            name: "summarize".into(),
100            description: "Summarize text".into(),
101            input_schema: json!({"type": "object"}),
102        }
103    }
104
105    fn mock_llm_fn() -> LlmFn {
106        Arc::new(|req: LlmExecutorRequest| {
107            Box::pin(async move { Ok(format!("LLM response to: {}", req.user_message)) })
108        })
109    }
110
111    #[tokio::test]
112    async fn executes_with_text_argument() {
113        let executor = LlmExecutor::new(mock_llm_fn());
114        let result = executor
115            .execute(
116                &test_def(),
117                "c1",
118                json!({"text": "Please summarize this"}),
119                None,
120            )
121            .await
122            .unwrap();
123        assert!(result.content.contains("Please summarize this"));
124    }
125
126    #[tokio::test]
127    async fn executes_without_text_argument() {
128        let executor = LlmExecutor::new(mock_llm_fn());
129        let result = executor
130            .execute(&test_def(), "c1", json!({"data": [1, 2, 3]}), None)
131            .await
132            .unwrap();
133        assert!(result.content.contains("summarize"));
134    }
135
136    #[tokio::test]
137    async fn passes_model_and_system_prompt() {
138        let llm_fn: LlmFn = Arc::new(|req: LlmExecutorRequest| {
139            Box::pin(async move {
140                Ok(format!(
141                    "model={}, sys={}",
142                    req.model.unwrap_or_default(),
143                    req.system_prompt.unwrap_or_default()
144                ))
145            })
146        });
147
148        let executor = LlmExecutor::new(llm_fn)
149            .with_model("haiku")
150            .with_system_prompt("Be concise");
151        let result = executor
152            .execute(&test_def(), "c1", json!({"text": "test"}), None)
153            .await
154            .unwrap();
155        assert!(result.content.contains("model=haiku"));
156        assert!(result.content.contains("sys=Be concise"));
157    }
158
159    #[test]
160    fn executor_name() {
161        let executor = LlmExecutor::new(mock_llm_fn());
162        assert_eq!(executor.executor_name(), "llm");
163    }
164
165    #[test]
166    fn is_send_sync() {
167        fn assert_send_sync<T: Send + Sync>() {}
168        assert_send_sync::<LlmExecutor>();
169    }
170}