Skip to main content

we_trust/
lib.rs

1#![warn(missing_docs)]
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use uuid::Uuid;
5use yykv_types::layout::{DsValueDecoder, DsValueEncoder};
6pub use yykv_types::{DsError, DsValue, Redundancy};
7
8/// 数据库后端类型
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DatabaseBackend {
11    /// Limbo/SQLite 后端
12    Limbo,
13    /// PostgreSQL 后端
14    Postgres,
15    /// MySQL 后端
16    MySql,
17}
18
19/// 数据库连接结果类型
20pub type DatabaseResult<T> = Result<T, DsError>;
21
22/// 数据库连接 trait
23#[async_trait::async_trait]
24pub trait DatabaseConnection {
25    /// 获取数据库后端类型
26    fn backend(&self) -> DatabaseBackend;
27    
28    /// 执行查询
29    async fn query(&self, sql: &str) -> DatabaseResult<Box<dyn RowIterator>>;
30}
31
32/// 行迭代器 trait
33#[async_trait::async_trait]
34pub trait RowIterator: Send {
35    /// 获取下一行
36    async fn next(&mut self) -> DatabaseResult<Option<Box<dyn Row>>>;
37}
38
39/// 行 trait
40pub trait Row: Send {
41    /// 获取指定索引的字符串值
42    fn get_string(&self, index: usize) -> DatabaseResult<String>;
43    
44    /// 获取指定索引的整数值
45    fn get_i64(&self, index: usize) -> DatabaseResult<i64>;
46    
47    /// 获取指定索引的布尔值
48    fn get_bool(&self, index: usize) -> DatabaseResult<bool>;
49    
50    /// 获取指定索引的可选字符串值
51    fn get_option_string(&self, index: usize) -> DatabaseResult<Option<String>>;
52}
53
54// 导出 schema 模块
55pub mod schema;
56
57use crc32fast::Hasher;
58use futures::{SinkExt, StreamExt};
59use sha2::{Digest, Sha256};
60use std::net::SocketAddr;
61use std::str::FromStr;
62use tokio::net::TcpStream;
63use tokio_util::codec::Framed;
64
65/// Connection options for WeTrust.
66#[derive(Debug, Clone)]
67pub struct ConnectionOptions {
68    pub addr: SocketAddr,
69    pub tenant_id: Uuid,
70    pub secret_key: Vec<u8>,
71}
72
73impl FromStr for ConnectionOptions {
74    type Err = DsError;
75
76    fn from_str(s: &str) -> Result<Self, Self::Err> {
77        let mut options = ConnectionOptions {
78            addr: "127.0.0.1:8889".parse().unwrap(),
79            tenant_id: Uuid::nil(),
80            secret_key: b"yykv-secret-key-2026".to_vec(),
81        };
82
83        for part in s.split(';') {
84            let kv: Vec<&str> = part.split('=').collect();
85            if kv.len() == 2 {
86                match kv[0].to_lowercase().as_str() {
87                    "server" | "host" => {
88                        let host = kv[1];
89                        options.addr = format!("{}:8889", host)
90                            .parse()
91                            .map_err(|e| DsError::internal(format!("Invalid host: {}", e)))?;
92                    }
93                    "port" => {
94                        let port: u16 = kv[1]
95                            .parse()
96                            .map_err(|e| DsError::internal(format!("Invalid port: {}", e)))?;
97                        let mut addr = options.addr;
98                        addr.set_port(port);
99                        options.addr = addr;
100                    }
101                    "tenantid" => {
102                        options.tenant_id = Uuid::parse_str(kv[1])
103                            .map_err(|e| DsError::internal(format!("Invalid TenantID: {}", e)))?;
104                    }
105                    "secretkey" => {
106                        options.secret_key = kv[1].as_bytes().to_vec();
107                    }
108                    _ => {}
109                }
110            }
111        }
112
113        Ok(options)
114    }
115}
116
117/// Trust protocol magic bytes: "YY"
118pub const MAGIC: [u8; 2] = *b"YY";
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum MessageType {
122    Put = 1,
123    Get = 2,
124    Delete = 3,
125    Query = 4, // New: Query using OpsGraph
126    Rbq = 5,
127    Response = 6,
128    Auth = 7,
129    Value = 8,
130    Push = 9,
131    Pull = 10,
132    Heartbeat = 11,
133    Kql = 12,
134
135    PutResp = 101,
136    GetResp = 102,
137    DeleteResp = 103,
138    QueryResp = 104,
139
140    Error = 255,
141}
142
143impl From<u8> for MessageType {
144    fn from(v: u8) -> Self {
145        match v {
146            1 => MessageType::Put,
147            2 => MessageType::Get,
148            3 => MessageType::Delete,
149            4 => MessageType::Query,
150            5 => MessageType::Rbq,
151            6 => MessageType::Response,
152            7 => MessageType::Auth,
153            8 => MessageType::Value,
154            9 => MessageType::Push,
155            10 => MessageType::Pull,
156            11 => MessageType::Heartbeat,
157            12 => MessageType::Kql,
158            101 => MessageType::PutResp,
159            102 => MessageType::GetResp,
160            103 => MessageType::DeleteResp,
161            104 => MessageType::QueryResp,
162            _ => MessageType::Error,
163        }
164    }
165}
166
167/// Trust Message Header (80 bytes)
168/// | Magic (2B) | Version (1B) | Type (1B) | Flags (4B) | Length (4B) | Checksum (4B) |
169/// | RequestID (16B) | TenantID (16B) | Signature (32B) |
170/// Flags:
171/// - Bits 0-7: SDR Level
172#[derive(Debug, Clone, PartialEq)]
173pub struct TrustHeader {
174    pub version: u8,
175    pub msg_type: u8,
176    pub flags: u32,
177    pub length: u32,
178    pub checksum: u32,
179    pub request_id: Uuid,
180    pub tenant_id: Uuid,
181    pub signature: [u8; 32],
182}
183
184impl TrustHeader {
185    pub const SIZE: usize = 80;
186
187    pub fn sdr_level(&self) -> Redundancy {
188        Redundancy::from_u8((self.flags & 0xFF) as u8)
189    }
190
191    pub fn set_sdr_level(&mut self, level: Redundancy) {
192        self.flags = (self.flags & !0xFF) | (level.0 as u32 & 0xFF);
193    }
194
195    pub fn sign(&mut self, secret: &[u8]) {
196        let mut hasher = Sha256::new();
197        hasher.update(secret);
198        hasher.update(self.request_id.as_bytes());
199        hasher.update(self.tenant_id.as_bytes());
200        hasher.update(self.checksum.to_be_bytes());
201        hasher.update(self.flags.to_be_bytes());
202        let hash = hasher.finalize();
203        self.signature.copy_from_slice(&hash);
204    }
205
206    pub fn verify(&self, secret: &[u8]) -> bool {
207        let mut hasher = Sha256::new();
208        hasher.update(secret);
209        hasher.update(self.request_id.as_bytes());
210        hasher.update(self.tenant_id.as_bytes());
211        hasher.update(self.checksum.to_be_bytes());
212        hasher.update(self.flags.to_be_bytes());
213        let hash = hasher.finalize();
214        self.signature == hash.as_slice()
215    }
216
217    pub fn encode<B: BufMut>(&self, mut dst: B) {
218        dst.put_slice(&MAGIC);
219        dst.put_u8(self.version);
220        dst.put_u8(self.msg_type);
221        dst.put_u32(self.flags);
222        dst.put_u32(self.length);
223        dst.put_u32(self.checksum);
224        dst.put_slice(self.request_id.as_bytes());
225        dst.put_slice(self.tenant_id.as_bytes());
226        dst.put_slice(&self.signature);
227    }
228
229    pub fn decode(src: &mut BytesMut) -> Result<Self, DsError> {
230        if src.len() < Self::SIZE {
231            return Err(DsError::protocol("Insufficient data for header"));
232        }
233
234        let magic = [src[0], src[1]];
235        if magic != MAGIC {
236            return Err(DsError::protocol(format!("Invalid magic: {:?}", magic)));
237        }
238
239        let version = src[2];
240        let msg_type = src[3];
241        let flags = u32::from_be_bytes([src[4], src[5], src[6], src[7]]);
242        let length = u32::from_be_bytes([src[8], src[9], src[10], src[11]]);
243        let checksum = u32::from_be_bytes([src[12], src[13], src[14], src[15]]);
244
245        let request_id = Uuid::from_slice(&src[16..32])
246            .map_err(|e| DsError::protocol(format!("Invalid request ID: {}", e)))?;
247        let tenant_id = Uuid::from_slice(&src[32..48])
248            .map_err(|e| DsError::protocol(format!("Invalid tenant ID: {}", e)))?;
249
250        let mut signature = [0u8; 32];
251        signature.copy_from_slice(&src[48..80]);
252
253        src.advance(Self::SIZE);
254
255        Ok(Self {
256            version,
257            msg_type,
258            flags,
259            length,
260            checksum,
261            request_id,
262            tenant_id,
263            signature,
264        })
265    }
266}
267
268/// A Trust protocol message
269#[derive(Debug)]
270pub struct TrustMessage {
271    pub header: TrustHeader,
272    pub payload: Bytes,
273}
274
275impl TrustMessage {
276    pub fn new(msg_type: MessageType, tenant_id: Uuid, payload: Bytes) -> Self {
277        let mut hasher = Hasher::new();
278        hasher.update(&payload);
279        let checksum = hasher.finalize();
280
281        TrustMessage {
282            header: TrustHeader {
283                version: 1,
284                msg_type: msg_type as u8,
285                flags: 0,
286                length: payload.len() as u32,
287                checksum,
288                request_id: Uuid::new_v4(),
289                tenant_id,
290                signature: [0u8; 32],
291            },
292            payload,
293        }
294    }
295
296    pub fn encode<B: BufMut>(&self, mut dst: B) {
297        self.header.encode(&mut dst);
298        dst.put(self.payload.clone());
299    }
300}
301
302pub struct TrustCodec;
303
304impl tokio_util::codec::Decoder for TrustCodec {
305    type Item = TrustMessage;
306    type Error = DsError;
307
308    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
309        if src.len() < TrustHeader::SIZE {
310            return Ok(None);
311        }
312
313        // Peek payload length without consuming
314        let mut length_bytes = [0u8; 4];
315        length_bytes.copy_from_slice(&src[8..12]);
316        let payload_len = u32::from_be_bytes(length_bytes) as usize;
317        let total_length = TrustHeader::SIZE + payload_len;
318
319        if src.len() < total_length {
320            src.reserve(total_length - src.len());
321            return Ok(None);
322        }
323
324        let header = TrustHeader::decode(src)?;
325        let payload = src.split_to(payload_len).freeze();
326
327        // Verify payload checksum
328        let mut hasher = Hasher::new();
329        hasher.update(&payload);
330        if hasher.finalize() != header.checksum {
331            return Err(DsError::protocol("Payload checksum mismatch"));
332        }
333
334        Ok(Some(TrustMessage { header, payload }))
335    }
336}
337
338impl tokio_util::codec::Encoder<TrustMessage> for TrustCodec {
339    type Error = DsError;
340
341    fn encode(&mut self, item: TrustMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
342        item.encode(dst);
343        Ok(())
344    }
345}
346
347/// Codec for encoding/decoding DsValue with TrustHeader.
348pub struct DsValueCodec;
349
350impl DsValueCodec {
351    pub fn encode(
352        value: &DsValue,
353        tenant_id: Uuid,
354        sdr_level: Redundancy,
355    ) -> Result<Bytes, DsError> {
356        let mut result = BytesMut::new();
357        // Reserve space for TrustHeader
358        result.resize(TrustHeader::SIZE, 0);
359
360        // Encode DsValue directly into result after TrustHeader
361        DsValueEncoder::encode_to_buf(value, &mut result)?;
362
363        let total_len = result.len();
364
365        // Calculate checksum on the encoded payload
366        let mut hasher = Hasher::new();
367        hasher.update(&result[TrustHeader::SIZE..]);
368        let checksum = hasher.finalize();
369
370        let mut header = TrustHeader {
371            version: 1,
372            msg_type: MessageType::Value as u8,
373            flags: 0,
374            length: total_len as u32,
375            checksum,
376            request_id: Uuid::new_v4(),
377            tenant_id,
378            signature: [0u8; 32],
379        };
380        header.set_sdr_level(sdr_level);
381
382        // Encode header into the reserved space at the beginning
383        let mut header_part = &mut result[..TrustHeader::SIZE];
384        header.encode(&mut header_part);
385
386        Ok(result.freeze())
387    }
388
389    pub fn decode(mut data: BytesMut) -> Result<(DsValue, TrustHeader), DsError> {
390        let header = TrustHeader::decode(&mut data)?;
391        let mut payload = data.freeze();
392        let value = DsValueDecoder::decode(&mut payload)?;
393        Ok((value, header))
394    }
395}
396
397/// A client for the WeTrust protocol, similar to the JDBC implementation.
398pub struct WeTrustClient {
399    framed: Framed<TcpStream, TrustCodec>,
400    tenant_id: Uuid,
401    secret_key: Vec<u8>,
402}
403
404impl WeTrustClient {
405    pub async fn connect(
406        addr: SocketAddr,
407        tenant_id: Uuid,
408        secret_key: Vec<u8>,
409    ) -> Result<Self, DsError> {
410        let stream = TcpStream::connect(addr)
411            .await
412            .map_err(|e| DsError::io_raw(e, Some(addr.to_string().into())))?;
413        let mut framed = Framed::new(stream, TrustCodec);
414
415        // Handshake: Send Auth message
416        let mut auth_msg = TrustMessage::new(MessageType::Auth, tenant_id, Bytes::from("auth-v1"));
417        auth_msg.header.sign(&secret_key);
418
419        framed.send(auth_msg).await?;
420
421        // Wait for Auth response
422        if let Some(resp) = framed.next().await {
423            let resp = resp?;
424            if resp.header.msg_type != MessageType::Response as u8 {
425                return Err(DsError::protocol(
426                    "Unexpected message type during handshake",
427                ));
428            }
429            if !resp.header.verify(&secret_key) {
430                return Err(DsError::protocol("Handshake signature verification failed"));
431            }
432        } else {
433            return Err(DsError::protocol("Connection closed during handshake"));
434        }
435
436        Ok(Self {
437            framed,
438            tenant_id,
439            secret_key,
440        })
441    }
442
443    pub async fn send_request(
444        &mut self,
445        msg_type: MessageType,
446        payload: Bytes,
447    ) -> Result<TrustMessage, DsError> {
448        let mut msg = TrustMessage::new(msg_type, self.tenant_id, payload);
449        msg.header.sign(&self.secret_key);
450
451        self.framed.send(msg).await?;
452
453        if let Some(resp) = self.framed.next().await {
454            let resp = resp?;
455            if !resp.header.verify(&self.secret_key) {
456                return Err(DsError::protocol("Message signature verification failed"));
457            }
458            Ok(resp)
459        } else {
460            Err(DsError::protocol("Connection closed by server"))
461        }
462    }
463
464    pub async fn send_query(&mut self, sql: &str) -> Result<Vec<Vec<DsValue>>, DsError> {
465        let _resp = self
466            .send_request(MessageType::Kql, Bytes::copy_from_slice(sql.as_bytes()))
467            .await?;
468
469        // In a real implementation, we would decode the payload into rows.
470        // For now, we return a mock success result.
471        Ok(vec![vec![DsValue::Text(format!("Executed: {}", sql))]])
472    }
473
474    pub async fn put(&mut self, key: &str, value: DsValue) -> Result<(), DsError> {
475        let value_data = DsValueEncoder::encode(&value)?;
476        let mut payload = BytesMut::with_capacity(4 + key.len() + value_data.len());
477        payload.put_u32(key.len() as u32);
478        payload.put_slice(key.as_bytes());
479        payload.put(value_data);
480
481        self.send_request(MessageType::Put, payload.freeze())
482            .await?;
483        Ok(())
484    }
485
486    pub async fn get(&mut self, key: &str) -> Result<Option<DsValue>, DsError> {
487        let mut payload = BytesMut::with_capacity(4 + key.len());
488        payload.put_u32(key.len() as u32);
489        payload.put_slice(key.as_bytes());
490
491        let resp = self
492            .send_request(MessageType::Get, payload.freeze())
493            .await?;
494        if resp.header.msg_type == MessageType::Error as u8 {
495            return Ok(None);
496        }
497
498        let mut data = resp.payload;
499        if data.is_empty() {
500            return Ok(None);
501        }
502
503        // Use DsValueDecoder to decode the value from the response
504        let val = DsValueDecoder::decode(&mut data)?;
505        Ok(Some(val))
506    }
507
508    pub async fn delete(&mut self, key: &str) -> Result<(), DsError> {
509        let mut payload = BytesMut::with_capacity(4 + key.len());
510        payload.put_u32(key.len() as u32);
511        payload.put_slice(key.as_bytes());
512        self.send_request(MessageType::Delete, payload.freeze())
513            .await?;
514        Ok(())
515    }
516
517    pub async fn kql(&mut self, query: &str) -> Result<DsValue, DsError> {
518        let resp = self
519            .send_request(MessageType::Kql, Bytes::copy_from_slice(query.as_bytes()))
520            .await?;
521        if resp.header.msg_type == MessageType::Error as u8 {
522            return Err(DsError::query_with_sql(
523                query,
524                String::from_utf8_lossy(&resp.payload).to_string(),
525            ));
526        }
527        let mut data = resp.payload;
528        let value = DsValueDecoder::decode(&mut data)?;
529        Ok(value)
530    }
531
532    pub async fn heartbeat(&mut self) -> Result<(), DsError> {
533        self.send_request(MessageType::Heartbeat, Bytes::new())
534            .await?;
535        Ok(())
536    }
537}