surrealdb_sql/
function.rs

1use crate::ctx::Context;
2use crate::dbs::{Options, Transaction};
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fmt::Fmt;
6use crate::fnc;
7use crate::iam::Action;
8use crate::idiom::Idiom;
9use crate::script::Script;
10use crate::value::Value;
11use crate::Permission;
12use async_recursion::async_recursion;
13use futures::future::try_join_all;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::crate::Function";
22
23#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
24#[serde(rename = "$surrealdb::private::crate::Function")]
25#[revisioned(revision = 1)]
26pub enum Function {
27	Normal(String, Vec<Value>),
28	Custom(String, Vec<Value>),
29	Script(Script, Vec<Value>),
30	// Add new variants here
31}
32
33impl PartialOrd for Function {
34	#[inline]
35	fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
36		None
37	}
38}
39
40impl Function {
41	/// Get function name if applicable
42	pub fn name(&self) -> Option<&str> {
43		match self {
44			Self::Normal(n, _) => Some(n.as_str()),
45			Self::Custom(n, _) => Some(n.as_str()),
46			_ => None,
47		}
48	}
49	/// Get function arguments if applicable
50	pub fn args(&self) -> &[Value] {
51		match self {
52			Self::Normal(_, a) => a,
53			Self::Custom(_, a) => a,
54			_ => &[],
55		}
56	}
57	/// Convert function call to a field name
58	pub fn to_idiom(&self) -> Idiom {
59		match self {
60			Self::Script(_, _) => "function".to_string().into(),
61			Self::Normal(f, _) => f.to_owned().into(),
62			Self::Custom(f, _) => format!("fn::{f}").into(),
63		}
64	}
65	/// Convert this function to an aggregate
66	pub fn aggregate(&self, val: Value) -> Self {
67		match self {
68			Self::Normal(n, a) => {
69				let mut a = a.to_owned();
70				match a.len() {
71					0 => a.insert(0, val),
72					_ => {
73						a.remove(0);
74						a.insert(0, val);
75					}
76				}
77				Self::Normal(n.to_owned(), a)
78			}
79			_ => unreachable!(),
80		}
81	}
82	/// Check if this function is a custom function
83	pub fn is_custom(&self) -> bool {
84		matches!(self, Self::Custom(_, _))
85	}
86
87	/// Check if this function is a scripting function
88	pub fn is_script(&self) -> bool {
89		matches!(self, Self::Script(_, _))
90	}
91
92	/// Check if this function has static arguments
93	pub fn is_static(&self) -> bool {
94		match self {
95			Self::Normal(_, a) => a.iter().all(Value::is_static),
96			_ => false,
97		}
98	}
99
100	/// Check if this function is a rolling function
101	pub fn is_rolling(&self) -> bool {
102		match self {
103			Self::Normal(f, _) if f == "count" => true,
104			Self::Normal(f, _) if f == "math::max" => true,
105			Self::Normal(f, _) if f == "math::mean" => true,
106			Self::Normal(f, _) if f == "math::min" => true,
107			Self::Normal(f, _) if f == "math::sum" => true,
108			Self::Normal(f, _) if f == "time::max" => true,
109			Self::Normal(f, _) if f == "time::min" => true,
110			_ => false,
111		}
112	}
113	/// Check if this function is a grouping function
114	pub fn is_aggregate(&self) -> bool {
115		match self {
116			Self::Normal(f, _) if f == "array::distinct" => true,
117			Self::Normal(f, _) if f == "array::first" => true,
118			Self::Normal(f, _) if f == "array::flatten" => true,
119			Self::Normal(f, _) if f == "array::group" => true,
120			Self::Normal(f, _) if f == "array::last" => true,
121			Self::Normal(f, _) if f == "count" => true,
122			Self::Normal(f, _) if f == "math::bottom" => true,
123			Self::Normal(f, _) if f == "math::interquartile" => true,
124			Self::Normal(f, _) if f == "math::max" => true,
125			Self::Normal(f, _) if f == "math::mean" => true,
126			Self::Normal(f, _) if f == "math::median" => true,
127			Self::Normal(f, _) if f == "math::midhinge" => true,
128			Self::Normal(f, _) if f == "math::min" => true,
129			Self::Normal(f, _) if f == "math::mode" => true,
130			Self::Normal(f, _) if f == "math::nearestrank" => true,
131			Self::Normal(f, _) if f == "math::percentile" => true,
132			Self::Normal(f, _) if f == "math::sample" => true,
133			Self::Normal(f, _) if f == "math::spread" => true,
134			Self::Normal(f, _) if f == "math::stddev" => true,
135			Self::Normal(f, _) if f == "math::sum" => true,
136			Self::Normal(f, _) if f == "math::top" => true,
137			Self::Normal(f, _) if f == "math::trimean" => true,
138			Self::Normal(f, _) if f == "math::variance" => true,
139			Self::Normal(f, _) if f == "time::max" => true,
140			Self::Normal(f, _) if f == "time::min" => true,
141			_ => false,
142		}
143	}
144}
145
146impl Function {
147	/// Process this type returning a computed simple Value
148	#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
149	#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
150	pub(crate) async fn compute(
151		&self,
152		ctx: &Context<'_>,
153		opt: &Options,
154		txn: &Transaction,
155		doc: Option<&'async_recursion CursorDoc<'_>>,
156	) -> Result<Value, Error> {
157		// Ensure futures are run
158		let opt = &opt.new_with_futures(true);
159		// Process the function type
160		match self {
161			Self::Normal(s, x) => {
162				// Check this function is allowed
163				ctx.check_allowed_function(s)?;
164				// Compute the function arguments
165				let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
166				// Run the normal function
167				fnc::run(ctx, opt, txn, doc, s, a).await
168			}
169			Self::Custom(s, x) => {
170				// Get the full name of this function
171				let name = format!("fn::{s}");
172				// Check this function is allowed
173				ctx.check_allowed_function(name.as_str())?;
174				// Get the function definition
175				let val = {
176					// Claim transaction
177					let mut run = txn.lock().await;
178					// Get the function definition
179					run.get_and_cache_db_function(opt.ns(), opt.db(), s).await?
180				};
181				// Check permissions
182				if opt.check_perms(Action::View) {
183					match &val.permissions {
184						Permission::Full => (),
185						Permission::None => {
186							return Err(Error::FunctionPermissions {
187								name: s.to_owned(),
188							})
189						}
190						Permission::Specific(e) => {
191							// Disable permissions
192							let opt = &opt.new_with_perms(false);
193							// Process the PERMISSION clause
194							if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
195								return Err(Error::FunctionPermissions {
196									name: s.to_owned(),
197								});
198							}
199						}
200					}
201				}
202				// Get the number of function arguments
203				let max_args_len = val.args.len();
204				// Track the number of required arguments
205				let mut min_args_len = 0;
206				// Check for any final optional arguments
207				val.args.iter().rev().for_each(|(_, kind)| match kind {
208					Kind::Option(_) if min_args_len == 0 => {}
209					_ => min_args_len += 1,
210				});
211				// Check the necessary arguments are passed
212				if x.len() < min_args_len || max_args_len < x.len() {
213					return Err(Error::InvalidArguments {
214						name: format!("fn::{}", val.name),
215						message: match (min_args_len, max_args_len) {
216							(1, 1) => String::from("The function expects 1 argument."),
217							(r, t) if r == t => format!("The function expects {r} arguments."),
218							(r, t) => format!("The function expects {r} to {t} arguments."),
219						},
220					});
221				}
222				// Compute the function arguments
223				let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
224				// Duplicate context
225				let mut ctx = Context::new(ctx);
226				// Process the function arguments
227				for (val, (name, kind)) in a.into_iter().zip(&val.args) {
228					ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
229				}
230				// Run the custom function
231				val.block.compute(&ctx, opt, txn, doc).await
232			}
233			#[allow(unused_variables)]
234			Self::Script(s, x) => {
235				#[cfg(feature = "scripting")]
236				{
237					// Check if scripting is allowed
238					ctx.check_allowed_scripting()?;
239					// Compute the function arguments
240					let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
241					// Run the script function
242					fnc::script::run(ctx, opt, txn, doc, s, a).await
243				}
244				#[cfg(not(feature = "scripting"))]
245				{
246					Err(Error::InvalidScript {
247						message: String::from("Embedded functions are not enabled."),
248					})
249				}
250			}
251		}
252	}
253}
254
255impl fmt::Display for Function {
256	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
257		match self {
258			Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
259			Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
260			Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
261		}
262	}
263}