soil_rpc/server/middleware/
mod.rs1use std::{
10 num::NonZeroU32,
11 time::{Duration, Instant},
12};
13
14use futures::future::{BoxFuture, FutureExt};
15use governor::{clock::Clock, Jitter};
16use jsonrpsee::{
17 server::middleware::rpc::RpcServiceT,
18 types::{ErrorObject, Id, Request},
19 MethodResponse,
20};
21
22mod metrics;
23mod node_health;
24mod rate_limit;
25
26pub use metrics::*;
27pub use node_health::*;
28pub use rate_limit::*;
29
30const MAX_JITTER: Duration = Duration::from_millis(50);
31const MAX_RETRIES: usize = 10;
32
33#[derive(Debug, Clone, Default)]
35pub struct MiddlewareLayer {
36 rate_limit: Option<RateLimit>,
37 metrics: Option<Metrics>,
38}
39
40impl MiddlewareLayer {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self {
48 Self { rate_limit: Some(RateLimit::per_minute(n)), metrics: self.metrics }
49 }
50
51 pub fn with_metrics(self, metrics: Metrics) -> Self {
53 Self { rate_limit: self.rate_limit, metrics: Some(metrics) }
54 }
55
56 pub fn ws_connect(&self) {
58 self.metrics.as_ref().map(|m| m.ws_connect());
59 }
60
61 pub fn ws_disconnect(&self, now: Instant) {
63 self.metrics.as_ref().map(|m| m.ws_disconnect(now));
64 }
65}
66
67impl<S> tower::Layer<S> for MiddlewareLayer {
68 type Service = Middleware<S>;
69
70 fn layer(&self, service: S) -> Self::Service {
71 Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() }
72 }
73}
74
75pub struct Middleware<S> {
83 service: S,
84 rate_limit: Option<RateLimit>,
85 metrics: Option<Metrics>,
86}
87
88impl<'a, S> RpcServiceT<'a> for Middleware<S>
89where
90 S: Send + Sync + RpcServiceT<'a> + Clone + 'static,
91{
92 type Future = BoxFuture<'a, MethodResponse>;
93
94 fn call(&self, req: Request<'a>) -> Self::Future {
95 let now = Instant::now();
96
97 self.metrics.as_ref().map(|m| m.on_call(&req));
98
99 let service = self.service.clone();
100 let rate_limit = self.rate_limit.clone();
101 let metrics = self.metrics.clone();
102
103 async move {
104 let mut is_rate_limited = false;
105
106 if let Some(limit) = rate_limit.as_ref() {
107 let mut attempts = 0;
108 let jitter = Jitter::up_to(MAX_JITTER);
109
110 loop {
111 if attempts >= MAX_RETRIES {
112 return reject_too_many_calls(req.id);
113 }
114
115 if let Err(rejected) = limit.inner.check() {
116 tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now()))
117 .await;
118 } else {
119 break;
120 }
121
122 is_rate_limited = true;
123 attempts += 1;
124 }
125 }
126
127 let rp = service.call(req.clone()).await;
128 metrics.as_ref().map(|m| m.on_response(&req, &rp, is_rate_limited, now));
129
130 rp
131 }
132 .boxed()
133 }
134}
135
136fn reject_too_many_calls(id: Id) -> MethodResponse {
137 MethodResponse::error(id, ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>))
138}