Skip to main content

sierradb_server/
server.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use kameo::actor::ActorRef;
5use libp2p::bytes::BytesMut;
6use redis_protocol::resp3;
7use redis_protocol::resp3::decode::complete::decode_bytes_mut;
8use redis_protocol::resp3::types::BytesFrame;
9use sierradb::bucket::BucketId;
10use sierradb::bucket::segment::EventRecord;
11use sierradb::cache::SegmentBlockCache;
12use sierradb_cluster::ClusterActor;
13use sierradb_cluster::subscription::SubscriptionEvent;
14use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
15use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
16use tokio::sync::{mpsc, watch};
17use tokio::task::JoinSet;
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, warn};
20use uuid::Uuid;
21
22use crate::request::{Command, encode_event, number, simple_str};
23
24pub struct Server {
25    cluster_ref: ActorRef<ClusterActor>,
26    caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
27    num_partitions: u16,
28    cache_capacity_bytes: usize,
29    shutdown: CancellationToken,
30    conns: JoinSet<io::Result<()>>,
31}
32
33impl Server {
34    pub fn new(
35        cluster_ref: ActorRef<ClusterActor>,
36        caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
37        num_partitions: u16,
38        cache_capacity_bytes: usize,
39        shutdown: CancellationToken,
40    ) -> Self {
41        Server {
42            cluster_ref,
43            caches,
44            num_partitions,
45            cache_capacity_bytes,
46            shutdown,
47            conns: JoinSet::new(),
48        }
49    }
50
51    pub async fn listen(mut self, addr: impl ToSocketAddrs) -> io::Result<JoinSet<io::Result<()>>> {
52        let listener = TcpListener::bind(addr).await?;
53        loop {
54            tokio::select! {
55                res = listener.accept() => {
56                    match res {
57                        Ok((stream, _)) => {
58                            stream.set_nodelay(true)?;
59                            let cluster_ref = self.cluster_ref.clone();
60                            let caches = self.caches.clone();
61                            let num_partitions = self.num_partitions;
62                            let cache_capacity_bytes = self.cache_capacity_bytes;
63                            let shutdown = self.shutdown.clone();
64                            self.conns.spawn(async move {
65                                let res = Conn::new(
66                                    cluster_ref,
67                                    caches,
68                                    num_partitions,
69                                    cache_capacity_bytes,
70                                    stream,
71                                    shutdown,
72                                )
73                                .run()
74                                .await;
75                                if let Err(err) = &res {
76                                    warn!("connection error: {err}");
77                                }
78                                res
79                            });
80                        }
81                        Err(err) => warn!("failed to accept connection: {err}"),
82                    }
83                }
84                _ = self.shutdown.cancelled() => {
85                    return Ok(self.conns);
86                }
87            }
88        }
89    }
90}
91
92pub struct Conn {
93    pub cluster_ref: ActorRef<ClusterActor>,
94    pub caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
95    pub num_partitions: u16,
96    pub cache_capacity_bytes: usize,
97    pub stream: TcpStream,
98    pub shutdown: CancellationToken,
99    pub read: BytesMut,
100    pub write: BytesMut,
101    pub subscription_channel: Option<(
102        mpsc::WeakUnboundedSender<SubscriptionEvent>,
103        mpsc::UnboundedReceiver<SubscriptionEvent>,
104    )>,
105    pub subscriptions: HashMap<Uuid, watch::Sender<Option<u64>>>,
106}
107
108impl Conn {
109    fn new(
110        cluster_ref: ActorRef<ClusterActor>,
111        caches: Arc<HashMap<BucketId, Arc<SegmentBlockCache>>>,
112        num_partitions: u16,
113        cache_capacity_bytes: usize,
114        stream: TcpStream,
115        shutdown: CancellationToken,
116    ) -> Self {
117        let read = BytesMut::new();
118        let write = BytesMut::new();
119
120        Conn {
121            cluster_ref,
122            caches,
123            num_partitions,
124            cache_capacity_bytes,
125            stream,
126            shutdown,
127            read,
128            write,
129            subscription_channel: None,
130            subscriptions: HashMap::new(),
131        }
132    }
133
134    async fn run(mut self) -> io::Result<()> {
135        loop {
136            match &mut self.subscription_channel {
137                Some((_, rx)) => {
138                    tokio::select! {
139                        res = self.stream.read_buf(&mut self.read) => {
140                            match res {
141                                Ok(bytes_read) => {
142                                    if bytes_read == 0 && self.read.is_empty() {
143                                        // Clean up subscriptions on disconnect
144                                        self.cleanup_subscriptions();
145                                        return Ok(());
146                                    }
147
148                                    // Try to decode and handle requests
149                                    while let Some((frame, _, _)) =
150                                        decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
151                                    {
152                                        let response = self.handle_request(frame).await?;
153                                        if let Some(resp) = response {
154                                            resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
155                                                .map_err(io::Error::other)?;
156
157                                            self.stream.write_all(&self.write).await?;
158                                            self.stream.flush().await?;
159                                            self.write.clear();
160                                        }
161                                    }
162                                }
163                                Err(err) => return Err(err),
164                            }
165                        }
166                        msg = rx.recv() => {
167                            match msg {
168                                Some(SubscriptionEvent::Record { subscription_id, cursor, record }) => self.send_subscription_event(subscription_id, cursor, record).await?,
169                                Some(SubscriptionEvent::Error { subscription_id, error }) => {
170                                    warn!(%subscription_id, "subscription error: {error}");
171                                }
172                                Some(SubscriptionEvent::Closed { subscription_id }) => {
173                                    debug!(
174                                        subscription_id = %subscription_id,
175                                        "closed subscription"
176                                    );
177                                    self.subscriptions.remove(&subscription_id);
178                                    if self.subscriptions.is_empty() {
179                                        self.cleanup_subscriptions();
180                                    }
181                                }
182                                None => self.cleanup_subscriptions(),
183                            }
184                        }
185                        _ = self.shutdown.cancelled() => {
186                            rx.close();
187                            return self.stream.shutdown().await;
188                        }
189                    }
190                }
191                None => {
192                    tokio::select! {
193                        res = self.stream.read_buf(&mut self.read) => {
194                            // Not in subscription mode - block normally on socket reads
195                            let bytes_read = res?;
196                            if bytes_read == 0 && self.read.is_empty() {
197                                return Ok(());
198                            }
199
200                            // Try to decode and handle requests
201                            while let Some((frame, _, _)) =
202                                decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
203                            {
204                                let response = self.handle_request(frame).await?;
205                                if let Some(resp) = response {
206                                    resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
207                                        .map_err(io::Error::other)?;
208
209                                    self.stream.write_all(&self.write).await?;
210                                    self.stream.flush().await?;
211                                    self.write.clear();
212                                }
213                            }
214                        }
215                        _ = self.shutdown.cancelled() => {
216                            return self.stream.shutdown().await;
217                        }
218                    }
219                }
220            }
221        }
222    }
223
224    fn cleanup_subscriptions(&mut self) {
225        self.subscriptions.clear();
226        self.subscription_channel = None;
227    }
228
229    async fn send_subscription_event(
230        &mut self,
231        subscription_id: Uuid,
232        cursor: u64,
233        record: EventRecord,
234    ) -> io::Result<()> {
235        resp3::encode::complete::extend_encode(
236            &mut self.write,
237            &BytesFrame::Push {
238                data: vec![
239                    simple_str("message"),
240                    simple_str(subscription_id.to_string()),
241                    number(cursor as i64),
242                    encode_event(record),
243                ],
244                attributes: None,
245            },
246            false,
247        )
248        .map_err(io::Error::other)?;
249
250        self.stream.write_all(&self.write).await?;
251        self.stream.flush().await?;
252        self.write.clear();
253
254        Ok(())
255    }
256
257    async fn handle_request(&mut self, frame: BytesFrame) -> Result<Option<BytesFrame>, io::Error> {
258        match frame {
259            BytesFrame::Array { data, .. } => {
260                if data.is_empty() {
261                    return Ok(Some(BytesFrame::SimpleError {
262                        data: "empty command".into(),
263                        attributes: None,
264                    }));
265                }
266
267                let cmd = match Command::try_from(&data[0]) {
268                    Ok(cmd) => cmd,
269                    Err(err) => {
270                        return Ok(Some(BytesFrame::SimpleError {
271                            data: err.into(),
272                            attributes: None,
273                        }));
274                    }
275                };
276                let args = &data[1..];
277                cmd.handle(args, self).await
278            }
279            _ => Ok(Some(BytesFrame::SimpleError {
280                data: "expected array command".into(),
281                attributes: None,
282            })),
283        }
284    }
285}