xee_interpreter/function/
static_function.rs

1use ahash::{HashMap, HashMapExt};
2use std::fmt::{Debug, Formatter};
3use xot::xmlname::NameStrInfo;
4
5use xee_name::{Name, Namespaces};
6use xee_xpath_ast::ast;
7
8use crate::context::DynamicContext;
9use crate::error;
10use crate::function;
11use crate::interpreter;
12use crate::library::static_function_descriptions;
13use crate::sequence;
14use crate::stack;
15
16#[derive(Debug, Clone, Eq, PartialEq, Hash, Copy)]
17pub(crate) enum FunctionKind {
18    // generate a function with one less arity that takes the
19    // item as the first argument
20    ItemFirst,
21    // generate a function with one less arity that takes the item
22    // as the last argument
23    ItemLast,
24    // generate just one function, but it takes an additional last
25    // argument that contains an option of the context item
26    ItemLastOptional,
27    // this function takes position as the implicit only argument
28    Position,
29    // this function takes size as the implicit only argument
30    Size,
31    // generate a function with one less arity that takes the collation
32    // as the last argument
33    Collation,
34}
35
36impl FunctionKind {
37    pub(crate) fn parse(s: &str) -> Option<FunctionKind> {
38        match s {
39            "" => None,
40            "context_first" => Some(FunctionKind::ItemFirst),
41            "context_last" => Some(FunctionKind::ItemLast),
42            "context_last_optional" => Some(FunctionKind::ItemLastOptional),
43            "position" => Some(FunctionKind::Position),
44            "size" => Some(FunctionKind::Size),
45            "collation" => Some(FunctionKind::Collation),
46            _ => panic!("Unknown function kind {}", s),
47        }
48    }
49}
50
51pub(crate) type StaticFunctionType = fn(
52    context: &DynamicContext,
53    interpreter: &mut interpreter::Interpreter,
54    arguments: &[sequence::Sequence],
55) -> error::Result<sequence::Sequence>;
56
57pub(crate) struct StaticFunctionDescription {
58    pub(crate) name: Name,
59    pub(crate) signature: function::Signature,
60    pub(crate) function_kind: Option<FunctionKind>,
61    pub(crate) func: StaticFunctionType,
62}
63
64// Wraps a Rust function annotated with `#[xpath_fn]` and turns it
65// into a StaticFunctionDescription
66#[macro_export]
67macro_rules! wrap_xpath_fn {
68    ($function:path) => {{
69        use $function as wrapped_function;
70        let namespaces = xee_name::Namespaces::default();
71        $crate::function::StaticFunctionDescription::new(
72            wrapped_function::WRAPPER,
73            wrapped_function::SIGNATURE,
74            $crate::function::FunctionKind::parse(wrapped_function::KIND),
75            &namespaces,
76        )
77    }};
78}
79
80impl StaticFunctionDescription {
81    pub(crate) fn new(
82        func: StaticFunctionType,
83        signature: &str,
84        function_kind: Option<FunctionKind>,
85        namespaces: &Namespaces,
86    ) -> Self {
87        // TODO reparse signature; the macro could have stored the parsed
88        // version as code, but that's more work than I'm prepared to do
89        // right now.
90        let signature = ast::Signature::parse(signature, namespaces)
91            .expect("Signature parse failed unexpectedly");
92        let name = signature.name.value.clone();
93        let signature: function::Signature = signature.into();
94        Self {
95            name,
96            signature,
97            function_kind,
98            func,
99        }
100    }
101
102    fn functions(&self) -> Vec<StaticFunction> {
103        if let Some(function_kind) = &self.function_kind {
104            self.signature
105                .alternative_signatures(*function_kind)
106                .into_iter()
107                .map(|(signature, function_kind)| {
108                    StaticFunction::new(self.func, self.name.clone(), signature, function_kind)
109                })
110                .collect()
111        } else {
112            vec![StaticFunction::new(
113                self.func,
114                self.name.clone(),
115                self.signature.clone(),
116                None,
117            )]
118        }
119    }
120}
121
122#[derive(Debug, Clone, Eq, PartialEq, Hash)]
123pub enum FunctionRule {
124    ItemFirst,
125    ItemLast,
126    ItemLastOptional,
127    PositionFirst,
128    SizeFirst,
129    Collation,
130}
131
132impl From<FunctionKind> for FunctionRule {
133    fn from(function_kind: FunctionKind) -> Self {
134        match function_kind {
135            FunctionKind::ItemFirst => FunctionRule::ItemFirst,
136            FunctionKind::ItemLast => FunctionRule::ItemLast,
137            FunctionKind::ItemLastOptional => FunctionRule::ItemLastOptional,
138            FunctionKind::Position => FunctionRule::PositionFirst,
139            FunctionKind::Size => FunctionRule::SizeFirst,
140            FunctionKind::Collation => FunctionRule::Collation,
141        }
142    }
143}
144
145pub struct StaticFunction {
146    name: Name,
147    signature: function::Signature,
148    arity: usize,
149    pub function_rule: Option<FunctionRule>,
150    func: StaticFunctionType,
151}
152
153impl Debug for StaticFunction {
154    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
155        f.debug_struct("StaticFunction")
156            .field("name", &self.name)
157            .field("arity", &self.arity)
158            .field("function_rule", &self.function_rule)
159            .finish()
160    }
161}
162
163impl StaticFunction {
164    pub(crate) fn new(
165        func: StaticFunctionType,
166        name: Name,
167        signature: function::Signature,
168        function_kind: Option<FunctionKind>,
169    ) -> Self {
170        let function_rule = function_kind.map(|k| k.into());
171        let arity = signature.arity();
172        Self {
173            name,
174            signature,
175            arity,
176            function_rule,
177            func,
178        }
179    }
180
181    pub(crate) fn needs_context(&self) -> bool {
182        match self.function_rule {
183            None | Some(FunctionRule::Collation) => false,
184            Some(_) => true,
185        }
186    }
187
188    pub(crate) fn invoke(
189        &self,
190        context: &DynamicContext,
191        interpreter: &mut interpreter::Interpreter,
192        arguments: Vec<sequence::Sequence>,
193        closure_values: &[stack::Value],
194    ) -> error::Result<sequence::Sequence> {
195        if let Some(function_rule) = &self.function_rule {
196            match function_rule {
197                FunctionRule::ItemFirst | FunctionRule::PositionFirst | FunctionRule::SizeFirst => {
198                    let mut new_arguments: Vec<sequence::Sequence> =
199                        vec![closure_values[0].clone().try_into()?];
200                    new_arguments.extend(arguments);
201                    (self.func)(context, interpreter, &new_arguments)
202                }
203                FunctionRule::ItemLast => {
204                    let mut new_arguments = arguments;
205                    new_arguments.push(closure_values[0].clone().try_into()?);
206                    (self.func)(context, interpreter, &new_arguments)
207                }
208                FunctionRule::ItemLastOptional => {
209                    let mut new_arguments = arguments;
210                    let value: sequence::Sequence =
211                        if !closure_values.is_empty() && !closure_values[0].is_absent() {
212                            closure_values[0].clone().try_into()?
213                        } else {
214                            sequence::Sequence::default()
215                        };
216                    new_arguments.push(value);
217                    (self.func)(context, interpreter, &new_arguments)
218                }
219                FunctionRule::Collation => {
220                    let mut new_arguments = arguments;
221                    // the default collation query
222                    new_arguments.push(context.static_context().default_collation_uri().into());
223                    (self.func)(context, interpreter, &new_arguments)
224                }
225            }
226        } else {
227            (self.func)(context, interpreter, &arguments)
228        }
229    }
230
231    pub(crate) fn name(&self) -> &Name {
232        &self.name
233    }
234
235    pub(crate) fn arity(&self) -> usize {
236        self.arity
237    }
238
239    pub(crate) fn signature(&self) -> &function::Signature {
240        &self.signature
241    }
242
243    pub fn display_representation(&self) -> String {
244        let name = self.name.full_name();
245        let signature = self.signature.display_representation();
246        format!("{}{}", name, signature)
247    }
248}
249
250fn into_sequences(values: &[stack::Value]) -> error::Result<Vec<sequence::Sequence>> {
251    values
252        .iter()
253        .map(|v| match v {
254            stack::Value::Sequence(sequence) => Ok(sequence.clone()),
255            stack::Value::Absent => Err(error::Error::XPDY0002),
256        })
257        .collect()
258}
259
260#[derive(Debug)]
261pub struct StaticFunctions {
262    by_name: HashMap<(Name, u8), function::StaticFunctionId>,
263    by_index: Vec<StaticFunction>,
264}
265
266impl StaticFunctions {
267    pub(crate) fn new() -> Self {
268        let mut by_name = HashMap::new();
269        let descriptions = static_function_descriptions();
270        let mut by_index = Vec::new();
271        for description in descriptions {
272            by_index.extend(description.functions());
273        }
274
275        for (i, static_function) in by_index.iter().enumerate() {
276            by_name.insert(
277                (static_function.name.clone(), static_function.arity as u8),
278                function::StaticFunctionId(i),
279            );
280        }
281        Self { by_name, by_index }
282    }
283
284    pub fn get_by_name(&self, name: &Name, arity: u8) -> Option<function::StaticFunctionId> {
285        // TODO annoying clone
286        self.by_name.get(&(name.clone(), arity)).copied()
287    }
288
289    pub fn get_by_index(&self, static_function_id: function::StaticFunctionId) -> &StaticFunction {
290        &self.by_index[static_function_id.0]
291    }
292}