Skip to main content

tarn/
scripting.rs

1use crate::assert::types::AssertionResult;
2use crate::error::TarnError;
3use crate::http::HttpResponse;
4use mlua::prelude::*;
5use mlua::{Error as LuaError, HookTriggers, LuaOptions, StdLib, VmState};
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, Mutex};
9
10const SCRIPT_MEMORY_LIMIT_BYTES: usize = 4 * 1024 * 1024;
11const SCRIPT_HOOK_GRANULARITY: u32 = 1_000;
12const SCRIPT_MAX_INSTRUCTIONS: usize = 100_000;
13
14/// Result of running a Lua script.
15#[derive(Debug)]
16pub struct ScriptResult {
17    pub captures: HashMap<String, serde_json::Value>,
18    pub assertion_results: Vec<AssertionResult>,
19}
20
21/// Execute a Lua script with access to the HTTP response and current captures.
22pub fn run_script(
23    script: &str,
24    response: &HttpResponse,
25    captures: &HashMap<String, serde_json::Value>,
26    step_name: &str,
27) -> Result<ScriptResult, TarnError> {
28    let lua = create_sandboxed_lua()?;
29
30    // Build response table
31    let response_table = lua
32        .create_table()
33        .map_err(|e| TarnError::Script(e.to_string()))?;
34    response_table
35        .set("status", response.status)
36        .map_err(|e| TarnError::Script(e.to_string()))?;
37
38    // Headers table
39    let headers_table = lua
40        .create_table()
41        .map_err(|e| TarnError::Script(e.to_string()))?;
42    for (k, v) in &response.headers {
43        headers_table
44            .set(k.as_str(), v.as_str())
45            .map_err(|e| TarnError::Script(e.to_string()))?;
46    }
47    response_table
48        .set("headers", headers_table)
49        .map_err(|e| TarnError::Script(e.to_string()))?;
50
51    // Body as Lua value (serde -> Lua conversion)
52    let body_lua = lua
53        .to_value(&response.body)
54        .map_err(|e| TarnError::Script(format!("Failed to convert body to Lua: {}", e)))?;
55    response_table
56        .set("body", body_lua)
57        .map_err(|e| TarnError::Script(e.to_string()))?;
58
59    lua.globals()
60        .set("response", response_table)
61        .map_err(|e| TarnError::Script(e.to_string()))?;
62
63    // Captures table — push typed JSON values to Lua
64    let captures_table = lua
65        .create_table()
66        .map_err(|e| TarnError::Script(e.to_string()))?;
67    for (k, v) in captures {
68        let lua_val = lua.to_value(v).map_err(|e| {
69            TarnError::Script(format!("Failed to convert capture '{}' to Lua: {}", k, e))
70        })?;
71        captures_table
72            .set(k.as_str(), lua_val)
73            .map_err(|e| TarnError::Script(e.to_string()))?;
74    }
75    lua.globals()
76        .set("captures", captures_table)
77        .map_err(|e| TarnError::Script(e.to_string()))?;
78
79    // Collect assertion results via overridden assert()
80    let assertions: Arc<Mutex<Vec<AssertionResult>>> = Arc::new(Mutex::new(Vec::new()));
81    let assertions_clone = assertions.clone();
82    let step_name_owned = step_name.to_string();
83
84    let assert_fn = lua
85        .create_function(move |_, (condition, message): (bool, Option<String>)| {
86            let msg = message.unwrap_or_else(|| "script assertion".to_string());
87            let result = if condition {
88                AssertionResult::pass(format!("script: {}", msg), "true", "true")
89            } else {
90                AssertionResult::fail(
91                    format!("script: {}", msg),
92                    "true",
93                    "false",
94                    format!("Script assertion failed in '{}': {}", step_name_owned, msg),
95                )
96            };
97            assertions_clone.lock().unwrap().push(result);
98            Ok(())
99        })
100        .map_err(|e| TarnError::Script(e.to_string()))?;
101
102    lua.globals()
103        .set("assert", assert_fn)
104        .map_err(|e| TarnError::Script(e.to_string()))?;
105
106    // Register json global (json.encode / json.decode)
107    register_json_module(&lua)?;
108
109    // Execute script
110    lua.load(script)
111        .exec()
112        .map_err(|e| TarnError::Script(format!("Lua error in step '{}': {}", step_name, e)))?;
113
114    // Extract modified captures — convert Lua types to serde_json::Value
115    let final_captures: HashMap<String, serde_json::Value> = {
116        let captures_table: LuaTable = lua
117            .globals()
118            .get("captures")
119            .map_err(|e| TarnError::Script(e.to_string()))?;
120        let mut result = HashMap::new();
121        for pair in captures_table.pairs::<String, LuaValue>() {
122            let (k, v) = pair.map_err(|e| TarnError::Script(e.to_string()))?;
123            let v_json = lua_value_to_json(v);
124            result.insert(k, v_json);
125        }
126        result
127    };
128
129    let assertion_results = assertions.lock().unwrap().clone();
130
131    Ok(ScriptResult {
132        captures: final_captures,
133        assertion_results,
134    })
135}
136
137/// Register the `json` global table with `encode` and `decode` functions.
138fn register_json_module(lua: &Lua) -> Result<(), TarnError> {
139    let json_table = lua
140        .create_table()
141        .map_err(|e| TarnError::Script(e.to_string()))?;
142
143    // json.decode(string) -> Lua value
144    let decode_fn = lua
145        .create_function(|lua, s: String| {
146            let value: serde_json::Value =
147                serde_json::from_str(&s).map_err(|e| LuaError::runtime(e.to_string()))?;
148            lua.to_value(&value)
149                .map_err(|e| LuaError::runtime(e.to_string()))
150        })
151        .map_err(|e| TarnError::Script(e.to_string()))?;
152
153    // json.encode(value) -> string
154    let encode_fn = lua
155        .create_function(|lua, value: LuaValue| {
156            let json_value: serde_json::Value = lua
157                .from_value(value)
158                .map_err(|e| LuaError::runtime(e.to_string()))?;
159            serde_json::to_string(&json_value).map_err(|e| LuaError::runtime(e.to_string()))
160        })
161        .map_err(|e| TarnError::Script(e.to_string()))?;
162
163    json_table
164        .set("decode", decode_fn)
165        .map_err(|e| TarnError::Script(e.to_string()))?;
166    json_table
167        .set("encode", encode_fn)
168        .map_err(|e| TarnError::Script(e.to_string()))?;
169
170    lua.globals()
171        .set("json", json_table)
172        .map_err(|e| TarnError::Script(e.to_string()))?;
173
174    Ok(())
175}
176
177fn create_sandboxed_lua() -> Result<Lua, TarnError> {
178    let lua = Lua::new_with(
179        StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
180        LuaOptions::default(),
181    )
182    .map_err(|e| TarnError::Script(format!("Failed to initialize Lua sandbox: {}", e)))?;
183
184    lua.set_memory_limit(SCRIPT_MEMORY_LIMIT_BYTES)
185        .map_err(|e| TarnError::Script(format!("Failed to configure Lua memory limit: {}", e)))?;
186
187    let executed = Arc::new(AtomicUsize::new(0));
188    let executed_clone = executed.clone();
189    lua.set_hook(
190        HookTriggers::new().every_nth_instruction(SCRIPT_HOOK_GRANULARITY),
191        move |_lua, _debug| {
192            let total = executed_clone
193                .fetch_add(SCRIPT_HOOK_GRANULARITY as usize, Ordering::Relaxed)
194                + SCRIPT_HOOK_GRANULARITY as usize;
195            if total > SCRIPT_MAX_INSTRUCTIONS {
196                Err(LuaError::runtime(
197                    "script exceeded the instruction limit and was terminated",
198                ))
199            } else {
200                Ok(VmState::Continue)
201            }
202        },
203    );
204
205    let globals = lua.globals();
206    for name in ["dofile", "loadfile", "collectgarbage"] {
207        globals
208            .set(name, LuaValue::Nil)
209            .map_err(|e| TarnError::Script(format!("Failed to harden Lua globals: {}", e)))?;
210    }
211
212    Ok(lua)
213}
214
215/// Convert a Lua value to a serde_json::Value.
216fn lua_value_to_json(v: LuaValue) -> serde_json::Value {
217    match v {
218        LuaValue::String(s) => serde_json::Value::String(s.to_string_lossy().to_string()),
219        LuaValue::Integer(i) => serde_json::json!(i),
220        LuaValue::Number(n) => serde_json::Number::from_f64(n)
221            .map(serde_json::Value::Number)
222            .unwrap_or(serde_json::Value::Null),
223        LuaValue::Boolean(b) => serde_json::Value::Bool(b),
224        LuaValue::Nil => serde_json::Value::Null,
225        _ => serde_json::Value::String(format!("{:?}", v)),
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    fn make_response(status: u16, body: serde_json::Value) -> HttpResponse {
235        let body_bytes = match &body {
236            serde_json::Value::Null => Vec::new(),
237            serde_json::Value::String(text) => text.as_bytes().to_vec(),
238            other => serde_json::to_vec(other).unwrap(),
239        };
240        HttpResponse {
241            status,
242            url: "https://example.com/".to_string(),
243            redirect_count: 0,
244            headers: HashMap::new(),
245            raw_headers: vec![],
246            body_bytes,
247            body,
248            duration_ms: 50,
249            timings: crate::http::ResponseTimings {
250                total_ms: 50,
251                ttfb_ms: 25,
252                body_read_ms: 25,
253                connect_ms: None,
254                tls_ms: None,
255            },
256        }
257    }
258
259    #[test]
260    fn script_accesses_response_status() {
261        let resp = make_response(200, json!({}));
262        let result = run_script(
263            "assert(response.status == 200, 'status ok')",
264            &resp,
265            &HashMap::new(),
266            "test",
267        )
268        .unwrap();
269        assert_eq!(result.assertion_results.len(), 1);
270        assert!(result.assertion_results[0].passed);
271    }
272
273    #[test]
274    fn script_accesses_response_body() {
275        let resp = make_response(200, json!({"name": "Alice"}));
276        let result = run_script(
277            "assert(response.body.name == 'Alice', 'name check')",
278            &resp,
279            &HashMap::new(),
280            "test",
281        )
282        .unwrap();
283        assert!(result.assertion_results[0].passed);
284    }
285
286    #[test]
287    fn script_sets_captures() {
288        let resp = make_response(200, json!({"id": "usr_123"}));
289        let result = run_script(
290            "captures.user_id = response.body.id",
291            &resp,
292            &HashMap::new(),
293            "test",
294        )
295        .unwrap();
296        assert_eq!(result.captures.get("user_id").unwrap(), &json!("usr_123"));
297    }
298
299    #[test]
300    fn script_sets_typed_captures() {
301        let resp = make_response(200, json!({"count": 42}));
302        let result = run_script(
303            "captures.count = response.body.count",
304            &resp,
305            &HashMap::new(),
306            "test",
307        )
308        .unwrap();
309        assert_eq!(result.captures.get("count").unwrap(), &json!(42));
310    }
311
312    #[test]
313    fn script_failed_assertion() {
314        let resp = make_response(404, json!({}));
315        let result = run_script(
316            "assert(response.status == 200, 'expected 200')",
317            &resp,
318            &HashMap::new(),
319            "test",
320        )
321        .unwrap();
322        assert_eq!(result.assertion_results.len(), 1);
323        assert!(!result.assertion_results[0].passed);
324        assert!(result.assertion_results[0].message.contains("expected 200"));
325    }
326
327    #[test]
328    fn script_syntax_error() {
329        let resp = make_response(200, json!({}));
330        let result = run_script("this is not valid lua!!!", &resp, &HashMap::new(), "test");
331        assert!(result.is_err());
332        let err = result.unwrap_err();
333        assert!(matches!(err, TarnError::Script(_)));
334    }
335
336    #[test]
337    fn script_reads_existing_captures() {
338        let resp = make_response(200, json!({}));
339        let mut caps = HashMap::new();
340        caps.insert("token".to_string(), json!("abc123"));
341        let result = run_script(
342            "assert(captures.token == 'abc123', 'token check')",
343            &resp,
344            &caps,
345            "test",
346        )
347        .unwrap();
348        assert!(result.assertion_results[0].passed);
349    }
350
351    #[test]
352    fn script_reads_typed_captures() {
353        let resp = make_response(200, json!({}));
354        let mut caps = HashMap::new();
355        caps.insert("count".to_string(), json!(42));
356        let result = run_script(
357            "assert(captures.count == 42, 'number preserved')",
358            &resp,
359            &caps,
360            "test",
361        )
362        .unwrap();
363        assert!(result.assertion_results[0].passed);
364    }
365
366    #[test]
367    fn script_cannot_access_os_library() {
368        let resp = make_response(200, json!({}));
369        let result = run_script(
370            "assert(os == nil, 'os hidden')",
371            &resp,
372            &HashMap::new(),
373            "test",
374        )
375        .unwrap();
376        assert!(result.assertion_results[0].passed);
377    }
378
379    #[test]
380    fn script_cannot_load_files() {
381        let resp = make_response(200, json!({}));
382        let result = run_script("dofile('secret.lua')", &resp, &HashMap::new(), "test");
383        assert!(result.is_err());
384        assert!(result
385            .unwrap_err()
386            .to_string()
387            .contains("attempt to call a nil value"));
388    }
389
390    #[test]
391    fn script_instruction_limit_is_enforced() {
392        let resp = make_response(200, json!({}));
393        let result = run_script("while true do end", &resp, &HashMap::new(), "test");
394        assert!(result.is_err());
395        assert!(result
396            .unwrap_err()
397            .to_string()
398            .contains("instruction limit"));
399    }
400
401    #[test]
402    fn script_json_decode() {
403        let resp = make_response(200, json!({}));
404        let result = run_script(
405            r#"
406            local data = json.decode('{"name":"Alice","age":30}')
407            assert(data.name == 'Alice', 'name decoded')
408            assert(data.age == 30, 'age decoded')
409            "#,
410            &resp,
411            &HashMap::new(),
412            "test",
413        )
414        .unwrap();
415        assert_eq!(result.assertion_results.len(), 2);
416        assert!(result.assertion_results.iter().all(|a| a.passed));
417    }
418
419    #[test]
420    fn script_json_encode() {
421        let resp = make_response(200, json!({}));
422        let result = run_script(
423            r#"
424            local encoded = json.encode({name = 'Bob'})
425            assert(type(encoded) == 'string', 'encode returns string')
426            local decoded = json.decode(encoded)
427            assert(decoded.name == 'Bob', 'roundtrip works')
428            "#,
429            &resp,
430            &HashMap::new(),
431            "test",
432        )
433        .unwrap();
434        assert!(result.assertion_results.iter().all(|a| a.passed));
435    }
436
437    #[test]
438    fn script_json_decode_invalid() {
439        let resp = make_response(200, json!({}));
440        let result = run_script(
441            "json.decode('not valid json')",
442            &resp,
443            &HashMap::new(),
444            "test",
445        );
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn script_json_global_exists() {
451        let resp = make_response(200, json!({}));
452        let result = run_script(
453            "assert(json ~= nil, 'json exists')\nassert(type(json.decode) == 'function', 'decode is function')\nassert(type(json.encode) == 'function', 'encode is function')",
454            &resp,
455            &HashMap::new(),
456            "test",
457        )
458        .unwrap();
459        assert!(result.assertion_results.iter().all(|a| a.passed));
460    }
461}