tonic_debug/
connection.rs1use std::{
8 fmt,
9 future::Future,
10 net::SocketAddr,
11 pin::Pin,
12 sync::{
13 atomic::{AtomicU64, Ordering},
14 Arc,
15 },
16 task::{Context, Poll},
17};
18
19use tower_layer::Layer;
20use tower_service::Service;
21
22#[derive(Debug, Clone)]
24pub struct ConnectionMetrics {
25 inner: Arc<ConnectionMetricsInner>,
26}
27
28#[derive(Debug)]
29struct ConnectionMetricsInner {
30 total_connections: AtomicU64,
32 active_connections: AtomicU64,
34 connection_errors: AtomicU64,
36}
37
38impl ConnectionMetrics {
39 pub fn new() -> Self {
41 Self {
42 inner: Arc::new(ConnectionMetricsInner {
43 total_connections: AtomicU64::new(0),
44 active_connections: AtomicU64::new(0),
45 connection_errors: AtomicU64::new(0),
46 }),
47 }
48 }
49
50 pub fn total_connections(&self) -> u64 {
52 self.inner.total_connections.load(Ordering::Relaxed)
53 }
54
55 pub fn active_connections(&self) -> u64 {
57 self.inner.active_connections.load(Ordering::Relaxed)
58 }
59
60 pub fn connection_errors(&self) -> u64 {
62 self.inner.connection_errors.load(Ordering::Relaxed)
63 }
64
65 fn on_connect(&self) {
66 self.inner.total_connections.fetch_add(1, Ordering::Relaxed);
67 self.inner
68 .active_connections
69 .fetch_add(1, Ordering::Relaxed);
70 }
71
72 fn on_disconnect(&self) {
73 self.inner
74 .active_connections
75 .fetch_sub(1, Ordering::Relaxed);
76 }
77
78 fn on_error(&self) {
79 self.inner.connection_errors.fetch_add(1, Ordering::Relaxed);
80 }
81}
82
83impl Default for ConnectionMetrics {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl fmt::Display for ConnectionMetrics {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(
92 f,
93 "connections(total={}, active={}, errors={})",
94 self.total_connections(),
95 self.active_connections(),
96 self.connection_errors()
97 )
98 }
99}
100
101#[derive(Debug, Clone)]
107pub struct ConnectionTrackerLayer {
108 metrics: ConnectionMetrics,
109}
110
111impl ConnectionTrackerLayer {
112 pub fn new() -> Self {
114 Self {
115 metrics: ConnectionMetrics::new(),
116 }
117 }
118
119 pub fn with_metrics(metrics: ConnectionMetrics) -> Self {
121 Self { metrics }
122 }
123
124 pub fn metrics(&self) -> &ConnectionMetrics {
126 &self.metrics
127 }
128}
129
130impl Default for ConnectionTrackerLayer {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl<S> Layer<S> for ConnectionTrackerLayer {
137 type Service = ConnectionTrackerService<S>;
138
139 fn layer(&self, inner: S) -> Self::Service {
140 ConnectionTrackerService {
141 inner,
142 metrics: self.metrics.clone(),
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
152pub struct ConnectionTrackerService<S> {
153 inner: S,
154 metrics: ConnectionMetrics,
155}
156
157impl<S> ConnectionTrackerService<S> {
158 pub fn metrics(&self) -> &ConnectionMetrics {
160 &self.metrics
161 }
162}
163
164impl<S, Target> Service<Target> for ConnectionTrackerService<S>
165where
166 S: Service<Target> + Clone + Send + 'static,
167 S::Response: Send + 'static,
168 S::Error: fmt::Display + Send + 'static,
169 S::Future: Send + 'static,
170 Target: fmt::Debug + Send + 'static,
171{
172 type Response = S::Response;
173 type Error = S::Error;
174 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
175
176 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.inner.poll_ready(cx)
178 }
179
180 fn call(&mut self, target: Target) -> Self::Future {
181 let metrics = self.metrics.clone();
182 let mut inner = self.inner.clone();
183 std::mem::swap(&mut self.inner, &mut inner);
184
185 metrics.on_connect();
186
187 tracing::info!(
188 peer = ?target,
189 active_connections = metrics.active_connections(),
190 total_connections = metrics.total_connections(),
191 "⚡ New connection established"
192 );
193
194 Box::pin(async move {
195 let result = inner.call(target).await;
196 match &result {
197 Ok(_) => {
198 metrics.on_disconnect();
199 tracing::info!(
200 active_connections = metrics.active_connections(),
201 "🔌 Connection closed"
202 );
203 }
204 Err(e) => {
205 metrics.on_error();
206 metrics.on_disconnect();
207 tracing::error!(
208 error = %e,
209 active_connections = metrics.active_connections(),
210 connection_errors = metrics.connection_errors(),
211 "❌ Connection error"
212 );
213 }
214 }
215 result
216 })
217 }
218}
219
220#[derive(Debug)]
225pub struct ConnectionGuard {
226 metrics: ConnectionMetrics,
227 peer: Option<SocketAddr>,
228}
229
230impl ConnectionGuard {
231 pub fn new(metrics: ConnectionMetrics, peer: Option<SocketAddr>) -> Self {
233 metrics.on_connect();
234 tracing::info!(
235 peer = ?peer,
236 active_connections = metrics.active_connections(),
237 total_connections = metrics.total_connections(),
238 "⚡ New connection established"
239 );
240 Self { metrics, peer }
241 }
242}
243
244impl Drop for ConnectionGuard {
245 fn drop(&mut self) {
246 self.metrics.on_disconnect();
247 tracing::info!(
248 peer = ?self.peer,
249 active_connections = self.metrics.active_connections(),
250 "🔌 Connection closed"
251 );
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_connection_metrics() {
261 let metrics = ConnectionMetrics::new();
262 assert_eq!(metrics.total_connections(), 0);
263 assert_eq!(metrics.active_connections(), 0);
264 assert_eq!(metrics.connection_errors(), 0);
265
266 metrics.on_connect();
267 assert_eq!(metrics.total_connections(), 1);
268 assert_eq!(metrics.active_connections(), 1);
269
270 metrics.on_connect();
271 assert_eq!(metrics.total_connections(), 2);
272 assert_eq!(metrics.active_connections(), 2);
273
274 metrics.on_disconnect();
275 assert_eq!(metrics.active_connections(), 1);
276
277 metrics.on_error();
278 assert_eq!(metrics.connection_errors(), 1);
279 }
280
281 #[test]
282 fn test_connection_metrics_display() {
283 let metrics = ConnectionMetrics::new();
284 metrics.on_connect();
285 let display = format!("{}", metrics);
286 assert!(display.contains("total=1"));
287 assert!(display.contains("active=1"));
288 assert!(display.contains("errors=0"));
289 }
290
291 #[test]
292 fn test_metrics_shared_across_clones() {
293 let metrics = ConnectionMetrics::new();
294 let metrics2 = metrics.clone();
295
296 metrics.on_connect();
297 assert_eq!(metrics2.active_connections(), 1);
298
299 metrics2.on_connect();
300 assert_eq!(metrics.active_connections(), 2);
301 }
302
303 #[test]
304 fn test_connection_guard_drop() {
305 let metrics = ConnectionMetrics::new();
306 {
307 let _guard = ConnectionGuard::new(metrics.clone(), None);
308 assert_eq!(metrics.active_connections(), 1);
309 }
310 assert_eq!(metrics.active_connections(), 0);
312 assert_eq!(metrics.total_connections(), 1);
313 }
314}