1use crate::{DebugEngine, Engine, Result, SurrealClient, SurrealError, WsCborEngine};
4
5use serde_json::Value;
6use url::Url;
7
8#[derive(Default, Debug, Clone)]
10pub struct SurrealConnection {
11 pub url: Option<String>,
13
14 namespace: Option<String>,
16
17 database: Option<String>,
19
20 auth: Option<AuthParams>,
22
23 version_check: bool,
25
26 debug: bool,
28}
29
30#[derive(Debug, Clone)]
32pub enum AuthParams {
33 Root { username: String, password: String },
35 Namespace { username: String, password: String },
37 Database { username: String, password: String },
39 Scope {
41 namespace: String,
42 database: String,
43 scope: String,
44 params: Value,
45 },
46 Token(String),
48}
49
50impl SurrealConnection {
51 pub fn new() -> Self {
53 Self {
54 version_check: true,
55 debug: false,
56 ..Default::default()
57 }
58 }
59
60 pub fn dsn(dsn: impl AsRef<str>) -> Result<Self> {
62 let mut conn = Self::new();
63 let url = Url::parse(dsn.as_ref())?;
64
65 if url.host().is_none() {
67 return Err(SurrealError::Connection(
68 "URL must have a valid host".to_string(),
69 ));
70 }
71
72 let base_url = format!("{}://{}", url.scheme(), url.host_str().unwrap());
74 let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
75 let final_url = format!("{}{}", base_url, port);
76 conn.url = Some(final_url);
77
78 if !url.username().is_empty() {
80 let username = url.username().to_string();
81 let password = url.password().unwrap_or("").to_string();
82 conn.auth = Some(AuthParams::Root { username, password });
83 }
84
85 let path_segments: Vec<&str> = url.path_segments().map(|c| c.collect()).unwrap_or_default();
87
88 if let Some(namespace) = path_segments.first().filter(|s| !s.is_empty()) {
89 conn.namespace = Some(namespace.to_string());
90 }
91 if let Some(database) = path_segments.get(1).filter(|s| !s.is_empty()) {
92 conn.database = Some(database.to_string());
93 }
94
95 for (key, value) in url.query_pairs() {
97 match key.as_ref() {
98 "namespace" => conn.namespace = Some(value.into_owned()),
99 "database" => conn.database = Some(value.into_owned()),
100 "version_check" => {
101 conn.version_check = value.parse().unwrap_or(true);
102 }
103 _ => {}
104 }
105 }
106
107 Ok(conn)
108 }
109
110 pub fn url(mut self, url: impl Into<String>) -> Self {
112 self.url = Some(url.into());
113 self
114 }
115
116 pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
118 self.namespace = Some(namespace.into());
119 self
120 }
121
122 pub fn database(mut self, database: impl Into<String>) -> Self {
124 self.database = Some(database.into());
125 self
126 }
127
128 pub fn auth_root(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
130 self.auth = Some(AuthParams::Root {
131 username: username.into(),
132 password: password.into(),
133 });
134 self
135 }
136
137 pub fn auth_namespace(
139 mut self,
140 username: impl Into<String>,
141 password: impl Into<String>,
142 ) -> Self {
143 self.auth = Some(AuthParams::Namespace {
144 username: username.into(),
145 password: password.into(),
146 });
147 self
148 }
149
150 pub fn auth_database(
152 mut self,
153 username: impl Into<String>,
154 password: impl Into<String>,
155 ) -> Self {
156 self.auth = Some(AuthParams::Database {
157 username: username.into(),
158 password: password.into(),
159 });
160 self
161 }
162
163 pub fn auth_scope(
165 mut self,
166 namespace: impl Into<String>,
167 database: impl Into<String>,
168 scope: impl Into<String>,
169 params: Value,
170 ) -> Self {
171 self.auth = Some(AuthParams::Scope {
172 namespace: namespace.into(),
173 database: database.into(),
174 scope: scope.into(),
175 params,
176 });
177 self
178 }
179
180 pub fn auth_token(mut self, token: impl Into<String>) -> Self {
182 self.auth = Some(AuthParams::Token(token.into()));
183 self
184 }
185
186 pub fn version_check(mut self, check: bool) -> Self {
188 self.version_check = check;
189 self
190 }
191
192 pub fn with_debug(mut self, enabled: bool) -> Self {
194 self.debug = enabled;
195 self
196 }
197
198 pub(crate) async fn init_engine(&self, engine: &mut crate::WsCborEngine) -> Result<()> {
205 use ciborium::Value as CborValue;
206
207 match self.auth.as_ref().ok_or(SurrealError::Connection(
208 "Attempted to connect without auth".to_string(),
209 ))? {
210 AuthParams::Root { username, password } => {
211 let auth_params = CborValue::Array(vec![CborValue::Map(vec![
212 (
213 CborValue::Text("user".to_string()),
214 CborValue::Text(username.clone()),
215 ),
216 (
217 CborValue::Text("pass".to_string()),
218 CborValue::Text(password.clone()),
219 ),
220 ])]);
221 engine.send_message_cbor("signin", auth_params).await?;
222 }
223 AuthParams::Namespace { username, password } => {
224 let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
225 "Namespace is required for namespace auth".to_string(),
226 ))?;
227 let auth_params = CborValue::Array(vec![CborValue::Map(vec![
228 (
229 CborValue::Text("user".to_string()),
230 CborValue::Text(username.clone()),
231 ),
232 (
233 CborValue::Text("pass".to_string()),
234 CborValue::Text(password.clone()),
235 ),
236 (
237 CborValue::Text("NS".to_string()),
238 CborValue::Text(namespace),
239 ),
240 ])]);
241 engine.send_message_cbor("signin", auth_params).await?;
242 }
243 AuthParams::Database { username, password } => {
244 let namespace = self.namespace.clone().ok_or(SurrealError::Connection(
245 "Namespace is required for database auth".to_string(),
246 ))?;
247 let database = self.database.clone().ok_or(SurrealError::Connection(
248 "Database is required for database auth".to_string(),
249 ))?;
250 let auth_params = CborValue::Array(vec![CborValue::Map(vec![
251 (
252 CborValue::Text("user".to_string()),
253 CborValue::Text(username.clone()),
254 ),
255 (
256 CborValue::Text("pass".to_string()),
257 CborValue::Text(password.clone()),
258 ),
259 (
260 CborValue::Text("NS".to_string()),
261 CborValue::Text(namespace),
262 ),
263 (CborValue::Text("DB".to_string()), CborValue::Text(database)),
264 ])]);
265 engine.send_message_cbor("signin", auth_params).await?;
266 }
267 _ => {
268 return Err(SurrealError::Connection(
269 "Unsupported authentication method".to_string(),
270 ));
271 }
272 }
273
274 if let Some(namespace) = &self.namespace {
275 let use_params = CborValue::Array(vec![
276 CborValue::Text(namespace.clone()),
277 CborValue::Text(self.database.as_ref().unwrap_or(&String::new()).clone()),
278 ]);
279 engine.send_message_cbor("use", use_params).await?;
280 }
281
282 Ok(())
283 }
284
285 pub async fn connect(self) -> Result<SurrealClient> {
287 let url_str = self
288 .url
289 .as_ref()
290 .ok_or_else(|| SurrealError::Connection("URL is required".to_string()))?;
291 let url = Url::parse(url_str)
292 .map_err(|e| SurrealError::Connection(format!("Invalid URL: {}", e)))?;
293
294 let mut engine: Box<dyn Engine> = match url.scheme() {
295 "ws" | "wss" | "cbor" => Box::new(WsCborEngine::from_connection(&self).await?),
296 _ => {
297 return Err(SurrealError::Protocol(
298 "Unsupported protocol. Use ws://, wss://, or cbor://".to_string(),
299 ));
300 }
301 };
302
303 if self.debug {
304 engine = DebugEngine::wrap(engine);
305 }
306
307 let client = SurrealClient::new(engine, self.namespace, self.database);
308 Ok(client.with_debug(self.debug))
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_connection_builder() {
318 let conn = SurrealConnection::new()
319 .url("ws://localhost:8000")
320 .namespace("test_ns")
321 .database("test_db")
322 .auth_root("root", "root")
323 .version_check(false);
324
325 assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
326 assert_eq!(conn.namespace, Some("test_ns".to_string()));
327 assert_eq!(conn.database, Some("test_db".to_string()));
328 assert!(!conn.version_check);
329 assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
330 }
331
332 #[test]
333 fn test_dsn_parsing() {
334 let conn = SurrealConnection::dsn(
335 "ws://root:root@localhost:8000/test_ns/test_db?version_check=false",
336 )
337 .unwrap();
338
339 assert_eq!(conn.url, Some("ws://localhost:8000".to_string()));
340 assert_eq!(conn.namespace, Some("test_ns".to_string()));
341 assert_eq!(conn.database, Some("test_db".to_string()));
342 assert!(!conn.version_check);
343 assert!(matches!(conn.auth, Some(AuthParams::Root { .. })));
344 }
345
346 #[test]
347 fn test_dsn_with_query_params() {
348 let conn =
349 SurrealConnection::dsn("http://localhost:8000?namespace=ns&database=db").unwrap();
350
351 assert_eq!(conn.url, Some("http://localhost:8000".to_string()));
352 assert_eq!(conn.namespace, Some("ns".to_string()));
353 assert_eq!(conn.database, Some("db".to_string()));
354 }
355
356 #[test]
357 fn test_auth_methods() {
358 let conn1 = SurrealConnection::new().auth_root("admin", "pass");
359 assert!(matches!(conn1.auth, Some(AuthParams::Root { .. })));
360
361 let conn2 = SurrealConnection::new().auth_namespace("ns_user", "ns_pass");
362 assert!(matches!(conn2.auth, Some(AuthParams::Namespace { .. })));
363
364 let conn3 = SurrealConnection::new().auth_database("db_user", "db_pass");
365 assert!(matches!(conn3.auth, Some(AuthParams::Database { .. })));
366
367 let conn4 = SurrealConnection::new().auth_token("jwt_token");
368 assert!(matches!(conn4.auth, Some(AuthParams::Token(_))));
369 }
370
371 #[tokio::test]
372 async fn test_connection_to_client_flow() {
373 let connection = SurrealConnection::new()
387 .url("ws://localhost:8000")
388 .namespace("test_namespace")
389 .database("test_database")
390 .auth_root("admin", "password")
391 .version_check(false);
392
393 assert_eq!(connection.url, Some("ws://localhost:8000".to_string()));
394 assert_eq!(connection.namespace, Some("test_namespace".to_string()));
395 assert_eq!(connection.database, Some("test_database".to_string()));
396 assert!(!connection.version_check);
397 assert!(matches!(connection.auth, Some(AuthParams::Root { .. })));
398
399 }
405}