1use crate::ctx::Context;
2use crate::dbs::{Options, Transaction};
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fmt::Fmt;
6use crate::fnc;
7use crate::iam::Action;
8use crate::idiom::Idiom;
9use crate::script::Script;
10use crate::value::Value;
11use crate::Permission;
12use async_recursion::async_recursion;
13use futures::future::try_join_all;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::crate::Function";
22
23#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
24#[serde(rename = "$surrealdb::private::crate::Function")]
25#[revisioned(revision = 1)]
26pub enum Function {
27 Normal(String, Vec<Value>),
28 Custom(String, Vec<Value>),
29 Script(Script, Vec<Value>),
30 }
32
33impl PartialOrd for Function {
34 #[inline]
35 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
36 None
37 }
38}
39
40impl Function {
41 pub fn name(&self) -> Option<&str> {
43 match self {
44 Self::Normal(n, _) => Some(n.as_str()),
45 Self::Custom(n, _) => Some(n.as_str()),
46 _ => None,
47 }
48 }
49 pub fn args(&self) -> &[Value] {
51 match self {
52 Self::Normal(_, a) => a,
53 Self::Custom(_, a) => a,
54 _ => &[],
55 }
56 }
57 pub fn to_idiom(&self) -> Idiom {
59 match self {
60 Self::Script(_, _) => "function".to_string().into(),
61 Self::Normal(f, _) => f.to_owned().into(),
62 Self::Custom(f, _) => format!("fn::{f}").into(),
63 }
64 }
65 pub fn aggregate(&self, val: Value) -> Self {
67 match self {
68 Self::Normal(n, a) => {
69 let mut a = a.to_owned();
70 match a.len() {
71 0 => a.insert(0, val),
72 _ => {
73 a.remove(0);
74 a.insert(0, val);
75 }
76 }
77 Self::Normal(n.to_owned(), a)
78 }
79 _ => unreachable!(),
80 }
81 }
82 pub fn is_custom(&self) -> bool {
84 matches!(self, Self::Custom(_, _))
85 }
86
87 pub fn is_script(&self) -> bool {
89 matches!(self, Self::Script(_, _))
90 }
91
92 pub fn is_static(&self) -> bool {
94 match self {
95 Self::Normal(_, a) => a.iter().all(Value::is_static),
96 _ => false,
97 }
98 }
99
100 pub fn is_rolling(&self) -> bool {
102 match self {
103 Self::Normal(f, _) if f == "count" => true,
104 Self::Normal(f, _) if f == "math::max" => true,
105 Self::Normal(f, _) if f == "math::mean" => true,
106 Self::Normal(f, _) if f == "math::min" => true,
107 Self::Normal(f, _) if f == "math::sum" => true,
108 Self::Normal(f, _) if f == "time::max" => true,
109 Self::Normal(f, _) if f == "time::min" => true,
110 _ => false,
111 }
112 }
113 pub fn is_aggregate(&self) -> bool {
115 match self {
116 Self::Normal(f, _) if f == "array::distinct" => true,
117 Self::Normal(f, _) if f == "array::first" => true,
118 Self::Normal(f, _) if f == "array::flatten" => true,
119 Self::Normal(f, _) if f == "array::group" => true,
120 Self::Normal(f, _) if f == "array::last" => true,
121 Self::Normal(f, _) if f == "count" => true,
122 Self::Normal(f, _) if f == "math::bottom" => true,
123 Self::Normal(f, _) if f == "math::interquartile" => true,
124 Self::Normal(f, _) if f == "math::max" => true,
125 Self::Normal(f, _) if f == "math::mean" => true,
126 Self::Normal(f, _) if f == "math::median" => true,
127 Self::Normal(f, _) if f == "math::midhinge" => true,
128 Self::Normal(f, _) if f == "math::min" => true,
129 Self::Normal(f, _) if f == "math::mode" => true,
130 Self::Normal(f, _) if f == "math::nearestrank" => true,
131 Self::Normal(f, _) if f == "math::percentile" => true,
132 Self::Normal(f, _) if f == "math::sample" => true,
133 Self::Normal(f, _) if f == "math::spread" => true,
134 Self::Normal(f, _) if f == "math::stddev" => true,
135 Self::Normal(f, _) if f == "math::sum" => true,
136 Self::Normal(f, _) if f == "math::top" => true,
137 Self::Normal(f, _) if f == "math::trimean" => true,
138 Self::Normal(f, _) if f == "math::variance" => true,
139 Self::Normal(f, _) if f == "time::max" => true,
140 Self::Normal(f, _) if f == "time::min" => true,
141 _ => false,
142 }
143 }
144}
145
146impl Function {
147 #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
149 #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
150 pub(crate) async fn compute(
151 &self,
152 ctx: &Context<'_>,
153 opt: &Options,
154 txn: &Transaction,
155 doc: Option<&'async_recursion CursorDoc<'_>>,
156 ) -> Result<Value, Error> {
157 let opt = &opt.new_with_futures(true);
159 match self {
161 Self::Normal(s, x) => {
162 ctx.check_allowed_function(s)?;
164 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
166 fnc::run(ctx, opt, txn, doc, s, a).await
168 }
169 Self::Custom(s, x) => {
170 let name = format!("fn::{s}");
172 ctx.check_allowed_function(name.as_str())?;
174 let val = {
176 let mut run = txn.lock().await;
178 run.get_and_cache_db_function(opt.ns(), opt.db(), s).await?
180 };
181 if opt.check_perms(Action::View) {
183 match &val.permissions {
184 Permission::Full => (),
185 Permission::None => {
186 return Err(Error::FunctionPermissions {
187 name: s.to_owned(),
188 })
189 }
190 Permission::Specific(e) => {
191 let opt = &opt.new_with_perms(false);
193 if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
195 return Err(Error::FunctionPermissions {
196 name: s.to_owned(),
197 });
198 }
199 }
200 }
201 }
202 let max_args_len = val.args.len();
204 let mut min_args_len = 0;
206 val.args.iter().rev().for_each(|(_, kind)| match kind {
208 Kind::Option(_) if min_args_len == 0 => {}
209 _ => min_args_len += 1,
210 });
211 if x.len() < min_args_len || max_args_len < x.len() {
213 return Err(Error::InvalidArguments {
214 name: format!("fn::{}", val.name),
215 message: match (min_args_len, max_args_len) {
216 (1, 1) => String::from("The function expects 1 argument."),
217 (r, t) if r == t => format!("The function expects {r} arguments."),
218 (r, t) => format!("The function expects {r} to {t} arguments."),
219 },
220 });
221 }
222 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
224 let mut ctx = Context::new(ctx);
226 for (val, (name, kind)) in a.into_iter().zip(&val.args) {
228 ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
229 }
230 val.block.compute(&ctx, opt, txn, doc).await
232 }
233 #[allow(unused_variables)]
234 Self::Script(s, x) => {
235 #[cfg(feature = "scripting")]
236 {
237 ctx.check_allowed_scripting()?;
239 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
241 fnc::script::run(ctx, opt, txn, doc, s, a).await
243 }
244 #[cfg(not(feature = "scripting"))]
245 {
246 Err(Error::InvalidScript {
247 message: String::from("Embedded functions are not enabled."),
248 })
249 }
250 }
251 }
252 }
253}
254
255impl fmt::Display for Function {
256 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
257 match self {
258 Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
259 Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
260 Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
261 }
262 }
263}