scylladb_parse/
data_types.rs

1use super::{
2    format_cql_f32,
3    format_cql_f64,
4    keywords::*,
5    Alpha,
6    Angles,
7    BindMarker,
8    Braces,
9    Brackets,
10    CustomToTokens,
11    Float,
12    FunctionCall,
13    Hex,
14    KeyspaceQualifiedName,
15    List,
16    LitStr,
17    Name,
18    Parens,
19    Parse,
20    SignedNumber,
21    StatementStream,
22    Tag,
23    TokenWrapper,
24};
25use chrono::{
26    Datelike,
27    NaiveDate,
28    NaiveDateTime,
29    NaiveTime,
30    Timelike,
31};
32use derive_builder::Builder;
33use derive_more::{
34    From,
35    TryInto,
36};
37use scylladb_parse_macros::{
38    ParseFromStr,
39    ToTokens,
40};
41use std::{
42    collections::{
43        BTreeMap,
44        BTreeSet,
45        HashMap,
46        HashSet,
47    },
48    convert::{
49        TryFrom,
50        TryInto,
51    },
52    fmt::{
53        Display,
54        Formatter,
55    },
56    str::FromStr,
57};
58use uuid::Uuid;
59
60#[derive(ParseFromStr, Copy, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
61pub enum ArithmeticOp {
62    Add,
63    Sub,
64    Mul,
65    Div,
66    Mod,
67}
68
69impl Display for ArithmeticOp {
70    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "{}",
74            match self {
75                ArithmeticOp::Add => "+",
76                ArithmeticOp::Sub => "-",
77                ArithmeticOp::Mul => "*",
78                ArithmeticOp::Div => "/",
79                ArithmeticOp::Mod => "%",
80            }
81        )
82    }
83}
84
85impl TryFrom<char> for ArithmeticOp {
86    type Error = anyhow::Error;
87
88    fn try_from(value: char) -> Result<Self, Self::Error> {
89        match value {
90            '+' => Ok(ArithmeticOp::Add),
91            '-' => Ok(ArithmeticOp::Sub),
92            '*' => Ok(ArithmeticOp::Mul),
93            '/' => Ok(ArithmeticOp::Div),
94            '%' => Ok(ArithmeticOp::Mod),
95            _ => anyhow::bail!("Invalid arithmetic operator: {}", value),
96        }
97    }
98}
99
100impl Parse for ArithmeticOp {
101    type Output = Self;
102    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
103        s.parse::<char>()?.try_into()
104    }
105}
106
107#[derive(ParseFromStr, Copy, Clone, Debug, ToTokens, PartialEq, Eq)]
108pub enum Operator {
109    Equal,
110    NotEqual,
111    GreaterThan,
112    GreaterThanOrEqual,
113    LessThan,
114    LessThanOrEqual,
115    In,
116    Contains,
117    ContainsKey,
118    Like,
119}
120
121impl Display for Operator {
122    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
123        write!(
124            f,
125            "{}",
126            match self {
127                Operator::Equal => "=",
128                Operator::NotEqual => "!=",
129                Operator::GreaterThan => ">",
130                Operator::GreaterThanOrEqual => ">=",
131                Operator::LessThan => "<",
132                Operator::LessThanOrEqual => "<=",
133                Operator::In => "IN",
134                Operator::Contains => "CONTAINS",
135                Operator::ContainsKey => "CONTAINS KEY",
136                Operator::Like => "LIKE",
137            }
138        )
139    }
140}
141
142impl Parse for Operator {
143    type Output = Self;
144    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self> {
145        if s.parse::<Option<(CONTAINS, KEY)>>()?.is_some() {
146            Ok(Operator::ContainsKey)
147        } else if s.parse::<Option<CONTAINS>>()?.is_some() {
148            Ok(Operator::Contains)
149        } else if s.parse::<Option<IN>>()?.is_some() {
150            Ok(Operator::In)
151        } else if s.parse::<Option<LIKE>>()?.is_some() {
152            Ok(Operator::Like)
153        } else if let (Some(first), second) = (s.next(), s.peek()) {
154            Ok(match (first, second) {
155                ('=', _) => Operator::Equal,
156                ('!', Some('=')) => {
157                    s.next();
158                    Operator::NotEqual
159                }
160                ('>', Some('=')) => {
161                    s.next();
162                    Operator::GreaterThanOrEqual
163                }
164                ('<', Some('=')) => {
165                    s.next();
166                    Operator::LessThanOrEqual
167                }
168                ('>', _) => Operator::GreaterThan,
169                ('<', _) => Operator::LessThan,
170                _ => anyhow::bail!(
171                    "Invalid operator: {}",
172                    if let Some(second) = second {
173                        format!("{}{}", first, second)
174                    } else {
175                        first.to_string()
176                    }
177                ),
178            })
179        } else {
180            anyhow::bail!("Expected operator, found {}", s.info())
181        }
182    }
183}
184
185#[derive(Copy, Clone, Debug)]
186pub enum TimeUnit {
187    Nanos,
188    Micros,
189    Millis,
190    Seconds,
191    Minutes,
192    Hours,
193    Days,
194    Weeks,
195    Months,
196    Years,
197}
198
199impl FromStr for TimeUnit {
200    type Err = anyhow::Error;
201
202    fn from_str(s: &str) -> Result<Self, Self::Err> {
203        match s {
204            "ns" => Ok(TimeUnit::Nanos),
205            "us" | "µs" => Ok(TimeUnit::Micros),
206            "ms" => Ok(TimeUnit::Millis),
207            "s" => Ok(TimeUnit::Seconds),
208            "m" => Ok(TimeUnit::Minutes),
209            "h" => Ok(TimeUnit::Hours),
210            "d" => Ok(TimeUnit::Days),
211            "w" => Ok(TimeUnit::Weeks),
212            "mo" => Ok(TimeUnit::Months),
213            "y" => Ok(TimeUnit::Years),
214            _ => anyhow::bail!("Invalid time unit: {}", s),
215        }
216    }
217}
218
219impl Display for TimeUnit {
220    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
221        write!(
222            f,
223            "{}",
224            match self {
225                TimeUnit::Nanos => "ns",
226                TimeUnit::Micros => "us",
227                TimeUnit::Millis => "ms",
228                TimeUnit::Seconds => "s",
229                TimeUnit::Minutes => "m",
230                TimeUnit::Hours => "h",
231                TimeUnit::Days => "d",
232                TimeUnit::Weeks => "w",
233                TimeUnit::Months => "mo",
234                TimeUnit::Years => "y",
235            }
236        )
237    }
238}
239
240impl Parse for TimeUnit {
241    type Output = Self;
242    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
243        if let (Some(first), second) = (s.next(), s.peek()) {
244            Ok(match (first, second) {
245                ('n', Some('s')) => {
246                    s.next();
247                    TimeUnit::Nanos
248                }
249                ('u', Some('s')) | ('µ', Some('s')) => {
250                    s.next();
251                    TimeUnit::Micros
252                }
253                ('m', Some('s')) => {
254                    s.next();
255                    TimeUnit::Millis
256                }
257                ('m', Some('o')) => {
258                    s.next();
259                    TimeUnit::Months
260                }
261                ('s', _) => TimeUnit::Seconds,
262                ('m', _) => TimeUnit::Minutes,
263                ('h', _) => TimeUnit::Hours,
264                ('d', _) => TimeUnit::Days,
265                ('w', _) => TimeUnit::Weeks,
266                ('y', _) => TimeUnit::Years,
267                _ => anyhow::bail!(
268                    "Invalid time unit: {}",
269                    if let Some(second) = second {
270                        format!("{}{}", first, second)
271                    } else {
272                        first.to_string()
273                    }
274                ),
275            })
276        } else {
277            anyhow::bail!("Expected time unit, found {}", s.info())
278        }
279    }
280}
281
282#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens, From)]
283pub enum Term {
284    Constant(Constant),
285    Literal(Literal),
286    FunctionCall(FunctionCall),
287    #[from(ignore)]
288    ArithmeticOp {
289        lhs: Option<Box<Term>>,
290        op: ArithmeticOp,
291        rhs: Box<Term>,
292    },
293    #[from(ignore)]
294    TypeHint {
295        hint: CqlType,
296        ident: Name,
297    },
298    BindMarker(BindMarker),
299}
300
301impl Term {
302    pub fn constant<T: Into<Constant>>(value: T) -> Self {
303        Term::Constant(value.into())
304    }
305
306    pub fn literal<T: Into<Literal>>(value: T) -> Self {
307        Term::Literal(value.into())
308    }
309
310    pub fn function_call<T: Into<FunctionCall>>(value: T) -> Self {
311        Term::FunctionCall(value.into())
312    }
313
314    pub fn negative<T: Into<Term>>(t: T) -> Self {
315        Term::ArithmeticOp {
316            lhs: None,
317            op: ArithmeticOp::Sub,
318            rhs: Box::new(t.into()),
319        }
320    }
321
322    pub fn arithmetic_op<LT: Into<Term>, RT: Into<Term>>(lhs: LT, op: ArithmeticOp, rhs: RT) -> Self {
323        Term::ArithmeticOp {
324            lhs: Some(Box::new(lhs.into())),
325            op,
326            rhs: Box::new(rhs.into()),
327        }
328    }
329
330    pub fn type_hint<T: Into<CqlType>, N: Into<Name>>(hint: T, ident: N) -> Self {
331        Term::TypeHint {
332            hint: hint.into(),
333            ident: ident.into(),
334        }
335    }
336
337    pub fn bind_marker<T: Into<BindMarker>>(marker: T) -> Self {
338        Term::BindMarker(marker.into())
339    }
340}
341
342impl Parse for Term {
343    type Output = Self;
344    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self> {
345        Ok(if let Some(c) = s.parse()? {
346            if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
347                Self::ArithmeticOp {
348                    lhs: Some(Box::new(Self::Constant(c))),
349                    op,
350                    rhs: Box::new(rhs),
351                }
352            } else {
353                Self::Constant(c)
354            }
355        } else if let Some(lit) = s.parse()? {
356            if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
357                Self::ArithmeticOp {
358                    lhs: Some(Box::new(Self::Literal(lit))),
359                    op,
360                    rhs: Box::new(rhs),
361                }
362            } else {
363                Self::Literal(lit)
364            }
365        } else if let Some(f) = s.parse()? {
366            if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
367                Self::ArithmeticOp {
368                    lhs: Some(Box::new(Self::FunctionCall(f))),
369                    op,
370                    rhs: Box::new(rhs),
371                }
372            } else {
373                Self::FunctionCall(f)
374            }
375        } else if let Some((hint, ident)) = s.parse()? {
376            if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
377                Self::ArithmeticOp {
378                    lhs: Some(Box::new(Self::TypeHint { hint, ident })),
379                    op,
380                    rhs: Box::new(rhs),
381                }
382            } else {
383                Self::TypeHint { hint, ident }
384            }
385        } else if let Some(b) = s.parse()? {
386            if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
387                Self::ArithmeticOp {
388                    lhs: Some(Box::new(Self::BindMarker(b))),
389                    op,
390                    rhs: Box::new(rhs),
391                }
392            } else {
393                Self::BindMarker(b)
394            }
395        } else if let Some((op, rhs)) = s.parse::<Option<(ArithmeticOp, Term)>>()? {
396            Self::ArithmeticOp {
397                lhs: None,
398                op,
399                rhs: Box::new(rhs),
400            }
401        } else {
402            anyhow::bail!("Expected term, found {}", s.info())
403        })
404    }
405}
406
407impl Display for Term {
408    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
409        match self {
410            Self::Constant(c) => c.fmt(f),
411            Self::Literal(l) => l.fmt(f),
412            Self::FunctionCall(fc) => fc.fmt(f),
413            Self::ArithmeticOp { lhs, op, rhs } => match lhs {
414                Some(lhs) => write!(f, "{}{}{}", lhs, op, rhs),
415                None => write!(f, "{}{}", op, rhs),
416            },
417            Self::TypeHint { hint, ident } => write!(f, "{} {}", hint, ident),
418            Self::BindMarker(b) => b.fmt(f),
419        }
420    }
421}
422
423impl TryInto<LitStr> for Term {
424    type Error = anyhow::Error;
425
426    fn try_into(self) -> anyhow::Result<LitStr> {
427        if let Self::Constant(c) = self {
428            c.try_into()
429        } else {
430            Err(anyhow::anyhow!("Expected constant, found {}", self))
431        }
432    }
433}
434
435impl TryInto<i32> for Term {
436    type Error = anyhow::Error;
437
438    fn try_into(self) -> anyhow::Result<i32> {
439        if let Self::Constant(c) = self {
440            c.try_into()
441        } else {
442            Err(anyhow::anyhow!("Expected constant, found {}", self))
443        }
444    }
445}
446
447impl TryInto<i64> for Term {
448    type Error = anyhow::Error;
449
450    fn try_into(self) -> anyhow::Result<i64> {
451        if let Self::Constant(c) = self {
452            c.try_into()
453        } else {
454            Err(anyhow::anyhow!("Expected constant, found {}", self))
455        }
456    }
457}
458
459impl TryInto<f32> for Term {
460    type Error = anyhow::Error;
461
462    fn try_into(self) -> anyhow::Result<f32> {
463        if let Self::Constant(c) = self {
464            c.try_into()
465        } else {
466            Err(anyhow::anyhow!("Expected constant, found {}", self))
467        }
468    }
469}
470
471impl TryInto<f64> for Term {
472    type Error = anyhow::Error;
473
474    fn try_into(self) -> anyhow::Result<f64> {
475        if let Self::Constant(c) = self {
476            c.try_into()
477        } else {
478            Err(anyhow::anyhow!("Expected constant, found {}", self))
479        }
480    }
481}
482
483impl TryInto<bool> for Term {
484    type Error = anyhow::Error;
485
486    fn try_into(self) -> anyhow::Result<bool> {
487        if let Self::Constant(c) = self {
488            c.try_into()
489        } else {
490            Err(anyhow::anyhow!("Expected constant, found {}", self))
491        }
492    }
493}
494
495impl TryInto<Uuid> for Term {
496    type Error = anyhow::Error;
497
498    fn try_into(self) -> anyhow::Result<Uuid> {
499        if let Self::Constant(c) = self {
500            c.try_into()
501        } else {
502            Err(anyhow::anyhow!("Expected constant, found {}", self))
503        }
504    }
505}
506
507macro_rules! impl_from_constant_to_term {
508    ($t:ty) => {
509        impl From<$t> for Term {
510            fn from(t: $t) -> Self {
511                Self::Constant(t.into())
512            }
513        }
514    };
515}
516
517impl_from_constant_to_term!(LitStr);
518impl_from_constant_to_term!(i32);
519impl_from_constant_to_term!(i64);
520impl_from_constant_to_term!(f32);
521impl_from_constant_to_term!(f64);
522impl_from_constant_to_term!(bool);
523impl_from_constant_to_term!(Uuid);
524
525#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
526pub enum Constant {
527    Null,
528    String(LitStr),
529    Integer(String),
530    Float(String),
531    Boolean(bool),
532    Uuid(#[wrap] Uuid),
533    Hex(Vec<u8>),
534    Blob(Vec<u8>),
535}
536
537impl Constant {
538    pub fn string(s: &str) -> Self {
539        Self::String(s.into())
540    }
541
542    pub fn integer(s: &str) -> anyhow::Result<Self> {
543        Ok(Self::Integer(StatementStream::new(s).parse_from::<SignedNumber>()?))
544    }
545
546    pub fn float(s: &str) -> anyhow::Result<Self> {
547        Ok(Self::Float(StatementStream::new(s).parse_from::<Float>()?))
548    }
549
550    pub fn bool(b: bool) -> Self {
551        Self::Boolean(b)
552    }
553
554    pub fn uuid(u: Uuid) -> Self {
555        Self::Uuid(u)
556    }
557
558    pub fn hex(s: &str) -> anyhow::Result<Self> {
559        Ok(Self::Hex(StatementStream::new(s).parse_from::<Hex>()?))
560    }
561
562    pub fn blob(b: Vec<u8>) -> Self {
563        Self::Blob(b)
564    }
565}
566
567impl Parse for Constant {
568    type Output = Self;
569    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
570        Ok(if s.parse::<Option<NULL>>()?.is_some() {
571            Constant::Null
572        } else if let Some(ss) = s.parse()? {
573            Constant::String(ss)
574        } else if let Some(f) = s.parse_from::<Option<Float>>()? {
575            Constant::Float(f)
576        } else if let Some(i) = s.parse_from::<Option<SignedNumber>>()? {
577            Constant::Integer(i)
578        } else if let Some(b) = s.parse()? {
579            Constant::Boolean(b)
580        } else if let Some(u) = s.parse()? {
581            Constant::Uuid(u)
582        } else if s.peekn(2).map(|s| s.to_lowercase().as_str() == "0x").unwrap_or(false) {
583            s.nextn(2);
584            Constant::Blob(s.parse_from::<Hex>()?)
585        } else if let Some(h) = s.parse_from::<Option<Hex>>()? {
586            Constant::Hex(h)
587        } else {
588            anyhow::bail!("Expected constant, found {}", s.info())
589        })
590    }
591}
592
593impl Display for Constant {
594    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
595        match self {
596            Self::Null => write!(f, "NULL"),
597            Self::String(s) => s.fmt(f),
598            Self::Integer(s) => s.fmt(f),
599            Self::Float(s) => s.fmt(f),
600            Self::Boolean(b) => b.to_string().to_uppercase().fmt(f),
601            Self::Uuid(u) => u.fmt(f),
602            Self::Hex(h) => hex::encode(h).fmt(f),
603            Self::Blob(b) => write!(f, "0x{}", hex::encode(b)),
604        }
605    }
606}
607
608impl TryInto<LitStr> for Constant {
609    type Error = anyhow::Error;
610
611    fn try_into(self) -> anyhow::Result<LitStr> {
612        if let Self::String(s) = self {
613            Ok(s)
614        } else {
615            Err(anyhow::anyhow!("Expected string constant, found {}", self))
616        }
617    }
618}
619
620impl TryInto<i32> for Constant {
621    type Error = anyhow::Error;
622
623    fn try_into(self) -> anyhow::Result<i32> {
624        if let Self::Integer(i) = self {
625            Ok(i.parse()?)
626        } else {
627            Err(anyhow::anyhow!("Expected integer constant, found {}", self))
628        }
629    }
630}
631
632impl TryInto<i64> for Constant {
633    type Error = anyhow::Error;
634
635    fn try_into(self) -> anyhow::Result<i64> {
636        if let Self::Integer(i) = self {
637            Ok(i.parse()?)
638        } else {
639            Err(anyhow::anyhow!("Expected integer constant, found {}", self))
640        }
641    }
642}
643
644impl TryInto<f32> for Constant {
645    type Error = anyhow::Error;
646
647    fn try_into(self) -> anyhow::Result<f32> {
648        if let Self::Float(f) = self {
649            Ok(f.parse()?)
650        } else {
651            Err(anyhow::anyhow!("Expected float constant, found {}", self))
652        }
653    }
654}
655
656impl TryInto<f64> for Constant {
657    type Error = anyhow::Error;
658
659    fn try_into(self) -> anyhow::Result<f64> {
660        if let Self::Float(f) = self {
661            Ok(f.parse()?)
662        } else {
663            Err(anyhow::anyhow!("Expected float constant, found {}", self))
664        }
665    }
666}
667
668impl TryInto<bool> for Constant {
669    type Error = anyhow::Error;
670
671    fn try_into(self) -> anyhow::Result<bool> {
672        if let Self::Boolean(b) = self {
673            Ok(b)
674        } else {
675            Err(anyhow::anyhow!("Expected boolean constant, found {}", self))
676        }
677    }
678}
679
680impl TryInto<Uuid> for Constant {
681    type Error = anyhow::Error;
682
683    fn try_into(self) -> anyhow::Result<Uuid> {
684        if let Self::Uuid(u) = self {
685            Ok(u)
686        } else {
687            Err(anyhow::anyhow!("Expected UUID constant, found {}", self))
688        }
689    }
690}
691
692impl From<LitStr> for Constant {
693    fn from(s: LitStr) -> Self {
694        Self::String(s)
695    }
696}
697
698impl From<i32> for Constant {
699    fn from(i: i32) -> Self {
700        Self::Integer(i.to_string())
701    }
702}
703
704impl From<i64> for Constant {
705    fn from(i: i64) -> Self {
706        Self::Integer(i.to_string())
707    }
708}
709
710impl From<f32> for Constant {
711    fn from(f: f32) -> Self {
712        Self::Float(format_cql_f32(f))
713    }
714}
715
716impl From<f64> for Constant {
717    fn from(f: f64) -> Self {
718        Self::Float(format_cql_f64(f))
719    }
720}
721
722impl From<bool> for Constant {
723    fn from(b: bool) -> Self {
724        Self::Boolean(b)
725    }
726}
727
728impl From<Uuid> for Constant {
729    fn from(u: Uuid) -> Self {
730        Self::Uuid(u)
731    }
732}
733
734#[derive(ParseFromStr, Clone, Debug, TryInto, From, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
735pub enum Literal {
736    Collection(CollectionTypeLiteral),
737    UserDefined(UserDefinedTypeLiteral),
738    Tuple(TupleLiteral),
739}
740
741impl Parse for Literal {
742    type Output = Self;
743    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
744        Ok(if let Some(c) = s.parse()? {
745            Self::Collection(c)
746        } else if let Some(u) = s.parse()? {
747            Self::UserDefined(u)
748        } else if let Some(t) = s.parse()? {
749            Self::Tuple(t)
750        } else {
751            anyhow::bail!("Expected CQL literal, found {}", s.info())
752        })
753    }
754}
755
756impl Display for Literal {
757    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
758        match self {
759            Self::Collection(c) => c.fmt(f),
760            Self::UserDefined(u) => u.fmt(f),
761            Self::Tuple(t) => t.fmt(f),
762        }
763    }
764}
765
766#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, From, ToTokens)]
767pub enum CqlType {
768    Native(NativeType),
769    #[from(ignore)]
770    Collection(Box<CollectionType>),
771    UserDefined(UserDefinedType),
772    Tuple(Vec<CqlType>),
773    Custom(LitStr),
774}
775
776impl Parse for CqlType {
777    type Output = Self;
778    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
779    where
780        Self: Sized,
781    {
782        Ok(if let Some(c) = s.parse()? {
783            Self::Collection(Box::new(c))
784        } else if s.parse::<Option<TUPLE>>()?.is_some() {
785            Self::Tuple(s.parse_from::<Angles<List<CqlType, Comma>>>()?)
786        } else if let Some(n) = s.parse()? {
787            Self::Native(n)
788        } else if let Some(udt) = s.parse()? {
789            Self::UserDefined(udt)
790        } else if let Some(c) = s.parse()? {
791            Self::Custom(c)
792        } else {
793            anyhow::bail!("Expected CQL Type, found {}", s.info())
794        })
795    }
796}
797
798impl Display for CqlType {
799    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
800        match self {
801            Self::Native(n) => n.fmt(f),
802            Self::Collection(c) => c.fmt(f),
803            Self::UserDefined(u) => u.fmt(f),
804            Self::Tuple(t) => write!(
805                f,
806                "TUPLE<{}>",
807                t.iter().map(|t| t.to_string()).collect::<Vec<_>>().join(", ")
808            ),
809            Self::Custom(c) => c.fmt(f),
810        }
811    }
812}
813
814impl From<CollectionType> for CqlType {
815    fn from(c: CollectionType) -> Self {
816        Self::Collection(Box::new(c))
817    }
818}
819
820#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
821pub enum NativeType {
822    Ascii,
823    Bigint,
824    Blob,
825    Boolean,
826    Counter,
827    Date,
828    Decimal,
829    Double,
830    Duration,
831    Float,
832    Inet,
833    Int,
834    Smallint,
835    Text,
836    Time,
837    Timestamp,
838    Timeuuid,
839    Tinyint,
840    Uuid,
841    Varchar,
842    Varint,
843}
844
845impl Display for NativeType {
846    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
847        write!(
848            f,
849            "{}",
850            match self {
851                NativeType::Ascii => "ASCII",
852                NativeType::Bigint => "BIGINT",
853                NativeType::Blob => "BLOB",
854                NativeType::Boolean => "BOOLEAN",
855                NativeType::Counter => "COUNTER",
856                NativeType::Date => "DATE",
857                NativeType::Decimal => "DECIMAL",
858                NativeType::Double => "DOUBLE",
859                NativeType::Duration => "DURATION",
860                NativeType::Float => "FLOAT",
861                NativeType::Inet => "INET",
862                NativeType::Int => "INT",
863                NativeType::Smallint => "SMALLINT",
864                NativeType::Text => "TEXT",
865                NativeType::Time => "TIME",
866                NativeType::Timestamp => "TIMESTAMP",
867                NativeType::Timeuuid => "TIMEUUID",
868                NativeType::Tinyint => "TINYINT",
869                NativeType::Uuid => "UUID",
870                NativeType::Varchar => "VARCHAR",
871                NativeType::Varint => "VARINT",
872            }
873        )
874    }
875}
876
877impl FromStr for NativeType {
878    type Err = anyhow::Error;
879
880    fn from_str(s: &str) -> Result<Self, Self::Err> {
881        Ok(match s.to_uppercase().as_str() {
882            "ASCII" => NativeType::Ascii,
883            "BIGINT" => NativeType::Bigint,
884            "BLOB" => NativeType::Blob,
885            "BOOLEAN" => NativeType::Boolean,
886            "COUNTER" => NativeType::Counter,
887            "DATE" => NativeType::Date,
888            "DECIMAL" => NativeType::Decimal,
889            "DOUBLE" => NativeType::Double,
890            "DURATION" => NativeType::Duration,
891            "FLOAT" => NativeType::Float,
892            "INET" => NativeType::Inet,
893            "INT" => NativeType::Int,
894            "SMALLINT" => NativeType::Smallint,
895            "TEXT" => NativeType::Text,
896            "TIME" => NativeType::Time,
897            "TIMESTAMP" => NativeType::Timestamp,
898            "TIMEUUID" => NativeType::Timeuuid,
899            "TINYINT" => NativeType::Tinyint,
900            "UUID" => NativeType::Uuid,
901            "VARCHAR" => NativeType::Varchar,
902            "VARINT" => NativeType::Varint,
903            _ => anyhow::bail!("Invalid native type: {}", s),
904        })
905    }
906}
907
908impl Parse for NativeType {
909    type Output = Self;
910    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
911    where
912        Self: Sized,
913    {
914        let token = s.parse_from::<Alpha>()?;
915        NativeType::from_str(&token)
916    }
917}
918
919#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
920pub enum CollectionTypeLiteral {
921    List(ListLiteral),
922    Set(SetLiteral),
923    Map(MapLiteral),
924}
925
926impl Parse for CollectionTypeLiteral {
927    type Output = Self;
928    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
929        Ok(if let Some(l) = s.parse()? {
930            Self::List(l)
931        } else if let Some(s) = s.parse()? {
932            Self::Set(s)
933        } else if let Some(m) = s.parse()? {
934            Self::Map(m)
935        } else {
936            anyhow::bail!("Expected collection literal, found {}", s.info())
937        })
938    }
939}
940
941impl Display for CollectionTypeLiteral {
942    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
943        match self {
944            Self::List(l) => l.fmt(f),
945            Self::Set(s) => s.fmt(f),
946            Self::Map(m) => m.fmt(f),
947        }
948    }
949}
950
951#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
952pub enum CollectionType {
953    List(CqlType),
954    Set(CqlType),
955    Map(CqlType, CqlType),
956}
957
958impl CollectionType {
959    pub fn list<T: Into<CqlType>>(t: T) -> Self {
960        Self::List(t.into())
961    }
962
963    pub fn set<T: Into<CqlType>>(t: T) -> Self {
964        Self::Set(t.into())
965    }
966
967    pub fn map<K: Into<CqlType>, V: Into<CqlType>>(k: K, v: V) -> Self {
968        Self::Map(k.into(), v.into())
969    }
970}
971
972impl Parse for CollectionType {
973    type Output = Self;
974    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
975    where
976        Self: Sized,
977    {
978        Ok(if s.parse::<Option<MAP>>()?.is_some() {
979            let (t1, _, t2) = s.parse_from::<Angles<(CqlType, Comma, CqlType)>>()?;
980            Self::Map(t1, t2)
981        } else if s.parse::<Option<SET>>()?.is_some() {
982            Self::Set(s.parse_from::<Angles<CqlType>>()?)
983        } else if s.parse::<Option<LIST>>()?.is_some() {
984            Self::List(s.parse_from::<Angles<CqlType>>()?)
985        } else {
986            anyhow::bail!("Expected collection type, found {}", s.info())
987        })
988    }
989}
990
991impl Display for CollectionType {
992    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
993        match self {
994            Self::List(e) => write!(f, "LIST<{}>", e),
995            Self::Set(e) => write!(f, "SET<{}>", e),
996            Self::Map(k, v) => write!(f, "MAP<{}, {}>", k, v),
997        }
998    }
999}
1000
1001#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1002#[parse_via(TaggedMapLiteral)]
1003pub struct MapLiteral {
1004    pub elements: BTreeMap<Term, Term>,
1005}
1006
1007impl TryFrom<TaggedMapLiteral> for MapLiteral {
1008    type Error = anyhow::Error;
1009
1010    fn try_from(value: TaggedMapLiteral) -> Result<Self, Self::Error> {
1011        let mut elements = BTreeMap::new();
1012        for (k, v) in value.elements {
1013            if elements.insert(k.into_value()?, v.into_value()?).is_some() {
1014                anyhow::bail!("Duplicate key in map literal");
1015            }
1016        }
1017        Ok(Self { elements })
1018    }
1019}
1020
1021#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1022pub struct TaggedMapLiteral {
1023    pub elements: BTreeMap<Tag<Term>, Tag<Term>>,
1024}
1025
1026impl Parse for TaggedMapLiteral {
1027    type Output = Self;
1028    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
1029    where
1030        Self: Sized,
1031    {
1032        Ok(Self {
1033            elements: s
1034                .parse_from::<Braces<List<(Tag<Term>, Colon, Tag<Term>), Comma>>>()?
1035                .into_iter()
1036                .map(|(k, _, v)| (k, v))
1037                .collect(),
1038        })
1039    }
1040}
1041
1042impl Display for MapLiteral {
1043    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1044        write!(
1045            f,
1046            "{{{}}}",
1047            self.elements
1048                .iter()
1049                .map(|(k, v)| format!("{}: {}", k, v))
1050                .collect::<Vec<_>>()
1051                .join(", ")
1052        )
1053    }
1054}
1055
1056impl Display for TaggedMapLiteral {
1057    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1058        write!(
1059            f,
1060            "{{{}}}",
1061            self.elements
1062                .iter()
1063                .map(|(k, v)| format!("{}: {}", k, v))
1064                .collect::<Vec<_>>()
1065                .join(", ")
1066        )
1067    }
1068}
1069
1070impl<T: Into<Term>> From<HashMap<T, T>> for MapLiteral {
1071    fn from(m: HashMap<T, T>) -> Self {
1072        Self {
1073            elements: m.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
1074        }
1075    }
1076}
1077
1078impl<T: Into<Term>> From<BTreeMap<T, T>> for MapLiteral {
1079    fn from(m: BTreeMap<T, T>) -> Self {
1080        Self {
1081            elements: m.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
1082        }
1083    }
1084}
1085
1086#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1087pub struct TupleLiteral {
1088    pub elements: Vec<Term>,
1089}
1090
1091impl Parse for TupleLiteral {
1092    type Output = Self;
1093    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
1094    where
1095        Self: Sized,
1096    {
1097        Ok(Self {
1098            elements: s.parse_from::<Parens<List<Term, Comma>>>()?,
1099        })
1100    }
1101}
1102
1103impl Display for TupleLiteral {
1104    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1105        write!(
1106            f,
1107            "({})",
1108            self.elements
1109                .iter()
1110                .map(|t| t.to_string())
1111                .collect::<Vec<_>>()
1112                .join(", ")
1113        )
1114    }
1115}
1116
1117impl<T: Into<Term>> From<Vec<T>> for TupleLiteral {
1118    fn from(elements: Vec<T>) -> Self {
1119        Self {
1120            elements: elements.into_iter().map(|t| t.into()).collect(),
1121        }
1122    }
1123}
1124
1125#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1126pub struct SetLiteral {
1127    pub elements: BTreeSet<Term>,
1128}
1129
1130impl Parse for SetLiteral {
1131    type Output = Self;
1132    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1133        let v = s.parse_from::<Braces<List<Term, Comma>>>()?;
1134        let mut elements = BTreeSet::new();
1135        for e in v {
1136            if elements.contains(&e) {
1137                anyhow::bail!("Duplicate element in set: {}", e);
1138            }
1139            elements.insert(e);
1140        }
1141        Ok(Self { elements })
1142    }
1143}
1144
1145impl Display for SetLiteral {
1146    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1147        write!(
1148            f,
1149            "{{{}}}",
1150            self.elements
1151                .iter()
1152                .map(|t| t.to_string())
1153                .collect::<Vec<_>>()
1154                .join(", ")
1155        )
1156    }
1157}
1158
1159impl<T: Into<Term>> From<HashSet<T>> for SetLiteral {
1160    fn from(elements: HashSet<T>) -> Self {
1161        Self {
1162            elements: elements.into_iter().map(|t| t.into()).collect(),
1163        }
1164    }
1165}
1166
1167impl<T: Into<Term>> From<BTreeSet<T>> for SetLiteral {
1168    fn from(elements: BTreeSet<T>) -> Self {
1169        Self {
1170            elements: elements.into_iter().map(|t| t.into()).collect(),
1171        }
1172    }
1173}
1174
1175#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1176pub struct ListLiteral {
1177    pub elements: Vec<Term>,
1178}
1179
1180impl Parse for ListLiteral {
1181    type Output = Self;
1182    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1183        Ok(Self {
1184            elements: s.parse_from::<Brackets<List<Term, Comma>>>()?,
1185        })
1186    }
1187}
1188
1189impl Display for ListLiteral {
1190    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1191        write!(
1192            f,
1193            "[{}]",
1194            self.elements
1195                .iter()
1196                .map(|t| t.to_string())
1197                .collect::<Vec<_>>()
1198                .join(", ")
1199        )
1200    }
1201}
1202
1203impl<T: Into<Term>> From<Vec<T>> for ListLiteral {
1204    fn from(elements: Vec<T>) -> Self {
1205        Self {
1206            elements: elements.into_iter().map(|t| t.into()).collect(),
1207        }
1208    }
1209}
1210
1211#[derive(Clone, Debug, Default)]
1212pub struct TimestampLiteral(i64);
1213impl Parse for TimestampLiteral {
1214    type Output = Self;
1215    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1216        if let Some(ts) = s.parse::<Option<LitStr>>()? {
1217            Ok(Self(
1218                ts.value
1219                    .parse::<NaiveDateTime>()
1220                    .map_err(|e| anyhow::anyhow!(e))?
1221                    .timestamp_millis(),
1222            ))
1223        } else {
1224            Ok(Self(s.parse::<u64>()? as i64))
1225        }
1226    }
1227}
1228
1229#[derive(Clone, Debug, Default)]
1230pub struct DateLiteral(u32);
1231impl Parse for DateLiteral {
1232    type Output = Self;
1233    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1234        if let Some(d) = s.parse::<Option<LitStr>>()? {
1235            let dur = d.value.parse::<NaiveDate>().map_err(|e| anyhow::anyhow!(e))?
1236                - NaiveDate::from_ymd_opt(1970, 1, 1).ok_or(anyhow::anyhow!("Out of range ymd"))?;
1237            Ok(Self(dur.num_days() as u32 + (1u32 << 31)))
1238        } else {
1239            Ok(Self(s.parse::<u32>()?))
1240        }
1241    }
1242}
1243
1244#[derive(Clone, Debug, Default)]
1245pub struct TimeLiteral(i64);
1246impl Parse for TimeLiteral {
1247    type Output = Self;
1248    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1249        if let Some(t) = s.parse::<Option<LitStr>>()? {
1250            let t = t.value.parse::<NaiveTime>().map_err(|e| anyhow::anyhow!(e))?
1251                - NaiveTime::from_hms_opt(0, 0, 0).ok_or(anyhow::anyhow!("Out of range hms"))?;
1252            Ok(Self(
1253                t.num_nanoseconds()
1254                    .ok_or_else(|| anyhow::anyhow!("Invalid time literal!"))?,
1255            ))
1256        } else {
1257            Ok(Self(s.parse::<u64>()? as i64))
1258        }
1259    }
1260}
1261
1262enum DurationLiteralKind {
1263    QuantityUnit,
1264    ISO8601,
1265}
1266
1267#[derive(Builder, Clone, Debug, Default)]
1268#[builder(default, build_fn(validate = "Self::validate"))]
1269struct ISO8601 {
1270    years: i64,
1271    months: i64,
1272    days: i64,
1273    hours: i64,
1274    minutes: i64,
1275    seconds: i64,
1276    weeks: i64,
1277}
1278
1279impl ISO8601Builder {
1280    fn validate(&self) -> Result<(), String> {
1281        if self.weeks.is_some()
1282            && (self.years.is_some()
1283                || self.months.is_some()
1284                || self.days.is_some()
1285                || self.hours.is_some()
1286                || self.minutes.is_some()
1287                || self.seconds.is_some())
1288        {
1289            return Err("ISO8601 duration cannot have weeks and other units".to_string());
1290        }
1291        Ok(())
1292    }
1293}
1294
1295#[derive(ParseFromStr, Clone, Debug, Default, ToTokens, PartialEq, Eq)]
1296pub struct DurationLiteral {
1297    pub months: i32,
1298    pub days: i32,
1299    pub nanos: i64,
1300}
1301
1302impl DurationLiteral {
1303    pub fn ns(mut self, ns: i64) -> Self {
1304        self.nanos += ns;
1305        self
1306    }
1307
1308    pub fn us(self, us: i64) -> Self {
1309        self.ns(us * 1000)
1310    }
1311
1312    pub fn ms(self, ms: i64) -> Self {
1313        self.us(ms * 1000)
1314    }
1315
1316    pub fn s(self, s: i64) -> Self {
1317        self.ms(s * 1000)
1318    }
1319
1320    pub fn m(self, m: i32) -> Self {
1321        self.s(m as i64 * 60)
1322    }
1323
1324    pub fn h(self, h: i32) -> Self {
1325        self.m(h * 60)
1326    }
1327
1328    pub fn d(mut self, d: i32) -> Self {
1329        self.days += d;
1330        self
1331    }
1332
1333    pub fn w(self, w: i32) -> Self {
1334        self.d(w * 7)
1335    }
1336
1337    pub fn mo(mut self, mo: i32) -> Self {
1338        self.months += mo;
1339        self
1340    }
1341
1342    pub fn y(self, y: i32) -> Self {
1343        self.mo(y * 12)
1344    }
1345}
1346
1347impl Parse for DurationLiteral {
1348    type Output = Self;
1349    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output> {
1350        let kind = match s.peek() {
1351            Some(c) => match c {
1352                'P' => {
1353                    s.next();
1354                    DurationLiteralKind::ISO8601
1355                }
1356                _ => DurationLiteralKind::QuantityUnit,
1357            },
1358            None => anyhow::bail!("End of statement!"),
1359        };
1360        match kind {
1361            DurationLiteralKind::ISO8601 => {
1362                let mut iso = ISO8601Builder::default();
1363                let mut ty = 'Y';
1364                let mut num = None;
1365                let mut time = false;
1366                let mut res = None;
1367                let mut alternative = None;
1368                while let Some(c) = s.peek() {
1369                    if c == 'P' {
1370                        anyhow::bail!("Invalid ISO8601 duration literal: Too many date specifiers");
1371                    } else if c == 'T' {
1372                        if time {
1373                            anyhow::bail!("Invalid ISO8601 duration literal: Too many time specifiers");
1374                        }
1375                        match ty {
1376                            'Y' => {
1377                                ty = 'h';
1378                            }
1379                            'D' => {
1380                                let num = num
1381                                    .take()
1382                                    .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration: Missing days"))?;
1383                                if num < 1 || num > 31 {
1384                                    anyhow::bail!("Invalid ISO8601 duration: Day out of range");
1385                                }
1386                                iso.days(num);
1387                                ty = 'h';
1388                            }
1389                            _ => {
1390                                panic!("Duration `ty` variable got set improperly to {}. This is a bug!", ty);
1391                            }
1392                        }
1393                        s.next();
1394                        time = true;
1395                    } else if c == '-' {
1396                        match alternative {
1397                            Some(true) => (),
1398                            Some(false) => anyhow::bail!("Invalid ISO8601 duration literal: Invalid '-' character"),
1399                            None => alternative = Some(true),
1400                        }
1401                        if time {
1402                            anyhow::bail!("Invalid ISO8601 duration literal: Date separator outside of date");
1403                        }
1404                        match ty {
1405                            'Y' => {
1406                                iso.years(
1407                                    num.take()
1408                                        .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration: Missing years"))?,
1409                                );
1410                                ty = 'M';
1411                            }
1412                            'M' => {
1413                                let num = num
1414                                    .take()
1415                                    .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration: Missing months"))?;
1416                                if num < 1 || num > 12 {
1417                                    anyhow::bail!("Invalid ISO8601 duration: Month out of range");
1418                                }
1419                                iso.months(num);
1420                                ty = 'D';
1421                            }
1422                            _ => {
1423                                panic!("Duration `ty` variable got set improperly to {}. This is a bug!", ty);
1424                            }
1425                        }
1426                        s.next();
1427                    } else if c == ':' {
1428                        match alternative {
1429                            Some(true) => (),
1430                            Some(false) => anyhow::bail!("Invalid ISO8601 duration literal: Invalid '-' character"),
1431                            None => alternative = Some(true),
1432                        }
1433                        if !time {
1434                            anyhow::bail!("Invalid ISO8601 duration: Time separator outside of time");
1435                        }
1436                        match ty {
1437                            'h' => {
1438                                let num = num
1439                                    .take()
1440                                    .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration: Missing hours"))?;
1441                                if num > 24 {
1442                                    anyhow::bail!("Invalid ISO8601 duration: Hour out of range");
1443                                }
1444                                iso.hours(num);
1445                                ty = 'm';
1446                            }
1447                            'm' => {
1448                                let num = num
1449                                    .take()
1450                                    .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration: Missing minutes"))?;
1451                                if num > 59 {
1452                                    anyhow::bail!("Invalid ISO8601 duration: Minutes out of range");
1453                                }
1454                                iso.minutes(num);
1455                                ty = 's';
1456                            }
1457                            _ => {
1458                                panic!("Duration `ty` variable got set improperly to {}. This is a bug!", ty);
1459                            }
1460                        }
1461                        s.next();
1462                    } else if c.is_alphabetic() {
1463                        match alternative {
1464                            Some(false) => (),
1465                            Some(true) => anyhow::bail!("Invalid ISO8601 duration literal: Invalid unit specifier character in alternative format"),
1466                            None => alternative = Some(false),
1467                        }
1468                        match c {
1469                            'Y' => {
1470                                if iso.years.is_some() {
1471                                    anyhow::bail!("Invalid ISO8601 duration: Duplicate year specifiers");
1472                                }
1473                                iso.years(num.take().ok_or_else(|| {
1474                                    anyhow::anyhow!("Invalid ISO8601 duration: Missing number preceeding years unit")
1475                                })?);
1476                            }
1477                            'M' => {
1478                                if !time {
1479                                    if iso.months.is_some() {
1480                                        anyhow::bail!("Invalid ISO8601 duration: Duplicate month specifiers");
1481                                    }
1482                                    iso.months(num.take().ok_or_else(|| {
1483                                        anyhow::anyhow!(
1484                                            "Invalid ISO8601 duration: Missing number preceeding months unit"
1485                                        )
1486                                    })?);
1487                                } else {
1488                                    if iso.minutes.is_some() {
1489                                        anyhow::bail!("Invalid ISO8601 duration: Duplicate minute specifiers");
1490                                    }
1491                                    iso.minutes(num.take().ok_or_else(|| {
1492                                        anyhow::anyhow!(
1493                                            "Invalid ISO8601 duration: Missing number preceeding minutes unit"
1494                                        )
1495                                    })?);
1496                                }
1497                            }
1498                            'D' => {
1499                                if iso.days.is_some() {
1500                                    anyhow::bail!("Invalid ISO8601 duration: Duplicate day specifiers");
1501                                }
1502                                iso.days(num.take().ok_or_else(|| {
1503                                    anyhow::anyhow!("Invalid ISO8601 duration: Missing number preceeding days unit")
1504                                })?);
1505                            }
1506                            'H' => {
1507                                if iso.hours.is_some() {
1508                                    anyhow::bail!("Invalid ISO8601 duration: Duplicate hour specifiers");
1509                                }
1510                                iso.hours(num.take().ok_or_else(|| {
1511                                    anyhow::anyhow!("Invalid ISO8601 duration: Missing number preceeding hours unit")
1512                                })?);
1513                            }
1514                            'S' => {
1515                                if iso.seconds.is_some() {
1516                                    anyhow::bail!("Invalid ISO8601 duration: Duplicate second specifiers");
1517                                }
1518                                iso.seconds(num.take().ok_or_else(|| {
1519                                    anyhow::anyhow!("Invalid ISO8601 duration: Missing number preceeding seconds unit")
1520                                })?);
1521                            }
1522                            'W' => {
1523                                if iso.weeks.is_some() {
1524                                    anyhow::bail!("Invalid ISO8601 duration: Duplicate week specifiers");
1525                                }
1526                                iso.weeks(num.take().ok_or_else(|| {
1527                                    anyhow::anyhow!("Invalid ISO8601 duration: Missing number preceeding weeks unit")
1528                                })?);
1529                            }
1530                            _ => {
1531                                anyhow::bail!(
1532                                    "Invalid ISO8601 duration: Expected P, Y, M, W, D, T, H, M, or S, found {}",
1533                                    c
1534                                );
1535                            }
1536                        }
1537                        s.next();
1538                    } else if c.is_numeric() {
1539                        num = Some(s.parse::<u64>()? as i64);
1540                    } else {
1541                        break;
1542                    }
1543                }
1544                let alternative = alternative
1545                    .ok_or_else(|| anyhow::anyhow!("Invalid ISO8601 duration literal: Unable to determine format"))?;
1546                if let Some(num) = num {
1547                    if !alternative {
1548                        anyhow::bail!("Invalid ISO8601 duration: Trailing number");
1549                    }
1550                    if !time {
1551                        if ty != 'D' {
1552                            anyhow::bail!("Invalid ISO8601 duration: Trailing number");
1553                        }
1554                        if num < 1 || num > 31 {
1555                            anyhow::bail!("Invalid ISO8601 duration: Day out of range");
1556                        }
1557                        iso.days(num);
1558                    } else {
1559                        if ty != 's' {
1560                            anyhow::bail!("Invalid ISO8601 duration: Trailing number");
1561                        }
1562                        if num > 59 {
1563                            anyhow::bail!("Invalid ISO8601 duration: Seconds out of range");
1564                        }
1565                        iso.seconds(num);
1566                    }
1567                    if iso.years.is_none() && iso.months.is_none() && iso.days.is_none()
1568                        || iso.hours.is_none() && iso.minutes.is_none() && iso.seconds.is_none()
1569                    {
1570                        anyhow::bail!("Invalid ISO8601 duration: Missing required unit for alternative format");
1571                    }
1572                    res = Some(iso);
1573                } else if !alternative {
1574                    res = Some(iso);
1575                }
1576                if let Some(iso) = res {
1577                    Ok(iso.build().map_err(|e| anyhow::anyhow!(e))?.into())
1578                } else {
1579                    anyhow::bail!("End of statement!");
1580                }
1581            }
1582            DurationLiteralKind::QuantityUnit => {
1583                let mut res = DurationLiteral::default();
1584                let mut num = None;
1585                while let Some(c) = s.peek() {
1586                    if c.is_numeric() {
1587                        num = Some(s.parse_from::<u64>()? as i64);
1588                    } else if c.is_alphabetic() {
1589                        if let Some(num) = num.take() {
1590                            match s.parse::<TimeUnit>()? {
1591                                TimeUnit::Nanos => res.nanos += num,
1592                                TimeUnit::Micros => res.nanos += num * 1000,
1593                                TimeUnit::Millis => res.nanos += num * 1_000_000,
1594                                TimeUnit::Seconds => res.nanos += num * 1_000_000_000,
1595                                TimeUnit::Minutes => res.nanos += num * 60_000_000_000,
1596                                TimeUnit::Hours => res.nanos += num * 3_600_000_000_000,
1597                                TimeUnit::Days => res.days += num as i32,
1598                                TimeUnit::Weeks => res.days += num as i32 * 7,
1599                                TimeUnit::Months => res.months += num as i32,
1600                                TimeUnit::Years => res.months += num as i32 * 12,
1601                            }
1602                        } else {
1603                            anyhow::bail!("Invalid ISO8601 duration: Missing number preceeding unit specifier");
1604                        }
1605                    } else {
1606                        break;
1607                    }
1608                }
1609                Ok(res)
1610            }
1611        }
1612    }
1613}
1614
1615impl Display for DurationLiteral {
1616    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1617        if self.months == 0 && self.days == 0 && self.nanos == 0 {
1618            write!(f, "0ns")
1619        } else {
1620            if self.months > 0 {
1621                write!(f, "{}mo", self.months)?;
1622            }
1623            if self.days > 0 {
1624                write!(f, "{}d", self.days)?;
1625            }
1626            if self.nanos > 0 {
1627                write!(f, "{}ns", self.nanos)?;
1628            }
1629            Ok(())
1630        }
1631    }
1632}
1633
1634impl From<NaiveDateTime> for DurationLiteral {
1635    fn from(dt: NaiveDateTime) -> Self {
1636        let mut res = DurationLiteral::default();
1637        res.months = dt.year() * 12 + dt.month() as i32;
1638        res.days = dt.day() as i32;
1639        res.nanos = dt.hour() as i64 * 3_600_000_000_000
1640            + dt.minute() as i64 * 60_000_000_000
1641            + dt.second() as i64 * 1_000_000_000;
1642        res
1643    }
1644}
1645
1646impl From<std::time::Duration> for DurationLiteral {
1647    fn from(d: std::time::Duration) -> Self {
1648        let mut res = DurationLiteral::default();
1649        let mut s = d.as_secs();
1650        res.months = (s / (60 * 60 * 24 * 30)) as i32;
1651        s %= 60 * 60 * 24 * 30;
1652        res.days = (s / (60 * 60 * 24)) as i32;
1653        s %= 60 * 60 * 24;
1654        res.nanos = (s * 1_000_000_000 + d.subsec_nanos() as u64) as i64;
1655        res
1656    }
1657}
1658
1659impl From<ISO8601> for DurationLiteral {
1660    fn from(iso: ISO8601) -> Self {
1661        let mut res = DurationLiteral::default();
1662        res.months = (iso.years * 12 + iso.months) as i32;
1663        res.days = (iso.weeks * 7 + iso.days) as i32;
1664        res.nanos = iso.hours * 3_600_000_000_000 + iso.minutes * 60_000_000_000 + iso.seconds * 1_000_000_000;
1665        res
1666    }
1667}
1668
1669#[derive(ParseFromStr, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd, ToTokens)]
1670pub struct UserDefinedTypeLiteral {
1671    pub fields: BTreeMap<Name, Term>,
1672}
1673
1674impl Parse for UserDefinedTypeLiteral {
1675    type Output = Self;
1676    fn parse(s: &mut StatementStream<'_>) -> anyhow::Result<Self::Output>
1677    where
1678        Self: Sized,
1679    {
1680        Ok(Self {
1681            fields: s
1682                .parse_from::<Braces<List<(Name, Colon, Term), Comma>>>()?
1683                .into_iter()
1684                .fold(BTreeMap::new(), |mut acc, (k, _, v)| {
1685                    acc.insert(k, v);
1686                    acc
1687                }),
1688        })
1689    }
1690}
1691
1692impl Display for UserDefinedTypeLiteral {
1693    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1694        write!(
1695            f,
1696            "{{{}}}",
1697            self.fields
1698                .iter()
1699                .map(|(k, v)| format!("{}: {}", k, v))
1700                .collect::<Vec<_>>()
1701                .join(", ")
1702        )
1703    }
1704}
1705
1706pub type UserDefinedType = KeyspaceQualifiedName;
1707
1708#[cfg(test)]
1709mod test {
1710    use super::*;
1711
1712    #[test]
1713    fn test_duration_literals() {
1714        assert_eq!(
1715            "P3Y6M4DT12H30M5S".parse::<DurationLiteral>().unwrap(),
1716            DurationLiteral::default().y(3).mo(6).d(4).h(12).m(30).s(5),
1717        );
1718        assert_eq!(
1719            "P23DT23H".parse::<DurationLiteral>().unwrap(),
1720            DurationLiteral::default().d(23).h(23),
1721        );
1722        assert_eq!(
1723            "P4Y".parse::<DurationLiteral>().unwrap(),
1724            DurationLiteral::default().y(4)
1725        );
1726        assert_eq!("PT0S".parse::<DurationLiteral>().unwrap(), DurationLiteral::default());
1727        assert_eq!(
1728            "P1M".parse::<DurationLiteral>().unwrap(),
1729            DurationLiteral::default().mo(1)
1730        );
1731        assert_eq!(
1732            "PT1M".parse::<DurationLiteral>().unwrap(),
1733            DurationLiteral::default().m(1)
1734        );
1735        assert_eq!(
1736            "PT36H".parse::<DurationLiteral>().unwrap(),
1737            DurationLiteral::default().h(36)
1738        );
1739        assert_eq!(
1740            "P1DT12H".parse::<DurationLiteral>().unwrap(),
1741            DurationLiteral::default().d(1).h(12)
1742        );
1743        assert_eq!(
1744            "P0003-06-04T12:30:05".parse::<DurationLiteral>().unwrap(),
1745            DurationLiteral::default().y(3).mo(6).d(4).h(12).m(30).s(5)
1746        );
1747        assert_eq!(
1748            "89h4m48s".parse::<DurationLiteral>().unwrap(),
1749            DurationLiteral::default().h(89).m(4).s(48)
1750        );
1751        assert_eq!(
1752            "89d4w48ns2us15ms".parse::<DurationLiteral>().unwrap(),
1753            DurationLiteral::default().d(89).w(4).ns(48).us(2).ms(15)
1754        );
1755
1756        assert!("P".parse::<DurationLiteral>().is_err());
1757        assert!("T".parse::<DurationLiteral>().is_err());
1758        assert!("PT".parse::<DurationLiteral>().is_err());
1759        assert!("P1".parse::<DurationLiteral>().is_err());
1760        assert!("P10Y3".parse::<DurationLiteral>().is_err());
1761        assert!("P0003-06-04".parse::<DurationLiteral>().is_err());
1762        assert!("T11:30:05".parse::<DurationLiteral>().is_err());
1763        assert!("PT11:30:05".parse::<DurationLiteral>().is_err());
1764        assert!("P0003-06-04T25:30:05".parse::<DurationLiteral>().is_err());
1765        assert!("P0003-06-04T12:70:05".parse::<DurationLiteral>().is_err());
1766        assert!("P0003-06-04T12:30:70".parse::<DurationLiteral>().is_err());
1767        assert!("P0003-06-80T12:30:05".parse::<DurationLiteral>().is_err());
1768        assert!("P0003-13-04T12:30:05".parse::<DurationLiteral>().is_err());
1769        assert!("2w6y8mo96ns4u".parse::<DurationLiteral>().is_err());
1770        assert!("2w6b8mo96ns4us".parse::<DurationLiteral>().is_err());
1771    }
1772}