1use crate::types::LiteralValue;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub trait FilterFunction: Send + Sync {
10 fn call(&self, args: &[LiteralValue]) -> Option<LiteralValue>;
11}
12
13#[derive(Default)]
14pub struct FunctionRegistry {
15 functions: HashMap<String, Arc<dyn FilterFunction>>,
16 function_names: Vec<String>, function_ids: HashMap<String, usize>, }
19
20impl FunctionRegistry {
21 pub fn new() -> Self {
22 Self::default()
23 }
24 pub fn register<F>(&mut self, name: impl Into<String>, func: F)
26 where
27 F: FilterFunction + 'static,
28 {
29 let name = name.into();
30 if !self.function_ids.contains_key(&name) {
31 self.function_ids
32 .insert(name.clone(), self.function_names.len());
33 self.function_names.push(name.clone());
34 }
35 self.functions.insert(name, Arc::new(func));
36 }
37 pub fn register_fn<F>(&mut self, name: impl Into<String>, func: F)
39 where
40 F: Fn(&[LiteralValue]) -> Option<LiteralValue> + Send + Sync + 'static,
41 {
42 struct ClosureFn<F>(F);
43 impl<F> FilterFunction for ClosureFn<F>
44 where
45 F: Fn(&[LiteralValue]) -> Option<LiteralValue> + Send + Sync + 'static,
46 {
47 fn call(&self, args: &[LiteralValue]) -> Option<LiteralValue> {
48 (self.0)(args)
49 }
50 }
51 self.register(name, ClosureFn(func));
52 }
53 pub fn function_id(&self, name: &str) -> Option<usize> {
55 self.function_ids.get(name).copied()
56 }
57 pub fn function_name(&self, id: usize) -> Option<&str> {
59 self.function_names.get(id).map(|s| s.as_str())
60 }
61 pub fn num_functions(&self) -> usize {
63 self.function_names.len()
64 }
65 pub fn get(&self, name: &str) -> Option<&Arc<dyn FilterFunction>> {
66 self.functions.get(name)
67 }
68 pub fn get_by_id(&self, id: usize) -> Option<&Arc<dyn FilterFunction>> {
70 self.function_names
71 .get(id)
72 .and_then(|name| self.functions.get(name))
73 }
74}
75
76impl Clone for FunctionRegistry {
77 fn clone(&self) -> Self {
78 Self {
79 functions: self.functions.clone(),
80 function_names: self.function_names.clone(),
81 function_ids: self.function_ids.clone(),
82 }
83 }
84}
85
86macro_rules! builtin_functions {
87 ($( $name:ident: $func_name:expr, $args:ident => $body:block ),* $(,)?) => {
88 $(
89 pub struct $name;
90 impl FilterFunction for $name {
91 fn call(&self, $args: &[LiteralValue]) -> Option<LiteralValue> $body
92 }
93 )*
94 pub fn register_builtins(reg: &mut FunctionRegistry) {
95 $(reg.register($func_name, $name);)*
96 }
97 };
98}
99
100builtin_functions! {
101 LenFunction: "len", args => {
102 if let Some(LiteralValue::Array(arr)) = args.first() {
103 Some(LiteralValue::Int(arr.len() as i64))
104 } else {
105 None
106 }
107 },
108 UpperFunction: "upper", args => {
109 if let Some(LiteralValue::Bytes(bytes)) = args.first() {
110 let s = String::from_utf8_lossy(bytes).to_uppercase();
111 Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
112 } else {
113 None
114 }
115 },
116 LowerFunction: "lower", args => {
117 if let Some(LiteralValue::Bytes(bytes)) = args.first() {
118 let s = String::from_utf8_lossy(bytes).to_lowercase();
119 Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
120 } else {
121 None
122 }
123 },
124 SumFunction: "sum", args => {
125 if let Some(LiteralValue::Array(arr)) = args.first() {
126 let sum: i64 = arr.iter().filter_map(|v| if let LiteralValue::Int(i) = v { Some(*i) } else { None }).sum();
127 Some(LiteralValue::Int(sum))
128 } else {
129 None
130 }
131 },
132 StartsWithFunction: "starts_with", args => {
133 if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(prefix))) = (args.first(), args.get(1)) {
134 let h = String::from_utf8_lossy(haystack);
135 let p = String::from_utf8_lossy(prefix);
136 Some(LiteralValue::Bool(h.starts_with(&*p)))
137 } else {
138 None
139 }
140 },
141 EndsWithFunction: "ends_with", args => {
142 if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(suffix))) = (args.first(), args.get(1)) {
143 let h = String::from_utf8_lossy(haystack);
144 let s = String::from_utf8_lossy(suffix);
145 Some(LiteralValue::Bool(h.ends_with(&*s)))
146 } else {
147 None
148 }
149 },
150}
151
152#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
153pub enum BuiltinFunctionId {
154 Len,
155 Upper,
156 Lower,
157 Sum,
158 StartsWith,
159 EndsWith,
160}
161
162impl BuiltinFunctionId {
163 pub fn from_name(name: &str) -> Option<Self> {
164 match name {
165 "len" => Some(Self::Len),
166 "upper" => Some(Self::Upper),
167 "lower" => Some(Self::Lower),
168 "sum" => Some(Self::Sum),
169 "starts_with" => Some(Self::StartsWith),
170 "ends_with" => Some(Self::EndsWith),
171 _ => None,
172 }
173 }
174}
175
176pub fn call_builtin(id: BuiltinFunctionId, args: &[LiteralValue]) -> Option<LiteralValue> {
177 match id {
178 BuiltinFunctionId::Len => {
179 if let Some(LiteralValue::Array(arr)) = args.first() {
180 Some(LiteralValue::Int(arr.len() as i64))
181 } else {
182 None
183 }
184 }
185 BuiltinFunctionId::Upper => {
186 if let Some(LiteralValue::Bytes(bytes)) = args.first() {
187 let s = String::from_utf8_lossy(bytes).to_uppercase();
188 Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
189 } else {
190 None
191 }
192 }
193 BuiltinFunctionId::Lower => {
194 if let Some(LiteralValue::Bytes(bytes)) = args.first() {
195 let s = String::from_utf8_lossy(bytes).to_lowercase();
196 Some(LiteralValue::Bytes(Arc::new(s.into_bytes())))
197 } else {
198 None
199 }
200 }
201 BuiltinFunctionId::Sum => {
202 if let Some(LiteralValue::Array(arr)) = args.first() {
203 let sum: i64 = arr
204 .iter()
205 .filter_map(|v| {
206 if let LiteralValue::Int(i) = v {
207 Some(*i)
208 } else {
209 None
210 }
211 })
212 .sum();
213 Some(LiteralValue::Int(sum))
214 } else {
215 None
216 }
217 }
218 BuiltinFunctionId::StartsWith => {
219 if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(prefix))) =
220 (args.first(), args.get(1))
221 {
222 let h = String::from_utf8_lossy(haystack);
223 let p = String::from_utf8_lossy(prefix);
224 Some(LiteralValue::Bool(h.starts_with(&*p)))
225 } else {
226 None
227 }
228 }
229 BuiltinFunctionId::EndsWith => {
230 if let (Some(LiteralValue::Bytes(haystack)), Some(LiteralValue::Bytes(suffix))) =
231 (args.first(), args.get(1))
232 {
233 let h = String::from_utf8_lossy(haystack);
234 let s = String::from_utf8_lossy(suffix);
235 Some(LiteralValue::Bool(h.ends_with(&*s)))
236 } else {
237 None
238 }
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 #[test]
247 fn test_register_and_call_len() {
248 let mut reg = FunctionRegistry::new();
249 reg.register("len", LenFunction);
250 let arr = LiteralValue::Array(Arc::new(vec![LiteralValue::Int(1), LiteralValue::Int(2)]));
251 let result = reg.get("len").unwrap().call(&[arr]);
252 assert_eq!(result, Some(LiteralValue::Int(2)));
253 }
254 #[test]
255 fn test_upper_function() {
256 let mut reg = FunctionRegistry::new();
257 reg.register("upper", UpperFunction);
258 let val = LiteralValue::Bytes(Arc::new(b"hello".to_vec()));
259 let result = reg.get("upper").unwrap().call(&[val]);
260 assert_eq!(
261 result,
262 Some(LiteralValue::Bytes(Arc::new(b"HELLO".to_vec())))
263 );
264 }
265 #[test]
266 fn test_sum_function() {
267 let mut reg = FunctionRegistry::new();
268 reg.register("sum", SumFunction);
269 let arr = LiteralValue::Array(Arc::new(vec![
270 LiteralValue::Int(1),
271 LiteralValue::Int(2),
272 LiteralValue::Int(3),
273 ]));
274 let result = reg.get("sum").unwrap().call(&[arr]);
275 assert_eq!(result, Some(LiteralValue::Int(6)));
276 }
277 #[test]
278 fn test_starts_with_function() {
279 let mut reg = FunctionRegistry::new();
280 reg.register("starts_with", StartsWithFunction);
281 let val = LiteralValue::Bytes(Arc::new(b"foobar".to_vec()));
282 let prefix = LiteralValue::Bytes(Arc::new(b"foo".to_vec()));
283 let wrong = LiteralValue::Bytes(Arc::new(b"bar".to_vec()));
284 assert_eq!(
285 reg.get("starts_with")
286 .unwrap()
287 .call(&[val.clone(), prefix.clone()]),
288 Some(LiteralValue::Bool(true))
289 );
290 assert_eq!(
291 reg.get("starts_with")
292 .unwrap()
293 .call(&[val.clone(), wrong.clone()]),
294 Some(LiteralValue::Bool(false))
295 );
296 assert_eq!(
297 reg.get("starts_with")
298 .unwrap()
299 .call(&[wrong.clone(), prefix.clone()]),
300 Some(LiteralValue::Bool(false))
301 );
302 }
303 #[test]
304 fn test_ends_with_function() {
305 let mut reg = FunctionRegistry::new();
306 reg.register("ends_with", EndsWithFunction);
307 let val = LiteralValue::Bytes(Arc::new(b"foobar".to_vec()));
308 let suffix = LiteralValue::Bytes(Arc::new(b"bar".to_vec()));
309 let wrong = LiteralValue::Bytes(Arc::new(b"foo".to_vec()));
310 assert_eq!(
311 reg.get("ends_with")
312 .unwrap()
313 .call(&[val.clone(), suffix.clone()]),
314 Some(LiteralValue::Bool(true))
315 );
316 assert_eq!(
317 reg.get("ends_with")
318 .unwrap()
319 .call(&[val.clone(), wrong.clone()]),
320 Some(LiteralValue::Bool(false))
321 );
322 assert_eq!(
323 reg.get("ends_with")
324 .unwrap()
325 .call(&[wrong.clone(), suffix.clone()]),
326 Some(LiteralValue::Bool(false))
327 );
328 }
329 #[test]
330 fn test_register_closure() {
331 let mut reg = FunctionRegistry::new();
332 reg.register_fn("always_true", |_args| Some(LiteralValue::Bool(true)));
333 let result = reg.get("always_true").unwrap().call(&[]);
334 assert_eq!(result, Some(LiteralValue::Bool(true)));
335 }
336}