Skip to main content

surreal_client/
record.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::fmt;
4
5/// Represents a SurrealDB record ID
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub struct RecordId {
8    pub table: String,
9    pub id: RecordIdValue,
10}
11
12/// The value part of a record ID
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(untagged)]
15pub enum RecordIdValue {
16    String(String),
17    Number(i64),
18    Object(Value),
19    Array(Vec<Value>),
20}
21
22impl RecordId {
23    /// Create a new record ID
24    pub fn new(table: impl Into<String>, id: impl Into<RecordIdValue>) -> Self {
25        Self {
26            table: table.into(),
27            id: id.into(),
28        }
29    }
30
31    /// Create a record ID with a string ID
32    pub fn string(table: impl Into<String>, id: impl Into<String>) -> Self {
33        Self::new(table, RecordIdValue::String(id.into()))
34    }
35
36    /// Create a record ID with a numeric ID
37    pub fn number(table: impl Into<String>, id: i64) -> Self {
38        Self::new(table, RecordIdValue::Number(id))
39    }
40
41    /// Create a record ID with an object ID
42    pub fn object(table: impl Into<String>, id: Value) -> Self {
43        Self::new(table, RecordIdValue::Object(id))
44    }
45
46    /// Create a record ID with an array ID
47    pub fn array(table: impl Into<String>, id: Vec<Value>) -> Self {
48        Self::new(table, RecordIdValue::Array(id))
49    }
50
51    /// Parse a record ID from a string format "table:id"
52    pub fn parse(input: &str) -> Result<Self, RecordParseError> {
53        let parts: Vec<&str> = input.splitn(2, ':').collect();
54        if parts.len() != 2 {
55            return Err(RecordParseError::InvalidFormat);
56        }
57
58        let table = parts[0].to_string();
59        let id_str = parts[1];
60
61        // Try to parse as number first
62        if let Ok(num) = id_str.parse::<i64>() {
63            return Ok(Self::number(table, num));
64        }
65
66        // Try to parse as JSON object/array
67        if (id_str.starts_with('{') || id_str.starts_with('['))
68            && let Ok(value) = serde_json::from_str::<Value>(id_str)
69        {
70            match value {
71                Value::Object(_) => return Ok(Self::object(table, value)),
72                Value::Array(arr) => return Ok(Self::array(table, arr)),
73                _ => {}
74            }
75        }
76
77        // Default to string
78        Ok(Self::string(table, id_str))
79    }
80
81    /// Convert to SurrealQL string representation
82    pub fn to_surql(&self) -> String {
83        let table = escape_identifier(&self.table);
84        let id = match &self.id {
85            RecordIdValue::String(s) => escape_identifier(s),
86            RecordIdValue::Number(n) => n.to_string(),
87            RecordIdValue::Object(obj) => obj.to_string(),
88            RecordIdValue::Array(arr) => {
89                format!(
90                    "[{}]",
91                    arr.iter()
92                        .map(|v| v.to_string())
93                        .collect::<Vec<_>>()
94                        .join(", ")
95                )
96            }
97        };
98        format!("{}:{}", table, id)
99    }
100
101    /// Get the table name
102    pub fn table(&self) -> &str {
103        &self.table
104    }
105
106    /// Get the ID value
107    pub fn id(&self) -> &RecordIdValue {
108        &self.id
109    }
110}
111
112impl fmt::Display for RecordId {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "{}", self.to_surql())
115    }
116}
117
118impl From<&str> for RecordId {
119    fn from(s: &str) -> Self {
120        Self::parse(s).unwrap_or_else(|_| Self::string("unknown", s))
121    }
122}
123
124impl From<String> for RecordId {
125    fn from(s: String) -> Self {
126        Self::from(s.as_str())
127    }
128}
129
130impl From<RecordId> for Value {
131    fn from(record_id: RecordId) -> Self {
132        Value::String(record_id.to_string())
133    }
134}
135
136impl From<&RecordId> for Value {
137    fn from(record_id: &RecordId) -> Self {
138        Value::String(record_id.to_string())
139    }
140}
141
142impl From<String> for RecordIdValue {
143    fn from(s: String) -> Self {
144        RecordIdValue::String(s)
145    }
146}
147
148impl From<&str> for RecordIdValue {
149    fn from(s: &str) -> Self {
150        RecordIdValue::String(s.to_string())
151    }
152}
153
154impl From<i64> for RecordIdValue {
155    fn from(n: i64) -> Self {
156        RecordIdValue::Number(n)
157    }
158}
159
160impl From<Value> for RecordIdValue {
161    fn from(v: Value) -> Self {
162        match v {
163            Value::String(s) => RecordIdValue::String(s),
164            Value::Number(n) => {
165                if let Some(i) = n.as_i64() {
166                    RecordIdValue::Number(i)
167                } else {
168                    RecordIdValue::Object(Value::Number(n))
169                }
170            }
171            Value::Array(arr) => RecordIdValue::Array(arr),
172            other => RecordIdValue::Object(other),
173        }
174    }
175}
176
177/// Table reference for queries
178#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
179pub struct Table {
180    pub name: String,
181}
182
183impl Table {
184    /// Create a new table reference
185    pub fn new(name: impl Into<String>) -> Self {
186        Self { name: name.into() }
187    }
188
189    /// Get the table name
190    pub fn name(&self) -> &str {
191        &self.name
192    }
193
194    /// Convert to SurrealQL string representation
195    pub fn to_surql(&self) -> String {
196        escape_identifier(&self.name)
197    }
198}
199
200impl fmt::Display for Table {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        write!(f, "{}", self.to_surql())
203    }
204}
205
206impl From<&str> for Table {
207    fn from(s: &str) -> Self {
208        Self::new(s)
209    }
210}
211
212impl From<String> for Table {
213    fn from(s: String) -> Self {
214        Self::new(s)
215    }
216}
217
218impl From<Table> for Value {
219    fn from(table: Table) -> Self {
220        Value::String(table.name)
221    }
222}
223
224impl From<&Table> for Value {
225    fn from(table: &Table) -> Self {
226        Value::String(table.name.clone())
227    }
228}
229
230/// Record range for selecting multiple records
231#[derive(Debug, Clone, PartialEq)]
232pub struct RecordRange {
233    pub table: String,
234    pub start: Option<RecordIdValue>,
235    pub end: Option<RecordIdValue>,
236    pub start_inclusive: bool,
237    pub end_inclusive: bool,
238}
239
240impl RecordRange {
241    /// Create a new record range
242    pub fn new(table: impl Into<String>) -> Self {
243        Self {
244            table: table.into(),
245            start: None,
246            end: None,
247            start_inclusive: true,
248            end_inclusive: true,
249        }
250    }
251
252    /// Set the start of the range
253    pub fn start(mut self, start: impl Into<RecordIdValue>, inclusive: bool) -> Self {
254        self.start = Some(start.into());
255        self.start_inclusive = inclusive;
256        self
257    }
258
259    /// Set the end of the range
260    pub fn end(mut self, end: impl Into<RecordIdValue>, inclusive: bool) -> Self {
261        self.end = Some(end.into());
262        self.end_inclusive = inclusive;
263        self
264    }
265
266    /// Convert to SurrealQL string representation
267    pub fn to_surql(&self) -> String {
268        let table = escape_identifier(&self.table);
269
270        let start_str = match &self.start {
271            Some(start) => {
272                let start_val = match start {
273                    RecordIdValue::String(s) => escape_identifier(s),
274                    RecordIdValue::Number(n) => n.to_string(),
275                    RecordIdValue::Object(obj) => obj.to_string(),
276                    RecordIdValue::Array(arr) => {
277                        format!(
278                            "[{}]",
279                            arr.iter()
280                                .map(|v| v.to_string())
281                                .collect::<Vec<_>>()
282                                .join(", ")
283                        )
284                    }
285                };
286                if self.start_inclusive {
287                    start_val
288                } else {
289                    format!(">{}", start_val)
290                }
291            }
292            None => String::new(),
293        };
294
295        let end_str = match &self.end {
296            Some(end) => {
297                let end_val = match end {
298                    RecordIdValue::String(s) => escape_identifier(s),
299                    RecordIdValue::Number(n) => n.to_string(),
300                    RecordIdValue::Object(obj) => obj.to_string(),
301                    RecordIdValue::Array(arr) => {
302                        format!(
303                            "[{}]",
304                            arr.iter()
305                                .map(|v| v.to_string())
306                                .collect::<Vec<_>>()
307                                .join(", ")
308                        )
309                    }
310                };
311                if self.end_inclusive {
312                    end_val
313                } else {
314                    format!("={}", end_val)
315                }
316            }
317            None => String::new(),
318        };
319
320        if start_str.is_empty() && end_str.is_empty() {
321            table
322        } else {
323            format!("{}:{}..{}", table, start_str, end_str)
324        }
325    }
326}
327
328impl fmt::Display for RecordRange {
329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330        write!(f, "{}", self.to_surql())
331    }
332}
333
334/// Error type for record ID parsing
335#[derive(Debug, Clone, PartialEq, Eq)]
336pub enum RecordParseError {
337    InvalidFormat,
338    InvalidId,
339}
340
341impl fmt::Display for RecordParseError {
342    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343        match self {
344            RecordParseError::InvalidFormat => write!(f, "Invalid record ID format"),
345            RecordParseError::InvalidId => write!(f, "Invalid record ID value"),
346        }
347    }
348}
349
350impl std::error::Error for RecordParseError {}
351
352/// Escape a SurrealDB identifier if needed
353fn escape_identifier(ident: &str) -> String {
354    // Check if identifier needs escaping
355    if ident.is_empty() {
356        return "⟨⟩".to_string();
357    }
358
359    // Check if it's numeric
360    if ident.parse::<i64>().is_ok() || ident.parse::<f64>().is_ok() {
361        return format!("⟨{}⟩", ident);
362    }
363
364    // Check if it contains special characters or starts with a number
365    if ident.chars().next().unwrap().is_ascii_digit()
366        || ident.chars().any(|c| !c.is_alphanumeric() && c != '_')
367    {
368        return format!("⟨{}⟩", ident.replace('⟩', "\\⟩"));
369    }
370
371    ident.to_string()
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_record_id_creation() {
380        let record = RecordId::string("user", "john");
381        assert_eq!(record.table, "user");
382        assert_eq!(record.id, RecordIdValue::String("john".to_string()));
383    }
384
385    #[test]
386    fn test_record_id_parsing() {
387        let record = RecordId::parse("user:123").unwrap();
388        assert_eq!(record.table, "user");
389        assert_eq!(record.id, RecordIdValue::Number(123));
390
391        let record = RecordId::parse("user:john").unwrap();
392        assert_eq!(record.table, "user");
393        assert_eq!(record.id, RecordIdValue::String("john".to_string()));
394    }
395
396    #[test]
397    fn test_record_id_surql() {
398        let record = RecordId::string("user", "john");
399        assert_eq!(record.to_surql(), "user:john");
400
401        let record = RecordId::number("user", 123);
402        assert_eq!(record.to_surql(), "user:123");
403    }
404
405    #[test]
406    fn test_table_creation() {
407        let table = Table::new("users");
408        assert_eq!(table.name(), "users");
409        assert_eq!(table.to_surql(), "users");
410    }
411
412    #[test]
413    fn test_escape_identifier() {
414        assert_eq!(escape_identifier("normal"), "normal");
415        assert_eq!(escape_identifier("123"), "⟨123⟩");
416        assert_eq!(escape_identifier("with-dash"), "⟨with-dash⟩");
417        assert_eq!(escape_identifier(""), "⟨⟩");
418    }
419
420    #[test]
421    fn test_record_range() {
422        let range = RecordRange::new("user").start("a", true).end("z", false);
423
424        assert_eq!(range.to_surql(), "user:a..=z");
425    }
426
427    #[test]
428    fn test_conversions() {
429        let record = RecordId::string("user", "john");
430        let value: Value = record.into();
431        assert_eq!(value, Value::String("user:john".to_string()));
432
433        let table = Table::new("users");
434        let value: Value = table.into();
435        assert_eq!(value, Value::String("users".to_string()));
436    }
437}