runtara_protocol/
server.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! QUIC server helpers for runtara-core.
4
5use std::net::SocketAddr;
6use std::sync::Arc;
7
8use quinn::{Endpoint, Incoming, RecvStream, SendStream, ServerConfig, TransportConfig};
9use thiserror::Error;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::frame::{Frame, FrameError, FramedStream, read_frame, write_frame};
13
14/// Errors that can occur in the QUIC server
15#[derive(Debug, Error)]
16pub enum ServerError {
17    #[error("bind error: {0}")]
18    Bind(#[from] std::io::Error),
19
20    #[error("connection error: {0}")]
21    Connection(#[from] quinn::ConnectionError),
22
23    #[error("frame error: {0}")]
24    Frame(#[from] FrameError),
25
26    #[error("TLS error: {0}")]
27    Tls(String),
28
29    #[error("server closed")]
30    Closed,
31}
32
33/// Configuration for the QUIC server
34#[derive(Debug, Clone)]
35pub struct RuntaraServerConfig {
36    /// Address to bind to
37    pub bind_addr: SocketAddr,
38    /// TLS certificate chain (PEM format)
39    pub cert_pem: Vec<u8>,
40    /// TLS private key (PEM format)
41    pub key_pem: Vec<u8>,
42    /// Maximum concurrent connections
43    pub max_connections: u32,
44    /// Maximum concurrent bidirectional streams per connection
45    pub max_bi_streams: u32,
46    /// Maximum concurrent unidirectional streams per connection
47    pub max_uni_streams: u32,
48    /// Idle timeout in milliseconds
49    pub idle_timeout_ms: u64,
50}
51
52impl Default for RuntaraServerConfig {
53    fn default() -> Self {
54        Self {
55            bind_addr: "0.0.0.0:7001".parse().unwrap(),
56            cert_pem: Vec::new(),
57            key_pem: Vec::new(),
58            max_connections: 10_000,
59            max_bi_streams: 100,
60            max_uni_streams: 100,
61            idle_timeout_ms: 30_000,
62        }
63    }
64}
65
66/// QUIC server for runtara-core
67pub struct RuntaraServer {
68    endpoint: Endpoint,
69}
70
71impl RuntaraServer {
72    /// Create a new server with the given configuration
73    pub fn new(config: RuntaraServerConfig) -> Result<Self, ServerError> {
74        let server_config = Self::build_server_config(&config)?;
75        let endpoint = Endpoint::server(server_config, config.bind_addr)?;
76
77        info!(addr = %config.bind_addr, "QUIC server bound");
78
79        Ok(Self { endpoint })
80    }
81
82    /// Create a server with self-signed certificate for local development
83    pub fn localhost(bind_addr: SocketAddr) -> Result<Self, ServerError> {
84        let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
85            .map_err(|e| ServerError::Tls(e.to_string()))?;
86
87        let cert_pem = cert.cert.pem().into_bytes();
88        let key_pem = cert.key_pair.serialize_pem().into_bytes();
89
90        let config = RuntaraServerConfig {
91            bind_addr,
92            cert_pem,
93            key_pem,
94            ..Default::default()
95        };
96
97        Self::new(config)
98    }
99
100    fn build_server_config(config: &RuntaraServerConfig) -> Result<ServerConfig, ServerError> {
101        let certs = rustls_pemfile::certs(&mut config.cert_pem.as_slice())
102            .collect::<Result<Vec<_>, _>>()
103            .map_err(|e| ServerError::Tls(format!("failed to parse certificates: {}", e)))?;
104
105        let key = rustls_pemfile::private_key(&mut config.key_pem.as_slice())
106            .map_err(|e| ServerError::Tls(format!("failed to parse private key: {}", e)))?
107            .ok_or_else(|| ServerError::Tls("no private key found".to_string()))?;
108
109        let crypto = rustls::ServerConfig::builder()
110            .with_no_client_auth()
111            .with_single_cert(certs, key)
112            .map_err(|e| ServerError::Tls(e.to_string()))?;
113
114        let mut transport = TransportConfig::default();
115        transport.max_idle_timeout(Some(
116            std::time::Duration::from_millis(config.idle_timeout_ms)
117                .try_into()
118                .unwrap(),
119        ));
120        transport.max_concurrent_bidi_streams(config.max_bi_streams.into());
121        transport.max_concurrent_uni_streams(config.max_uni_streams.into());
122
123        let mut server_config = ServerConfig::with_crypto(Arc::new(
124            quinn::crypto::rustls::QuicServerConfig::try_from(crypto)
125                .map_err(|e| ServerError::Tls(e.to_string()))?,
126        ));
127        server_config.transport_config(Arc::new(transport));
128
129        Ok(server_config)
130    }
131
132    /// Accept the next incoming connection
133    pub async fn accept(&self) -> Option<Incoming> {
134        self.endpoint.accept().await
135    }
136
137    /// Get the local address the server is bound to
138    pub fn local_addr(&self) -> Result<SocketAddr, ServerError> {
139        Ok(self.endpoint.local_addr()?)
140    }
141
142    /// Close the server
143    pub fn close(&self) {
144        self.endpoint.close(0u32.into(), b"server closing");
145    }
146
147    /// Run the server with a connection handler
148    #[instrument(skip(self, handler))]
149    pub async fn run<H, Fut>(&self, handler: H) -> Result<(), ServerError>
150    where
151        H: Fn(ConnectionHandler) -> Fut + Send + Sync + Clone + 'static,
152        Fut: std::future::Future<Output = ()> + Send + 'static,
153    {
154        info!("QUIC server running");
155
156        while let Some(incoming) = self.accept().await {
157            let handler = handler.clone();
158
159            tokio::spawn(async move {
160                match incoming.await {
161                    Ok(connection) => {
162                        let remote_addr = connection.remote_address();
163                        debug!(%remote_addr, "accepted connection");
164
165                        let conn_handler = ConnectionHandler::new(connection);
166                        handler(conn_handler).await;
167                    }
168                    Err(e) => {
169                        warn!("failed to accept connection: {}", e);
170                    }
171                }
172            });
173        }
174
175        Ok(())
176    }
177}
178
179/// Handler for an individual QUIC connection
180pub struct ConnectionHandler {
181    connection: quinn::Connection,
182}
183
184impl ConnectionHandler {
185    pub fn new(connection: quinn::Connection) -> Self {
186        Self { connection }
187    }
188
189    /// Get the remote address of the connection
190    pub fn remote_address(&self) -> SocketAddr {
191        self.connection.remote_address()
192    }
193
194    /// Accept the next bidirectional stream
195    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ServerError> {
196        Ok(self.connection.accept_bi().await?)
197    }
198
199    /// Accept the next unidirectional stream (for receiving)
200    pub async fn accept_uni(&self) -> Result<RecvStream, ServerError> {
201        Ok(self.connection.accept_uni().await?)
202    }
203
204    /// Open a bidirectional stream
205    pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ServerError> {
206        Ok(self.connection.open_bi().await?)
207    }
208
209    /// Open a unidirectional stream (for sending)
210    pub async fn open_uni(&self) -> Result<SendStream, ServerError> {
211        Ok(self.connection.open_uni().await?)
212    }
213
214    /// Run the connection handler with a stream handler
215    #[instrument(skip(self, handler), fields(remote = %self.remote_address()))]
216    pub async fn run<H, Fut>(&self, handler: H)
217    where
218        H: Fn(StreamHandler) -> Fut + Send + Sync + Clone + 'static,
219        Fut: std::future::Future<Output = ()> + Send + 'static,
220    {
221        loop {
222            tokio::select! {
223                result = self.accept_bi() => {
224                    match result {
225                        Ok((send, recv)) => {
226                            let handler = handler.clone();
227                            tokio::spawn(async move {
228                                let stream_handler = StreamHandler::new(send, recv);
229                                handler(stream_handler).await;
230                            });
231                        }
232                        Err(e) => {
233                            match &e {
234                                ServerError::Connection(quinn::ConnectionError::ApplicationClosed(_)) |
235                                ServerError::Connection(quinn::ConnectionError::LocallyClosed) => {
236                                    debug!("connection closed");
237                                }
238                                _ => {
239                                    error!("error accepting stream: {}", e);
240                                }
241                            }
242                            break;
243                        }
244                    }
245                }
246            }
247        }
248    }
249
250    /// Check if the connection is still open
251    pub fn is_open(&self) -> bool {
252        self.connection.close_reason().is_none()
253    }
254
255    /// Close the connection
256    pub fn close(&self, code: u32, reason: &[u8]) {
257        self.connection.close(code.into(), reason);
258    }
259}
260
261/// Handler for an individual QUIC stream (bidirectional)
262pub struct StreamHandler {
263    send: SendStream,
264    recv: RecvStream,
265}
266
267impl StreamHandler {
268    pub fn new(send: SendStream, recv: RecvStream) -> Self {
269        Self { send, recv }
270    }
271
272    /// Read the next frame from the stream
273    pub async fn read_frame(&mut self) -> Result<Frame, ServerError> {
274        Ok(read_frame(&mut self.recv).await?)
275    }
276
277    /// Write a frame to the stream
278    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ServerError> {
279        Ok(write_frame(&mut self.send, frame).await?)
280    }
281
282    /// Handle a request/response pattern
283    pub async fn handle_request<Req, Resp, H, Fut>(&mut self, handler: H) -> Result<(), ServerError>
284    where
285        Req: prost::Message + Default,
286        Resp: prost::Message,
287        H: FnOnce(Req) -> Fut,
288        Fut: std::future::Future<Output = Result<Resp, ServerError>>,
289    {
290        // Read request
291        let request_frame = self.read_frame().await?;
292        let request: Req = request_frame.decode()?;
293
294        // Process and respond
295        match handler(request).await {
296            Ok(response) => {
297                let response_frame = Frame::response(&response)?;
298                self.write_frame(&response_frame).await?;
299            }
300            Err(e) => {
301                error!("request handler error: {}", e);
302                // Send error frame with empty payload
303                // The frame type itself indicates an error
304                let error_frame = Frame {
305                    message_type: crate::frame::MessageType::Error,
306                    payload: bytes::Bytes::new(),
307                };
308                self.write_frame(&error_frame).await?;
309            }
310        }
311
312        Ok(())
313    }
314
315    /// Convert to a FramedStream for more complex patterns
316    pub fn into_framed(self) -> FramedStream<(SendStream, RecvStream)> {
317        FramedStream::new((self.send, self.recv))
318    }
319
320    /// Finish the send stream (signal no more data)
321    pub fn finish(&mut self) -> Result<(), ServerError> {
322        self.send
323            .finish()
324            .map_err(|e| ServerError::Frame(FrameError::Io(std::io::Error::other(e))))?;
325        Ok(())
326    }
327
328    /// Read raw bytes from the stream (for streaming uploads)
329    /// Returns the number of bytes read, or 0 if EOF
330    pub async fn read_bytes(&mut self, buf: &mut [u8]) -> Result<usize, ServerError> {
331        match self.recv.read(buf).await {
332            Ok(Some(n)) => Ok(n),
333            Ok(None) => Ok(0), // EOF
334            Err(e) => Err(ServerError::Frame(FrameError::Io(std::io::Error::other(
335                e.to_string(),
336            )))),
337        }
338    }
339
340    /// Read exact number of bytes from the stream
341    pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ServerError> {
342        self.recv.read_exact(buf).await.map_err(|e| {
343            ServerError::Frame(FrameError::Io(std::io::Error::other(e.to_string())))
344        })?;
345        Ok(())
346    }
347
348    /// Read all remaining bytes from the stream until EOF (with size limit)
349    pub async fn read_to_end(&mut self, size_limit: usize) -> Result<Vec<u8>, ServerError> {
350        self.recv
351            .read_to_end(size_limit)
352            .await
353            .map_err(|e| ServerError::Frame(FrameError::Io(std::io::Error::other(e.to_string()))))
354    }
355
356    /// Stream bytes to a writer (for large uploads without buffering all in memory)
357    pub async fn stream_to_writer<W: tokio::io::AsyncWrite + Unpin>(
358        &mut self,
359        writer: &mut W,
360        expected_size: Option<u64>,
361    ) -> Result<u64, ServerError> {
362        use tokio::io::AsyncWriteExt;
363
364        let mut total = 0u64;
365        let mut buf = [0u8; 64 * 1024]; // 64KB chunks
366
367        loop {
368            let n = match self.recv.read(&mut buf).await {
369                Ok(Some(n)) => n,
370                Ok(None) => 0, // EOF
371                Err(e) => {
372                    return Err(ServerError::Frame(FrameError::Io(std::io::Error::other(
373                        e.to_string(),
374                    ))));
375                }
376            };
377            if n == 0 {
378                break;
379            }
380            writer.write_all(&buf[..n]).await?;
381            total += n as u64;
382        }
383
384        if let Some(expected) = expected_size
385            && total != expected
386        {
387            return Err(ServerError::Frame(FrameError::Io(std::io::Error::new(
388                std::io::ErrorKind::UnexpectedEof,
389                format!("Expected {} bytes, got {}", expected, total),
390            ))));
391        }
392
393        Ok(total)
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_default_config() {
403        let config = RuntaraServerConfig::default();
404        assert_eq!(config.bind_addr, "0.0.0.0:7001".parse().unwrap());
405        assert_eq!(config.max_connections, 10_000);
406    }
407}