1use mlua::{Function, Lua, Value};
4
5use super::{json_to_lua, lua_to_json};
6use crate::error::{SoulError, SoulResult};
7
8pub fn register_skill_functions(lua: &Lua) -> SoulResult<()> {
12 let globals = lua.globals();
13
14 globals
16 .set("json_decode", create_json_decode(lua)?)
17 .map_err(lua_err)?;
18
19 globals
21 .set("json_encode", create_json_encode(lua)?)
22 .map_err(lua_err)?;
23
24 globals
26 .set("url_encode", create_url_encode(lua)?)
27 .map_err(lua_err)?;
28
29 globals
31 .set("fetch", create_fetch(lua)?)
32 .map_err(lua_err)?;
33
34 globals
36 .set("fetch_json", create_fetch_json(lua)?)
37 .map_err(lua_err)?;
38
39 globals
41 .set("post", create_post(lua)?)
42 .map_err(lua_err)?;
43
44 super::crypto::register_crypto_functions(lua)?;
46
47 Ok(())
48}
49
50pub fn register_rlm_functions(lua: &Lua) -> SoulResult<()> {
54 let globals = lua.globals();
55
56 globals
58 .set("chunk_by_lines", create_chunk_by_lines(lua)?)
59 .map_err(lua_err)?;
60
61 globals
63 .set("chunk_by_chars", create_chunk_by_chars(lua)?)
64 .map_err(lua_err)?;
65
66 globals
68 .set("chunk_by_regex", create_chunk_by_regex(lua)?)
69 .map_err(lua_err)?;
70
71 globals
73 .set("slice", create_slice(lua)?)
74 .map_err(lua_err)?;
75
76 globals
78 .set("search", create_search(lua)?)
79 .map_err(lua_err)?;
80
81 Ok(())
82}
83
84fn create_json_decode(lua: &Lua) -> SoulResult<Function> {
87 lua.create_function(|lua, s: String| {
88 let value: serde_json::Value = serde_json::from_str(&s).map_err(|e| {
89 mlua::Error::external(format!("json_decode error: {e}"))
90 })?;
91 json_to_lua(lua, &value)
92 })
93 .map_err(lua_err)
94}
95
96fn create_json_encode(lua: &Lua) -> SoulResult<Function> {
97 lua.create_function(|_, val: Value| {
98 let json = lua_to_json(&val);
99 serde_json::to_string(&json)
100 .map_err(|e| mlua::Error::external(format!("json_encode error: {e}")))
101 })
102 .map_err(lua_err)
103}
104
105fn create_url_encode(lua: &Lua) -> SoulResult<Function> {
106 lua.create_function(|_, s: String| {
107 Ok(percent_encode(&s))
108 })
109 .map_err(lua_err)
110}
111
112fn create_fetch(lua: &Lua) -> SoulResult<Function> {
113 lua.create_function(|_, url: String| {
114 let body = blocking_get(&url)
117 .map_err(|e| mlua::Error::external(format!("fetch error: {e}")))?;
118 Ok(body)
119 })
120 .map_err(lua_err)
121}
122
123fn create_fetch_json(lua: &Lua) -> SoulResult<Function> {
124 lua.create_function(|lua, url: String| {
125 let body = blocking_get(&url)
126 .map_err(|e| mlua::Error::external(format!("fetch_json error: {e}")))?;
127 let value: serde_json::Value = serde_json::from_str(&body)
128 .map_err(|e| mlua::Error::external(format!("fetch_json parse error: {e}")))?;
129 json_to_lua(lua, &value)
130 })
131 .map_err(lua_err)
132}
133
134fn create_post(lua: &Lua) -> SoulResult<Function> {
135 lua.create_function(|_, (url, body, content_type): (String, String, Option<String>)| {
136 let ct = content_type.unwrap_or_else(|| "application/json".to_string());
137 let result = blocking_post(&url, &body, &ct)
138 .map_err(|e| mlua::Error::external(format!("post error: {e}")))?;
139 Ok(result)
140 })
141 .map_err(lua_err)
142}
143
144fn create_chunk_by_lines(lua: &Lua) -> SoulResult<Function> {
145 lua.create_function(|lua, (text, n): (String, usize)| {
146 let n = n.max(1);
147 let lines: Vec<&str> = text.lines().collect();
148 let chunks: Vec<String> = lines.chunks(n).map(|c| c.join("\n")).collect();
149
150 let table = lua.create_table()?;
151 for (i, chunk) in chunks.iter().enumerate() {
152 table.set(i + 1, chunk.as_str())?;
153 }
154 Ok(Value::Table(table))
155 })
156 .map_err(lua_err)
157}
158
159fn create_chunk_by_chars(lua: &Lua) -> SoulResult<Function> {
160 lua.create_function(|lua, (text, n): (String, usize)| {
161 let n = n.max(1);
162 let chars: Vec<char> = text.chars().collect();
163 let chunks: Vec<String> = chars.chunks(n).map(|c| c.iter().collect()).collect();
164
165 let table = lua.create_table()?;
166 for (i, chunk) in chunks.iter().enumerate() {
167 table.set(i + 1, chunk.as_str())?;
168 }
169 Ok(Value::Table(table))
170 })
171 .map_err(lua_err)
172}
173
174fn create_chunk_by_regex(lua: &Lua) -> SoulResult<Function> {
175 lua.create_function(|lua, (text, pattern): (String, String)| {
176 let re = regex::Regex::new(&pattern)
177 .map_err(|e| mlua::Error::external(format!("Invalid regex '{pattern}': {e}")))?;
178
179 let chunks: Vec<&str> = re.split(&text).collect();
180
181 let table = lua.create_table()?;
182 for (i, chunk) in chunks.iter().enumerate() {
183 table.set(i + 1, *chunk)?;
184 }
185 Ok(Value::Table(table))
186 })
187 .map_err(lua_err)
188}
189
190fn create_slice(lua: &Lua) -> SoulResult<Function> {
191 lua.create_function(|_, (text, start, len): (String, usize, usize)| {
192 let chars: Vec<char> = text.chars().collect();
193 let start = start.min(chars.len());
194 let end = (start + len).min(chars.len());
195 let result: String = chars[start..end].iter().collect();
196 Ok(result)
197 })
198 .map_err(lua_err)
199}
200
201fn create_search(lua: &Lua) -> SoulResult<Function> {
202 lua.create_function(|lua, (text, query, top_k): (String, String, Option<usize>)| {
203 let top_k = top_k.unwrap_or(5);
204
205 let query_terms: Vec<String> = query
207 .split(|c: char| !c.is_alphanumeric() && c != '_')
208 .filter(|s| s.len() >= 2)
209 .map(|s| s.to_lowercase())
210 .collect();
211
212 if query_terms.is_empty() {
213 return Ok(Value::Table(lua.create_table()?));
214 }
215
216 let lines: Vec<&str> = text.lines().collect();
218 let mut sections: Vec<(usize, String)> = Vec::new();
219 let mut section_start = 0;
220
221 for (i, line) in lines.iter().enumerate() {
222 if line.trim().is_empty() || line.starts_with("## ") || line.starts_with("### ") {
223 if i > section_start {
224 let section_text = lines[section_start..i].join("\n");
225 if !section_text.trim().is_empty() {
226 sections.push((section_start + 1, section_text));
227 }
228 }
229 if line.starts_with('#') {
230 section_start = i; } else {
232 section_start = i + 1;
233 }
234 }
235 }
236 if section_start < lines.len() {
238 let section_text = lines[section_start..].join("\n");
239 if !section_text.trim().is_empty() {
240 sections.push((section_start + 1, section_text));
241 }
242 }
243
244 if sections.is_empty() {
245 return Ok(Value::Table(lua.create_table()?));
246 }
247
248 let n = sections.len() as f64;
250 let mut idf_map: Vec<f64> = Vec::with_capacity(query_terms.len());
251 for term in &query_terms {
252 let df = sections
253 .iter()
254 .filter(|(_, s)| s.to_lowercase().contains(term.as_str()))
255 .count();
256 let idf = if df > 0 { (n / df as f64).ln() + 1.0 } else { 0.0 };
257 idf_map.push(idf);
258 }
259
260 let mut scored: Vec<(f64, usize, usize)> = sections
262 .iter()
263 .enumerate()
264 .map(|(idx, (line, section))| {
265 let lower = section.to_lowercase();
266 let word_count = lower.split_whitespace().count().max(1) as f64;
267 let mut score = 0.0f64;
268
269 for (ti, term) in query_terms.iter().enumerate() {
270 let tf = lower.matches(term.as_str()).count() as f64 / word_count;
271 score += tf * idf_map[ti];
272 }
273
274 (score, *line, idx)
275 })
276 .filter(|(score, _, _)| *score > 0.0)
277 .collect();
278
279 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280 scored.truncate(top_k);
281
282 let result_table = lua.create_table()?;
284 for (i, (score, line, idx)) in scored.iter().enumerate() {
285 let entry = lua.create_table()?;
286 entry.set("text", sections[*idx].1.as_str())?;
287 entry.set("score", *score)?;
288 entry.set("line", *line)?;
289 result_table.set(i + 1, entry)?;
290 }
291 Ok(Value::Table(result_table))
292 })
293 .map_err(lua_err)
294}
295
296fn percent_encode(s: &str) -> String {
300 let mut encoded = String::with_capacity(s.len());
301 for byte in s.bytes() {
302 match byte {
303 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
304 encoded.push(byte as char);
305 }
306 _ => {
307 encoded.push_str(&format!("%{byte:02X}"));
308 }
309 }
310 }
311 encoded
312}
313
314fn blocking_get(url: &str) -> Result<String, String> {
316 let client = reqwest::blocking::Client::builder()
319 .timeout(std::time::Duration::from_secs(30))
320 .redirect(reqwest::redirect::Policy::limited(5))
321 .build()
322 .map_err(|e| format!("HTTP client error: {e}"))?;
323
324 let response = client
325 .get(url)
326 .header("User-Agent", "amai-agent/lua-sandbox")
327 .send()
328 .map_err(|e| format!("HTTP request failed: {e}"))?;
329
330 let status = response.status();
331 if !status.is_success() {
332 return Err(format!("HTTP {status} from {url}"));
333 }
334
335 response
336 .text()
337 .map_err(|e| format!("Failed to read response: {e}"))
338}
339
340fn blocking_post(url: &str, body: &str, content_type: &str) -> Result<String, String> {
342 let client = reqwest::blocking::Client::builder()
343 .timeout(std::time::Duration::from_secs(30))
344 .redirect(reqwest::redirect::Policy::limited(5))
345 .build()
346 .map_err(|e| format!("HTTP client error: {e}"))?;
347
348 let response = client
349 .post(url)
350 .header("User-Agent", "amai-agent/lua-sandbox")
351 .header("Content-Type", content_type)
352 .body(body.to_string())
353 .send()
354 .map_err(|e| format!("HTTP request failed: {e}"))?;
355
356 let status = response.status();
357 if !status.is_success() {
358 return Err(format!("HTTP {status} from {url}"));
359 }
360
361 response
362 .text()
363 .map_err(|e| format!("Failed to read response: {e}"))
364}
365
366fn lua_err(e: mlua::Error) -> SoulError {
368 SoulError::ToolExecution {
369 tool_name: "lua_sandbox".into(),
370 message: format!("Lua function registration error: {e}"),
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::lua::LuaSandbox;
378
379 fn sandbox_with_skill_fns() -> LuaSandbox {
380 let sb = LuaSandbox::new().unwrap();
381 register_skill_functions(sb.lua()).unwrap();
382 sb
383 }
384
385 fn sandbox_with_rlm_fns() -> LuaSandbox {
386 let sb = LuaSandbox::new().unwrap();
387 register_rlm_functions(sb.lua()).unwrap();
388 sb
389 }
390
391 #[test]
394 fn json_decode_object() {
395 let sb = sandbox_with_skill_fns();
396 let result = sb
397 .exec(r#"local t = json_decode('{"name":"test","count":3}'); return t.name"#)
398 .unwrap();
399 assert_eq!(result, "test");
400 }
401
402 #[test]
403 fn json_decode_array() {
404 let sb = sandbox_with_skill_fns();
405 let result = sb
406 .exec(r#"local t = json_decode('[1,2,3]'); return #t"#)
407 .unwrap();
408 assert_eq!(result, "3");
409 }
410
411 #[test]
412 fn json_decode_invalid() {
413 let sb = sandbox_with_skill_fns();
414 let result = sb.exec(r#"json_decode('not json')"#);
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn json_encode_table() {
420 let sb = sandbox_with_skill_fns();
421 let result = sb
422 .exec(r#"return json_encode({name = "test"})"#)
423 .unwrap();
424 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
425 assert_eq!(parsed["name"], "test");
426 }
427
428 #[test]
429 fn json_encode_decode_roundtrip() {
430 let sb = sandbox_with_skill_fns();
431 let result = sb
432 .exec(
433 r#"
434 local obj = {name = "test", count = 42}
435 local encoded = json_encode(obj)
436 local decoded = json_decode(encoded)
437 return decoded.name .. ":" .. decoded.count
438 "#,
439 )
440 .unwrap();
441 assert_eq!(result, "test:42");
442 }
443
444 #[test]
447 fn url_encode_basic() {
448 let sb = sandbox_with_skill_fns();
449 let result = sb
450 .exec(r#"return url_encode("hello world")"#)
451 .unwrap();
452 assert_eq!(result, "hello%20world");
453 }
454
455 #[test]
456 fn url_encode_special_chars() {
457 let sb = sandbox_with_skill_fns();
458 let result = sb
459 .exec(r#"return url_encode("a=b&c=d")"#)
460 .unwrap();
461 assert_eq!(result, "a%3Db%26c%3Dd");
462 }
463
464 #[test]
467 fn chunk_by_lines_basic() {
468 let sb = sandbox_with_rlm_fns();
469 sb.set_string("text", "line1\nline2\nline3\nline4\nline5")
470 .unwrap();
471 let result = sb.exec("local chunks = chunk_by_lines(text, 2); return #chunks").unwrap();
472 assert_eq!(result, "3"); }
474
475 #[test]
476 fn chunk_by_lines_content() {
477 let sb = sandbox_with_rlm_fns();
478 sb.set_string("text", "a\nb\nc\nd").unwrap();
479 let result = sb
480 .exec("local chunks = chunk_by_lines(text, 2); return chunks[1]")
481 .unwrap();
482 assert_eq!(result, "a\nb");
483 }
484
485 #[test]
486 fn chunk_by_chars_basic() {
487 let sb = sandbox_with_rlm_fns();
488 sb.set_string("text", "abcdefghij").unwrap();
489 let result = sb
490 .exec("local chunks = chunk_by_chars(text, 3); return #chunks")
491 .unwrap();
492 assert_eq!(result, "4"); }
494
495 #[test]
496 fn chunk_by_regex_basic() {
497 let sb = sandbox_with_rlm_fns();
498 sb.set_string("text", "part1---part2---part3").unwrap();
499 let result = sb
500 .exec("local chunks = chunk_by_regex(text, '---'); return table.concat(chunks, ',')")
501 .unwrap();
502 assert_eq!(result, "part1,part2,part3");
503 }
504
505 #[test]
506 fn slice_basic() {
507 let sb = sandbox_with_rlm_fns();
508 sb.set_string("text", "hello world").unwrap();
509 let result = sb.exec("return slice(text, 0, 5)").unwrap();
510 assert_eq!(result, "hello");
511 }
512
513 #[test]
514 fn slice_out_of_bounds() {
515 let sb = sandbox_with_rlm_fns();
516 sb.set_string("text", "hello").unwrap();
517 let result = sb.exec("return slice(text, 3, 100)").unwrap();
518 assert_eq!(result, "lo");
519 }
520
521 #[test]
524 fn search_basic_keyword() {
525 let sb = sandbox_with_rlm_fns();
526 sb.set_string(
527 "text",
528 "## TURN 1\n### USER\nHow do I use async rust?\n### ASSISTANT\nUse tokio and async/await\n\n## TURN 2\n### USER\nWrite a test for the server\n### ASSISTANT\nHere is the test code\n\n## TURN 3\n### USER\nFix the database error\n### ASSISTANT\nThe error was in the connection pool",
529 ).unwrap();
530 let result = sb
531 .exec(r#"local r = search(text, "error database", 2); return #r"#)
532 .unwrap();
533 let count: usize = result.parse().unwrap();
534 assert!(count >= 1); }
536
537 #[test]
538 fn search_returns_scored_results() {
539 let sb = sandbox_with_rlm_fns();
540 sb.set_string(
541 "text",
542 "## Section A\nThe cat sat on the mat\n\n## Section B\nThe dog ran in the park\n\n## Section C\nThe cat chased the dog across the mat",
543 ).unwrap();
544 let result = sb
545 .exec(r#"local r = search(text, "cat mat"); return r[1].score > 0"#)
546 .unwrap();
547 assert_eq!(result, "true");
548 }
549
550 #[test]
551 fn search_respects_top_k() {
552 let sb = sandbox_with_rlm_fns();
553 sb.set_string(
554 "text",
555 "## A\nalpha bravo\n\n## B\ncharlie bravo\n\n## C\ndelta bravo\n\n## D\necho bravo",
556 ).unwrap();
557 let result = sb
558 .exec(r#"local r = search(text, "bravo", 2); return #r"#)
559 .unwrap();
560 assert_eq!(result, "2");
561 }
562
563 #[test]
564 fn search_empty_query() {
565 let sb = sandbox_with_rlm_fns();
566 sb.set_string("text", "some text here").unwrap();
567 let result = sb
569 .exec(r#"local r = search(text, "a b"); return #r"#)
570 .unwrap();
571 assert_eq!(result, "0");
572 }
573
574 #[test]
575 fn search_no_matches() {
576 let sb = sandbox_with_rlm_fns();
577 sb.set_string("text", "## Section\nThe quick brown fox").unwrap();
578 let result = sb
579 .exec(r#"local r = search(text, "zebra elephant"); return #r"#)
580 .unwrap();
581 assert_eq!(result, "0");
582 }
583
584 #[test]
585 fn search_result_has_line_numbers() {
586 let sb = sandbox_with_rlm_fns();
587 sb.set_string("text", "## First\nHello world\n\n## Second\nGoodbye world").unwrap();
588 let result = sb
589 .exec(r#"local r = search(text, "goodbye"); return r[1].line"#)
590 .unwrap();
591 let line: usize = result.parse().unwrap();
592 assert!(line >= 4); }
594
595 #[test]
598 fn percent_encode_passthrough() {
599 assert_eq!(percent_encode("hello"), "hello");
600 assert_eq!(percent_encode("a-b_c.d~e"), "a-b_c.d~e");
601 }
602
603 #[test]
604 fn percent_encode_spaces_and_specials() {
605 assert_eq!(percent_encode("hello world"), "hello%20world");
606 assert_eq!(percent_encode("/path?q=1"), "%2Fpath%3Fq%3D1");
607 }
608}