Skip to main content

synapse_sdk/
client.rs

1//! Unified RPC client with connection pooling
2//!
3//! SynClient provides a connection pool for all outgoing RPC traffic.
4//! Generated service clients use this pool for efficient communication.
5
6use anyhow::{Context, Result, bail};
7use bytes::Bytes;
8use std::{path::Path, sync::Arc};
9use synapse_primitives::{InterfaceId, MethodId};
10use synapse_proto::RpcRequest;
11use synapse_proto::RpcStatus;
12#[cfg(feature = "otlp")]
13use synapse_proto::{HeaderEntry, header_entry};
14use synapse_rpc::HttpRpcClient;
15use tracing::{debug, warn};
16
17/// Unified RPC client with connection pooling
18///
19/// This is the central client that manages all outgoing RPC connections.
20/// Generated service clients use this pool rather than creating their own connections.
21///
22/// The underlying HTTP client (reqwest) automatically pools connections,
23/// so SynClient can be cloned cheaply and shared across the application.
24#[derive(Clone)]
25pub struct SynClient {
26    inner: Arc<SynClientInner>,
27}
28
29struct SynClientInner {
30    http_client: HttpRpcClient,
31}
32
33impl SynClient {
34    /// Create a new unified RPC client
35    ///
36    /// # Arguments
37    /// * `gateway_url` - URL of the gateway (e.g., "http://localhost:8080")
38    ///
39    /// # Example
40    /// ```ignore
41    /// let client = SynClient::new("http://localhost:8080");
42    /// ```
43    pub fn new(gateway_url: impl Into<String>) -> Self {
44        Self {
45            inner: Arc::new(SynClientInner {
46                http_client: HttpRpcClient::json(gateway_url),
47            }),
48        }
49    }
50
51    /// Create a client with protobuf encoding (more efficient)
52    pub fn with_protobuf(gateway_url: impl Into<String>) -> Self {
53        Self {
54            inner: Arc::new(SynClientInner {
55                http_client: HttpRpcClient::protobuf(gateway_url),
56            }),
57        }
58    }
59
60    /// Create a client with mTLS authentication (production)
61    pub fn with_mtls(
62        gateway_url: impl Into<String>,
63        cert_path: impl AsRef<Path>,
64        key_path: impl AsRef<Path>,
65        ca_cert_path: impl AsRef<Path>,
66    ) -> Result<Self> {
67        Ok(Self {
68            inner: Arc::new(SynClientInner {
69                http_client: HttpRpcClient::protobuf_mtls(
70                    gateway_url,
71                    cert_path,
72                    key_path,
73                    ca_cert_path,
74                )?,
75            }),
76        })
77    }
78
79    /// Set the request timeout
80    pub fn with_timeout(self, timeout: std::time::Duration) -> Self {
81        Self {
82            inner: Arc::new(SynClientInner {
83                http_client: HttpRpcClient::new(
84                    self.gateway_url(),
85                    self.inner.http_client.content_type(),
86                )
87                .with_timeout(timeout),
88            }),
89        }
90    }
91
92    /// Call a service method with raw bytes
93    ///
94    /// # Arguments
95    /// * `interface` - The interface name or ID
96    /// * `method` - The method name or ID
97    /// * `payload` - The request payload (serialized)
98    pub async fn call(
99        &self,
100        interface: impl Into<InterfaceId>,
101        method: impl Into<MethodId>,
102        payload: Bytes,
103    ) -> Result<Bytes> {
104        let interface_id = interface.into();
105        let method_id = method.into();
106
107        debug!(
108            "Calling {}.{} ({} bytes)",
109            u32::from(interface_id),
110            u32::from(method_id),
111            payload.len()
112        );
113
114        // Build RPC request (request_id is in the SynapseMessage envelope, not here)
115        #[allow(unused_mut)]
116        let mut headers = Vec::new();
117
118        // Inject trace context if OpenTelemetry is enabled
119        #[cfg(feature = "otlp")]
120        {
121            use opentelemetry::trace::TraceContextExt;
122            use tracing_opentelemetry::OpenTelemetrySpanExt;
123
124            let span = tracing::Span::current();
125            let otel_ctx = span.context();
126            let span_ref = otel_ctx.span();
127            let span_ctx = span_ref.span_context();
128
129            if span_ctx.is_valid() {
130                headers.push(HeaderEntry {
131                    key: u32::from(*synapse_primitives::id::well_known::TRACE_ID),
132                    value: Some(header_entry::Value::StringValue(format!(
133                        "{:032x}",
134                        span_ctx.trace_id()
135                    ))),
136                });
137                headers.push(HeaderEntry {
138                    key: u32::from(*synapse_primitives::id::well_known::SPAN_ID),
139                    value: Some(header_entry::Value::StringValue(format!(
140                        "{:016x}",
141                        span_ctx.span_id()
142                    ))),
143                });
144            }
145        }
146
147        let request = RpcRequest {
148            interface_id: interface_id.into(),
149            method_id: method_id.into(),
150            headers,
151            payload,
152            sent_at_unix_ms: chrono::Utc::now().timestamp_millis(),
153        };
154
155        // Make the call
156        let response = self
157            .inner
158            .http_client
159            .call(request)
160            .await
161            .context("RPC call failed")?;
162
163        // Check status
164        if response.status != RpcStatus::Ok as i32 {
165            let error = response.error.unwrap_or_else(|| synapse_proto::RpcError {
166                code: response.status as u32,
167                message: format!("RPC failed with status {}", response.status),
168                details: vec![],
169            });
170
171            warn!("RPC call failed: {} - {}", error.code, error.message);
172            bail!("RPC error {}: {}", error.code, error.message);
173        }
174
175        debug!("RPC call succeeded ({} bytes)", response.payload.len());
176        Ok(response.payload)
177    }
178
179    /// Call with typed request and response (JSON serialization)
180    ///
181    /// This is a convenience method for JSON-serializable types.
182    pub async fn call_json<TReq, TResp>(
183        &self,
184        interface: impl Into<InterfaceId>,
185        method: impl Into<MethodId>,
186        request: &TReq,
187    ) -> Result<TResp>
188    where
189        TReq: serde::Serialize,
190        TResp: serde::de::DeserializeOwned,
191    {
192        // Serialize request
193        let payload = serde_json::to_vec(request).context("Failed to serialize request")?;
194
195        // Make call
196        let response_bytes = self.call(interface, method, Bytes::from(payload)).await?;
197
198        // Deserialize response
199        let response =
200            serde_json::from_slice(&response_bytes).context("Failed to deserialize response")?;
201
202        Ok(response)
203    }
204
205    /// Call with typed request and response (Protobuf serialization)
206    ///
207    /// This is the efficient method for protobuf types.
208    /// Generated service clients use this method.
209    pub async fn call_proto<TReq, TResp>(
210        &self,
211        interface: impl Into<InterfaceId>,
212        method: impl Into<MethodId>,
213        request: &TReq,
214    ) -> Result<TResp>
215    where
216        TReq: prost::Message,
217        TResp: prost::Message + Default,
218    {
219        // Serialize request with prost
220        let payload = request.encode_to_vec();
221
222        // Make call
223        let response_bytes = self.call(interface, method, Bytes::from(payload)).await?;
224
225        // Deserialize response with prost
226        let response =
227            TResp::decode(response_bytes.as_ref()).context("Failed to deserialize response")?;
228
229        Ok(response)
230    }
231
232    /// Get the gateway URL
233    pub fn gateway_url(&self) -> &str {
234        self.inner.http_client.gateway_url()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_client_creation() {
244        let client = SynClient::new("http://localhost:8080");
245        assert_eq!(client.gateway_url(), "http://localhost:8080");
246
247        // with_timeout creates a new client with the specified duration
248        let _client =
249            SynClient::new("http://gateway:5000").with_timeout(std::time::Duration::from_secs(60));
250    }
251
252    #[test]
253    fn test_client_clone() {
254        let client1 = SynClient::new("http://localhost:8080");
255        let client2 = client1.clone();
256
257        // Both should point to same inner Arc
258        assert_eq!(client1.gateway_url(), client2.gateway_url());
259    }
260
261    #[tokio::test]
262    #[ignore] // Requires running gateway
263    async fn test_call() {
264        let client = SynClient::new("http://localhost:8080");
265
266        let interface = InterfaceId::from_name("test.Service");
267        let method = MethodId::from_name("Echo");
268        let payload = Bytes::from("hello");
269
270        let _response = client.call(interface, method, payload).await;
271    }
272}