wdl_ast/v1/
decls.rs

1//! V1 AST representation for declarations.
2
3use std::fmt;
4
5use super::EnvKeyword;
6use super::Expr;
7use super::Plus;
8use super::QuestionMark;
9use crate::AstNode;
10use crate::AstToken;
11use crate::Ident;
12use crate::SyntaxKind;
13use crate::SyntaxNode;
14use crate::TreeNode;
15use crate::TreeToken;
16
17/// Represents a `Map` type.
18#[derive(Clone, Debug, Eq)]
19pub struct MapType<N: TreeNode = SyntaxNode>(N);
20
21impl<N: TreeNode> MapType<N> {
22    /// Gets the key and value types of the `Map`.
23    pub fn types(&self) -> (PrimitiveType<N>, Type<N>) {
24        let mut children = self.0.children().filter_map(Type::cast);
25        let key = children
26            .next()
27            .expect("map should have a key type")
28            .unwrap_primitive_type();
29        let value = children.next().expect("map should have a value type");
30        (key, value)
31    }
32
33    /// Determines if the type is optional.
34    pub fn is_optional(&self) -> bool {
35        matches!(
36            self.0.last_token().map(|t| t.kind()),
37            Some(SyntaxKind::QuestionMark)
38        )
39    }
40}
41
42impl<N: TreeNode> PartialEq for MapType<N> {
43    fn eq(&self, other: &Self) -> bool {
44        self.is_optional() == other.is_optional() && self.types() == other.types()
45    }
46}
47
48impl<N: TreeNode> AstNode<N> for MapType<N> {
49    fn can_cast(kind: SyntaxKind) -> bool {
50        kind == SyntaxKind::MapTypeNode
51    }
52
53    fn cast(inner: N) -> Option<Self> {
54        match inner.kind() {
55            SyntaxKind::MapTypeNode => Some(Self(inner)),
56            _ => None,
57        }
58    }
59
60    fn inner(&self) -> &N {
61        &self.0
62    }
63}
64
65impl fmt::Display for MapType {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        let (key, value) = self.types();
68        write!(
69            f,
70            "Map[{key}, {value}]{o}",
71            o = if self.is_optional() { "?" } else { "" }
72        )
73    }
74}
75
76/// Represents an `Array` type.
77#[derive(Clone, Debug, Eq)]
78pub struct ArrayType<N: TreeNode = SyntaxNode>(N);
79
80impl<N: TreeNode> ArrayType<N> {
81    /// Gets the element type of the array.
82    pub fn element_type(&self) -> Type<N> {
83        Type::child(&self.0).expect("array should have an element type")
84    }
85
86    /// Determines if the type has the "non-empty" qualifier.
87    pub fn is_non_empty(&self) -> bool {
88        self.token::<Plus<N::Token>>().is_some()
89    }
90
91    /// Determines if the type is optional.
92    pub fn is_optional(&self) -> bool {
93        self.last_token::<QuestionMark<N::Token>>().is_some()
94    }
95}
96
97impl<N: TreeNode> PartialEq for ArrayType<N> {
98    fn eq(&self, other: &Self) -> bool {
99        self.is_optional() == other.is_optional()
100            && self.is_non_empty() == other.is_non_empty()
101            && self.element_type() == other.element_type()
102    }
103}
104
105impl<N: TreeNode> AstNode<N> for ArrayType<N> {
106    fn can_cast(kind: SyntaxKind) -> bool {
107        kind == SyntaxKind::ArrayTypeNode
108    }
109
110    fn cast(inner: N) -> Option<Self> {
111        match inner.kind() {
112            SyntaxKind::ArrayTypeNode => Some(Self(inner)),
113            _ => None,
114        }
115    }
116
117    fn inner(&self) -> &N {
118        &self.0
119    }
120}
121
122impl fmt::Display for ArrayType {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(
125            f,
126            "Array[{ty}]{p}{o}",
127            ty = self.element_type(),
128            p = if self.is_non_empty() { "+" } else { "" },
129            o = if self.is_optional() { "?" } else { "" }
130        )
131    }
132}
133
134/// Represents a `Pair` type.
135#[derive(Clone, Debug, Eq)]
136pub struct PairType<N: TreeNode = SyntaxNode>(N);
137
138impl<N: TreeNode> PairType<N> {
139    /// Gets the first and second types of the `Pair`.
140    pub fn types(&self) -> (Type<N>, Type<N>) {
141        let mut children = self.0.children().filter_map(Type::cast);
142        let left = children.next().expect("pair should have a left type");
143        let right = children.next().expect("pair should have a right type");
144        (left, right)
145    }
146
147    /// Determines if the type is optional.
148    pub fn is_optional(&self) -> bool {
149        matches!(
150            self.0.last_token().map(|t| t.kind()),
151            Some(SyntaxKind::QuestionMark)
152        )
153    }
154}
155
156impl<N: TreeNode> PartialEq for PairType<N> {
157    fn eq(&self, other: &Self) -> bool {
158        self.is_optional() == other.is_optional() && self.types() == other.types()
159    }
160}
161
162impl<N: TreeNode> AstNode<N> for PairType<N> {
163    fn can_cast(kind: SyntaxKind) -> bool {
164        kind == SyntaxKind::PairTypeNode
165    }
166
167    fn cast(inner: N) -> Option<Self> {
168        match inner.kind() {
169            SyntaxKind::PairTypeNode => Some(Self(inner)),
170            _ => None,
171        }
172    }
173
174    fn inner(&self) -> &N {
175        &self.0
176    }
177}
178
179impl fmt::Display for PairType {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        let (left, right) = self.types();
182        write!(
183            f,
184            "Pair[{left}, {right}]{o}",
185            o = if self.is_optional() { "?" } else { "" }
186        )
187    }
188}
189
190/// Represents a `Object` type.
191#[derive(Clone, Debug, Eq)]
192pub struct ObjectType<N: TreeNode = SyntaxNode>(N);
193
194impl<N: TreeNode> ObjectType<N> {
195    /// Determines if the type is optional.
196    pub fn is_optional(&self) -> bool {
197        matches!(
198            self.0.last_token().map(|t| t.kind()),
199            Some(SyntaxKind::QuestionMark)
200        )
201    }
202}
203
204impl<N: TreeNode> PartialEq for ObjectType<N> {
205    fn eq(&self, other: &Self) -> bool {
206        self.is_optional() == other.is_optional()
207    }
208}
209
210impl<N: TreeNode> AstNode<N> for ObjectType<N> {
211    fn can_cast(kind: SyntaxKind) -> bool {
212        kind == SyntaxKind::ObjectTypeNode
213    }
214
215    fn cast(inner: N) -> Option<Self> {
216        match inner.kind() {
217            SyntaxKind::ObjectTypeNode => Some(Self(inner)),
218            _ => None,
219        }
220    }
221
222    fn inner(&self) -> &N {
223        &self.0
224    }
225}
226
227impl fmt::Display for ObjectType {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        write!(
230            f,
231            "Object{o}",
232            o = if self.is_optional() { "?" } else { "" }
233        )
234    }
235}
236
237/// Represents a reference to a type.
238#[derive(Clone, Debug, Eq)]
239pub struct TypeRef<N: TreeNode = SyntaxNode>(N);
240
241impl<N: TreeNode> TypeRef<N> {
242    /// Gets the name of the type reference.
243    pub fn name(&self) -> Ident<N::Token> {
244        self.token().expect("type reference should have a name")
245    }
246
247    /// Determines if the type is optional.
248    pub fn is_optional(&self) -> bool {
249        matches!(
250            self.0.last_token().map(|t| t.kind()),
251            Some(SyntaxKind::QuestionMark)
252        )
253    }
254}
255
256impl<N: TreeNode> PartialEq for TypeRef<N> {
257    fn eq(&self, other: &Self) -> bool {
258        self.is_optional() == other.is_optional() && self.name().text() == other.name().text()
259    }
260}
261
262impl<N: TreeNode> AstNode<N> for TypeRef<N> {
263    fn can_cast(kind: SyntaxKind) -> bool {
264        kind == SyntaxKind::TypeRefNode
265    }
266
267    fn cast(inner: N) -> Option<Self> {
268        match inner.kind() {
269            SyntaxKind::TypeRefNode => Some(Self(inner)),
270            _ => None,
271        }
272    }
273
274    fn inner(&self) -> &N {
275        &self.0
276    }
277}
278
279impl fmt::Display for TypeRef {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        write!(
282            f,
283            "{n}{o}",
284            n = self.name().text(),
285            o = if self.is_optional() { "?" } else { "" }
286        )
287    }
288}
289
290/// Represents a kind of primitive type.
291#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
292pub enum PrimitiveTypeKind {
293    /// The primitive is a `Boolean`.
294    Boolean,
295    /// The primitive is an `Int`.
296    Integer,
297    /// The primitive is a `Float`.
298    Float,
299    /// The primitive is a `String`.
300    String,
301    /// The primitive is a `File`.
302    File,
303    /// The primitive is a `Directory`
304    Directory,
305}
306
307/// Represents a primitive type.
308#[derive(Clone, Debug, Eq)]
309pub struct PrimitiveType<N: TreeNode = SyntaxNode>(N);
310
311impl<N: TreeNode> PrimitiveType<N> {
312    /// Gets the kind of the primitive type.
313    pub fn kind(&self) -> PrimitiveTypeKind {
314        self.0
315            .children_with_tokens()
316            .find_map(|c| {
317                c.into_token().and_then(|t| match t.kind() {
318                    SyntaxKind::BooleanTypeKeyword => Some(PrimitiveTypeKind::Boolean),
319                    SyntaxKind::IntTypeKeyword => Some(PrimitiveTypeKind::Integer),
320                    SyntaxKind::FloatTypeKeyword => Some(PrimitiveTypeKind::Float),
321                    SyntaxKind::StringTypeKeyword => Some(PrimitiveTypeKind::String),
322                    SyntaxKind::FileTypeKeyword => Some(PrimitiveTypeKind::File),
323                    SyntaxKind::DirectoryTypeKeyword => Some(PrimitiveTypeKind::Directory),
324                    _ => None,
325                })
326            })
327            .expect("type should have a kind")
328    }
329
330    /// Determines if the type is optional.
331    pub fn is_optional(&self) -> bool {
332        matches!(
333            self.0.last_token().map(|t| t.kind()),
334            Some(SyntaxKind::QuestionMark)
335        )
336    }
337}
338
339impl<N: TreeNode> PartialEq for PrimitiveType<N> {
340    fn eq(&self, other: &Self) -> bool {
341        self.kind() == other.kind()
342    }
343}
344
345impl<N: TreeNode> AstNode<N> for PrimitiveType<N> {
346    fn can_cast(kind: SyntaxKind) -> bool {
347        kind == SyntaxKind::PrimitiveTypeNode
348    }
349
350    fn cast(inner: N) -> Option<Self> {
351        match inner.kind() {
352            SyntaxKind::PrimitiveTypeNode => Some(Self(inner)),
353            _ => None,
354        }
355    }
356
357    fn inner(&self) -> &N {
358        &self.0
359    }
360}
361
362impl fmt::Display for PrimitiveType {
363    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364        match self.kind() {
365            PrimitiveTypeKind::Boolean => write!(f, "Boolean")?,
366            PrimitiveTypeKind::Integer => write!(f, "Int")?,
367            PrimitiveTypeKind::Float => write!(f, "Float")?,
368            PrimitiveTypeKind::String => write!(f, "String")?,
369            PrimitiveTypeKind::File => write!(f, "File")?,
370            PrimitiveTypeKind::Directory => write!(f, "Directory")?,
371        }
372
373        if self.is_optional() {
374            write!(f, "?")
375        } else {
376            Ok(())
377        }
378    }
379}
380
381/// Represents a type.
382#[derive(Clone, Debug, Eq)]
383pub enum Type<N: TreeNode = SyntaxNode> {
384    /// The type is a map.
385    Map(MapType<N>),
386    /// The type is an array.
387    Array(ArrayType<N>),
388    /// The type is a pair.
389    Pair(PairType<N>),
390    /// The type is an object.
391    Object(ObjectType<N>),
392    /// The type is a reference to custom type.
393    Ref(TypeRef<N>),
394    /// The type is a primitive.
395    Primitive(PrimitiveType<N>),
396}
397
398impl<N: TreeNode> Type<N> {
399    //// Returns whether or not the given syntax kind can be cast to
400    /// [`Type`].
401    pub fn can_cast(kind: SyntaxKind) -> bool {
402        matches!(
403            kind,
404            SyntaxKind::MapTypeNode
405                | SyntaxKind::ArrayTypeNode
406                | SyntaxKind::PairTypeNode
407                | SyntaxKind::ObjectTypeNode
408                | SyntaxKind::TypeRefNode
409                | SyntaxKind::PrimitiveTypeNode
410        )
411    }
412
413    /// Casts the given node to [`Type`].
414    ///
415    /// Returns `None` if the node cannot be cast.
416    pub fn cast(inner: N) -> Option<Self> {
417        match inner.kind() {
418            SyntaxKind::MapTypeNode => {
419                Some(Self::Map(MapType::cast(inner).expect("map type to cast")))
420            }
421            SyntaxKind::ArrayTypeNode => Some(Self::Array(
422                ArrayType::cast(inner).expect("array type to cast"),
423            )),
424            SyntaxKind::PairTypeNode => Some(Self::Pair(
425                PairType::cast(inner).expect("pair type to cast"),
426            )),
427            SyntaxKind::ObjectTypeNode => Some(Self::Object(
428                ObjectType::cast(inner).expect("object type to cast"),
429            )),
430            SyntaxKind::TypeRefNode => {
431                Some(Self::Ref(TypeRef::cast(inner).expect("type ref to cast")))
432            }
433            SyntaxKind::PrimitiveTypeNode => Some(Self::Primitive(
434                PrimitiveType::cast(inner).expect("primitive type to cast"),
435            )),
436            _ => None,
437        }
438    }
439
440    /// Gets a reference to the inner node.
441    pub fn inner(&self) -> &N {
442        match self {
443            Self::Map(ty) => ty.inner(),
444            Self::Array(ty) => ty.inner(),
445            Self::Pair(ty) => ty.inner(),
446            Self::Object(ty) => ty.inner(),
447            Self::Ref(ty) => ty.inner(),
448            Self::Primitive(ty) => ty.inner(),
449        }
450    }
451
452    /// Determines if the type is optional.
453    pub fn is_optional(&self) -> bool {
454        match self {
455            Self::Map(m) => m.is_optional(),
456            Self::Array(a) => a.is_optional(),
457            Self::Pair(p) => p.is_optional(),
458            Self::Object(o) => o.is_optional(),
459            Self::Ref(r) => r.is_optional(),
460            Self::Primitive(p) => p.is_optional(),
461        }
462    }
463
464    /// Attempts to get a reference to the inner [`MapType`].
465    ///
466    /// * If `self` is a [`Type::Map`], then a reference to the inner
467    ///   [`MapType`] is returned wrapped in [`Some`].
468    /// * Else, [`None`] is returned.
469    pub fn as_map_type(&self) -> Option<&MapType<N>> {
470        match self {
471            Self::Map(ty) => Some(ty),
472            _ => None,
473        }
474    }
475
476    /// Consumes `self` and attempts to return the inner [`MapType`].
477    ///
478    /// * If `self` is a [`Type::Map`], then the inner [`MapType`] is returned
479    ///   wrapped in [`Some`].
480    /// * Else, [`None`] is returned.
481    pub fn into_map_type(self) -> Option<MapType<N>> {
482        match self {
483            Self::Map(ty) => Some(ty),
484            _ => None,
485        }
486    }
487
488    /// Unwraps the type into a map type.
489    ///
490    /// # Panics
491    ///
492    /// Panics if the type is not a map type.
493    pub fn unwrap_map_type(self) -> MapType<N> {
494        match self {
495            Self::Map(ty) => ty,
496            _ => panic!("not a map type"),
497        }
498    }
499
500    /// Attempts to get a reference to the inner [`ArrayType`].
501    ///
502    /// * If `self` is a [`Type::Array`], then a reference to the inner
503    ///   [`ArrayType`] is returned wrapped in [`Some`].
504    /// * Else, [`None`] is returned.
505    pub fn as_array_type(&self) -> Option<&ArrayType<N>> {
506        match self {
507            Self::Array(ty) => Some(ty),
508            _ => None,
509        }
510    }
511
512    /// Consumes `self` and attempts to return the inner [`ArrayType`].
513    ///
514    /// * If `self` is a [`Type::Array`], then the inner [`ArrayType`] is
515    ///   returned wrapped in [`Some`].
516    /// * Else, [`None`] is returned.
517    pub fn into_array_type(self) -> Option<ArrayType<N>> {
518        match self {
519            Self::Array(ty) => Some(ty),
520            _ => None,
521        }
522    }
523
524    /// Unwraps the type into an array type.
525    ///
526    /// # Panics
527    ///
528    /// Panics if the type is not an array type.
529    pub fn unwrap_array_type(self) -> ArrayType<N> {
530        match self {
531            Self::Array(ty) => ty,
532            _ => panic!("not an array type"),
533        }
534    }
535
536    /// Attempts to get a reference to the inner [`PairType`].
537    ///
538    /// * If `self` is a [`Type::Pair`], then a reference to the inner
539    ///   [`PairType`] is returned wrapped in [`Some`].
540    /// * Else, [`None`] is returned.
541    pub fn as_pair_type(&self) -> Option<&PairType<N>> {
542        match self {
543            Self::Pair(ty) => Some(ty),
544            _ => None,
545        }
546    }
547
548    /// Consumes `self` and attempts to return the inner [`PairType`].
549    ///
550    /// * If `self` is a [`Type::Pair`], then the inner [`PairType`] is returned
551    ///   wrapped in [`Some`].
552    /// * Else, [`None`] is returned.
553    pub fn into_pair_type(self) -> Option<PairType<N>> {
554        match self {
555            Self::Pair(ty) => Some(ty),
556            _ => None,
557        }
558    }
559
560    /// Unwraps the type into a pair type.
561    ///
562    /// # Panics
563    ///
564    /// Panics if the type is not a pair type.
565    pub fn unwrap_pair_type(self) -> PairType<N> {
566        match self {
567            Self::Pair(ty) => ty,
568            _ => panic!("not a pair type"),
569        }
570    }
571
572    /// Attempts to get a reference to the inner [`ObjectType`].
573    ///
574    /// * If `self` is a [`Type::Object`], then a reference to the inner
575    ///   [`ObjectType`] is returned wrapped in [`Some`].
576    /// * Else, [`None`] is returned.
577    pub fn as_object_type(&self) -> Option<&ObjectType<N>> {
578        match self {
579            Self::Object(ty) => Some(ty),
580            _ => None,
581        }
582    }
583
584    /// Consumes `self` and attempts to return the inner [`ObjectType`].
585    ///
586    /// * If `self` is a [`Type::Object`], then the inner [`ObjectType`] is
587    ///   returned wrapped in [`Some`].
588    /// * Else, [`None`] is returned.
589    pub fn into_object_type(self) -> Option<ObjectType<N>> {
590        match self {
591            Self::Object(ty) => Some(ty),
592            _ => None,
593        }
594    }
595
596    /// Unwraps the type into an object type.
597    ///
598    /// # Panics
599    ///
600    /// Panics if the type is not an object type.
601    pub fn unwrap_object_type(self) -> ObjectType<N> {
602        match self {
603            Self::Object(ty) => ty,
604            _ => panic!("not an object type"),
605        }
606    }
607
608    /// Attempts to get a reference to the inner [`TypeRef`].
609    ///
610    /// * If `self` is a [`Type::Ref`], then a reference to the inner
611    ///   [`TypeRef`] is returned wrapped in [`Some`].
612    /// * Else, [`None`] is returned.
613    pub fn as_type_ref(&self) -> Option<&TypeRef<N>> {
614        match self {
615            Self::Ref(ty) => Some(ty),
616            _ => None,
617        }
618    }
619
620    /// Consumes `self` and attempts to return the inner [`TypeRef`].
621    ///
622    /// * If `self` is a [`Type::Ref`], then the inner [`TypeRef`] is returned
623    ///   wrapped in [`Some`].
624    /// * Else, [`None`] is returned.
625    pub fn into_type_ref(self) -> Option<TypeRef<N>> {
626        match self {
627            Self::Ref(ty) => Some(ty),
628            _ => None,
629        }
630    }
631
632    /// Unwraps the type into a type reference.
633    ///
634    /// # Panics
635    ///
636    /// Panics if the type is not a type reference.
637    pub fn unwrap_type_ref(self) -> TypeRef<N> {
638        match self {
639            Self::Ref(ty) => ty,
640            _ => panic!("not a type reference"),
641        }
642    }
643
644    /// Attempts to get a reference to the inner [`PrimitiveType`].
645    ///
646    /// * If `self` is a [`Type::Primitive`], then a reference to the inner
647    ///   [`PrimitiveType`] is returned wrapped in [`Some`].
648    /// * Else, [`None`] is returned.
649    pub fn as_primitive_type(&self) -> Option<&PrimitiveType<N>> {
650        match self {
651            Self::Primitive(ty) => Some(ty),
652            _ => None,
653        }
654    }
655
656    /// Consumes `self` and attempts to return the inner [`PrimitiveType`].
657    ///
658    /// * If `self` is a [`Type::Primitive`], then the inner [`PrimitiveType`]
659    ///   is returned wrapped in [`Some`].
660    /// * Else, [`None`] is returned.
661    pub fn into_primitive_type(self) -> Option<PrimitiveType<N>> {
662        match self {
663            Self::Primitive(ty) => Some(ty),
664            _ => None,
665        }
666    }
667
668    /// Unwraps the type into a primitive type.
669    ///
670    /// # Panics
671    ///
672    /// Panics if the type is not a primitive type.
673    pub fn unwrap_primitive_type(self) -> PrimitiveType<N> {
674        match self {
675            Self::Primitive(ty) => ty,
676            _ => panic!("not a primitive type"),
677        }
678    }
679
680    /// Finds the first child that can be cast to a [`Type`].
681    pub fn child(node: &N) -> Option<Self> {
682        node.children().find_map(Self::cast)
683    }
684
685    /// Finds all children that can be cast to a [`Type`].
686    pub fn children(node: &N) -> impl Iterator<Item = Self> + use<'_, N> {
687        node.children().filter_map(Self::cast)
688    }
689}
690
691impl<N: TreeNode> PartialEq for Type<N> {
692    fn eq(&self, other: &Self) -> bool {
693        match (self, other) {
694            (Self::Map(l), Self::Map(r)) => l == r,
695            (Self::Array(l), Self::Array(r)) => l == r,
696            (Self::Pair(l), Self::Pair(r)) => l == r,
697            (Self::Object(l), Self::Object(r)) => l == r,
698            (Self::Ref(l), Self::Ref(r)) => l == r,
699            (Self::Primitive(l), Self::Primitive(r)) => l == r,
700            _ => false,
701        }
702    }
703}
704
705impl fmt::Display for Type {
706    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
707        match self {
708            Type::Map(m) => m.fmt(f),
709            Type::Array(a) => a.fmt(f),
710            Type::Pair(p) => p.fmt(f),
711            Type::Object(o) => o.fmt(f),
712            Type::Ref(r) => r.fmt(f),
713            Type::Primitive(p) => p.fmt(f),
714        }
715    }
716}
717
718/// Represents an unbound declaration.
719#[derive(Clone, Debug, PartialEq, Eq)]
720pub struct UnboundDecl<N: TreeNode = SyntaxNode>(N);
721
722impl<N: TreeNode> UnboundDecl<N> {
723    /// Gets the `env` token, if present.
724    ///
725    /// This may only return a token for task inputs (WDL 1.2+).
726    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
727        self.token()
728    }
729
730    /// Gets the type of the declaration.
731    pub fn ty(&self) -> Type<N> {
732        Type::child(&self.0).expect("unbound declaration should have a type")
733    }
734
735    /// Gets the name of the declaration.
736    pub fn name(&self) -> Ident<N::Token> {
737        self.token()
738            .expect("unbound declaration should have a name")
739    }
740}
741
742impl<N: TreeNode> AstNode<N> for UnboundDecl<N> {
743    fn can_cast(kind: SyntaxKind) -> bool {
744        kind == SyntaxKind::UnboundDeclNode
745    }
746
747    fn cast(inner: N) -> Option<Self> {
748        match inner.kind() {
749            SyntaxKind::UnboundDeclNode => Some(Self(inner)),
750            _ => None,
751        }
752    }
753
754    fn inner(&self) -> &N {
755        &self.0
756    }
757}
758
759/// Represents a bound declaration in a task or workflow definition.
760#[derive(Clone, Debug, PartialEq, Eq)]
761pub struct BoundDecl<N: TreeNode = SyntaxNode>(N);
762
763impl<N: TreeNode> BoundDecl<N> {
764    /// Gets the `env` token, if present.
765    ///
766    /// This may only return a token for task inputs and private declarations
767    /// (WDL 1.2+).
768    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
769        self.token()
770    }
771
772    /// Gets the type of the declaration.
773    pub fn ty(&self) -> Type<N> {
774        Type::child(&self.0).expect("bound declaration should have a type")
775    }
776
777    /// Gets the name of the declaration.
778    pub fn name(&self) -> Ident<N::Token> {
779        self.token().expect("bound declaration should have a name")
780    }
781
782    /// Gets the expression the declaration is bound to.
783    pub fn expr(&self) -> Expr<N> {
784        Expr::child(&self.0).expect("bound declaration should have an expression")
785    }
786}
787
788impl<N: TreeNode> AstNode<N> for BoundDecl<N> {
789    fn can_cast(kind: SyntaxKind) -> bool {
790        kind == SyntaxKind::BoundDeclNode
791    }
792
793    fn cast(inner: N) -> Option<Self> {
794        match inner.kind() {
795            SyntaxKind::BoundDeclNode => Some(Self(inner)),
796            _ => None,
797        }
798    }
799
800    fn inner(&self) -> &N {
801        &self.0
802    }
803}
804
805/// Represents a declaration in an input section.
806#[derive(Clone, Debug, PartialEq, Eq)]
807pub enum Decl<N: TreeNode = SyntaxNode> {
808    /// The declaration is bound.
809    Bound(BoundDecl<N>),
810    /// The declaration is unbound.
811    Unbound(UnboundDecl<N>),
812}
813
814impl<N: TreeNode> Decl<N> {
815    /// Returns whether or not the given syntax kind can be cast to
816    /// [`Decl`].
817    pub fn can_cast(kind: SyntaxKind) -> bool {
818        kind == SyntaxKind::BoundDeclNode || kind == SyntaxKind::UnboundDeclNode
819    }
820
821    /// Casts the given node to [`Decl`].
822    ///
823    /// Returns `None` if the node cannot be cast.
824    pub fn cast(inner: N) -> Option<Self> {
825        match inner.kind() {
826            SyntaxKind::BoundDeclNode => Some(Self::Bound(
827                BoundDecl::cast(inner).expect("bound decl to cast"),
828            )),
829            SyntaxKind::UnboundDeclNode => Some(Self::Unbound(
830                UnboundDecl::cast(inner).expect("unbound decl to cast"),
831            )),
832            _ => None,
833        }
834    }
835
836    /// Gets a reference to the inner node.
837    pub fn inner(&self) -> &N {
838        match self {
839            Self::Bound(d) => d.inner(),
840            Self::Unbound(d) => d.inner(),
841        }
842    }
843
844    /// Gets the `env` token, if present.
845    ///
846    /// This may only return a token for task inputs and private declarations
847    /// (WDL 1.2+).
848    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
849        match self {
850            Self::Bound(d) => d.env(),
851            Self::Unbound(d) => d.env(),
852        }
853    }
854
855    /// Gets the type of the declaration.
856    pub fn ty(&self) -> Type<N> {
857        match self {
858            Self::Bound(d) => d.ty(),
859            Self::Unbound(d) => d.ty(),
860        }
861    }
862
863    /// Gets the name of the declaration.
864    pub fn name(&self) -> Ident<N::Token> {
865        match self {
866            Self::Bound(d) => d.name(),
867            Self::Unbound(d) => d.name(),
868        }
869    }
870
871    /// Gets the expression of the declaration.
872    ///
873    /// Returns `None` for unbound declarations.
874    pub fn expr(&self) -> Option<Expr<N>> {
875        match self {
876            Self::Bound(d) => Some(d.expr()),
877            Self::Unbound(_) => None,
878        }
879    }
880
881    /// Attempts to get a reference to the inner [`BoundDecl`].
882    ///
883    /// * If `self` is a [`Decl::Bound`], then a reference to the inner
884    ///   [`BoundDecl`] is returned wrapped in [`Some`].
885    /// * Else, [`None`] is returned.
886    pub fn as_bound_decl(&self) -> Option<&BoundDecl<N>> {
887        match self {
888            Self::Bound(d) => Some(d),
889            _ => None,
890        }
891    }
892
893    /// Consumes `self` and attempts to return the inner [`BoundDecl`].
894    ///
895    /// * If `self` is a [`Decl::Bound`], then the inner [`BoundDecl`] is
896    ///   returned wrapped in [`Some`].
897    /// * Else, [`None`] is returned.
898    pub fn into_bound_decl(self) -> Option<BoundDecl<N>> {
899        match self {
900            Self::Bound(d) => Some(d),
901            _ => None,
902        }
903    }
904
905    /// Unwraps the declaration into a bound declaration.
906    ///
907    /// # Panics
908    ///
909    /// Panics if the declaration is not a bound declaration.
910    pub fn unwrap_bound_decl(self) -> BoundDecl<N> {
911        match self {
912            Self::Bound(d) => d,
913            _ => panic!("not a bound declaration"),
914        }
915    }
916
917    /// Attempts to get a reference to the inner [`UnboundDecl`].
918    ///
919    /// * If `self` is a [`Decl::Unbound`], then a reference to the inner
920    ///   [`UnboundDecl`] is returned wrapped in [`Some`].
921    /// * Else, [`None`] is returned.
922    pub fn as_unbound_decl(&self) -> Option<&UnboundDecl<N>> {
923        match self {
924            Self::Unbound(d) => Some(d),
925            _ => None,
926        }
927    }
928
929    /// Consumes `self` and attempts to return the inner [`UnboundDecl`].
930    ///
931    /// * If `self` is a [`Decl::Unbound`], then the inner [`UnboundDecl`] is
932    ///   returned wrapped in [`Some`].
933    /// * Else, [`None`] is returned.
934    pub fn into_unbound_decl(self) -> Option<UnboundDecl<N>> {
935        match self {
936            Self::Unbound(d) => Some(d),
937            _ => None,
938        }
939    }
940
941    /// Unwraps the declaration into an unbound declaration.
942    ///
943    /// # Panics
944    ///
945    /// Panics if the declaration is not an unbound declaration.
946    pub fn unwrap_unbound_decl(self) -> UnboundDecl<N> {
947        match self {
948            Self::Unbound(d) => d,
949            _ => panic!("not an unbound declaration"),
950        }
951    }
952
953    /// Finds the first child that can be cast to a [`Decl`].
954    pub fn child(node: &N) -> Option<Self> {
955        node.children().find_map(Self::cast)
956    }
957
958    /// Finds all children that can be cast to a [`Decl`].
959    pub fn children(node: &N) -> impl Iterator<Item = Self> + use<'_, N> {
960        node.children().filter_map(Self::cast)
961    }
962}
963
964#[cfg(test)]
965mod test {
966    use super::*;
967    use crate::Document;
968
969    #[test]
970    fn decls() {
971        let (document, diagnostics) = Document::parse(
972            r#"
973version 1.1
974
975task test {
976    input {
977        Boolean a
978        Int b = 42
979        Float? c = None
980        String d
981        File e = "foo.wdl"
982        Map[Int, Int] f
983        Array[String] g = []
984        Pair[Boolean, Int] h
985        Object i = object {}
986        MyStruct j
987        Directory k = "foo"
988    }
989}
990"#,
991        );
992
993        assert!(diagnostics.is_empty());
994        let ast = document.ast();
995        let ast = ast.as_v1().expect("should be a V1 AST");
996        let tasks: Vec<_> = ast.tasks().collect();
997        assert_eq!(tasks.len(), 1);
998        assert_eq!(tasks[0].name().text(), "test");
999
1000        // Inputs
1001        let input = tasks[0].input().expect("task should have an input section");
1002        let decls: Vec<_> = input.declarations().collect();
1003        assert_eq!(decls.len(), 11);
1004
1005        // First input declaration
1006        let decl = decls[0].clone().unwrap_unbound_decl();
1007        assert_eq!(decl.ty().to_string(), "Boolean");
1008        assert_eq!(decl.name().text(), "a");
1009
1010        // Second input declaration
1011        let decl = decls[1].clone().unwrap_bound_decl();
1012        assert_eq!(decl.ty().to_string(), "Int");
1013        assert_eq!(decl.name().text(), "b");
1014        assert_eq!(
1015            decl.expr()
1016                .unwrap_literal()
1017                .unwrap_integer()
1018                .value()
1019                .unwrap(),
1020            42
1021        );
1022
1023        // Third input declaration
1024        let decl = decls[2].clone().unwrap_bound_decl();
1025        assert_eq!(decl.ty().to_string(), "Float?");
1026        assert_eq!(decl.name().text(), "c");
1027        decl.expr().unwrap_literal().unwrap_none();
1028
1029        // Fourth input declaration
1030        let decl = decls[3].clone().unwrap_unbound_decl();
1031        assert_eq!(decl.ty().to_string(), "String");
1032        assert_eq!(decl.name().text(), "d");
1033
1034        // Fifth input declaration
1035        let decl = decls[4].clone().unwrap_bound_decl();
1036        assert_eq!(decl.ty().to_string(), "File");
1037        assert_eq!(decl.name().text(), "e");
1038        assert_eq!(
1039            decl.expr()
1040                .unwrap_literal()
1041                .unwrap_string()
1042                .text()
1043                .unwrap()
1044                .text(),
1045            "foo.wdl"
1046        );
1047
1048        // Sixth input declaration
1049        let decl = decls[5].clone().unwrap_unbound_decl();
1050        assert_eq!(decl.ty().to_string(), "Map[Int, Int]");
1051        assert_eq!(decl.name().text(), "f");
1052
1053        // Seventh input declaration
1054        let decl = decls[6].clone().unwrap_bound_decl();
1055        assert_eq!(decl.ty().to_string(), "Array[String]");
1056        assert_eq!(decl.name().text(), "g");
1057        assert_eq!(
1058            decl.expr()
1059                .unwrap_literal()
1060                .unwrap_array()
1061                .elements()
1062                .count(),
1063            0
1064        );
1065
1066        // Eighth input declaration
1067        let decl = decls[7].clone().unwrap_unbound_decl();
1068        assert_eq!(decl.ty().to_string(), "Pair[Boolean, Int]");
1069        assert_eq!(decl.name().text(), "h");
1070
1071        // Ninth input declaration
1072        let decl = decls[8].clone().unwrap_bound_decl();
1073        assert_eq!(decl.ty().to_string(), "Object");
1074        assert_eq!(decl.name().text(), "i");
1075        assert_eq!(
1076            decl.expr().unwrap_literal().unwrap_object().items().count(),
1077            0
1078        );
1079
1080        // Tenth input declaration
1081        let decl = decls[9].clone().unwrap_unbound_decl();
1082        assert_eq!(decl.ty().to_string(), "MyStruct");
1083        assert_eq!(decl.name().text(), "j");
1084
1085        // Eleventh input declaration
1086        let decl = decls[10].clone().unwrap_bound_decl();
1087        assert_eq!(decl.ty().to_string(), "Directory");
1088        assert_eq!(decl.name().text(), "k");
1089        assert_eq!(
1090            decl.expr()
1091                .unwrap_literal()
1092                .unwrap_string()
1093                .text()
1094                .unwrap()
1095                .text(),
1096            "foo"
1097        );
1098    }
1099}