Skip to main content

we_trust_vector/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3use yykv_types::DsError;
4
5#[derive(Debug, Clone)]
6pub struct VectorSearchRequest {
7    pub collection: String,
8    pub vector: Vec<f32>,
9    pub top_k: u32,
10}
11
12#[derive(Debug, Clone)]
13pub struct VectorSearchResponse {
14    pub ids: Vec<u64>,
15    pub scores: Vec<f32>,
16}
17
18pub struct VectorCodec;
19
20impl Decoder for VectorCodec {
21    type Item = VectorSearchRequest;
22    type Error = DsError;
23
24    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
25        if src.len() < 8 {
26            return Ok(None);
27        }
28
29        let name_len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
30        let vec_dim = u32::from_le_bytes([src[4], src[5], src[6], src[7]]) as usize;
31        let total_len = 8 + name_len + vec_dim * 4 + 4;
32
33        if src.len() < total_len {
34            src.reserve(total_len);
35            return Ok(None);
36        }
37
38        src.advance(8);
39        let collection = String::from_utf8_lossy(&src.split_to(name_len)).to_string();
40
41        let mut vector = Vec::with_capacity(vec_dim);
42        for _ in 0..vec_dim {
43            vector.push(src.get_f32_le());
44        }
45
46        let top_k = src.get_u32_le();
47
48        Ok(Some(VectorSearchRequest {
49            collection,
50            vector,
51            top_k,
52        }))
53    }
54}
55
56impl Encoder<VectorSearchResponse> for VectorCodec {
57    type Error = DsError;
58
59    fn encode(
60        &mut self,
61        item: VectorSearchResponse,
62        dst: &mut BytesMut,
63    ) -> Result<(), Self::Error> {
64        dst.put_u32_le(item.ids.len() as u32);
65        for id in item.ids {
66            dst.put_u64_le(id);
67        }
68        for score in item.scores {
69            dst.put_f32_le(score);
70        }
71        Ok(())
72    }
73}