1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 1)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29 Normal(String, Vec<Value>),
30 Custom(String, Vec<Value>),
31 Script(Script, Vec<Value>),
32 Anonymous(Value, Vec<Value>),
33 }
35
36pub(crate) enum OptimisedAggregate {
37 None,
38 Count,
39 CountFunction,
40 MathMax,
41 MathMin,
42 MathSum,
43 MathMean,
44 TimeMax,
45 TimeMin,
46}
47
48impl PartialOrd for Function {
49 #[inline]
50 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
51 None
52 }
53}
54
55impl Function {
56 pub fn name(&self) -> Option<&str> {
58 match self {
59 Self::Normal(n, _) => Some(n.as_str()),
60 Self::Custom(n, _) => Some(n.as_str()),
61 _ => None,
62 }
63 }
64 pub fn args(&self) -> &[Value] {
66 match self {
67 Self::Normal(_, a) => a,
68 Self::Custom(_, a) => a,
69 _ => &[],
70 }
71 }
72 pub fn to_idiom(&self) -> Idiom {
74 match self {
75 Self::Anonymous(_, _) => "function".to_string().into(),
76 Self::Script(_, _) => "function".to_string().into(),
77 Self::Normal(f, _) => f.to_owned().into(),
78 Self::Custom(f, _) => format!("fn::{f}").into(),
79 }
80 }
81 pub fn aggregate(&self, val: Value) -> Self {
83 match self {
84 Self::Normal(n, a) => {
85 let mut a = a.to_owned();
86 match a.len() {
87 0 => a.insert(0, val),
88 _ => {
89 a.remove(0);
90 a.insert(0, val);
91 }
92 }
93 Self::Normal(n.to_owned(), a)
94 }
95 _ => unreachable!(),
96 }
97 }
98 pub fn is_custom(&self) -> bool {
100 matches!(self, Self::Custom(_, _))
101 }
102
103 pub fn is_script(&self) -> bool {
105 matches!(self, Self::Script(_, _))
106 }
107
108 pub fn is_static(&self) -> bool {
110 match self {
111 Self::Normal(_, a) => a.iter().all(Value::is_static),
112 _ => false,
113 }
114 }
115
116 pub fn is_inline(&self) -> bool {
118 matches!(self, Self::Anonymous(_, _))
119 }
120
121 pub fn is_rolling(&self) -> bool {
123 match self {
124 Self::Normal(f, _) if f == "count" => true,
125 Self::Normal(f, _) if f == "math::max" => true,
126 Self::Normal(f, _) if f == "math::mean" => true,
127 Self::Normal(f, _) if f == "math::min" => true,
128 Self::Normal(f, _) if f == "math::sum" => true,
129 Self::Normal(f, _) if f == "time::max" => true,
130 Self::Normal(f, _) if f == "time::min" => true,
131 _ => false,
132 }
133 }
134 pub fn is_aggregate(&self) -> bool {
136 match self {
137 Self::Normal(f, _) if f == "array::distinct" => true,
138 Self::Normal(f, _) if f == "array::first" => true,
139 Self::Normal(f, _) if f == "array::flatten" => true,
140 Self::Normal(f, _) if f == "array::group" => true,
141 Self::Normal(f, _) if f == "array::last" => true,
142 Self::Normal(f, _) if f == "count" => true,
143 Self::Normal(f, _) if f == "math::bottom" => true,
144 Self::Normal(f, _) if f == "math::interquartile" => true,
145 Self::Normal(f, _) if f == "math::max" => true,
146 Self::Normal(f, _) if f == "math::mean" => true,
147 Self::Normal(f, _) if f == "math::median" => true,
148 Self::Normal(f, _) if f == "math::midhinge" => true,
149 Self::Normal(f, _) if f == "math::min" => true,
150 Self::Normal(f, _) if f == "math::mode" => true,
151 Self::Normal(f, _) if f == "math::nearestrank" => true,
152 Self::Normal(f, _) if f == "math::percentile" => true,
153 Self::Normal(f, _) if f == "math::sample" => true,
154 Self::Normal(f, _) if f == "math::spread" => true,
155 Self::Normal(f, _) if f == "math::stddev" => true,
156 Self::Normal(f, _) if f == "math::sum" => true,
157 Self::Normal(f, _) if f == "math::top" => true,
158 Self::Normal(f, _) if f == "math::trimean" => true,
159 Self::Normal(f, _) if f == "math::variance" => true,
160 Self::Normal(f, _) if f == "time::max" => true,
161 Self::Normal(f, _) if f == "time::min" => true,
162 _ => false,
163 }
164 }
165 pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
166 match self {
167 Self::Normal(f, v) if f == "count" => {
168 if v.is_empty() {
169 OptimisedAggregate::Count
170 } else {
171 OptimisedAggregate::CountFunction
172 }
173 }
174 Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
175 Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
176 Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
177 Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
178 Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
179 Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
180 _ => OptimisedAggregate::None,
181 }
182 }
183}
184
185impl Function {
186 pub(crate) async fn compute(
190 &self,
191 stk: &mut Stk,
192 ctx: &Context<'_>,
193 opt: &Options,
194 doc: Option<&CursorDoc<'_>>,
195 ) -> Result<Value, Error> {
196 let opt = &opt.new_with_futures(true);
198 match self {
200 Self::Normal(s, x) => {
201 ctx.check_allowed_function(s)?;
203 let a = stk
205 .scope(|scope| {
206 try_join_all(
207 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
208 )
209 })
210 .await?;
211 fnc::run(stk, ctx, opt, doc, s, a).await
213 }
214 Self::Anonymous(v, x) => {
215 let val = match v {
216 c @ Value::Closure(_) => c.clone(),
217 Value::Param(p) => ctx.value(p).cloned().unwrap_or(Value::None),
218 Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
219 stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
220 }
221 _ => Value::None,
222 };
223
224 match val {
225 Value::Closure(closure) => {
226 let a = stk
228 .scope(|scope| {
229 try_join_all(
230 x.iter()
231 .map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
232 )
233 })
234 .await?;
235 stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
236 }
237 v => Err(Error::InvalidFunction {
238 name: "ANONYMOUS".to_string(),
239 message: format!("'{}' is not a function", v.kindof()),
240 }),
241 }
242 }
243 Self::Custom(s, x) => {
244 let name = format!("fn::{s}");
246 ctx.check_allowed_function(name.as_str())?;
248 let val = ctx.tx().get_db_function(opt.ns()?, opt.db()?, s).await?;
250 if opt.check_perms(Action::View)? {
252 match &val.permissions {
253 Permission::Full => (),
254 Permission::None => {
255 return Err(Error::FunctionPermissions {
256 name: s.to_owned(),
257 })
258 }
259 Permission::Specific(e) => {
260 let opt = &opt.new_with_perms(false);
262 if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
264 return Err(Error::FunctionPermissions {
265 name: s.to_owned(),
266 });
267 }
268 }
269 }
270 }
271 let max_args_len = val.args.len();
273 let mut min_args_len = 0;
275 val.args.iter().rev().for_each(|(_, kind)| match kind {
277 Kind::Option(_) if min_args_len == 0 => {}
278 _ => min_args_len += 1,
279 });
280 if x.len() < min_args_len || max_args_len < x.len() {
282 return Err(Error::InvalidArguments {
283 name: format!("fn::{}", val.name),
284 message: match (min_args_len, max_args_len) {
285 (1, 1) => String::from("The function expects 1 argument."),
286 (r, t) if r == t => format!("The function expects {r} arguments."),
287 (r, t) => format!("The function expects {r} to {t} arguments."),
288 },
289 });
290 }
291 let a = stk
293 .scope(|scope| {
294 try_join_all(
295 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
296 )
297 })
298 .await?;
299 let mut ctx = Context::new_isolated(ctx);
301 for (val, (name, kind)) in a.into_iter().zip(&val.args) {
303 ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
304 }
305 let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
307 Err(Error::Return {
308 value,
309 }) => Ok(value),
310 res => res,
311 }?;
312
313 if let Some(ref returns) = val.returns {
314 result
315 .coerce_to(returns)
316 .map_err(|e| e.function_check_from_coerce(val.name.to_string()))
317 } else {
318 Ok(result)
319 }
320 }
321 #[allow(unused_variables)]
322 Self::Script(s, x) => {
323 #[cfg(feature = "scripting")]
324 {
325 ctx.check_allowed_scripting()?;
327 let a = stk
329 .scope(|scope| {
330 try_join_all(
331 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
332 )
333 })
334 .await?;
335 fnc::script::run(ctx, opt, doc, s, a).await
337 }
338 #[cfg(not(feature = "scripting"))]
339 {
340 Err(Error::InvalidScript {
341 message: String::from("Embedded functions are not enabled."),
342 })
343 }
344 }
345 }
346 }
347}
348
349impl fmt::Display for Function {
350 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
351 match self {
352 Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
353 Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
354 Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
355 Self::Anonymous(p, e) => write!(f, "{p}({})", Fmt::comma_separated(e)),
356 }
357 }
358}