sierradb_server/
server.rs1use std::collections::HashMap;
2
3use kameo::actor::ActorRef;
4use libp2p::bytes::BytesMut;
5use redis_protocol::resp3;
6use redis_protocol::resp3::decode::complete::decode_bytes_mut;
7use redis_protocol::resp3::types::BytesFrame;
8use sierradb::bucket::segment::EventRecord;
9use sierradb_cluster::ClusterActor;
10use sierradb_cluster::subscription::SubscriptionEvent;
11use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
13use tokio::sync::{mpsc, watch};
14use tracing::{debug, warn};
15use uuid::Uuid;
16
17use crate::request::{Command, encode_event, number, simple_str};
18
19pub struct Server {
20 cluster_ref: ActorRef<ClusterActor>,
21 num_partitions: u16,
22}
23
24impl Server {
25 pub fn new(cluster_ref: ActorRef<ClusterActor>, num_partitions: u16) -> Self {
26 Server {
27 cluster_ref,
28 num_partitions,
29 }
30 }
31
32 pub async fn listen(&mut self, addr: impl ToSocketAddrs) -> io::Result<()> {
33 let listener = TcpListener::bind(addr).await?;
34 loop {
35 match listener.accept().await {
36 Ok((socket, _)) => {
37 let cluster_ref = self.cluster_ref.clone();
38 let num_partitions = self.num_partitions;
39 tokio::spawn(async move {
40 let res = Conn::new(cluster_ref, num_partitions, socket).run().await;
41 if let Err(err) = res {
42 warn!("connection error: {err}");
43 }
44 });
45 }
46 Err(err) => warn!("failed to accept connection: {err}"),
47 }
48 }
49 }
50}
51
52pub struct Conn {
53 pub cluster_ref: ActorRef<ClusterActor>,
54 pub num_partitions: u16,
55 pub socket: TcpStream,
56 pub read: BytesMut,
57 pub write: BytesMut,
58 pub subscription_channel: Option<(
59 mpsc::WeakUnboundedSender<SubscriptionEvent>,
60 mpsc::UnboundedReceiver<SubscriptionEvent>,
61 )>,
62 pub subscriptions: HashMap<Uuid, watch::Sender<Option<u64>>>,
63}
64
65impl Conn {
66 fn new(cluster_ref: ActorRef<ClusterActor>, num_partitions: u16, socket: TcpStream) -> Self {
67 let read = BytesMut::new();
68 let write = BytesMut::new();
69
70 Conn {
71 cluster_ref,
72 socket,
73 read,
74 write,
75 num_partitions,
76 subscription_channel: None,
77 subscriptions: HashMap::new(),
78 }
79 }
80
81 async fn run(mut self) -> io::Result<()> {
82 loop {
83 match &mut self.subscription_channel {
84 Some((_, rx)) => {
85 tokio::select! {
86 res = self.socket.read_buf(&mut self.read) => {
87 match res {
88 Ok(bytes_read) => {
89 if bytes_read == 0 && self.read.is_empty() {
90 self.cleanup_subscriptions();
92 return Ok(());
93 }
94
95 while let Some((frame, _, _)) =
97 decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
98 {
99 let response = self.handle_request(frame).await?;
100 if let Some(resp) = response {
101 resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
102 .map_err(io::Error::other)?;
103
104 self.socket.write_all(&self.write).await?;
105 self.socket.flush().await?;
106 self.write.clear();
107 }
108 }
109 }
110 Err(err) => return Err(err),
111 }
112 }
113 msg = rx.recv() => {
114 match msg {
115 Some(SubscriptionEvent::Record { subscription_id, cursor, record }) => self.send_subscription_event(subscription_id, cursor, record).await?,
116 Some(SubscriptionEvent::Error { subscription_id, error }) => {
117 warn!(%subscription_id, "subscription error: {error}");
118 }
119 Some(SubscriptionEvent::Closed { subscription_id }) => {
120 debug!(
121 subscription_id = %subscription_id,
122 "closed subscription"
123 );
124 self.subscriptions.remove(&subscription_id);
125 if self.subscriptions.is_empty() {
126 self.cleanup_subscriptions();
127 }
128 }
129 None => self.cleanup_subscriptions(),
130 }
131 }
132 }
133 }
134 None => {
135 let bytes_read = self.socket.read_buf(&mut self.read).await?;
137 if bytes_read == 0 && self.read.is_empty() {
138 return Ok(());
139 }
140
141 while let Some((frame, _, _)) =
143 decode_bytes_mut(&mut self.read).map_err(io::Error::other)?
144 {
145 let response = self.handle_request(frame).await?;
146 if let Some(resp) = response {
147 resp3::encode::complete::extend_encode(&mut self.write, &resp, false)
148 .map_err(io::Error::other)?;
149
150 self.socket.write_all(&self.write).await?;
151 self.socket.flush().await?;
152 self.write.clear();
153 }
154 }
155 }
156 }
157 }
158 }
159
160 fn cleanup_subscriptions(&mut self) {
161 self.subscriptions.clear();
162 self.subscription_channel = None;
163 }
164
165 async fn send_subscription_event(
166 &mut self,
167 subscription_id: Uuid,
168 cursor: u64,
169 record: EventRecord,
170 ) -> io::Result<()> {
171 resp3::encode::complete::extend_encode(
172 &mut self.write,
173 &BytesFrame::Push {
174 data: vec![
175 simple_str("message"),
176 simple_str(subscription_id.to_string()),
177 number(cursor as i64),
178 encode_event(record),
179 ],
180 attributes: None,
181 },
182 false,
183 )
184 .map_err(io::Error::other)?;
185
186 self.socket.write_all(&self.write).await?;
187 self.socket.flush().await?;
188 self.write.clear();
189
190 Ok(())
191 }
192
193 async fn handle_request(&mut self, frame: BytesFrame) -> Result<Option<BytesFrame>, io::Error> {
194 match frame {
195 BytesFrame::Array { data, .. } => {
196 if data.is_empty() {
197 return Ok(Some(BytesFrame::SimpleError {
198 data: "empty command".into(),
199 attributes: None,
200 }));
201 }
202
203 let cmd = match Command::try_from(&data[0]) {
204 Ok(cmd) => cmd,
205 Err(err) => {
206 return Ok(Some(BytesFrame::SimpleError {
207 data: err.into(),
208 attributes: None,
209 }));
210 }
211 };
212 let args = &data[1..];
213 cmd.handle(args, self).await
214 }
215 _ => Ok(Some(BytesFrame::SimpleError {
216 data: "expected array command".into(),
217 attributes: None,
218 })),
219 }
220 }
221}