1use bytes::Bytes;
2use bytes::BytesMut;
3use skeg_proto::{
4 ErrCode, Flags, Frame, FrameParser, decode_key_payload, decode_mget_payload,
5 decode_set_payload, decode_vindex_create_payload, decode_vname_id_payload,
6 decode_vname_payload, decode_vsearch_payload, decode_vset_payload, encode_err, encode_ok,
7 encode_ok_bool, encode_ok_mget, encode_ok_shards, encode_ok_stats, encode_ok_value,
8 encode_ok_vindex_list, encode_ok_vsearch, f32_vec_to_bytes,
9};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::TcpStream;
12use tracing::{debug, warn};
13
14use skeg_core::Durability;
15
16use crate::shard::{ShardError, ShardSet};
17
18const DEFAULT_DURABILITY: Durability = Durability::Kernel;
23
24pub async fn handle_connection(mut stream: TcpStream, shards: ShardSet) {
25 let peer = stream.peer_addr().ok();
26 debug!(?peer, "connection accepted");
27
28 let mut parser = FrameParser::new();
29 let mut buf = BytesMut::with_capacity(64 * 1024);
30
31 loop {
32 match parser.feed(&mut buf) {
33 Ok(Some(frame)) => {
34 if let Some(response) = dispatch(&frame, &shards).await
35 && stream.write_all(&response).await.is_err()
36 {
37 break;
38 }
39 }
40 Ok(None) => match stream.read_buf(&mut buf).await {
41 Ok(0) => break,
42 Ok(_) => {}
43 Err(e) => {
44 warn!(?peer, "read error: {e}");
45 break;
46 }
47 },
48 Err(e) => {
49 warn!(?peer, "protocol error: {e}");
50 break;
51 }
52 }
53 }
54
55 debug!(?peer, "connection closed");
56}
57
58fn shard_err_to_response(req_id: u64, e: &ShardError) -> Bytes {
59 warn!("shard error: {e}");
60 encode_err(req_id, ErrCode::Internal, &e.to_string())
61}
62
63#[allow(clippy::too_many_lines)] async fn dispatch(frame: &Frame, shards: &ShardSet) -> Option<Bytes> {
66 let req_id = frame.header.req_id;
67 let payload = &frame.payload;
68
69 match frame.header.op {
70 skeg_proto::Op::Ping => Some(encode_ok(req_id)),
71
72 skeg_proto::Op::Stats => match shards.stats().await {
73 Ok(stats) => Some(encode_ok_stats(req_id, stats)),
74 Err(e) => Some(shard_err_to_response(req_id, &e)),
75 },
76
77 skeg_proto::Op::Shards => match shards.stats_per_shard().await {
78 Ok(rows) => Some(encode_ok_shards(req_id, &rows)),
79 Err(e) => Some(shard_err_to_response(req_id, &e)),
80 },
81
82 skeg_proto::Op::VindexList => match shards.vindex_list().await {
83 Ok(rows) => {
84 let info: Vec<skeg_proto::VindexInfo> = rows
85 .into_iter()
86 .map(
87 |(name, dim, kind, backend, n_vectors)| skeg_proto::VindexInfo {
88 name,
89 dim,
90 kind,
91 backend,
92 n_vectors,
93 },
94 )
95 .collect();
96 Some(encode_ok_vindex_list(req_id, &info))
97 }
98 Err(e) => Some(shard_err_to_response(req_id, &e)),
99 },
100
101 skeg_proto::Op::Get => match decode_key_payload(payload) {
102 Ok(key) => match shards.get(&key).await {
103 Ok(Some(val)) => Some(encode_ok_value(req_id, &val)),
104 Ok(None) => Some(encode_err(req_id, ErrCode::NotFound, "key not found")),
105 Err(e) => Some(shard_err_to_response(req_id, &e)),
106 },
107 Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
108 },
109
110 skeg_proto::Op::Set => match decode_set_payload(payload) {
111 Ok((key, val)) => match shards.set(&key, &val, DEFAULT_DURABILITY).await {
112 Ok(()) => {
113 if frame.header.flags.contains(Flags::NO_REPLY) {
114 None
115 } else {
116 Some(encode_ok(req_id))
117 }
118 }
119 Err(e) => Some(shard_err_to_response(req_id, &e)),
120 },
121 Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
122 },
123
124 skeg_proto::Op::Del => match decode_key_payload(payload) {
125 Ok(key) => match shards.del(&key, DEFAULT_DURABILITY).await {
126 Ok(existed) => Some(encode_ok_bool(req_id, existed)),
127 Err(e) => Some(shard_err_to_response(req_id, &e)),
128 },
129 Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
130 },
131
132 skeg_proto::Op::Mget => match decode_mget_payload(payload) {
133 Ok(keys) => match shards.mget(&keys).await {
134 Ok(results) => Some(encode_ok_mget(req_id, &results)),
135 Err(e) => Some(shard_err_to_response(req_id, &e)),
136 },
137 Err(e) => Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
138 },
139
140 skeg_proto::Op::VindexCreate => {
141 let (name, dim, kind, backend) = match decode_vindex_create_payload(payload) {
142 Ok(v) => v,
143 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
144 };
145 let Ok(name) = std::str::from_utf8(&name) else {
146 return Some(encode_err(
147 req_id,
148 ErrCode::InvalidRequest,
149 "index name not utf-8",
150 ));
151 };
152 match shards.vindex_create(name, dim, kind, backend).await {
153 Ok(()) => Some(encode_ok(req_id)),
154 Err(e) => Some(shard_err_to_response(req_id, &e)),
155 }
156 }
157
158 skeg_proto::Op::VindexDrop => {
159 let name = match decode_vname_payload(payload) {
160 Ok(v) => v,
161 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
162 };
163 let Ok(name) = std::str::from_utf8(&name) else {
164 return Some(encode_err(
165 req_id,
166 ErrCode::InvalidRequest,
167 "index name not utf-8",
168 ));
169 };
170 match shards.vindex_drop(name, 0).await {
171 Ok(()) => Some(encode_ok(req_id)),
172 Err(e) => Some(shard_err_to_response(req_id, &e)),
173 }
174 }
175
176 skeg_proto::Op::Vset => {
177 let (name, id, vector) = match decode_vset_payload(payload) {
178 Ok(v) => v,
179 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
180 };
181 let Ok(name) = std::str::from_utf8(&name) else {
182 return Some(encode_err(
183 req_id,
184 ErrCode::InvalidRequest,
185 "index name not utf-8",
186 ));
187 };
188 match shards.vset(name, id, vector, 0, None).await {
189 Ok(()) => {
190 if frame.header.flags.contains(Flags::NO_REPLY) {
191 None
192 } else {
193 Some(encode_ok(req_id))
194 }
195 }
196 Err(e) => Some(shard_err_to_response(req_id, &e)),
197 }
198 }
199
200 skeg_proto::Op::Vget => {
201 let (name, id) = match decode_vname_id_payload(payload) {
202 Ok(v) => v,
203 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
204 };
205 let Ok(name) = std::str::from_utf8(&name) else {
206 return Some(encode_err(
207 req_id,
208 ErrCode::InvalidRequest,
209 "index name not utf-8",
210 ));
211 };
212 match shards.vget(name, id).await {
213 Ok(Some(v)) => Some(encode_ok_value(req_id, &f32_vec_to_bytes(&v))),
214 Ok(None) => Some(encode_err(req_id, ErrCode::NotFound, "vector not found")),
215 Err(e) => Some(shard_err_to_response(req_id, &e)),
216 }
217 }
218
219 skeg_proto::Op::Vdel => {
220 let (name, id) = match decode_vname_id_payload(payload) {
221 Ok(v) => v,
222 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
223 };
224 let Ok(name) = std::str::from_utf8(&name) else {
225 return Some(encode_err(
226 req_id,
227 ErrCode::InvalidRequest,
228 "index name not utf-8",
229 ));
230 };
231 match shards.vdel(name, id, 0).await {
232 Ok(existed) => Some(encode_ok_bool(req_id, existed)),
233 Err(e) => Some(shard_err_to_response(req_id, &e)),
234 }
235 }
236
237 skeg_proto::Op::Vsearch => {
238 let (name, k, query, l_search) = match decode_vsearch_payload(payload) {
239 Ok(v) => v,
240 Err(e) => return Some(encode_err(req_id, ErrCode::InvalidRequest, &e.to_string())),
241 };
242 let Ok(name) = std::str::from_utf8(&name) else {
243 return Some(encode_err(
244 req_id,
245 ErrCode::InvalidRequest,
246 "index name not utf-8",
247 ));
248 };
249 let span = tracing::info_span!(
250 "vsearch",
251 protocol = "binary",
252 vindex = name,
253 k,
254 l_search,
255 vector_dim = query.len(),
256 hits = tracing::field::Empty,
257 );
258 let _guard = span.enter();
259 match shards.vsearch(name, query, k as usize, l_search).await {
260 Ok(hits) => {
261 span.record("hits", hits.len());
262 Some(encode_ok_vsearch(req_id, &hits))
263 }
264 Err(e) => Some(shard_err_to_response(req_id, &e)),
265 }
266 }
267
268 op => Some(encode_err(
269 req_id,
270 ErrCode::InvalidRequest,
271 &format!("op {op:?} not implemented"),
272 )),
273 }
274}