1use std::ops::{Deref, DerefMut};
100
101use mlua::prelude::*;
102use rusty_rules::{BoolMatcher, IpMatcher, NumberMatcher, RegexMatcher, StringMatcher, Value};
103
104pub struct Engine(rusty_rules::Engine<LuaValue>);
127
128impl Deref for Engine {
129 type Target = rusty_rules::Engine<LuaValue>;
130
131 fn deref(&self) -> &Self::Target {
132 &self.0
133 }
134}
135
136impl DerefMut for Engine {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 &mut self.0
139 }
140}
141
142impl Engine {
143 pub fn new() -> Self {
145 Engine(rusty_rules::Engine::new())
146 }
147}
148
149impl LuaUserData for Engine {
150 fn register(registry: &mut LuaUserDataRegistry<Self>) {
151 registry.add_function("new", |_, ()| Ok(Self::new()));
153
154 registry.add_method_mut(
156 "register_fetcher",
157 |lua,
158 this,
159 (name, params_or_func, func): (
160 String,
161 LuaEither<LuaTable, LuaFunction>,
162 Option<LuaFunction>,
163 )| {
164 let (params, fetcher) = match (params_or_func, func) {
165 (LuaEither::Left(params), Some(func)) => (Some(params), func),
166 (LuaEither::Right(func), _) => (None, func),
167 _ => Err(LuaError::external("fetcher function must be provided"))?,
168 };
169
170 let lua = lua.weak();
171 let fetcher = this.0.register_fetcher(&name, move |ctx: &LuaValue, args| {
172 let mut combined_args = LuaVariadic::with_capacity(args.len() + 1);
173 combined_args.push(LuaEither::Left(ctx));
174 combined_args.extend(args.iter().map(|arg| LuaEither::Right(arg.as_str())));
175
176 Ok(match fetcher.call(combined_args)? {
177 LuaValue::String(s) => Value::String(s.to_string_lossy().into()),
178 value => lua.upgrade().from_value::<serde_json::Value>(value)?.into(),
179 })
180 });
181
182 let params = params.as_ref();
183 let matcher = params.and_then(|p| p.get::<String>("matcher").ok());
184 match matcher.as_deref() {
185 Some("bool") => fetcher.with_matcher(BoolMatcher),
186 Some("ip") => fetcher.with_matcher(IpMatcher),
187 Some("number") => fetcher.with_matcher(NumberMatcher),
188 Some("regex") | Some("re") => fetcher.with_matcher(RegexMatcher),
189 Some("string") => fetcher.with_matcher(StringMatcher),
190 Some(_) => return Ok(Err("unknown matcher type")),
191 None => fetcher,
192 };
193 let raw_args = params.and_then(|p| p.get("raw_args").ok());
194 if raw_args.unwrap_or(false) {
195 fetcher.with_raw_args(true);
196 }
197 Ok(Ok(()))
198 },
199 );
200
201 registry.add_method("compile", |lua, this, rule: LuaValue| {
203 let rule = lua.from_value::<serde_json::Value>(rule)?;
204 match this.0.compile_rule(&rule) {
205 Ok(rule) => Ok(Ok(Rule(rule))),
206 Err(err) => Ok(Err(err.to_string())),
207 }
208 });
209
210 #[cfg(feature = "validation")]
212 registry.add_method("validate", |lua, this, rule: LuaValue| {
213 let rule = lua.from_value::<serde_json::Value>(rule)?;
214 match this.0.validate_rule(&rule) {
215 Ok(_) => Ok(Ok(true)),
216 Err(err) => Ok(Err(err.to_string())),
217 }
218 });
219
220 registry.add_method("json_schema", |_, this, ()| {
222 let schema = this.0.json_schema();
223 serde_json::to_string(&schema).into_lua_err()
224 });
225 }
226}
227
228pub struct Rule(rusty_rules::Rule<LuaValue>);
232
233impl Deref for Rule {
234 type Target = rusty_rules::Rule<LuaValue>;
235
236 fn deref(&self) -> &Self::Target {
237 &self.0
238 }
239}
240
241impl DerefMut for Rule {
242 fn deref_mut(&mut self) -> &mut Self::Target {
243 &mut self.0
244 }
245}
246
247impl LuaUserData for Rule {
248 fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
249 methods.add_method("evaluate", |_, this, ctx| match this.0.evaluate(&ctx) {
251 Ok(decision) => Ok(Ok(decision)),
252 Err(err) => Ok(Err(err.to_string())),
253 });
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use mlua::prelude::*;
260
261 use super::Engine;
262
263 #[test]
264 fn test() -> LuaResult<()> {
265 let lua = Lua::new();
266
267 let engine = lua.create_proxy::<Engine>()?;
268 lua.globals().set("Engine", engine)?;
269
270 lua.load(
271 r#"
272 local engine = Engine.new()
273 engine:register_fetcher("ctx_fetcher", function(ctx, arg)
274 return ctx[arg]
275 end)
276
277 local rule = engine:compile({["ctx_fetcher(key)"] = "my_value"})
278 assert(rule:evaluate({key = "my_value"}) == true, "(1) should evaluate to true")
279 assert(rule:evaluate({key = "other_value"}) == false, "(2) should evaluate to false")
280 assert(rule:evaluate({}) == false, "(3) should evaluate to false")
281
282 -- Edge cases
283 local ok, err = rule:evaluate()
284 assert(ok == nil and err:find("attempt to index a nil value"), "(4) should return an error")
285 ok, err = engine:compile({["ctx_fetcher(key)"] = {op = 123}})
286 assert(ok == nil and err:find("unknown operator 'op'"), "(5) should return an error")
287
288 -- Check complex struct
289 local complex_struct = { array = {1, 2, 3}, map = { key = "value" } }
290 local rule = engine:compile({
291 ["ctx_fetcher(key)"] = { ["=="] = complex_struct }
292 })
293 assert(rule:evaluate({key = complex_struct}) == true, "(6) should evaluate to true")
294 complex_struct.array[1] = 42 -- Modify to check immutability of compiled rule
295 assert(rule:evaluate({key = complex_struct}) == false, "(7) should evaluate to false")
296 "#,
297 )
298 .exec()
299 .unwrap();
300
301 Ok(())
302 }
303
304 #[test]
305 fn test_custom_matcher() -> LuaResult<()> {
306 let lua = Lua::new();
307
308 let engine = lua.create_proxy::<Engine>()?;
309 lua.globals().set("Engine", engine)?;
310
311 lua.load(
312 r#"
313 local engine = Engine.new()
314 engine:register_fetcher("ip", { matcher = "ip" }, function(ctx)
315 return ctx.ip
316 end)
317
318 local rule = engine:compile({ip = "127.0.0.1/8"})
319 assert(rule:evaluate({ip = "127.0.0.1"}) == true, "(1) should evaluate to true")
320 assert(rule:evaluate({ip = "172.16.0.0"}) == false, "(2) should evaluate to false")
321
322 local _, err = engine:register_fetcher("ip", { matcher = "abc" }, function(ctx) end)
323 assert(err:find("unknown matcher type"), "(3) should return an error for unknown matcher type")
324 "#,
325 )
326 .exec()
327 .unwrap();
328
329 Ok(())
330 }
331
332 #[cfg(feature = "validation")]
333 #[test]
334 fn test_validation() -> LuaResult<()> {
335 let lua = Lua::new();
336
337 let engine = lua.create_proxy::<Engine>()?;
338 lua.globals().set("Engine", engine)?;
339
340 lua.load(
341 r#"
342 local engine = Engine.new()
343 engine:register_fetcher("ctx_fetcher", function(ctx)
344 return ctx[arg]
345 end)
346
347 local ok, err = engine:validate({["ctx_fetcher(key)"] = "my_value"})
348 assert(ok == true and err == nil, "(1) should compile successfully")
349
350 ok, err = engine:validate({unknown_fetcher = "my_value"})
351 assert(not ok and err:find("'unknown_fetcher' was unexpected"), "(2) should return an error for unknown fetcher")
352 "#,
353 )
354 .exec()
355 .unwrap();
356
357 Ok(())
358 }
359}