Skip to main content

virtue_next/parse/
generics.rs

1use super::utils::*;
2use crate::generate::StreamBuilder;
3use crate::prelude::{Ident, TokenTree};
4use crate::{Error, Result};
5use std::iter::Peekable;
6use std::ops::{Deref, DerefMut};
7
8/// A generic parameter for a struct or enum.
9///
10/// ```
11/// use std::marker::PhantomData;
12/// use std::fmt::Display;
13///
14/// // Generic will be `Generic::Generic("F")`
15/// struct Foo<F> {
16///     f: PhantomData<F>
17/// }
18/// // Generics will be `Generic::Generic("F: Display")`
19/// struct Bar<F: Display> {
20///     f: PhantomData<F>
21/// }
22/// // Generics will be `[Generic::Lifetime("a"), Generic::Generic("F: Display")]`
23/// struct Baz<'a, F> {
24///     f: PhantomData<&'a F>
25/// }
26/// ```
27#[derive(Debug, Clone)]
28pub struct Generics(pub Vec<Generic>);
29
30impl Generics {
31    pub(crate) fn try_take(
32        input: &mut Peekable<impl Iterator<Item = TokenTree>>,
33    ) -> Result<Option<Generics>> {
34        let maybe_punct = input.peek();
35        if let Some(TokenTree::Punct(punct)) = maybe_punct {
36            if punct.as_char() == '<' {
37                let punct = assume_punct(input.next(), '<');
38                let mut result = Generics(Vec::new());
39                loop {
40                    match input.peek() {
41                        Some(TokenTree::Punct(punct)) if punct.as_char() == '\'' => {
42                            result.push(Lifetime::take(input)?.into());
43                            consume_punct_if(input, ',');
44                        }
45                        Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
46                            assume_punct(input.next(), '>');
47                            break;
48                        }
49                        Some(TokenTree::Ident(ident)) if ident_eq(ident, "const") => {
50                            result.push(ConstGeneric::take(input)?.into());
51                            consume_punct_if(input, ',');
52                        }
53                        Some(TokenTree::Ident(_)) => {
54                            result.push(SimpleGeneric::take(input)?.into());
55                            consume_punct_if(input, ',');
56                        }
57                        x => {
58                            return Err(Error::InvalidRustSyntax {
59                                span: x.map(|x| x.span()).unwrap_or_else(|| punct.span()),
60                                expected: format!("', > or an ident, got {:?}", x),
61                            });
62                        }
63                    }
64                }
65                return Ok(Some(result));
66            }
67        }
68        Ok(None)
69    }
70
71    /// Returns `true` if any of the generics is a [`Generic::Lifetime`]
72    pub fn has_lifetime(&self) -> bool {
73        self.iter().any(|lt| lt.is_lifetime())
74    }
75
76    /// Returns an iterator which contains only the simple type generics
77    pub fn iter_generics(&self) -> impl Iterator<Item = &SimpleGeneric> {
78        self.iter().filter_map(|g| match g {
79            Generic::Generic(s) => Some(s),
80            _ => None,
81        })
82    }
83
84    /// Returns an iterator which contains only the lifetimes
85    pub fn iter_lifetimes(&self) -> impl Iterator<Item = &Lifetime> {
86        self.iter().filter_map(|g| match g {
87            Generic::Lifetime(s) => Some(s),
88            _ => None,
89        })
90    }
91
92    /// Returns an iterator which contains only the const generics
93    pub fn iter_consts(&self) -> impl Iterator<Item = &ConstGeneric> {
94        self.iter().filter_map(|g| match g {
95            Generic::Const(s) => Some(s),
96            _ => None,
97        })
98    }
99
100    pub(crate) fn impl_generics(&self) -> StreamBuilder {
101        let mut result = StreamBuilder::new();
102        result.punct('<');
103
104        for (idx, generic) in self.iter().enumerate() {
105            if idx > 0 {
106                result.punct(',');
107            }
108
109            generic.append_to_result_with_constraints(&mut result);
110        }
111
112        result.punct('>');
113
114        result
115    }
116
117    pub(crate) fn impl_generics_with_additional(
118        &self,
119        lifetimes: &[String],
120        types: &[String],
121    ) -> StreamBuilder {
122        let mut result = StreamBuilder::new();
123        result.punct('<');
124        let mut is_first = true;
125        for lt in lifetimes.iter() {
126            if !is_first {
127                result.punct(',');
128            } else {
129                is_first = false;
130            }
131            result.lifetime_str(lt);
132        }
133
134        for generic in self.iter() {
135            if !is_first {
136                result.punct(',');
137            } else {
138                is_first = false;
139            }
140            generic.append_to_result_with_constraints(&mut result);
141        }
142        for ty in types {
143            if !is_first {
144                result.punct(',');
145            } else {
146                is_first = false;
147            }
148            result.ident_str(ty);
149        }
150
151        result.punct('>');
152
153        result
154    }
155
156    pub(crate) fn type_generics(&self) -> StreamBuilder {
157        let mut result = StreamBuilder::new();
158        result.punct('<');
159
160        for (idx, generic) in self.iter().enumerate() {
161            if idx > 0 {
162                result.punct(',');
163            }
164            if generic.is_lifetime() {
165                result.lifetime(generic.ident().clone());
166            } else {
167                result.ident(generic.ident().clone());
168            }
169        }
170
171        result.punct('>');
172        result
173    }
174}
175
176impl Deref for Generics {
177    type Target = Vec<Generic>;
178
179    fn deref(&self) -> &Self::Target {
180        &self.0
181    }
182}
183
184impl DerefMut for Generics {
185    fn deref_mut(&mut self) -> &mut Self::Target {
186        &mut self.0
187    }
188}
189
190/// A single generic argument on a type
191#[derive(Debug, Clone)]
192#[allow(clippy::enum_variant_names)]
193#[non_exhaustive]
194pub enum Generic {
195    /// A lifetime generic
196    ///
197    /// ```
198    /// # use std::marker::PhantomData;
199    /// struct Foo<'a> { // will be Generic::Lifetime("a")
200    /// #   a: PhantomData<&'a ()>,
201    /// }
202    /// ```
203    Lifetime(Lifetime),
204    /// A simple generic
205    ///
206    /// ```
207    /// # use std::marker::PhantomData;
208    /// struct Foo<F> { // will be Generic::Generic("F")
209    /// #   a: PhantomData<F>,
210    /// }
211    /// ```
212    Generic(SimpleGeneric),
213    /// A const generic
214    ///
215    /// ```
216    /// struct Foo<const N: usize> { // will be Generic::Const("N")
217    /// #   a: [u8; N],
218    /// }
219    /// ```
220    Const(ConstGeneric),
221}
222
223impl Generic {
224    fn is_lifetime(&self) -> bool {
225        matches!(self, Generic::Lifetime(_))
226    }
227
228    /// The ident of this generic
229    pub fn ident(&self) -> &Ident {
230        match self {
231            Self::Lifetime(lt) => &lt.ident,
232            Self::Generic(r#gen) => &r#gen.ident,
233            Self::Const(r#gen) => &r#gen.ident,
234        }
235    }
236
237    fn has_constraints(&self) -> bool {
238        match self {
239            Self::Lifetime(lt) => !lt.constraint.is_empty(),
240            Self::Generic(r#gen) => !r#gen.constraints.is_empty(),
241            Self::Const(_) => true, // const generics always have a constraint
242        }
243    }
244
245    fn constraints(&self) -> Vec<TokenTree> {
246        match self {
247            Self::Lifetime(lt) => lt.constraint.clone(),
248            Self::Generic(r#gen) => r#gen.constraints.clone(),
249            Self::Const(r#gen) => r#gen.constraints.clone(),
250        }
251    }
252
253    fn append_to_result_with_constraints(&self, builder: &mut StreamBuilder) {
254        match self {
255            Self::Lifetime(lt) => builder.lifetime(lt.ident.clone()),
256            Self::Generic(r#gen) => builder.ident(r#gen.ident.clone()),
257            Self::Const(r#gen) => {
258                builder.ident(r#gen.const_token.clone());
259                builder.ident(r#gen.ident.clone())
260            }
261        };
262        if self.has_constraints() {
263            builder.punct(':');
264            builder.extend(self.constraints());
265        }
266    }
267}
268
269impl From<Lifetime> for Generic {
270    fn from(lt: Lifetime) -> Self {
271        Self::Lifetime(lt)
272    }
273}
274
275impl From<SimpleGeneric> for Generic {
276    fn from(r#gen: SimpleGeneric) -> Self {
277        Self::Generic(r#gen)
278    }
279}
280
281impl From<ConstGeneric> for Generic {
282    fn from(r#gen: ConstGeneric) -> Self {
283        Self::Const(r#gen)
284    }
285}
286
287#[test]
288fn test_generics_try_take() {
289    use crate::token_stream;
290
291    assert!(Generics::try_take(&mut token_stream("")).unwrap().is_none());
292    assert!(
293        Generics::try_take(&mut token_stream("foo"))
294            .unwrap()
295            .is_none()
296    );
297    assert!(
298        Generics::try_take(&mut token_stream("()"))
299            .unwrap()
300            .is_none()
301    );
302
303    let stream = &mut token_stream("struct Foo<'a, T>()");
304    let (data_type, ident) = super::DataType::take(stream).unwrap();
305    assert_eq!(data_type, super::DataType::Struct);
306    assert_eq!(ident, "Foo");
307    let generics = Generics::try_take(stream).unwrap().unwrap();
308    assert_eq!(generics.len(), 2);
309    assert_eq!(generics[0].ident(), "a");
310    assert_eq!(generics[1].ident(), "T");
311
312    let stream = &mut token_stream("struct Foo<A, B>()");
313    let (data_type, ident) = super::DataType::take(stream).unwrap();
314    assert_eq!(data_type, super::DataType::Struct);
315    assert_eq!(ident, "Foo");
316    let generics = Generics::try_take(stream).unwrap().unwrap();
317    assert_eq!(generics.len(), 2);
318    assert_eq!(generics[0].ident(), "A");
319    assert_eq!(generics[1].ident(), "B");
320
321    let stream = &mut token_stream("struct Foo<'a, T: Display>()");
322    let (data_type, ident) = super::DataType::take(stream).unwrap();
323    assert_eq!(data_type, super::DataType::Struct);
324    assert_eq!(ident, "Foo");
325    let generics = Generics::try_take(stream).unwrap().unwrap();
326    dbg!(&generics);
327    assert_eq!(generics.len(), 2);
328    assert_eq!(generics[0].ident(), "a");
329    assert_eq!(generics[1].ident(), "T");
330
331    let stream = &mut token_stream("struct Foo<'a, T: for<'a> Bar<'a> + 'static>()");
332    let (data_type, ident) = super::DataType::take(stream).unwrap();
333    assert_eq!(data_type, super::DataType::Struct);
334    assert_eq!(ident, "Foo");
335    dbg!(&generics);
336    assert_eq!(generics.len(), 2);
337    assert_eq!(generics[0].ident(), "a");
338    assert_eq!(generics[1].ident(), "T");
339
340    let stream = &mut token_stream(
341        "struct Baz<T: for<'a> Bar<'a, for<'b> Bar<'b, for<'c> Bar<'c, u32>>>> {}",
342    );
343    let (data_type, ident) = super::DataType::take(stream).unwrap();
344    assert_eq!(data_type, super::DataType::Struct);
345    assert_eq!(ident, "Baz");
346    let generics = Generics::try_take(stream).unwrap().unwrap();
347    dbg!(&generics);
348    assert_eq!(generics.len(), 1);
349    assert_eq!(generics[0].ident(), "T");
350
351    let stream = &mut token_stream("struct Baz<()> {}");
352    let (data_type, ident) = super::DataType::take(stream).unwrap();
353    assert_eq!(data_type, super::DataType::Struct);
354    assert_eq!(ident, "Baz");
355    assert!(
356        Generics::try_take(stream)
357            .unwrap_err()
358            .is_invalid_rust_syntax()
359    );
360
361    let stream = &mut token_stream("struct Bar<A: FnOnce(&'static str) -> SomeStruct, B>");
362    let (data_type, ident) = super::DataType::take(stream).unwrap();
363    assert_eq!(data_type, super::DataType::Struct);
364    assert_eq!(ident, "Bar");
365    let generics = Generics::try_take(stream).unwrap().unwrap();
366    dbg!(&generics);
367    assert_eq!(generics.len(), 2);
368    assert_eq!(generics[0].ident(), "A");
369    assert_eq!(generics[1].ident(), "B");
370
371    let stream = &mut token_stream("struct Bar<A = ()>");
372    let (data_type, ident) = super::DataType::take(stream).unwrap();
373    assert_eq!(data_type, super::DataType::Struct);
374    assert_eq!(ident, "Bar");
375    let generics = Generics::try_take(stream).unwrap().unwrap();
376    dbg!(&generics);
377    assert_eq!(generics.len(), 1);
378    if let Generic::Generic(generic) = &generics[0] {
379        assert_eq!(generic.ident, "A");
380        assert_eq!(generic.default_value.len(), 1);
381        assert_eq!(generic.default_value[0].to_string(), "()");
382    } else {
383        panic!("Expected simple generic, got {:?}", generics[0]);
384    }
385}
386
387/// a lifetime generic parameter, e.g. `struct Foo<'a> { ... }`
388#[derive(Debug, Clone)]
389pub struct Lifetime {
390    /// The ident of this lifetime
391    pub ident: Ident,
392    /// Any constraints that this lifetime may have
393    pub constraint: Vec<TokenTree>,
394}
395
396impl Lifetime {
397    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
398        let start = assume_punct(input.next(), '\'');
399        let ident = match input.peek() {
400            Some(TokenTree::Ident(_)) => assume_ident(input.next()),
401            Some(t) => return Err(Error::ExpectedIdent(t.span())),
402            None => return Err(Error::ExpectedIdent(start.span())),
403        };
404
405        let mut constraint = Vec::new();
406        if let Some(TokenTree::Punct(p)) = input.peek() {
407            if p.as_char() == ':' {
408                assume_punct(input.next(), ':');
409                constraint = read_tokens_until_punct(input, &[',', '>'])?;
410            }
411        }
412
413        Ok(Self { ident, constraint })
414    }
415
416    #[cfg(test)]
417    fn is_ident(&self, s: &str) -> bool {
418        self.ident == s
419    }
420}
421
422#[test]
423fn test_lifetime_take() {
424    use crate::token_stream;
425    use std::panic::catch_unwind;
426    assert!(
427        Lifetime::take(&mut token_stream("'a"))
428            .unwrap()
429            .is_ident("a")
430    );
431    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0"))).is_err());
432    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'("))).is_err());
433    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("')"))).is_err());
434    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0'"))).is_err());
435
436    let stream = &mut token_stream("'a: 'b>");
437    let lifetime = Lifetime::take(stream).unwrap();
438    assert_eq!(lifetime.ident, "a");
439    assert_eq!(lifetime.constraint.len(), 2);
440    assume_punct(stream.next(), '>');
441    assert!(stream.next().is_none());
442}
443
444/// a simple generic parameter, e.g. `struct Foo<F> { .. }`
445#[derive(Debug, Clone)]
446#[non_exhaustive]
447pub struct SimpleGeneric {
448    /// The ident of this generic
449    pub ident: Ident,
450    /// The constraints of this generic, e.g. `F: SomeTrait`
451    pub constraints: Vec<TokenTree>,
452    /// The default value of this generic, e.g. `F = ()`
453    pub default_value: Vec<TokenTree>,
454}
455
456impl SimpleGeneric {
457    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
458        let ident = assume_ident(input.next());
459        let mut constraints = Vec::new();
460        let mut default_value = Vec::new();
461        if let Some(TokenTree::Punct(punct)) = input.peek() {
462            let punct_char = punct.as_char();
463            if punct_char == ':' {
464                assume_punct(input.next(), ':');
465                constraints = read_tokens_until_punct(input, &['>', ','])?;
466            }
467            if punct_char == '=' {
468                assume_punct(input.next(), '=');
469                default_value = read_tokens_until_punct(input, &['>', ','])?;
470            }
471        }
472        Ok(Self {
473            ident,
474            constraints,
475            default_value,
476        })
477    }
478
479    /// The name of this generic, e.g. `T`
480    pub fn name(&self) -> Ident {
481        self.ident.clone()
482    }
483}
484
485/// a const generic parameter, e.g. `struct Foo<const N: usize> { .. }`
486#[derive(Debug, Clone)]
487pub struct ConstGeneric {
488    /// The `const` token for this generic
489    pub const_token: Ident,
490    /// The ident of this generic
491    pub ident: Ident,
492    /// The "constraints" (type) of this generic, e.g. the `usize` from `const N: usize`
493    pub constraints: Vec<TokenTree>,
494}
495
496impl ConstGeneric {
497    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
498        let const_token = assume_ident(input.next());
499        let ident = assume_ident(input.next());
500        let mut constraints = Vec::new();
501        if let Some(TokenTree::Punct(punct)) = input.peek() {
502            if punct.as_char() == ':' {
503                assume_punct(input.next(), ':');
504                constraints = read_tokens_until_punct(input, &['>', ','])?;
505            }
506        }
507        Ok(Self {
508            const_token,
509            ident,
510            constraints,
511        })
512    }
513}
514
515/// Constraints on generic types.
516///
517/// ```
518/// # use std::marker::PhantomData;
519/// # use std::fmt::Display;
520///
521/// struct Foo<F>
522///     where F: Display // These are `GenericConstraints`
523/// {
524///     f: PhantomData<F>
525/// }
526#[derive(Debug, Clone, Default)]
527pub struct GenericConstraints {
528    constraints: Vec<TokenTree>,
529}
530
531impl GenericConstraints {
532    pub(crate) fn try_take(
533        input: &mut Peekable<impl Iterator<Item = TokenTree>>,
534    ) -> Result<Option<Self>> {
535        match input.peek() {
536            Some(TokenTree::Ident(ident)) => {
537                if !ident_eq(ident, "where") {
538                    return Ok(None);
539                }
540            }
541            _ => {
542                return Ok(None);
543            }
544        }
545        input.next();
546        let constraints = read_tokens_until_punct(input, &['{', '('])?;
547        Ok(Some(Self { constraints }))
548    }
549
550    pub(crate) fn where_clause(&self) -> StreamBuilder {
551        let mut result = StreamBuilder::new();
552        result.ident_str("where");
553        result.extend(self.constraints.clone());
554        result
555    }
556
557    /// Push the given constraint onto this stream.
558    ///
559    /// ```ignore
560    /// let mut generic_constraints = GenericConstraints::parse("T: Foo"); // imaginary function
561    /// let mut generic = SimpleGeneric::new("U"); // imaginary function
562    ///
563    /// generic_constraints.push_constraint(&generic, "Bar");
564    ///
565    /// // generic_constraints is now:
566    /// // `T: Foo, U: Bar`
567    /// ```
568    pub fn push_constraint(
569        &mut self,
570        generic: &SimpleGeneric,
571        constraint: impl AsRef<str>,
572    ) -> Result<()> {
573        let mut builder = StreamBuilder::new();
574        let last_constraint_was_comma = self
575            .constraints
576            .last()
577            .is_some_and(|l| matches!(l, TokenTree::Punct(c) if c.as_char() == ','));
578        if !self.constraints.is_empty() && !last_constraint_was_comma {
579            builder.punct(',');
580        }
581        builder.ident(generic.ident.clone());
582        builder.punct(':');
583        builder.push_parsed(constraint)?;
584        self.constraints.extend(builder.stream);
585
586        Ok(())
587    }
588
589    /// Push the given constraint onto this stream.
590    ///
591    /// ```ignore
592    /// let mut generic_constraints = GenericConstraints::parse("T: Foo"); // imaginary function
593    ///
594    /// generic_constraints.push_parsed_constraint("u32: SomeTrait");
595    ///
596    /// // generic_constraints is now:
597    /// // `T: Foo, u32: SomeTrait`
598    /// ```
599    pub fn push_parsed_constraint(&mut self, constraint: impl AsRef<str>) -> Result<()> {
600        let mut builder = StreamBuilder::new();
601        if !self.constraints.is_empty() {
602            builder.punct(',');
603        }
604        builder.push_parsed(constraint)?;
605        self.constraints.extend(builder.stream);
606
607        Ok(())
608    }
609
610    /// Clear the constraints
611    pub fn clear(&mut self) {
612        self.constraints.clear();
613    }
614}
615
616#[test]
617fn test_generic_constraints_try_take() {
618    use super::{DataType, StructBody, Visibility};
619    use crate::parse::body::Fields;
620    use crate::token_stream;
621
622    let stream = &mut token_stream("struct Foo where Foo: Bar { }");
623    DataType::take(stream).unwrap();
624    assert!(GenericConstraints::try_take(stream).unwrap().is_some());
625
626    let stream = &mut token_stream("struct Foo { }");
627    DataType::take(stream).unwrap();
628    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
629
630    let stream = &mut token_stream("struct Foo where Foo: Bar(Foo)");
631    DataType::take(stream).unwrap();
632    assert!(GenericConstraints::try_take(stream).unwrap().is_some());
633
634    let stream = &mut token_stream("struct Foo()");
635    DataType::take(stream).unwrap();
636    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
637
638    let stream = &mut token_stream("struct Foo()");
639    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
640
641    let stream = &mut token_stream("{}");
642    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
643
644    let stream = &mut token_stream("");
645    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
646
647    let stream = &mut token_stream("pub(crate) struct Test<T: Encode> {}");
648    assert_eq!(Visibility::Pub, Visibility::try_take(stream).unwrap());
649    let (data_type, ident) = DataType::take(stream).unwrap();
650    assert_eq!(data_type, DataType::Struct);
651    assert_eq!(ident, "Test");
652    let constraints = Generics::try_take(stream).unwrap().unwrap();
653    assert_eq!(constraints.len(), 1);
654    assert_eq!(constraints[0].ident(), "T");
655    let body = StructBody::take(stream).unwrap();
656    if let Some(Fields::Struct(v)) = body.fields {
657        assert!(v.is_empty());
658    } else {
659        panic!("wrong fields {:?}", body.fields);
660    }
661}
662
663#[test]
664fn test_generic_constraints_trailing_comma() {
665    use crate::parse::{
666        Attribute, AttributeLocation, DataType, GenericConstraints, Generics, StructBody,
667        Visibility,
668    };
669    use crate::token_stream;
670    let source = &mut token_stream("pub struct MyStruct<T> where T: Clone, { }");
671
672    Attribute::try_take(AttributeLocation::Container, source).unwrap();
673    Visibility::try_take(source).unwrap();
674    DataType::take(source).unwrap();
675    Generics::try_take(source).unwrap().unwrap();
676    GenericConstraints::try_take(source).unwrap().unwrap();
677    StructBody::take(source).unwrap();
678}