1use async_trait::async_trait;
19use serde::{de, Deserialize, Serialize};
20
21use crate::{raw::RawValue, schema::FunctionSchema, Diagnostics};
22
23#[async_trait]
24pub trait Function: Send + Sync {
28 type Input<'a>: Deserialize<'a> + Send;
32
33 type Output<'a>: Serialize + Send;
37
38 fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema>;
49
50 async fn call<'a>(
62 &self,
63 diags: &mut Diagnostics,
64 params: Self::Input<'a>,
65 ) -> Option<Self::Output<'a>>;
66}
67
68#[async_trait]
69pub trait DynamicFunction: Send + Sync {
73 fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema>;
84
85 async fn call<'a>(&self, diags: &mut Diagnostics, params: Vec<RawValue>) -> Option<RawValue>;
97}
98
99#[async_trait]
100impl<T: Function> DynamicFunction for T {
101 fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema> {
103 <T as Function>::schema(self, diags)
104 }
105 async fn call<'a>(&self, diags: &mut Diagnostics, params: Vec<RawValue>) -> Option<RawValue> {
107 let mut decoder = Decoder {
108 params: ¶ms,
109 index: 0,
110 };
111 match Deserialize::deserialize(&mut decoder) {
112 Ok(params) => {
113 let value = <T as Function>::call(self, diags, params).await?;
114 RawValue::serialize(diags, &value)
115 }
116 Err(DecoderError::UnsupportedFormat) => {
117 diags.root_error("Provider Bug: Unsupported format", "This is a provider bug.\nThe input type is not a struct, a vec, or a tuple.\nTherefore, it can not be parsed as a list of arguments.");
118 None
119 }
120 Err(DecoderError::MsgPackError(index, err)) => {
121 diags.function_error(index as i64, err.to_string());
122 None
123 }
124 Err(DecoderError::JsonError(index, err)) => {
125 diags.function_error(index as i64, err.to_string());
126 None
127 }
128 Err(DecoderError::Custom(msg)) => {
129 diags.root_error_short(msg);
130 None
131 }
132 }
133 }
134}
135
136impl<T: Function + 'static> From<T> for Box<dyn DynamicFunction> {
137 fn from(value: T) -> Self {
138 Box::new(value)
139 }
140}
141
142struct Decoder<'de> {
143 params: &'de [RawValue],
144 index: usize,
145}
146
147#[derive(Debug)]
148enum DecoderError {
149 UnsupportedFormat,
150 JsonError(usize, serde_json::Error),
151 MsgPackError(usize, rmp_serde::decode::Error),
152 Custom(String),
153}
154
155impl std::fmt::Display for DecoderError {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 Self::UnsupportedFormat => f.write_str("Bad format"),
159 Self::JsonError(_, err) => err.fmt(f),
160 Self::MsgPackError(_, err) => err.fmt(f),
161 Self::Custom(msg) => f.write_str(msg),
162 }
163 }
164}
165impl std::error::Error for DecoderError {
166 fn source(&self) -> Option<&(dyn de::StdError + 'static)> {
167 match self {
168 Self::JsonError(_, err) => err.source(),
169 Self::MsgPackError(_, err) => err.source(),
170 _ => None,
171 }
172 }
173}
174
175impl serde::de::Error for DecoderError {
176 fn custom<T>(msg: T) -> Self
177 where
178 T: std::fmt::Display,
179 {
180 Self::Custom(msg.to_string())
181 }
182}
183
184macro_rules! deserialize {
185 ($deserialize:ident) => {
186 fn $deserialize<V: de::Visitor<'de>>(self, _visitor: V) -> Result<V::Value, Self::Error> {
187 Err(DecoderError::UnsupportedFormat)
188 }
189 };
190}
191
192impl<'de, 'a> de::Deserializer<'de> for &'a mut Decoder<'de> {
193 type Error = DecoderError;
194
195 deserialize!(deserialize_bool);
196 deserialize!(deserialize_i8);
197 deserialize!(deserialize_i16);
198 deserialize!(deserialize_i32);
199 deserialize!(deserialize_i64);
200 deserialize!(deserialize_i128);
201 deserialize!(deserialize_u8);
202 deserialize!(deserialize_u16);
203 deserialize!(deserialize_u32);
204 deserialize!(deserialize_u64);
205 deserialize!(deserialize_u128);
206 deserialize!(deserialize_f32);
207 deserialize!(deserialize_f64);
208 deserialize!(deserialize_char);
209 deserialize!(deserialize_str);
210 deserialize!(deserialize_string);
211 deserialize!(deserialize_bytes);
212 deserialize!(deserialize_byte_buf);
213 deserialize!(deserialize_option);
214 deserialize!(deserialize_unit);
215
216 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
217 where
218 V: de::Visitor<'de>,
219 {
220 visitor.visit_seq(self)
221 }
222
223 fn deserialize_unit_struct<V>(
224 self,
225 _name: &'static str,
226 visitor: V,
227 ) -> Result<V::Value, Self::Error>
228 where
229 V: de::Visitor<'de>,
230 {
231 visitor.visit_unit()
232 }
233
234 fn deserialize_newtype_struct<V>(
235 self,
236 _name: &'static str,
237 visitor: V,
238 ) -> Result<V::Value, Self::Error>
239 where
240 V: de::Visitor<'de>,
241 {
242 visitor.visit_newtype_struct(self)
243 }
244
245 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
246 where
247 V: de::Visitor<'de>,
248 {
249 visitor.visit_seq(self)
250 }
251
252 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
253 where
254 V: de::Visitor<'de>,
255 {
256 visitor.visit_seq(self)
257 }
258
259 fn deserialize_tuple_struct<V>(
260 self,
261 _name: &'static str,
262 _len: usize,
263 visitor: V,
264 ) -> Result<V::Value, Self::Error>
265 where
266 V: de::Visitor<'de>,
267 {
268 visitor.visit_seq(self)
269 }
270
271 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
272 where
273 V: de::Visitor<'de>,
274 {
275 Err(DecoderError::UnsupportedFormat)
276 }
277
278 fn deserialize_struct<V>(
279 self,
280 _name: &'static str,
281 _fields: &'static [&'static str],
282 visitor: V,
283 ) -> Result<V::Value, Self::Error>
284 where
285 V: de::Visitor<'de>,
286 {
287 visitor.visit_seq(self)
288 }
289
290 fn deserialize_enum<V>(
291 self,
292 _name: &'static str,
293 _variants: &'static [&'static str],
294 _visitor: V,
295 ) -> Result<V::Value, Self::Error>
296 where
297 V: de::Visitor<'de>,
298 {
299 Err(DecoderError::UnsupportedFormat)
300 }
301
302 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
303 where
304 V: de::Visitor<'de>,
305 {
306 Err(DecoderError::UnsupportedFormat)
307 }
308
309 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
310 where
311 V: de::Visitor<'de>,
312 {
313 visitor.visit_seq(self)
314 }
315}
316
317impl<'de> de::SeqAccess<'de> for Decoder<'de> {
318 type Error = DecoderError;
319
320 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
321 where
322 T: de::DeserializeSeed<'de>,
323 {
324 match self.params {
325 [] => Ok(None),
326 [param, params @ ..] => {
327 let index = self.index;
328 self.index += 1;
329 self.params = params;
330 match param {
331 RawValue::MessagePack(bytes) => {
332 let mut deserializer =
333 rmp_serde::Deserializer::from_read_ref(bytes.as_slice());
334 match seed.deserialize(&mut deserializer) {
335 Ok(value) => Ok(Some(value)),
336 Err(err) => Err(DecoderError::MsgPackError(index, err)),
337 }
338 }
339 RawValue::Json(bytes) => {
340 let mut deserializer =
341 serde_json::Deserializer::from_slice(bytes.as_slice());
342 match seed.deserialize(&mut deserializer) {
343 Ok(value) => Ok(Some(value)),
344 Err(err) => Err(DecoderError::JsonError(index, err)),
345 }
346 }
347 }
348 }
349 }
350 }
351
352 fn size_hint(&self) -> Option<usize> {
353 Some(self.params.len())
354 }
355}