Skip to main content

serdes_ai_tools/builtin/
code_execution.rs

1//! Code execution tool for running code in a sandbox.
2//!
3//! This module provides a configurable code execution tool that can
4//! execute code in various programming languages.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use std::str::FromStr;
10use std::time::Duration;
11
12use crate::{
13    definition::ToolDefinition,
14    errors::ToolError,
15    return_types::{ToolResult, ToolReturn},
16    schema::SchemaBuilder,
17    tool::Tool,
18    RunContext,
19};
20
21/// Configuration for the code execution tool.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CodeExecutionConfig {
24    /// Maximum execution time.
25    #[serde(with = "humantime_serde")]
26    pub timeout: Duration,
27    /// Maximum output size in bytes.
28    pub max_output_size: usize,
29    /// Allowed languages.
30    pub allowed_languages: Vec<ProgrammingLanguage>,
31    /// Whether to capture stderr.
32    pub capture_stderr: bool,
33    /// Working directory.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub working_dir: Option<String>,
36    /// Environment variables.
37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
38    pub env_vars: Vec<(String, String)>,
39}
40
41impl Default for CodeExecutionConfig {
42    fn default() -> Self {
43        Self {
44            timeout: Duration::from_secs(30),
45            max_output_size: 1024 * 1024, // 1MB
46            allowed_languages: vec![ProgrammingLanguage::Python, ProgrammingLanguage::JavaScript],
47            capture_stderr: true,
48            working_dir: None,
49            env_vars: Vec::new(),
50        }
51    }
52}
53
54impl CodeExecutionConfig {
55    /// Create a new config with defaults.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Set timeout.
62    #[must_use]
63    pub fn timeout(mut self, timeout: Duration) -> Self {
64        self.timeout = timeout;
65        self
66    }
67
68    /// Set timeout in seconds.
69    #[must_use]
70    pub fn timeout_secs(self, secs: u64) -> Self {
71        self.timeout(Duration::from_secs(secs))
72    }
73
74    /// Set max output size.
75    #[must_use]
76    pub fn max_output_size(mut self, size: usize) -> Self {
77        self.max_output_size = size;
78        self
79    }
80
81    /// Set allowed languages.
82    #[must_use]
83    pub fn allowed_languages(mut self, langs: Vec<ProgrammingLanguage>) -> Self {
84        self.allowed_languages = langs;
85        self
86    }
87
88    /// Add an allowed language.
89    #[must_use]
90    pub fn allow_language(mut self, lang: ProgrammingLanguage) -> Self {
91        if !self.allowed_languages.contains(&lang) {
92            self.allowed_languages.push(lang);
93        }
94        self
95    }
96
97    /// Set capture stderr.
98    #[must_use]
99    pub fn capture_stderr(mut self, capture: bool) -> Self {
100        self.capture_stderr = capture;
101        self
102    }
103
104    /// Add an environment variable.
105    #[must_use]
106    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
107        self.env_vars.push((key.into(), value.into()));
108        self
109    }
110}
111
112/// Programming languages supported for execution.
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "lowercase")]
115pub enum ProgrammingLanguage {
116    /// Python 3.
117    Python,
118    /// JavaScript (Node.js).
119    JavaScript,
120    /// TypeScript.
121    TypeScript,
122    /// Ruby.
123    Ruby,
124    /// Go.
125    Go,
126    /// Rust.
127    Rust,
128    /// Shell/Bash.
129    Shell,
130    /// SQL.
131    Sql,
132}
133
134impl ProgrammingLanguage {
135    /// Get the language name as a string.
136    #[must_use]
137    pub fn as_str(&self) -> &'static str {
138        match self {
139            Self::Python => "python",
140            Self::JavaScript => "javascript",
141            Self::TypeScript => "typescript",
142            Self::Ruby => "ruby",
143            Self::Go => "go",
144            Self::Rust => "rust",
145            Self::Shell => "shell",
146            Self::Sql => "sql",
147        }
148    }
149
150    /// Get all language names for schema enum.
151    #[must_use]
152    pub fn all_names() -> &'static [&'static str] {
153        &[
154            "python",
155            "javascript",
156            "typescript",
157            "ruby",
158            "go",
159            "rust",
160            "shell",
161            "sql",
162        ]
163    }
164}
165
166impl std::fmt::Display for ProgrammingLanguage {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(f, "{}", self.as_str())
169    }
170}
171
172impl std::str::FromStr for ProgrammingLanguage {
173    type Err = String;
174
175    fn from_str(s: &str) -> Result<Self, Self::Err> {
176        match s.to_lowercase().as_str() {
177            "python" | "py" => Ok(Self::Python),
178            "javascript" | "js" => Ok(Self::JavaScript),
179            "typescript" | "ts" => Ok(Self::TypeScript),
180            "ruby" | "rb" => Ok(Self::Ruby),
181            "go" | "golang" => Ok(Self::Go),
182            "rust" | "rs" => Ok(Self::Rust),
183            "shell" | "bash" | "sh" => Ok(Self::Shell),
184            "sql" => Ok(Self::Sql),
185            _ => Err(format!("Unknown language: {}", s)),
186        }
187    }
188}
189
190/// Result of code execution.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ExecutionResult {
193    /// Standard output.
194    pub stdout: String,
195    /// Standard error.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub stderr: Option<String>,
198    /// Exit code.
199    pub exit_code: i32,
200    /// Execution time in milliseconds.
201    pub execution_time_ms: u64,
202    /// Whether execution timed out.
203    pub timed_out: bool,
204}
205
206impl ExecutionResult {
207    /// Check if execution was successful.
208    #[must_use]
209    pub fn is_success(&self) -> bool {
210        self.exit_code == 0 && !self.timed_out
211    }
212}
213
214/// Code execution tool.
215///
216/// This tool allows agents to execute code in a sandboxed environment.
217/// It requires integration with an external code execution service.
218///
219/// # Safety
220///
221/// Code execution is inherently dangerous. This tool should:
222/// - Always run in a sandboxed environment
223/// - Have strict resource limits
224/// - Only be used with trusted agents
225///
226/// # Example
227///
228/// ```ignore
229/// use serdes_ai_tools::builtin::{CodeExecutionTool, CodeExecutionConfig, ProgrammingLanguage};
230///
231/// let tool = CodeExecutionTool::with_config(
232///     CodeExecutionConfig::new()
233///         .timeout_secs(10)
234///         .allowed_languages(vec![ProgrammingLanguage::Python])
235/// );
236/// ```
237pub struct CodeExecutionTool {
238    config: CodeExecutionConfig,
239}
240
241impl CodeExecutionTool {
242    /// Create a new code execution tool with default config.
243    #[must_use]
244    pub fn new() -> Self {
245        Self {
246            config: CodeExecutionConfig::default(),
247        }
248    }
249
250    /// Create with a specific config.
251    #[must_use]
252    pub fn with_config(config: CodeExecutionConfig) -> Self {
253        Self { config }
254    }
255
256    /// Get the tool schema.
257    fn schema(&self) -> JsonValue {
258        let lang_names: Vec<&str> = self
259            .config
260            .allowed_languages
261            .iter()
262            .map(|l| l.as_str())
263            .collect();
264
265        SchemaBuilder::new()
266            .enum_values(
267                "language",
268                "The programming language to execute",
269                &lang_names,
270                true,
271            )
272            .string("code", "The code to execute", true)
273            .string(
274                "stdin",
275                "Optional input to provide to the program via stdin",
276                false,
277            )
278            .build()
279            .expect("SchemaBuilder JSON serialization failed")
280    }
281
282    /// Execute code (stub - integrate with actual sandbox).
283    async fn execute(
284        &self,
285        language: ProgrammingLanguage,
286        code: &str,
287        _stdin: Option<&str>,
288    ) -> ExecutionResult {
289        // This is a stub implementation.
290        // In a real implementation, you would:
291        // 1. Send the code to a sandbox service (e.g., Docker, Firecracker, etc.)
292        // 2. Execute with proper resource limits
293        // 3. Capture output and handle timeouts
294
295        ExecutionResult {
296            stdout: format!(
297                "[Placeholder] Would execute {} code:\n{}\n\n\
298                 Integrate with a sandbox service for real execution.",
299                language, code
300            ),
301            stderr: None,
302            exit_code: 0,
303            execution_time_ms: 0,
304            timed_out: false,
305        }
306    }
307}
308
309impl Default for CodeExecutionTool {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315#[async_trait]
316impl<Deps: Send + Sync> Tool<Deps> for CodeExecutionTool {
317    fn definition(&self) -> ToolDefinition {
318        ToolDefinition::new("code_execution", "Execute code in a sandboxed environment")
319            .with_parameters(self.schema())
320    }
321
322    async fn call(&self, _ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
323        let language_str = args
324            .get("language")
325            .and_then(|v| v.as_str())
326            .ok_or_else(|| {
327                ToolError::validation_error(
328                    "code_execution",
329                    Some("language".to_string()),
330                    "Missing 'language' field",
331                )
332            })?;
333
334        let language = ProgrammingLanguage::from_str(language_str).map_err(|_| {
335            ToolError::validation_error(
336                "code_execution",
337                Some("language".to_string()),
338                format!("Unknown language: {}", language_str),
339            )
340        })?;
341
342        if !self.config.allowed_languages.contains(&language) {
343            return Err(ToolError::validation_error(
344                "code_execution",
345                Some("language".to_string()),
346                format!(
347                    "Language '{}' is not allowed. Allowed: {:?}",
348                    language, self.config.allowed_languages
349                ),
350            ));
351        }
352
353        let code = args.get("code").and_then(|v| v.as_str()).ok_or_else(|| {
354            ToolError::validation_error(
355                "code_execution",
356                Some("code".to_string()),
357                "Missing 'code' field",
358            )
359        })?;
360
361        if code.trim().is_empty() {
362            return Err(ToolError::validation_error(
363                "code_execution",
364                Some("code".to_string()),
365                "Code cannot be empty",
366            ));
367        }
368
369        let stdin = args.get("stdin").and_then(|v| v.as_str());
370
371        let result = self.execute(language, code, stdin).await;
372
373        let output = serde_json::json!({
374            "success": result.is_success(),
375            "stdout": result.stdout,
376            "stderr": result.stderr,
377            "exit_code": result.exit_code,
378            "execution_time_ms": result.execution_time_ms,
379            "timed_out": result.timed_out
380        });
381
382        Ok(ToolReturn::json(output))
383    }
384
385    fn max_retries(&self) -> Option<u32> {
386        Some(1)
387    }
388}
389
390impl std::fmt::Debug for CodeExecutionTool {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("CodeExecutionTool")
393            .field("config", &self.config)
394            .finish()
395    }
396}
397
398/// Trait for code execution providers.
399#[allow(async_fn_in_trait)]
400pub trait CodeExecutor: Send + Sync {
401    /// Execute code in a sandbox.
402    async fn execute(
403        &self,
404        language: ProgrammingLanguage,
405        code: &str,
406        stdin: Option<&str>,
407        config: &CodeExecutionConfig,
408    ) -> Result<ExecutionResult, ToolError>;
409}
410
411/// Serde helper for Duration.
412mod humantime_serde {
413    use serde::{Deserialize, Deserializer, Serialize, Serializer};
414    use std::time::Duration;
415
416    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
417    where
418        S: Serializer,
419    {
420        duration.as_secs().serialize(serializer)
421    }
422
423    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
424    where
425        D: Deserializer<'de>,
426    {
427        let secs = u64::deserialize(deserializer)?;
428        Ok(Duration::from_secs(secs))
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_code_execution_config() {
438        let config = CodeExecutionConfig::new()
439            .timeout_secs(10)
440            .max_output_size(1024)
441            .allowed_languages(vec![ProgrammingLanguage::Python]);
442
443        assert_eq!(config.timeout, Duration::from_secs(10));
444        assert_eq!(config.max_output_size, 1024);
445        assert_eq!(config.allowed_languages.len(), 1);
446    }
447
448    #[test]
449    fn test_programming_language() {
450        assert_eq!(ProgrammingLanguage::Python.as_str(), "python");
451        assert_eq!(
452            ProgrammingLanguage::from_str("python"),
453            Ok(ProgrammingLanguage::Python)
454        );
455        assert_eq!(
456            ProgrammingLanguage::from_str("js"),
457            Ok(ProgrammingLanguage::JavaScript)
458        );
459        assert!(ProgrammingLanguage::from_str("unknown").is_err());
460    }
461
462    #[test]
463    fn test_code_execution_tool_definition() {
464        let tool = CodeExecutionTool::new();
465        let def = <CodeExecutionTool as Tool<()>>::definition(&tool);
466        assert_eq!(def.name, "code_execution");
467        let required = def
468            .parameters()
469            .get("required")
470            .and_then(|value| value.as_array())
471            .unwrap();
472        assert!(required
473            .iter()
474            .any(|value| value.as_str() == Some("language")));
475        assert!(required.iter().any(|value| value.as_str() == Some("code")));
476    }
477
478    #[tokio::test]
479    async fn test_code_execution_tool_call() {
480        let tool = CodeExecutionTool::new();
481        let ctx = RunContext::minimal("test");
482
483        let result = tool
484            .call(
485                &ctx,
486                serde_json::json!({
487                    "language": "python",
488                    "code": "print('hello')"
489                }),
490            )
491            .await
492            .unwrap();
493
494        assert!(!result.is_error());
495        let json = result.as_json().unwrap();
496        assert!(json["success"].as_bool().unwrap());
497    }
498
499    #[tokio::test]
500    async fn test_code_execution_disallowed_language() {
501        let tool = CodeExecutionTool::with_config(
502            CodeExecutionConfig::new().allowed_languages(vec![ProgrammingLanguage::Python]),
503        );
504        let ctx = RunContext::minimal("test");
505
506        let result = tool
507            .call(
508                &ctx,
509                serde_json::json!({
510                    "language": "javascript",
511                    "code": "console.log('hi')"
512                }),
513            )
514            .await;
515
516        assert!(matches!(result, Err(ToolError::ValidationFailed { .. })));
517    }
518
519    #[tokio::test]
520    async fn test_code_execution_missing_code() {
521        let tool = CodeExecutionTool::new();
522        let ctx = RunContext::minimal("test");
523
524        let result = tool
525            .call(&ctx, serde_json::json!({"language": "python"}))
526            .await;
527
528        assert!(matches!(result, Err(ToolError::ValidationFailed { .. })));
529    }
530
531    #[test]
532    fn test_execution_result() {
533        let success = ExecutionResult {
534            stdout: "output".to_string(),
535            stderr: None,
536            exit_code: 0,
537            execution_time_ms: 100,
538            timed_out: false,
539        };
540        assert!(success.is_success());
541
542        let failure = ExecutionResult {
543            stdout: "".to_string(),
544            stderr: Some("error".to_string()),
545            exit_code: 1,
546            execution_time_ms: 100,
547            timed_out: false,
548        };
549        assert!(!failure.is_success());
550
551        let timeout = ExecutionResult {
552            stdout: "".to_string(),
553            stderr: None,
554            exit_code: 0,
555            execution_time_ms: 30000,
556            timed_out: true,
557        };
558        assert!(!timeout.is_success());
559    }
560}