1use crate::assert::types::AssertionResult;
2use crate::error::TarnError;
3use crate::http::HttpResponse;
4use mlua::prelude::*;
5use mlua::{Error as LuaError, HookTriggers, LuaOptions, StdLib, VmState};
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::{Arc, Mutex};
9
10const SCRIPT_MEMORY_LIMIT_BYTES: usize = 4 * 1024 * 1024;
11const SCRIPT_HOOK_GRANULARITY: u32 = 1_000;
12const SCRIPT_MAX_INSTRUCTIONS: usize = 100_000;
13
14#[derive(Debug)]
16pub struct ScriptResult {
17 pub captures: HashMap<String, serde_json::Value>,
18 pub assertion_results: Vec<AssertionResult>,
19}
20
21pub fn run_script(
23 script: &str,
24 response: &HttpResponse,
25 captures: &HashMap<String, serde_json::Value>,
26 step_name: &str,
27) -> Result<ScriptResult, TarnError> {
28 let lua = create_sandboxed_lua()?;
29
30 let response_table = lua
32 .create_table()
33 .map_err(|e| TarnError::Script(e.to_string()))?;
34 response_table
35 .set("status", response.status)
36 .map_err(|e| TarnError::Script(e.to_string()))?;
37
38 let headers_table = lua
40 .create_table()
41 .map_err(|e| TarnError::Script(e.to_string()))?;
42 for (k, v) in &response.headers {
43 headers_table
44 .set(k.as_str(), v.as_str())
45 .map_err(|e| TarnError::Script(e.to_string()))?;
46 }
47 response_table
48 .set("headers", headers_table)
49 .map_err(|e| TarnError::Script(e.to_string()))?;
50
51 let body_lua = lua
53 .to_value(&response.body)
54 .map_err(|e| TarnError::Script(format!("Failed to convert body to Lua: {}", e)))?;
55 response_table
56 .set("body", body_lua)
57 .map_err(|e| TarnError::Script(e.to_string()))?;
58
59 lua.globals()
60 .set("response", response_table)
61 .map_err(|e| TarnError::Script(e.to_string()))?;
62
63 let captures_table = lua
65 .create_table()
66 .map_err(|e| TarnError::Script(e.to_string()))?;
67 for (k, v) in captures {
68 let lua_val = lua.to_value(v).map_err(|e| {
69 TarnError::Script(format!("Failed to convert capture '{}' to Lua: {}", k, e))
70 })?;
71 captures_table
72 .set(k.as_str(), lua_val)
73 .map_err(|e| TarnError::Script(e.to_string()))?;
74 }
75 lua.globals()
76 .set("captures", captures_table)
77 .map_err(|e| TarnError::Script(e.to_string()))?;
78
79 let assertions: Arc<Mutex<Vec<AssertionResult>>> = Arc::new(Mutex::new(Vec::new()));
81 let assertions_clone = assertions.clone();
82 let step_name_owned = step_name.to_string();
83
84 let assert_fn = lua
85 .create_function(move |_, (condition, message): (bool, Option<String>)| {
86 let msg = message.unwrap_or_else(|| "script assertion".to_string());
87 let result = if condition {
88 AssertionResult::pass(format!("script: {}", msg), "true", "true")
89 } else {
90 AssertionResult::fail(
91 format!("script: {}", msg),
92 "true",
93 "false",
94 format!("Script assertion failed in '{}': {}", step_name_owned, msg),
95 )
96 };
97 assertions_clone.lock().unwrap().push(result);
98 Ok(())
99 })
100 .map_err(|e| TarnError::Script(e.to_string()))?;
101
102 lua.globals()
103 .set("assert", assert_fn)
104 .map_err(|e| TarnError::Script(e.to_string()))?;
105
106 register_json_module(&lua)?;
108
109 lua.load(script)
111 .exec()
112 .map_err(|e| TarnError::Script(format!("Lua error in step '{}': {}", step_name, e)))?;
113
114 let final_captures: HashMap<String, serde_json::Value> = {
116 let captures_table: LuaTable = lua
117 .globals()
118 .get("captures")
119 .map_err(|e| TarnError::Script(e.to_string()))?;
120 let mut result = HashMap::new();
121 for pair in captures_table.pairs::<String, LuaValue>() {
122 let (k, v) = pair.map_err(|e| TarnError::Script(e.to_string()))?;
123 let v_json = lua_value_to_json(v);
124 result.insert(k, v_json);
125 }
126 result
127 };
128
129 let assertion_results = assertions.lock().unwrap().clone();
130
131 Ok(ScriptResult {
132 captures: final_captures,
133 assertion_results,
134 })
135}
136
137fn register_json_module(lua: &Lua) -> Result<(), TarnError> {
139 let json_table = lua
140 .create_table()
141 .map_err(|e| TarnError::Script(e.to_string()))?;
142
143 let decode_fn = lua
145 .create_function(|lua, s: String| {
146 let value: serde_json::Value =
147 serde_json::from_str(&s).map_err(|e| LuaError::runtime(e.to_string()))?;
148 lua.to_value(&value)
149 .map_err(|e| LuaError::runtime(e.to_string()))
150 })
151 .map_err(|e| TarnError::Script(e.to_string()))?;
152
153 let encode_fn = lua
155 .create_function(|lua, value: LuaValue| {
156 let json_value: serde_json::Value = lua
157 .from_value(value)
158 .map_err(|e| LuaError::runtime(e.to_string()))?;
159 serde_json::to_string(&json_value).map_err(|e| LuaError::runtime(e.to_string()))
160 })
161 .map_err(|e| TarnError::Script(e.to_string()))?;
162
163 json_table
164 .set("decode", decode_fn)
165 .map_err(|e| TarnError::Script(e.to_string()))?;
166 json_table
167 .set("encode", encode_fn)
168 .map_err(|e| TarnError::Script(e.to_string()))?;
169
170 lua.globals()
171 .set("json", json_table)
172 .map_err(|e| TarnError::Script(e.to_string()))?;
173
174 Ok(())
175}
176
177fn create_sandboxed_lua() -> Result<Lua, TarnError> {
178 let lua = Lua::new_with(
179 StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
180 LuaOptions::default(),
181 )
182 .map_err(|e| TarnError::Script(format!("Failed to initialize Lua sandbox: {}", e)))?;
183
184 lua.set_memory_limit(SCRIPT_MEMORY_LIMIT_BYTES)
185 .map_err(|e| TarnError::Script(format!("Failed to configure Lua memory limit: {}", e)))?;
186
187 let executed = Arc::new(AtomicUsize::new(0));
188 let executed_clone = executed.clone();
189 lua.set_hook(
190 HookTriggers::new().every_nth_instruction(SCRIPT_HOOK_GRANULARITY),
191 move |_lua, _debug| {
192 let total = executed_clone
193 .fetch_add(SCRIPT_HOOK_GRANULARITY as usize, Ordering::Relaxed)
194 + SCRIPT_HOOK_GRANULARITY as usize;
195 if total > SCRIPT_MAX_INSTRUCTIONS {
196 Err(LuaError::runtime(
197 "script exceeded the instruction limit and was terminated",
198 ))
199 } else {
200 Ok(VmState::Continue)
201 }
202 },
203 );
204
205 let globals = lua.globals();
206 for name in ["dofile", "loadfile", "collectgarbage"] {
207 globals
208 .set(name, LuaValue::Nil)
209 .map_err(|e| TarnError::Script(format!("Failed to harden Lua globals: {}", e)))?;
210 }
211
212 Ok(lua)
213}
214
215fn lua_value_to_json(v: LuaValue) -> serde_json::Value {
217 match v {
218 LuaValue::String(s) => serde_json::Value::String(s.to_string_lossy().to_string()),
219 LuaValue::Integer(i) => serde_json::json!(i),
220 LuaValue::Number(n) => serde_json::Number::from_f64(n)
221 .map(serde_json::Value::Number)
222 .unwrap_or(serde_json::Value::Null),
223 LuaValue::Boolean(b) => serde_json::Value::Bool(b),
224 LuaValue::Nil => serde_json::Value::Null,
225 _ => serde_json::Value::String(format!("{:?}", v)),
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use serde_json::json;
233
234 fn make_response(status: u16, body: serde_json::Value) -> HttpResponse {
235 let body_bytes = match &body {
236 serde_json::Value::Null => Vec::new(),
237 serde_json::Value::String(text) => text.as_bytes().to_vec(),
238 other => serde_json::to_vec(other).unwrap(),
239 };
240 HttpResponse {
241 status,
242 url: "https://example.com/".to_string(),
243 redirect_count: 0,
244 headers: HashMap::new(),
245 raw_headers: vec![],
246 body_bytes,
247 body,
248 duration_ms: 50,
249 timings: crate::http::ResponseTimings {
250 total_ms: 50,
251 ttfb_ms: 25,
252 body_read_ms: 25,
253 connect_ms: None,
254 tls_ms: None,
255 },
256 }
257 }
258
259 #[test]
260 fn script_accesses_response_status() {
261 let resp = make_response(200, json!({}));
262 let result = run_script(
263 "assert(response.status == 200, 'status ok')",
264 &resp,
265 &HashMap::new(),
266 "test",
267 )
268 .unwrap();
269 assert_eq!(result.assertion_results.len(), 1);
270 assert!(result.assertion_results[0].passed);
271 }
272
273 #[test]
274 fn script_accesses_response_body() {
275 let resp = make_response(200, json!({"name": "Alice"}));
276 let result = run_script(
277 "assert(response.body.name == 'Alice', 'name check')",
278 &resp,
279 &HashMap::new(),
280 "test",
281 )
282 .unwrap();
283 assert!(result.assertion_results[0].passed);
284 }
285
286 #[test]
287 fn script_sets_captures() {
288 let resp = make_response(200, json!({"id": "usr_123"}));
289 let result = run_script(
290 "captures.user_id = response.body.id",
291 &resp,
292 &HashMap::new(),
293 "test",
294 )
295 .unwrap();
296 assert_eq!(result.captures.get("user_id").unwrap(), &json!("usr_123"));
297 }
298
299 #[test]
300 fn script_sets_typed_captures() {
301 let resp = make_response(200, json!({"count": 42}));
302 let result = run_script(
303 "captures.count = response.body.count",
304 &resp,
305 &HashMap::new(),
306 "test",
307 )
308 .unwrap();
309 assert_eq!(result.captures.get("count").unwrap(), &json!(42));
310 }
311
312 #[test]
313 fn script_failed_assertion() {
314 let resp = make_response(404, json!({}));
315 let result = run_script(
316 "assert(response.status == 200, 'expected 200')",
317 &resp,
318 &HashMap::new(),
319 "test",
320 )
321 .unwrap();
322 assert_eq!(result.assertion_results.len(), 1);
323 assert!(!result.assertion_results[0].passed);
324 assert!(result.assertion_results[0].message.contains("expected 200"));
325 }
326
327 #[test]
328 fn script_syntax_error() {
329 let resp = make_response(200, json!({}));
330 let result = run_script("this is not valid lua!!!", &resp, &HashMap::new(), "test");
331 assert!(result.is_err());
332 let err = result.unwrap_err();
333 assert!(matches!(err, TarnError::Script(_)));
334 }
335
336 #[test]
337 fn script_reads_existing_captures() {
338 let resp = make_response(200, json!({}));
339 let mut caps = HashMap::new();
340 caps.insert("token".to_string(), json!("abc123"));
341 let result = run_script(
342 "assert(captures.token == 'abc123', 'token check')",
343 &resp,
344 &caps,
345 "test",
346 )
347 .unwrap();
348 assert!(result.assertion_results[0].passed);
349 }
350
351 #[test]
352 fn script_reads_typed_captures() {
353 let resp = make_response(200, json!({}));
354 let mut caps = HashMap::new();
355 caps.insert("count".to_string(), json!(42));
356 let result = run_script(
357 "assert(captures.count == 42, 'number preserved')",
358 &resp,
359 &caps,
360 "test",
361 )
362 .unwrap();
363 assert!(result.assertion_results[0].passed);
364 }
365
366 #[test]
367 fn script_cannot_access_os_library() {
368 let resp = make_response(200, json!({}));
369 let result = run_script(
370 "assert(os == nil, 'os hidden')",
371 &resp,
372 &HashMap::new(),
373 "test",
374 )
375 .unwrap();
376 assert!(result.assertion_results[0].passed);
377 }
378
379 #[test]
380 fn script_cannot_load_files() {
381 let resp = make_response(200, json!({}));
382 let result = run_script("dofile('secret.lua')", &resp, &HashMap::new(), "test");
383 assert!(result.is_err());
384 assert!(result
385 .unwrap_err()
386 .to_string()
387 .contains("attempt to call a nil value"));
388 }
389
390 #[test]
391 fn script_instruction_limit_is_enforced() {
392 let resp = make_response(200, json!({}));
393 let result = run_script("while true do end", &resp, &HashMap::new(), "test");
394 assert!(result.is_err());
395 assert!(result
396 .unwrap_err()
397 .to_string()
398 .contains("instruction limit"));
399 }
400
401 #[test]
402 fn script_json_decode() {
403 let resp = make_response(200, json!({}));
404 let result = run_script(
405 r#"
406 local data = json.decode('{"name":"Alice","age":30}')
407 assert(data.name == 'Alice', 'name decoded')
408 assert(data.age == 30, 'age decoded')
409 "#,
410 &resp,
411 &HashMap::new(),
412 "test",
413 )
414 .unwrap();
415 assert_eq!(result.assertion_results.len(), 2);
416 assert!(result.assertion_results.iter().all(|a| a.passed));
417 }
418
419 #[test]
420 fn script_json_encode() {
421 let resp = make_response(200, json!({}));
422 let result = run_script(
423 r#"
424 local encoded = json.encode({name = 'Bob'})
425 assert(type(encoded) == 'string', 'encode returns string')
426 local decoded = json.decode(encoded)
427 assert(decoded.name == 'Bob', 'roundtrip works')
428 "#,
429 &resp,
430 &HashMap::new(),
431 "test",
432 )
433 .unwrap();
434 assert!(result.assertion_results.iter().all(|a| a.passed));
435 }
436
437 #[test]
438 fn script_json_decode_invalid() {
439 let resp = make_response(200, json!({}));
440 let result = run_script(
441 "json.decode('not valid json')",
442 &resp,
443 &HashMap::new(),
444 "test",
445 );
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn script_json_global_exists() {
451 let resp = make_response(200, json!({}));
452 let result = run_script(
453 "assert(json ~= nil, 'json exists')\nassert(type(json.decode) == 'function', 'decode is function')\nassert(type(json.encode) == 'function', 'encode is function')",
454 &resp,
455 &HashMap::new(),
456 "test",
457 )
458 .unwrap();
459 assert!(result.assertion_results.iter().all(|a| a.passed));
460 }
461}