Skip to main content

soul_core/lua/
functions.rs

1//! Registered Lua functions for skill execution and RLM REPL.
2
3use mlua::{Function, Lua, Value};
4
5use super::{json_to_lua, lua_to_json};
6use crate::error::{SoulError, SoulResult};
7
8/// Register skill-specific functions on a Lua sandbox.
9///
10/// Functions: `json_decode`, `json_encode`, `url_encode`, `fetch`, `post`, `fetch_json`
11pub fn register_skill_functions(lua: &Lua) -> SoulResult<()> {
12    let globals = lua.globals();
13
14    // json_decode(s) → Lua table
15    globals
16        .set("json_decode", create_json_decode(lua)?)
17        .map_err(lua_err)?;
18
19    // json_encode(t) → JSON string
20    globals
21        .set("json_encode", create_json_encode(lua)?)
22        .map_err(lua_err)?;
23
24    // url_encode(s) → percent-encoded string
25    globals
26        .set("url_encode", create_url_encode(lua)?)
27        .map_err(lua_err)?;
28
29    // fetch(url) → body string (sync HTTP GET)
30    globals
31        .set("fetch", create_fetch(lua)?)
32        .map_err(lua_err)?;
33
34    // fetch_json(url) → Lua table (sync HTTP GET + JSON parse)
35    globals
36        .set("fetch_json", create_fetch_json(lua)?)
37        .map_err(lua_err)?;
38
39    // post(url, body, content_type?) → body string (sync HTTP POST)
40    globals
41        .set("post", create_post(lua)?)
42        .map_err(lua_err)?;
43
44    // Crypto, encoding, time, and random functions
45    super::crypto::register_crypto_functions(lua)?;
46
47    Ok(())
48}
49
50/// Register RLM context functions on a Lua sandbox.
51///
52/// Functions: `chunk_by_lines`, `chunk_by_chars`, `chunk_by_regex`, `slice`, `search`
53pub fn register_rlm_functions(lua: &Lua) -> SoulResult<()> {
54    let globals = lua.globals();
55
56    // chunk_by_lines(text, n) → table of chunks
57    globals
58        .set("chunk_by_lines", create_chunk_by_lines(lua)?)
59        .map_err(lua_err)?;
60
61    // chunk_by_chars(text, n) → table of chunks
62    globals
63        .set("chunk_by_chars", create_chunk_by_chars(lua)?)
64        .map_err(lua_err)?;
65
66    // chunk_by_regex(text, pattern) → table of chunks
67    globals
68        .set("chunk_by_regex", create_chunk_by_regex(lua)?)
69        .map_err(lua_err)?;
70
71    // slice(text, start, len) → substring
72    globals
73        .set("slice", create_slice(lua)?)
74        .map_err(lua_err)?;
75
76    // search(text, query, top_k?) → table of {text, score, line} results
77    globals
78        .set("search", create_search(lua)?)
79        .map_err(lua_err)?;
80
81    Ok(())
82}
83
84// ─── Function Constructors ──────────────────────────────────────────────────
85
86fn create_json_decode(lua: &Lua) -> SoulResult<Function> {
87    lua.create_function(|lua, s: String| {
88        let value: serde_json::Value = serde_json::from_str(&s).map_err(|e| {
89            mlua::Error::external(format!("json_decode error: {e}"))
90        })?;
91        json_to_lua(lua, &value)
92    })
93    .map_err(lua_err)
94}
95
96fn create_json_encode(lua: &Lua) -> SoulResult<Function> {
97    lua.create_function(|_, val: Value| {
98        let json = lua_to_json(&val);
99        serde_json::to_string(&json)
100            .map_err(|e| mlua::Error::external(format!("json_encode error: {e}")))
101    })
102    .map_err(lua_err)
103}
104
105fn create_url_encode(lua: &Lua) -> SoulResult<Function> {
106    lua.create_function(|_, s: String| {
107        Ok(percent_encode(&s))
108    })
109    .map_err(lua_err)
110}
111
112fn create_fetch(lua: &Lua) -> SoulResult<Function> {
113    lua.create_function(|_, url: String| {
114        // Use a blocking HTTP request — Lua execution is synchronous.
115        // This is safe because the outer LuaSandbox wraps execution in a timeout.
116        let body = blocking_get(&url)
117            .map_err(|e| mlua::Error::external(format!("fetch error: {e}")))?;
118        Ok(body)
119    })
120    .map_err(lua_err)
121}
122
123fn create_fetch_json(lua: &Lua) -> SoulResult<Function> {
124    lua.create_function(|lua, url: String| {
125        let body = blocking_get(&url)
126            .map_err(|e| mlua::Error::external(format!("fetch_json error: {e}")))?;
127        let value: serde_json::Value = serde_json::from_str(&body)
128            .map_err(|e| mlua::Error::external(format!("fetch_json parse error: {e}")))?;
129        json_to_lua(lua, &value)
130    })
131    .map_err(lua_err)
132}
133
134fn create_post(lua: &Lua) -> SoulResult<Function> {
135    lua.create_function(|_, (url, body, content_type): (String, String, Option<String>)| {
136        let ct = content_type.unwrap_or_else(|| "application/json".to_string());
137        let result = blocking_post(&url, &body, &ct)
138            .map_err(|e| mlua::Error::external(format!("post error: {e}")))?;
139        Ok(result)
140    })
141    .map_err(lua_err)
142}
143
144fn create_chunk_by_lines(lua: &Lua) -> SoulResult<Function> {
145    lua.create_function(|lua, (text, n): (String, usize)| {
146        let n = n.max(1);
147        let lines: Vec<&str> = text.lines().collect();
148        let chunks: Vec<String> = lines.chunks(n).map(|c| c.join("\n")).collect();
149
150        let table = lua.create_table()?;
151        for (i, chunk) in chunks.iter().enumerate() {
152            table.set(i + 1, chunk.as_str())?;
153        }
154        Ok(Value::Table(table))
155    })
156    .map_err(lua_err)
157}
158
159fn create_chunk_by_chars(lua: &Lua) -> SoulResult<Function> {
160    lua.create_function(|lua, (text, n): (String, usize)| {
161        let n = n.max(1);
162        let chars: Vec<char> = text.chars().collect();
163        let chunks: Vec<String> = chars.chunks(n).map(|c| c.iter().collect()).collect();
164
165        let table = lua.create_table()?;
166        for (i, chunk) in chunks.iter().enumerate() {
167            table.set(i + 1, chunk.as_str())?;
168        }
169        Ok(Value::Table(table))
170    })
171    .map_err(lua_err)
172}
173
174fn create_chunk_by_regex(lua: &Lua) -> SoulResult<Function> {
175    lua.create_function(|lua, (text, pattern): (String, String)| {
176        let re = regex::Regex::new(&pattern)
177            .map_err(|e| mlua::Error::external(format!("Invalid regex '{pattern}': {e}")))?;
178
179        let chunks: Vec<&str> = re.split(&text).collect();
180
181        let table = lua.create_table()?;
182        for (i, chunk) in chunks.iter().enumerate() {
183            table.set(i + 1, *chunk)?;
184        }
185        Ok(Value::Table(table))
186    })
187    .map_err(lua_err)
188}
189
190fn create_slice(lua: &Lua) -> SoulResult<Function> {
191    lua.create_function(|_, (text, start, len): (String, usize, usize)| {
192        let chars: Vec<char> = text.chars().collect();
193        let start = start.min(chars.len());
194        let end = (start + len).min(chars.len());
195        let result: String = chars[start..end].iter().collect();
196        Ok(result)
197    })
198    .map_err(lua_err)
199}
200
201fn create_search(lua: &Lua) -> SoulResult<Function> {
202    lua.create_function(|lua, (text, query, top_k): (String, String, Option<usize>)| {
203        let top_k = top_k.unwrap_or(5);
204
205        // Tokenize query into lowercase terms (skip short noise words)
206        let query_terms: Vec<String> = query
207            .split(|c: char| !c.is_alphanumeric() && c != '_')
208            .filter(|s| s.len() >= 2)
209            .map(|s| s.to_lowercase())
210            .collect();
211
212        if query_terms.is_empty() {
213            return Ok(Value::Table(lua.create_table()?));
214        }
215
216        // Split text into sections at blank lines or markdown headers
217        let lines: Vec<&str> = text.lines().collect();
218        let mut sections: Vec<(usize, String)> = Vec::new();
219        let mut section_start = 0;
220
221        for (i, line) in lines.iter().enumerate() {
222            if line.trim().is_empty() || line.starts_with("## ") || line.starts_with("### ") {
223                if i > section_start {
224                    let section_text = lines[section_start..i].join("\n");
225                    if !section_text.trim().is_empty() {
226                        sections.push((section_start + 1, section_text));
227                    }
228                }
229                if line.starts_with('#') {
230                    section_start = i; // include header in next section
231                } else {
232                    section_start = i + 1;
233                }
234            }
235        }
236        // Last section
237        if section_start < lines.len() {
238            let section_text = lines[section_start..].join("\n");
239            if !section_text.trim().is_empty() {
240                sections.push((section_start + 1, section_text));
241            }
242        }
243
244        if sections.is_empty() {
245            return Ok(Value::Table(lua.create_table()?));
246        }
247
248        // IDF: log(N / df) for each query term
249        let n = sections.len() as f64;
250        let mut idf_map: Vec<f64> = Vec::with_capacity(query_terms.len());
251        for term in &query_terms {
252            let df = sections
253                .iter()
254                .filter(|(_, s)| s.to_lowercase().contains(term.as_str()))
255                .count();
256            let idf = if df > 0 { (n / df as f64).ln() + 1.0 } else { 0.0 };
257            idf_map.push(idf);
258        }
259
260        // Score each section: BM25-lite (tf * idf, length-normalized)
261        let mut scored: Vec<(f64, usize, usize)> = sections
262            .iter()
263            .enumerate()
264            .map(|(idx, (line, section))| {
265                let lower = section.to_lowercase();
266                let word_count = lower.split_whitespace().count().max(1) as f64;
267                let mut score = 0.0f64;
268
269                for (ti, term) in query_terms.iter().enumerate() {
270                    let tf = lower.matches(term.as_str()).count() as f64 / word_count;
271                    score += tf * idf_map[ti];
272                }
273
274                (score, *line, idx)
275            })
276            .filter(|(score, _, _)| *score > 0.0)
277            .collect();
278
279        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280        scored.truncate(top_k);
281
282        // Build result table: [{text=..., score=..., line=...}, ...]
283        let result_table = lua.create_table()?;
284        for (i, (score, line, idx)) in scored.iter().enumerate() {
285            let entry = lua.create_table()?;
286            entry.set("text", sections[*idx].1.as_str())?;
287            entry.set("score", *score)?;
288            entry.set("line", *line)?;
289            result_table.set(i + 1, entry)?;
290        }
291        Ok(Value::Table(result_table))
292    })
293    .map_err(lua_err)
294}
295
296// ─── Helpers ────────────────────────────────────────────────────────────────
297
298/// Simple percent-encoding (RFC 3986 unreserved characters).
299fn percent_encode(s: &str) -> String {
300    let mut encoded = String::with_capacity(s.len());
301    for byte in s.bytes() {
302        match byte {
303            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
304                encoded.push(byte as char);
305            }
306            _ => {
307                encoded.push_str(&format!("%{byte:02X}"));
308            }
309        }
310    }
311    encoded
312}
313
314/// Blocking HTTP GET. Uses reqwest blocking client.
315fn blocking_get(url: &str) -> Result<String, String> {
316    // Build a new tokio runtime is not feasible here since we're already in one.
317    // Use reqwest's blocking client which uses its own thread pool.
318    let client = reqwest::blocking::Client::builder()
319        .timeout(std::time::Duration::from_secs(30))
320        .redirect(reqwest::redirect::Policy::limited(5))
321        .build()
322        .map_err(|e| format!("HTTP client error: {e}"))?;
323
324    let response = client
325        .get(url)
326        .header("User-Agent", "amai-agent/lua-sandbox")
327        .send()
328        .map_err(|e| format!("HTTP request failed: {e}"))?;
329
330    let status = response.status();
331    if !status.is_success() {
332        return Err(format!("HTTP {status} from {url}"));
333    }
334
335    response
336        .text()
337        .map_err(|e| format!("Failed to read response: {e}"))
338}
339
340/// Blocking HTTP POST.
341fn blocking_post(url: &str, body: &str, content_type: &str) -> Result<String, String> {
342    let client = reqwest::blocking::Client::builder()
343        .timeout(std::time::Duration::from_secs(30))
344        .redirect(reqwest::redirect::Policy::limited(5))
345        .build()
346        .map_err(|e| format!("HTTP client error: {e}"))?;
347
348    let response = client
349        .post(url)
350        .header("User-Agent", "amai-agent/lua-sandbox")
351        .header("Content-Type", content_type)
352        .body(body.to_string())
353        .send()
354        .map_err(|e| format!("HTTP request failed: {e}"))?;
355
356    let status = response.status();
357    if !status.is_success() {
358        return Err(format!("HTTP {status} from {url}"));
359    }
360
361    response
362        .text()
363        .map_err(|e| format!("Failed to read response: {e}"))
364}
365
366/// Convert mlua::Error to SoulError.
367fn lua_err(e: mlua::Error) -> SoulError {
368    SoulError::ToolExecution {
369        tool_name: "lua_sandbox".into(),
370        message: format!("Lua function registration error: {e}"),
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::lua::LuaSandbox;
378
379    fn sandbox_with_skill_fns() -> LuaSandbox {
380        let sb = LuaSandbox::new().unwrap();
381        register_skill_functions(sb.lua()).unwrap();
382        sb
383    }
384
385    fn sandbox_with_rlm_fns() -> LuaSandbox {
386        let sb = LuaSandbox::new().unwrap();
387        register_rlm_functions(sb.lua()).unwrap();
388        sb
389    }
390
391    // ─── JSON ─────────────────────────────────────────────────────────
392
393    #[test]
394    fn json_decode_object() {
395        let sb = sandbox_with_skill_fns();
396        let result = sb
397            .exec(r#"local t = json_decode('{"name":"test","count":3}'); return t.name"#)
398            .unwrap();
399        assert_eq!(result, "test");
400    }
401
402    #[test]
403    fn json_decode_array() {
404        let sb = sandbox_with_skill_fns();
405        let result = sb
406            .exec(r#"local t = json_decode('[1,2,3]'); return #t"#)
407            .unwrap();
408        assert_eq!(result, "3");
409    }
410
411    #[test]
412    fn json_decode_invalid() {
413        let sb = sandbox_with_skill_fns();
414        let result = sb.exec(r#"json_decode('not json')"#);
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn json_encode_table() {
420        let sb = sandbox_with_skill_fns();
421        let result = sb
422            .exec(r#"return json_encode({name = "test"})"#)
423            .unwrap();
424        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
425        assert_eq!(parsed["name"], "test");
426    }
427
428    #[test]
429    fn json_encode_decode_roundtrip() {
430        let sb = sandbox_with_skill_fns();
431        let result = sb
432            .exec(
433                r#"
434            local obj = {name = "test", count = 42}
435            local encoded = json_encode(obj)
436            local decoded = json_decode(encoded)
437            return decoded.name .. ":" .. decoded.count
438        "#,
439            )
440            .unwrap();
441        assert_eq!(result, "test:42");
442    }
443
444    // ─── URL Encode ──────────────────────────────────────────────────
445
446    #[test]
447    fn url_encode_basic() {
448        let sb = sandbox_with_skill_fns();
449        let result = sb
450            .exec(r#"return url_encode("hello world")"#)
451            .unwrap();
452        assert_eq!(result, "hello%20world");
453    }
454
455    #[test]
456    fn url_encode_special_chars() {
457        let sb = sandbox_with_skill_fns();
458        let result = sb
459            .exec(r#"return url_encode("a=b&c=d")"#)
460            .unwrap();
461        assert_eq!(result, "a%3Db%26c%3Dd");
462    }
463
464    // ─── Chunk Functions ─────────────────────────────────────────────
465
466    #[test]
467    fn chunk_by_lines_basic() {
468        let sb = sandbox_with_rlm_fns();
469        sb.set_string("text", "line1\nline2\nline3\nline4\nline5")
470            .unwrap();
471        let result = sb.exec("local chunks = chunk_by_lines(text, 2); return #chunks").unwrap();
472        assert_eq!(result, "3"); // 2+2+1
473    }
474
475    #[test]
476    fn chunk_by_lines_content() {
477        let sb = sandbox_with_rlm_fns();
478        sb.set_string("text", "a\nb\nc\nd").unwrap();
479        let result = sb
480            .exec("local chunks = chunk_by_lines(text, 2); return chunks[1]")
481            .unwrap();
482        assert_eq!(result, "a\nb");
483    }
484
485    #[test]
486    fn chunk_by_chars_basic() {
487        let sb = sandbox_with_rlm_fns();
488        sb.set_string("text", "abcdefghij").unwrap();
489        let result = sb
490            .exec("local chunks = chunk_by_chars(text, 3); return #chunks")
491            .unwrap();
492        assert_eq!(result, "4"); // 3+3+3+1
493    }
494
495    #[test]
496    fn chunk_by_regex_basic() {
497        let sb = sandbox_with_rlm_fns();
498        sb.set_string("text", "part1---part2---part3").unwrap();
499        let result = sb
500            .exec("local chunks = chunk_by_regex(text, '---'); return table.concat(chunks, ',')")
501            .unwrap();
502        assert_eq!(result, "part1,part2,part3");
503    }
504
505    #[test]
506    fn slice_basic() {
507        let sb = sandbox_with_rlm_fns();
508        sb.set_string("text", "hello world").unwrap();
509        let result = sb.exec("return slice(text, 0, 5)").unwrap();
510        assert_eq!(result, "hello");
511    }
512
513    #[test]
514    fn slice_out_of_bounds() {
515        let sb = sandbox_with_rlm_fns();
516        sb.set_string("text", "hello").unwrap();
517        let result = sb.exec("return slice(text, 3, 100)").unwrap();
518        assert_eq!(result, "lo");
519    }
520
521    // ─── Search (BM25) ────────────────────────────────────────────
522
523    #[test]
524    fn search_basic_keyword() {
525        let sb = sandbox_with_rlm_fns();
526        sb.set_string(
527            "text",
528            "## TURN 1\n### USER\nHow do I use async rust?\n### ASSISTANT\nUse tokio and async/await\n\n## TURN 2\n### USER\nWrite a test for the server\n### ASSISTANT\nHere is the test code\n\n## TURN 3\n### USER\nFix the database error\n### ASSISTANT\nThe error was in the connection pool",
529        ).unwrap();
530        let result = sb
531            .exec(r#"local r = search(text, "error database", 2); return #r"#)
532            .unwrap();
533        let count: usize = result.parse().unwrap();
534        assert!(count >= 1); // Should find the database error section
535    }
536
537    #[test]
538    fn search_returns_scored_results() {
539        let sb = sandbox_with_rlm_fns();
540        sb.set_string(
541            "text",
542            "## Section A\nThe cat sat on the mat\n\n## Section B\nThe dog ran in the park\n\n## Section C\nThe cat chased the dog across the mat",
543        ).unwrap();
544        let result = sb
545            .exec(r#"local r = search(text, "cat mat"); return r[1].score > 0"#)
546            .unwrap();
547        assert_eq!(result, "true");
548    }
549
550    #[test]
551    fn search_respects_top_k() {
552        let sb = sandbox_with_rlm_fns();
553        sb.set_string(
554            "text",
555            "## A\nalpha bravo\n\n## B\ncharlie bravo\n\n## C\ndelta bravo\n\n## D\necho bravo",
556        ).unwrap();
557        let result = sb
558            .exec(r#"local r = search(text, "bravo", 2); return #r"#)
559            .unwrap();
560        assert_eq!(result, "2");
561    }
562
563    #[test]
564    fn search_empty_query() {
565        let sb = sandbox_with_rlm_fns();
566        sb.set_string("text", "some text here").unwrap();
567        // Single char query terms are filtered out
568        let result = sb
569            .exec(r#"local r = search(text, "a b"); return #r"#)
570            .unwrap();
571        assert_eq!(result, "0");
572    }
573
574    #[test]
575    fn search_no_matches() {
576        let sb = sandbox_with_rlm_fns();
577        sb.set_string("text", "## Section\nThe quick brown fox").unwrap();
578        let result = sb
579            .exec(r#"local r = search(text, "zebra elephant"); return #r"#)
580            .unwrap();
581        assert_eq!(result, "0");
582    }
583
584    #[test]
585    fn search_result_has_line_numbers() {
586        let sb = sandbox_with_rlm_fns();
587        sb.set_string("text", "## First\nHello world\n\n## Second\nGoodbye world").unwrap();
588        let result = sb
589            .exec(r#"local r = search(text, "goodbye"); return r[1].line"#)
590            .unwrap();
591        let line: usize = result.parse().unwrap();
592        assert!(line >= 4); // "## Second" starts at line 4
593    }
594
595    // ─── percent_encode ─────────────────────────────────────────────
596
597    #[test]
598    fn percent_encode_passthrough() {
599        assert_eq!(percent_encode("hello"), "hello");
600        assert_eq!(percent_encode("a-b_c.d~e"), "a-b_c.d~e");
601    }
602
603    #[test]
604    fn percent_encode_spaces_and_specials() {
605        assert_eq!(percent_encode("hello world"), "hello%20world");
606        assert_eq!(percent_encode("/path?q=1"), "%2Fpath%3Fq%3D1");
607    }
608}