1use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use quinn::{ClientConfig, Connection, Endpoint, TransportConfig};
10use thiserror::Error;
11use tokio::sync::Mutex;
12use tracing::{debug, info, instrument};
13
14use crate::frame::{Frame, FrameError, FramedStream};
15
16#[derive(Debug, Error)]
18pub enum ClientError {
19 #[error("connection error: {0}")]
20 Connection(#[from] quinn::ConnectionError),
21
22 #[error("connect error: {0}")]
23 Connect(#[from] quinn::ConnectError),
24
25 #[error("write error: {0}")]
26 Write(#[from] quinn::WriteError),
27
28 #[error("read error: {0}")]
29 Read(#[from] quinn::ReadExactError),
30
31 #[error("frame error: {0}")]
32 Frame(#[from] FrameError),
33
34 #[error("IO error: {0}")]
35 Io(#[from] std::io::Error),
36
37 #[error("stream closed: {0}")]
38 ClosedStream(#[from] quinn::ClosedStream),
39
40 #[error("no connection established")]
41 NotConnected,
42
43 #[error("invalid server name: {0}")]
44 InvalidServerName(String),
45
46 #[error("connection timed out after {0}ms")]
47 Timeout(u64),
48}
49
50#[derive(Debug, Clone)]
52pub struct RuntaraClientConfig {
53 pub server_addr: SocketAddr,
55 pub server_name: String,
57 pub enable_0rtt: bool,
59 pub dangerous_skip_cert_verification: bool,
61 pub keep_alive_interval_ms: u64,
63 pub idle_timeout_ms: u64,
65 pub connect_timeout_ms: u64,
67}
68
69impl Default for RuntaraClientConfig {
70 fn default() -> Self {
71 Self {
72 server_addr: "127.0.0.1:8001".parse().unwrap(),
73 server_name: "localhost".to_string(),
74 enable_0rtt: true,
75 dangerous_skip_cert_verification: false,
76 keep_alive_interval_ms: 10_000,
77 idle_timeout_ms: 30_000,
78 connect_timeout_ms: 10_000,
79 }
80 }
81}
82
83pub struct RuntaraClient {
85 endpoint: Endpoint,
86 connection: Mutex<Option<Connection>>,
87 config: RuntaraClientConfig,
88}
89
90impl RuntaraClient {
91 pub fn new(config: RuntaraClientConfig) -> Result<Self, ClientError> {
93 let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
94
95 let client_config = Self::build_client_config(&config)?;
96 endpoint.set_default_client_config(client_config);
97
98 Ok(Self {
99 endpoint,
100 connection: Mutex::new(None),
101 config,
102 })
103 }
104
105 pub fn localhost() -> Result<Self, ClientError> {
107 Self::new(RuntaraClientConfig {
108 dangerous_skip_cert_verification: true,
109 ..Default::default()
110 })
111 }
112
113 fn build_client_config(config: &RuntaraClientConfig) -> Result<ClientConfig, ClientError> {
114 let crypto = if config.dangerous_skip_cert_verification {
115 rustls::ClientConfig::builder()
116 .dangerous()
117 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
118 .with_no_client_auth()
119 } else {
120 let mut roots = rustls::RootCertStore::empty();
121 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
122 rustls::ClientConfig::builder()
123 .with_root_certificates(roots)
124 .with_no_client_auth()
125 };
126
127 let mut transport = TransportConfig::default();
128 if config.keep_alive_interval_ms > 0 {
129 transport.keep_alive_interval(Some(std::time::Duration::from_millis(
130 config.keep_alive_interval_ms,
131 )));
132 }
133 transport.max_idle_timeout(Some(
134 std::time::Duration::from_millis(config.idle_timeout_ms)
135 .try_into()
136 .unwrap(),
137 ));
138
139 let mut client_config = ClientConfig::new(Arc::new(
140 quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
141 ));
142 client_config.transport_config(Arc::new(transport));
143
144 Ok(client_config)
145 }
146
147 #[instrument(skip(self))]
149 pub async fn connect(&self) -> Result<(), ClientError> {
150 let mut conn_guard = self.connection.lock().await;
151
152 if let Some(ref conn) = *conn_guard
154 && conn.close_reason().is_none()
155 {
156 debug!("reusing existing connection");
157 return Ok(());
158 }
159
160 info!(addr = %self.config.server_addr, "connecting to runtara-core");
161
162 let timeout = Duration::from_millis(self.config.connect_timeout_ms);
163 let connecting = self
164 .endpoint
165 .connect(self.config.server_addr, &self.config.server_name)?;
166
167 let connection = tokio::time::timeout(timeout, connecting)
168 .await
169 .map_err(|_| ClientError::Timeout(self.config.connect_timeout_ms))??;
170
171 info!("connected to runtara-core");
172 *conn_guard = Some(connection);
173 Ok(())
174 }
175
176 async fn get_connection(&self) -> Result<Connection, ClientError> {
178 self.connect().await?;
179 let conn_guard = self.connection.lock().await;
180 conn_guard.clone().ok_or(ClientError::NotConnected)
181 }
182
183 pub async fn open_stream(
185 &self,
186 ) -> Result<FramedStream<(quinn::SendStream, quinn::RecvStream)>, ClientError> {
187 let conn = self.get_connection().await?;
188 let (send, recv) = conn.open_bi().await?;
189 Ok(FramedStream::new((send, recv)))
190 }
191
192 pub async fn open_uni_send(&self) -> Result<FramedStream<quinn::SendStream>, ClientError> {
194 let conn = self.get_connection().await?;
195 let send = conn.open_uni().await?;
196 Ok(FramedStream::new(send))
197 }
198
199 #[instrument(skip(self, request))]
201 pub async fn request<Req: prost::Message, Resp: prost::Message + Default>(
202 &self,
203 request: &Req,
204 ) -> Result<Resp, ClientError> {
205 let conn = self.get_connection().await?;
206 let (mut send, mut recv) = conn.open_bi().await?;
207
208 let frame = Frame::request(request)?;
210 crate::frame::write_frame(&mut send, &frame).await?;
211 send.finish()?;
212
213 let response_frame = crate::frame::read_frame(&mut recv).await?;
215 Ok(response_frame.decode()?)
216 }
217
218 #[instrument(skip(self, request))]
222 pub async fn send_fire_and_forget<Req: prost::Message>(
223 &self,
224 request: &Req,
225 ) -> Result<(), ClientError> {
226 let conn = self.get_connection().await?;
227 let (mut send, _recv) = conn.open_bi().await?;
228
229 let frame = Frame::request(request)?;
231 crate::frame::write_frame(&mut send, &frame).await?;
232 send.finish()?;
233
234 Ok(())
236 }
237
238 pub async fn open_raw_stream(
243 &self,
244 ) -> Result<(quinn::SendStream, quinn::RecvStream), ClientError> {
245 let conn = self.get_connection().await?;
246 Ok(conn.open_bi().await?)
247 }
248
249 pub async fn close(&self) {
251 let mut conn_guard = self.connection.lock().await;
252 if let Some(conn) = conn_guard.take() {
253 conn.close(0u32.into(), b"client closing");
254 }
255 }
256
257 pub async fn is_connected(&self) -> bool {
259 let conn_guard = self.connection.lock().await;
260 if let Some(ref conn) = *conn_guard {
261 conn.close_reason().is_none()
262 } else {
263 false
264 }
265 }
266}
267
268impl Drop for RuntaraClient {
269 fn drop(&mut self) {
270 if let Ok(mut guard) = self.connection.try_lock()
272 && let Some(conn) = guard.take()
273 {
274 conn.close(0u32.into(), b"client dropped");
275 }
276 }
277}
278
279#[derive(Debug)]
281struct SkipServerVerification;
282
283impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
284 fn verify_server_cert(
285 &self,
286 _end_entity: &rustls::pki_types::CertificateDer<'_>,
287 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
288 _server_name: &rustls::pki_types::ServerName<'_>,
289 _ocsp_response: &[u8],
290 _now: rustls::pki_types::UnixTime,
291 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
292 Ok(rustls::client::danger::ServerCertVerified::assertion())
293 }
294
295 fn verify_tls12_signature(
296 &self,
297 _message: &[u8],
298 _cert: &rustls::pki_types::CertificateDer<'_>,
299 _dss: &rustls::DigitallySignedStruct,
300 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
301 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
302 }
303
304 fn verify_tls13_signature(
305 &self,
306 _message: &[u8],
307 _cert: &rustls::pki_types::CertificateDer<'_>,
308 _dss: &rustls::DigitallySignedStruct,
309 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
311 }
312
313 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
314 vec![
315 rustls::SignatureScheme::RSA_PKCS1_SHA256,
316 rustls::SignatureScheme::RSA_PKCS1_SHA384,
317 rustls::SignatureScheme::RSA_PKCS1_SHA512,
318 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
319 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
320 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
321 rustls::SignatureScheme::RSA_PSS_SHA256,
322 rustls::SignatureScheme::RSA_PSS_SHA384,
323 rustls::SignatureScheme::RSA_PSS_SHA512,
324 rustls::SignatureScheme::ED25519,
325 ]
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_default_config() {
335 let config = RuntaraClientConfig::default();
336 assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
337 assert_eq!(config.server_name, "localhost");
338 }
339
340 #[tokio::test]
341 async fn test_client_creation() {
342 let mut config = RuntaraClientConfig::default();
343 config.dangerous_skip_cert_verification = true;
344 let client = RuntaraClient::new(config);
345 assert!(
346 client.is_ok(),
347 "Failed to create client: {:?}",
348 client.err()
349 );
350 }
351}