sierradb_server/
server.rs

1use std::collections::HashMap;
2
3use kameo::actor::ActorRef;
4use libp2p::bytes::BytesMut;
5use redis_protocol::resp3;
6use redis_protocol::resp3::decode::complete::decode_bytes_mut;
7use redis_protocol::resp3::types::BytesFrame;
8use sierradb::bucket::segment::EventRecord;
9use sierradb_cluster::ClusterActor;
10use sierradb_cluster::subscription::SubscriptionEvent;
11use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
13use tokio::sync::{mpsc, watch};
14use tracing::{debug, warn};
15use uuid::Uuid;
16
17use crate::request::{Command, encode_event, number, simple_str};
18
19pub struct Server {
20    cluster_ref: ActorRef<ClusterActor>,
21    num_partitions: u16,
22}
23
24impl Server {
25    pub fn new(cluster_ref: ActorRef<ClusterActor>, num_partitions: u16) -> Self {
26        Server {
27            cluster_ref,
28            num_partitions,
29        }
30    }
31
32    pub async fn listen(&mut self, addr: impl ToSocketAddrs) -> io::Result<()> {
33        let listener = TcpListener::bind(addr).await?;
34        loop {
35            match listener.accept().await {
36                Ok((socket, _)) => {
37                    let cluster_ref = self.cluster_ref.clone();
38                    let num_partitions = self.num_partitions;
39                    tokio::spawn(async move {
40                        let res = Conn::new(cluster_ref, num_partitions, socket).run().await;
41                        if let Err(err) = res {
42                            warn!("connection error: {err}");
43                        }
44                    });
45                }
46                Err(err) => warn!("failed to accept connection: {err}"),
47            }
48        }
49    }
50}
51
52pub struct Conn {
53    pub cluster_ref: ActorRef<ClusterActor>,
54    pub num_partitions: u16,
55    pub socket: TcpStream,
56    pub read: BytesMut,
57    pub write: BytesMut,
58    pub subscription_channel: Option<(
59        mpsc::WeakUnboundedSender<SubscriptionEvent>,
60        mpsc::UnboundedReceiver<SubscriptionEvent>,
61    )>,
62    pub subscriptions: HashMap<Uuid, watch::Sender<Option<u64>>>,
63}
64
65impl Conn {
66    fn new(cluster_ref: ActorRef<ClusterActor>, num_partitions: u16, socket: TcpStream) -> Self {
67        let read = BytesMut::new();
68        let write = BytesMut::new();
69
70        Conn {
71            cluster_ref,
72            socket,
73            read,
74            write,
75            num_partitions,
76            subscription_channel: None,
77            subscriptions: HashMap::new(),
78        }
79    }
80
81    async fn run(mut self) -> io::Result<()> {
82        loop {
83            match &mut self.subscription_channel {
84                Some((_, rx)) => {
85                    tokio::select! {
86                        res = self.socket.read_buf(&mut self.read) => {
87                            match res {
88                                Ok(bytes_read) => {
89                                    if bytes_read == 0 && self.read.is_empty() {
90                                        // Clean up subscriptions on disconnect
91                                        self.cleanup_subscriptions();
92                                        return Ok(());
93                                    }
94
95                                    // Try to decode and handle requests
96                                    while let Some((frame, _, _)) =
97                                        decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
98                                    {
99                                        let response = self.handle_request(frame).await?;
100                                        if let Some(resp) = response {
101                                            resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
102                                                .map_err(io::Error::other)?;
103
104                                            self.socket.write_all(&self.write).await?;
105                                            self.socket.flush().await?;
106                                            self.write.clear();
107                                        }
108                                    }
109                                }
110                                Err(err) => return Err(err),
111                            }
112                        }
113                        msg = rx.recv() => {
114                            match msg {
115                                Some(SubscriptionEvent::Record { subscription_id, cursor, record }) => self.send_subscription_event(subscription_id, cursor, record).await?,
116                                Some(SubscriptionEvent::Error { subscription_id, error }) => {
117                                    warn!(%subscription_id, "subscription error: {error}");
118                                }
119                                Some(SubscriptionEvent::Closed { subscription_id }) => {
120                                    debug!(
121                                        subscription_id = %subscription_id,
122                                        "closed subscription"
123                                    );
124                                    self.subscriptions.remove(&subscription_id);
125                                    if self.subscriptions.is_empty() {
126                                        self.cleanup_subscriptions();
127                                    }
128                                }
129                                None => self.cleanup_subscriptions(),
130                            }
131                        }
132                    }
133                }
134                None => {
135                    // Not in subscription mode - block normally on socket reads
136                    let bytes_read = self.socket.read_buf(&mut self.read).await?;
137                    if bytes_read == 0 && self.read.is_empty() {
138                        return Ok(());
139                    }
140
141                    // Try to decode and handle requests
142                    while let Some((frame, _, _)) =
143                        decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
144                    {
145                        let response = self.handle_request(frame).await?;
146                        if let Some(resp) = response {
147                            resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
148                                .map_err(io::Error::other)?;
149
150                            self.socket.write_all(&self.write).await?;
151                            self.socket.flush().await?;
152                            self.write.clear();
153                        }
154                    }
155                }
156            }
157        }
158    }
159
160    fn cleanup_subscriptions(&mut self) {
161        self.subscriptions.clear();
162        self.subscription_channel = None;
163    }
164
165    async fn send_subscription_event(
166        &mut self,
167        subscription_id: Uuid,
168        cursor: u64,
169        record: EventRecord,
170    ) -> io::Result<()> {
171        resp3::encode::complete::extend_encode(
172            &mut self.write,
173            &BytesFrame::Push {
174                data: vec![
175                    simple_str("message"),
176                    simple_str(subscription_id.to_string()),
177                    number(cursor as i64),
178                    encode_event(record),
179                ],
180                attributes: None,
181            },
182            false,
183        )
184        .map_err(io::Error::other)?;
185
186        self.socket.write_all(&self.write).await?;
187        self.socket.flush().await?;
188        self.write.clear();
189
190        Ok(())
191    }
192
193    async fn handle_request(&mut self, frame: BytesFrame) -> Result<Option<BytesFrame>, io::Error> {
194        match frame {
195            BytesFrame::Array { data, .. } => {
196                if data.is_empty() {
197                    return Ok(Some(BytesFrame::SimpleError {
198                        data: "empty command".into(),
199                        attributes: None,
200                    }));
201                }
202
203                let cmd = match Command::try_from(&data[0]) {
204                    Ok(cmd) => cmd,
205                    Err(err) => {
206                        return Ok(Some(BytesFrame::SimpleError {
207                            data: err.into(),
208                            attributes: None,
209                        }));
210                    }
211                };
212                let args = &data[1..];
213                cmd.handle(args, self).await
214            }
215            _ => Ok(Some(BytesFrame::SimpleError {
216                data: "expected array command".into(),
217                attributes: None,
218            })),
219        }
220    }
221}