runtara_protocol/
client.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! QUIC client helpers for connecting to runtara-core.
4
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use quinn::{ClientConfig, Connection, Endpoint, TransportConfig};
10use thiserror::Error;
11use tokio::sync::Mutex;
12use tracing::{debug, info, instrument};
13
14use crate::frame::{Frame, FrameError, FramedStream};
15
16/// Errors that can occur in the QUIC client
17#[derive(Debug, Error)]
18pub enum ClientError {
19    #[error("connection error: {0}")]
20    Connection(#[from] quinn::ConnectionError),
21
22    #[error("connect error: {0}")]
23    Connect(#[from] quinn::ConnectError),
24
25    #[error("write error: {0}")]
26    Write(#[from] quinn::WriteError),
27
28    #[error("read error: {0}")]
29    Read(#[from] quinn::ReadExactError),
30
31    #[error("frame error: {0}")]
32    Frame(#[from] FrameError),
33
34    #[error("IO error: {0}")]
35    Io(#[from] std::io::Error),
36
37    #[error("stream closed: {0}")]
38    ClosedStream(#[from] quinn::ClosedStream),
39
40    #[error("no connection established")]
41    NotConnected,
42
43    #[error("invalid server name: {0}")]
44    InvalidServerName(String),
45
46    #[error("connection timed out after {0}ms")]
47    Timeout(u64),
48}
49
50/// Configuration for the QUIC client
51#[derive(Debug, Clone)]
52pub struct RuntaraClientConfig {
53    /// Server address to connect to
54    pub server_addr: SocketAddr,
55    /// Server name for TLS verification (use "localhost" for local dev)
56    pub server_name: String,
57    /// Enable 0-RTT for lower latency (requires server support)
58    pub enable_0rtt: bool,
59    /// Skip certificate verification (for development only!)
60    pub dangerous_skip_cert_verification: bool,
61    /// Keep-alive interval in milliseconds (0 to disable)
62    pub keep_alive_interval_ms: u64,
63    /// Idle timeout in milliseconds
64    pub idle_timeout_ms: u64,
65    /// Connection timeout in milliseconds
66    pub connect_timeout_ms: u64,
67}
68
69impl Default for RuntaraClientConfig {
70    fn default() -> Self {
71        Self {
72            server_addr: "127.0.0.1:8001".parse().unwrap(),
73            server_name: "localhost".to_string(),
74            enable_0rtt: true,
75            dangerous_skip_cert_verification: false,
76            keep_alive_interval_ms: 10_000,
77            idle_timeout_ms: 30_000,
78            connect_timeout_ms: 10_000,
79        }
80    }
81}
82
83/// QUIC client for communicating with runtara-core
84pub struct RuntaraClient {
85    endpoint: Endpoint,
86    connection: Mutex<Option<Connection>>,
87    config: RuntaraClientConfig,
88}
89
90impl RuntaraClient {
91    /// Create a new client with the given configuration
92    pub fn new(config: RuntaraClientConfig) -> Result<Self, ClientError> {
93        let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
94
95        let client_config = Self::build_client_config(&config)?;
96        endpoint.set_default_client_config(client_config);
97
98        Ok(Self {
99            endpoint,
100            connection: Mutex::new(None),
101            config,
102        })
103    }
104
105    /// Create a client with default configuration for local development
106    pub fn localhost() -> Result<Self, ClientError> {
107        Self::new(RuntaraClientConfig {
108            dangerous_skip_cert_verification: true,
109            ..Default::default()
110        })
111    }
112
113    fn build_client_config(config: &RuntaraClientConfig) -> Result<ClientConfig, ClientError> {
114        let crypto = if config.dangerous_skip_cert_verification {
115            rustls::ClientConfig::builder()
116                .dangerous()
117                .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
118                .with_no_client_auth()
119        } else {
120            let mut roots = rustls::RootCertStore::empty();
121            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
122            rustls::ClientConfig::builder()
123                .with_root_certificates(roots)
124                .with_no_client_auth()
125        };
126
127        let mut transport = TransportConfig::default();
128        if config.keep_alive_interval_ms > 0 {
129            transport.keep_alive_interval(Some(std::time::Duration::from_millis(
130                config.keep_alive_interval_ms,
131            )));
132        }
133        transport.max_idle_timeout(Some(
134            std::time::Duration::from_millis(config.idle_timeout_ms)
135                .try_into()
136                .unwrap(),
137        ));
138
139        let mut client_config = ClientConfig::new(Arc::new(
140            quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
141        ));
142        client_config.transport_config(Arc::new(transport));
143
144        Ok(client_config)
145    }
146
147    /// Connect to the server
148    #[instrument(skip(self))]
149    pub async fn connect(&self) -> Result<(), ClientError> {
150        let mut conn_guard = self.connection.lock().await;
151
152        // Check if we already have a valid connection
153        if let Some(ref conn) = *conn_guard
154            && conn.close_reason().is_none()
155        {
156            debug!("reusing existing connection");
157            return Ok(());
158        }
159
160        info!(addr = %self.config.server_addr, "connecting to runtara-core");
161
162        let timeout = Duration::from_millis(self.config.connect_timeout_ms);
163        let connecting = self
164            .endpoint
165            .connect(self.config.server_addr, &self.config.server_name)?;
166
167        let connection = tokio::time::timeout(timeout, connecting)
168            .await
169            .map_err(|_| ClientError::Timeout(self.config.connect_timeout_ms))??;
170
171        info!("connected to runtara-core");
172        *conn_guard = Some(connection);
173        Ok(())
174    }
175
176    /// Get the current connection, connecting if necessary
177    async fn get_connection(&self) -> Result<Connection, ClientError> {
178        self.connect().await?;
179        let conn_guard = self.connection.lock().await;
180        conn_guard.clone().ok_or(ClientError::NotConnected)
181    }
182
183    /// Open a new bidirectional stream for a request/response
184    pub async fn open_stream(
185        &self,
186    ) -> Result<FramedStream<(quinn::SendStream, quinn::RecvStream)>, ClientError> {
187        let conn = self.get_connection().await?;
188        let (send, recv) = conn.open_bi().await?;
189        Ok(FramedStream::new((send, recv)))
190    }
191
192    /// Open a unidirectional stream for sending (e.g., events)
193    pub async fn open_uni_send(&self) -> Result<FramedStream<quinn::SendStream>, ClientError> {
194        let conn = self.get_connection().await?;
195        let send = conn.open_uni().await?;
196        Ok(FramedStream::new(send))
197    }
198
199    /// Send a request and receive a response using a new stream
200    #[instrument(skip(self, request))]
201    pub async fn request<Req: prost::Message, Resp: prost::Message + Default>(
202        &self,
203        request: &Req,
204    ) -> Result<Resp, ClientError> {
205        let conn = self.get_connection().await?;
206        let (mut send, mut recv) = conn.open_bi().await?;
207
208        // Send request
209        let frame = Frame::request(request)?;
210        crate::frame::write_frame(&mut send, &frame).await?;
211        send.finish()?;
212
213        // Read response
214        let response_frame = crate::frame::read_frame(&mut recv).await?;
215        Ok(response_frame.decode()?)
216    }
217
218    /// Send a fire-and-forget request (no response expected).
219    ///
220    /// Use this for events that don't require acknowledgement.
221    #[instrument(skip(self, request))]
222    pub async fn send_fire_and_forget<Req: prost::Message>(
223        &self,
224        request: &Req,
225    ) -> Result<(), ClientError> {
226        let conn = self.get_connection().await?;
227        let (mut send, _recv) = conn.open_bi().await?;
228
229        // Send request
230        let frame = Frame::request(request)?;
231        crate::frame::write_frame(&mut send, &frame).await?;
232        send.finish()?;
233
234        // No response expected - just return
235        Ok(())
236    }
237
238    /// Open a raw bidirectional stream for streaming operations.
239    ///
240    /// This returns the raw QUIC streams for advanced use cases like
241    /// streaming large data that doesn't fit in a single frame.
242    pub async fn open_raw_stream(
243        &self,
244    ) -> Result<(quinn::SendStream, quinn::RecvStream), ClientError> {
245        let conn = self.get_connection().await?;
246        Ok(conn.open_bi().await?)
247    }
248
249    /// Close the connection gracefully
250    pub async fn close(&self) {
251        let mut conn_guard = self.connection.lock().await;
252        if let Some(conn) = conn_guard.take() {
253            conn.close(0u32.into(), b"client closing");
254        }
255    }
256
257    /// Check if the client is currently connected
258    pub async fn is_connected(&self) -> bool {
259        let conn_guard = self.connection.lock().await;
260        if let Some(ref conn) = *conn_guard {
261            conn.close_reason().is_none()
262        } else {
263            false
264        }
265    }
266}
267
268impl Drop for RuntaraClient {
269    fn drop(&mut self) {
270        // Close connection on drop (non-async, best effort)
271        if let Ok(mut guard) = self.connection.try_lock()
272            && let Some(conn) = guard.take()
273        {
274            conn.close(0u32.into(), b"client dropped");
275        }
276    }
277}
278
279/// Certificate verifier that skips all verification (for development only!)
280#[derive(Debug)]
281struct SkipServerVerification;
282
283impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
284    fn verify_server_cert(
285        &self,
286        _end_entity: &rustls::pki_types::CertificateDer<'_>,
287        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
288        _server_name: &rustls::pki_types::ServerName<'_>,
289        _ocsp_response: &[u8],
290        _now: rustls::pki_types::UnixTime,
291    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
292        Ok(rustls::client::danger::ServerCertVerified::assertion())
293    }
294
295    fn verify_tls12_signature(
296        &self,
297        _message: &[u8],
298        _cert: &rustls::pki_types::CertificateDer<'_>,
299        _dss: &rustls::DigitallySignedStruct,
300    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
301        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
302    }
303
304    fn verify_tls13_signature(
305        &self,
306        _message: &[u8],
307        _cert: &rustls::pki_types::CertificateDer<'_>,
308        _dss: &rustls::DigitallySignedStruct,
309    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
311    }
312
313    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
314        vec![
315            rustls::SignatureScheme::RSA_PKCS1_SHA256,
316            rustls::SignatureScheme::RSA_PKCS1_SHA384,
317            rustls::SignatureScheme::RSA_PKCS1_SHA512,
318            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
319            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
320            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
321            rustls::SignatureScheme::RSA_PSS_SHA256,
322            rustls::SignatureScheme::RSA_PSS_SHA384,
323            rustls::SignatureScheme::RSA_PSS_SHA512,
324            rustls::SignatureScheme::ED25519,
325        ]
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_default_config() {
335        let config = RuntaraClientConfig::default();
336        assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
337        assert_eq!(config.server_name, "localhost");
338    }
339
340    #[tokio::test]
341    async fn test_client_creation() {
342        let mut config = RuntaraClientConfig::default();
343        config.dangerous_skip_cert_verification = true;
344        let client = RuntaraClient::new(config);
345        assert!(
346            client.is_ok(),
347            "Failed to create client: {:?}",
348            client.err()
349        );
350    }
351}