wyre_derive/
lib.rs

1use proc_macro::TokenStream;
2use noco::*;
3
4#[proc_macro_derive(Pack, attributes(pack))]
5pub fn derive_pack(input: TokenStream) -> TokenStream {
6    let mut parser = TokenParser::new(input);
7    let mut tb = TokenBuilder::new();
8
9    parser.eat_attributes();
10    parser.eat_ident("pub");
11    if parser.eat_ident("struct") {
12        if let Some(name) = parser.eat_any_ident() {
13            let generic = parser.eat_generic();
14            let types = parser.eat_all_types();
15            let where_clause = parser.eat_where_clause(Some("Pack"));
16
17            tb.add("impl").stream(generic.clone());
18            tb.add("wyre::Pack for").ident(&name).stream(generic).stream(where_clause);
19            tb.add("{ fn pack<B: bytes::BufMut>(&self, buf: &mut B) {");
20
21            if let Some(types) = types {
22                for i in 0..types.len() {
23                    tb.add("self.").add(&i.to_string()).add(".pack(buf);");
24                }
25            } else if let Some(fields) = parser.eat_all_struct_fields() {
26                for field in fields {
27                    tb.add("self.").ident(&field.name).add(".pack(buf);");
28                }
29            } else {
30                return parser.unexpected();
31            }
32
33            tb.add("} } ;");
34            return tb.end();
35        }
36    } else if parser.eat_ident("enum") {
37        if let Some(name) = parser.eat_any_ident() {
38            let generic = parser.eat_generic();
39            let where_clause = parser.eat_where_clause(Some("Pack"));
40
41            tb.add("impl").stream(generic.clone());
42            tb.add("wyre::Pack for").ident(&name).stream(generic).stream(where_clause);
43            tb.add("{ fn pack<B: bytes::BufMut>(&self, buf: &mut B) {");
44            tb.add("match self {");
45
46            if !parser.open_brace() {
47                return parser.unexpected();
48            }
49
50            let mut variant_index = 0u32;
51            while !parser.eat_eot() {
52                parser.eat_attributes();
53                if let Some(variant) = parser.eat_any_ident() {
54                    if let Some(types) = parser.eat_all_types() {
55                        tb.add("Self::").ident(&variant).add("(");
56                        for i in 0..types.len() {
57                            tb.ident(&format!("v{}", i)).add(",");
58                        }
59                        tb.add(") => {");
60                        tb.add("wyre::write_varint(buf,").add(&variant_index.to_string()).add(");");
61                        for i in 0..types.len() {
62                            tb.ident(&format!("v{}", i)).add(".pack(buf);");
63                        }
64                        tb.add("}");
65                    } else if let Some(fields) = parser.eat_all_struct_fields() {
66                        tb.add("Self::").ident(&variant).add("{");
67                        for field in fields.iter() {
68                            tb.ident(&field.name).add(",");
69                        }
70                        tb.add("} => {");
71                        tb.add("wyre::write_varint(buf,").add(&variant_index.to_string()).add(");");
72                        for field in fields {
73                            tb.ident(&field.name).add(".pack(buf);");
74                        }
75                        tb.add("}");
76                    } else if parser.is_punct_alone(',') || parser.is_eot() {
77                        tb.add("Self::").ident(&variant).add("=> {");
78                        tb.add("wyre::write_varint(buf,").add(&variant_index.to_string()).add(");");
79                        tb.add("}");
80                    } else {
81                        return parser.unexpected();
82                    }
83                    variant_index += 1;
84                    parser.eat_punct_alone(',');
85                } else {
86                    return parser.unexpected();
87                }
88            }
89            tb.add("} } } ;");
90            return tb.end();
91        }
92    }
93    parser.unexpected()
94}
95
96#[proc_macro_derive(Unpack, attributes(unpack))]
97pub fn derive_unpack(input: TokenStream) -> TokenStream {
98    let mut parser = TokenParser::new(input);
99    let mut tb = TokenBuilder::new();
100
101    parser.eat_attributes();
102    parser.eat_ident("pub");
103    if parser.eat_ident("struct") {
104        if let Some(name) = parser.eat_any_ident() {
105            let generic = parser.eat_generic();
106            let types = parser.eat_all_types();
107            let where_clause = parser.eat_where_clause(Some("Unpack"));
108
109            tb.add("impl").stream(generic.clone());
110            tb.add("wyre::Unpack for").ident(&name).stream(generic).stream(where_clause);
111            tb.add("{ fn unpack(buf: &mut bytes::Bytes) -> Result<Self, wyre::WyreError> {");
112
113            if let Some(types) = types {
114                tb.add("Ok(Self(");
115                for _ in 0..types.len() {
116                    tb.add("Unpack::unpack(buf)?,");
117                }
118                tb.add("))");
119            } else if let Some(fields) = parser.eat_all_struct_fields() {
120                tb.add("Ok(Self {");
121                for field in fields {
122                    tb.ident(&field.name).add(": wyre::Unpack::unpack(buf)?,");
123                }
124                tb.add("})");
125            } else {
126                return parser.unexpected();
127            }
128
129            tb.add("} } ;");
130            return tb.end();
131        }
132    } else if parser.eat_ident("enum") {
133        if let Some(name) = parser.eat_any_ident() {
134            let generic = parser.eat_generic();
135            let where_clause = parser.eat_where_clause(Some("Unpack"));
136
137            tb.add("impl").stream(generic.clone());
138            tb.add("wyre::Unpack for").ident(&name).stream(generic).stream(where_clause);
139            tb.add("{ fn unpack(buf: &mut bytes::Bytes) -> Result<Self, wyre::WyreError> {");
140            tb.add("let variant_index = wyre::read_varint(buf)? as u32;");
141            tb.add("match variant_index {");
142
143            if !parser.open_brace() {
144                return parser.unexpected();
145            }
146
147            let mut variant_index = 0u32;
148            while !parser.eat_eot() {
149                parser.eat_attributes();
150                if let Some(variant) = parser.eat_any_ident() {
151                    tb.add(&variant_index.to_string()).add("=> {");
152                    if let Some(types) = parser.eat_all_types() {
153                        tb.add("Ok(Self::").ident(&variant).add("(");
154                        for _ in 0..types.len() {
155                            tb.add("wyre::Unpack::unpack(buf)?,");
156                        }
157                        tb.add("))");
158                    } else if let Some(fields) = parser.eat_all_struct_fields() {
159                        tb.add("Ok(Self::").ident(&variant).add("{");
160                        for field in fields {
161                            tb.ident(&field.name).add(": wyre::Unpack::unpack(buf)?,");
162                        }
163                        tb.add("})");
164                    } else if parser.is_punct_alone(',') || parser.is_eot() {
165                        tb.add("Ok(Self::").ident(&variant).add(")");
166                    } else {
167                        return parser.unexpected();
168                    }
169                    tb.add("}");
170                    variant_index += 1;
171                    parser.eat_punct_alone(',');
172                } else {
173                    return parser.unexpected();
174                }
175            }
176            tb.add("_ => Err(wyre::WyreError::InvalidFormat(\"Invalid enum variant\".into())),");
177            tb.add("} } } ;");
178            return tb.end();
179        }
180    }
181    parser.unexpected()
182}