1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 1)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29 Normal(String, Vec<Value>),
30 Custom(String, Vec<Value>),
31 Script(Script, Vec<Value>),
32 }
34
35pub(crate) enum OptimisedAggregate {
36 None,
37 Count,
38 CountFunction,
39 MathMax,
40 MathMin,
41 MathSum,
42 MathMean,
43 TimeMax,
44 TimeMin,
45}
46
47impl PartialOrd for Function {
48 #[inline]
49 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
50 None
51 }
52}
53
54impl Function {
55 pub fn name(&self) -> Option<&str> {
57 match self {
58 Self::Normal(n, _) => Some(n.as_str()),
59 Self::Custom(n, _) => Some(n.as_str()),
60 _ => None,
61 }
62 }
63 pub fn args(&self) -> &[Value] {
65 match self {
66 Self::Normal(_, a) => a,
67 Self::Custom(_, a) => a,
68 _ => &[],
69 }
70 }
71 pub fn to_idiom(&self) -> Idiom {
73 match self {
74 Self::Script(_, _) => "function".to_string().into(),
75 Self::Normal(f, _) => f.to_owned().into(),
76 Self::Custom(f, _) => format!("fn::{f}").into(),
77 }
78 }
79 pub fn aggregate(&self, val: Value) -> Self {
81 match self {
82 Self::Normal(n, a) => {
83 let mut a = a.to_owned();
84 match a.len() {
85 0 => a.insert(0, val),
86 _ => {
87 a.remove(0);
88 a.insert(0, val);
89 }
90 }
91 Self::Normal(n.to_owned(), a)
92 }
93 _ => unreachable!(),
94 }
95 }
96 pub fn is_custom(&self) -> bool {
98 matches!(self, Self::Custom(_, _))
99 }
100
101 pub fn is_script(&self) -> bool {
103 matches!(self, Self::Script(_, _))
104 }
105
106 pub fn is_static(&self) -> bool {
108 match self {
109 Self::Normal(_, a) => a.iter().all(Value::is_static),
110 _ => false,
111 }
112 }
113
114 pub fn is_rolling(&self) -> bool {
116 match self {
117 Self::Normal(f, _) if f == "count" => true,
118 Self::Normal(f, _) if f == "math::max" => true,
119 Self::Normal(f, _) if f == "math::mean" => true,
120 Self::Normal(f, _) if f == "math::min" => true,
121 Self::Normal(f, _) if f == "math::sum" => true,
122 Self::Normal(f, _) if f == "time::max" => true,
123 Self::Normal(f, _) if f == "time::min" => true,
124 _ => false,
125 }
126 }
127 pub fn is_aggregate(&self) -> bool {
129 match self {
130 Self::Normal(f, _) if f == "array::distinct" => true,
131 Self::Normal(f, _) if f == "array::first" => true,
132 Self::Normal(f, _) if f == "array::flatten" => true,
133 Self::Normal(f, _) if f == "array::group" => true,
134 Self::Normal(f, _) if f == "array::last" => true,
135 Self::Normal(f, _) if f == "count" => true,
136 Self::Normal(f, _) if f == "math::bottom" => true,
137 Self::Normal(f, _) if f == "math::interquartile" => true,
138 Self::Normal(f, _) if f == "math::max" => true,
139 Self::Normal(f, _) if f == "math::mean" => true,
140 Self::Normal(f, _) if f == "math::median" => true,
141 Self::Normal(f, _) if f == "math::midhinge" => true,
142 Self::Normal(f, _) if f == "math::min" => true,
143 Self::Normal(f, _) if f == "math::mode" => true,
144 Self::Normal(f, _) if f == "math::nearestrank" => true,
145 Self::Normal(f, _) if f == "math::percentile" => true,
146 Self::Normal(f, _) if f == "math::sample" => true,
147 Self::Normal(f, _) if f == "math::spread" => true,
148 Self::Normal(f, _) if f == "math::stddev" => true,
149 Self::Normal(f, _) if f == "math::sum" => true,
150 Self::Normal(f, _) if f == "math::top" => true,
151 Self::Normal(f, _) if f == "math::trimean" => true,
152 Self::Normal(f, _) if f == "math::variance" => true,
153 Self::Normal(f, _) if f == "time::max" => true,
154 Self::Normal(f, _) if f == "time::min" => true,
155 _ => false,
156 }
157 }
158 pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
159 match self {
160 Self::Normal(f, v) if f == "count" => {
161 if v.is_empty() {
162 OptimisedAggregate::Count
163 } else {
164 OptimisedAggregate::CountFunction
165 }
166 }
167 Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
168 Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
169 Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
170 Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
171 Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
172 Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
173 _ => OptimisedAggregate::None,
174 }
175 }
176}
177
178impl Function {
179 pub(crate) async fn compute(
183 &self,
184 stk: &mut Stk,
185 ctx: &Context<'_>,
186 opt: &Options,
187 doc: Option<&CursorDoc<'_>>,
188 ) -> Result<Value, Error> {
189 let opt = &opt.new_with_futures(true);
191 match self {
193 Self::Normal(s, x) => {
194 ctx.check_allowed_function(s)?;
196 let a = stk
198 .scope(|scope| {
199 try_join_all(
200 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
201 )
202 })
203 .await?;
204 fnc::run(stk, ctx, opt, doc, s, a).await
206 }
207 Self::Custom(s, x) => {
208 let name = format!("fn::{s}");
210 ctx.check_allowed_function(name.as_str())?;
212 let val = {
214 let mut run = ctx.tx_lock().await;
216 let val = run.get_and_cache_db_function(opt.ns()?, opt.db()?, s).await?;
218 drop(run);
219 val
220 };
221 if opt.check_perms(Action::View)? {
223 match &val.permissions {
224 Permission::Full => (),
225 Permission::None => {
226 return Err(Error::FunctionPermissions {
227 name: s.to_owned(),
228 })
229 }
230 Permission::Specific(e) => {
231 let opt = &opt.new_with_perms(false);
233 if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
235 return Err(Error::FunctionPermissions {
236 name: s.to_owned(),
237 });
238 }
239 }
240 }
241 }
242 let max_args_len = val.args.len();
244 let mut min_args_len = 0;
246 val.args.iter().rev().for_each(|(_, kind)| match kind {
248 Kind::Option(_) if min_args_len == 0 => {}
249 _ => min_args_len += 1,
250 });
251 if x.len() < min_args_len || max_args_len < x.len() {
253 return Err(Error::InvalidArguments {
254 name: format!("fn::{}", val.name),
255 message: match (min_args_len, max_args_len) {
256 (1, 1) => String::from("The function expects 1 argument."),
257 (r, t) if r == t => format!("The function expects {r} arguments."),
258 (r, t) => format!("The function expects {r} to {t} arguments."),
259 },
260 });
261 }
262 let a = stk
264 .scope(|scope| {
265 try_join_all(
266 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
267 )
268 })
269 .await?;
270 let mut ctx = Context::new(ctx);
272 for (val, (name, kind)) in a.into_iter().zip(&val.args) {
274 ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
275 }
276 stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await
278 }
279 #[allow(unused_variables)]
280 Self::Script(s, x) => {
281 #[cfg(feature = "scripting")]
282 {
283 ctx.check_allowed_scripting()?;
285 let a = stk
287 .scope(|scope| {
288 try_join_all(
289 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
290 )
291 })
292 .await?;
293 fnc::script::run(ctx, opt, doc, s, a).await
295 }
296 #[cfg(not(feature = "scripting"))]
297 {
298 Err(Error::InvalidScript {
299 message: String::from("Embedded functions are not enabled."),
300 })
301 }
302 }
303 }
304 }
305}
306
307impl fmt::Display for Function {
308 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
309 match self {
310 Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
311 Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
312 Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
313 }
314 }
315}