Skip to main content

vgi_rpc/
hooks.rs

1//! Dispatch hook interface used by observability integrations.
2//!
3//! Each call dispatches through `on_dispatch_start` before the handler runs
4//! and `on_dispatch_end` after completion (success or error). The hook
5//! receives `CallStatistics` tallied by the framework and may record
6//! spans / metrics / sentry events.
7
8use std::sync::Arc;
9
10use crate::errors::RpcError;
11use crate::wire::Metadata;
12
13/// Per-call statistics accumulated during dispatch.
14///
15/// All fields start at zero and are incremented by the server as batches
16/// are read/written. Values are a best-effort snapshot at the moment the
17/// `on_dispatch_end` hook fires.
18#[derive(Clone, Debug, Default)]
19pub struct CallStatistics {
20    pub input_batches: u64,
21    pub output_batches: u64,
22    pub input_rows: u64,
23    pub output_rows: u64,
24    pub input_bytes: u64,
25    pub output_bytes: u64,
26}
27
28/// Information passed to a dispatch hook at start and end of each call.
29#[derive(Clone, Debug)]
30pub struct DispatchInfo {
31    pub method: String,
32    pub method_type: &'static str,
33    pub server_id: String,
34    /// Logical service / protocol name.
35    pub protocol: String,
36    /// SHA-256 hex of the canonical __describe__ payload (always required in access log).
37    pub protocol_hash: String,
38    /// Operator-supplied free-form protocol-contract version label (optional).
39    pub protocol_version: String,
40    pub request_id: String,
41    /// Transport-level metadata (HTTP peer addr / pipe contextvar payload).
42    pub transport_metadata: Arc<Metadata>,
43    /// Authenticated principal name, empty when anonymous.
44    pub principal: String,
45    /// Authentication domain identifier, empty when anonymous.
46    pub auth_domain: String,
47    /// True when the call was authenticated.
48    pub authenticated: bool,
49    /// HTTP transport: remote IP:port. Empty otherwise.
50    pub remote_addr: String,
51    /// HTTP transport: response status; 0 when not applicable.
52    pub http_status: u16,
53    /// Self-contained Arrow IPC stream of the request batch (unary + stream init only).
54    pub request_data: Vec<u8>,
55    /// Stream lifecycle identifier (32-char lowercase hex); empty on unary.
56    pub stream_id: String,
57    /// True when a stream was cancelled by the client.
58    pub cancelled: bool,
59    /// Authentication claims — e.g. decoded JWT claims, X.509 cert
60    /// extensions, OAuth introspection fields. Cloned from
61    /// [`AuthContext::claims`](crate::auth::AuthContext::claims) at
62    /// dispatch start. Used by the Sentry hook to enrich user / tag
63    /// fields per Python `2d93987`.
64    pub claims: std::collections::BTreeMap<String, String>,
65}
66
67impl DispatchInfo {
68    /// Build a `DispatchInfo` from the serving server + request + resolved
69    /// auth context. `method_type` is either `"unary"` or `"stream"`.
70    pub fn from_request(
71        server: &crate::server::RpcServer,
72        req: &crate::server::Request,
73        method_type: &'static str,
74        auth: &crate::auth::AuthContext,
75    ) -> Self {
76        Self {
77            method: req.method.clone(),
78            method_type,
79            server_id: server.server_id.clone(),
80            protocol: server.protocol_name().to_string(),
81            protocol_hash: server.protocol_hash().to_string(),
82            protocol_version: server.protocol_version().to_string(),
83            request_id: req.request_id.clone(),
84            transport_metadata: Arc::new(req.metadata.clone()),
85            principal: auth.principal.clone(),
86            auth_domain: auth.domain.clone(),
87            authenticated: auth.authenticated,
88            remote_addr: String::new(),
89            http_status: 0,
90            request_data: Vec::new(),
91            stream_id: String::new(),
92            cancelled: false,
93            claims: auth.claims.clone(),
94        }
95    }
96}
97
98/// Token returned by a hook's start callback and passed back to `on_end`.
99pub type HookToken = u64;
100
101/// Trait implemented by dispatch observability hooks.
102pub trait DispatchHook: Send + Sync {
103    /// Invoked just before the handler runs. Return a token that will be
104    /// passed to `on_dispatch_end`.
105    fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken;
106
107    /// Invoked once the handler has returned and all logs/batches have been
108    /// written to the transport.
109    fn on_dispatch_end(
110        &self,
111        token: HookToken,
112        info: &DispatchInfo,
113        error: Option<&RpcError>,
114        stats: &CallStatistics,
115    );
116}
117
118/// A shared reference to a boxed hook.
119pub type SharedHook = Arc<dyn DispatchHook>;
120
121/// A hook that delegates to two hooks in sequence.
122pub struct ChainHook {
123    inner: Vec<SharedHook>,
124}
125
126impl ChainHook {
127    pub fn new(hooks: Vec<SharedHook>) -> Self {
128        Self { inner: hooks }
129    }
130}
131
132impl DispatchHook for ChainHook {
133    fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken {
134        // Tokens aren't individually recoverable here; each inner hook gets
135        // a best-effort fresh token. Callers that need per-hook tokens can
136        // wrap them individually.
137        for h in &self.inner {
138            let _ = h.on_dispatch_start(info);
139        }
140        0
141    }
142
143    fn on_dispatch_end(
144        &self,
145        token: HookToken,
146        info: &DispatchInfo,
147        error: Option<&RpcError>,
148        stats: &CallStatistics,
149    ) {
150        for h in &self.inner {
151            h.on_dispatch_end(token, info, error, stats);
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::sync::atomic::{AtomicU64, Ordering};
160
161    struct CountingHook {
162        starts: AtomicU64,
163        ends: AtomicU64,
164    }
165
166    impl DispatchHook for CountingHook {
167        fn on_dispatch_start(&self, _info: &DispatchInfo) -> HookToken {
168            self.starts.fetch_add(1, Ordering::Relaxed) + 1
169        }
170        fn on_dispatch_end(
171            &self,
172            _token: HookToken,
173            _info: &DispatchInfo,
174            _error: Option<&RpcError>,
175            _stats: &CallStatistics,
176        ) {
177            self.ends.fetch_add(1, Ordering::Relaxed);
178        }
179    }
180
181    #[test]
182    fn chain_hook_fans_out() {
183        let a = Arc::new(CountingHook {
184            starts: AtomicU64::new(0),
185            ends: AtomicU64::new(0),
186        });
187        let b = Arc::new(CountingHook {
188            starts: AtomicU64::new(0),
189            ends: AtomicU64::new(0),
190        });
191        let chain = ChainHook::new(vec![a.clone(), b.clone()]);
192        let info = DispatchInfo {
193            method: "echo".into(),
194            method_type: "unary",
195            server_id: "test".into(),
196            protocol: String::new(),
197            protocol_hash: String::new(),
198            protocol_version: String::new(),
199            request_id: String::new(),
200            transport_metadata: Arc::new(Default::default()),
201            principal: String::new(),
202            auth_domain: String::new(),
203            authenticated: false,
204            remote_addr: String::new(),
205            http_status: 0,
206            request_data: Vec::new(),
207            stream_id: String::new(),
208            cancelled: false,
209            claims: std::collections::BTreeMap::new(),
210        };
211        let token = chain.on_dispatch_start(&info);
212        chain.on_dispatch_end(token, &info, None, &CallStatistics::default());
213        assert_eq!(a.starts.load(Ordering::Relaxed), 1);
214        assert_eq!(b.starts.load(Ordering::Relaxed), 1);
215        assert_eq!(a.ends.load(Ordering::Relaxed), 1);
216        assert_eq!(b.ends.load(Ordering::Relaxed), 1);
217    }
218}