wirerust/
functions.rs

1//! Functions module: defines built-in and user-defined filter functions.
2//!
3//! This module provides traits and registries for filter functions.
4
5use crate::types::LiteralValue;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub trait FilterFunction: Send + Sync {
10    fn call(&self, args: &[LiteralValue]) -> Option<LiteralValue>;
11}
12
13#[derive(Default)]
14pub struct FunctionRegistry {
15    functions: HashMap<String, Arc<dyn FilterFunction>>,
16    function_names: Vec<String>,          // index = FunctionId
17    function_ids: HashMap<String, usize>, // name -> id
18}
19
20impl FunctionRegistry {
21    pub fn new() -> Self {
22        Self::default()
23    }
24    /// Register a function and assign it a unique ID if not already present.
25    pub fn register<F>(&mut self, name: impl Into<String>, func: F)
26    where
27        F: FilterFunction + 'static,
28    {
29        let name = name.into();
30        if !self.function_ids.contains_key(&name) {
31            self.function_ids
32                .insert(name.clone(), self.function_names.len());
33            self.function_names.push(name.clone());
34        }
35        self.functions.insert(name, Arc::new(func));
36    }
37    /// Register a closure as a filter function.
38    pub fn register_fn<F>(&mut self, name: impl Into<String>, func: F)
39    where
40        F: Fn(&[LiteralValue]) -> Option<LiteralValue> + Send + Sync + 'static,
41    {
42        struct ClosureFn<F>(F);
43        impl<F> FilterFunction for ClosureFn<F>
44        where
45            F: Fn(&[LiteralValue]) -> Option<LiteralValue> + Send + Sync + 'static,
46        {
47            fn call(&self, args: &[LiteralValue]) -> Option<LiteralValue> {
48                (self.0)(args)
49            }
50        }
51        self.register(name, ClosureFn(func));
52    }
53    /// Get the function ID for a given function name, if it exists.
54    pub fn function_id(&self, name: &str) -> Option<usize> {
55        self.function_ids.get(name).copied()
56    }
57    /// Get the function name for a given function ID, if it exists.
58    pub fn function_name(&self, id: usize) -> Option<&str> {
59        self.function_names.get(id).map(|s| s.as_str())
60    }
61    /// Get the total number of functions.
62    pub fn num_functions(&self) -> usize {
63        self.function_names.len()
64    }
65    pub fn get(&self, name: &str) -> Option<&Arc<dyn FilterFunction>> {
66        self.functions.get(name)
67    }
68    /// Get a function by ID.
69    pub fn get_by_id(&self, id: usize) -> Option<&Arc<dyn FilterFunction>> {
70        self.function_names
71            .get(id)
72            .and_then(|name| self.functions.get(name))
73    }
74}
75
76impl Clone for FunctionRegistry {
77    fn clone(&self) -> Self {
78        Self {
79            functions: self.functions.clone(),
80            function_names: self.function_names.clone(),
81            function_ids: self.function_ids.clone(),
82        }
83    }
84}
85
86macro_rules! builtin_functions {
87    ($( $name:ident: $func_name:expr, $args:ident => $body:block ),* $(,)?) => {
88        $(
89            pub struct $name;
90            impl FilterFunction for $name {
91                fn call(&self, $args: &[LiteralValue]) -> Option<LiteralValue> $body
92            }
93        )*
94        pub fn register_builtins(reg: &mut FunctionRegistry) {
95            $(reg.register($func_name, $name);)*
96        }
97    };
98}
99
100builtin_functions! {
101    LenFunction: "len", args => {
102        if let Some(LiteralValue::Array(arr)) = args.first() {
103            Some(LiteralValue::Int(arr.len() as i64))
104        } else {
105            None
106        }
107    },
108    UpperFunction: "upper", args => {
109        if let Some(LiteralValue::Bytes(bytes)) = args.first() {
110            let s = String::from_utf8_lossy(bytes).to_uppercase();
111            Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
112        } else {
113            None
114        }
115    },
116    LowerFunction: "lower", args => {
117        if let Some(LiteralValue::Bytes(bytes)) = args.first() {
118            let s = String::from_utf8_lossy(bytes).to_lowercase();
119            Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
120        } else {
121            None
122        }
123    },
124    SumFunction: "sum", args => {
125        if let Some(LiteralValue::Array(arr)) = args.first() {
126            let sum: i64 = arr.iter().filter_map(|v| if let LiteralValue::Int(i) = v { Some(*i) } else { None }).sum();
127            Some(LiteralValue::Int(sum))
128        } else {
129            None
130        }
131    },
132    StartsWithFunction: "starts_with", args => {
133        if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(prefix))) = (args.first(), args.get(1)) {
134            let h = String::from_utf8_lossy(haystack);
135            let p = String::from_utf8_lossy(prefix);
136            Some(LiteralValue::Bool(h.starts_with(&*p)))
137        } else {
138            None
139        }
140    },
141    EndsWithFunction: "ends_with", args => {
142        if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(suffix))) = (args.first(), args.get(1)) {
143            let h = String::from_utf8_lossy(haystack);
144            let s = String::from_utf8_lossy(suffix);
145            Some(LiteralValue::Bool(h.ends_with(&*s)))
146        } else {
147            None
148        }
149    },
150}
151
152#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
153pub enum BuiltinFunctionId {
154    Len,
155    Upper,
156    Lower,
157    Sum,
158    StartsWith,
159    EndsWith,
160}
161
162impl BuiltinFunctionId {
163    pub fn from_name(name: &str) -> Option<Self> {
164        match name {
165            "len" => Some(Self::Len),
166            "upper" => Some(Self::Upper),
167            "lower" => Some(Self::Lower),
168            "sum" => Some(Self::Sum),
169            "starts_with" => Some(Self::StartsWith),
170            "ends_with" => Some(Self::EndsWith),
171            _ => None,
172        }
173    }
174}
175
176pub fn call_builtin(id: BuiltinFunctionId, args: &[LiteralValue]) -> Option<LiteralValue> {
177    match id {
178        BuiltinFunctionId::Len => {
179            if let Some(LiteralValue::Array(arr)) = args.first() {
180                Some(LiteralValue::Int(arr.len() as i64))
181            } else {
182                None
183            }
184        }
185        BuiltinFunctionId::Upper => {
186            if let Some(LiteralValue::Bytes(bytes)) = args.first() {
187                let s = String::from_utf8_lossy(bytes).to_uppercase();
188                Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
189            } else {
190                None
191            }
192        }
193        BuiltinFunctionId::Lower => {
194            if let Some(LiteralValue::Bytes(bytes)) = args.first() {
195                let s = String::from_utf8_lossy(bytes).to_lowercase();
196                Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
197            } else {
198                None
199            }
200        }
201        BuiltinFunctionId::Sum => {
202            if let Some(LiteralValue::Array(arr)) = args.first() {
203                let sum: i64 = arr
204                    .iter()
205                    .filter_map(|v| {
206                        if let LiteralValue::Int(i) = v {
207                            Some(*i)
208                        } else {
209                            None
210                        }
211                    })
212                    .sum();
213                Some(LiteralValue::Int(sum))
214            } else {
215                None
216            }
217        }
218        BuiltinFunctionId::StartsWith => {
219            if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(prefix))) =
220                (args.first(), args.get(1))
221            {
222                let h = String::from_utf8_lossy(haystack);
223                let p = String::from_utf8_lossy(prefix);
224                Some(LiteralValue::Bool(h.starts_with(&*p)))
225            } else {
226                None
227            }
228        }
229        BuiltinFunctionId::EndsWith => {
230            if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(suffix))) =
231                (args.first(), args.get(1))
232            {
233                let h = String::from_utf8_lossy(haystack);
234                let s = String::from_utf8_lossy(suffix);
235                Some(LiteralValue::Bool(h.ends_with(&*s)))
236            } else {
237                None
238            }
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    #[test]
247    fn test_register_and_call_len() {
248        let mut reg = FunctionRegistry::new();
249        reg.register("len", LenFunction);
250        let arr = LiteralValue::Array(Arc::new(vec![LiteralValue::Int(1), LiteralValue::Int(2)]));
251        let result = reg.get("len").unwrap().call(&[arr]);
252        assert_eq!(result, Some(LiteralValue::Int(2)));
253    }
254    #[test]
255    fn test_upper_function() {
256        let mut reg = FunctionRegistry::new();
257        reg.register("upper", UpperFunction);
258        let val = LiteralValue::Bytes(Arc::new(b"hello".to_vec()));
259        let result = reg.get("upper").unwrap().call(&[val]);
260        assert_eq!(
261            result,
262            Some(LiteralValue::Bytes(Arc::new(b"HELLO".to_vec())))
263        );
264    }
265    #[test]
266    fn test_sum_function() {
267        let mut reg = FunctionRegistry::new();
268        reg.register("sum", SumFunction);
269        let arr = LiteralValue::Array(Arc::new(vec![
270            LiteralValue::Int(1),
271            LiteralValue::Int(2),
272            LiteralValue::Int(3),
273        ]));
274        let result = reg.get("sum").unwrap().call(&[arr]);
275        assert_eq!(result, Some(LiteralValue::Int(6)));
276    }
277    #[test]
278    fn test_starts_with_function() {
279        let mut reg = FunctionRegistry::new();
280        reg.register("starts_with", StartsWithFunction);
281        let val = LiteralValue::Bytes(Arc::new(b"foobar".to_vec()));
282        let prefix = LiteralValue::Bytes(Arc::new(b"foo".to_vec()));
283        let wrong = LiteralValue::Bytes(Arc::new(b"bar".to_vec()));
284        assert_eq!(
285            reg.get("starts_with")
286                .unwrap()
287                .call(&[val.clone(), prefix.clone()]),
288            Some(LiteralValue::Bool(true))
289        );
290        assert_eq!(
291            reg.get("starts_with")
292                .unwrap()
293                .call(&[val.clone(), wrong.clone()]),
294            Some(LiteralValue::Bool(false))
295        );
296        assert_eq!(
297            reg.get("starts_with")
298                .unwrap()
299                .call(&[wrong.clone(), prefix.clone()]),
300            Some(LiteralValue::Bool(false))
301        );
302    }
303    #[test]
304    fn test_ends_with_function() {
305        let mut reg = FunctionRegistry::new();
306        reg.register("ends_with", EndsWithFunction);
307        let val = LiteralValue::Bytes(Arc::new(b"foobar".to_vec()));
308        let suffix = LiteralValue::Bytes(Arc::new(b"bar".to_vec()));
309        let wrong = LiteralValue::Bytes(Arc::new(b"foo".to_vec()));
310        assert_eq!(
311            reg.get("ends_with")
312                .unwrap()
313                .call(&[val.clone(), suffix.clone()]),
314            Some(LiteralValue::Bool(true))
315        );
316        assert_eq!(
317            reg.get("ends_with")
318                .unwrap()
319                .call(&[val.clone(), wrong.clone()]),
320            Some(LiteralValue::Bool(false))
321        );
322        assert_eq!(
323            reg.get("ends_with")
324                .unwrap()
325                .call(&[wrong.clone(), suffix.clone()]),
326            Some(LiteralValue::Bool(false))
327        );
328    }
329    #[test]
330    fn test_register_closure() {
331        let mut reg = FunctionRegistry::new();
332        reg.register_fn("always_true", |_args| Some(LiteralValue::Bool(true)));
333        let result = reg.get("always_true").unwrap().call(&[]);
334        assert_eq!(result, Some(LiteralValue::Bool(true)));
335    }
336}