Skip to main content

solidb_client/client/
mod.rs

1mod builder;
2mod bulk;
3mod collection;
4mod database;
5mod document;
6mod http_client;
7mod index;
8mod query;
9mod transaction;
10
11pub use builder::SoliDBClientBuilder;
12pub use builder::Transport;
13pub use http_client::HttpClient;
14
15use serde_json::Value;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
18use tokio::net::TcpStream;
19
20use super::protocol::{
21    decode_message, encode_command, Command, DriverError, Response, DRIVER_MAGIC, MAX_MESSAGE_SIZE,
22};
23
24const DEFAULT_POOL_SIZE: usize = 4;
25
26struct PooledConnection {
27    read: OwnedReadHalf,
28    write: OwnedWriteHalf,
29}
30
31pub struct SoliDBClient {
32    pool: Vec<PooledConnection>,
33    next_index: usize,
34    current_tx: Option<String>,
35}
36
37impl SoliDBClient {
38    pub async fn connect(addr: &str) -> Result<Self, DriverError> {
39        Self::connect_with_pool(addr, DEFAULT_POOL_SIZE).await
40    }
41
42    pub async fn connect_with_pool(addr: &str, pool_size: usize) -> Result<Self, DriverError> {
43        let mut pool_connections: Vec<PooledConnection> = Vec::with_capacity(pool_size);
44
45        for _ in 0..pool_size {
46            let stream = TcpStream::connect(addr).await.map_err(|e| {
47                DriverError::ConnectionError(format!("Failed to connect to {}: {}", addr, e))
48            })?;
49
50            stream.set_nodelay(true).map_err(|e| {
51                DriverError::ConnectionError(format!("Failed to set TCP_NODELAY: {}", e))
52            })?;
53
54            let (read, mut write) = stream.into_split();
55
56            write.write_all(DRIVER_MAGIC).await.map_err(|e| {
57                DriverError::ConnectionError(format!("Failed to send magic header: {}", e))
58            })?;
59
60            pool_connections.push(PooledConnection { read, write });
61        }
62
63        Ok(Self {
64            pool: pool_connections,
65            next_index: 0,
66            current_tx: None,
67        })
68    }
69
70    fn get_next_connection(&mut self) -> &mut PooledConnection {
71        let idx = self.next_index;
72        self.next_index = (self.next_index + 1) % self.pool.len();
73        &mut self.pool[idx]
74    }
75
76    pub(crate) async fn send_command(&mut self, command: Command) -> Result<Response, DriverError> {
77        let conn = self.get_next_connection();
78
79        let data = encode_command(&command)?;
80        conn.write
81            .write_all(&data)
82            .await
83            .map_err(|e| DriverError::ConnectionError(format!("Write failed: {}", e)))?;
84        conn.write
85            .flush()
86            .await
87            .map_err(|e| DriverError::ConnectionError(format!("Flush failed: {}", e)))?;
88
89        let mut len_buf = [0u8; 4];
90        conn.read
91            .read_exact(&mut len_buf)
92            .await
93            .map_err(|e| DriverError::ConnectionError(format!("Read length failed: {}", e)))?;
94
95        let msg_len = u32::from_be_bytes(len_buf) as usize;
96        if msg_len > MAX_MESSAGE_SIZE {
97            return Err(DriverError::MessageTooLarge);
98        }
99
100        let mut payload = vec![0u8; msg_len];
101        conn.read
102            .read_exact(&mut payload)
103            .await
104            .map_err(|e| DriverError::ConnectionError(format!("Read payload failed: {}", e)))?;
105
106        decode_message(&payload)
107    }
108
109    pub(crate) fn extract_data(response: Response) -> Result<Option<Value>, DriverError> {
110        match response {
111            Response::Ok { data, .. } => Ok(data),
112            Response::Error { error } => Err(error),
113            Response::Pong { .. } => Ok(None),
114            Response::Batch { .. } => Ok(None),
115        }
116    }
117
118    pub(crate) fn extract_tx_id(response: Response) -> Result<String, DriverError> {
119        match response {
120            Response::Ok {
121                tx_id: Some(id), ..
122            } => Ok(id),
123            Response::Ok { .. } => Err(DriverError::ProtocolError(
124                "Expected transaction ID".to_string(),
125            )),
126            Response::Error { error } => Err(error),
127            _ => Err(DriverError::ProtocolError(
128                "Unexpected response type".to_string(),
129            )),
130        }
131    }
132
133    pub async fn ping(&mut self) -> Result<i64, DriverError> {
134        let response = self.send_command(Command::Ping).await?;
135        match response {
136            Response::Pong { timestamp } => Ok(timestamp),
137            Response::Error { error } => Err(error),
138            _ => Err(DriverError::ProtocolError(
139                "Expected pong response".to_string(),
140            )),
141        }
142    }
143
144    pub async fn auth(
145        &mut self,
146        database: &str,
147        username: &str,
148        password: &str,
149    ) -> Result<(), DriverError> {
150        let response = self
151            .send_command(Command::Auth {
152                database: database.to_string(),
153                username: username.to_string(),
154                password: password.to_string(),
155                api_key: None,
156            })
157            .await?;
158
159        match response {
160            Response::Ok { .. } => Ok(()),
161            Response::Error { error } => Err(error),
162            _ => Err(DriverError::ProtocolError(
163                "Unexpected response".to_string(),
164            )),
165        }
166    }
167
168    pub async fn auth_with_api_key(
169        &mut self,
170        database: &str,
171        api_key: &str,
172    ) -> Result<(), DriverError> {
173        let response = self
174            .send_command(Command::Auth {
175                database: database.to_string(),
176                username: String::new(),
177                password: String::new(),
178                api_key: Some(api_key.to_string()),
179            })
180            .await?;
181
182        match response {
183            Response::Ok { .. } => Ok(()),
184            Response::Error { error } => Err(error),
185            _ => Err(DriverError::ProtocolError(
186                "Unexpected response".to_string(),
187            )),
188        }
189    }
190}