Skip to main content

typhoon_syn/constraints/
mod.rs

1use syn::{
2    parse::{Parse, ParseStream},
3    Ident, Token,
4};
5
6mod address;
7mod assert;
8mod associated_token;
9mod bump;
10mod has_one;
11mod init;
12mod init_if_needed;
13mod mint;
14mod payer;
15mod program;
16mod seeded;
17mod seeds;
18mod space;
19mod token;
20
21pub use {
22    address::*, assert::*, associated_token::*, bump::*, has_one::*, init::*, init_if_needed::*,
23    mint::*, payer::*, program::*, seeded::*, seeds::*, space::*, token::*,
24};
25
26pub const CONSTRAINT_IDENT_STR: &str = "constraint";
27
28//TODO rewrite it to add custom constraint for users
29#[derive(Clone)]
30pub enum Constraint {
31    Init(ConstraintInit),
32    Payer(ConstraintPayer),
33    Space(ConstraintSpace),
34    Seeded(ConstraintSeeded),
35    Seeds(ConstraintSeeds),
36    Bump(ConstraintBump),
37    HasOne(ConstraintHasOne),
38    Program(ConstraintProgram),
39    Token(ConstraintToken),
40    Mint(ConstraintMint),
41    AssociatedToken(ConstraintAssociatedToken),
42    InitIfNeeded(ConstraintInitIfNeeded),
43    Assert(ConstraintAssert),
44    Address(ConstraintAddress),
45}
46
47impl Constraint {
48    fn sort_order(&self) -> u8 {
49        match self {
50            Self::Init(_) => 0,
51            Self::InitIfNeeded(_) => 1,
52            Self::Space(_) => 2,
53            Self::Seeded(_) => 3,
54            Self::Seeds(_) => 4,
55            Self::Bump(_) => 5,
56            Self::Program(_) => 6,
57            Self::HasOne(_) => 7,
58            Self::Token(_) => 8,
59            Self::Mint(_) => 9,
60            Self::AssociatedToken(_) => 10,
61            Self::Payer(_) => 11,
62            Self::Assert(_) => 12,
63            Self::Address(_) => 13,
64        }
65    }
66}
67
68#[derive(Clone, Default)]
69pub struct Constraints(pub Vec<Constraint>);
70
71impl TryFrom<&[syn::Attribute]> for Constraints {
72    type Error = syn::Error;
73
74    fn try_from(value: &[syn::Attribute]) -> Result<Self, Self::Error> {
75        let mut constraints = value
76            .iter()
77            .filter(|attr| attr.path().is_ident(CONSTRAINT_IDENT_STR))
78            .map(|attr| attr.parse_args_with(parse_constraints))
79            .collect::<Result<Vec<Vec<Constraint>>, syn::Error>>()?
80            .into_iter()
81            .flatten()
82            .collect::<Vec<_>>();
83
84        constraints.sort_by_key(|c| c.sort_order());
85
86        Ok(Constraints(constraints))
87    }
88}
89
90pub fn parse_constraints(input: ParseStream) -> syn::Result<Vec<Constraint>> {
91    let mut constraints = Vec::new();
92
93    while !input.is_empty() {
94        let name = input.parse::<Ident>()?.to_string();
95        match name.as_str() {
96            "init" => constraints.push(Constraint::Init(ConstraintInit)),
97            "payer" => constraints.push(Constraint::Payer(ConstraintPayer::parse(input)?)),
98            "space" => constraints.push(Constraint::Space(ConstraintSpace::parse(input)?)),
99            "seeds" => constraints.push(Constraint::Seeds(ConstraintSeeds::parse(input)?)),
100            "bump" => constraints.push(Constraint::Bump(ConstraintBump::parse(input)?)),
101            "seeded" => constraints.push(Constraint::Seeded(ConstraintSeeded::parse(input)?)),
102            "has_one" => constraints.push(Constraint::HasOne(ConstraintHasOne::parse(input)?)),
103            "program" => constraints.push(Constraint::Program(ConstraintProgram::parse(input)?)),
104            "token" => constraints.push(Constraint::Token(ConstraintToken::parse(input)?)),
105            "mint" => constraints.push(Constraint::Mint(ConstraintMint::parse(input)?)),
106            "associated_token" => constraints.push(Constraint::AssociatedToken(
107                ConstraintAssociatedToken::parse(input)?,
108            )),
109            "init_if_needed" => constraints.push(Constraint::InitIfNeeded(ConstraintInitIfNeeded)),
110            "assert" => constraints.push(Constraint::Assert(ConstraintAssert::parse(input)?)),
111            "address" => constraints.push(Constraint::Address(ConstraintAddress::parse(input)?)),
112            _ => return Err(syn::Error::new(input.span(), "Unknown constraint.")),
113        }
114
115        if input.peek(Token![,]) {
116            input.parse::<Token![,]>()?;
117        }
118    }
119
120    Ok(constraints)
121}
122
123#[cfg(test)]
124mod tests {
125    use {super::*, syn::parse_quote};
126
127    #[test]
128    fn test_parse_constraints() {
129        let attributes: Vec<syn::Attribute> = parse_quote! {
130            #[constraint(
131                has_one = account,
132                seeds = [
133                    b"seed".as_ref(),
134                ],
135                bump = counter.data()?.bump,
136                token::mint = mint,
137                token::owner = authority,
138                mint::decimals = args.decimals,
139                mint::authority = escrow.key(),
140                mint::freeze_authority = freeze_authority.key(),
141                init_if_needed
142            )]
143        };
144
145        let constraints = Constraints::try_from(attributes.as_slice()).unwrap();
146
147        assert_eq!(constraints.0.len(), 9);
148    }
149}