surrealdb_core/sql/
function.rs

1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 1)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29	Normal(String, Vec<Value>),
30	Custom(String, Vec<Value>),
31	Script(Script, Vec<Value>),
32	// Add new variants here
33}
34
35pub(crate) enum OptimisedAggregate {
36	None,
37	Count,
38	CountFunction,
39	MathMax,
40	MathMin,
41	MathSum,
42	MathMean,
43	TimeMax,
44	TimeMin,
45}
46
47impl PartialOrd for Function {
48	#[inline]
49	fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
50		None
51	}
52}
53
54impl Function {
55	/// Get function name if applicable
56	pub fn name(&self) -> Option<&str> {
57		match self {
58			Self::Normal(n, _) => Some(n.as_str()),
59			Self::Custom(n, _) => Some(n.as_str()),
60			_ => None,
61		}
62	}
63	/// Get function arguments if applicable
64	pub fn args(&self) -> &[Value] {
65		match self {
66			Self::Normal(_, a) => a,
67			Self::Custom(_, a) => a,
68			_ => &[],
69		}
70	}
71	/// Convert function call to a field name
72	pub fn to_idiom(&self) -> Idiom {
73		match self {
74			Self::Script(_, _) => "function".to_string().into(),
75			Self::Normal(f, _) => f.to_owned().into(),
76			Self::Custom(f, _) => format!("fn::{f}").into(),
77		}
78	}
79	/// Convert this function to an aggregate
80	pub fn aggregate(&self, val: Value) -> Self {
81		match self {
82			Self::Normal(n, a) => {
83				let mut a = a.to_owned();
84				match a.len() {
85					0 => a.insert(0, val),
86					_ => {
87						a.remove(0);
88						a.insert(0, val);
89					}
90				}
91				Self::Normal(n.to_owned(), a)
92			}
93			_ => unreachable!(),
94		}
95	}
96	/// Check if this function is a custom function
97	pub fn is_custom(&self) -> bool {
98		matches!(self, Self::Custom(_, _))
99	}
100
101	/// Check if this function is a scripting function
102	pub fn is_script(&self) -> bool {
103		matches!(self, Self::Script(_, _))
104	}
105
106	/// Check if this function has static arguments
107	pub fn is_static(&self) -> bool {
108		match self {
109			Self::Normal(_, a) => a.iter().all(Value::is_static),
110			_ => false,
111		}
112	}
113
114	/// Check if this function is a rolling function
115	pub fn is_rolling(&self) -> bool {
116		match self {
117			Self::Normal(f, _) if f == "count" => true,
118			Self::Normal(f, _) if f == "math::max" => true,
119			Self::Normal(f, _) if f == "math::mean" => true,
120			Self::Normal(f, _) if f == "math::min" => true,
121			Self::Normal(f, _) if f == "math::sum" => true,
122			Self::Normal(f, _) if f == "time::max" => true,
123			Self::Normal(f, _) if f == "time::min" => true,
124			_ => false,
125		}
126	}
127	/// Check if this function is a grouping function
128	pub fn is_aggregate(&self) -> bool {
129		match self {
130			Self::Normal(f, _) if f == "array::distinct" => true,
131			Self::Normal(f, _) if f == "array::first" => true,
132			Self::Normal(f, _) if f == "array::flatten" => true,
133			Self::Normal(f, _) if f == "array::group" => true,
134			Self::Normal(f, _) if f == "array::last" => true,
135			Self::Normal(f, _) if f == "count" => true,
136			Self::Normal(f, _) if f == "math::bottom" => true,
137			Self::Normal(f, _) if f == "math::interquartile" => true,
138			Self::Normal(f, _) if f == "math::max" => true,
139			Self::Normal(f, _) if f == "math::mean" => true,
140			Self::Normal(f, _) if f == "math::median" => true,
141			Self::Normal(f, _) if f == "math::midhinge" => true,
142			Self::Normal(f, _) if f == "math::min" => true,
143			Self::Normal(f, _) if f == "math::mode" => true,
144			Self::Normal(f, _) if f == "math::nearestrank" => true,
145			Self::Normal(f, _) if f == "math::percentile" => true,
146			Self::Normal(f, _) if f == "math::sample" => true,
147			Self::Normal(f, _) if f == "math::spread" => true,
148			Self::Normal(f, _) if f == "math::stddev" => true,
149			Self::Normal(f, _) if f == "math::sum" => true,
150			Self::Normal(f, _) if f == "math::top" => true,
151			Self::Normal(f, _) if f == "math::trimean" => true,
152			Self::Normal(f, _) if f == "math::variance" => true,
153			Self::Normal(f, _) if f == "time::max" => true,
154			Self::Normal(f, _) if f == "time::min" => true,
155			_ => false,
156		}
157	}
158	pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
159		match self {
160			Self::Normal(f, v) if f == "count" => {
161				if v.is_empty() {
162					OptimisedAggregate::Count
163				} else {
164					OptimisedAggregate::CountFunction
165				}
166			}
167			Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
168			Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
169			Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
170			Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
171			Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
172			Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
173			_ => OptimisedAggregate::None,
174		}
175	}
176}
177
178impl Function {
179	/// Process this type returning a computed simple Value
180	///
181	/// Was marked recursive
182	pub(crate) async fn compute(
183		&self,
184		stk: &mut Stk,
185		ctx: &Context<'_>,
186		opt: &Options,
187		doc: Option<&CursorDoc<'_>>,
188	) -> Result<Value, Error> {
189		// Ensure futures are run
190		let opt = &opt.new_with_futures(true);
191		// Process the function type
192		match self {
193			Self::Normal(s, x) => {
194				// Check this function is allowed
195				ctx.check_allowed_function(s)?;
196				// Compute the function arguments
197				let a = stk
198					.scope(|scope| {
199						try_join_all(
200							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
201						)
202					})
203					.await?;
204				// Run the normal function
205				fnc::run(stk, ctx, opt, doc, s, a).await
206			}
207			Self::Custom(s, x) => {
208				// Get the full name of this function
209				let name = format!("fn::{s}");
210				// Check this function is allowed
211				ctx.check_allowed_function(name.as_str())?;
212				// Get the function definition
213				let val = {
214					// Claim transaction
215					let mut run = ctx.tx_lock().await;
216					// Get the function definition
217					let val = run.get_and_cache_db_function(opt.ns()?, opt.db()?, s).await?;
218					drop(run);
219					val
220				};
221				// Check permissions
222				if opt.check_perms(Action::View)? {
223					match &val.permissions {
224						Permission::Full => (),
225						Permission::None => {
226							return Err(Error::FunctionPermissions {
227								name: s.to_owned(),
228							})
229						}
230						Permission::Specific(e) => {
231							// Disable permissions
232							let opt = &opt.new_with_perms(false);
233							// Process the PERMISSION clause
234							if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
235								return Err(Error::FunctionPermissions {
236									name: s.to_owned(),
237								});
238							}
239						}
240					}
241				}
242				// Get the number of function arguments
243				let max_args_len = val.args.len();
244				// Track the number of required arguments
245				let mut min_args_len = 0;
246				// Check for any final optional arguments
247				val.args.iter().rev().for_each(|(_, kind)| match kind {
248					Kind::Option(_) if min_args_len == 0 => {}
249					_ => min_args_len += 1,
250				});
251				// Check the necessary arguments are passed
252				if x.len() < min_args_len || max_args_len < x.len() {
253					return Err(Error::InvalidArguments {
254						name: format!("fn::{}", val.name),
255						message: match (min_args_len, max_args_len) {
256							(1, 1) => String::from("The function expects 1 argument."),
257							(r, t) if r == t => format!("The function expects {r} arguments."),
258							(r, t) => format!("The function expects {r} to {t} arguments."),
259						},
260					});
261				}
262				// Compute the function arguments
263				let a = stk
264					.scope(|scope| {
265						try_join_all(
266							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
267						)
268					})
269					.await?;
270				// Duplicate context
271				let mut ctx = Context::new(ctx);
272				// Process the function arguments
273				for (val, (name, kind)) in a.into_iter().zip(&val.args) {
274					ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
275				}
276				// Run the custom function
277				stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await
278			}
279			#[allow(unused_variables)]
280			Self::Script(s, x) => {
281				#[cfg(feature = "scripting")]
282				{
283					// Check if scripting is allowed
284					ctx.check_allowed_scripting()?;
285					// Compute the function arguments
286					let a = stk
287						.scope(|scope| {
288							try_join_all(
289								x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
290							)
291						})
292						.await?;
293					// Run the script function
294					fnc::script::run(ctx, opt, doc, s, a).await
295				}
296				#[cfg(not(feature = "scripting"))]
297				{
298					Err(Error::InvalidScript {
299						message: String::from("Embedded functions are not enabled."),
300					})
301				}
302			}
303		}
304	}
305}
306
307impl fmt::Display for Function {
308	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
309		match self {
310			Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
311			Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
312			Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
313		}
314	}
315}