1use std::{borrow::Cow, marker::PhantomData};
2
3use serde::{
4 de::{MapAccess, SeqAccess, Visitor},
5 Deserialize, Deserializer, Serialize, Serializer,
6};
7use serde_json::{Map, Value};
8
9pub const GENERAL_ERROR: RpcError = RpcError {
10 code: -1,
11 message: Cow::Borrowed("General error"),
12 data: None,
13};
14pub const PARSE_ERROR: RpcError = RpcError {
15 code: -32700,
16 message: Cow::Borrowed("Parse error"),
17 data: None,
18};
19pub const INVALID_REQUEST_ERROR: RpcError = RpcError {
20 code: -32600,
21 message: Cow::Borrowed("Invalid Request"),
22 data: None,
23};
24pub const METHOD_NOT_FOUND_ERROR: RpcError = RpcError {
25 code: -32601,
26 message: Cow::Borrowed("Method not found"),
27 data: None,
28};
29pub const INVALID_PARAMS_ERROR: RpcError = RpcError {
30 code: -32602,
31 message: Cow::Borrowed("Invalid params"),
32 data: None,
33};
34pub const INTERNAL_ERROR: RpcError = RpcError {
35 code: -32603,
36 message: Cow::Borrowed("Internal error"),
37 data: None,
38};
39
40fn deserialize_some<'de, D: Deserializer<'de>, T: Deserialize<'de>>(
41 deserializer: D,
42) -> Result<Option<T>, D::Error> {
43 T::deserialize(deserializer).map(Some)
44}
45
46pub enum SingleOrBatchRpcRequest<T: RpcMethod = AnyRpcMethod<'static>> {
47 Single(RpcRequest<T>),
48 Batch(Vec<RpcRequest<T>>),
49}
50impl<T> Serialize for SingleOrBatchRpcRequest<T>
51where
52 T: RpcMethod,
53 RpcRequest<T>: Serialize,
54{
55 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
56 match self {
57 SingleOrBatchRpcRequest::Single(s) => s.serialize(serializer),
58 SingleOrBatchRpcRequest::Batch(b) => b.serialize(serializer),
59 }
60 }
61}
62impl<'de, T> Deserialize<'de> for SingleOrBatchRpcRequest<T>
63where
64 T: RpcMethod + Deserialize<'de>,
65 T::Params: Deserialize<'de>,
66{
67 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
68 struct ReqVisitor<T>(PhantomData<T>);
69 impl<'de, T> Visitor<'de> for ReqVisitor<T>
70 where
71 T: RpcMethod + Deserialize<'de>,
72 T::Params: Deserialize<'de>,
73 {
74 type Value = SingleOrBatchRpcRequest<T>;
75 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
76 write!(
77 formatter,
78 "a single rpc request, or a batch of rpc requests"
79 )
80 }
81 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
82 let mut res = Vec::with_capacity(seq.size_hint().unwrap_or(16));
83 while let Some(elem) = seq.next_element()? {
84 res.push(elem);
85 }
86 Ok(SingleOrBatchRpcRequest::Batch(res))
87 }
88 fn visit_map<A: serde::de::MapAccess<'de>>(
89 self,
90 mut map: A,
91 ) -> Result<Self::Value, A::Error> {
92 let mut id = None;
93 let mut method = None;
94 let mut params = None;
95 while let Some(key) = map.next_key::<String>()? {
96 match key.as_str() {
97 "id" => {
98 id = map.next_value()?;
99 }
100 "method" => {
101 method = map.next_value()?;
102 }
103 "params" => {
104 params = map.next_value()?;
105 }
106 _ => {
107 let _: Value = map.next_value()?;
108 }
109 }
110 }
111 Ok(SingleOrBatchRpcRequest::Single(RpcRequest {
112 id,
113 method: method.ok_or_else(|| serde::de::Error::missing_field("method"))?,
114 params: params.ok_or_else(|| serde::de::Error::missing_field("params"))?,
115 }))
116 }
117 }
118 deserializer.deserialize_any(ReqVisitor::<T>(PhantomData))
119 }
120}
121
122pub trait RpcMethod {
123 type Params;
124 type Response;
125 fn as_str<'a>(&'a self) -> &'a str;
126}
127
128pub struct GenericRpcMethod<Method: AsRef<str>, Params = AnyParams, Response = Value> {
129 method: Method,
130 params: PhantomData<Params>,
131 response: PhantomData<Response>,
132}
133impl<Method: AsRef<str> + Clone, Params, Response> Clone
134 for GenericRpcMethod<Method, Params, Response>
135{
136 fn clone(&self) -> Self {
137 Self {
138 method: self.method.clone(),
139 params: PhantomData,
140 response: PhantomData,
141 }
142 }
143}
144impl<Method: AsRef<str>, Params, Response> std::fmt::Debug
145 for GenericRpcMethod<Method, Params, Response>
146{
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 write!(f, "{:?}", self.method.as_ref())
149 }
150}
151impl<Method: AsRef<str>, Params, Response> GenericRpcMethod<Method, Params, Response> {
152 pub fn new(method: Method) -> Self {
153 GenericRpcMethod {
154 method,
155 params: PhantomData,
156 response: PhantomData,
157 }
158 }
159}
160impl<Method: AsRef<str>, Params, Response> Serialize
161 for GenericRpcMethod<Method, Params, Response>
162{
163 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
164 where
165 S: Serializer,
166 {
167 Serialize::serialize(self.method.as_ref(), serializer)
168 }
169}
170impl<'de, Method: AsRef<str> + Deserialize<'de>, Params, Response> Deserialize<'de>
171 for GenericRpcMethod<Method, Params, Response>
172{
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: Deserializer<'de>,
176 {
177 Ok(GenericRpcMethod::new(Deserialize::deserialize(
178 deserializer,
179 )?))
180 }
181}
182impl<Method: AsRef<str>, Params, Response> RpcMethod
183 for GenericRpcMethod<Method, Params, Response>
184{
185 type Params = Params;
186 type Response = Response;
187 fn as_str<'a>(&'a self) -> &'a str {
188 self.method.as_ref()
189 }
190}
191
192pub type AnyRpcMethod<'a> = GenericRpcMethod<Cow<'a, str>>;
193
194#[derive(Debug)]
195pub enum AnyParams {
196 Positional(Vec<Value>),
197 Named(Map<String, Value>),
198}
199impl Serialize for AnyParams {
200 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
201 match self {
202 AnyParams::Positional(s) => s.serialize(serializer),
203 AnyParams::Named(b) => b.serialize(serializer),
204 }
205 }
206}
207impl<'de> Deserialize<'de> for AnyParams {
208 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
209 struct ParamVisitor;
210 impl<'de> Visitor<'de> for ParamVisitor {
211 type Value = AnyParams;
212 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
213 write!(formatter, "an array or object")
214 }
215 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
216 let mut res = Vec::with_capacity(seq.size_hint().unwrap_or(16));
217 while let Some(elem) = seq.next_element()? {
218 res.push(elem);
219 }
220 Ok(AnyParams::Positional(res))
221 }
222 fn visit_map<A: serde::de::MapAccess<'de>>(
223 self,
224 mut map: A,
225 ) -> Result<Self::Value, A::Error> {
226 let mut res = Map::new();
227
228 while let Some((key, value)) = map.next_entry()? {
229 res.insert(key, value);
230 }
231 Ok(AnyParams::Named(res))
232 }
233 }
234 deserializer.deserialize_any(ParamVisitor)
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq, Hash)]
239pub enum Id {
240 Null,
241 String(String),
242 Number(serde_json::Number),
243}
244impl Serialize for Id {
245 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
246 where
247 S: Serializer,
248 {
249 match self {
250 Id::Null => serializer.serialize_none(),
251 Id::String(s) => serializer.serialize_str(s),
252 Id::Number(n) => {
253 #[cfg(feature = "strict")]
254 if !n.is_i64() {
255 return Err(serde::ser::Error::custom(
256 "Numbers SHOULD NOT contain fractional parts",
257 ));
258 }
259 serde_json::Number::serialize(n, serializer)
260 }
261 }
262 }
263}
264impl<'de> Deserialize<'de> for Id {
265 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
266 where
267 D: Deserializer<'de>,
268 {
269 struct IdVisitor;
270 impl<'de> Visitor<'de> for IdVisitor {
271 type Value = Id;
272 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
273 write!(formatter, "a String, Number, or NULL value")
274 }
275 fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
276 Ok(Id::Null)
277 }
278 fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
279 Ok(Id::Null)
280 }
281 fn visit_string<E: serde::de::Error>(self, v: String) -> Result<Self::Value, E> {
282 Ok(Id::String(v))
283 }
284 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
285 Ok(Id::String(v.to_string()))
286 }
287 fn visit_f32<E: serde::de::Error>(self, v: f32) -> Result<Self::Value, E> {
288 #[cfg(feature = "strict")]
289 if v != v.trunc() {
290 return Err(serde::de::Error::custom(
291 "Numbers SHOULD NOT contain fractional parts",
292 ));
293 }
294 Ok(Id::Number(
295 serde_json::Number::from_f64(v as f64).ok_or_else(|| {
296 serde::de::Error::custom("Infinite or NaN values are not JSON numbers")
297 })?,
298 ))
299 }
300 fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
301 #[cfg(feature = "strict")]
302 if v != v.trunc() {
303 return Err(serde::de::Error::custom(
304 "Numbers SHOULD NOT contain fractional parts",
305 ));
306 }
307 Ok(Id::Number(serde_json::Number::from_f64(v).ok_or_else(
308 || serde::de::Error::custom("Infinite or NaN values are not JSON numbers"),
309 )?))
310 }
311 fn visit_i8<E: serde::de::Error>(self, v: i8) -> Result<Self::Value, E> {
312 Ok(Id::Number(v.into()))
313 }
314 fn visit_i16<E: serde::de::Error>(self, v: i16) -> Result<Self::Value, E> {
315 Ok(Id::Number(v.into()))
316 }
317 fn visit_i32<E: serde::de::Error>(self, v: i32) -> Result<Self::Value, E> {
318 Ok(Id::Number(v.into()))
319 }
320 fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
321 Ok(Id::Number(v.into()))
322 }
323 fn visit_u8<E: serde::de::Error>(self, v: u8) -> Result<Self::Value, E> {
324 Ok(Id::Number(v.into()))
325 }
326 fn visit_u16<E: serde::de::Error>(self, v: u16) -> Result<Self::Value, E> {
327 Ok(Id::Number(v.into()))
328 }
329 fn visit_u32<E: serde::de::Error>(self, v: u32) -> Result<Self::Value, E> {
330 Ok(Id::Number(v.into()))
331 }
332 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
333 Ok(Id::Number(v.into()))
334 }
335 }
336 deserializer.deserialize_any(IdVisitor)
337 }
338}
339impl PartialOrd for Id {
340 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
341 use std::cmp::Ordering::*;
342 match (self, other) {
343 (Id::Null, Id::Null) => Some(Equal),
344 (Id::Null, _) => Some(Less),
345 (_, Id::Null) => Some(Greater),
346 (Id::String(a), Id::String(b)) => a.partial_cmp(b),
347 (Id::String(_), _) => Some(Less),
348 (_, Id::String(_)) => Some(Greater),
349 (Id::Number(a), Id::Number(b)) => a.as_f64().partial_cmp(&b.as_f64()),
350 }
351 }
352}
353impl Ord for Id {
354 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
355 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
356 }
357}
358
359#[derive(Debug, Clone, Default, Serialize, Deserialize)]
360pub struct RpcRequest<T: RpcMethod = AnyRpcMethod<'static>> {
361 #[serde(default)]
362 #[serde(skip_serializing_if = "Option::is_none")]
363 #[serde(deserialize_with = "deserialize_some")]
364 pub id: Option<Id>,
365 pub method: T,
366 pub params: T::Params,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct RpcError {
371 pub code: i32,
372 pub message: Cow<'static, str>,
373 #[serde(default)]
374 #[serde(skip_serializing_if = "Option::is_none")]
375 #[serde(deserialize_with = "deserialize_some")]
376 pub data: Option<Value>,
377}
378impl RpcError {
379 pub fn into_anyhow(self) -> anyhow::Error {
380 anyhow::anyhow!(self)
381 }
382}
383impl std::fmt::Display for RpcError {
384 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
385 write!(f, "rpc exited with code {}: {}", self.code, self.message)?;
386 if let Some(data) = &self.data {
387 write!(f, "\n{:#?}", data)
388 } else {
389 Ok(())
390 }
391 }
392}
393impl<E> From<E> for RpcError
394where
395 E: Into<anyhow::Error>,
396{
397 fn from(err: E) -> Self {
398 let err = err.into();
399 let mut res = if let Some(json_err) = err.downcast_ref::<serde_json::Error>() {
400 if json_err.is_syntax() {
401 PARSE_ERROR
402 } else {
403 INVALID_REQUEST_ERROR
404 }
405 } else {
406 GENERAL_ERROR
407 };
408 res.data = Some(Value::String(format!("{}", err)));
409 res
410 }
411}
412
413#[derive(Debug, Clone)]
414pub struct RpcResponse<T: RpcMethod = AnyRpcMethod<'static>> {
415 pub id: Option<Id>,
416 pub result: Result<T::Response, RpcError>,
417}
418impl<Method: RpcMethod> From<RpcError> for RpcResponse<Method> {
419 fn from(e: RpcError) -> Self {
420 RpcResponse {
421 id: None,
422 result: Err(e),
423 }
424 }
425}
426impl<T> RpcResponse<T>
427where
428 T: RpcMethod,
429{
430 pub fn from_result<E: Into<RpcError>>(res: Result<T::Response, E>) -> Self {
431 RpcResponse {
432 id: None,
433 result: res.map_err(|e| e.into()),
434 }
435 }
436}
437impl<T> Serialize for RpcResponse<T>
438where
439 T: RpcMethod,
440 T::Response: Serialize,
441{
442 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
443 where
444 S: Serializer,
445 {
446 use serde::ser::SerializeMap;
447
448 let mut map_ser = serializer.serialize_map(Some(3))?;
449 map_ser.serialize_entry("jsonrpc", "2.0")?;
450 match &self.result {
451 Ok(a) => {
452 map_ser.serialize_entry("result", a)?;
453 }
454 Err(e) => {
455 map_ser.serialize_entry("error", e)?;
456 }
457 }
458 map_ser.serialize_entry("id", &self.id)?;
459 map_ser.end()
460 }
461}
462impl<'de, T> Deserialize<'de> for RpcResponse<T>
463where
464 T: RpcMethod,
465 T::Response: Deserialize<'de>,
466{
467 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
468 where
469 D: Deserializer<'de>,
470 {
471 struct ResponseVisitor<T>(PhantomData<T>);
472 impl<'de, T> Visitor<'de> for ResponseVisitor<T>
473 where
474 T: RpcMethod,
475 T::Response: Deserialize<'de>,
476 {
477 type Value = RpcResponse<T>;
478 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
479 write!(formatter, "a String, Number, or NULL value")
480 }
481 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
482 where
483 A: MapAccess<'de>,
484 {
485 let mut verison = None;
486 let mut id = None;
487 let mut result = None;
488 let mut error = None;
489 while let Some(k) = map.next_key::<String>()? {
490 match k.as_str() {
491 "jsonrpc" => {
492 verison = Some(map.next_value::<String>()?);
493 }
494 "id" => {
495 id = Some(map.next_value()?);
496 }
497 "result" => {
498 result = Some(map.next_value()?);
499 }
500 "error" => {
501 error = Some(map.next_value()?);
502 }
503 _ => {
504 let _: Value = map.next_value()?;
505 }
506 }
507 }
508 match verison {
509 Some(v) if v == "2.0" => (),
510 Some(v) => {
511 return Err(serde::de::Error::invalid_value(
512 serde::de::Unexpected::Str(v.as_str()),
513 &"2.0",
514 ))
515 }
516 None => return Err(serde::de::Error::missing_field("jsonrpc")),
517 }
518 Ok(RpcResponse {
519 id: id.ok_or_else(|| serde::de::Error::missing_field("id"))?,
520 result: match (result, error) {
521 (None, None) => return Err(serde::de::Error::missing_field("result OR error")),
522 (None, Some(e)) => Err(e),
523 (Some(a), None) => Ok(a),
524 (Some(_), Some(_)) => return Err(serde::de::Error::custom("Either the result member or error member MUST be included, but both members MUST NOT be included.")),
525 }
526 })
527 }
528 }
529 deserializer.deserialize_map(ResponseVisitor(PhantomData))
530 }
531}