yajrc/
lib.rs

1use std::{borrow::Cow, marker::PhantomData};
2
3use serde::{
4    de::{MapAccess, SeqAccess, Visitor},
5    Deserialize, Deserializer, Serialize, Serializer,
6};
7use serde_json::{Map, Value};
8
9pub const GENERAL_ERROR: RpcError = RpcError {
10    code: -1,
11    message: Cow::Borrowed("General error"),
12    data: None,
13};
14pub const PARSE_ERROR: RpcError = RpcError {
15    code: -32700,
16    message: Cow::Borrowed("Parse error"),
17    data: None,
18};
19pub const INVALID_REQUEST_ERROR: RpcError = RpcError {
20    code: -32600,
21    message: Cow::Borrowed("Invalid Request"),
22    data: None,
23};
24pub const METHOD_NOT_FOUND_ERROR: RpcError = RpcError {
25    code: -32601,
26    message: Cow::Borrowed("Method not found"),
27    data: None,
28};
29pub const INVALID_PARAMS_ERROR: RpcError = RpcError {
30    code: -32602,
31    message: Cow::Borrowed("Invalid params"),
32    data: None,
33};
34pub const INTERNAL_ERROR: RpcError = RpcError {
35    code: -32603,
36    message: Cow::Borrowed("Internal error"),
37    data: None,
38};
39
40fn deserialize_some<'de, D: Deserializer<'de>, T: Deserialize<'de>>(
41    deserializer: D,
42) -> Result<Option<T>, D::Error> {
43    T::deserialize(deserializer).map(Some)
44}
45
46pub enum SingleOrBatchRpcRequest<T: RpcMethod = AnyRpcMethod<'static>> {
47    Single(RpcRequest<T>),
48    Batch(Vec<RpcRequest<T>>),
49}
50impl<T> Serialize for SingleOrBatchRpcRequest<T>
51where
52    T: RpcMethod,
53    RpcRequest<T>: Serialize,
54{
55    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
56        match self {
57            SingleOrBatchRpcRequest::Single(s) => s.serialize(serializer),
58            SingleOrBatchRpcRequest::Batch(b) => b.serialize(serializer),
59        }
60    }
61}
62impl<'de, T> Deserialize<'de> for SingleOrBatchRpcRequest<T>
63where
64    T: RpcMethod + Deserialize<'de>,
65    T::Params: Deserialize<'de>,
66{
67    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
68        struct ReqVisitor<T>(PhantomData<T>);
69        impl<'de, T> Visitor<'de> for ReqVisitor<T>
70        where
71            T: RpcMethod + Deserialize<'de>,
72            T::Params: Deserialize<'de>,
73        {
74            type Value = SingleOrBatchRpcRequest<T>;
75            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
76                write!(
77                    formatter,
78                    "a single rpc request, or a batch of rpc requests"
79                )
80            }
81            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
82                let mut res = Vec::with_capacity(seq.size_hint().unwrap_or(16));
83                while let Some(elem) = seq.next_element()? {
84                    res.push(elem);
85                }
86                Ok(SingleOrBatchRpcRequest::Batch(res))
87            }
88            fn visit_map<A: serde::de::MapAccess<'de>>(
89                self,
90                mut map: A,
91            ) -> Result<Self::Value, A::Error> {
92                let mut id = None;
93                let mut method = None;
94                let mut params = None;
95                while let Some(key) = map.next_key::<String>()? {
96                    match key.as_str() {
97                        "id" => {
98                            id = map.next_value()?;
99                        }
100                        "method" => {
101                            method = map.next_value()?;
102                        }
103                        "params" => {
104                            params = map.next_value()?;
105                        }
106                        _ => {
107                            let _: Value = map.next_value()?;
108                        }
109                    }
110                }
111                Ok(SingleOrBatchRpcRequest::Single(RpcRequest {
112                    id,
113                    method: method.ok_or_else(|| serde::de::Error::missing_field("method"))?,
114                    params: params.ok_or_else(|| serde::de::Error::missing_field("params"))?,
115                }))
116            }
117        }
118        deserializer.deserialize_any(ReqVisitor::<T>(PhantomData))
119    }
120}
121
122pub trait RpcMethod {
123    type Params;
124    type Response;
125    fn as_str<'a>(&'a self) -> &'a str;
126}
127
128pub struct GenericRpcMethod<Method: AsRef<str>, Params = AnyParams, Response = Value> {
129    method: Method,
130    params: PhantomData<Params>,
131    response: PhantomData<Response>,
132}
133impl<Method: AsRef<str> + Clone, Params, Response> Clone
134    for GenericRpcMethod<Method, Params, Response>
135{
136    fn clone(&self) -> Self {
137        Self {
138            method: self.method.clone(),
139            params: PhantomData,
140            response: PhantomData,
141        }
142    }
143}
144impl<Method: AsRef<str>, Params, Response> std::fmt::Debug
145    for GenericRpcMethod<Method, Params, Response>
146{
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        write!(f, "{:?}", self.method.as_ref())
149    }
150}
151impl<Method: AsRef<str>, Params, Response> GenericRpcMethod<Method, Params, Response> {
152    pub fn new(method: Method) -> Self {
153        GenericRpcMethod {
154            method,
155            params: PhantomData,
156            response: PhantomData,
157        }
158    }
159}
160impl<Method: AsRef<str>, Params, Response> Serialize
161    for GenericRpcMethod<Method, Params, Response>
162{
163    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
164    where
165        S: Serializer,
166    {
167        Serialize::serialize(self.method.as_ref(), serializer)
168    }
169}
170impl<'de, Method: AsRef<str> + Deserialize<'de>, Params, Response> Deserialize<'de>
171    for GenericRpcMethod<Method, Params, Response>
172{
173    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174    where
175        D: Deserializer<'de>,
176    {
177        Ok(GenericRpcMethod::new(Deserialize::deserialize(
178            deserializer,
179        )?))
180    }
181}
182impl<Method: AsRef<str>, Params, Response> RpcMethod
183    for GenericRpcMethod<Method, Params, Response>
184{
185    type Params = Params;
186    type Response = Response;
187    fn as_str<'a>(&'a self) -> &'a str {
188        self.method.as_ref()
189    }
190}
191
192pub type AnyRpcMethod<'a> = GenericRpcMethod<Cow<'a, str>>;
193
194#[derive(Debug)]
195pub enum AnyParams {
196    Positional(Vec<Value>),
197    Named(Map<String, Value>),
198}
199impl Serialize for AnyParams {
200    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
201        match self {
202            AnyParams::Positional(s) => s.serialize(serializer),
203            AnyParams::Named(b) => b.serialize(serializer),
204        }
205    }
206}
207impl<'de> Deserialize<'de> for AnyParams {
208    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
209        struct ParamVisitor;
210        impl<'de> Visitor<'de> for ParamVisitor {
211            type Value = AnyParams;
212            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
213                write!(formatter, "an array or object")
214            }
215            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
216                let mut res = Vec::with_capacity(seq.size_hint().unwrap_or(16));
217                while let Some(elem) = seq.next_element()? {
218                    res.push(elem);
219                }
220                Ok(AnyParams::Positional(res))
221            }
222            fn visit_map<A: serde::de::MapAccess<'de>>(
223                self,
224                mut map: A,
225            ) -> Result<Self::Value, A::Error> {
226                let mut res = Map::new();
227
228                while let Some((key, value)) = map.next_entry()? {
229                    res.insert(key, value);
230                }
231                Ok(AnyParams::Named(res))
232            }
233        }
234        deserializer.deserialize_any(ParamVisitor)
235    }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq, Hash)]
239pub enum Id {
240    Null,
241    String(String),
242    Number(serde_json::Number),
243}
244impl Serialize for Id {
245    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
246    where
247        S: Serializer,
248    {
249        match self {
250            Id::Null => serializer.serialize_none(),
251            Id::String(s) => serializer.serialize_str(s),
252            Id::Number(n) => {
253                #[cfg(feature = "strict")]
254                if !n.is_i64() {
255                    return Err(serde::ser::Error::custom(
256                        "Numbers SHOULD NOT contain fractional parts",
257                    ));
258                }
259                serde_json::Number::serialize(n, serializer)
260            }
261        }
262    }
263}
264impl<'de> Deserialize<'de> for Id {
265    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
266    where
267        D: Deserializer<'de>,
268    {
269        struct IdVisitor;
270        impl<'de> Visitor<'de> for IdVisitor {
271            type Value = Id;
272            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
273                write!(formatter, "a String, Number, or NULL value")
274            }
275            fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
276                Ok(Id::Null)
277            }
278            fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
279                Ok(Id::Null)
280            }
281            fn visit_string<E: serde::de::Error>(self, v: String) -> Result<Self::Value, E> {
282                Ok(Id::String(v))
283            }
284            fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
285                Ok(Id::String(v.to_string()))
286            }
287            fn visit_f32<E: serde::de::Error>(self, v: f32) -> Result<Self::Value, E> {
288                #[cfg(feature = "strict")]
289                if v != v.trunc() {
290                    return Err(serde::de::Error::custom(
291                        "Numbers SHOULD NOT contain fractional parts",
292                    ));
293                }
294                Ok(Id::Number(
295                    serde_json::Number::from_f64(v as f64).ok_or_else(|| {
296                        serde::de::Error::custom("Infinite or NaN values are not JSON numbers")
297                    })?,
298                ))
299            }
300            fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
301                #[cfg(feature = "strict")]
302                if v != v.trunc() {
303                    return Err(serde::de::Error::custom(
304                        "Numbers SHOULD NOT contain fractional parts",
305                    ));
306                }
307                Ok(Id::Number(serde_json::Number::from_f64(v).ok_or_else(
308                    || serde::de::Error::custom("Infinite or NaN values are not JSON numbers"),
309                )?))
310            }
311            fn visit_i8<E: serde::de::Error>(self, v: i8) -> Result<Self::Value, E> {
312                Ok(Id::Number(v.into()))
313            }
314            fn visit_i16<E: serde::de::Error>(self, v: i16) -> Result<Self::Value, E> {
315                Ok(Id::Number(v.into()))
316            }
317            fn visit_i32<E: serde::de::Error>(self, v: i32) -> Result<Self::Value, E> {
318                Ok(Id::Number(v.into()))
319            }
320            fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
321                Ok(Id::Number(v.into()))
322            }
323            fn visit_u8<E: serde::de::Error>(self, v: u8) -> Result<Self::Value, E> {
324                Ok(Id::Number(v.into()))
325            }
326            fn visit_u16<E: serde::de::Error>(self, v: u16) -> Result<Self::Value, E> {
327                Ok(Id::Number(v.into()))
328            }
329            fn visit_u32<E: serde::de::Error>(self, v: u32) -> Result<Self::Value, E> {
330                Ok(Id::Number(v.into()))
331            }
332            fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
333                Ok(Id::Number(v.into()))
334            }
335        }
336        deserializer.deserialize_any(IdVisitor)
337    }
338}
339impl PartialOrd for Id {
340    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
341        use std::cmp::Ordering::*;
342        match (self, other) {
343            (Id::Null, Id::Null) => Some(Equal),
344            (Id::Null, _) => Some(Less),
345            (_, Id::Null) => Some(Greater),
346            (Id::String(a), Id::String(b)) => a.partial_cmp(b),
347            (Id::String(_), _) => Some(Less),
348            (_, Id::String(_)) => Some(Greater),
349            (Id::Number(a), Id::Number(b)) => a.as_f64().partial_cmp(&b.as_f64()),
350        }
351    }
352}
353impl Ord for Id {
354    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
355        self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
356    }
357}
358
359#[derive(Debug, Clone, Default, Serialize, Deserialize)]
360pub struct RpcRequest<T: RpcMethod = AnyRpcMethod<'static>> {
361    #[serde(default)]
362    #[serde(skip_serializing_if = "Option::is_none")]
363    #[serde(deserialize_with = "deserialize_some")]
364    pub id: Option<Id>,
365    pub method: T,
366    pub params: T::Params,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct RpcError {
371    pub code: i32,
372    pub message: Cow<'static, str>,
373    #[serde(default)]
374    #[serde(skip_serializing_if = "Option::is_none")]
375    #[serde(deserialize_with = "deserialize_some")]
376    pub data: Option<Value>,
377}
378impl RpcError {
379    pub fn into_anyhow(self) -> anyhow::Error {
380        anyhow::anyhow!(self)
381    }
382}
383impl std::fmt::Display for RpcError {
384    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
385        write!(f, "rpc exited with code {}: {}", self.code, self.message)?;
386        if let Some(data) = &self.data {
387            write!(f, "\n{:#?}", data)
388        } else {
389            Ok(())
390        }
391    }
392}
393impl<E> From<E> for RpcError
394where
395    E: Into<anyhow::Error>,
396{
397    fn from(err: E) -> Self {
398        let err = err.into();
399        let mut res = if let Some(json_err) = err.downcast_ref::<serde_json::Error>() {
400            if json_err.is_syntax() {
401                PARSE_ERROR
402            } else {
403                INVALID_REQUEST_ERROR
404            }
405        } else {
406            GENERAL_ERROR
407        };
408        res.data = Some(Value::String(format!("{}", err)));
409        res
410    }
411}
412
413#[derive(Debug, Clone)]
414pub struct RpcResponse<T: RpcMethod = AnyRpcMethod<'static>> {
415    pub id: Option<Id>,
416    pub result: Result<T::Response, RpcError>,
417}
418impl<Method: RpcMethod> From<RpcError> for RpcResponse<Method> {
419    fn from(e: RpcError) -> Self {
420        RpcResponse {
421            id: None,
422            result: Err(e),
423        }
424    }
425}
426impl<T> RpcResponse<T>
427where
428    T: RpcMethod,
429{
430    pub fn from_result<E: Into<RpcError>>(res: Result<T::Response, E>) -> Self {
431        RpcResponse {
432            id: None,
433            result: res.map_err(|e| e.into()),
434        }
435    }
436}
437impl<T> Serialize for RpcResponse<T>
438where
439    T: RpcMethod,
440    T::Response: Serialize,
441{
442    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
443    where
444        S: Serializer,
445    {
446        use serde::ser::SerializeMap;
447
448        let mut map_ser = serializer.serialize_map(Some(3))?;
449        map_ser.serialize_entry("jsonrpc", "2.0")?;
450        match &self.result {
451            Ok(a) => {
452                map_ser.serialize_entry("result", a)?;
453            }
454            Err(e) => {
455                map_ser.serialize_entry("error", e)?;
456            }
457        }
458        map_ser.serialize_entry("id", &self.id)?;
459        map_ser.end()
460    }
461}
462impl<'de, T> Deserialize<'de> for RpcResponse<T>
463where
464    T: RpcMethod,
465    T::Response: Deserialize<'de>,
466{
467    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
468    where
469        D: Deserializer<'de>,
470    {
471        struct ResponseVisitor<T>(PhantomData<T>);
472        impl<'de, T> Visitor<'de> for ResponseVisitor<T>
473        where
474            T: RpcMethod,
475            T::Response: Deserialize<'de>,
476        {
477            type Value = RpcResponse<T>;
478            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
479                write!(formatter, "a String, Number, or NULL value")
480            }
481            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
482            where
483                A: MapAccess<'de>,
484            {
485                let mut verison = None;
486                let mut id = None;
487                let mut result = None;
488                let mut error = None;
489                while let Some(k) = map.next_key::<String>()? {
490                    match k.as_str() {
491                        "jsonrpc" => {
492                            verison = Some(map.next_value::<String>()?);
493                        }
494                        "id" => {
495                            id = Some(map.next_value()?);
496                        }
497                        "result" => {
498                            result = Some(map.next_value()?);
499                        }
500                        "error" => {
501                            error = Some(map.next_value()?);
502                        }
503                        _ => {
504                            let _: Value = map.next_value()?;
505                        }
506                    }
507                }
508                match verison {
509                    Some(v) if v == "2.0" => (),
510                    Some(v) => {
511                        return Err(serde::de::Error::invalid_value(
512                            serde::de::Unexpected::Str(v.as_str()),
513                            &"2.0",
514                        ))
515                    }
516                    None => return Err(serde::de::Error::missing_field("jsonrpc")),
517                }
518                Ok(RpcResponse {
519                    id: id.ok_or_else(|| serde::de::Error::missing_field("id"))?,
520                    result: match (result, error) {
521                        (None, None) => return Err(serde::de::Error::missing_field("result OR error")),
522                        (None, Some(e)) => Err(e),
523                        (Some(a), None) => Ok(a),
524                        (Some(_), Some(_)) => return Err(serde::de::Error::custom("Either the result member or error member MUST be included, but both members MUST NOT be included.")),
525                    }
526                })
527            }
528        }
529        deserializer.deserialize_map(ResponseVisitor(PhantomData))
530    }
531}