1use bytes::{Buf, BufMut, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3use yykv_types::DsError;
4
5#[derive(Debug, Clone)]
6pub struct VectorSearchRequest {
7 pub collection: String,
8 pub vector: Vec<f32>,
9 pub top_k: u32,
10}
11
12#[derive(Debug, Clone)]
13pub struct VectorSearchResponse {
14 pub ids: Vec<u64>,
15 pub scores: Vec<f32>,
16}
17
18pub struct VectorCodec;
19
20impl Decoder for VectorCodec {
21 type Item = VectorSearchRequest;
22 type Error = DsError;
23
24 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
25 if src.len() < 8 {
26 return Ok(None);
27 }
28
29 let name_len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
30 let vec_dim = u32::from_le_bytes([src[4], src[5], src[6], src[7]]) as usize;
31 let total_len = 8 + name_len + vec_dim * 4 + 4;
32
33 if src.len() < total_len {
34 src.reserve(total_len);
35 return Ok(None);
36 }
37
38 src.advance(8);
39 let collection = String::from_utf8_lossy(&src.split_to(name_len)).to_string();
40
41 let mut vector = Vec::with_capacity(vec_dim);
42 for _ in 0..vec_dim {
43 vector.push(src.get_f32_le());
44 }
45
46 let top_k = src.get_u32_le();
47
48 Ok(Some(VectorSearchRequest {
49 collection,
50 vector,
51 top_k,
52 }))
53 }
54}
55
56impl Encoder<VectorSearchResponse> for VectorCodec {
57 type Error = DsError;
58
59 fn encode(
60 &mut self,
61 item: VectorSearchResponse,
62 dst: &mut BytesMut,
63 ) -> Result<(), Self::Error> {
64 dst.put_u32_le(item.ids.len() as u32);
65 for id in item.ids {
66 dst.put_u64_le(id);
67 }
68 for score in item.scores {
69 dst.put_f32_le(score);
70 }
71 Ok(())
72 }
73}