scylladb_parse/statements/
dml.rs

1use super::*;
2use crate::{
3    ArithmeticOp,
4    BindMarker,
5    DurationLiteral,
6    ListLiteral,
7    Operator,
8    ReservedKeyword,
9    TupleLiteral,
10};
11
12#[derive(ParseFromStr, Clone, Debug, TryInto, From, ToTokens, PartialEq, Eq)]
13#[parse_via(TaggedDataManipulationStatement)]
14pub enum DataManipulationStatement {
15    Select(SelectStatement),
16    Insert(InsertStatement),
17    Update(UpdateStatement),
18    Delete(DeleteStatement),
19    Batch(BatchStatement),
20}
21
22impl TryFrom<TaggedDataManipulationStatement> for DataManipulationStatement {
23    type Error = anyhow::Error;
24    fn try_from(value: TaggedDataManipulationStatement) -> Result<Self, Self::Error> {
25        Ok(match value {
26            TaggedDataManipulationStatement::Select(value) => DataManipulationStatement::Select(value.try_into()?),
27            TaggedDataManipulationStatement::Insert(value) => DataManipulationStatement::Insert(value.try_into()?),
28            TaggedDataManipulationStatement::Update(value) => DataManipulationStatement::Update(value.try_into()?),
29            TaggedDataManipulationStatement::Delete(value) => DataManipulationStatement::Delete(value.try_into()?),
30            TaggedDataManipulationStatement::Batch(value) => DataManipulationStatement::Batch(value.try_into()?),
31        })
32    }
33}
34
35#[derive(ParseFromStr, Clone, Debug, TryInto, From, ToTokens, PartialEq, Eq)]
36#[tokenize_as(DataManipulationStatement)]
37pub enum TaggedDataManipulationStatement {
38    Select(TaggedSelectStatement),
39    Insert(TaggedInsertStatement),
40    Update(TaggedUpdateStatement),
41    Delete(TaggedDeleteStatement),
42    Batch(TaggedBatchStatement),
43}
44
45impl Parse for TaggedDataManipulationStatement {
46    type Output = Self;
47    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
48        Ok(if let Some(keyword) = s.find::<ReservedKeyword>() {
49            match keyword {
50                ReservedKeyword::SELECT => Self::Select(s.parse()?),
51                ReservedKeyword::INSERT => Self::Insert(s.parse()?),
52                ReservedKeyword::UPDATE => Self::Update(s.parse()?),
53                ReservedKeyword::DELETE => Self::Delete(s.parse()?),
54                ReservedKeyword::BATCH => Self::Batch(s.parse()?),
55                _ => anyhow::bail!("Expected a data manipulation statement, found {}", s.info()),
56            }
57        } else {
58            anyhow::bail!("Expected a data manipulation statement, found {}", s.info())
59        })
60    }
61}
62
63impl Display for DataManipulationStatement {
64    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Select(s) => s.fmt(f),
67            Self::Insert(s) => s.fmt(f),
68            Self::Update(s) => s.fmt(f),
69            Self::Delete(s) => s.fmt(f),
70            Self::Batch(s) => s.fmt(f),
71        }
72    }
73}
74
75impl KeyspaceExt for DataManipulationStatement {
76    fn get_keyspace(&self) -> Option<String> {
77        match self {
78            DataManipulationStatement::Select(s) => s.get_keyspace(),
79            DataManipulationStatement::Insert(s) => s.get_keyspace(),
80            DataManipulationStatement::Update(s) => s.get_keyspace(),
81            DataManipulationStatement::Delete(s) => s.get_keyspace(),
82            DataManipulationStatement::Batch(_) => None,
83        }
84    }
85
86    fn set_keyspace(&mut self, keyspace: impl Into<Name>) {
87        match self {
88            DataManipulationStatement::Select(s) => s.set_keyspace(keyspace),
89            DataManipulationStatement::Insert(s) => s.set_keyspace(keyspace),
90            DataManipulationStatement::Update(s) => s.set_keyspace(keyspace),
91            DataManipulationStatement::Delete(s) => s.set_keyspace(keyspace),
92            DataManipulationStatement::Batch(_) => (),
93        }
94    }
95}
96
97#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
98#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
99#[parse_via(TaggedSelectStatement)]
100pub struct SelectStatement {
101    #[builder(setter(name = "set_distinct"), default)]
102    pub distinct: bool,
103    #[builder(setter(into))]
104    pub select_clause: SelectClause,
105    #[builder(setter(into))]
106    pub from: KeyspaceQualifiedName,
107    #[builder(setter(into), default)]
108    pub where_clause: Option<WhereClause>,
109    #[builder(setter(into), default)]
110    pub group_by_clause: Option<GroupByClause>,
111    #[builder(setter(into), default)]
112    pub order_by_clause: Option<OrderByClause>,
113    #[builder(setter(into), default)]
114    pub per_partition_limit: Option<Limit>,
115    #[builder(setter(into), default)]
116    pub limit: Option<Limit>,
117    #[builder(setter(name = "set_allow_filtering"), default)]
118    pub allow_filtering: bool,
119    #[builder(setter(name = "set_bypass_cache"), default)]
120    pub bypass_cache: bool,
121    #[builder(setter(into), default)]
122    pub timeout: Option<DurationLiteral>,
123}
124
125impl TryFrom<TaggedSelectStatement> for SelectStatement {
126    type Error = anyhow::Error;
127    fn try_from(value: TaggedSelectStatement) -> Result<Self, Self::Error> {
128        Ok(Self {
129            distinct: value.distinct,
130            select_clause: value.select_clause.into_value()?,
131            from: value.from.try_into()?,
132            where_clause: value.where_clause.map(|v| v.into_value()).transpose()?,
133            group_by_clause: value.group_by_clause.map(|v| v.into_value()).transpose()?,
134            order_by_clause: value.order_by_clause.map(|v| v.into_value()).transpose()?,
135            per_partition_limit: value.per_partition_limit.map(|v| v.into_value()).transpose()?,
136            limit: value.limit.map(|v| v.into_value()).transpose()?,
137            allow_filtering: value.allow_filtering,
138            bypass_cache: value.bypass_cache,
139            timeout: value.timeout.map(|v| v.into_value()).transpose()?,
140        })
141    }
142}
143
144#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
145#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
146#[tokenize_as(SelectStatement)]
147pub struct TaggedSelectStatement {
148    #[builder(setter(name = "set_distinct"), default)]
149    pub distinct: bool,
150    pub select_clause: Tag<SelectClause>,
151    pub from: TaggedKeyspaceQualifiedName,
152    #[builder(default)]
153    pub where_clause: Option<Tag<WhereClause>>,
154    #[builder(default)]
155    pub group_by_clause: Option<Tag<GroupByClause>>,
156    #[builder(default)]
157    pub order_by_clause: Option<Tag<OrderByClause>>,
158    #[builder(default)]
159    pub per_partition_limit: Option<Tag<Limit>>,
160    #[builder(default)]
161    pub limit: Option<Tag<Limit>>,
162    #[builder(setter(name = "set_allow_filtering"), default)]
163    pub allow_filtering: bool,
164    #[builder(setter(name = "set_bypass_cache"), default)]
165    pub bypass_cache: bool,
166    #[builder(default)]
167    pub timeout: Option<Tag<DurationLiteral>>,
168}
169
170impl SelectStatementBuilder {
171    /// Set DISTINCT on the statement
172    /// To undo this, use `set_distinct(false)`
173    pub fn distinct(&mut self) -> &mut Self {
174        self.distinct.replace(true);
175        self
176    }
177
178    /// Set ALLOW FILTERING on the statement
179    /// To undo this, use `set_allow_filtering(false)`
180    pub fn allow_filtering(&mut self) -> &mut Self {
181        self.allow_filtering.replace(true);
182        self
183    }
184
185    /// Set BYPASS CACHE on the statement
186    /// To undo this, use `set_bypass_cache(false)`
187    pub fn bypass_cache(&mut self) -> &mut Self {
188        self.bypass_cache.replace(true);
189        self
190    }
191
192    fn validate(&self) -> Result<(), String> {
193        if self
194            .select_clause
195            .as_ref()
196            .map(|s| match s {
197                SelectClause::Selectors(s) => s.is_empty(),
198                _ => false,
199            })
200            .unwrap_or(false)
201        {
202            return Err("SELECT clause selectors cannot be empty".to_string());
203        }
204        Ok(())
205    }
206}
207
208impl TaggedSelectStatementBuilder {
209    fn validate(&self) -> Result<(), String> {
210        if self
211            .select_clause
212            .as_ref()
213            .map(|s| match s {
214                Tag::Value(SelectClause::Selectors(s)) => s.is_empty(),
215                _ => false,
216            })
217            .unwrap_or(false)
218        {
219            return Err("SELECT clause selectors cannot be empty".to_string());
220        }
221        Ok(())
222    }
223}
224
225impl Parse for TaggedSelectStatement {
226    type Output = Self;
227    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
228    where
229        Self: Sized,
230    {
231        s.parse::<SELECT>()?;
232        let mut res = TaggedSelectStatementBuilder::default();
233        res.set_distinct(s.parse::<Option<DISTINCT>>()?.is_some())
234            .select_clause(s.parse()?)
235            .from(s.parse::<(FROM, _)>()?.1);
236        loop {
237            if s.remaining() == 0 || s.parse::<Option<Semicolon>>()?.is_some() {
238                break;
239            }
240            if let Some(where_clause) = s.parse()? {
241                if res.where_clause.is_some() {
242                    anyhow::bail!("Duplicate WHERE clause!");
243                }
244                res.where_clause(where_clause);
245            } else if let Some(group_by_clause) = s.parse()? {
246                if res.group_by_clause.is_some() {
247                    anyhow::bail!("Duplicate GROUP BY clause!");
248                }
249                res.group_by_clause(group_by_clause);
250            } else if let Some(order_by_clause) = s.parse()? {
251                if res.order_by_clause.is_some() {
252                    anyhow::bail!("Duplicate ORDER BY clause!");
253                }
254                res.order_by_clause(order_by_clause);
255            } else if s.parse::<Option<(PER, PARTITION, LIMIT)>>()?.is_some() {
256                if res.per_partition_limit.is_some() {
257                    anyhow::bail!("Duplicate PER PARTITION LIMIT clause!");
258                }
259                res.per_partition_limit(s.parse()?);
260            } else if s.parse::<Option<LIMIT>>()?.is_some() {
261                if res.limit.is_some() {
262                    anyhow::bail!("Duplicate LIMIT clause!");
263                }
264                res.limit(s.parse()?);
265            } else if s.parse::<Option<(ALLOW, FILTERING)>>()?.is_some() {
266                if res.allow_filtering.is_some() {
267                    anyhow::bail!("Duplicate ALLOW FILTERING clause!");
268                }
269                res.set_allow_filtering(true);
270            } else if s.parse::<Option<(BYPASS, CACHE)>>()?.is_some() {
271                if res.bypass_cache.is_some() {
272                    anyhow::bail!("Duplicate BYPASS CACHE clause!");
273                }
274                res.set_bypass_cache(true);
275            } else if let Some(t) = s.parse_from::<If<(USING, TIMEOUT), Tag<DurationLiteral>>>()? {
276                if res.timeout.is_some() {
277                    anyhow::bail!("Duplicate USING TIMEOUT clause!");
278                }
279                res.timeout(t);
280            } else {
281                return Ok(res
282                    .build()
283                    .map_err(|_| anyhow::anyhow!("Invalid tokens in SELECT statement: {}", s.info()))?);
284            }
285        }
286        Ok(res
287            .build()
288            .map_err(|e| anyhow::anyhow!("Invalid SELECT statement: {}", e))?)
289    }
290}
291
292impl Display for SelectStatement {
293    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        write!(
295            f,
296            "SELECT {}{} FROM {}",
297            if self.distinct { "DISTINCT " } else { "" },
298            self.select_clause,
299            self.from
300        )?;
301        if let Some(where_clause) = &self.where_clause {
302            write!(f, " {}", where_clause)?;
303        }
304        if let Some(group_by_clause) = &self.group_by_clause {
305            write!(f, " {}", group_by_clause)?;
306        }
307        if let Some(order_by_clause) = &self.order_by_clause {
308            write!(f, " {}", order_by_clause)?;
309        }
310        if let Some(per_partition_limit) = &self.per_partition_limit {
311            write!(f, " PER PARTITION LIMIT {}", per_partition_limit)?;
312        }
313        if let Some(limit) = &self.limit {
314            write!(f, " LIMIT {}", limit)?;
315        }
316        if self.allow_filtering {
317            write!(f, " ALLOW FILTERING")?;
318        }
319        if self.bypass_cache {
320            write!(f, " BYPASS CACHE")?;
321        }
322        if let Some(timeout) = &self.timeout {
323            write!(f, " USING TIMEOUT {}", timeout)?;
324        }
325        Ok(())
326    }
327}
328
329impl KeyspaceExt for SelectStatement {
330    fn get_keyspace(&self) -> Option<String> {
331        self.from.keyspace.as_ref().map(|n| n.to_string())
332    }
333
334    fn set_keyspace(&mut self, keyspace: impl Into<Name>) {
335        self.from.keyspace.replace(keyspace.into());
336    }
337}
338
339impl WhereExt for SelectStatement {
340    fn iter_where(&self) -> Option<std::slice::Iter<Relation>> {
341        self.where_clause.as_ref().map(|w| w.relations.iter())
342    }
343}
344
345#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
346pub enum SelectClause {
347    All,
348    Selectors(Vec<Selector>),
349}
350
351impl Parse for SelectClause {
352    type Output = Self;
353    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
354    where
355        Self: Sized,
356    {
357        Ok(if s.parse::<Option<Star>>()?.is_some() {
358            SelectClause::All
359        } else {
360            SelectClause::Selectors(s.parse_from::<List<Selector, Comma>>()?)
361        })
362    }
363}
364
365impl Display for SelectClause {
366    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
367        match self {
368            SelectClause::All => write!(f, "*"),
369            SelectClause::Selectors(selectors) => {
370                for (i, selector) in selectors.iter().enumerate() {
371                    if i > 0 {
372                        write!(f, ", ")?;
373                    }
374                    selector.fmt(f)?;
375                }
376                Ok(())
377            }
378        }
379    }
380}
381
382impl From<Vec<Selector>> for SelectClause {
383    fn from(selectors: Vec<Selector>) -> Self {
384        SelectClause::Selectors(selectors)
385    }
386}
387
388#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
389pub struct Selector {
390    #[builder(setter(into))]
391    pub kind: SelectorKind,
392    #[builder(setter(strip_option), default)]
393    pub as_id: Option<Name>,
394}
395
396impl Selector {
397    pub fn column(name: impl Into<Name>) -> Self {
398        Selector {
399            kind: SelectorKind::Column(name.into()),
400            as_id: Default::default(),
401        }
402    }
403
404    pub fn term(term: impl Into<Term>) -> Self {
405        Selector {
406            kind: SelectorKind::Term(term.into()),
407            as_id: Default::default(),
408        }
409    }
410
411    pub fn cast(self, ty: impl Into<CqlType>) -> Self {
412        Selector {
413            kind: SelectorKind::Cast(Box::new(self), ty.into()),
414            as_id: Default::default(),
415        }
416    }
417
418    pub fn function(function: SelectorFunction) -> Self {
419        Selector {
420            kind: SelectorKind::Function(function),
421            as_id: Default::default(),
422        }
423    }
424
425    pub fn count() -> Self {
426        Selector {
427            kind: SelectorKind::Count,
428            as_id: Default::default(),
429        }
430    }
431
432    pub fn as_id(self, name: impl Into<Name>) -> Self {
433        Self {
434            kind: self.kind,
435            as_id: Some(name.into()),
436        }
437    }
438}
439
440impl Parse for Selector {
441    type Output = Self;
442    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
443    where
444        Self: Sized,
445    {
446        let (kind, as_id) = s.parse::<(SelectorKind, Option<(AS, Name)>)>()?;
447        Ok(Self {
448            kind,
449            as_id: as_id.map(|(_, id)| id),
450        })
451    }
452}
453
454impl Display for Selector {
455    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
456        self.kind.fmt(f)?;
457        if let Some(id) = &self.as_id {
458            write!(f, " AS {}", id)?;
459        }
460        Ok(())
461    }
462}
463
464#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
465pub struct SelectorFunction {
466    pub function: Name,
467    pub args: Vec<Selector>,
468}
469
470impl SelectorFunction {
471    pub fn new(function: Name) -> Self {
472        SelectorFunction {
473            function,
474            args: Vec::new(),
475        }
476    }
477
478    pub fn arg(mut self, arg: Selector) -> Self {
479        self.args.push(arg);
480        self
481    }
482}
483
484impl Parse for SelectorFunction {
485    type Output = Self;
486    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
487    where
488        Self: Sized,
489    {
490        let (function, args) = s.parse_from::<(Name, Parens<List<Selector, Comma>>)>()?;
491        Ok(SelectorFunction { function, args })
492    }
493}
494
495impl Display for SelectorFunction {
496    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
497        write!(
498            f,
499            "{}({})",
500            self.function,
501            self.args.iter().map(|s| s.to_string()).collect::<Vec<_>>().join(", ")
502        )
503    }
504}
505
506#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
507pub enum SelectorKind {
508    Column(Name),
509    Term(Term),
510    Cast(Box<Selector>, CqlType),
511    Function(SelectorFunction),
512    Count,
513}
514
515impl Parse for SelectorKind {
516    type Output = Self;
517    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
518    where
519        Self: Sized,
520    {
521        Ok(if s.parse::<Option<CAST>>()?.is_some() {
522            let (selector, _, cql_type) = s.parse_from::<Parens<(Selector, AS, CqlType)>>()?;
523            Self::Cast(Box::new(selector), cql_type)
524        } else if s.parse::<Option<COUNT>>()?.is_some() {
525            // TODO: Double check that this is ok
526            s.parse_from::<Parens<char>>()?;
527            Self::Count
528        } else if let Some(f) = s.parse()? {
529            Self::Function(f)
530        } else if let Some(id) = s.parse()? {
531            Self::Column(id)
532        } else if let Some(term) = s.parse()? {
533            Self::Term(term)
534        } else {
535            anyhow::bail!("Expected selector, found {}", s.info())
536        })
537    }
538}
539
540impl Display for SelectorKind {
541    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
542        match self {
543            SelectorKind::Column(id) => id.fmt(f),
544            SelectorKind::Term(term) => term.fmt(f),
545            SelectorKind::Cast(selector, cql_type) => write!(f, "CAST({} AS {})", selector, cql_type),
546            SelectorKind::Function(func) => func.fmt(f),
547            SelectorKind::Count => write!(f, "COUNT(*)"),
548        }
549    }
550}
551
552#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
553#[builder(setter(strip_option))]
554#[parse_via(TaggedInsertStatement)]
555pub struct InsertStatement {
556    #[builder(setter(into))]
557    pub table: KeyspaceQualifiedName,
558    #[builder(setter(into))]
559    pub kind: InsertKind,
560    #[builder(setter(name = "set_if_not_exists"), default)]
561    pub if_not_exists: bool,
562    #[builder(default)]
563    pub using: Option<Vec<UpdateParameter>>,
564}
565
566impl TryFrom<TaggedInsertStatement> for InsertStatement {
567    type Error = anyhow::Error;
568    fn try_from(value: TaggedInsertStatement) -> Result<Self, Self::Error> {
569        Ok(InsertStatement {
570            table: value.table.try_into()?,
571            kind: value.kind,
572            if_not_exists: value.if_not_exists,
573            using: value.using.map(|v| v.into_value()).transpose()?,
574        })
575    }
576}
577
578#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
579#[builder(setter(strip_option))]
580#[tokenize_as(InsertStatement)]
581pub struct TaggedInsertStatement {
582    pub table: TaggedKeyspaceQualifiedName,
583    pub kind: InsertKind,
584    #[builder(setter(name = "set_if_not_exists"), default)]
585    pub if_not_exists: bool,
586    #[builder(default)]
587    pub using: Option<Tag<Vec<UpdateParameter>>>,
588}
589
590impl InsertStatementBuilder {
591    /// Set IF NOT EXISTS on the statement.
592    /// To undo this, use `set_if_not_exists(false)`.
593    pub fn if_not_exists(&mut self) -> &mut Self {
594        self.if_not_exists.replace(true);
595        self
596    }
597}
598
599impl Parse for TaggedInsertStatement {
600    type Output = Self;
601    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
602        s.parse::<(INSERT, INTO)>()?;
603        let mut res = TaggedInsertStatementBuilder::default();
604        res.table(s.parse()?).kind(s.parse()?);
605        loop {
606            if s.remaining() == 0 || s.parse::<Option<Semicolon>>()?.is_some() {
607                break;
608            }
609            if s.parse::<Option<(IF, NOT, EXISTS)>>()?.is_some() {
610                if res.if_not_exists.is_some() {
611                    anyhow::bail!("Duplicate IF NOT EXISTS clause!");
612                }
613                res.set_if_not_exists(true);
614            } else if s.parse::<Option<USING>>()?.is_some() {
615                if res.using.is_some() {
616                    anyhow::bail!("Duplicate USING clause!");
617                }
618                res.using(s.parse_from::<Tag<List<UpdateParameter, AND>>>()?);
619            } else {
620                return Ok(res
621                    .build()
622                    .map_err(|_| anyhow::anyhow!("Invalid tokens in INSERT statement: {}", s.info()))?);
623            }
624        }
625        Ok(res
626            .build()
627            .map_err(|e| anyhow::anyhow!("Invalid INSERT statement: {}", e))?)
628    }
629}
630
631impl Display for InsertStatement {
632    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
633        write!(f, "INSERT INTO {} {}", self.table, self.kind)?;
634        if self.if_not_exists {
635            write!(f, " IF NOT EXISTS")?;
636        }
637        if let Some(using) = &self.using {
638            if !using.is_empty() {
639                write!(
640                    f,
641                    " USING {}",
642                    using.iter().map(|p| p.to_string()).collect::<Vec<_>>().join(" AND ")
643                )?;
644            }
645        }
646        Ok(())
647    }
648}
649
650impl KeyspaceExt for InsertStatement {
651    fn get_keyspace(&self) -> Option<String> {
652        self.table.keyspace.as_ref().map(|n| n.to_string())
653    }
654
655    fn set_keyspace(&mut self, keyspace: impl Into<Name>) {
656        self.table.keyspace.replace(keyspace.into());
657    }
658}
659
660#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
661pub enum InsertKind {
662    NameValue {
663        names: Vec<Name>,
664        values: TupleLiteral,
665    },
666    Json {
667        json: LitStr,
668        default: Option<ColumnDefault>,
669    },
670}
671
672impl InsertKind {
673    pub fn name_value(names: Vec<Name>, values: Vec<Term>) -> anyhow::Result<Self> {
674        if names.is_empty() {
675            anyhow::bail!("No column names specified!");
676        }
677        if values.is_empty() {
678            anyhow::bail!("No values specified!");
679        }
680        if names.len() != values.len() {
681            anyhow::bail!(
682                "Number of column names and values do not match! ({} names vs {} values)",
683                names.len(),
684                values.len()
685            );
686        }
687        Ok(Self::NameValue {
688            names,
689            values: values.into(),
690        })
691    }
692
693    pub fn json<S: Into<LitStr>, O: Into<Option<ColumnDefault>>>(json: S, default: O) -> Self {
694        Self::Json {
695            json: json.into(),
696            default: default.into(),
697        }
698    }
699}
700
701impl Parse for InsertKind {
702    type Output = Self;
703    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
704        if s.parse::<Option<JSON>>()?.is_some() {
705            let (json, default) = s.parse_from::<(LitStr, Option<(DEFAULT, ColumnDefault)>)>()?;
706            Ok(Self::Json {
707                json,
708                default: default.map(|(_, d)| d),
709            })
710        } else {
711            let (names, _, values) = s.parse_from::<(Parens<List<Name, Comma>>, VALUES, TupleLiteral)>()?;
712            if names.len() != values.elements.len() {
713                anyhow::bail!(
714                    "Number of column names and values do not match! ({} names vs {} values)",
715                    names.len(),
716                    values.elements.len()
717                );
718            }
719            Ok(Self::NameValue { names, values })
720        }
721    }
722}
723
724impl Display for InsertKind {
725    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
726        match self {
727            InsertKind::NameValue { names, values } => write!(
728                f,
729                "({}) VALUES {}",
730                names.iter().map(|p| p.to_string()).collect::<Vec<_>>().join(", "),
731                values
732            ),
733            InsertKind::Json { json, default } => {
734                write!(f, "JSON {}", json)?;
735                if let Some(default) = default {
736                    write!(f, " DEFAULT {}", default)?;
737                }
738                Ok(())
739            }
740        }
741    }
742}
743
744#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
745pub enum UpdateParameter {
746    TTL(Limit),
747    Timestamp(Limit),
748    Timeout(DurationLiteral),
749}
750
751impl UpdateParameter {
752    pub fn ttl(limit: impl Into<Limit>) -> Self {
753        Self::TTL(limit.into())
754    }
755
756    pub fn timestamp(limit: impl Into<Limit>) -> Self {
757        Self::Timestamp(limit.into())
758    }
759
760    pub fn timeout(duration: impl Into<DurationLiteral>) -> Self {
761        Self::Timeout(duration.into())
762    }
763}
764
765impl Parse for UpdateParameter {
766    type Output = Self;
767    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
768        if s.parse::<Option<TTL>>()?.is_some() {
769            Ok(UpdateParameter::TTL(s.parse()?))
770        } else if s.parse::<Option<TIMESTAMP>>()?.is_some() {
771            Ok(UpdateParameter::Timestamp(s.parse()?))
772        } else if s.parse::<Option<TIMEOUT>>()?.is_some() {
773            Ok(UpdateParameter::Timeout(s.parse()?))
774        } else {
775            anyhow::bail!("Expected update parameter, found {}", s.info())
776        }
777    }
778}
779
780impl Display for UpdateParameter {
781    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
782        match self {
783            UpdateParameter::TTL(limit) => write!(f, "TTL {}", limit),
784            UpdateParameter::Timestamp(limit) => write!(f, "TIMESTAMP {}", limit),
785            UpdateParameter::Timeout(duration) => write!(f, "TIMEOUT {}", duration),
786        }
787    }
788}
789
790#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
791#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
792#[parse_via(TaggedUpdateStatement)]
793pub struct UpdateStatement {
794    #[builder(setter(into))]
795    pub table: KeyspaceQualifiedName,
796    #[builder(default)]
797    pub using: Option<Vec<UpdateParameter>>,
798    pub set_clause: Vec<Assignment>,
799    #[builder(setter(into))]
800    pub where_clause: WhereClause,
801    #[builder(setter(into), default)]
802    pub if_clause: Option<IfClause>,
803}
804
805impl TryFrom<TaggedUpdateStatement> for UpdateStatement {
806    type Error = anyhow::Error;
807    fn try_from(value: TaggedUpdateStatement) -> Result<Self, Self::Error> {
808        Ok(Self {
809            table: value.table.try_into()?,
810            using: value.using.map(|v| v.into_value()).transpose()?,
811            set_clause: value.set_clause.into_value()?,
812            where_clause: value.where_clause.into_value()?,
813            if_clause: value.if_clause.map(|v| v.into_value()).transpose()?,
814        })
815    }
816}
817
818#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
819#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
820#[tokenize_as(UpdateStatement)]
821pub struct TaggedUpdateStatement {
822    pub table: TaggedKeyspaceQualifiedName,
823    #[builder(default)]
824    pub using: Option<Tag<Vec<UpdateParameter>>>,
825    pub set_clause: Tag<Vec<Assignment>>,
826    pub where_clause: Tag<WhereClause>,
827    #[builder(default)]
828    pub if_clause: Option<Tag<IfClause>>,
829}
830
831impl UpdateStatementBuilder {
832    /// Set IF EXISTS on the statement.
833    pub fn if_exists(&mut self) -> &mut Self {
834        self.if_clause.replace(Some(IfClause::Exists));
835        self
836    }
837
838    fn validate(&self) -> Result<(), String> {
839        if self.set_clause.as_ref().map(|s| s.is_empty()).unwrap_or(false) {
840            return Err("SET clause assignments cannot be empty".to_string());
841        }
842        if self
843            .where_clause
844            .as_ref()
845            .map(|s| s.relations.is_empty())
846            .unwrap_or(false)
847        {
848            return Err("WHERE clause cannot be empty".to_string());
849        }
850        Ok(())
851    }
852}
853
854impl TaggedUpdateStatementBuilder {
855    fn validate(&self) -> Result<(), String> {
856        if self
857            .set_clause
858            .as_ref()
859            .map(|s| match s {
860                Tag::Value(v) => v.is_empty(),
861                _ => false,
862            })
863            .unwrap_or(false)
864        {
865            return Err("SET clause assignments cannot be empty".to_string());
866        }
867        if self
868            .where_clause
869            .as_ref()
870            .map(|s| match s {
871                Tag::Value(v) => v.relations.is_empty(),
872                _ => false,
873            })
874            .unwrap_or(false)
875        {
876            return Err("WHERE clause cannot be empty".to_string());
877        }
878        Ok(())
879    }
880}
881
882impl Parse for TaggedUpdateStatement {
883    type Output = Self;
884    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
885        s.parse::<UPDATE>()?;
886        let mut res = TaggedUpdateStatementBuilder::default();
887        res.table(s.parse()?);
888        if let Some(u) = s.parse_from::<If<USING, Tag<List<UpdateParameter, AND>>>>()? {
889            res.using(u);
890        }
891        res.set_clause(s.parse_from::<(SET, Tag<List<Assignment, Comma>>)>()?.1)
892            .where_clause(s.parse()?);
893        if let Some(i) = s.parse()? {
894            res.if_clause(i);
895        }
896        s.parse::<Option<Semicolon>>()?;
897        Ok(res
898            .build()
899            .map_err(|e| anyhow::anyhow!("Invalid UPDATE statement: {}", e))?)
900    }
901}
902
903impl Display for UpdateStatement {
904    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
905        write!(f, "UPDATE {}", self.table)?;
906        if let Some(using) = &self.using {
907            if !using.is_empty() {
908                write!(
909                    f,
910                    " USING {}",
911                    using.iter().map(|p| p.to_string()).collect::<Vec<_>>().join(" AND ")
912                )?;
913            }
914        }
915        write!(
916            f,
917            " SET {} {}",
918            self.set_clause
919                .iter()
920                .map(|p| p.to_string())
921                .collect::<Vec<_>>()
922                .join(", "),
923            self.where_clause
924        )?;
925        if let Some(if_clause) = &self.if_clause {
926            write!(f, " {}", if_clause)?;
927        }
928        Ok(())
929    }
930}
931
932impl KeyspaceExt for UpdateStatement {
933    fn get_keyspace(&self) -> Option<String> {
934        self.table.keyspace.as_ref().map(|n| n.to_string())
935    }
936
937    fn set_keyspace(&mut self, keyspace: impl Into<Name>) {
938        self.table.keyspace.replace(keyspace.into());
939    }
940}
941
942impl WhereExt for UpdateStatement {
943    fn iter_where(&self) -> Option<std::slice::Iter<Relation>> {
944        Some(self.where_clause.relations.iter())
945    }
946}
947
948#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
949pub enum Assignment {
950    Simple {
951        selection: SimpleSelection,
952        term: Term,
953    },
954    Arithmetic {
955        assignee: Name,
956        lhs: Name,
957        op: ArithmeticOp,
958        rhs: Term,
959    },
960    Append {
961        assignee: Name,
962        list: ListLiteral,
963        item: Name,
964    },
965}
966
967impl Assignment {
968    pub fn simple(selection: impl Into<SimpleSelection>, term: impl Into<Term>) -> Self {
969        Self::Simple {
970            selection: selection.into(),
971            term: term.into(),
972        }
973    }
974
975    pub fn arithmetic(assignee: impl Into<Name>, lhs: impl Into<Name>, op: ArithmeticOp, rhs: impl Into<Term>) -> Self {
976        Self::Arithmetic {
977            assignee: assignee.into(),
978            lhs: lhs.into(),
979            op,
980            rhs: rhs.into(),
981        }
982    }
983
984    pub fn append(assignee: impl Into<Name>, list: Vec<impl Into<Term>>, item: impl Into<Name>) -> Self {
985        Self::Append {
986            assignee: assignee.into(),
987            list: list.into_iter().map(Into::into).collect::<Vec<_>>().into(),
988            item: item.into(),
989        }
990    }
991}
992
993impl Parse for Assignment {
994    type Output = Self;
995    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
996        Ok(
997            if let Some((assignee, _, list, _, item)) = s.parse::<Option<(_, Equals, _, Plus, _)>>()? {
998                Self::Append { assignee, list, item }
999            } else if let Some((assignee, _, lhs, op, rhs)) = s.parse::<Option<(_, Equals, _, _, _)>>()? {
1000                Self::Arithmetic { assignee, lhs, op, rhs }
1001            } else {
1002                let (selection, _, term) = s.parse::<(_, Equals, _)>()?;
1003                Self::Simple { selection, term }
1004            },
1005        )
1006    }
1007}
1008
1009impl Display for Assignment {
1010    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1011        match self {
1012            Assignment::Simple { selection, term } => write!(f, "{} = {}", selection, term),
1013            Assignment::Arithmetic { assignee, lhs, op, rhs } => write!(f, "{} = {} {} {}", assignee, lhs, op, rhs),
1014            Assignment::Append { assignee, list, item } => {
1015                write!(f, "{} = {} + {}", assignee, list, item)
1016            }
1017        }
1018    }
1019}
1020
1021#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
1022pub enum SimpleSelection {
1023    Column(Name),
1024    Term(Name, Term),
1025    Field(Name, Name),
1026}
1027
1028impl SimpleSelection {
1029    pub fn column<T: Into<Name>>(name: T) -> Self {
1030        Self::Column(name.into())
1031    }
1032
1033    pub fn term<N: Into<Name>, T: Into<Term>>(name: N, term: T) -> Self {
1034        Self::Term(name.into(), term.into())
1035    }
1036
1037    pub fn field<T: Into<Name>>(name: T, field: T) -> Self {
1038        Self::Field(name.into(), field.into())
1039    }
1040}
1041
1042impl Parse for SimpleSelection {
1043    type Output = Self;
1044    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1045        Ok(if let Some((column, _, field)) = s.parse::<Option<(_, Dot, _)>>()? {
1046            Self::Field(column, field)
1047        } else if let Some((column, term)) = s.parse_from::<Option<(Name, Brackets<Term>)>>()? {
1048            Self::Term(column, term)
1049        } else {
1050            Self::Column(s.parse()?)
1051        })
1052    }
1053}
1054
1055impl Display for SimpleSelection {
1056    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1057        match self {
1058            Self::Column(name) => name.fmt(f),
1059            Self::Term(name, term) => write!(f, "{}[{}]", name, term),
1060            Self::Field(column, field) => write!(f, "{}.{}", column, field),
1061        }
1062    }
1063}
1064
1065impl<N: Into<Name>> From<N> for SimpleSelection {
1066    fn from(name: N) -> Self {
1067        Self::Column(name.into())
1068    }
1069}
1070
1071#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
1072pub struct Condition {
1073    pub lhs: SimpleSelection,
1074    pub op: Operator,
1075    pub rhs: Term,
1076}
1077
1078impl Condition {
1079    pub fn new(lhs: impl Into<SimpleSelection>, op: impl Into<Operator>, rhs: impl Into<Term>) -> Self {
1080        Self {
1081            lhs: lhs.into(),
1082            op: op.into(),
1083            rhs: rhs.into(),
1084        }
1085    }
1086}
1087
1088impl Parse for Condition {
1089    type Output = Self;
1090    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1091        let (lhs, op, rhs) = s.parse()?;
1092        Ok(Condition { lhs, op, rhs })
1093    }
1094}
1095
1096impl Display for Condition {
1097    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1098        write!(f, "{} {} {}", self.lhs, self.op, self.rhs)
1099    }
1100}
1101
1102#[derive(ParseFromStr, Clone, Debug, ToTokens, PartialEq, Eq)]
1103pub enum IfClause {
1104    Exists,
1105    Conditions(Vec<Condition>),
1106}
1107
1108impl IfClause {
1109    pub fn exists() -> Self {
1110        Self::Exists
1111    }
1112
1113    pub fn conditions<T: Into<Condition>>(conditions: Vec<T>) -> Self {
1114        Self::Conditions(conditions.into_iter().map(Into::into).collect())
1115    }
1116}
1117
1118impl Parse for IfClause {
1119    type Output = Self;
1120    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1121        s.parse::<IF>()?;
1122        Ok(if s.parse::<Option<EXISTS>>()?.is_some() {
1123            IfClause::Exists
1124        } else {
1125            IfClause::Conditions(s.parse_from::<List<Condition, AND>>()?)
1126        })
1127    }
1128}
1129
1130impl Display for IfClause {
1131    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1132        match self {
1133            Self::Exists => write!(f, "IF EXISTS"),
1134            Self::Conditions(conditions) => {
1135                if conditions.is_empty() {
1136                    return Ok(());
1137                }
1138                write!(
1139                    f,
1140                    "IF {}",
1141                    conditions
1142                        .iter()
1143                        .map(|c| c.to_string())
1144                        .collect::<Vec<_>>()
1145                        .join(" AND ")
1146                )
1147            }
1148        }
1149    }
1150}
1151
1152#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
1153#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
1154#[parse_via(TaggedDeleteStatement)]
1155pub struct DeleteStatement {
1156    #[builder(default)]
1157    pub selections: Option<Vec<SimpleSelection>>,
1158    #[builder(setter(into))]
1159    pub from: KeyspaceQualifiedName,
1160    #[builder(default)]
1161    pub using: Option<Vec<UpdateParameter>>,
1162    #[builder(setter(into))]
1163    pub where_clause: WhereClause,
1164    #[builder(default)]
1165    pub if_clause: Option<IfClause>,
1166}
1167
1168impl TryFrom<TaggedDeleteStatement> for DeleteStatement {
1169    type Error = anyhow::Error;
1170    fn try_from(value: TaggedDeleteStatement) -> Result<Self, Self::Error> {
1171        Ok(Self {
1172            selections: value.selections.map(|v| v.into_value()).transpose()?,
1173            from: value.from.try_into()?,
1174            using: value.using.map(|v| v.into_value()).transpose()?,
1175            where_clause: value.where_clause.into_value()?,
1176            if_clause: value.if_clause.map(|v| v.into_value()).transpose()?,
1177        })
1178    }
1179}
1180
1181#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
1182#[builder(setter(strip_option), build_fn(validate = "Self::validate"))]
1183#[tokenize_as(DeleteStatement)]
1184pub struct TaggedDeleteStatement {
1185    #[builder(default)]
1186    pub selections: Option<Tag<Vec<SimpleSelection>>>,
1187    pub from: TaggedKeyspaceQualifiedName,
1188    #[builder(default)]
1189    pub using: Option<Tag<Vec<UpdateParameter>>>,
1190    pub where_clause: Tag<WhereClause>,
1191    #[builder(default)]
1192    pub if_clause: Option<Tag<IfClause>>,
1193}
1194
1195impl DeleteStatementBuilder {
1196    /// Set IF EXISTS on the statement.
1197    pub fn if_exists(&mut self) -> &mut Self {
1198        self.if_clause.replace(Some(IfClause::Exists));
1199        self
1200    }
1201
1202    fn validate(&self) -> Result<(), String> {
1203        if self
1204            .where_clause
1205            .as_ref()
1206            .map(|s| s.relations.is_empty())
1207            .unwrap_or(false)
1208        {
1209            return Err("WHERE clause cannot be empty".to_string());
1210        }
1211        Ok(())
1212    }
1213}
1214
1215impl TaggedDeleteStatementBuilder {
1216    fn validate(&self) -> Result<(), String> {
1217        if self
1218            .where_clause
1219            .as_ref()
1220            .map(|s| match s {
1221                Tag::Value(v) => v.relations.is_empty(),
1222                _ => false,
1223            })
1224            .unwrap_or(false)
1225        {
1226            return Err("WHERE clause cannot be empty".to_string());
1227        }
1228        Ok(())
1229    }
1230}
1231
1232impl Parse for TaggedDeleteStatement {
1233    type Output = Self;
1234    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1235        s.parse::<DELETE>()?;
1236        let mut res = TaggedDeleteStatementBuilder::default();
1237        if let Some(s) = s.parse_from::<Option<Tag<List<SimpleSelection, Comma>>>>()? {
1238            res.selections(s);
1239        }
1240        res.from(s.parse::<(FROM, _)>()?.1);
1241        if let Some(u) = s.parse_from::<If<USING, Tag<List<UpdateParameter, AND>>>>()? {
1242            res.using(u);
1243        }
1244        res.where_clause(s.parse()?);
1245        if let Some(i) = s.parse()? {
1246            res.if_clause(i);
1247        }
1248        s.parse::<Option<Semicolon>>()?;
1249        Ok(res
1250            .build()
1251            .map_err(|e| anyhow::anyhow!("Invalid DELETE statement: {}", e))?)
1252    }
1253}
1254
1255impl Display for DeleteStatement {
1256    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1257        write!(f, "DELETE")?;
1258        if let Some(selections) = &self.selections {
1259            if !selections.is_empty() {
1260                write!(
1261                    f,
1262                    " {}",
1263                    selections.iter().map(|s| s.to_string()).collect::<Vec<_>>().join(", ")
1264                )?;
1265            }
1266        }
1267        write!(f, " FROM {}", self.from)?;
1268        if let Some(using) = &self.using {
1269            if !using.is_empty() {
1270                write!(
1271                    f,
1272                    " USING {}",
1273                    using.iter().map(|p| p.to_string()).collect::<Vec<_>>().join(", ")
1274                )?;
1275            }
1276        }
1277        write!(f, " {}", self.where_clause)?;
1278        if let Some(if_clause) = &self.if_clause {
1279            write!(f, " {}", if_clause)?;
1280        }
1281        Ok(())
1282    }
1283}
1284
1285impl KeyspaceExt for DeleteStatement {
1286    fn get_keyspace(&self) -> Option<String> {
1287        self.from.keyspace.as_ref().map(|n| n.to_string())
1288    }
1289
1290    fn set_keyspace(&mut self, keyspace: impl Into<Name>) {
1291        self.from.keyspace.replace(keyspace.into());
1292    }
1293}
1294
1295impl WhereExt for DeleteStatement {
1296    fn iter_where(&self) -> Option<std::slice::Iter<Relation>> {
1297        Some(self.where_clause.relations.iter())
1298    }
1299}
1300
1301#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
1302#[builder(build_fn(validate = "Self::validate"))]
1303#[parse_via(TaggedBatchStatement)]
1304pub struct BatchStatement {
1305    #[builder(default)]
1306    pub kind: BatchKind,
1307    #[builder(setter(strip_option), default)]
1308    pub using: Option<Vec<UpdateParameter>>,
1309    pub statements: Vec<ModificationStatement>,
1310}
1311
1312impl TryFrom<TaggedBatchStatement> for BatchStatement {
1313    type Error = anyhow::Error;
1314    fn try_from(value: TaggedBatchStatement) -> Result<Self, Self::Error> {
1315        Ok(Self {
1316            kind: value.kind,
1317            using: value.using.map(|v| v.into_value()).transpose()?,
1318            statements: value
1319                .statements
1320                .into_iter()
1321                .map(|v| v.into_value().and_then(|v| v.try_into()))
1322                .collect::<Result<_, _>>()?,
1323        })
1324    }
1325}
1326
1327#[derive(ParseFromStr, Builder, Clone, Debug, ToTokens, PartialEq, Eq)]
1328#[builder(build_fn(validate = "Self::validate"))]
1329#[tokenize_as(BatchStatement)]
1330pub struct TaggedBatchStatement {
1331    #[builder(default)]
1332    pub kind: BatchKind,
1333    #[builder(setter(strip_option), default)]
1334    pub using: Option<Tag<Vec<UpdateParameter>>>,
1335    pub statements: Vec<Tag<TaggedModificationStatement>>,
1336}
1337
1338impl BatchStatement {
1339    pub fn add_parse_statement(&mut self, statement: &str) -> anyhow::Result<()> {
1340        self.statements.push(statement.parse()?);
1341        Ok(())
1342    }
1343
1344    pub fn add_statement(&mut self, statement: ModificationStatement) {
1345        self.statements.push(statement);
1346    }
1347
1348    pub fn parse_statement(mut self, statement: &str) -> anyhow::Result<Self> {
1349        self.add_parse_statement(statement)?;
1350        Ok(self)
1351    }
1352
1353    pub fn statement(mut self, statement: ModificationStatement) -> Self {
1354        self.add_statement(statement);
1355        self
1356    }
1357
1358    pub fn insert(mut self, statement: InsertStatement) -> Self {
1359        self.statements.push(statement.into());
1360        self
1361    }
1362
1363    pub fn update(mut self, statement: UpdateStatement) -> Self {
1364        self.statements.push(statement.into());
1365        self
1366    }
1367
1368    pub fn delete(mut self, statement: DeleteStatement) -> Self {
1369        self.statements.push(statement.into());
1370        self
1371    }
1372}
1373
1374impl BatchStatementBuilder {
1375    pub fn parse_statement(&mut self, statement: &str) -> anyhow::Result<&mut Self> {
1376        self.statements
1377            .get_or_insert_with(Default::default)
1378            .push(statement.parse()?);
1379        Ok(self)
1380    }
1381
1382    pub fn statement(&mut self, statement: ModificationStatement) -> &mut Self {
1383        self.statements.get_or_insert_with(Default::default).push(statement);
1384        self
1385    }
1386
1387    pub fn insert(&mut self, statement: InsertStatement) -> &mut Self {
1388        self.statements
1389            .get_or_insert_with(Default::default)
1390            .push(statement.into());
1391        self
1392    }
1393
1394    pub fn update(&mut self, statement: UpdateStatement) -> &mut Self {
1395        self.statements
1396            .get_or_insert_with(Default::default)
1397            .push(statement.into());
1398        self
1399    }
1400
1401    pub fn delete(&mut self, statement: DeleteStatement) -> &mut Self {
1402        self.statements
1403            .get_or_insert_with(Default::default)
1404            .push(statement.into());
1405        self
1406    }
1407
1408    fn validate(&self) -> Result<(), String> {
1409        if self.statements.as_ref().map(|s| s.is_empty()).unwrap_or(false) {
1410            return Err("Batch cannot contain zero statements".to_string());
1411        }
1412        Ok(())
1413    }
1414}
1415
1416impl TaggedBatchStatementBuilder {
1417    fn validate(&self) -> Result<(), String> {
1418        if self.statements.as_ref().map(|s| s.is_empty()).unwrap_or(false) {
1419            return Err("Batch cannot contain zero statements".to_string());
1420        }
1421        Ok(())
1422    }
1423}
1424
1425impl Parse for TaggedBatchStatement {
1426    type Output = Self;
1427    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1428        s.parse::<BEGIN>()?;
1429        let mut res = TaggedBatchStatementBuilder::default();
1430        res.kind(s.parse()?);
1431        s.parse::<BATCH>()?;
1432        if let Some(u) = s.parse_from::<If<USING, Tag<List<UpdateParameter, AND>>>>()? {
1433            res.using(u);
1434        }
1435        let mut statements = Vec::new();
1436        while let Some(res) = s.parse()? {
1437            statements.push(res);
1438        }
1439        res.statements(statements);
1440        s.parse::<(APPLY, BATCH)>()?;
1441        s.parse::<Option<Semicolon>>()?;
1442        Ok(res
1443            .build()
1444            .map_err(|e| anyhow::anyhow!("Invalid BATCH statement: {}", e))?)
1445    }
1446}
1447
1448impl Display for BatchStatement {
1449    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1450        write!(f, "BEGIN")?;
1451        match self.kind {
1452            BatchKind::Logged => (),
1453            BatchKind::Unlogged => write!(f, " UNLOGGED")?,
1454            BatchKind::Counter => write!(f, " COUNTER")?,
1455        };
1456        write!(f, " BATCH")?;
1457        if let Some(using) = &self.using {
1458            if !using.is_empty() {
1459                write!(
1460                    f,
1461                    " USING {}",
1462                    using.iter().map(|p| p.to_string()).collect::<Vec<_>>().join(" AND ")
1463                )?;
1464            }
1465        }
1466        write!(
1467            f,
1468            " {}",
1469            self.statements
1470                .iter()
1471                .map(|s| s.to_string())
1472                .collect::<Vec<_>>()
1473                .join("; ")
1474        )?;
1475        write!(f, " APPLY BATCH")?;
1476        Ok(())
1477    }
1478}
1479
1480#[derive(ParseFromStr, Clone, Debug, TryInto, From, ToTokens, PartialEq, Eq)]
1481#[parse_via(TaggedModificationStatement)]
1482pub enum ModificationStatement {
1483    Insert(InsertStatement),
1484    Update(UpdateStatement),
1485    Delete(DeleteStatement),
1486}
1487
1488impl Into<DataManipulationStatement> for ModificationStatement {
1489    fn into(self) -> DataManipulationStatement {
1490        match self {
1491            ModificationStatement::Insert(i) => DataManipulationStatement::Insert(i),
1492            ModificationStatement::Update(u) => DataManipulationStatement::Update(u),
1493            ModificationStatement::Delete(d) => DataManipulationStatement::Delete(d),
1494        }
1495    }
1496}
1497
1498impl TryFrom<TaggedModificationStatement> for ModificationStatement {
1499    type Error = anyhow::Error;
1500    fn try_from(value: TaggedModificationStatement) -> Result<Self, Self::Error> {
1501        Ok(match value {
1502            TaggedModificationStatement::Insert(s) => ModificationStatement::Insert(s.try_into()?),
1503            TaggedModificationStatement::Update(s) => ModificationStatement::Update(s.try_into()?),
1504            TaggedModificationStatement::Delete(s) => ModificationStatement::Delete(s.try_into()?),
1505        })
1506    }
1507}
1508
1509#[derive(ParseFromStr, Clone, Debug, TryInto, From, ToTokens, PartialEq, Eq)]
1510#[tokenize_as(ModificationStatement)]
1511pub enum TaggedModificationStatement {
1512    Insert(TaggedInsertStatement),
1513    Update(TaggedUpdateStatement),
1514    Delete(TaggedDeleteStatement),
1515}
1516
1517impl Parse for TaggedModificationStatement {
1518    type Output = Self;
1519    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1520        Ok(if let Some(keyword) = s.find::<ReservedKeyword>() {
1521            match keyword {
1522                ReservedKeyword::INSERT => Self::Insert(s.parse()?),
1523                ReservedKeyword::UPDATE => Self::Update(s.parse()?),
1524                ReservedKeyword::DELETE => Self::Delete(s.parse()?),
1525                _ => anyhow::bail!(
1526                    "Expected a data modification statement (INSERT / UPDATE / DELETE)! Found {}",
1527                    keyword
1528                ),
1529            }
1530        } else {
1531            anyhow::bail!(
1532                "Expected a data modification statement (INSERT / UPDATE / DELETE), found {}",
1533                s.info()
1534            )
1535        })
1536    }
1537}
1538
1539impl Display for ModificationStatement {
1540    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1541        match self {
1542            Self::Insert(s) => s.fmt(f),
1543            Self::Update(s) => s.fmt(f),
1544            Self::Delete(s) => s.fmt(f),
1545        }
1546    }
1547}
1548
1549#[derive(Copy, Clone, Debug, ToTokens, PartialEq, Eq)]
1550pub enum BatchKind {
1551    Logged,
1552    Unlogged,
1553    Counter,
1554}
1555
1556impl Parse for BatchKind {
1557    type Output = Self;
1558    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1559        Ok(if s.parse::<Option<UNLOGGED>>()?.is_some() {
1560            BatchKind::Unlogged
1561        } else if s.parse::<Option<COUNTER>>()?.is_some() {
1562            BatchKind::Counter
1563        } else {
1564            BatchKind::Logged
1565        })
1566    }
1567}
1568
1569impl Default for BatchKind {
1570    fn default() -> Self {
1571        BatchKind::Logged
1572    }
1573}
1574
1575#[derive(Clone, Debug, ToTokens, PartialEq, Eq)]
1576pub struct WhereClause {
1577    pub relations: Vec<Relation>,
1578}
1579
1580impl Parse for WhereClause {
1581    type Output = Self;
1582    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self> {
1583        let (_, relations) = s.parse_from::<(WHERE, List<Relation, AND>)>()?;
1584        Ok(WhereClause { relations })
1585    }
1586}
1587
1588impl Display for WhereClause {
1589    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1590        if self.relations.is_empty() {
1591            return Ok(());
1592        }
1593        write!(
1594            f,
1595            "WHERE {}",
1596            self.relations
1597                .iter()
1598                .map(|r| r.to_string())
1599                .collect::<Vec<_>>()
1600                .join(" AND ")
1601        )
1602    }
1603}
1604
1605impl From<Vec<Relation>> for WhereClause {
1606    fn from(relations: Vec<Relation>) -> Self {
1607        WhereClause { relations }
1608    }
1609}
1610
1611#[derive(Clone, Debug, ToTokens, PartialEq, Eq)]
1612pub struct GroupByClause {
1613    pub columns: Vec<Name>,
1614}
1615
1616impl Parse for GroupByClause {
1617    type Output = Self;
1618    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self> {
1619        let (_, _, columns) = s.parse_from::<(GROUP, BY, List<Name, Comma>)>()?;
1620        Ok(GroupByClause { columns })
1621    }
1622}
1623
1624impl Display for GroupByClause {
1625    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1626        if self.columns.is_empty() {
1627            return Ok(());
1628        }
1629        write!(
1630            f,
1631            "GROUP BY {}",
1632            self.columns
1633                .iter()
1634                .map(|c| c.to_string())
1635                .collect::<Vec<_>>()
1636                .join(", ")
1637        )
1638    }
1639}
1640
1641impl<T: Into<Name>> From<Vec<T>> for GroupByClause {
1642    fn from(columns: Vec<T>) -> Self {
1643        GroupByClause {
1644            columns: columns.into_iter().map(|c| c.into()).collect(),
1645        }
1646    }
1647}
1648
1649#[derive(Clone, Debug, ToTokens, PartialEq, Eq)]
1650pub struct OrderByClause {
1651    pub columns: Vec<ColumnOrder>,
1652}
1653
1654impl Parse for OrderByClause {
1655    type Output = Self;
1656    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self> {
1657        let (_, _, columns) = s.parse_from::<(ORDER, BY, List<ColumnOrder, Comma>)>()?;
1658        Ok(OrderByClause { columns })
1659    }
1660}
1661
1662impl Display for OrderByClause {
1663    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1664        if self.columns.is_empty() {
1665            return Ok(());
1666        }
1667        write!(
1668            f,
1669            "ORDER BY {}",
1670            self.columns
1671                .iter()
1672                .map(|c| c.to_string())
1673                .collect::<Vec<_>>()
1674                .join(", ")
1675        )
1676    }
1677}
1678
1679impl<T: Into<ColumnOrder>> From<Vec<T>> for OrderByClause {
1680    fn from(columns: Vec<T>) -> Self {
1681        OrderByClause {
1682            columns: columns.into_iter().map(|c| c.into()).collect(),
1683        }
1684    }
1685}
1686
1687#[derive(Clone, Debug, From, ToTokens, PartialEq, Eq)]
1688pub enum Limit {
1689    Literal(i32),
1690    #[from(ignore)]
1691    BindMarker(BindMarker),
1692}
1693
1694impl Parse for Limit {
1695    type Output = Self;
1696    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1697        if let Some(bind) = s.parse::<Option<BindMarker>>()? {
1698            Ok(Limit::BindMarker(bind))
1699        } else if let Some(n) = s.parse::<Option<i32>>()? {
1700            Ok(Limit::Literal(n))
1701        } else {
1702            anyhow::bail!("Expected an integer or bind marker (?), found {}", s.info())
1703        }
1704    }
1705}
1706
1707impl Display for Limit {
1708    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1709        match self {
1710            Limit::Literal(i) => i.fmt(f),
1711            Limit::BindMarker(b) => b.fmt(f),
1712        }
1713    }
1714}
1715
1716impl<T: Into<BindMarker>> From<T> for Limit {
1717    fn from(bind: T) -> Self {
1718        Limit::BindMarker(bind.into())
1719    }
1720}
1721
1722#[derive(Copy, Clone, Debug, ToTokens, PartialEq, Eq)]
1723pub enum ColumnDefault {
1724    Null,
1725    Unset,
1726}
1727
1728impl Parse for ColumnDefault {
1729    type Output = Self;
1730    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1731        if s.parse::<Option<NULL>>()?.is_some() {
1732            Ok(ColumnDefault::Null)
1733        } else if s.parse::<Option<UNSET>>()?.is_some() {
1734            Ok(ColumnDefault::Unset)
1735        } else {
1736            anyhow::bail!("Expected column default (NULL/UNSET), found {}", s.info())
1737        }
1738    }
1739}
1740
1741impl Display for ColumnDefault {
1742    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1743        match self {
1744            ColumnDefault::Null => write!(f, "NULL"),
1745            ColumnDefault::Unset => write!(f, "UNSET"),
1746        }
1747    }
1748}
1749
1750#[cfg(test)]
1751mod test {
1752    use super::*;
1753    use crate::{
1754        KeyspaceQualifyExt,
1755        Order,
1756    };
1757
1758    #[test]
1759    fn test_parse_select() {
1760        let mut builder = SelectStatementBuilder::default();
1761        builder.select_clause(vec![
1762            Selector::column("movie"),
1763            Selector::column("director").as_id("Movie Director"),
1764        ]);
1765        assert!(builder.build().is_err());
1766        builder.from("movies".dot("NerdMovies"));
1767        let statement = builder.build().unwrap().to_string();
1768        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1769        builder.where_clause(vec![
1770            Relation::normal("year", Operator::Equal, 2012_i32),
1771            Relation::tuple(
1772                vec!["main_actor"],
1773                Operator::In,
1774                vec![LitStr::from("Nathan Fillion"), LitStr::from("John O'Goodman")],
1775            ),
1776            Relation::token(
1777                vec!["director"],
1778                Operator::GreaterThan,
1779                FunctionCall::new("token", vec![LitStr::from("movie")]),
1780            ),
1781        ]);
1782        let statement = builder.build().unwrap().to_string();
1783        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1784        builder.distinct();
1785        let statement = builder.build().unwrap().to_string();
1786        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1787        builder.select_clause(SelectClause::All);
1788        let statement = builder.build().unwrap().to_string();
1789        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1790        builder.group_by_clause(vec!["director", "main_actor", "year", "movie"]);
1791        let statement = builder.build().unwrap().to_string();
1792        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1793        builder.order_by_clause(vec![("director", Order::Ascending), ("year", Order::Descending)]);
1794        let statement = builder.build().unwrap().to_string();
1795        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1796        builder.per_partition_limit(10);
1797        let statement = builder.build().unwrap().to_string();
1798        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1799        builder.limit(BindMarker::Anonymous);
1800        let statement = builder.build().unwrap().to_string();
1801        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1802        builder.limit("bind_marker");
1803        let statement = builder.build().unwrap().to_string();
1804        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1805        builder.allow_filtering().bypass_cache();
1806        let statement = builder.build().unwrap().to_string();
1807        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1808        builder.timeout(std::time::Duration::from_secs(10));
1809        let statement = builder.build().unwrap().to_string();
1810        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1811    }
1812
1813    #[test]
1814    fn test_parse_insert() {
1815        let mut builder = InsertStatementBuilder::default();
1816        builder.table("test");
1817        assert!(builder.build().is_err());
1818        builder.kind(
1819            InsertKind::name_value(
1820                vec!["movie".into(), "director".into(), "main_actor".into(), "year".into()],
1821                vec![
1822                    LitStr::from("Serenity").into(),
1823                    LitStr::from("Joss Whedon").into(),
1824                    LitStr::from("Nathan Fillion").into(),
1825                    2005_i32.into(),
1826                ],
1827            )
1828            .unwrap(),
1829        );
1830        let statement = builder.build().unwrap().to_string();
1831        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1832        builder.if_not_exists();
1833        let statement = builder.build().unwrap().to_string();
1834        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1835        builder.using(vec![
1836            UpdateParameter::ttl(86400),
1837            UpdateParameter::timestamp(1000),
1838            UpdateParameter::timeout(std::time::Duration::from_secs(60)),
1839        ]);
1840        let statement = builder.build().unwrap().to_string();
1841        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1842        builder.kind(InsertKind::json(
1843            r#"{
1844                "movie": "Serenity",
1845                "director": "Joss Whedon",
1846                "main_actor": "Nathan Fillion",
1847                "year": 2005
1848            }"#,
1849            ColumnDefault::Null,
1850        ));
1851        let statement = builder.build().unwrap().to_string();
1852        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1853    }
1854
1855    #[test]
1856    fn test_parse_update() {
1857        let mut builder = UpdateStatementBuilder::default();
1858        builder.table("test");
1859        assert!(builder.build().is_err());
1860        builder.set_clause(vec![
1861            Assignment::simple("director", LitStr::from("Joss Whedon")),
1862            Assignment::simple("main_actor", LitStr::from("Nathan Fillion")),
1863            Assignment::arithmetic("year", "year", ArithmeticOp::Add, 10_i32),
1864            Assignment::append("my_list", vec![LitStr::from("foo"), LitStr::from("bar")], "my_list"),
1865        ]);
1866        assert!(builder.build().is_err());
1867        builder.where_clause(vec![Relation::normal(
1868            "movie",
1869            Operator::Equal,
1870            LitStr::from("Serenity"),
1871        )]);
1872        let statement = builder.build().unwrap().to_string();
1873        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1874        builder.if_clause(IfClause::Exists);
1875        let statement = builder.build().unwrap().to_string();
1876        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1877        builder.if_clause(IfClause::conditions(vec![
1878            Condition::new("director", Operator::Equal, LitStr::from("Joss Whedon")),
1879            Condition::new(
1880                SimpleSelection::field("my_type", "my_field"),
1881                Operator::LessThan,
1882                100_i32,
1883            ),
1884            Condition::new(
1885                SimpleSelection::term("my_list", 0_i32),
1886                Operator::Like,
1887                LitStr::from("foo%"),
1888            ),
1889        ]));
1890        let statement = builder.build().unwrap().to_string();
1891        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1892        builder.using(vec![
1893            UpdateParameter::ttl(86400),
1894            UpdateParameter::timestamp(1000),
1895            UpdateParameter::timeout(std::time::Duration::from_secs(60)),
1896        ]);
1897        let statement = builder.build().unwrap().to_string();
1898        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1899    }
1900
1901    #[test]
1902    fn test_parse_delete() {
1903        let mut builder = DeleteStatementBuilder::default();
1904        builder.from("test");
1905        assert!(builder.build().is_err());
1906        builder.where_clause(vec![Relation::normal(
1907            "movie",
1908            Operator::Equal,
1909            LitStr::from("Serenity"),
1910        )]);
1911        let statement = builder.build().unwrap().to_string();
1912        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1913        builder.if_clause(IfClause::Exists);
1914        let statement = builder.build().unwrap().to_string();
1915        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1916        builder.if_clause(IfClause::conditions(vec![
1917            Condition::new("director", Operator::Equal, LitStr::from("Joss Whedon")),
1918            Condition::new(
1919                SimpleSelection::field("my_type", "my_field"),
1920                Operator::LessThan,
1921                100_i32,
1922            ),
1923            Condition::new(
1924                SimpleSelection::term("my_list", 0_i32),
1925                Operator::Like,
1926                LitStr::from("foo%"),
1927            ),
1928        ]));
1929        let statement = builder.build().unwrap().to_string();
1930        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1931    }
1932
1933    #[test]
1934    fn test_parse_batch() {
1935        let mut builder = BatchStatementBuilder::default();
1936        builder.using(vec![
1937            UpdateParameter::ttl(86400),
1938            UpdateParameter::timestamp(1000),
1939            UpdateParameter::timeout(std::time::Duration::from_secs(60)),
1940        ]);
1941        assert!(builder.build().is_err());
1942        builder.insert(
1943            InsertStatementBuilder::default()
1944                .table("NerdMovies")
1945                .kind(
1946                    InsertKind::name_value(
1947                        vec!["movie".into(), "director".into(), "main_actor".into(), "year".into()],
1948                        vec![
1949                            LitStr::from("Serenity").into(),
1950                            LitStr::from("Joss Whedon").into(),
1951                            LitStr::from("Nathan Fillion").into(),
1952                            2005_i32.into(),
1953                        ],
1954                    )
1955                    .unwrap(),
1956                )
1957                .if_not_exists()
1958                .using(vec![UpdateParameter::ttl(86400)])
1959                .build()
1960                .unwrap(),
1961        );
1962        let statement = builder.build().unwrap().to_string();
1963        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1964        builder.kind(BatchKind::Unlogged);
1965        let statement = builder.build().unwrap().to_string();
1966        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1967        builder.kind(BatchKind::Logged);
1968        let statement = builder.build().unwrap().to_string();
1969        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1970        builder.kind(BatchKind::Counter);
1971        let statement = builder.build().unwrap().to_string();
1972        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
1973        builder
1974            .update(
1975                UpdateStatementBuilder::default()
1976                    .table("NerdMovies")
1977                    .set_clause(vec![
1978                        Assignment::simple("director", LitStr::from("Joss Whedon")),
1979                        Assignment::simple("main_actor", LitStr::from("Nathan Fillion")),
1980                    ])
1981                    .where_clause(vec![Relation::normal(
1982                        "movie",
1983                        Operator::Equal,
1984                        LitStr::from("Serenity"),
1985                    )])
1986                    .if_clause(IfClause::Exists)
1987                    .build()
1988                    .unwrap(),
1989            )
1990            .delete(
1991                DeleteStatementBuilder::default()
1992                    .from("NerdMovies")
1993                    .where_clause(vec![Relation::normal(
1994                        "movie",
1995                        Operator::Equal,
1996                        LitStr::from("Serenity"),
1997                    )])
1998                    .if_clause(IfClause::Exists)
1999                    .build()
2000                    .unwrap(),
2001            );
2002        let statement = builder.build().unwrap().to_string();
2003        assert_eq!(builder.build().unwrap(), statement.parse().unwrap());
2004    }
2005}