1pub mod crypto;
8pub mod functions;
9
10use std::sync::{Arc, Mutex};
11
12use mlua::{Lua, MultiValue, Result as LuaResult, Value};
13
14use crate::error::{SoulError, SoulResult};
15
16pub struct LlmQueryChannel {
21 pub request_tx: std::sync::mpsc::Sender<LlmSubRequest>,
22 pub response_rx: Arc<Mutex<std::sync::mpsc::Receiver<String>>>,
23}
24
25pub enum LlmSubRequest {
27 Single(String),
29 Batched(Vec<String>),
31}
32
33pub struct LuaSandbox {
35 lua: Lua,
36 output_buffer: Arc<Mutex<Vec<String>>>,
37 final_answer: Arc<Mutex<Option<String>>>,
38}
39
40impl LuaSandbox {
41 pub fn new() -> SoulResult<Self> {
46 let lua = Lua::new();
47
48 {
50 let globals = lua.globals();
51 for name in &[
52 "io",
53 "os",
54 "debug",
55 "require",
56 "loadfile",
57 "dofile",
58 "load",
59 ] {
60 globals
61 .set(*name, Value::Nil)
62 .map_err(|e| SoulError::ToolExecution {
63 tool_name: "lua_sandbox".into(),
64 message: format!("Failed to remove {name}: {e}"),
65 })?;
66 }
67 }
68
69 let output_buffer = Arc::new(Mutex::new(Vec::new()));
70 let final_answer = Arc::new(Mutex::new(None));
71
72 let buf = output_buffer.clone();
74 lua.globals()
75 .set(
76 "print_var",
77 lua.create_function(move |_, val: Value| {
78 let s = lua_value_to_string(&val);
79 buf.lock().unwrap().push(s);
80 Ok(())
81 })
82 .map_err(|e| SoulError::ToolExecution {
83 tool_name: "lua_sandbox".into(),
84 message: format!("Failed to register print_var: {e}"),
85 })?,
86 )
87 .map_err(|e| SoulError::ToolExecution {
88 tool_name: "lua_sandbox".into(),
89 message: format!("Failed to set print_var: {e}"),
90 })?;
91
92 let buf2 = output_buffer.clone();
94 lua.globals()
95 .set(
96 "print",
97 lua.create_function(move |_, vals: MultiValue| {
98 let parts: Vec<String> = vals.iter().map(lua_value_to_string).collect();
99 buf2.lock().unwrap().push(parts.join("\t"));
100 Ok(())
101 })
102 .map_err(|e| SoulError::ToolExecution {
103 tool_name: "lua_sandbox".into(),
104 message: format!("Failed to register print: {e}"),
105 })?,
106 )
107 .map_err(|e| SoulError::ToolExecution {
108 tool_name: "lua_sandbox".into(),
109 message: format!("Failed to set print: {e}"),
110 })?;
111
112 let fa = final_answer.clone();
114 lua.globals()
115 .set(
116 "final_answer",
117 lua.create_function(move |_, val: Value| {
118 let s = lua_value_to_string(&val);
119 *fa.lock().unwrap() = Some(s);
120 Ok(())
121 })
122 .map_err(|e| SoulError::ToolExecution {
123 tool_name: "lua_sandbox".into(),
124 message: format!("Failed to register final_answer: {e}"),
125 })?,
126 )
127 .map_err(|e| SoulError::ToolExecution {
128 tool_name: "lua_sandbox".into(),
129 message: format!("Failed to set final_answer: {e}"),
130 })?;
131
132 let fa2 = final_answer.clone();
134 lua.globals()
135 .set(
136 "final_var",
137 lua.create_function(move |lua, varname: String| {
138 let val: Value = lua.globals().get(varname.clone()).map_err(|e| {
139 mlua::Error::external(format!("final_var: variable '{varname}' not found: {e}"))
140 })?;
141 let s = lua_value_to_string(&val);
142 *fa2.lock().unwrap() = Some(s);
143 Ok(())
144 })
145 .map_err(|e| SoulError::ToolExecution {
146 tool_name: "lua_sandbox".into(),
147 message: format!("Failed to register final_var: {e}"),
148 })?,
149 )
150 .map_err(|e| SoulError::ToolExecution {
151 tool_name: "lua_sandbox".into(),
152 message: format!("Failed to set final_var: {e}"),
153 })?;
154
155 Ok(Self {
156 lua,
157 output_buffer,
158 final_answer,
159 })
160 }
161
162 pub fn lua(&self) -> &Lua {
164 &self.lua
165 }
166
167 pub fn set_string(&self, name: &str, value: &str) -> SoulResult<()> {
169 self.lua
170 .globals()
171 .set(name.to_string(), value.to_string())
172 .map_err(|e| SoulError::ToolExecution {
173 tool_name: "lua_sandbox".into(),
174 message: format!("Failed to set string '{name}': {e}"),
175 })
176 }
177
178 pub fn set_json(&self, name: &str, value: &serde_json::Value) -> SoulResult<()> {
180 let lua_val =
181 json_to_lua(&self.lua, value).map_err(|e| SoulError::ToolExecution {
182 tool_name: "lua_sandbox".into(),
183 message: format!("Failed to convert JSON to Lua for '{name}': {e}"),
184 })?;
185 self.lua
186 .globals()
187 .set(name.to_string(), lua_val)
188 .map_err(|e| SoulError::ToolExecution {
189 tool_name: "lua_sandbox".into(),
190 message: format!("Failed to set JSON '{name}': {e}"),
191 })
192 }
193
194 pub fn exec(&self, code: &str) -> SoulResult<String> {
196 let result: Value = self
197 .lua
198 .load(code)
199 .eval()
200 .map_err(|e| SoulError::ToolExecution {
201 tool_name: "lua_sandbox".into(),
202 message: format!("Lua execution error: {e}"),
203 })?;
204
205 Ok(lua_value_to_string(&result))
206 }
207
208 pub async fn exec_with_timeout(&self, code: &str, timeout_secs: u64) -> SoulResult<String> {
210 let timeout = std::time::Duration::from_secs(timeout_secs);
211
212 let code = code.to_string();
214 let lua_code = code.clone();
215
216 match tokio::time::timeout(timeout, async { self.exec(&lua_code) }).await {
218 Ok(result) => result,
219 Err(_) => Err(SoulError::ToolExecution {
220 tool_name: "lua_sandbox".into(),
221 message: format!("Lua execution timed out after {timeout_secs}s"),
222 }),
223 }
224 }
225
226 pub fn take_output(&self) -> Vec<String> {
228 std::mem::take(&mut self.output_buffer.lock().unwrap())
229 }
230
231 pub fn take_final_answer(&self) -> Option<String> {
233 self.final_answer.lock().unwrap().take()
234 }
235
236 pub fn has_final_answer(&self) -> bool {
238 self.final_answer.lock().unwrap().is_some()
239 }
240
241 pub fn get_global_string(&self, name: &str) -> SoulResult<Option<String>> {
246 let val: Value = self.lua.globals().get(name.to_string()).map_err(|e| {
247 SoulError::ToolExecution {
248 tool_name: "lua_sandbox".into(),
249 message: format!("Failed to read global '{name}': {e}"),
250 }
251 })?;
252 match val {
253 Value::Nil => Ok(None),
254 _ => Ok(Some(lua_value_to_string(&val))),
255 }
256 }
257
258 pub fn register_llm_query(&self, channel: LlmQueryChannel) -> SoulResult<()> {
264 let req_tx = channel.request_tx.clone();
265 let resp_rx = channel.response_rx.clone();
266
267 let req_tx_single = req_tx.clone();
269 let resp_rx_single = resp_rx.clone();
270 self.lua
271 .globals()
272 .set(
273 "llm_query",
274 self.lua
275 .create_function(move |_, prompt: String| {
276 req_tx_single
277 .send(LlmSubRequest::Single(prompt))
278 .map_err(|e| {
279 mlua::Error::external(format!("llm_query channel error: {e}"))
280 })?;
281 let response = resp_rx_single
282 .lock()
283 .unwrap()
284 .recv()
285 .map_err(|e| {
286 mlua::Error::external(format!("llm_query response error: {e}"))
287 })?;
288 Ok(response)
289 })
290 .map_err(|e| SoulError::ToolExecution {
291 tool_name: "lua_sandbox".into(),
292 message: format!("Failed to create llm_query: {e}"),
293 })?,
294 )
295 .map_err(|e| SoulError::ToolExecution {
296 tool_name: "lua_sandbox".into(),
297 message: format!("Failed to set llm_query: {e}"),
298 })?;
299
300 self.lua
302 .globals()
303 .set(
304 "llm_query_batched",
305 self.lua
306 .create_function(move |lua, prompts: Value| {
307 let table = match prompts {
309 Value::Table(t) => t,
310 _ => {
311 return Err(mlua::Error::external(
312 "llm_query_batched expects a table of strings",
313 ));
314 }
315 };
316 let mut prompt_vec = Vec::new();
317 for pair in table.sequence_values::<String>() {
318 prompt_vec.push(pair.map_err(|e| {
319 mlua::Error::external(format!(
320 "llm_query_batched: invalid prompt: {e}"
321 ))
322 })?);
323 }
324
325 req_tx
326 .send(LlmSubRequest::Batched(prompt_vec.clone()))
327 .map_err(|e| {
328 mlua::Error::external(format!(
329 "llm_query_batched channel error: {e}"
330 ))
331 })?;
332
333 let result_table = lua.create_table()?;
335 for i in 0..prompt_vec.len() {
336 let response =
337 resp_rx.lock().unwrap().recv().map_err(|e| {
338 mlua::Error::external(format!(
339 "llm_query_batched response error: {e}"
340 ))
341 })?;
342 result_table.set(i + 1, response)?;
343 }
344 Ok(Value::Table(result_table))
345 })
346 .map_err(|e| SoulError::ToolExecution {
347 tool_name: "lua_sandbox".into(),
348 message: format!("Failed to create llm_query_batched: {e}"),
349 })?,
350 )
351 .map_err(|e| SoulError::ToolExecution {
352 tool_name: "lua_sandbox".into(),
353 message: format!("Failed to set llm_query_batched: {e}"),
354 })?;
355
356 Ok(())
357 }
358}
359
360pub fn json_to_lua(lua: &Lua, value: &serde_json::Value) -> LuaResult<Value> {
364 match value {
365 serde_json::Value::Null => Ok(Value::Nil),
366 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
367 serde_json::Value::Number(n) => {
368 if let Some(i) = n.as_i64() {
369 Ok(Value::Integer(i))
370 } else if let Some(f) = n.as_f64() {
371 Ok(Value::Number(f))
372 } else {
373 Ok(Value::Nil)
374 }
375 }
376 serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
377 serde_json::Value::Array(arr) => {
378 let table = lua.create_table()?;
379 for (i, item) in arr.iter().enumerate() {
380 table.set(i + 1, json_to_lua(lua, item)?)?; }
382 Ok(Value::Table(table))
383 }
384 serde_json::Value::Object(obj) => {
385 let table = lua.create_table()?;
386 for (key, val) in obj {
387 table.set(key.as_str(), json_to_lua(lua, val)?)?;
388 }
389 Ok(Value::Table(table))
390 }
391 }
392}
393
394pub fn lua_to_json(value: &Value) -> serde_json::Value {
396 match value {
397 Value::Nil => serde_json::Value::Null,
398 Value::Boolean(b) => serde_json::Value::Bool(*b),
399 Value::Integer(i) => serde_json::json!(*i),
400 Value::Number(f) => serde_json::json!(*f),
401 Value::String(s) => {
402 serde_json::Value::String(
403 s.to_str().map(|s| s.to_string()).unwrap_or_default(),
404 )
405 }
406 Value::Table(t) => {
407 let len = t.raw_len();
409 if len > 0 {
410 let mut is_array = true;
412 for i in 1..=len {
413 if t.raw_get::<Value>(i).is_err() {
414 is_array = false;
415 break;
416 }
417 }
418
419 if is_array {
420 let mut total_pairs = 0;
422 if let Ok(pairs) = t.clone().pairs::<Value, Value>().try_fold(0, |acc, r| {
423 r.map(|_| acc + 1)
424 }) {
425 total_pairs = pairs;
426 }
427 if total_pairs != len as usize {
428 is_array = false;
429 }
430 }
431
432 if is_array {
433 let arr: Vec<serde_json::Value> = (1..=len)
434 .filter_map(|i| t.raw_get::<Value>(i).ok().map(|v| lua_to_json(&v)))
435 .collect();
436 return serde_json::Value::Array(arr);
437 }
438 }
439
440 let mut map = serde_json::Map::new();
442 if let Ok(pairs) = t.clone().pairs::<Value, Value>().collect::<LuaResult<Vec<_>>>()
443 {
444 for (k, v) in pairs {
445 let key = lua_value_to_string(&k);
446 map.insert(key, lua_to_json(&v));
447 }
448 }
449 serde_json::Value::Object(map)
450 }
451 _ => serde_json::Value::Null,
452 }
453}
454
455pub(crate) fn lua_value_to_string(val: &Value) -> String {
457 match val {
458 Value::Nil => "nil".to_string(),
459 Value::Boolean(b) => b.to_string(),
460 Value::Integer(i) => i.to_string(),
461 Value::Number(f) => f.to_string(),
462 Value::String(s) => s.to_str().map(|s| s.to_string()).unwrap_or_default(),
463 Value::Table(t) => {
464 let json = lua_to_json(&Value::Table(t.clone()));
466 serde_json::to_string_pretty(&json).unwrap_or_else(|_| "table: <?>".to_string())
467 }
468 _ => format!("{val:?}"),
469 }
470}
471
472pub fn extract_lua_blocks(text: &str) -> Vec<String> {
474 let mut blocks = Vec::new();
475 let mut in_block = false;
476 let mut current = String::new();
477
478 for line in text.lines() {
479 let trimmed = line.trim();
480 if !in_block && trimmed.starts_with("```lua") {
481 in_block = true;
482 current.clear();
483 } else if in_block && trimmed == "```" {
484 blocks.push(current.clone());
485 in_block = false;
486 } else if in_block {
487 current.push_str(line);
488 current.push('\n');
489 }
490 }
491
492 blocks
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use serde_json::json;
499
500 #[test]
501 fn sandbox_creation() {
502 let sandbox = LuaSandbox::new().unwrap();
503 let result = sandbox.exec("return 2 + 2").unwrap();
505 assert_eq!(result, "4");
506 }
507
508 #[test]
509 fn sandbox_removes_dangerous_modules() {
510 let sandbox = LuaSandbox::new().unwrap();
511
512 let result = sandbox.exec("return type(io)").unwrap();
514 assert_eq!(result, "nil");
515
516 let result = sandbox.exec("return type(os)").unwrap();
518 assert_eq!(result, "nil");
519
520 let result = sandbox.exec("return type(debug)").unwrap();
522 assert_eq!(result, "nil");
523
524 let result = sandbox.exec("return type(require)").unwrap();
526 assert_eq!(result, "nil");
527 }
528
529 #[test]
530 fn set_string_variable() {
531 let sandbox = LuaSandbox::new().unwrap();
532 sandbox.set_string("greeting", "hello world").unwrap();
533 let result = sandbox.exec("return greeting").unwrap();
534 assert_eq!(result, "hello world");
535 }
536
537 #[test]
538 fn set_json_object() {
539 let sandbox = LuaSandbox::new().unwrap();
540 sandbox
541 .set_json("data", &json!({"name": "test", "count": 42}))
542 .unwrap();
543
544 let name = sandbox.exec("return data.name").unwrap();
545 assert_eq!(name, "test");
546
547 let count = sandbox.exec("return data.count").unwrap();
548 assert_eq!(count, "42");
549 }
550
551 #[test]
552 fn set_json_array() {
553 let sandbox = LuaSandbox::new().unwrap();
554 sandbox
555 .set_json("items", &json!(["a", "b", "c"]))
556 .unwrap();
557
558 let first = sandbox.exec("return items[1]").unwrap();
559 assert_eq!(first, "a");
560
561 let len = sandbox.exec("return #items").unwrap();
562 assert_eq!(len, "3");
563 }
564
565 #[test]
566 fn set_json_nested() {
567 let sandbox = LuaSandbox::new().unwrap();
568 sandbox
569 .set_json(
570 "config",
571 &json!({"server": {"host": "localhost", "port": 8080}}),
572 )
573 .unwrap();
574
575 let host = sandbox.exec("return config.server.host").unwrap();
576 assert_eq!(host, "localhost");
577
578 let port = sandbox.exec("return config.server.port").unwrap();
579 assert_eq!(port, "8080");
580 }
581
582 #[test]
583 fn print_var_captures_output() {
584 let sandbox = LuaSandbox::new().unwrap();
585 sandbox.exec("print_var('hello')").unwrap();
586 sandbox.exec("print_var(42)").unwrap();
587 let output = sandbox.take_output();
588 assert_eq!(output, vec!["hello", "42"]);
589 }
590
591 #[test]
592 fn print_captures_output() {
593 let sandbox = LuaSandbox::new().unwrap();
594 sandbox.exec("print('hello', 'world')").unwrap();
595 let output = sandbox.take_output();
596 assert_eq!(output, vec!["hello\tworld"]);
597 }
598
599 #[test]
600 fn final_answer_captures() {
601 let sandbox = LuaSandbox::new().unwrap();
602 assert!(!sandbox.has_final_answer());
603
604 sandbox.exec("final_answer('the answer is 42')").unwrap();
605 assert!(sandbox.has_final_answer());
606
607 let answer = sandbox.take_final_answer();
608 assert_eq!(answer, Some("the answer is 42".to_string()));
609 }
610
611 #[test]
612 fn exec_error_returns_soul_error() {
613 let sandbox = LuaSandbox::new().unwrap();
614 let result = sandbox.exec("return undefined_function()");
615 assert!(result.is_err());
616 }
617
618 #[tokio::test]
619 async fn exec_with_timeout_succeeds() {
620 let sandbox = LuaSandbox::new().unwrap();
621 let result = sandbox.exec_with_timeout("return 'ok'", 5).await.unwrap();
622 assert_eq!(result, "ok");
623 }
624
625 #[test]
626 fn json_to_lua_roundtrip_object() {
627 let sandbox = LuaSandbox::new().unwrap();
628 let original = json!({"key": "value", "num": 123, "flag": true});
629 sandbox.set_json("data", &original).unwrap();
630
631 let key = sandbox.exec("return data.key").unwrap();
632 assert_eq!(key, "value");
633 let num = sandbox.exec("return data.num").unwrap();
634 assert_eq!(num, "123");
635 let flag = sandbox.exec("return tostring(data.flag)").unwrap();
636 assert_eq!(flag, "true");
637 }
638
639 #[test]
640 fn json_to_lua_null() {
641 let sandbox = LuaSandbox::new().unwrap();
642 sandbox.set_json("x", &json!(null)).unwrap();
643 let result = sandbox.exec("return type(x)").unwrap();
644 assert_eq!(result, "nil");
645 }
646
647 #[test]
648 fn extract_lua_blocks_basic() {
649 let text = r#"Some text
650```lua
651local x = 1
652return x
653```
654More text"#;
655 let blocks = extract_lua_blocks(text);
656 assert_eq!(blocks.len(), 1);
657 assert!(blocks[0].contains("local x = 1"));
658 }
659
660 #[test]
661 fn extract_lua_blocks_multiple() {
662 let text = r#"
663```lua
664print("first")
665```
666middle
667```lua
668print("second")
669```
670"#;
671 let blocks = extract_lua_blocks(text);
672 assert_eq!(blocks.len(), 2);
673 assert!(blocks[0].contains("first"));
674 assert!(blocks[1].contains("second"));
675 }
676
677 #[test]
678 fn extract_lua_blocks_none() {
679 let blocks = extract_lua_blocks("no code blocks here");
680 assert!(blocks.is_empty());
681 }
682
683 #[test]
684 fn extract_lua_blocks_ignores_other_languages() {
685 let text = r#"
686```python
687print("not lua")
688```
689```lua
690print("lua!")
691```
692```rust
693fn main() {}
694```
695"#;
696 let blocks = extract_lua_blocks(text);
697 assert_eq!(blocks.len(), 1);
698 assert!(blocks[0].contains("lua!"));
699 }
700
701 #[test]
702 fn lua_string_operations() {
703 let sandbox = LuaSandbox::new().unwrap();
704 sandbox.set_string("text", "hello world").unwrap();
705 let result = sandbox.exec("return string.upper(text)").unwrap();
706 assert_eq!(result, "HELLO WORLD");
707 }
708
709 #[test]
710 fn lua_table_operations() {
711 let sandbox = LuaSandbox::new().unwrap();
712 let result = sandbox
713 .exec(
714 r#"
715 local t = {1, 2, 3, 4, 5}
716 local sum = 0
717 for _, v in ipairs(t) do
718 sum = sum + v
719 end
720 return sum
721 "#,
722 )
723 .unwrap();
724 assert_eq!(result, "15");
725 }
726
727 #[test]
728 fn lua_to_json_array() {
729 let sandbox = LuaSandbox::new().unwrap();
730 sandbox
731 .set_json("items", &json!(["a", "b", "c"]))
732 .unwrap();
733 let lua_val: Value = sandbox.lua.globals().get("items").unwrap();
735 let json_val = lua_to_json(&lua_val);
736 assert_eq!(json_val, json!(["a", "b", "c"]));
737 }
738
739 #[test]
740 fn lua_to_json_object() {
741 let sandbox = LuaSandbox::new().unwrap();
742 sandbox
743 .set_json("obj", &json!({"name": "test"}))
744 .unwrap();
745 let lua_val: Value = sandbox.lua.globals().get("obj").unwrap();
746 let json_val = lua_to_json(&lua_val);
747 assert_eq!(json_val["name"], "test");
748 }
749}