Skip to main content

solidb_client/client/
mod.rs

1mod builder;
2mod bulk;
3mod collection;
4mod database;
5mod document;
6mod index;
7mod query;
8mod transaction;
9
10pub use builder::SoliDBClientBuilder;
11
12use serde_json::Value;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
15use tokio::net::TcpStream;
16
17use super::protocol::{
18    decode_message, encode_command, Command, DriverError, Response, DRIVER_MAGIC, MAX_MESSAGE_SIZE,
19};
20
21const DEFAULT_POOL_SIZE: usize = 4;
22
23struct PooledConnection {
24    read: OwnedReadHalf,
25    write: OwnedWriteHalf,
26}
27
28pub struct SoliDBClient {
29    pool: Vec<PooledConnection>,
30    next_index: usize,
31    current_tx: Option<String>,
32}
33
34impl SoliDBClient {
35    pub async fn connect(addr: &str) -> Result<Self, DriverError> {
36        Self::connect_with_pool(addr, DEFAULT_POOL_SIZE).await
37    }
38
39    pub async fn connect_with_pool(addr: &str, pool_size: usize) -> Result<Self, DriverError> {
40        let mut pool_connections: Vec<PooledConnection> = Vec::with_capacity(pool_size);
41
42        for _ in 0..pool_size {
43            let stream = TcpStream::connect(addr).await.map_err(|e| {
44                DriverError::ConnectionError(format!("Failed to connect to {}: {}", addr, e))
45            })?;
46
47            stream.set_nodelay(true).map_err(|e| {
48                DriverError::ConnectionError(format!("Failed to set TCP_NODELAY: {}", e))
49            })?;
50
51            let (read, mut write) = stream.into_split();
52
53            write.write_all(DRIVER_MAGIC).await.map_err(|e| {
54                DriverError::ConnectionError(format!("Failed to send magic header: {}", e))
55            })?;
56
57            pool_connections.push(PooledConnection { read, write });
58        }
59
60        Ok(Self {
61            pool: pool_connections,
62            next_index: 0,
63            current_tx: None,
64        })
65    }
66
67    fn get_next_connection(&mut self) -> &mut PooledConnection {
68        let idx = self.next_index;
69        self.next_index = (self.next_index + 1) % self.pool.len();
70        &mut self.pool[idx]
71    }
72
73    pub(crate) async fn send_command(&mut self, command: Command) -> Result<Response, DriverError> {
74        let conn = self.get_next_connection();
75
76        let data = encode_command(&command)?;
77        conn.write
78            .write_all(&data)
79            .await
80            .map_err(|e| DriverError::ConnectionError(format!("Write failed: {}", e)))?;
81        conn.write
82            .flush()
83            .await
84            .map_err(|e| DriverError::ConnectionError(format!("Flush failed: {}", e)))?;
85
86        let mut len_buf = [0u8; 4];
87        conn.read
88            .read_exact(&mut len_buf)
89            .await
90            .map_err(|e| DriverError::ConnectionError(format!("Read length failed: {}", e)))?;
91
92        let msg_len = u32::from_be_bytes(len_buf) as usize;
93        if msg_len > MAX_MESSAGE_SIZE {
94            return Err(DriverError::MessageTooLarge);
95        }
96
97        let mut payload = vec![0u8; msg_len];
98        conn.read
99            .read_exact(&mut payload)
100            .await
101            .map_err(|e| DriverError::ConnectionError(format!("Read payload failed: {}", e)))?;
102
103        decode_message(&payload)
104    }
105
106    pub(crate) fn extract_data(response: Response) -> Result<Option<Value>, DriverError> {
107        match response {
108            Response::Ok { data, .. } => Ok(data),
109            Response::Error { error } => Err(error),
110            Response::Pong { .. } => Ok(None),
111            Response::Batch { .. } => Ok(None),
112        }
113    }
114
115    pub(crate) fn extract_tx_id(response: Response) -> Result<String, DriverError> {
116        match response {
117            Response::Ok {
118                tx_id: Some(id), ..
119            } => Ok(id),
120            Response::Ok { .. } => Err(DriverError::ProtocolError(
121                "Expected transaction ID".to_string(),
122            )),
123            Response::Error { error } => Err(error),
124            _ => Err(DriverError::ProtocolError(
125                "Unexpected response type".to_string(),
126            )),
127        }
128    }
129
130    pub async fn ping(&mut self) -> Result<i64, DriverError> {
131        let response = self.send_command(Command::Ping).await?;
132        match response {
133            Response::Pong { timestamp } => Ok(timestamp),
134            Response::Error { error } => Err(error),
135            _ => Err(DriverError::ProtocolError(
136                "Expected pong response".to_string(),
137            )),
138        }
139    }
140
141    pub async fn auth(
142        &mut self,
143        database: &str,
144        username: &str,
145        password: &str,
146    ) -> Result<(), DriverError> {
147        let response = self
148            .send_command(Command::Auth {
149                database: database.to_string(),
150                username: username.to_string(),
151                password: password.to_string(),
152                api_key: None,
153            })
154            .await?;
155
156        match response {
157            Response::Ok { .. } => Ok(()),
158            Response::Error { error } => Err(error),
159            _ => Err(DriverError::ProtocolError(
160                "Unexpected response".to_string(),
161            )),
162        }
163    }
164
165    pub async fn auth_with_api_key(
166        &mut self,
167        database: &str,
168        api_key: &str,
169    ) -> Result<(), DriverError> {
170        let response = self
171            .send_command(Command::Auth {
172                database: database.to_string(),
173                username: String::new(),
174                password: String::new(),
175                api_key: Some(api_key.to_string()),
176            })
177            .await?;
178
179        match response {
180            Response::Ok { .. } => Ok(()),
181            Response::Error { error } => Err(error),
182            _ => Err(DriverError::ProtocolError(
183                "Unexpected response".to_string(),
184            )),
185        }
186    }
187}