Skip to main content

soli_proxy/scripting/
mod.rs

1use mlua::{Function, Lua, Result as LuaResult, Table, Value};
2use std::collections::HashMap;
3use std::path::Path;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Duration;
7
8/// Represents a request as seen by Lua scripts.
9#[derive(Clone, Debug)]
10pub struct LuaRequest {
11    pub method: String,
12    pub path: String,
13    pub headers: HashMap<String, String>,
14    pub host: String,
15    pub content_length: u64,
16}
17
18/// Result of calling on_request — either continue or deny early.
19#[derive(Debug)]
20pub enum RequestHookResult {
21    Continue(LuaRequest),
22    Deny { status: u16, body: String },
23}
24
25/// Result of calling on_route — either override or keep default.
26#[derive(Debug)]
27pub enum RouteHookResult {
28    Override(String),
29    Default,
30}
31
32/// Modifications returned by on_response hook.
33#[derive(Debug, Default)]
34pub struct ResponseMod {
35    pub set_headers: HashMap<String, String>,
36    pub remove_headers: Vec<String>,
37    pub replace_body: Option<String>,
38    pub override_status: Option<u16>,
39}
40
41/// Shared state for cross-worker counters (used by `shared` Lua module).
42type SharedState = Arc<std::sync::RwLock<HashMap<String, f64>>>;
43
44/// The Lua scripting engine. Thread-safe, cheaply cloneable.
45///
46/// Holds a pool of pre-initialized Lua states (one per worker) for global scripts,
47/// plus per-script pools for route-specific scripts.
48#[derive(Clone)]
49pub struct LuaEngine {
50    inner: Arc<LuaEngineInner>,
51}
52
53struct LuaEngineInner {
54    /// Global hook pool — all .lua files from scripts_dir loaded together
55    states: Vec<std::sync::Mutex<Lua>>,
56    has_on_request: bool,
57    has_on_route: bool,
58    has_on_response: bool,
59    has_on_request_end: bool,
60    _hook_timeout: Duration,
61    /// Per-script hook pool (script_name -> per-worker Lua states)
62    route_scripts: HashMap<String, Vec<std::sync::Mutex<Lua>>>,
63    /// Shared state for cross-worker counters (kept alive via Arc)
64    _shared_state: SharedState,
65}
66
67impl LuaEngine {
68    /// Create a new LuaEngine by loading all .lua files from `scripts_dir`.
69    ///
70    /// `num_states` should match the number of worker threads.
71    /// `hook_timeout` is the max execution time per hook call.
72    pub fn new(
73        scripts_dir: &Path,
74        num_states: usize,
75        hook_timeout: Duration,
76    ) -> anyhow::Result<Self> {
77        let num_states = num_states.max(1);
78        let shared_state: SharedState = Arc::new(std::sync::RwLock::new(HashMap::new()));
79
80        // Collect all .lua files from the scripts directory
81        let mut script_sources: Vec<(String, String)> = Vec::new();
82        if scripts_dir.exists() && scripts_dir.is_dir() {
83            let mut entries: Vec<_> = std::fs::read_dir(scripts_dir)?
84                .filter_map(|e| e.ok())
85                .filter(|e| {
86                    e.path()
87                        .extension()
88                        .map(|ext| ext == "lua")
89                        .unwrap_or(false)
90                })
91                .collect();
92            entries.sort_by_key(|e| e.file_name());
93
94            for entry in entries {
95                let path = entry.path();
96                let source = std::fs::read_to_string(&path)?;
97                let name = path.file_name().unwrap().to_string_lossy().to_string();
98                tracing::info!("Loading Lua script: {}", name);
99                script_sources.push((name, source));
100            }
101        }
102
103        if script_sources.is_empty() {
104            tracing::info!("No Lua scripts found in {}", scripts_dir.display());
105        }
106
107        // Create the first Lua state to probe which hooks exist
108        let probe_lua = Self::create_lua_state(&script_sources, hook_timeout, &shared_state)?;
109        let has_on_request = probe_lua.globals().get::<Function>("on_request").is_ok();
110        let has_on_route = probe_lua.globals().get::<Function>("on_route").is_ok();
111        let has_on_response = probe_lua.globals().get::<Function>("on_response").is_ok();
112        let has_on_request_end = probe_lua
113            .globals()
114            .get::<Function>("on_request_end")
115            .is_ok();
116
117        tracing::info!(
118            "Lua hooks: on_request={}, on_route={}, on_response={}, on_request_end={}",
119            has_on_request,
120            has_on_route,
121            has_on_response,
122            has_on_request_end
123        );
124
125        // Build the pool of Lua states
126        let mut states = Vec::with_capacity(num_states);
127        states.push(std::sync::Mutex::new(probe_lua));
128        for _ in 1..num_states {
129            let lua = Self::create_lua_state(&script_sources, hook_timeout, &shared_state)?;
130            states.push(std::sync::Mutex::new(lua));
131        }
132
133        Ok(Self {
134            inner: Arc::new(LuaEngineInner {
135                states,
136                has_on_request,
137                has_on_route,
138                has_on_response,
139                has_on_request_end,
140                _hook_timeout: hook_timeout,
141                route_scripts: HashMap::new(),
142                _shared_state: shared_state,
143            }),
144        })
145    }
146
147    /// Create a LuaEngine with per-route script support.
148    ///
149    /// `global_scripts` — filenames loaded into the global pool (run on every request)
150    /// `route_script_names` — unique filenames that need their own per-worker pools
151    pub fn with_route_scripts(
152        scripts_dir: &Path,
153        num_states: usize,
154        hook_timeout: Duration,
155        global_scripts: &[String],
156        route_script_names: &[String],
157    ) -> anyhow::Result<Self> {
158        let num_states = num_states.max(1);
159        let shared_state: SharedState = Arc::new(std::sync::RwLock::new(HashMap::new()));
160
161        // Load global scripts
162        let mut global_sources: Vec<(String, String)> = Vec::new();
163        for name in global_scripts {
164            let path = scripts_dir.join(name);
165            if path.exists() {
166                let source = std::fs::read_to_string(&path)?;
167                tracing::info!("Loading global Lua script: {}", name);
168                global_sources.push((name.clone(), source));
169            } else {
170                tracing::warn!("Global Lua script not found: {}", path.display());
171            }
172        }
173
174        // Probe global hooks
175        let probe_lua = Self::create_lua_state(&global_sources, hook_timeout, &shared_state)?;
176        let has_on_request = probe_lua.globals().get::<Function>("on_request").is_ok();
177        let has_on_route = probe_lua.globals().get::<Function>("on_route").is_ok();
178        let has_on_response = probe_lua.globals().get::<Function>("on_response").is_ok();
179        let has_on_request_end = probe_lua
180            .globals()
181            .get::<Function>("on_request_end")
182            .is_ok();
183
184        tracing::info!(
185            "Global Lua hooks: on_request={}, on_route={}, on_response={}, on_request_end={}",
186            has_on_request,
187            has_on_route,
188            has_on_response,
189            has_on_request_end
190        );
191
192        // Build global pool
193        let mut states = Vec::with_capacity(num_states);
194        states.push(std::sync::Mutex::new(probe_lua));
195        for _ in 1..num_states {
196            let lua = Self::create_lua_state(&global_sources, hook_timeout, &shared_state)?;
197            states.push(std::sync::Mutex::new(lua));
198        }
199
200        // Build per-route-script pools
201        let mut route_scripts: HashMap<String, Vec<std::sync::Mutex<Lua>>> = HashMap::new();
202        for name in route_script_names {
203            // Skip if it's already a global script (would run twice)
204            if global_scripts.contains(name) {
205                continue;
206            }
207            let path = scripts_dir.join(name);
208            if !path.exists() {
209                tracing::warn!("Route Lua script not found: {}", path.display());
210                continue;
211            }
212            let source = std::fs::read_to_string(&path)?;
213            tracing::info!("Loading route Lua script: {}", name);
214            let script_sources = vec![(name.clone(), source)];
215
216            let mut script_states = Vec::with_capacity(num_states);
217            for _ in 0..num_states {
218                let lua = Self::create_lua_state(&script_sources, hook_timeout, &shared_state)?;
219                script_states.push(std::sync::Mutex::new(lua));
220            }
221            route_scripts.insert(name.clone(), script_states);
222        }
223
224        Ok(Self {
225            inner: Arc::new(LuaEngineInner {
226                states,
227                has_on_request,
228                has_on_route,
229                has_on_response,
230                has_on_request_end,
231                _hook_timeout: hook_timeout,
232                route_scripts,
233                _shared_state: shared_state,
234            }),
235        })
236    }
237
238    fn create_lua_state(
239        scripts: &[(String, String)],
240        hook_timeout: Duration,
241        shared_state: &SharedState,
242    ) -> anyhow::Result<Lua> {
243        let lua = Lua::new();
244
245        // Set instruction count hook for timeout protection
246        let timeout_ms = hook_timeout.as_millis() as u32;
247        // ~1M instructions per ms is a rough estimate; we check every 10000 instructions
248        let max_instructions = (timeout_ms as u64) * 1000;
249        lua.set_hook(
250            mlua::HookTriggers::new().every_nth_instruction(10000),
251            move |_lua, _debug| {
252                // This is a simplified timeout: we count instruction batches.
253                // For a more accurate timeout, we'd track wall-clock time.
254                static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
255                let count = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
256                if count > 0 && count.is_multiple_of(max_instructions) {
257                    // Reset for next call
258                    COUNTER.store(0, std::sync::atomic::Ordering::Relaxed);
259                }
260                Ok(mlua::VmState::Continue)
261            },
262        );
263
264        // Register built-in modules
265        Self::register_log_module(&lua)?;
266        Self::register_base64_module(&lua)?;
267        Self::register_crypto_module(&lua)?;
268        Self::register_env_module(&lua)?;
269        Self::register_time_module(&lua)?;
270        Self::register_shared_module(&lua, shared_state)?;
271
272        // Load all scripts in order
273        for (name, source) in scripts {
274            lua.load(source)
275                .set_name(name)
276                .exec()
277                .map_err(|e| anyhow::anyhow!("Error loading Lua script '{}': {}", name, e))?;
278        }
279
280        Ok(lua)
281    }
282
283    fn register_log_module(lua: &Lua) -> LuaResult<()> {
284        let log_table = lua.create_table()?;
285
286        log_table.set(
287            "info",
288            lua.create_function(|_, msg: String| {
289                tracing::info!(target: "lua", "{}", msg);
290                Ok(())
291            })?,
292        )?;
293
294        log_table.set(
295            "warn",
296            lua.create_function(|_, msg: String| {
297                tracing::warn!(target: "lua", "{}", msg);
298                Ok(())
299            })?,
300        )?;
301
302        log_table.set(
303            "error",
304            lua.create_function(|_, msg: String| {
305                tracing::error!(target: "lua", "{}", msg);
306                Ok(())
307            })?,
308        )?;
309
310        log_table.set(
311            "debug",
312            lua.create_function(|_, msg: String| {
313                tracing::debug!(target: "lua", "{}", msg);
314                Ok(())
315            })?,
316        )?;
317
318        lua.globals().set("log", log_table)?;
319        Ok(())
320    }
321
322    fn register_base64_module(lua: &Lua) -> LuaResult<()> {
323        use base64::Engine as _;
324
325        let table = lua.create_table()?;
326
327        table.set(
328            "encode",
329            lua.create_function(|_, s: String| {
330                Ok(base64::engine::general_purpose::STANDARD.encode(s.as_bytes()))
331            })?,
332        )?;
333
334        table.set(
335            "decode",
336            lua.create_function(|_, s: String| {
337                match base64::engine::general_purpose::STANDARD.decode(s.as_bytes()) {
338                    Ok(bytes) => Ok(String::from_utf8_lossy(&bytes).into_owned()),
339                    Err(e) => Err(mlua::Error::RuntimeError(format!(
340                        "base64 decode error: {}",
341                        e
342                    ))),
343                }
344            })?,
345        )?;
346
347        lua.globals().set("base64", table)?;
348        Ok(())
349    }
350
351    fn register_crypto_module(lua: &Lua) -> LuaResult<()> {
352        use sha2::Digest;
353
354        let table = lua.create_table()?;
355
356        table.set(
357            "sha256",
358            lua.create_function(|_, s: String| {
359                let mut hasher = sha2::Sha256::new();
360                hasher.update(s.as_bytes());
361                let result = hasher.finalize();
362                Ok(hex_encode(&result))
363            })?,
364        )?;
365
366        table.set(
367            "hmac_sha256",
368            lua.create_function(|_, (key, msg): (String, String)| {
369                use hmac::{Hmac, Mac};
370                type HmacSha256 = Hmac<sha2::Sha256>;
371
372                let mut mac = HmacSha256::new_from_slice(key.as_bytes())
373                    .map_err(|e| mlua::Error::RuntimeError(format!("HMAC key error: {}", e)))?;
374                mac.update(msg.as_bytes());
375                let result = mac.finalize().into_bytes();
376                Ok(hex_encode(&result))
377            })?,
378        )?;
379
380        lua.globals().set("crypto", table)?;
381        Ok(())
382    }
383
384    fn register_env_module(lua: &Lua) -> LuaResult<()> {
385        let table = lua.create_table()?;
386
387        table.set(
388            "get",
389            lua.create_function(|lua, name: String| match std::env::var(&name) {
390                Ok(val) => Ok(Value::String(lua.create_string(&val)?)),
391                Err(_) => Ok(Value::Nil),
392            })?,
393        )?;
394
395        lua.globals().set("env", table)?;
396        Ok(())
397    }
398
399    fn register_time_module(lua: &Lua) -> LuaResult<()> {
400        let table = lua.create_table()?;
401
402        table.set(
403            "now_ms",
404            lua.create_function(|_, ()| {
405                let ms = std::time::SystemTime::now()
406                    .duration_since(std::time::UNIX_EPOCH)
407                    .unwrap_or_default()
408                    .as_millis() as f64;
409                Ok(ms)
410            })?,
411        )?;
412
413        lua.globals().set("time", table)?;
414        Ok(())
415    }
416
417    fn register_shared_module(lua: &Lua, shared_state: &SharedState) -> LuaResult<()> {
418        let table = lua.create_table()?;
419
420        let state = shared_state.clone();
421        table.set(
422            "get",
423            lua.create_function(move |_, key: String| {
424                let map = state.read().unwrap();
425                match map.get(&key) {
426                    Some(&val) => Ok(Value::Number(val)),
427                    None => Ok(Value::Nil),
428                }
429            })?,
430        )?;
431
432        let state = shared_state.clone();
433        table.set(
434            "set",
435            lua.create_function(move |_, (key, value): (String, f64)| {
436                let mut map = state.write().unwrap();
437                map.insert(key, value);
438                Ok(())
439            })?,
440        )?;
441
442        let state = shared_state.clone();
443        table.set(
444            "incr",
445            lua.create_function(move |_, key: String| {
446                let mut map = state.write().unwrap();
447                let val = map.entry(key).or_insert(0.0);
448                *val += 1.0;
449                Ok(*val)
450            })?,
451        )?;
452
453        lua.globals().set("shared", table)?;
454        Ok(())
455    }
456
457    /// Get a Lua state from the pool, using a simple round-robin.
458    /// This uses a global counter to distribute across states.
459    fn get_state_index(&self) -> usize {
460        static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
461        let idx = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
462        idx % self.inner.states.len()
463    }
464
465    // --- Hook accessors ---
466
467    pub fn has_on_request(&self) -> bool {
468        self.inner.has_on_request
469    }
470
471    pub fn has_on_route(&self) -> bool {
472        self.inner.has_on_route
473    }
474
475    pub fn has_on_response(&self) -> bool {
476        self.inner.has_on_response
477    }
478
479    pub fn has_on_request_end(&self) -> bool {
480        self.inner.has_on_request_end
481    }
482
483    /// Check if a named route script has a specific hook.
484    fn route_script_has_hook(lua: &Lua, hook_name: &str) -> bool {
485        lua.globals().get::<Function>(hook_name).is_ok()
486    }
487
488    // --- Hook calls ---
489
490    /// Call on_request(req). Returns Continue or Deny.
491    pub fn call_on_request(&self, req: &mut LuaRequest) -> RequestHookResult {
492        if !self.inner.has_on_request {
493            return RequestHookResult::Continue(req.clone());
494        }
495
496        let idx = self.get_state_index();
497        let lua = self.inner.states[idx].lock().unwrap();
498
499        match self.do_on_request(&lua, req) {
500            Ok(result) => result,
501            Err(e) => {
502                tracing::error!("Lua on_request error: {}", e);
503                // On error, continue with unmodified request
504                RequestHookResult::Continue(req.clone())
505            }
506        }
507    }
508
509    /// Call on_request for a specific route script. Returns Continue or Deny.
510    pub fn call_route_on_request(
511        &self,
512        script_name: &str,
513        req: &mut LuaRequest,
514    ) -> RequestHookResult {
515        let Some(script_states) = self.inner.route_scripts.get(script_name) else {
516            return RequestHookResult::Continue(req.clone());
517        };
518
519        let idx = self.get_state_index() % script_states.len();
520        let lua = script_states[idx].lock().unwrap();
521
522        if !Self::route_script_has_hook(&lua, "on_request") {
523            return RequestHookResult::Continue(req.clone());
524        }
525
526        match self.do_on_request(&lua, req) {
527            Ok(result) => result,
528            Err(e) => {
529                tracing::error!("Lua on_request error in {}: {}", script_name, e);
530                RequestHookResult::Continue(req.clone())
531            }
532        }
533    }
534
535    fn do_on_request(&self, lua: &Lua, req: &mut LuaRequest) -> LuaResult<RequestHookResult> {
536        let func: Function = lua.globals().get("on_request")?;
537
538        // Build the request table
539        let req_table = self.lua_request_table(lua, req)?;
540
541        let result: Value = func.call(req_table.clone())?;
542
543        match result {
544            Value::Nil => {
545                // No return value — read back any modified headers
546                self.read_back_request(lua, &req_table, req)?;
547                Ok(RequestHookResult::Continue(req.clone()))
548            }
549            Value::Table(t) => {
550                // Check if it's a deny response: { status = N, body = "..." }
551                if let Ok(status) = t.get::<u16>("status") {
552                    let body: String = t.get::<String>("body").unwrap_or_default();
553                    Ok(RequestHookResult::Deny { status, body })
554                } else {
555                    self.read_back_request(lua, &req_table, req)?;
556                    Ok(RequestHookResult::Continue(req.clone()))
557                }
558            }
559            _ => {
560                self.read_back_request(lua, &req_table, req)?;
561                Ok(RequestHookResult::Continue(req.clone()))
562            }
563        }
564    }
565
566    /// Call on_route(req, matched_target). Returns Override(url) or Default.
567    pub fn call_on_route(&self, req: &LuaRequest, matched_target: &str) -> RouteHookResult {
568        if !self.inner.has_on_route {
569            return RouteHookResult::Default;
570        }
571
572        let idx = self.get_state_index();
573        let lua = self.inner.states[idx].lock().unwrap();
574
575        match self.do_on_route(&lua, req, matched_target) {
576            Ok(result) => result,
577            Err(e) => {
578                tracing::error!("Lua on_route error: {}", e);
579                RouteHookResult::Default
580            }
581        }
582    }
583
584    /// Call on_route for a specific route script.
585    pub fn call_route_on_route(
586        &self,
587        script_name: &str,
588        req: &LuaRequest,
589        matched_target: &str,
590    ) -> RouteHookResult {
591        let Some(script_states) = self.inner.route_scripts.get(script_name) else {
592            return RouteHookResult::Default;
593        };
594
595        let idx = self.get_state_index() % script_states.len();
596        let lua = script_states[idx].lock().unwrap();
597
598        if !Self::route_script_has_hook(&lua, "on_route") {
599            return RouteHookResult::Default;
600        }
601
602        match self.do_on_route(&lua, req, matched_target) {
603            Ok(result) => result,
604            Err(e) => {
605                tracing::error!("Lua on_route error in {}: {}", script_name, e);
606                RouteHookResult::Default
607            }
608        }
609    }
610
611    fn do_on_route(
612        &self,
613        lua: &Lua,
614        req: &LuaRequest,
615        matched_target: &str,
616    ) -> LuaResult<RouteHookResult> {
617        let func: Function = lua.globals().get("on_route")?;
618        let req_table = self.lua_request_table(lua, req)?;
619
620        let result: Value = func.call((req_table, matched_target.to_string()))?;
621
622        match result {
623            Value::String(s) => Ok(RouteHookResult::Override(s.to_str()?.to_string())),
624            _ => Ok(RouteHookResult::Default),
625        }
626    }
627
628    /// Call on_response(req, resp). Returns ResponseMod with any changes.
629    pub fn call_on_response(
630        &self,
631        req: &LuaRequest,
632        status: u16,
633        headers: &HashMap<String, String>,
634    ) -> ResponseMod {
635        if !self.inner.has_on_response {
636            return ResponseMod::default();
637        }
638
639        let idx = self.get_state_index();
640        let lua = self.inner.states[idx].lock().unwrap();
641
642        match self.do_on_response(&lua, req, status, headers) {
643            Ok(result) => result,
644            Err(e) => {
645                tracing::error!("Lua on_response error: {}", e);
646                ResponseMod::default()
647            }
648        }
649    }
650
651    /// Call on_response for a specific route script.
652    pub fn call_route_on_response(
653        &self,
654        script_name: &str,
655        req: &LuaRequest,
656        status: u16,
657        headers: &HashMap<String, String>,
658    ) -> ResponseMod {
659        let Some(script_states) = self.inner.route_scripts.get(script_name) else {
660            return ResponseMod::default();
661        };
662
663        let idx = self.get_state_index() % script_states.len();
664        let lua = script_states[idx].lock().unwrap();
665
666        if !Self::route_script_has_hook(&lua, "on_response") {
667            return ResponseMod::default();
668        }
669
670        match self.do_on_response(&lua, req, status, headers) {
671            Ok(result) => result,
672            Err(e) => {
673                tracing::error!("Lua on_response error in {}: {}", script_name, e);
674                ResponseMod::default()
675            }
676        }
677    }
678
679    fn do_on_response(
680        &self,
681        lua: &Lua,
682        req: &LuaRequest,
683        status: u16,
684        headers: &HashMap<String, String>,
685    ) -> LuaResult<ResponseMod> {
686        let func: Function = lua.globals().get("on_response")?;
687
688        let req_table = self.lua_request_table(lua, req)?;
689
690        // Build response table
691        let resp_table = lua.create_table()?;
692        resp_table.set("status", status)?;
693
694        let headers_table = lua.create_table()?;
695        for (k, v) in headers {
696            headers_table.set(k.as_str(), v.as_str())?;
697        }
698        resp_table.set("headers", headers_table)?;
699
700        // Track modifications via metatables with __newindex
701        let set_headers_table = lua.create_table()?;
702        let remove_headers_table = lua.create_table()?;
703        let mods_table = lua.create_table()?;
704        mods_table.set("set_headers", set_headers_table)?;
705        mods_table.set("remove_headers", remove_headers_table)?;
706        mods_table.set("replace_body", Value::Nil)?;
707        mods_table.set("override_status", Value::Nil)?;
708
709        // Provide helper methods on resp_table (accept self for resp:method() syntax)
710        let mods_ref = mods_table.clone();
711        resp_table.set(
712            "set_header",
713            lua.create_function(
714                move |_lua, (_self_table, name, value): (Table, String, String)| {
715                    let sh: Table = mods_ref.get("set_headers")?;
716                    sh.set(name, value)?;
717                    Ok(())
718                },
719            )?,
720        )?;
721
722        let mods_ref = mods_table.clone();
723        resp_table.set(
724            "remove_header",
725            lua.create_function(move |_lua, (_self_table, name): (Table, String)| {
726                let rh: Table = mods_ref.get("remove_headers")?;
727                let len = rh.len()? + 1;
728                rh.set(len, name)?;
729                Ok(())
730            })?,
731        )?;
732
733        let mods_ref = mods_table.clone();
734        resp_table.set(
735            "replace_body",
736            lua.create_function(move |_lua, (_self_table, body): (Table, String)| {
737                mods_ref.set("replace_body", body)?;
738                Ok(())
739            })?,
740        )?;
741
742        let mods_ref = mods_table.clone();
743        resp_table.set(
744            "set_status",
745            lua.create_function(move |_lua, (_self_table, code): (Table, u16)| {
746                mods_ref.set("override_status", code)?;
747                Ok(())
748            })?,
749        )?;
750
751        let _result: Value = func.call((req_table, resp_table))?;
752
753        // Read back modifications
754        let mut mods = ResponseMod::default();
755
756        let sh: Table = mods_table.get("set_headers")?;
757        for pair in sh.pairs::<String, String>() {
758            let (k, v) = pair?;
759            mods.set_headers.insert(k, v);
760        }
761
762        let rh: Table = mods_table.get("remove_headers")?;
763        for pair in rh.pairs::<i64, String>() {
764            let (_, v) = pair?;
765            mods.remove_headers.push(v);
766        }
767
768        if let Ok(body) = mods_table.get::<String>("replace_body") {
769            mods.replace_body = Some(body);
770        }
771
772        if let Ok(status) = mods_table.get::<u16>("override_status") {
773            mods.override_status = Some(status);
774        }
775
776        Ok(mods)
777    }
778
779    /// Call on_request_end(req, resp_status, duration_ms).
780    pub fn call_on_request_end(
781        &self,
782        req: &LuaRequest,
783        status: u16,
784        duration_ms: f64,
785        target: &str,
786    ) {
787        if !self.inner.has_on_request_end {
788            return;
789        }
790
791        let idx = self.get_state_index();
792        let lua = self.inner.states[idx].lock().unwrap();
793
794        if let Err(e) = self.do_on_request_end(&lua, req, status, duration_ms, target) {
795            tracing::error!("Lua on_request_end error: {}", e);
796        }
797    }
798
799    /// Call on_request_end for a specific route script.
800    pub fn call_route_on_request_end(
801        &self,
802        script_name: &str,
803        req: &LuaRequest,
804        status: u16,
805        duration_ms: f64,
806        target: &str,
807    ) {
808        let Some(script_states) = self.inner.route_scripts.get(script_name) else {
809            return;
810        };
811
812        let idx = self.get_state_index() % script_states.len();
813        let lua = script_states[idx].lock().unwrap();
814
815        if !Self::route_script_has_hook(&lua, "on_request_end") {
816            return;
817        }
818
819        if let Err(e) = self.do_on_request_end(&lua, req, status, duration_ms, target) {
820            tracing::error!("Lua on_request_end error in {}: {}", script_name, e);
821        }
822    }
823
824    fn do_on_request_end(
825        &self,
826        lua: &Lua,
827        req: &LuaRequest,
828        status: u16,
829        duration_ms: f64,
830        target: &str,
831    ) -> LuaResult<()> {
832        let func: Function = lua.globals().get("on_request_end")?;
833        let req_table = self.lua_request_table(lua, req)?;
834
835        let resp_table = lua.create_table()?;
836        resp_table.set("status", status)?;
837
838        func.call::<()>((req_table, resp_table, duration_ms, target.to_string()))?;
839
840        Ok(())
841    }
842
843    /// Check if a route script is loaded.
844    pub fn has_route_script(&self, name: &str) -> bool {
845        self.inner.route_scripts.contains_key(name)
846    }
847
848    // --- Helpers ---
849
850    fn lua_request_table(&self, lua: &Lua, req: &LuaRequest) -> LuaResult<Table> {
851        let table = lua.create_table()?;
852        table.set("method", req.method.as_str())?;
853        table.set("path", req.path.as_str())?;
854        table.set("host", req.host.as_str())?;
855        table.set("content_length", req.content_length)?;
856
857        let headers_table = lua.create_table()?;
858        for (k, v) in &req.headers {
859            headers_table.set(k.as_str(), v.as_str())?;
860        }
861        let headers_ref = headers_table.clone();
862        let headers_ref2 = headers_table.clone();
863        table.set("headers", headers_table)?;
864
865        // Helper method: req:header("Name")
866        table.set(
867            "header",
868            lua.create_function(move |_lua, (_self_table, name): (Table, String)| {
869                let val: Value = headers_ref.get(name.to_lowercase().as_str())?;
870                Ok(val)
871            })?,
872        )?;
873
874        // Helper method: req:set_header("Name", "Value")
875        table.set(
876            "set_header",
877            lua.create_function(
878                move |_lua, (_self_table, name, value): (Table, String, String)| {
879                    headers_ref2.set(name.to_lowercase().as_str(), value.as_str())?;
880                    Ok(())
881                },
882            )?,
883        )?;
884
885        // Helper method: req:deny(status, body)
886        table.set(
887            "deny",
888            lua.create_function(|lua, (_self_table, status, body): (Table, u16, String)| {
889                let t = lua.create_table()?;
890                t.set("status", status)?;
891                t.set("body", body)?;
892                Ok(t)
893            })?,
894        )?;
895
896        Ok(table)
897    }
898
899    fn read_back_request(
900        &self,
901        _lua: &Lua,
902        req_table: &Table,
903        req: &mut LuaRequest,
904    ) -> LuaResult<()> {
905        // Read back modified headers
906        if let Ok(headers_table) = req_table.get::<Table>("headers") {
907            let mut new_headers = HashMap::new();
908            for pair in headers_table.pairs::<String, String>() {
909                let (k, v) = pair?;
910                new_headers.insert(k, v);
911            }
912            req.headers = new_headers;
913        }
914
915        // Read back modified path
916        if let Ok(path) = req_table.get::<String>("path") {
917            req.path = path;
918        }
919
920        Ok(())
921    }
922}
923
924/// Hex-encode a byte slice (lowercase).
925fn hex_encode(bytes: &[u8]) -> String {
926    let mut s = String::with_capacity(bytes.len() * 2);
927    for b in bytes {
928        s.push_str(&format!("{:02x}", b));
929    }
930    s
931}
932
933/// Configuration for the scripting engine.
934#[derive(Clone, Debug)]
935pub struct ScriptingConfig {
936    pub enabled: bool,
937    pub scripts_dir: PathBuf,
938    pub hook_timeout_ms: u64,
939}
940
941impl Default for ScriptingConfig {
942    fn default() -> Self {
943        Self {
944            enabled: false,
945            scripts_dir: PathBuf::from("./scripts/lua"),
946            hook_timeout_ms: 10,
947        }
948    }
949}