sunset_sshwire_derive/
lib.rs

1//! Used in conjunction with `sshwire.rs` and `packets.rs`
2//!
3//! `SSHWIRE_DEBUG` environment variable can be set at build time
4//! to write generated files to the `target/` directory.
5#![expect(clippy::useless_format)]
6
7use std::collections::HashSet;
8use std::env;
9
10use proc_macro::Delimiter;
11use virtue::generate::FnSelfArg;
12use virtue::parse::{Attribute, AttributeLocation, EnumBody, StructBody};
13use virtue::prelude::*;
14use virtue::utils::{parse_tagged_attribute, ParsedAttribute};
15
16const ENV_SSHWIRE_DEBUG: &str = "SSHWIRE_DEBUG";
17
18#[proc_macro_derive(SSHEncode, attributes(sshwire))]
19pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    encode_inner(input).unwrap_or_else(|e| e.into_token_stream())
21}
22
23#[proc_macro_derive(SSHDecode, attributes(sshwire))]
24pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25    decode_inner(input).unwrap_or_else(|e| e.into_token_stream())
26}
27
28fn encode_inner(input: TokenStream) -> Result<TokenStream> {
29    let parse = Parse::new(input)?;
30    let (mut gen, att, body) = parse.into_generator();
31    // println!("att {att:#?}");
32    match body {
33        Body::Struct(body) => {
34            encode_struct(&mut gen, body)?;
35        }
36        Body::Enum(body) => {
37            encode_enum(&mut gen, &att, body)?;
38        }
39    }
40    if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
41        gen.export_to_file("sshwire", "SSHEncode");
42    }
43    gen.finish()
44}
45
46fn decode_inner(input: TokenStream) -> Result<TokenStream> {
47    let parse = Parse::new(input)?;
48    let (mut gen, att, body) = parse.into_generator();
49    // println!("att {att:#?}");
50    match body {
51        Body::Struct(body) => {
52            decode_struct(&mut gen, body)?;
53        }
54        Body::Enum(body) => {
55            decode_enum(&mut gen, &att, body)?;
56        }
57    }
58    if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
59        gen.export_to_file("sshwire", "SSHDecode");
60    }
61    gen.finish()
62}
63
64#[derive(Debug)]
65enum ContainerAtt {
66    /// The string of the method is prefixed to this enum.
67    /// `#[sshwire(variant_prefix)]`
68    VariantPrefix,
69
70    /// Don't generate SSHEncodeEnum. Can't be used with SSHDecode derive.
71    /// `#[sshwire(no_variant_names)]`
72    NoNames,
73}
74
75#[derive(Debug)]
76enum FieldAtt {
77    /// A variant method name will be encoded/decoded before the next field.
78    /// eg `#[sshwire(variant_name = ch)]` for `ChannelRequest`
79    VariantName(Ident),
80
81    /// Any unknown variant name should be recorded here.
82    /// This variant can't be written out.
83    /// `#[sshwire(unknown))]`
84    CaptureUnknown,
85
86    /// The name of a variant, used by the parent struct
87    /// `#[sshwire(variant = "exit-signal"))]`
88    /// or
89    /// `#[sshwire(variant = SSH_NAME_IDENT))]`
90    Variant(TokenTree),
91}
92
93fn take_cont_atts(atts: &[Attribute]) -> Result<Vec<ContainerAtt>> {
94    let x = atts
95        .iter()
96        .filter_map(|a| parse_tagged_attribute(&a.tokens, "sshwire").transpose());
97
98    let mut ret = vec![];
99    // flatten the lists
100    for a in x {
101        for a in a? {
102            let l = match a {
103                ParsedAttribute::Tag(l) if l.to_string() == "no_variant_names" => {
104                    Ok(ContainerAtt::NoNames)
105                }
106                ParsedAttribute::Tag(l) if l.to_string() == "variant_prefix" => {
107                    Ok(ContainerAtt::VariantPrefix)
108                }
109                _ => Err(Error::Custom {
110                    error: "Unknown sshwire atttribute".into(),
111                    span: None,
112                }),
113            }?;
114            ret.push(l);
115        }
116    }
117    Ok(ret)
118}
119
120// TODO: we could use virtue parse_tagged_attribute() though it doesn't support Literals
121fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> {
122    atts.iter()
123        .filter_map(|a| {
124            match a.location {
125                AttributeLocation::Field | AttributeLocation::Variant => {
126                    let mut s = a.tokens.stream().into_iter();
127                    if &s.next().expect("missing attribute name").to_string()
128                        != "sshwire"
129                    {
130                        // skip attributes other than "sshwire"
131                        return None;
132                    }
133                    Some(if let Some(TokenTree::Group(g)) = s.next() {
134                        let mut g = g.stream().into_iter();
135                        let f = match g.next() {
136                            Some(TokenTree::Ident(l))
137                                if l.to_string() == "variant_name" =>
138                            {
139                                // check for '='
140                                match g.next() {
141                                    Some(TokenTree::Punct(p)) if p == '=' => (),
142                                    _ => {
143                                        return Some(Err(Error::Custom {
144                                            error: "Missing '='".into(),
145                                            span: Some(a.tokens.span()),
146                                        }))
147                                    }
148                                }
149                                match g.next() {
150                                    Some(TokenTree::Ident(i)) => {
151                                        Ok(FieldAtt::VariantName(i))
152                                    }
153                                    _ => Err(Error::ExpectedIdent(a.tokens.span())),
154                                }
155                            }
156
157                            Some(TokenTree::Ident(l))
158                                if l.to_string() == "unknown" =>
159                            {
160                                Ok(FieldAtt::CaptureUnknown)
161                            }
162
163                            Some(TokenTree::Ident(l))
164                                if l.to_string() == "variant" =>
165                            {
166                                // check for '='
167                                match g.next() {
168                                    Some(TokenTree::Punct(p)) if p == '=' => (),
169                                    _ => {
170                                        return Some(Err(Error::Custom {
171                                            error: "Missing '='".into(),
172                                            span: Some(a.tokens.span()),
173                                        }))
174                                    }
175                                }
176                                if let Some(t) = g.next() {
177                                    Ok(FieldAtt::Variant(t))
178                                } else {
179                                    Err(Error::Custom {
180                                        error: "Missing expression".into(),
181                                        span: Some(a.tokens.span()),
182                                    })
183                                }
184                            }
185
186                            _ => Err(Error::Custom {
187                                error: "Unknown sshwire atttribute".into(),
188                                span: Some(a.tokens.span()),
189                            }),
190                        };
191
192                        if g.next().is_some() {
193                            Err(Error::Custom {
194                                error: "Extra unhandled parts".into(),
195                                span: Some(a.tokens.span()),
196                            })
197                        } else {
198                            f
199                        }
200                    } else {
201                        Err(Error::Custom {
202                            error: "#[sshwire(...)] attribute is missing (...) part"
203                                .into(),
204                            span: Some(a.tokens.span()),
205                        })
206                    })
207                }
208                _ => panic!("Non-field attribute for field: {a:#?}"),
209            }
210        })
211        .collect()
212}
213
214fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
215    gen.impl_for("::sunset::sshwire::SSHEncode")
216        .generate_fn("enc")
217        .with_self_arg(FnSelfArg::RefSelf)
218        .with_arg("s", "&mut dyn ::sunset::sshwire::SSHSink")
219        .with_return_type("::sunset::sshwire::WireResult<()>")
220        .body(|fn_body| {
221            match &body.fields {
222                Some(Fields::Tuple(v)) => {
223                    for (fname, f) in v.iter().enumerate() {
224                        // we're only using single elements for newtype, don't bother with atts for now
225                        if !f.attributes.is_empty() {
226                            return Err(Error::Custom { error: "Attributes aren't allowed for tuple structs".into(), span: Some(f.span()) })
227                        }
228                        fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
229                    }
230                }
231                Some(Fields::Struct(v)) => {
232                    for f in v {
233                        let fname = &f.0;
234                        let atts = take_field_atts(&f.1.attributes)?;
235                        for a in atts {
236                            if let FieldAtt::VariantName(enum_field) = a {
237                                // encode an enum field's variant name before this field
238                                fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{enum_field}.variant_name()?, s)?;"))?;
239                            }
240                        }
241                        fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
242                    }
243
244                }
245                None => {
246                    // nothing to do.
247                    // either an empty braced struct or a unit struct.
248                }
249
250            }
251            fn_body.push_parsed("Ok(())")?;
252            Ok(())
253        })?;
254    Ok(())
255}
256
257fn encode_enum(
258    gen: &mut Generator,
259    atts: &[Attribute],
260    body: EnumBody,
261) -> Result<()> {
262    let cont_atts = take_cont_atts(atts)?;
263
264    gen.impl_for("::sunset::sshwire::SSHEncode")
265        .generate_fn("enc")
266        .with_self_arg(FnSelfArg::RefSelf)
267        .with_arg("s", "&mut dyn ::sunset::sshwire::SSHSink")
268        .with_return_type("::sunset::sshwire::WireResult<()>")
269        .body(|fn_body| {
270            if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
271                fn_body.push_parsed("::sunset::sshwire::SSHEncode::enc(&self.variant_name()?, s)?;")?;
272            }
273
274            fn_body.ident_str("match");
275            fn_body.puncts("*");
276            fn_body.ident_str("self");
277            fn_body.group(Delimiter::Brace, |match_arm| {
278                for var in &body.variants {
279                    match_arm.ident_str("Self");
280                    match_arm.puncts("::");
281                    match_arm.ident(var.name.clone());
282
283                    let atts = take_field_atts(&var.attributes)?;
284
285                    let mut rhs = StreamBuilder::new();
286                    if let Some(val) = &var.value {
287                        // Avoid users expecting enum values to be encoded.
288                        // Could be implemented if needed.
289                        return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently does not encode enum discriminants.".into(), span: Some(val.span())})
290                    }
291                    match var.fields {
292                        None => {
293                            // Unit enum
294                        }
295                        Some(Fields::Tuple(ref f)) if f.len() == 1 => {
296                            match_arm.group(Delimiter::Parenthesis, |item| {
297                                item.ident_str("ref");
298                                item.ident_str("i");
299                                Ok(())
300                            })?;
301                            if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
302                                rhs.push_parsed("return Err(::sunset::sshwire::WireError::UnknownVariant)")?;
303                            } else {
304                                rhs.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(i, s)?;"))?;
305                            }
306
307                        }
308                        _ => return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
309                    }
310
311                    match_arm.puncts("=>");
312                    match_arm.group(Delimiter::Brace, |var_body| {
313                        var_body.append(rhs);
314                        Ok(())
315                    })?;
316                }
317                Ok(())
318            })?;
319            // an enum with only an Unknown variant will always return an earlier error
320            fn_body.push_parsed("#[allow(unreachable_code)]")?;
321            fn_body.push_parsed("Ok(())")?;
322            Ok(())
323        })?;
324
325    if !cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
326        encode_enum_names(gen, atts, body)?;
327    }
328    Ok(())
329}
330
331fn field_att_var_names(name: &Ident, mut atts: Vec<FieldAtt>) -> Result<TokenTree> {
332    let mut v = vec![];
333    while let Some(p) = atts.pop() {
334        if let FieldAtt::Variant(t) = p {
335            v.push(t);
336        }
337    }
338    if v.len() != 1 {
339        return Err(Error::Custom { error: format!("One #[sshwire(variant = ...)] attribute is required for each enum field, missing for {:?}", name), span: None});
340    }
341    Ok(v.pop().unwrap())
342}
343
344fn encode_enum_names(
345    gen: &mut Generator,
346    _atts: &[Attribute],
347    body: EnumBody,
348) -> Result<()> {
349    gen.impl_for("::sunset::sshwire::SSHEncodeEnum")
350        .generate_fn("variant_name")
351        .with_self_arg(FnSelfArg::RefSelf)
352        .with_return_type("::sunset::sshwire::WireResult<&'static str>")
353        .body(|fn_body| {
354            fn_body.push_parsed("let r = match self")?;
355            fn_body.group(Delimiter::Brace, |match_arm| {
356                for var in &body.variants {
357                    match_arm.ident_str("Self");
358                    match_arm.puncts("::");
359                    match_arm.ident(var.name.clone());
360
361                    let mut rhs = StreamBuilder::new();
362                    let atts = take_field_atts(&var.attributes)?;
363                    if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
364                        rhs.push_parsed("return Err(::sunset::sshwire::WireError::UnknownVariant)")?;
365                    } else {
366                        rhs.push(field_att_var_names(&var.name, atts)?);
367                    }
368
369                    match var.fields {
370                        None => {
371                            // nothing to do
372                        }
373                        Some(Fields::Tuple(ref f)) if f.len() == 1 => {
374                            match_arm.group(Delimiter::Parenthesis, |item| {
375                                item.ident_str("_");
376                                Ok(())
377                            })?;
378
379                        }
380                        _ => return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
381                    }
382
383                    match_arm.puncts("=>");
384                    match_arm.group(Delimiter::Brace, |var_body| {
385                        var_body.append(rhs);
386                        Ok(())
387                    })?;
388                }
389                Ok(())
390            })?;
391            fn_body.push_parsed(";")?;
392            // an enum with only an Unknown variant will always return an earlier error
393            fn_body.push_parsed("#[allow(unreachable_code)]")?;
394            fn_body.push_parsed("Ok(r)")?;
395
396            Ok(())
397        })?;
398
399    Ok(())
400}
401
402fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
403    gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecode", ["de"])
404        .modify_generic_constraints(|generics, where_constraints| {
405            for lt in generics.iter_lifetimes() {
406                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
407            }
408            Ok(())
409        })?
410        .generate_fn("dec")
411        .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
412        .with_arg("s", "&mut S")
413        .with_return_type("::sunset::sshwire::WireResult<Self>")
414        .body(|fn_body| {
415            let mut named_enums = HashSet::new();
416            if let Some(Fields::Struct(v)) = &body.fields {
417                for f in v {
418                    let atts = take_field_atts(&f.1.attributes)?;
419                    for a in atts {
420                        if let FieldAtt::VariantName(enum_field) = a {
421                            // Read the extra field on the wire that isn't directly included in the struct
422                            named_enums.insert(enum_field.to_string());
423                            fn_body.push_parsed(format!("let enum_name_{enum_field}: BinString = ::sunset::sshwire::SSHDecode::dec(s)?;"))?;
424                        }
425                    }
426                    let fname = &f.0;
427                    if named_enums.contains(&fname.to_string()) {
428                        fn_body.push_parsed(format!("let field_{fname} =  ::sunset::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname}.0)?;"))?;
429                    } else {
430                        fn_body.push_parsed(format!("let field_{fname} = ::sunset::sshwire::SSHDecode::dec(s)?;"))?;
431                    }
432                }
433            }
434            fn_body.ident_str("Ok");
435            fn_body.group(Delimiter::Parenthesis, |fn_body| {
436                match &body.fields {
437                    Some(Fields::Tuple(f)) => {
438                        // we don't handle attributes for Tuple Structs - only use as newtype
439                        fn_body.ident_str("Self");
440                        fn_body.group(Delimiter::Parenthesis, |args| {
441                            for _ in f.iter() {
442                                args.push_parsed(format!("::sunset::sshwire::SSHDecode::dec(s)?,"))?;
443                            }
444                            Ok(())
445                        })?;
446                    }
447                    Some(Fields::Struct(v)) => {
448                        fn_body.ident_str("Self");
449                        fn_body.group(Delimiter::Brace, |args| {
450                            for f in v {
451                                let fname = &f.0;
452                                args.push_parsed(format!("{fname}: field_{fname},"))?;
453                            }
454                            Ok(())
455                        })?;
456                    }
457                    None => {
458                        // An empty struct (or unit or empty tuple-struct)
459                        fn_body.ident_str("Self");
460                        fn_body.group(Delimiter::Brace, |_| Ok(()))?;
461                    }
462                }
463                Ok(())
464            })?;
465            Ok(())
466        })?;
467    Ok(())
468}
469
470fn decode_enum(
471    gen: &mut Generator,
472    atts: &[Attribute],
473    body: EnumBody,
474) -> Result<()> {
475    let cont_atts = take_cont_atts(atts)?;
476
477    if cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
478        return Err(Error::Custom {
479            error:
480                "SSHDecode derive can't be used with #[sshwire(no_variant_names)]"
481                    .into(),
482            span: None,
483        });
484    }
485
486    // SSHDecode trait if it is self describing
487    if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
488        decode_enum_variant_prefix(gen, atts, &body)?;
489    }
490
491    decode_enum_names(gen, atts, &body)?;
492    Ok(())
493}
494
495fn decode_enum_variant_prefix(
496    gen: &mut Generator,
497    _atts: &[Attribute],
498    _body: &EnumBody,
499) -> Result<()> {
500    gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecode", ["de"])
501        .modify_generic_constraints(|generics, where_constraints| {
502            for lt in generics.iter_lifetimes() {
503                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
504            }
505            Ok(())
506        })?
507        .generate_fn("dec")
508        .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
509        .with_arg("s", "&mut S")
510        .with_return_type("::sunset::sshwire::WireResult<Self>")
511        .body(|fn_body| {
512            fn_body
513                .push_parsed("let variant: ::sunset::sshwire::BinString = ::sunset::sshwire::SSHDecode::dec(s)?;")?;
514            fn_body.push_parsed(
515                "::sunset::sshwire::SSHDecodeEnum::dec_enum(s, variant.0)",
516            )?;
517            Ok(())
518        })
519}
520
521fn decode_enum_names(
522    gen: &mut Generator,
523    _atts: &[Attribute],
524    body: &EnumBody,
525) -> Result<()> {
526    gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecodeEnum", ["de"])
527        .modify_generic_constraints(|generics, where_constraints| {
528            for lt in generics.iter_lifetimes() {
529                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
530            }
531            Ok(())
532        })?
533        .generate_fn("dec_enum")
534        .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
535        .with_arg("s", "&mut S")
536        .with_arg("variant", "&'de [u8]")
537        .with_return_type("::sunset::sshwire::WireResult<Self>")
538        .body(|fn_body| {
539            // Some(ascii_string), or None
540            fn_body.push_parsed("let var_str = ::sunset::sshwire::try_as_ascii_str(variant).ok();")?;
541
542            fn_body.push_parsed("let r = match var_str")?;
543            fn_body.group(Delimiter::Brace, |match_arm| {
544                let mut unknown_arm = None;
545                for var in &body.variants {
546                    let atts = take_field_atts(&var.attributes)?;
547                    if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
548                        // create the Unknown fallthrough but it will be at the end of the match list
549                        let mut m = StreamBuilder::new();
550                        m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown::new(variant))}}", var.name))?;
551                        if unknown_arm.replace(m).is_some() {
552                            return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None})
553                        }
554                    } else {
555                        let var_name = field_att_var_names(&var.name, atts)?;
556                        match_arm.push_parsed(format!("Some({}) => ", var_name))?;
557                        match_arm.group(Delimiter::Brace, |var_body| {
558                            match var.fields {
559                                None => {
560                                    var_body.push_parsed(format!("Self::{}", var.name))?;
561                                }
562                                Some(Fields::Tuple(ref f)) if f.len() == 1 => {
563                                    var_body.push_parsed(format!("Self::{}(::sunset::sshwire::SSHDecode::dec(s)?)", var.name))?;
564                                }
565                            _ => return Err(Error::Custom { error: "SSHDecode currently only implements Unit or single value enum variants. ".into(), span: None})
566                            }
567                            Ok(())
568                        })?;
569
570                    }
571                    if let Some(unk) = unknown_arm.take() {
572                        match_arm.append(unk);
573                    }
574                }
575                Ok(())
576            })?;
577            fn_body.push_parsed("; Ok(r)")?;
578            Ok(())
579        })?;
580    Ok(())
581}