surrealdb_core/sql/
function.rs

1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 1)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29	Normal(String, Vec<Value>),
30	Custom(String, Vec<Value>),
31	Script(Script, Vec<Value>),
32	Anonymous(Value, Vec<Value>),
33	// Add new variants here
34}
35
36pub(crate) enum OptimisedAggregate {
37	None,
38	Count,
39	CountFunction,
40	MathMax,
41	MathMin,
42	MathSum,
43	MathMean,
44	TimeMax,
45	TimeMin,
46}
47
48impl PartialOrd for Function {
49	#[inline]
50	fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
51		None
52	}
53}
54
55impl Function {
56	/// Get function name if applicable
57	pub fn name(&self) -> Option<&str> {
58		match self {
59			Self::Normal(n, _) => Some(n.as_str()),
60			Self::Custom(n, _) => Some(n.as_str()),
61			_ => None,
62		}
63	}
64	/// Get function arguments if applicable
65	pub fn args(&self) -> &[Value] {
66		match self {
67			Self::Normal(_, a) => a,
68			Self::Custom(_, a) => a,
69			_ => &[],
70		}
71	}
72	/// Convert function call to a field name
73	pub fn to_idiom(&self) -> Idiom {
74		match self {
75			Self::Anonymous(_, _) => "function".to_string().into(),
76			Self::Script(_, _) => "function".to_string().into(),
77			Self::Normal(f, _) => f.to_owned().into(),
78			Self::Custom(f, _) => format!("fn::{f}").into(),
79		}
80	}
81	/// Convert this function to an aggregate
82	pub fn aggregate(&self, val: Value) -> Self {
83		match self {
84			Self::Normal(n, a) => {
85				let mut a = a.to_owned();
86				match a.len() {
87					0 => a.insert(0, val),
88					_ => {
89						a.remove(0);
90						a.insert(0, val);
91					}
92				}
93				Self::Normal(n.to_owned(), a)
94			}
95			_ => unreachable!(),
96		}
97	}
98	/// Check if this function is a custom function
99	pub fn is_custom(&self) -> bool {
100		matches!(self, Self::Custom(_, _))
101	}
102
103	/// Check if this function is a scripting function
104	pub fn is_script(&self) -> bool {
105		matches!(self, Self::Script(_, _))
106	}
107
108	/// Check if this function has static arguments
109	pub fn is_static(&self) -> bool {
110		match self {
111			Self::Normal(_, a) => a.iter().all(Value::is_static),
112			_ => false,
113		}
114	}
115
116	/// Check if this function is a closure function
117	pub fn is_inline(&self) -> bool {
118		matches!(self, Self::Anonymous(_, _))
119	}
120
121	/// Check if this function is a rolling function
122	pub fn is_rolling(&self) -> bool {
123		match self {
124			Self::Normal(f, _) if f == "count" => true,
125			Self::Normal(f, _) if f == "math::max" => true,
126			Self::Normal(f, _) if f == "math::mean" => true,
127			Self::Normal(f, _) if f == "math::min" => true,
128			Self::Normal(f, _) if f == "math::sum" => true,
129			Self::Normal(f, _) if f == "time::max" => true,
130			Self::Normal(f, _) if f == "time::min" => true,
131			_ => false,
132		}
133	}
134	/// Check if this function is a grouping function
135	pub fn is_aggregate(&self) -> bool {
136		match self {
137			Self::Normal(f, _) if f == "array::distinct" => true,
138			Self::Normal(f, _) if f == "array::first" => true,
139			Self::Normal(f, _) if f == "array::flatten" => true,
140			Self::Normal(f, _) if f == "array::group" => true,
141			Self::Normal(f, _) if f == "array::last" => true,
142			Self::Normal(f, _) if f == "count" => true,
143			Self::Normal(f, _) if f == "math::bottom" => true,
144			Self::Normal(f, _) if f == "math::interquartile" => true,
145			Self::Normal(f, _) if f == "math::max" => true,
146			Self::Normal(f, _) if f == "math::mean" => true,
147			Self::Normal(f, _) if f == "math::median" => true,
148			Self::Normal(f, _) if f == "math::midhinge" => true,
149			Self::Normal(f, _) if f == "math::min" => true,
150			Self::Normal(f, _) if f == "math::mode" => true,
151			Self::Normal(f, _) if f == "math::nearestrank" => true,
152			Self::Normal(f, _) if f == "math::percentile" => true,
153			Self::Normal(f, _) if f == "math::sample" => true,
154			Self::Normal(f, _) if f == "math::spread" => true,
155			Self::Normal(f, _) if f == "math::stddev" => true,
156			Self::Normal(f, _) if f == "math::sum" => true,
157			Self::Normal(f, _) if f == "math::top" => true,
158			Self::Normal(f, _) if f == "math::trimean" => true,
159			Self::Normal(f, _) if f == "math::variance" => true,
160			Self::Normal(f, _) if f == "time::max" => true,
161			Self::Normal(f, _) if f == "time::min" => true,
162			_ => false,
163		}
164	}
165	pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
166		match self {
167			Self::Normal(f, v) if f == "count" => {
168				if v.is_empty() {
169					OptimisedAggregate::Count
170				} else {
171					OptimisedAggregate::CountFunction
172				}
173			}
174			Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
175			Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
176			Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
177			Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
178			Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
179			Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
180			_ => OptimisedAggregate::None,
181		}
182	}
183}
184
185impl Function {
186	/// Process this type returning a computed simple Value
187	///
188	/// Was marked recursive
189	pub(crate) async fn compute(
190		&self,
191		stk: &mut Stk,
192		ctx: &Context<'_>,
193		opt: &Options,
194		doc: Option<&CursorDoc<'_>>,
195	) -> Result<Value, Error> {
196		// Ensure futures are run
197		let opt = &opt.new_with_futures(true);
198		// Process the function type
199		match self {
200			Self::Normal(s, x) => {
201				// Check this function is allowed
202				ctx.check_allowed_function(s)?;
203				// Compute the function arguments
204				let a = stk
205					.scope(|scope| {
206						try_join_all(
207							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
208						)
209					})
210					.await?;
211				// Run the normal function
212				fnc::run(stk, ctx, opt, doc, s, a).await
213			}
214			Self::Anonymous(v, x) => {
215				let val = match v {
216					c @ Value::Closure(_) => c.clone(),
217					Value::Param(p) => ctx.value(p).cloned().unwrap_or(Value::None),
218					Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
219						stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
220					}
221					_ => Value::None,
222				};
223
224				match val {
225					Value::Closure(closure) => {
226						// Compute the function arguments
227						let a = stk
228							.scope(|scope| {
229								try_join_all(
230									x.iter()
231										.map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
232								)
233							})
234							.await?;
235						stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
236					}
237					v => Err(Error::InvalidFunction {
238						name: "ANONYMOUS".to_string(),
239						message: format!("'{}' is not a function", v.kindof()),
240					}),
241				}
242			}
243			Self::Custom(s, x) => {
244				// Get the full name of this function
245				let name = format!("fn::{s}");
246				// Check this function is allowed
247				ctx.check_allowed_function(name.as_str())?;
248				// Get the function definition
249				let val = ctx.tx().get_db_function(opt.ns()?, opt.db()?, s).await?;
250				// Check permissions
251				if opt.check_perms(Action::View)? {
252					match &val.permissions {
253						Permission::Full => (),
254						Permission::None => {
255							return Err(Error::FunctionPermissions {
256								name: s.to_owned(),
257							})
258						}
259						Permission::Specific(e) => {
260							// Disable permissions
261							let opt = &opt.new_with_perms(false);
262							// Process the PERMISSION clause
263							if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
264								return Err(Error::FunctionPermissions {
265									name: s.to_owned(),
266								});
267							}
268						}
269					}
270				}
271				// Get the number of function arguments
272				let max_args_len = val.args.len();
273				// Track the number of required arguments
274				let mut min_args_len = 0;
275				// Check for any final optional arguments
276				val.args.iter().rev().for_each(|(_, kind)| match kind {
277					Kind::Option(_) if min_args_len == 0 => {}
278					_ => min_args_len += 1,
279				});
280				// Check the necessary arguments are passed
281				if x.len() < min_args_len || max_args_len < x.len() {
282					return Err(Error::InvalidArguments {
283						name: format!("fn::{}", val.name),
284						message: match (min_args_len, max_args_len) {
285							(1, 1) => String::from("The function expects 1 argument."),
286							(r, t) if r == t => format!("The function expects {r} arguments."),
287							(r, t) => format!("The function expects {r} to {t} arguments."),
288						},
289					});
290				}
291				// Compute the function arguments
292				let a = stk
293					.scope(|scope| {
294						try_join_all(
295							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
296						)
297					})
298					.await?;
299				// Duplicate context
300				let mut ctx = Context::new_isolated(ctx);
301				// Process the function arguments
302				for (val, (name, kind)) in a.into_iter().zip(&val.args) {
303					ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
304				}
305				// Run the custom function
306				let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
307					Err(Error::Return {
308						value,
309					}) => Ok(value),
310					res => res,
311				}?;
312
313				if let Some(ref returns) = val.returns {
314					result
315						.coerce_to(returns)
316						.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
317				} else {
318					Ok(result)
319				}
320			}
321			#[allow(unused_variables)]
322			Self::Script(s, x) => {
323				#[cfg(feature = "scripting")]
324				{
325					// Check if scripting is allowed
326					ctx.check_allowed_scripting()?;
327					// Compute the function arguments
328					let a = stk
329						.scope(|scope| {
330							try_join_all(
331								x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
332							)
333						})
334						.await?;
335					// Run the script function
336					fnc::script::run(ctx, opt, doc, s, a).await
337				}
338				#[cfg(not(feature = "scripting"))]
339				{
340					Err(Error::InvalidScript {
341						message: String::from("Embedded functions are not enabled."),
342					})
343				}
344			}
345		}
346	}
347}
348
349impl fmt::Display for Function {
350	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
351		match self {
352			Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
353			Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
354			Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
355			Self::Anonymous(p, e) => write!(f, "{p}({})", Fmt::comma_separated(e)),
356		}
357	}
358}