tf_provider/
function.rs

1// This file is part of the tf-provider project
2//
3// Copyright (C) ANEO, 2024-2024. All rights reserved.
4//
5// Licensed under the Apache License, Version 2.0 (the "License")
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17//! [`Function`] module
18use async_trait::async_trait;
19use serde::{de, Deserialize, Serialize};
20
21use crate::{raw::RawValue, schema::FunctionSchema, Diagnostics};
22
23#[async_trait]
24/// Trait for implementing a function with automatic serialization/deserialization
25///
26/// See also: [`DynamicFunction`]
27pub trait Function: Send + Sync {
28    /// Function Input
29    ///
30    /// The input will be automatically serialized/deserialized at the border of the request.
31    type Input<'a>: Deserialize<'a> + Send;
32
33    /// Function Output
34    ///
35    /// The output will be automatically serialized/deserialized at the border of the request.
36    type Output<'a>: Serialize + Send;
37
38    /// Get the schema of the function
39    ///
40    /// # Arguments
41    ///
42    /// * `diags` - Diagnostics to record warnings and errors that occured when getting back the schema
43    ///
44    /// # Remarks
45    ///
46    /// The return is ignored if there is an error in diagnostics.
47    /// If the return is [`None`], an ad-hoc error is added to diagnostics.
48    fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema>;
49
50    /// Call Function
51    ///
52    /// # Arguments
53    ///
54    /// * `diags` - Diagnostics to record warnings and errors that occured when calling the function
55    /// * `params` - Function parameters packed into the input type
56    ///
57    /// # Remarks
58    ///
59    /// The return is ignored if there is an error in diagnostics.
60    /// If the return is [`None`], an ad-hoc error is added to diagnostics.
61    async fn call<'a>(
62        &self,
63        diags: &mut Diagnostics,
64        params: Self::Input<'a>,
65    ) -> Option<Self::Output<'a>>;
66}
67
68#[async_trait]
69/// Trait for implementing a function *without* automatic serialization/deserialization
70///
71/// See also: [`Function`]
72pub trait DynamicFunction: Send + Sync {
73    /// Get the schema of the function
74    ///
75    /// # Arguments
76    ///
77    /// * `diags` - Diagnostics to record warnings and errors that occured when getting back the schema
78    ///
79    /// # Remarks
80    ///
81    /// The return is ignored if there is an error in diagnostics.
82    /// If the return is [`None`], an ad-hoc error is added to diagnostics.
83    fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema>;
84
85    /// Call Function
86    ///
87    /// # Arguments
88    ///
89    /// * `diags` - Diagnostics to record warnings and errors that occured when calling the function
90    /// * `params` - Function parameters
91    ///
92    /// # Remarks
93    ///
94    /// The return is ignored if there is an error in diagnostics.
95    /// If the return is [`None`], an ad-hoc error is added to diagnostics.
96    async fn call<'a>(&self, diags: &mut Diagnostics, params: Vec<RawValue>) -> Option<RawValue>;
97}
98
99#[async_trait]
100impl<T: Function> DynamicFunction for T {
101    /// Get the schema of the function
102    fn schema(&self, diags: &mut Diagnostics) -> Option<FunctionSchema> {
103        <T as Function>::schema(self, diags)
104    }
105    /// CallFunction
106    async fn call<'a>(&self, diags: &mut Diagnostics, params: Vec<RawValue>) -> Option<RawValue> {
107        let mut decoder = Decoder {
108            params: &params,
109            index: 0,
110        };
111        match Deserialize::deserialize(&mut decoder) {
112            Ok(params) => {
113                let value = <T as Function>::call(self, diags, params).await?;
114                RawValue::serialize(diags, &value)
115            }
116            Err(DecoderError::UnsupportedFormat) => {
117                diags.root_error("Provider Bug: Unsupported format", "This is a provider bug.\nThe input type is not a struct, a vec, or a tuple.\nTherefore, it can not be parsed as a list of arguments.");
118                None
119            }
120            Err(DecoderError::MsgPackError(index, err)) => {
121                diags.function_error(index as i64, err.to_string());
122                None
123            }
124            Err(DecoderError::JsonError(index, err)) => {
125                diags.function_error(index as i64, err.to_string());
126                None
127            }
128            Err(DecoderError::Custom(msg)) => {
129                diags.root_error_short(msg);
130                None
131            }
132        }
133    }
134}
135
136impl<T: Function + 'static> From<T> for Box<dyn DynamicFunction> {
137    fn from(value: T) -> Self {
138        Box::new(value)
139    }
140}
141
142struct Decoder<'de> {
143    params: &'de [RawValue],
144    index: usize,
145}
146
147#[derive(Debug)]
148enum DecoderError {
149    UnsupportedFormat,
150    JsonError(usize, serde_json::Error),
151    MsgPackError(usize, rmp_serde::decode::Error),
152    Custom(String),
153}
154
155impl std::fmt::Display for DecoderError {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        match self {
158            Self::UnsupportedFormat => f.write_str("Bad format"),
159            Self::JsonError(_, err) => err.fmt(f),
160            Self::MsgPackError(_, err) => err.fmt(f),
161            Self::Custom(msg) => f.write_str(msg),
162        }
163    }
164}
165impl std::error::Error for DecoderError {
166    fn source(&self) -> Option<&(dyn de::StdError + 'static)> {
167        match self {
168            Self::JsonError(_, err) => err.source(),
169            Self::MsgPackError(_, err) => err.source(),
170            _ => None,
171        }
172    }
173}
174
175impl serde::de::Error for DecoderError {
176    fn custom<T>(msg: T) -> Self
177    where
178        T: std::fmt::Display,
179    {
180        Self::Custom(msg.to_string())
181    }
182}
183
184macro_rules! deserialize {
185    ($deserialize:ident) => {
186        fn $deserialize<V: de::Visitor<'de>>(self, _visitor: V) -> Result<V::Value, Self::Error> {
187            Err(DecoderError::UnsupportedFormat)
188        }
189    };
190}
191
192impl<'de, 'a> de::Deserializer<'de> for &'a mut Decoder<'de> {
193    type Error = DecoderError;
194
195    deserialize!(deserialize_bool);
196    deserialize!(deserialize_i8);
197    deserialize!(deserialize_i16);
198    deserialize!(deserialize_i32);
199    deserialize!(deserialize_i64);
200    deserialize!(deserialize_i128);
201    deserialize!(deserialize_u8);
202    deserialize!(deserialize_u16);
203    deserialize!(deserialize_u32);
204    deserialize!(deserialize_u64);
205    deserialize!(deserialize_u128);
206    deserialize!(deserialize_f32);
207    deserialize!(deserialize_f64);
208    deserialize!(deserialize_char);
209    deserialize!(deserialize_str);
210    deserialize!(deserialize_string);
211    deserialize!(deserialize_bytes);
212    deserialize!(deserialize_byte_buf);
213    deserialize!(deserialize_option);
214    deserialize!(deserialize_unit);
215
216    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
217    where
218        V: de::Visitor<'de>,
219    {
220        visitor.visit_seq(self)
221    }
222
223    fn deserialize_unit_struct<V>(
224        self,
225        _name: &'static str,
226        visitor: V,
227    ) -> Result<V::Value, Self::Error>
228    where
229        V: de::Visitor<'de>,
230    {
231        visitor.visit_unit()
232    }
233
234    fn deserialize_newtype_struct<V>(
235        self,
236        _name: &'static str,
237        visitor: V,
238    ) -> Result<V::Value, Self::Error>
239    where
240        V: de::Visitor<'de>,
241    {
242        visitor.visit_newtype_struct(self)
243    }
244
245    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
246    where
247        V: de::Visitor<'de>,
248    {
249        visitor.visit_seq(self)
250    }
251
252    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
253    where
254        V: de::Visitor<'de>,
255    {
256        visitor.visit_seq(self)
257    }
258
259    fn deserialize_tuple_struct<V>(
260        self,
261        _name: &'static str,
262        _len: usize,
263        visitor: V,
264    ) -> Result<V::Value, Self::Error>
265    where
266        V: de::Visitor<'de>,
267    {
268        visitor.visit_seq(self)
269    }
270
271    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
272    where
273        V: de::Visitor<'de>,
274    {
275        Err(DecoderError::UnsupportedFormat)
276    }
277
278    fn deserialize_struct<V>(
279        self,
280        _name: &'static str,
281        _fields: &'static [&'static str],
282        visitor: V,
283    ) -> Result<V::Value, Self::Error>
284    where
285        V: de::Visitor<'de>,
286    {
287        visitor.visit_seq(self)
288    }
289
290    fn deserialize_enum<V>(
291        self,
292        _name: &'static str,
293        _variants: &'static [&'static str],
294        _visitor: V,
295    ) -> Result<V::Value, Self::Error>
296    where
297        V: de::Visitor<'de>,
298    {
299        Err(DecoderError::UnsupportedFormat)
300    }
301
302    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
303    where
304        V: de::Visitor<'de>,
305    {
306        Err(DecoderError::UnsupportedFormat)
307    }
308
309    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
310    where
311        V: de::Visitor<'de>,
312    {
313        visitor.visit_seq(self)
314    }
315}
316
317impl<'de> de::SeqAccess<'de> for Decoder<'de> {
318    type Error = DecoderError;
319
320    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
321    where
322        T: de::DeserializeSeed<'de>,
323    {
324        match self.params {
325            [] => Ok(None),
326            [param, params @ ..] => {
327                let index = self.index;
328                self.index += 1;
329                self.params = params;
330                match param {
331                    RawValue::MessagePack(bytes) => {
332                        let mut deserializer =
333                            rmp_serde::Deserializer::from_read_ref(bytes.as_slice());
334                        match seed.deserialize(&mut deserializer) {
335                            Ok(value) => Ok(Some(value)),
336                            Err(err) => Err(DecoderError::MsgPackError(index, err)),
337                        }
338                    }
339                    RawValue::Json(bytes) => {
340                        let mut deserializer =
341                            serde_json::Deserializer::from_slice(bytes.as_slice());
342                        match seed.deserialize(&mut deserializer) {
343                            Ok(value) => Ok(Some(value)),
344                            Err(err) => Err(DecoderError::JsonError(index, err)),
345                        }
346                    }
347                }
348            }
349        }
350    }
351
352    fn size_hint(&self) -> Option<usize> {
353        Some(self.params.len())
354    }
355}