Skip to main content

we_trust_sqlite/
utils.rs

1//! SQLite 原生解析工具函数
2
3use yykv_types::{DsError, DsValue};
4
5type Result<T> = std::result::Result<T, DsError>;
6
7/// 解析 SQLite 的变长整数 (Varint)
8/// 返回 (解析出的值, 消耗的字节数)
9pub fn parse_varint(data: &[u8]) -> (u64, usize) {
10    let mut result: u64 = 0;
11    for i in 0..8 {
12        if i >= data.len() {
13            return (result, i);
14        }
15        let byte = data[i];
16        result = (result << 7) | (byte & 0x7F) as u64;
17        if byte & 0x80 == 0 {
18            return (result, i + 1);
19        }
20    }
21    if data.len() > 8 {
22        // 第 9 个字节使用全部 8 位
23        result = (result << 8) | data[8] as u64;
24        (result, 9)
25    } else {
26        (result, 8)
27    }
28}
29
30/// 解析 SQLite Record 格式
31pub fn parse_record(data: &[u8]) -> Result<Vec<DsValue>> {
32    let (header_size, mut offset) = parse_varint(data);
33    let header_end = header_size as usize;
34
35    let mut serial_types = Vec::new();
36    while offset < header_end {
37        let (serial_type, consumed) = parse_varint(&data[offset..]);
38        serial_types.push(serial_type);
39        offset += consumed;
40    }
41
42    let mut values = Vec::new();
43    let mut data_offset = header_end;
44
45    for serial_type in serial_types {
46        match serial_type {
47            0 => values.push(DsValue::Null),
48            1 => {
49                values.push(DsValue::Int(data[data_offset] as i8 as i64));
50                data_offset += 1;
51            }
52            2 => {
53                let val = i16::from_be_bytes([data[data_offset], data[data_offset + 1]]);
54                values.push(DsValue::Int(val as i64));
55                data_offset += 2;
56            }
57            3 => {
58                let val = ((data[data_offset] as i32) << 16)
59                    | ((data[data_offset + 1] as i32) << 8)
60                    | (data[data_offset + 2] as i32);
61                let val = if val & 0x800000 != 0 {
62                    val | !0xFFFFFF
63                } else {
64                    val
65                };
66                values.push(DsValue::Int(val as i64));
67                data_offset += 3;
68            }
69            4 => {
70                let val = i32::from_be_bytes([
71                    data[data_offset],
72                    data[data_offset + 1],
73                    data[data_offset + 2],
74                    data[data_offset + 3],
75                ]);
76                values.push(DsValue::Int(val as i64));
77                data_offset += 4;
78            }
79            5 => {
80                let val = ((data[data_offset] as i64) << 40)
81                    | ((data[data_offset + 1] as i64) << 32)
82                    | ((data[data_offset + 2] as i64) << 24)
83                    | ((data[data_offset + 3] as i64) << 16)
84                    | ((data[data_offset + 4] as i64) << 8)
85                    | (data[data_offset + 5] as i64);
86                let val = if val & 0x800000000000 != 0 {
87                    val | !0xFFFFFFFFFFFF
88                } else {
89                    val
90                };
91                values.push(DsValue::Int(val));
92                data_offset += 6;
93            }
94            6 => {
95                let val = i64::from_be_bytes([
96                    data[data_offset],
97                    data[data_offset + 1],
98                    data[data_offset + 2],
99                    data[data_offset + 3],
100                    data[data_offset + 4],
101                    data[data_offset + 5],
102                    data[data_offset + 6],
103                    data[data_offset + 7],
104                ]);
105                values.push(DsValue::Int(val));
106                data_offset += 8;
107            }
108            7 => {
109                let val = f64::from_be_bytes([
110                    data[data_offset],
111                    data[data_offset + 1],
112                    data[data_offset + 2],
113                    data[data_offset + 3],
114                    data[data_offset + 4],
115                    data[data_offset + 5],
116                    data[data_offset + 6],
117                    data[data_offset + 7],
118                ]);
119                values.push(DsValue::Float(val));
120                data_offset += 8;
121            }
122            8 => {
123                values.push(DsValue::Int(0));
124            }
125            9 => {
126                values.push(DsValue::Int(1));
127            }
128            t if t >= 12 && t % 2 == 0 => {
129                let len = ((t - 12) / 2) as usize;
130                let val = data[data_offset..data_offset + len].to_vec();
131                values.push(DsValue::Bytes(val.into()));
132                data_offset += len;
133            }
134            t if t >= 13 && t % 2 == 1 => {
135                let len = ((t - 13) / 2) as usize;
136                let val =
137                    String::from_utf8_lossy(&data[data_offset..data_offset + len]).to_string();
138                values.push(DsValue::Text(val));
139                data_offset += len;
140            }
141            _ => {
142                values.push(DsValue::Null);
143            }
144        }
145    }
146
147    Ok(values)
148}
149
150/// 比较两个 Value,遵循 SQLite 的比较规则:
151/// NULL < INTEGER/REAL < TEXT < BLOB
152pub fn compare_values(a: &DsValue, b: &DsValue) -> std::cmp::Ordering {
153    use DsValue::*;
154    use std::cmp::Ordering;
155
156    match (a, b) {
157        (Null, Null) => Ordering::Equal,
158        (Null, _) => Ordering::Less,
159        (_, Null) => Ordering::Greater,
160
161        // 数字比较 (Int, Float, Decimal)
162        (Int(v1), Int(v2)) => v1.cmp(v2),
163        (Float(v1), Float(v2)) => v1.partial_cmp(v2).unwrap_or(Ordering::Equal),
164        (Int(v1), Float(v2)) => (*v1 as f64).partial_cmp(v2).unwrap_or(Ordering::Equal),
165        (Float(v1), Int(v2)) => v1.partial_cmp(&(*v2 as f64)).unwrap_or(Ordering::Equal),
166
167        // Text 比较
168        (Text(s1), Text(s2)) => s1.cmp(s2),
169
170        // Bytes 比较
171        (Bytes(b1), Bytes(b2))
172        | (Binary(b1), Binary(b2))
173        | (Bytes(b1), Binary(b2))
174        | (Binary(b1), Bytes(b2)) => b1.cmp(b2),
175
176        // 不同类型的比较规则: NULL < INTEGER/REAL < TEXT < BLOB
177        (Int(_) | Float(_) | Decimal(_), Text(_)) => Ordering::Less,
178        (Text(_), Int(_) | Float(_) | Decimal(_)) => Ordering::Greater,
179
180        (Int(_) | Float(_) | Decimal(_) | Text(_), Bytes(_) | Binary(_)) => Ordering::Less,
181        (Bytes(_) | Binary(_), Int(_) | Float(_) | Decimal(_) | Text(_)) => Ordering::Greater,
182
183        // 默认实现
184        _ => Ordering::Equal,
185    }
186}
187
188/// 比较两个 Value 切片 (用于索引查找)
189pub fn compare_value_slices(a: &[DsValue], b: &[DsValue]) -> std::cmp::Ordering {
190    let len = std::cmp::min(a.len(), b.len());
191    for i in 0..len {
192        let cmp = compare_values(&a[i], &b[i]);
193        if cmp != std::cmp::Ordering::Equal {
194            return cmp;
195        }
196    }
197    a.len().cmp(&b.len())
198}
199
200/// 将 u64 编码为 SQLite 的变长整数 (Varint)
201pub fn encode_varint(mut value: u64) -> Vec<u8> {
202    if value == 0 {
203        return vec![0];
204    }
205
206    let mut res = Vec::new();
207    if value > (1 << 56) - 1 {
208        // 处理 9 字节情况
209        res.push((value & 0xFF) as u8);
210        value >>= 8;
211        for _ in 0..8 {
212            res.push(((value & 0x7F) | 0x80) as u8);
213            value >>= 7;
214        }
215    } else {
216        // 处理 1-8 字节情况
217        res.push((value & 0x7F) as u8);
218        value >>= 7;
219        while value > 0 {
220            res.push(((value & 0x7F) | 0x80) as u8);
221            value >>= 7;
222        }
223    }
224    res.reverse();
225    res
226}
227
228/// 将 YYValue 列表编码为 SQLite Record 格式
229pub fn encode_record(values: &[DsValue]) -> Vec<u8> {
230    let mut header = Vec::new();
231    let mut body = Vec::new();
232
233    for value in values {
234        match value {
235            DsValue::Null => header.extend(encode_varint(0)),
236            DsValue::Bool(b) => header.extend(encode_varint(if *b { 9 } else { 8 })),
237            DsValue::Int(v) => {
238                let v = *v;
239                if v == 0 {
240                    header.extend(encode_varint(8));
241                } else if v == 1 {
242                    header.extend(encode_varint(9));
243                } else if v >= -128 && v <= 127 {
244                    header.extend(encode_varint(1));
245                    body.push(v as u8);
246                } else if v >= -32768 && v <= 32767 {
247                    header.extend(encode_varint(2));
248                    body.extend_from_slice(&(v as i16).to_be_bytes());
249                } else if v >= -8388608 && v <= 8388607 {
250                    header.extend(encode_varint(3));
251                    let b = (v as i32).to_be_bytes();
252                    body.extend_from_slice(&b[1..]);
253                } else if v >= -2147483648 && v <= 2147483647 {
254                    header.extend(encode_varint(4));
255                    body.extend_from_slice(&(v as i32).to_be_bytes());
256                } else {
257                    header.extend(encode_varint(6));
258                    body.extend_from_slice(&v.to_be_bytes());
259                }
260            }
261            DsValue::Float(v) => {
262                header.extend(encode_varint(7));
263                body.extend_from_slice(&v.to_be_bytes());
264            }
265            DsValue::Text(s) => {
266                let len = s.len();
267                header.extend(encode_varint((len * 2 + 13) as u64));
268                body.extend_from_slice(s.as_bytes());
269            }
270            DsValue::Bytes(b) | DsValue::Binary(b) => {
271                let len = b.len();
272                header.extend(encode_varint((len * 2 + 12) as u64));
273                body.extend_from_slice(b);
274            }
275            _ => {
276                header.extend(encode_varint(0));
277            }
278        }
279    }
280
281    let mut record = encode_varint((header.len() + 1) as u64);
282    record.extend(header);
283    record.extend(body);
284    record
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_varint() {
293        let cases = vec![
294            (0, vec![0x00]),
295            (1, vec![0x01]),
296            (127, vec![0x7F]),
297            (128, vec![0x81, 0x00]),
298            (16383, vec![0xFF, 0x7F]),
299            (16384, vec![0x81, 0x80, 0x00]),
300        ];
301
302        for (val, expected) in cases {
303            let encoded = encode_varint(val);
304            assert_eq!(encoded, expected);
305            let (decoded, size) = parse_varint(&encoded);
306            assert_eq!(decoded, val);
307            assert_eq!(size, encoded.len());
308        }
309    }
310}