Skip to main content

skeg_server/
handler.rs

1use bytes::Bytes;
2use bytes::BytesMut;
3use skeg_proto::{
4    ErrCode, Flags, Frame, FrameParser, decode_key_payload, decode_mget_payload,
5    decode_set_payload, decode_vindex_create_payload, decode_vname_id_payload,
6    decode_vname_payload, decode_vsearch_payload, decode_vset_payload, encode_err, encode_ok,
7    encode_ok_bool, encode_ok_mget, encode_ok_shards, encode_ok_stats, encode_ok_value,
8    encode_ok_vindex_list, encode_ok_vsearch, f32_vec_to_bytes,
9};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::TcpStream;
12use tracing::{debug, warn};
13
14use skeg_core::Durability;
15
16use crate::shard::{ShardError, ShardSet};
17
18/// Durability applied to writes that do not request one explicitly.
19/// `Kernel` survives a process crash and kernel panic without paying the
20/// `F_FULLFSYNC` cost on every write - the right default for AI workloads
21/// (see `design-write-perf.md`).
22const DEFAULT_DURABILITY: Durability = Durability::Kernel;
23
24pub async fn handle_connection(mut stream: TcpStream, shards: ShardSet) {
25    let peer = stream.peer_addr().ok();
26    debug!(?peer, "connection accepted");
27
28    let mut parser = FrameParser::new();
29    let mut buf = BytesMut::with_capacity(64 * 1024);
30
31    loop {
32        match parser.feed(&mut buf) {
33            Ok(Some(frame)) => {
34                if let Some(response) = dispatch(&frame, &shards).await
35                    && stream.write_all(&response).await.is_err()
36                {
37                    break;
38                }
39            }
40            Ok(None) => match stream.read_buf(&mut buf).await {
41                Ok(0) => break,
42                Ok(_) => {}
43                Err(e) => {
44                    warn!(?peer, "read error: {e}");
45                    break;
46                }
47            },
48            Err(e) => {
49                warn!(?peer, "protocol error: {e}");
50                break;
51            }
52        }
53    }
54
55    debug!(?peer, "connection closed");
56}
57
58fn shard_err_to_response(req_id: u64, e: &ShardError) -> Bytes {
59    warn!("shard error: {e}");
60    encode_err(req_id, ErrCode::Internal, &e.to_string())
61}
62
63/// Dispatch a parsed frame to the shard set and return an optional response.
64#[allow(clippy::too_many_lines)] // one arm per protocol op; splitting hurts readability
65async fn dispatch(frame: &Frame, shards: &ShardSet) -> Option<Bytes> {
66    let req_id = frame.header.req_id;
67    let payload = &frame.payload;
68
69    match frame.header.op {
70        skeg_proto::Op::Ping => Some(encode_ok(req_id)),
71
72        skeg_proto::Op::Stats => match shards.stats().await {
73            Ok(stats) => Some(encode_ok_stats(req_id, stats)),
74            Err(e) => Some(shard_err_to_response(req_id, &e)),
75        },
76
77        skeg_proto::Op::Shards => match shards.stats_per_shard().await {
78            Ok(rows) => Some(encode_ok_shards(req_id, &rows)),
79            Err(e) => Some(shard_err_to_response(req_id, &e)),
80        },
81
82        skeg_proto::Op::VindexList => match shards.vindex_list().await {
83            Ok(rows) => {
84                let info: Vec<skeg_proto::VindexInfo> = rows
85                    .into_iter()
86                    .map(
87                        |(name, dim, kind, backend, n_vectors)| skeg_proto::VindexInfo {
88                            name,
89                            dim,
90                            kind,
91                            backend,
92                            n_vectors,
93                        },
94                    )
95                    .collect();
96                Some(encode_ok_vindex_list(req_id, &info))
97            }
98            Err(e) => Some(shard_err_to_response(req_id, &e)),
99        },
100
101        skeg_proto::Op::Get => match decode_key_payload(payload) {
102            Ok(key) => match shards.get(&key).await {
103                Ok(Some(val)) => Some(encode_ok_value(req_id, &val)),
104                Ok(None) => Some(encode_err(req_id, ErrCode::NotFound, "key not found")),
105                Err(e) => Some(shard_err_to_response(req_id, &e)),
106            },
107            Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
108        },
109
110        skeg_proto::Op::Set => match decode_set_payload(payload) {
111            Ok((key, val)) => match shards.set(&key, &val, DEFAULT_DURABILITY).await {
112                Ok(()) => {
113                    if frame.header.flags.contains(Flags::NO_REPLY) {
114                        None
115                    } else {
116                        Some(encode_ok(req_id))
117                    }
118                }
119                Err(e) => Some(shard_err_to_response(req_id, &e)),
120            },
121            Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
122        },
123
124        skeg_proto::Op::Del => match decode_key_payload(payload) {
125            Ok(key) => match shards.del(&key, DEFAULT_DURABILITY).await {
126                Ok(existed) => Some(encode_ok_bool(req_id, existed)),
127                Err(e) => Some(shard_err_to_response(req_id, &e)),
128            },
129            Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
130        },
131
132        skeg_proto::Op::Mget => match decode_mget_payload(payload) {
133            Ok(keys) => match shards.mget(&keys).await {
134                Ok(results) => Some(encode_ok_mget(req_id, &results)),
135                Err(e) => Some(shard_err_to_response(req_id, &e)),
136            },
137            Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
138        },
139
140        skeg_proto::Op::VindexCreate => {
141            let (name, dim, kind, backend) = match decode_vindex_create_payload(payload) {
142                Ok(v) => v,
143                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
144            };
145            let Ok(name) = std::str::from_utf8(&name) else {
146                return Some(encode_err(
147                    req_id,
148                    ErrCode::InvalidRequest,
149                    "index name not utf-8",
150                ));
151            };
152            match shards.vindex_create(name, dim, kind, backend).await {
153                Ok(()) => Some(encode_ok(req_id)),
154                Err(e) => Some(shard_err_to_response(req_id, &e)),
155            }
156        }
157
158        skeg_proto::Op::VindexDrop => {
159            let name = match decode_vname_payload(payload) {
160                Ok(v) => v,
161                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
162            };
163            let Ok(name) = std::str::from_utf8(&name) else {
164                return Some(encode_err(
165                    req_id,
166                    ErrCode::InvalidRequest,
167                    "index name not utf-8",
168                ));
169            };
170            match shards.vindex_drop(name, 0).await {
171                Ok(()) => Some(encode_ok(req_id)),
172                Err(e) => Some(shard_err_to_response(req_id, &e)),
173            }
174        }
175
176        skeg_proto::Op::Vset => {
177            let (name, id, vector) = match decode_vset_payload(payload) {
178                Ok(v) => v,
179                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
180            };
181            let Ok(name) = std::str::from_utf8(&name) else {
182                return Some(encode_err(
183                    req_id,
184                    ErrCode::InvalidRequest,
185                    "index name not utf-8",
186                ));
187            };
188            match shards.vset(name, id, vector, 0, None).await {
189                Ok(()) => {
190                    if frame.header.flags.contains(Flags::NO_REPLY) {
191                        None
192                    } else {
193                        Some(encode_ok(req_id))
194                    }
195                }
196                Err(e) => Some(shard_err_to_response(req_id, &e)),
197            }
198        }
199
200        skeg_proto::Op::Vget => {
201            let (name, id) = match decode_vname_id_payload(payload) {
202                Ok(v) => v,
203                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
204            };
205            let Ok(name) = std::str::from_utf8(&name) else {
206                return Some(encode_err(
207                    req_id,
208                    ErrCode::InvalidRequest,
209                    "index name not utf-8",
210                ));
211            };
212            match shards.vget(name, id).await {
213                Ok(Some(v)) => Some(encode_ok_value(req_id, &f32_vec_to_bytes(&v))),
214                Ok(None) => Some(encode_err(req_id, ErrCode::NotFound, "vector not found")),
215                Err(e) => Some(shard_err_to_response(req_id, &e)),
216            }
217        }
218
219        skeg_proto::Op::Vdel => {
220            let (name, id) = match decode_vname_id_payload(payload) {
221                Ok(v) => v,
222                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
223            };
224            let Ok(name) = std::str::from_utf8(&name) else {
225                return Some(encode_err(
226                    req_id,
227                    ErrCode::InvalidRequest,
228                    "index name not utf-8",
229                ));
230            };
231            match shards.vdel(name, id, 0).await {
232                Ok(existed) => Some(encode_ok_bool(req_id, existed)),
233                Err(e) => Some(shard_err_to_response(req_id, &e)),
234            }
235        }
236
237        skeg_proto::Op::Vsearch => {
238            let (name, k, query, l_search) = match decode_vsearch_payload(payload) {
239                Ok(v) => v,
240                Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
241            };
242            let Ok(name) = std::str::from_utf8(&name) else {
243                return Some(encode_err(
244                    req_id,
245                    ErrCode::InvalidRequest,
246                    "index name not utf-8",
247                ));
248            };
249            let span = tracing::info_span!(
250                "vsearch",
251                protocol = "binary",
252                vindex = name,
253                k,
254                l_search,
255                vector_dim = query.len(),
256                hits = tracing::field::Empty,
257            );
258            let _guard = span.enter();
259            match shards.vsearch(name, query, k as usize, l_search).await {
260                Ok(hits) => {
261                    span.record("hits", hits.len());
262                    Some(encode_ok_vsearch(req_id, &hits))
263                }
264                Err(e) => Some(shard_err_to_response(req_id, &e)),
265            }
266        }
267
268        op => Some(encode_err(
269            req_id,
270            ErrCode::InvalidRequest,
271            &format!("op {op:?} not implemented"),
272        )),
273    }
274}