1#![allow(clippy::type_complexity)]
2use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15
16use crate::bridges::{Bridge, BridgeError, BridgeKind};
17use crate::canonicalize;
18use crate::rpc::RpcTransport;
19use crate::session::SessionFrame;
20
21#[derive(Clone, Debug)]
22pub struct GrpcCallContext {
23 pub method: String,
24 pub metadata: HashMap<String, String>,
25 pub authority: Option<String>,
27}
28
29#[derive(Clone, Debug)]
30pub struct GrpcReply {
31 pub body: Vec<u8>,
32 pub metadata: HashMap<String, String>,
33}
34
35pub trait GrpcChannel: Send + Sync {
39 fn unary(&self, call: &GrpcCallContext, body: &[u8]) -> Result<GrpcReply, BridgeError>;
40 fn close(&self) -> Result<(), BridgeError>;
41}
42
43#[derive(Clone, Debug)]
44pub struct GrpcBridgeConfig {
45 pub bridge_id: String,
46 pub trust_domain: String,
47 pub service_method: String,
49 pub authority: Option<String>,
51 pub metadata: HashMap<String, String>,
53}
54
55pub struct GrpcBridge {
56 cfg: GrpcBridgeConfig,
57 channel: Arc<dyn GrpcChannel>,
58 listeners: Mutex<Vec<Arc<dyn Fn(SessionFrame) + Send + Sync>>>,
59}
60
61impl GrpcBridge {
62 pub fn new(channel: Arc<dyn GrpcChannel>, cfg: GrpcBridgeConfig) -> Self {
63 Self {
64 cfg,
65 channel,
66 listeners: Mutex::new(Vec::new()),
67 }
68 }
69
70 pub fn send_frame(&self, frame_canonical_json: &[u8]) -> Result<Vec<u8>, BridgeError> {
75 let ctx = GrpcCallContext {
76 method: self.cfg.service_method.clone(),
77 metadata: self.cfg.metadata.clone(),
78 authority: self.cfg.authority.clone(),
79 };
80 let reply = self.channel.unary(&ctx, frame_canonical_json)?;
81 if let Ok(listeners) = self.listeners.lock() {
82 if let Ok(frame) = serde_json::from_slice::<SessionFrame>(&reply.body) {
83 for l in listeners.iter() {
84 l(frame.clone());
85 }
86 }
87 }
88 Ok(reply.body)
89 }
90
91 pub fn send_value(&self, frame: &serde_json::Value) -> Result<Vec<u8>, BridgeError> {
94 let bytes = canonicalize(frame).map_err(|e| BridgeError::InvalidInput(e.to_string()))?;
95 self.send_frame(bytes.as_bytes())
96 }
97
98 pub fn close(&self) -> Result<(), BridgeError> {
99 self.channel.close()
100 }
101}
102
103impl RpcTransport for GrpcBridge {
104 fn send(&self, frame: SessionFrame) {
105 let json = serde_json::to_vec(&frame).unwrap_or_default();
106 let _ = self.send_frame(&json);
107 }
108
109 fn on_frame(&self, listener: Arc<dyn Fn(SessionFrame) + Send + Sync>) {
110 if let Ok(mut listeners) = self.listeners.lock() {
111 listeners.push(listener);
112 }
113 }
114}
115
116impl Bridge for GrpcBridge {
117 fn bridge_id(&self) -> &str {
118 &self.cfg.bridge_id
119 }
120 fn kind(&self) -> BridgeKind {
121 BridgeKind::Grpc
122 }
123 fn trust_domain(&self) -> &str {
124 &self.cfg.trust_domain
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use std::sync::atomic::{AtomicUsize, Ordering};
132
133 struct FakeChannel {
134 last_body: Mutex<Vec<u8>>,
135 echo: Mutex<Vec<u8>>,
136 close_calls: AtomicUsize,
137 }
138
139 impl FakeChannel {
140 fn new(echo: Vec<u8>) -> Self {
141 Self {
142 last_body: Mutex::new(Vec::new()),
143 echo: Mutex::new(echo),
144 close_calls: AtomicUsize::new(0),
145 }
146 }
147 }
148
149 impl GrpcChannel for FakeChannel {
150 fn unary(&self, _call: &GrpcCallContext, body: &[u8]) -> Result<GrpcReply, BridgeError> {
151 *self.last_body.lock().unwrap() = body.to_vec();
152 Ok(GrpcReply {
153 body: self.echo.lock().unwrap().clone(),
154 metadata: HashMap::new(),
155 })
156 }
157 fn close(&self) -> Result<(), BridgeError> {
158 self.close_calls.fetch_add(1, Ordering::SeqCst);
159 Ok(())
160 }
161 }
162
163 #[test]
164 fn send_and_receive_round_trip() {
165 let reply_frame = SessionFrame::Data {
166 payload: serde_json::json!({"ok": true}),
167 };
168 let reply_bytes = serde_json::to_vec(&reply_frame).unwrap();
169 let chan = Arc::new(FakeChannel::new(reply_bytes));
170 let bridge = GrpcBridge::new(
171 chan.clone(),
172 GrpcBridgeConfig {
173 bridge_id: "tf-grpc".into(),
174 trust_domain: "example.com".into(),
175 service_method: "trustforge.ProofRpc/Unary".into(),
176 authority: Some("rpc.example.com".into()),
177 metadata: HashMap::new(),
178 },
179 );
180 let counter = Arc::new(AtomicUsize::new(0));
181 {
182 let counter = counter.clone();
183 bridge.on_frame(Arc::new(move |f| {
184 match f {
185 SessionFrame::Data { payload } => {
186 assert_eq!(payload, serde_json::json!({"ok": true}));
187 }
188 _ => panic!("unexpected frame"),
189 }
190 counter.fetch_add(1, Ordering::SeqCst);
191 }));
192 }
193 let frame = SessionFrame::Data {
194 payload: serde_json::json!("hello"),
195 };
196 bridge.send(frame);
197 assert_eq!(counter.load(Ordering::SeqCst), 1);
198 bridge.close().expect("close");
199 assert_eq!(chan.close_calls.load(Ordering::SeqCst), 1);
200 }
201
202 #[test]
203 fn bridge_metadata_round_trip() {
204 let chan = Arc::new(FakeChannel::new(Vec::new()));
205 let bridge = GrpcBridge::new(
206 chan,
207 GrpcBridgeConfig {
208 bridge_id: "tf-grpc".into(),
209 trust_domain: "example.com".into(),
210 service_method: "trustforge.ProofRpc/Unary".into(),
211 authority: None,
212 metadata: HashMap::new(),
213 },
214 );
215 assert_eq!(bridge.bridge_id(), "tf-grpc");
216 assert_eq!(bridge.kind(), BridgeKind::Grpc);
217 assert_eq!(bridge.trust_domain(), "example.com");
218 }
219}