1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use std::str::FromStr;
10use std::time::Duration;
11
12use crate::{
13 definition::ToolDefinition,
14 errors::ToolError,
15 return_types::{ToolResult, ToolReturn},
16 schema::SchemaBuilder,
17 tool::Tool,
18 RunContext,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CodeExecutionConfig {
24 #[serde(with = "humantime_serde")]
26 pub timeout: Duration,
27 pub max_output_size: usize,
29 pub allowed_languages: Vec<ProgrammingLanguage>,
31 pub capture_stderr: bool,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub working_dir: Option<String>,
36 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub env_vars: Vec<(String, String)>,
39}
40
41impl Default for CodeExecutionConfig {
42 fn default() -> Self {
43 Self {
44 timeout: Duration::from_secs(30),
45 max_output_size: 1024 * 1024, allowed_languages: vec![ProgrammingLanguage::Python, ProgrammingLanguage::JavaScript],
47 capture_stderr: true,
48 working_dir: None,
49 env_vars: Vec::new(),
50 }
51 }
52}
53
54impl CodeExecutionConfig {
55 #[must_use]
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 #[must_use]
63 pub fn timeout(mut self, timeout: Duration) -> Self {
64 self.timeout = timeout;
65 self
66 }
67
68 #[must_use]
70 pub fn timeout_secs(self, secs: u64) -> Self {
71 self.timeout(Duration::from_secs(secs))
72 }
73
74 #[must_use]
76 pub fn max_output_size(mut self, size: usize) -> Self {
77 self.max_output_size = size;
78 self
79 }
80
81 #[must_use]
83 pub fn allowed_languages(mut self, langs: Vec<ProgrammingLanguage>) -> Self {
84 self.allowed_languages = langs;
85 self
86 }
87
88 #[must_use]
90 pub fn allow_language(mut self, lang: ProgrammingLanguage) -> Self {
91 if !self.allowed_languages.contains(&lang) {
92 self.allowed_languages.push(lang);
93 }
94 self
95 }
96
97 #[must_use]
99 pub fn capture_stderr(mut self, capture: bool) -> Self {
100 self.capture_stderr = capture;
101 self
102 }
103
104 #[must_use]
106 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
107 self.env_vars.push((key.into(), value.into()));
108 self
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "lowercase")]
115pub enum ProgrammingLanguage {
116 Python,
118 JavaScript,
120 TypeScript,
122 Ruby,
124 Go,
126 Rust,
128 Shell,
130 Sql,
132}
133
134impl ProgrammingLanguage {
135 #[must_use]
137 pub fn as_str(&self) -> &'static str {
138 match self {
139 Self::Python => "python",
140 Self::JavaScript => "javascript",
141 Self::TypeScript => "typescript",
142 Self::Ruby => "ruby",
143 Self::Go => "go",
144 Self::Rust => "rust",
145 Self::Shell => "shell",
146 Self::Sql => "sql",
147 }
148 }
149
150 #[must_use]
152 pub fn all_names() -> &'static [&'static str] {
153 &[
154 "python",
155 "javascript",
156 "typescript",
157 "ruby",
158 "go",
159 "rust",
160 "shell",
161 "sql",
162 ]
163 }
164}
165
166impl std::fmt::Display for ProgrammingLanguage {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 write!(f, "{}", self.as_str())
169 }
170}
171
172impl std::str::FromStr for ProgrammingLanguage {
173 type Err = String;
174
175 fn from_str(s: &str) -> Result<Self, Self::Err> {
176 match s.to_lowercase().as_str() {
177 "python" | "py" => Ok(Self::Python),
178 "javascript" | "js" => Ok(Self::JavaScript),
179 "typescript" | "ts" => Ok(Self::TypeScript),
180 "ruby" | "rb" => Ok(Self::Ruby),
181 "go" | "golang" => Ok(Self::Go),
182 "rust" | "rs" => Ok(Self::Rust),
183 "shell" | "bash" | "sh" => Ok(Self::Shell),
184 "sql" => Ok(Self::Sql),
185 _ => Err(format!("Unknown language: {}", s)),
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ExecutionResult {
193 pub stdout: String,
195 #[serde(skip_serializing_if = "Option::is_none")]
197 pub stderr: Option<String>,
198 pub exit_code: i32,
200 pub execution_time_ms: u64,
202 pub timed_out: bool,
204}
205
206impl ExecutionResult {
207 #[must_use]
209 pub fn is_success(&self) -> bool {
210 self.exit_code == 0 && !self.timed_out
211 }
212}
213
214pub struct CodeExecutionTool {
238 config: CodeExecutionConfig,
239}
240
241impl CodeExecutionTool {
242 #[must_use]
244 pub fn new() -> Self {
245 Self {
246 config: CodeExecutionConfig::default(),
247 }
248 }
249
250 #[must_use]
252 pub fn with_config(config: CodeExecutionConfig) -> Self {
253 Self { config }
254 }
255
256 fn schema(&self) -> JsonValue {
258 let lang_names: Vec<&str> = self
259 .config
260 .allowed_languages
261 .iter()
262 .map(|l| l.as_str())
263 .collect();
264
265 SchemaBuilder::new()
266 .enum_values(
267 "language",
268 "The programming language to execute",
269 &lang_names,
270 true,
271 )
272 .string("code", "The code to execute", true)
273 .string(
274 "stdin",
275 "Optional input to provide to the program via stdin",
276 false,
277 )
278 .build()
279 .expect("SchemaBuilder JSON serialization failed")
280 }
281
282 async fn execute(
284 &self,
285 language: ProgrammingLanguage,
286 code: &str,
287 _stdin: Option<&str>,
288 ) -> ExecutionResult {
289 ExecutionResult {
296 stdout: format!(
297 "[Placeholder] Would execute {} code:\n{}\n\n\
298 Integrate with a sandbox service for real execution.",
299 language, code
300 ),
301 stderr: None,
302 exit_code: 0,
303 execution_time_ms: 0,
304 timed_out: false,
305 }
306 }
307}
308
309impl Default for CodeExecutionTool {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[async_trait]
316impl<Deps: Send + Sync> Tool<Deps> for CodeExecutionTool {
317 fn definition(&self) -> ToolDefinition {
318 ToolDefinition::new("code_execution", "Execute code in a sandboxed environment")
319 .with_parameters(self.schema())
320 }
321
322 async fn call(&self, _ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
323 let language_str = args
324 .get("language")
325 .and_then(|v| v.as_str())
326 .ok_or_else(|| {
327 ToolError::validation_error(
328 "code_execution",
329 Some("language".to_string()),
330 "Missing 'language' field",
331 )
332 })?;
333
334 let language = ProgrammingLanguage::from_str(language_str).map_err(|_| {
335 ToolError::validation_error(
336 "code_execution",
337 Some("language".to_string()),
338 format!("Unknown language: {}", language_str),
339 )
340 })?;
341
342 if !self.config.allowed_languages.contains(&language) {
343 return Err(ToolError::validation_error(
344 "code_execution",
345 Some("language".to_string()),
346 format!(
347 "Language '{}' is not allowed. Allowed: {:?}",
348 language, self.config.allowed_languages
349 ),
350 ));
351 }
352
353 let code = args.get("code").and_then(|v| v.as_str()).ok_or_else(|| {
354 ToolError::validation_error(
355 "code_execution",
356 Some("code".to_string()),
357 "Missing 'code' field",
358 )
359 })?;
360
361 if code.trim().is_empty() {
362 return Err(ToolError::validation_error(
363 "code_execution",
364 Some("code".to_string()),
365 "Code cannot be empty",
366 ));
367 }
368
369 let stdin = args.get("stdin").and_then(|v| v.as_str());
370
371 let result = self.execute(language, code, stdin).await;
372
373 let output = serde_json::json!({
374 "success": result.is_success(),
375 "stdout": result.stdout,
376 "stderr": result.stderr,
377 "exit_code": result.exit_code,
378 "execution_time_ms": result.execution_time_ms,
379 "timed_out": result.timed_out
380 });
381
382 Ok(ToolReturn::json(output))
383 }
384
385 fn max_retries(&self) -> Option<u32> {
386 Some(1)
387 }
388}
389
390impl std::fmt::Debug for CodeExecutionTool {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.debug_struct("CodeExecutionTool")
393 .field("config", &self.config)
394 .finish()
395 }
396}
397
398#[allow(async_fn_in_trait)]
400pub trait CodeExecutor: Send + Sync {
401 async fn execute(
403 &self,
404 language: ProgrammingLanguage,
405 code: &str,
406 stdin: Option<&str>,
407 config: &CodeExecutionConfig,
408 ) -> Result<ExecutionResult, ToolError>;
409}
410
411mod humantime_serde {
413 use serde::{Deserialize, Deserializer, Serialize, Serializer};
414 use std::time::Duration;
415
416 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
417 where
418 S: Serializer,
419 {
420 duration.as_secs().serialize(serializer)
421 }
422
423 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
424 where
425 D: Deserializer<'de>,
426 {
427 let secs = u64::deserialize(deserializer)?;
428 Ok(Duration::from_secs(secs))
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_code_execution_config() {
438 let config = CodeExecutionConfig::new()
439 .timeout_secs(10)
440 .max_output_size(1024)
441 .allowed_languages(vec![ProgrammingLanguage::Python]);
442
443 assert_eq!(config.timeout, Duration::from_secs(10));
444 assert_eq!(config.max_output_size, 1024);
445 assert_eq!(config.allowed_languages.len(), 1);
446 }
447
448 #[test]
449 fn test_programming_language() {
450 assert_eq!(ProgrammingLanguage::Python.as_str(), "python");
451 assert_eq!(
452 ProgrammingLanguage::from_str("python"),
453 Ok(ProgrammingLanguage::Python)
454 );
455 assert_eq!(
456 ProgrammingLanguage::from_str("js"),
457 Ok(ProgrammingLanguage::JavaScript)
458 );
459 assert!(ProgrammingLanguage::from_str("unknown").is_err());
460 }
461
462 #[test]
463 fn test_code_execution_tool_definition() {
464 let tool = CodeExecutionTool::new();
465 let def = <CodeExecutionTool as Tool<()>>::definition(&tool);
466 assert_eq!(def.name, "code_execution");
467 let required = def
468 .parameters()
469 .get("required")
470 .and_then(|value| value.as_array())
471 .unwrap();
472 assert!(required
473 .iter()
474 .any(|value| value.as_str() == Some("language")));
475 assert!(required.iter().any(|value| value.as_str() == Some("code")));
476 }
477
478 #[tokio::test]
479 async fn test_code_execution_tool_call() {
480 let tool = CodeExecutionTool::new();
481 let ctx = RunContext::minimal("test");
482
483 let result = tool
484 .call(
485 &ctx,
486 serde_json::json!({
487 "language": "python",
488 "code": "print('hello')"
489 }),
490 )
491 .await
492 .unwrap();
493
494 assert!(!result.is_error());
495 let json = result.as_json().unwrap();
496 assert!(json["success"].as_bool().unwrap());
497 }
498
499 #[tokio::test]
500 async fn test_code_execution_disallowed_language() {
501 let tool = CodeExecutionTool::with_config(
502 CodeExecutionConfig::new().allowed_languages(vec![ProgrammingLanguage::Python]),
503 );
504 let ctx = RunContext::minimal("test");
505
506 let result = tool
507 .call(
508 &ctx,
509 serde_json::json!({
510 "language": "javascript",
511 "code": "console.log('hi')"
512 }),
513 )
514 .await;
515
516 assert!(matches!(result, Err(ToolError::ValidationFailed { .. })));
517 }
518
519 #[tokio::test]
520 async fn test_code_execution_missing_code() {
521 let tool = CodeExecutionTool::new();
522 let ctx = RunContext::minimal("test");
523
524 let result = tool
525 .call(&ctx, serde_json::json!({"language": "python"}))
526 .await;
527
528 assert!(matches!(result, Err(ToolError::ValidationFailed { .. })));
529 }
530
531 #[test]
532 fn test_execution_result() {
533 let success = ExecutionResult {
534 stdout: "output".to_string(),
535 stderr: None,
536 exit_code: 0,
537 execution_time_ms: 100,
538 timed_out: false,
539 };
540 assert!(success.is_success());
541
542 let failure = ExecutionResult {
543 stdout: "".to_string(),
544 stderr: Some("error".to_string()),
545 exit_code: 1,
546 execution_time_ms: 100,
547 timed_out: false,
548 };
549 assert!(!failure.is_success());
550
551 let timeout = ExecutionResult {
552 stdout: "".to_string(),
553 stderr: None,
554 exit_code: 0,
555 execution_time_ms: 30000,
556 timed_out: true,
557 };
558 assert!(!timeout.is_success());
559 }
560}