Skip to main content

wdl_ast/v1/
decls.rs

1//! V1 AST representation for declarations.
2
3use std::fmt;
4
5use wdl_grammar::SyntaxTokenExt;
6
7use super::EnvKeyword;
8use super::Expr;
9use super::Plus;
10use super::QuestionMark;
11use crate::AstNode;
12use crate::AstToken;
13use crate::Comment;
14use crate::Documented;
15use crate::Ident;
16use crate::SyntaxKind;
17use crate::SyntaxNode;
18use crate::TreeNode;
19use crate::TreeToken;
20
21/// Represents a `Map` type.
22#[derive(Clone, Debug, Eq)]
23pub struct MapType<N: TreeNode = SyntaxNode>(N);
24
25impl<N: TreeNode> MapType<N> {
26    /// Gets the key and value types of the `Map`.
27    pub fn types(&self) -> (PrimitiveType<N>, Type<N>) {
28        let mut children = self.0.children().filter_map(Type::cast);
29        let key = children
30            .next()
31            .expect("map should have a key type")
32            .unwrap_primitive_type();
33        let value = children.next().expect("map should have a value type");
34        (key, value)
35    }
36
37    /// Determines if the type is optional.
38    pub fn is_optional(&self) -> bool {
39        matches!(
40            self.0.last_token().map(|t| t.kind()),
41            Some(SyntaxKind::QuestionMark)
42        )
43    }
44}
45
46impl<N: TreeNode> PartialEq for MapType<N> {
47    fn eq(&self, other: &Self) -> bool {
48        self.is_optional() == other.is_optional() && self.types() == other.types()
49    }
50}
51
52impl<N: TreeNode> AstNode<N> for MapType<N> {
53    fn can_cast(kind: SyntaxKind) -> bool {
54        kind == SyntaxKind::MapTypeNode
55    }
56
57    fn cast(inner: N) -> Option<Self> {
58        match inner.kind() {
59            SyntaxKind::MapTypeNode => Some(Self(inner)),
60            _ => None,
61        }
62    }
63
64    fn inner(&self) -> &N {
65        &self.0
66    }
67}
68
69impl fmt::Display for MapType {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        let (key, value) = self.types();
72        write!(
73            f,
74            "Map[{key}, {value}]{o}",
75            o = if self.is_optional() { "?" } else { "" }
76        )
77    }
78}
79
80/// Represents an `Array` type.
81#[derive(Clone, Debug, Eq)]
82pub struct ArrayType<N: TreeNode = SyntaxNode>(N);
83
84impl<N: TreeNode> ArrayType<N> {
85    /// Gets the element type of the array.
86    pub fn element_type(&self) -> Type<N> {
87        Type::child(&self.0).expect("array should have an element type")
88    }
89
90    /// Determines if the type has the "non-empty" qualifier.
91    pub fn is_non_empty(&self) -> bool {
92        self.token::<Plus<N::Token>>().is_some()
93    }
94
95    /// Determines if the type is optional.
96    pub fn is_optional(&self) -> bool {
97        self.last_token::<QuestionMark<N::Token>>().is_some()
98    }
99}
100
101impl<N: TreeNode> PartialEq for ArrayType<N> {
102    fn eq(&self, other: &Self) -> bool {
103        self.is_optional() == other.is_optional()
104            && self.is_non_empty() == other.is_non_empty()
105            && self.element_type() == other.element_type()
106    }
107}
108
109impl<N: TreeNode> AstNode<N> for ArrayType<N> {
110    fn can_cast(kind: SyntaxKind) -> bool {
111        kind == SyntaxKind::ArrayTypeNode
112    }
113
114    fn cast(inner: N) -> Option<Self> {
115        match inner.kind() {
116            SyntaxKind::ArrayTypeNode => Some(Self(inner)),
117            _ => None,
118        }
119    }
120
121    fn inner(&self) -> &N {
122        &self.0
123    }
124}
125
126impl fmt::Display for ArrayType {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(
129            f,
130            "Array[{ty}]{p}{o}",
131            ty = self.element_type(),
132            p = if self.is_non_empty() { "+" } else { "" },
133            o = if self.is_optional() { "?" } else { "" }
134        )
135    }
136}
137
138/// Represents a `Pair` type.
139#[derive(Clone, Debug, Eq)]
140pub struct PairType<N: TreeNode = SyntaxNode>(N);
141
142impl<N: TreeNode> PairType<N> {
143    /// Gets the first and second types of the `Pair`.
144    pub fn types(&self) -> (Type<N>, Type<N>) {
145        let mut children = self.0.children().filter_map(Type::cast);
146        let left = children.next().expect("pair should have a left type");
147        let right = children.next().expect("pair should have a right type");
148        (left, right)
149    }
150
151    /// Determines if the type is optional.
152    pub fn is_optional(&self) -> bool {
153        matches!(
154            self.0.last_token().map(|t| t.kind()),
155            Some(SyntaxKind::QuestionMark)
156        )
157    }
158}
159
160impl<N: TreeNode> PartialEq for PairType<N> {
161    fn eq(&self, other: &Self) -> bool {
162        self.is_optional() == other.is_optional() && self.types() == other.types()
163    }
164}
165
166impl<N: TreeNode> AstNode<N> for PairType<N> {
167    fn can_cast(kind: SyntaxKind) -> bool {
168        kind == SyntaxKind::PairTypeNode
169    }
170
171    fn cast(inner: N) -> Option<Self> {
172        match inner.kind() {
173            SyntaxKind::PairTypeNode => Some(Self(inner)),
174            _ => None,
175        }
176    }
177
178    fn inner(&self) -> &N {
179        &self.0
180    }
181}
182
183impl fmt::Display for PairType {
184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185        let (left, right) = self.types();
186        write!(
187            f,
188            "Pair[{left}, {right}]{o}",
189            o = if self.is_optional() { "?" } else { "" }
190        )
191    }
192}
193
194/// Represents a `Object` type.
195#[derive(Clone, Debug, Eq)]
196pub struct ObjectType<N: TreeNode = SyntaxNode>(N);
197
198impl<N: TreeNode> ObjectType<N> {
199    /// Determines if the type is optional.
200    pub fn is_optional(&self) -> bool {
201        matches!(
202            self.0.last_token().map(|t| t.kind()),
203            Some(SyntaxKind::QuestionMark)
204        )
205    }
206}
207
208impl<N: TreeNode> PartialEq for ObjectType<N> {
209    fn eq(&self, other: &Self) -> bool {
210        self.is_optional() == other.is_optional()
211    }
212}
213
214impl<N: TreeNode> AstNode<N> for ObjectType<N> {
215    fn can_cast(kind: SyntaxKind) -> bool {
216        kind == SyntaxKind::ObjectTypeNode
217    }
218
219    fn cast(inner: N) -> Option<Self> {
220        match inner.kind() {
221            SyntaxKind::ObjectTypeNode => Some(Self(inner)),
222            _ => None,
223        }
224    }
225
226    fn inner(&self) -> &N {
227        &self.0
228    }
229}
230
231impl fmt::Display for ObjectType {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        write!(
234            f,
235            "Object{o}",
236            o = if self.is_optional() { "?" } else { "" }
237        )
238    }
239}
240
241/// Represents a reference to a type.
242#[derive(Clone, Debug, Eq)]
243pub struct TypeRef<N: TreeNode = SyntaxNode>(N);
244
245impl<N: TreeNode> TypeRef<N> {
246    /// Gets the name of the type reference.
247    pub fn name(&self) -> Ident<N::Token> {
248        self.token().expect("type reference should have a name")
249    }
250
251    /// Determines if the type is optional.
252    pub fn is_optional(&self) -> bool {
253        matches!(
254            self.0.last_token().map(|t| t.kind()),
255            Some(SyntaxKind::QuestionMark)
256        )
257    }
258}
259
260impl<N: TreeNode> PartialEq for TypeRef<N> {
261    fn eq(&self, other: &Self) -> bool {
262        self.is_optional() == other.is_optional() && self.name().text() == other.name().text()
263    }
264}
265
266impl<N: TreeNode> AstNode<N> for TypeRef<N> {
267    fn can_cast(kind: SyntaxKind) -> bool {
268        kind == SyntaxKind::TypeRefNode
269    }
270
271    fn cast(inner: N) -> Option<Self> {
272        match inner.kind() {
273            SyntaxKind::TypeRefNode => Some(Self(inner)),
274            _ => None,
275        }
276    }
277
278    fn inner(&self) -> &N {
279        &self.0
280    }
281}
282
283impl fmt::Display for TypeRef {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        write!(
286            f,
287            "{n}{o}",
288            n = self.name().text(),
289            o = if self.is_optional() { "?" } else { "" }
290        )
291    }
292}
293
294/// Represents a kind of primitive type.
295#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
296pub enum PrimitiveTypeKind {
297    /// The primitive is a `Boolean`.
298    Boolean,
299    /// The primitive is an `Int`.
300    Integer,
301    /// The primitive is a `Float`.
302    Float,
303    /// The primitive is a `String`.
304    String,
305    /// The primitive is a `File`.
306    File,
307    /// The primitive is a `Directory`
308    Directory,
309}
310
311/// Represents a primitive type.
312#[derive(Clone, Debug, Eq)]
313pub struct PrimitiveType<N: TreeNode = SyntaxNode>(N);
314
315impl<N: TreeNode> PrimitiveType<N> {
316    /// Gets the kind of the primitive type.
317    pub fn kind(&self) -> PrimitiveTypeKind {
318        self.0
319            .children_with_tokens()
320            .find_map(|c| {
321                c.into_token().and_then(|t| match t.kind() {
322                    SyntaxKind::BooleanTypeKeyword => Some(PrimitiveTypeKind::Boolean),
323                    SyntaxKind::IntTypeKeyword => Some(PrimitiveTypeKind::Integer),
324                    SyntaxKind::FloatTypeKeyword => Some(PrimitiveTypeKind::Float),
325                    SyntaxKind::StringTypeKeyword => Some(PrimitiveTypeKind::String),
326                    SyntaxKind::FileTypeKeyword => Some(PrimitiveTypeKind::File),
327                    SyntaxKind::DirectoryTypeKeyword => Some(PrimitiveTypeKind::Directory),
328                    _ => None,
329                })
330            })
331            .expect("type should have a kind")
332    }
333
334    /// Determines if the type is optional.
335    pub fn is_optional(&self) -> bool {
336        matches!(
337            self.0.last_token().map(|t| t.kind()),
338            Some(SyntaxKind::QuestionMark)
339        )
340    }
341}
342
343impl<N: TreeNode> PartialEq for PrimitiveType<N> {
344    fn eq(&self, other: &Self) -> bool {
345        self.kind() == other.kind()
346    }
347}
348
349impl<N: TreeNode> AstNode<N> for PrimitiveType<N> {
350    fn can_cast(kind: SyntaxKind) -> bool {
351        kind == SyntaxKind::PrimitiveTypeNode
352    }
353
354    fn cast(inner: N) -> Option<Self> {
355        match inner.kind() {
356            SyntaxKind::PrimitiveTypeNode => Some(Self(inner)),
357            _ => None,
358        }
359    }
360
361    fn inner(&self) -> &N {
362        &self.0
363    }
364}
365
366impl fmt::Display for PrimitiveType {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        match self.kind() {
369            PrimitiveTypeKind::Boolean => write!(f, "Boolean")?,
370            PrimitiveTypeKind::Integer => write!(f, "Int")?,
371            PrimitiveTypeKind::Float => write!(f, "Float")?,
372            PrimitiveTypeKind::String => write!(f, "String")?,
373            PrimitiveTypeKind::File => write!(f, "File")?,
374            PrimitiveTypeKind::Directory => write!(f, "Directory")?,
375        }
376
377        if self.is_optional() {
378            write!(f, "?")
379        } else {
380            Ok(())
381        }
382    }
383}
384
385/// Represents a type.
386#[derive(Clone, Debug, Eq)]
387pub enum Type<N: TreeNode = SyntaxNode> {
388    /// The type is a map.
389    Map(MapType<N>),
390    /// The type is an array.
391    Array(ArrayType<N>),
392    /// The type is a pair.
393    Pair(PairType<N>),
394    /// The type is an object.
395    Object(ObjectType<N>),
396    /// The type is a reference to custom type.
397    Ref(TypeRef<N>),
398    /// The type is a primitive.
399    Primitive(PrimitiveType<N>),
400}
401
402impl<N: TreeNode> Type<N> {
403    //// Returns whether or not the given syntax kind can be cast to
404    /// [`Type`].
405    pub fn can_cast(kind: SyntaxKind) -> bool {
406        matches!(
407            kind,
408            SyntaxKind::MapTypeNode
409                | SyntaxKind::ArrayTypeNode
410                | SyntaxKind::PairTypeNode
411                | SyntaxKind::ObjectTypeNode
412                | SyntaxKind::TypeRefNode
413                | SyntaxKind::PrimitiveTypeNode
414        )
415    }
416
417    /// Casts the given node to [`Type`].
418    ///
419    /// Returns `None` if the node cannot be cast.
420    pub fn cast(inner: N) -> Option<Self> {
421        match inner.kind() {
422            SyntaxKind::MapTypeNode => {
423                Some(Self::Map(MapType::cast(inner).expect("map type to cast")))
424            }
425            SyntaxKind::ArrayTypeNode => Some(Self::Array(
426                ArrayType::cast(inner).expect("array type to cast"),
427            )),
428            SyntaxKind::PairTypeNode => Some(Self::Pair(
429                PairType::cast(inner).expect("pair type to cast"),
430            )),
431            SyntaxKind::ObjectTypeNode => Some(Self::Object(
432                ObjectType::cast(inner).expect("object type to cast"),
433            )),
434            SyntaxKind::TypeRefNode => {
435                Some(Self::Ref(TypeRef::cast(inner).expect("type ref to cast")))
436            }
437            SyntaxKind::PrimitiveTypeNode => Some(Self::Primitive(
438                PrimitiveType::cast(inner).expect("primitive type to cast"),
439            )),
440            _ => None,
441        }
442    }
443
444    /// Gets a reference to the inner node.
445    pub fn inner(&self) -> &N {
446        match self {
447            Self::Map(ty) => ty.inner(),
448            Self::Array(ty) => ty.inner(),
449            Self::Pair(ty) => ty.inner(),
450            Self::Object(ty) => ty.inner(),
451            Self::Ref(ty) => ty.inner(),
452            Self::Primitive(ty) => ty.inner(),
453        }
454    }
455
456    /// Determines if the type is optional.
457    pub fn is_optional(&self) -> bool {
458        match self {
459            Self::Map(m) => m.is_optional(),
460            Self::Array(a) => a.is_optional(),
461            Self::Pair(p) => p.is_optional(),
462            Self::Object(o) => o.is_optional(),
463            Self::Ref(r) => r.is_optional(),
464            Self::Primitive(p) => p.is_optional(),
465        }
466    }
467
468    /// Attempts to get a reference to the inner [`MapType`].
469    ///
470    /// * If `self` is a [`Type::Map`], then a reference to the inner
471    ///   [`MapType`] is returned wrapped in [`Some`].
472    /// * Else, [`None`] is returned.
473    pub fn as_map_type(&self) -> Option<&MapType<N>> {
474        match self {
475            Self::Map(ty) => Some(ty),
476            _ => None,
477        }
478    }
479
480    /// Consumes `self` and attempts to return the inner [`MapType`].
481    ///
482    /// * If `self` is a [`Type::Map`], then the inner [`MapType`] is returned
483    ///   wrapped in [`Some`].
484    /// * Else, [`None`] is returned.
485    pub fn into_map_type(self) -> Option<MapType<N>> {
486        match self {
487            Self::Map(ty) => Some(ty),
488            _ => None,
489        }
490    }
491
492    /// Unwraps the type into a map type.
493    ///
494    /// # Panics
495    ///
496    /// Panics if the type is not a map type.
497    pub fn unwrap_map_type(self) -> MapType<N> {
498        match self {
499            Self::Map(ty) => ty,
500            _ => panic!("not a map type"),
501        }
502    }
503
504    /// Attempts to get a reference to the inner [`ArrayType`].
505    ///
506    /// * If `self` is a [`Type::Array`], then a reference to the inner
507    ///   [`ArrayType`] is returned wrapped in [`Some`].
508    /// * Else, [`None`] is returned.
509    pub fn as_array_type(&self) -> Option<&ArrayType<N>> {
510        match self {
511            Self::Array(ty) => Some(ty),
512            _ => None,
513        }
514    }
515
516    /// Consumes `self` and attempts to return the inner [`ArrayType`].
517    ///
518    /// * If `self` is a [`Type::Array`], then the inner [`ArrayType`] is
519    ///   returned wrapped in [`Some`].
520    /// * Else, [`None`] is returned.
521    pub fn into_array_type(self) -> Option<ArrayType<N>> {
522        match self {
523            Self::Array(ty) => Some(ty),
524            _ => None,
525        }
526    }
527
528    /// Unwraps the type into an array type.
529    ///
530    /// # Panics
531    ///
532    /// Panics if the type is not an array type.
533    pub fn unwrap_array_type(self) -> ArrayType<N> {
534        match self {
535            Self::Array(ty) => ty,
536            _ => panic!("not an array type"),
537        }
538    }
539
540    /// Attempts to get a reference to the inner [`PairType`].
541    ///
542    /// * If `self` is a [`Type::Pair`], then a reference to the inner
543    ///   [`PairType`] is returned wrapped in [`Some`].
544    /// * Else, [`None`] is returned.
545    pub fn as_pair_type(&self) -> Option<&PairType<N>> {
546        match self {
547            Self::Pair(ty) => Some(ty),
548            _ => None,
549        }
550    }
551
552    /// Consumes `self` and attempts to return the inner [`PairType`].
553    ///
554    /// * If `self` is a [`Type::Pair`], then the inner [`PairType`] is returned
555    ///   wrapped in [`Some`].
556    /// * Else, [`None`] is returned.
557    pub fn into_pair_type(self) -> Option<PairType<N>> {
558        match self {
559            Self::Pair(ty) => Some(ty),
560            _ => None,
561        }
562    }
563
564    /// Unwraps the type into a pair type.
565    ///
566    /// # Panics
567    ///
568    /// Panics if the type is not a pair type.
569    pub fn unwrap_pair_type(self) -> PairType<N> {
570        match self {
571            Self::Pair(ty) => ty,
572            _ => panic!("not a pair type"),
573        }
574    }
575
576    /// Attempts to get a reference to the inner [`ObjectType`].
577    ///
578    /// * If `self` is a [`Type::Object`], then a reference to the inner
579    ///   [`ObjectType`] is returned wrapped in [`Some`].
580    /// * Else, [`None`] is returned.
581    pub fn as_object_type(&self) -> Option<&ObjectType<N>> {
582        match self {
583            Self::Object(ty) => Some(ty),
584            _ => None,
585        }
586    }
587
588    /// Consumes `self` and attempts to return the inner [`ObjectType`].
589    ///
590    /// * If `self` is a [`Type::Object`], then the inner [`ObjectType`] is
591    ///   returned wrapped in [`Some`].
592    /// * Else, [`None`] is returned.
593    pub fn into_object_type(self) -> Option<ObjectType<N>> {
594        match self {
595            Self::Object(ty) => Some(ty),
596            _ => None,
597        }
598    }
599
600    /// Unwraps the type into an object type.
601    ///
602    /// # Panics
603    ///
604    /// Panics if the type is not an object type.
605    pub fn unwrap_object_type(self) -> ObjectType<N> {
606        match self {
607            Self::Object(ty) => ty,
608            _ => panic!("not an object type"),
609        }
610    }
611
612    /// Attempts to get a reference to the inner [`TypeRef`].
613    ///
614    /// * If `self` is a [`Type::Ref`], then a reference to the inner
615    ///   [`TypeRef`] is returned wrapped in [`Some`].
616    /// * Else, [`None`] is returned.
617    pub fn as_type_ref(&self) -> Option<&TypeRef<N>> {
618        match self {
619            Self::Ref(ty) => Some(ty),
620            _ => None,
621        }
622    }
623
624    /// Consumes `self` and attempts to return the inner [`TypeRef`].
625    ///
626    /// * If `self` is a [`Type::Ref`], then the inner [`TypeRef`] is returned
627    ///   wrapped in [`Some`].
628    /// * Else, [`None`] is returned.
629    pub fn into_type_ref(self) -> Option<TypeRef<N>> {
630        match self {
631            Self::Ref(ty) => Some(ty),
632            _ => None,
633        }
634    }
635
636    /// Unwraps the type into a type reference.
637    ///
638    /// # Panics
639    ///
640    /// Panics if the type is not a type reference.
641    pub fn unwrap_type_ref(self) -> TypeRef<N> {
642        match self {
643            Self::Ref(ty) => ty,
644            _ => panic!("not a type reference"),
645        }
646    }
647
648    /// Attempts to get a reference to the inner [`PrimitiveType`].
649    ///
650    /// * If `self` is a [`Type::Primitive`], then a reference to the inner
651    ///   [`PrimitiveType`] is returned wrapped in [`Some`].
652    /// * Else, [`None`] is returned.
653    pub fn as_primitive_type(&self) -> Option<&PrimitiveType<N>> {
654        match self {
655            Self::Primitive(ty) => Some(ty),
656            _ => None,
657        }
658    }
659
660    /// Consumes `self` and attempts to return the inner [`PrimitiveType`].
661    ///
662    /// * If `self` is a [`Type::Primitive`], then the inner [`PrimitiveType`]
663    ///   is returned wrapped in [`Some`].
664    /// * Else, [`None`] is returned.
665    pub fn into_primitive_type(self) -> Option<PrimitiveType<N>> {
666        match self {
667            Self::Primitive(ty) => Some(ty),
668            _ => None,
669        }
670    }
671
672    /// Unwraps the type into a primitive type.
673    ///
674    /// # Panics
675    ///
676    /// Panics if the type is not a primitive type.
677    pub fn unwrap_primitive_type(self) -> PrimitiveType<N> {
678        match self {
679            Self::Primitive(ty) => ty,
680            _ => panic!("not a primitive type"),
681        }
682    }
683
684    /// Finds the first child that can be cast to a [`Type`].
685    pub fn child(node: &N) -> Option<Self> {
686        node.children().find_map(Self::cast)
687    }
688
689    /// Finds all children that can be cast to a [`Type`].
690    pub fn children(node: &N) -> impl Iterator<Item = Self> + use<'_, N> {
691        node.children().filter_map(Self::cast)
692    }
693}
694
695impl<N: TreeNode> PartialEq for Type<N> {
696    fn eq(&self, other: &Self) -> bool {
697        match (self, other) {
698            (Self::Map(l), Self::Map(r)) => l == r,
699            (Self::Array(l), Self::Array(r)) => l == r,
700            (Self::Pair(l), Self::Pair(r)) => l == r,
701            (Self::Object(l), Self::Object(r)) => l == r,
702            (Self::Ref(l), Self::Ref(r)) => l == r,
703            (Self::Primitive(l), Self::Primitive(r)) => l == r,
704            _ => false,
705        }
706    }
707}
708
709impl fmt::Display for Type {
710    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711        match self {
712            Type::Map(m) => m.fmt(f),
713            Type::Array(a) => a.fmt(f),
714            Type::Pair(p) => p.fmt(f),
715            Type::Object(o) => o.fmt(f),
716            Type::Ref(r) => r.fmt(f),
717            Type::Primitive(p) => p.fmt(f),
718        }
719    }
720}
721
722/// Represents an unbound declaration.
723#[derive(Clone, Debug, PartialEq, Eq)]
724pub struct UnboundDecl<N: TreeNode = SyntaxNode>(N);
725
726impl<N: TreeNode> UnboundDecl<N> {
727    /// Gets the `env` token, if present.
728    ///
729    /// This may only return a token for task inputs (WDL 1.2+).
730    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
731        self.token()
732    }
733
734    /// Gets the type of the declaration.
735    pub fn ty(&self) -> Type<N> {
736        Type::child(&self.0).expect("unbound declaration should have a type")
737    }
738
739    /// Gets the name of the declaration.
740    pub fn name(&self) -> Ident<N::Token> {
741        self.token()
742            .expect("unbound declaration should have a name")
743    }
744}
745
746impl<N: TreeNode> AstNode<N> for UnboundDecl<N> {
747    fn can_cast(kind: SyntaxKind) -> bool {
748        kind == SyntaxKind::UnboundDeclNode
749    }
750
751    fn cast(inner: N) -> Option<Self> {
752        match inner.kind() {
753            SyntaxKind::UnboundDeclNode => Some(Self(inner)),
754            _ => None,
755        }
756    }
757
758    fn inner(&self) -> &N {
759        &self.0
760    }
761}
762
763impl Documented<SyntaxNode> for UnboundDecl<SyntaxNode> {
764    fn doc_comments(&self) -> Option<Vec<Comment<<SyntaxNode as TreeNode>::Token>>> {
765        let parent = self.inner().parent()?;
766        if !matches!(
767            parent.kind(),
768            SyntaxKind::StructDefinitionNode | SyntaxKind::InputSectionNode
769        ) {
770            return None;
771        }
772
773        Some(
774            crate::doc_comments::<SyntaxNode>(self.inner().first_token()?.preceding_trivia())
775                .collect(),
776        )
777    }
778}
779
780/// Represents a bound declaration in a task or workflow definition.
781#[derive(Clone, Debug, PartialEq, Eq)]
782pub struct BoundDecl<N: TreeNode = SyntaxNode>(N);
783
784impl<N: TreeNode> BoundDecl<N> {
785    /// Gets the `env` token, if present.
786    ///
787    /// This may only return a token for task inputs and private declarations
788    /// (WDL 1.2+).
789    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
790        self.token()
791    }
792
793    /// Gets the type of the declaration.
794    pub fn ty(&self) -> Type<N> {
795        Type::child(&self.0).expect("bound declaration should have a type")
796    }
797
798    /// Gets the name of the declaration.
799    pub fn name(&self) -> Ident<N::Token> {
800        self.token().expect("bound declaration should have a name")
801    }
802
803    /// Gets the expression the declaration is bound to.
804    pub fn expr(&self) -> Expr<N> {
805        Expr::child(&self.0).expect("bound declaration should have an expression")
806    }
807}
808
809impl<N: TreeNode> AstNode<N> for BoundDecl<N> {
810    fn can_cast(kind: SyntaxKind) -> bool {
811        kind == SyntaxKind::BoundDeclNode
812    }
813
814    fn cast(inner: N) -> Option<Self> {
815        match inner.kind() {
816            SyntaxKind::BoundDeclNode => Some(Self(inner)),
817            _ => None,
818        }
819    }
820
821    fn inner(&self) -> &N {
822        &self.0
823    }
824}
825
826impl Documented<SyntaxNode> for BoundDecl<SyntaxNode> {
827    fn doc_comments(&self) -> Option<Vec<Comment<<SyntaxNode as TreeNode>::Token>>> {
828        let parent = self.inner().parent()?;
829        if !matches!(
830            parent.kind(),
831            SyntaxKind::InputSectionNode | SyntaxKind::OutputSectionNode
832        ) {
833            return None;
834        }
835
836        Some(
837            crate::doc_comments::<SyntaxNode>(self.inner().first_token()?.preceding_trivia())
838                .collect(),
839        )
840    }
841}
842
843/// Represents a declaration in an input section.
844#[derive(Clone, Debug, PartialEq, Eq)]
845pub enum Decl<N: TreeNode = SyntaxNode> {
846    /// The declaration is bound.
847    Bound(BoundDecl<N>),
848    /// The declaration is unbound.
849    Unbound(UnboundDecl<N>),
850}
851
852impl<N: TreeNode> Decl<N> {
853    /// Returns whether or not the given syntax kind can be cast to
854    /// [`Decl`].
855    pub fn can_cast(kind: SyntaxKind) -> bool {
856        kind == SyntaxKind::BoundDeclNode || kind == SyntaxKind::UnboundDeclNode
857    }
858
859    /// Casts the given node to [`Decl`].
860    ///
861    /// Returns `None` if the node cannot be cast.
862    pub fn cast(inner: N) -> Option<Self> {
863        match inner.kind() {
864            SyntaxKind::BoundDeclNode => Some(Self::Bound(
865                BoundDecl::cast(inner).expect("bound decl to cast"),
866            )),
867            SyntaxKind::UnboundDeclNode => Some(Self::Unbound(
868                UnboundDecl::cast(inner).expect("unbound decl to cast"),
869            )),
870            _ => None,
871        }
872    }
873
874    /// Gets a reference to the inner node.
875    pub fn inner(&self) -> &N {
876        match self {
877            Self::Bound(d) => d.inner(),
878            Self::Unbound(d) => d.inner(),
879        }
880    }
881
882    /// Gets the `env` token, if present.
883    ///
884    /// This may only return a token for task inputs and private declarations
885    /// (WDL 1.2+).
886    pub fn env(&self) -> Option<EnvKeyword<N::Token>> {
887        match self {
888            Self::Bound(d) => d.env(),
889            Self::Unbound(d) => d.env(),
890        }
891    }
892
893    /// Gets the type of the declaration.
894    pub fn ty(&self) -> Type<N> {
895        match self {
896            Self::Bound(d) => d.ty(),
897            Self::Unbound(d) => d.ty(),
898        }
899    }
900
901    /// Gets the name of the declaration.
902    pub fn name(&self) -> Ident<N::Token> {
903        match self {
904            Self::Bound(d) => d.name(),
905            Self::Unbound(d) => d.name(),
906        }
907    }
908
909    /// Gets the expression of the declaration.
910    ///
911    /// Returns `None` for unbound declarations.
912    pub fn expr(&self) -> Option<Expr<N>> {
913        match self {
914            Self::Bound(d) => Some(d.expr()),
915            Self::Unbound(_) => None,
916        }
917    }
918
919    /// Attempts to get a reference to the inner [`BoundDecl`].
920    ///
921    /// * If `self` is a [`Decl::Bound`], then a reference to the inner
922    ///   [`BoundDecl`] is returned wrapped in [`Some`].
923    /// * Else, [`None`] is returned.
924    pub fn as_bound_decl(&self) -> Option<&BoundDecl<N>> {
925        match self {
926            Self::Bound(d) => Some(d),
927            _ => None,
928        }
929    }
930
931    /// Consumes `self` and attempts to return the inner [`BoundDecl`].
932    ///
933    /// * If `self` is a [`Decl::Bound`], then the inner [`BoundDecl`] is
934    ///   returned wrapped in [`Some`].
935    /// * Else, [`None`] is returned.
936    pub fn into_bound_decl(self) -> Option<BoundDecl<N>> {
937        match self {
938            Self::Bound(d) => Some(d),
939            _ => None,
940        }
941    }
942
943    /// Unwraps the declaration into a bound declaration.
944    ///
945    /// # Panics
946    ///
947    /// Panics if the declaration is not a bound declaration.
948    pub fn unwrap_bound_decl(self) -> BoundDecl<N> {
949        match self {
950            Self::Bound(d) => d,
951            _ => panic!("not a bound declaration"),
952        }
953    }
954
955    /// Attempts to get a reference to the inner [`UnboundDecl`].
956    ///
957    /// * If `self` is a [`Decl::Unbound`], then a reference to the inner
958    ///   [`UnboundDecl`] is returned wrapped in [`Some`].
959    /// * Else, [`None`] is returned.
960    pub fn as_unbound_decl(&self) -> Option<&UnboundDecl<N>> {
961        match self {
962            Self::Unbound(d) => Some(d),
963            _ => None,
964        }
965    }
966
967    /// Consumes `self` and attempts to return the inner [`UnboundDecl`].
968    ///
969    /// * If `self` is a [`Decl::Unbound`], then the inner [`UnboundDecl`] is
970    ///   returned wrapped in [`Some`].
971    /// * Else, [`None`] is returned.
972    pub fn into_unbound_decl(self) -> Option<UnboundDecl<N>> {
973        match self {
974            Self::Unbound(d) => Some(d),
975            _ => None,
976        }
977    }
978
979    /// Unwraps the declaration into an unbound declaration.
980    ///
981    /// # Panics
982    ///
983    /// Panics if the declaration is not an unbound declaration.
984    pub fn unwrap_unbound_decl(self) -> UnboundDecl<N> {
985        match self {
986            Self::Unbound(d) => d,
987            _ => panic!("not an unbound declaration"),
988        }
989    }
990
991    /// Finds the first child that can be cast to a [`Decl`].
992    pub fn child(node: &N) -> Option<Self> {
993        node.children().find_map(Self::cast)
994    }
995
996    /// Finds all children that can be cast to a [`Decl`].
997    pub fn children(node: &N) -> impl Iterator<Item = Self> + use<'_, N> {
998        node.children().filter_map(Self::cast)
999    }
1000}
1001
1002#[cfg(test)]
1003mod test {
1004    use super::*;
1005    use crate::Document;
1006
1007    #[test]
1008    fn decls() {
1009        let (document, diagnostics) = Document::parse(
1010            r#"
1011version 1.1
1012
1013task test {
1014    input {
1015        Boolean a
1016        Int b = 42
1017        Float? c = None
1018        String d
1019        File e = "foo.wdl"
1020        Map[Int, Int] f
1021        Array[String] g = []
1022        Pair[Boolean, Int] h
1023        Object i = object {}
1024        MyStruct j
1025        Directory k = "foo"
1026    }
1027}
1028"#,
1029        );
1030
1031        assert!(diagnostics.is_empty());
1032        let ast = document.ast();
1033        let ast = ast.as_v1().expect("should be a V1 AST");
1034        let tasks: Vec<_> = ast.tasks().collect();
1035        assert_eq!(tasks.len(), 1);
1036        assert_eq!(tasks[0].name().text(), "test");
1037
1038        // Inputs
1039        let input = tasks[0].input().expect("task should have an input section");
1040        let decls: Vec<_> = input.declarations().collect();
1041        assert_eq!(decls.len(), 11);
1042
1043        // First input declaration
1044        let decl = decls[0].clone().unwrap_unbound_decl();
1045        assert_eq!(decl.ty().to_string(), "Boolean");
1046        assert_eq!(decl.name().text(), "a");
1047
1048        // Second input declaration
1049        let decl = decls[1].clone().unwrap_bound_decl();
1050        assert_eq!(decl.ty().to_string(), "Int");
1051        assert_eq!(decl.name().text(), "b");
1052        assert_eq!(
1053            decl.expr()
1054                .unwrap_literal()
1055                .unwrap_integer()
1056                .value()
1057                .unwrap(),
1058            42
1059        );
1060
1061        // Third input declaration
1062        let decl = decls[2].clone().unwrap_bound_decl();
1063        assert_eq!(decl.ty().to_string(), "Float?");
1064        assert_eq!(decl.name().text(), "c");
1065        decl.expr().unwrap_literal().unwrap_none();
1066
1067        // Fourth input declaration
1068        let decl = decls[3].clone().unwrap_unbound_decl();
1069        assert_eq!(decl.ty().to_string(), "String");
1070        assert_eq!(decl.name().text(), "d");
1071
1072        // Fifth input declaration
1073        let decl = decls[4].clone().unwrap_bound_decl();
1074        assert_eq!(decl.ty().to_string(), "File");
1075        assert_eq!(decl.name().text(), "e");
1076        assert_eq!(
1077            decl.expr()
1078                .unwrap_literal()
1079                .unwrap_string()
1080                .text()
1081                .unwrap()
1082                .text(),
1083            "foo.wdl"
1084        );
1085
1086        // Sixth input declaration
1087        let decl = decls[5].clone().unwrap_unbound_decl();
1088        assert_eq!(decl.ty().to_string(), "Map[Int, Int]");
1089        assert_eq!(decl.name().text(), "f");
1090
1091        // Seventh input declaration
1092        let decl = decls[6].clone().unwrap_bound_decl();
1093        assert_eq!(decl.ty().to_string(), "Array[String]");
1094        assert_eq!(decl.name().text(), "g");
1095        assert_eq!(
1096            decl.expr()
1097                .unwrap_literal()
1098                .unwrap_array()
1099                .elements()
1100                .count(),
1101            0
1102        );
1103
1104        // Eighth input declaration
1105        let decl = decls[7].clone().unwrap_unbound_decl();
1106        assert_eq!(decl.ty().to_string(), "Pair[Boolean, Int]");
1107        assert_eq!(decl.name().text(), "h");
1108
1109        // Ninth input declaration
1110        let decl = decls[8].clone().unwrap_bound_decl();
1111        assert_eq!(decl.ty().to_string(), "Object");
1112        assert_eq!(decl.name().text(), "i");
1113        assert_eq!(
1114            decl.expr().unwrap_literal().unwrap_object().items().count(),
1115            0
1116        );
1117
1118        // Tenth input declaration
1119        let decl = decls[9].clone().unwrap_unbound_decl();
1120        assert_eq!(decl.ty().to_string(), "MyStruct");
1121        assert_eq!(decl.name().text(), "j");
1122
1123        // Eleventh input declaration
1124        let decl = decls[10].clone().unwrap_bound_decl();
1125        assert_eq!(decl.ty().to_string(), "Directory");
1126        assert_eq!(decl.name().text(), "k");
1127        assert_eq!(
1128            decl.expr()
1129                .unwrap_literal()
1130                .unwrap_string()
1131                .text()
1132                .unwrap()
1133                .text(),
1134            "foo"
1135        );
1136    }
1137}