sunset_sshwire_derive/
lib.rs

1//! Used in conjunction with `sshwire.rs` and `packets.rs`
2//!
3//! `SSHWIRE_DEBUG` environment variable can be set at build time
4//! to write generated files to the `target/` directory.
5
6use std::collections::HashSet;
7use std::env;
8
9use proc_macro::Delimiter;
10use virtue::generate::FnSelfArg;
11use virtue::parse::{Attribute, AttributeLocation, EnumBody, StructBody};
12use virtue::utils::{parse_tagged_attribute, ParsedAttribute};
13use virtue::prelude::*;
14
15const ENV_SSHWIRE_DEBUG: &str = "SSHWIRE_DEBUG";
16
17#[proc_macro_derive(SSHEncode, attributes(sshwire))]
18pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19    encode_inner(input).unwrap_or_else(|e| e.into_token_stream())
20}
21
22#[proc_macro_derive(SSHDecode, attributes(sshwire))]
23pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    decode_inner(input).unwrap_or_else(|e| e.into_token_stream())
25}
26
27fn encode_inner(input: TokenStream) -> Result<TokenStream> {
28    let parse = Parse::new(input)?;
29    let (mut gen, att, body) = parse.into_generator();
30    // println!("att {att:#?}");
31    match body {
32        Body::Struct(body) => {
33            encode_struct(&mut gen, body)?;
34        }
35        Body::Enum(body) => {
36            encode_enum(&mut gen, &att, body)?;
37        }
38    }
39    if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
40        gen.export_to_file("sshwire", "SSHEncode");
41    }
42    gen.finish()
43}
44
45fn decode_inner(input: TokenStream) -> Result<TokenStream> {
46    let parse = Parse::new(input)?;
47    let (mut gen, att, body) = parse.into_generator();
48    // println!("att {att:#?}");
49    match body {
50        Body::Struct(body) => {
51            decode_struct(&mut gen, body)?;
52        }
53        Body::Enum(body) => {
54            decode_enum(&mut gen, &att, body)?;
55        }
56    }
57    if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
58        gen.export_to_file("sshwire", "SSHDecode");
59    }
60    gen.finish()
61}
62
63#[derive(Debug)]
64enum ContainerAtt {
65    /// The string of the method is prefixed to this enum.
66    /// `#[sshwire(variant_prefix)]`
67    VariantPrefix,
68
69    /// Don't generate SSHEncodeEnum. Can't be used with SSHDecode derive.
70    /// `#[sshwire(no_variant_names)]`
71    NoNames,
72}
73
74#[derive(Debug)]
75enum FieldAtt {
76    /// A variant method name will be encoded/decoded before the next field.
77    /// eg `#[sshwire(variant_name = ch)]` for `ChannelRequest`
78    VariantName(Ident),
79
80    /// Any unknown variant name should be recorded here.
81    /// This variant can't be written out.
82    /// `#[sshwire(unknown))]`
83    CaptureUnknown,
84
85    /// The name of a variant, used by the parent struct
86    /// `#[sshwire(variant = "exit-signal"))]`
87    /// or
88    /// `#[sshwire(variant = SSH_NAME_IDENT))]`
89    Variant(TokenTree),
90}
91
92fn take_cont_atts(atts: &[Attribute]) -> Result<Vec<ContainerAtt>> {
93    let x = atts.iter()
94        .filter_map(|a| {
95            parse_tagged_attribute(&a.tokens, "sshwire")
96            .transpose()
97        });
98
99    let mut ret = vec![];
100    // flatten the lists
101    for a in x {
102        for a in a? {
103            let l = match a {
104                ParsedAttribute::Tag(l) if l.to_string() == "no_variant_names" => Ok(ContainerAtt::NoNames),
105                ParsedAttribute::Tag(l) if l.to_string() == "variant_prefix" => Ok(ContainerAtt::VariantPrefix),
106                _ => Err(Error::Custom {
107                    error: "Unknown sshwire atttribute".into(),
108                    span: None,
109                }),
110            }?;
111            ret.push(l);
112        }
113    }
114    Ok(ret)
115}
116
117// TODO: we could use virtue parse_tagged_attribute() though it doesn't support Literals
118fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> {
119    atts.iter()
120        .filter_map(|a| {
121            match a.location {
122                AttributeLocation::Field | AttributeLocation::Variant => {
123                    let mut s = a.tokens.stream().into_iter();
124                    if &s.next().expect("missing attribute name").to_string()
125                        != "sshwire"
126                    {
127                        // skip attributes other than "sshwire"
128                        return None;
129                    }
130                    Some(if let Some(TokenTree::Group(g)) = s.next() {
131                        let mut g = g.stream().into_iter();
132                        let f = match g.next() {
133                            Some(TokenTree::Ident(l))
134                                if l.to_string() == "variant_name" =>
135                            {
136                                // check for '='
137                                match g.next() {
138                                    Some(TokenTree::Punct(p)) if p == '=' => (),
139                                    _ => {
140                                        return Some(Err(Error::Custom {
141                                            error: "Missing '='".into(),
142                                            span: Some(a.tokens.span()),
143                                        }))
144                                    }
145                                }
146                                match g.next() {
147                                    Some(TokenTree::Ident(i)) => {
148                                        Ok(FieldAtt::VariantName(i))
149                                    }
150                                    _ => Err(Error::ExpectedIdent(a.tokens.span())),
151                                }
152                            }
153
154                            Some(TokenTree::Ident(l))
155                                if l.to_string() == "unknown" =>
156                            {
157                                Ok(FieldAtt::CaptureUnknown)
158                            }
159
160                            Some(TokenTree::Ident(l))
161                                if l.to_string() == "variant" =>
162                            {
163                                // check for '='
164                                match g.next() {
165                                    Some(TokenTree::Punct(p)) if p == '=' => (),
166                                    _ => {
167                                        return Some(Err(Error::Custom {
168                                            error: "Missing '='".into(),
169                                            span: Some(a.tokens.span()),
170                                        }))
171                                    }
172                                }
173                                if let Some(t) = g.next() {
174                                    Ok(FieldAtt::Variant(t))
175                                } else {
176                                    Err(Error::Custom {
177                                        error: "Missing expression".into(),
178                                        span: Some(a.tokens.span()),
179                                    })
180                                }
181                            }
182
183                            _ => Err(Error::Custom {
184                                error: "Unknown sshwire atttribute".into(),
185                                span: Some(a.tokens.span()),
186                            }),
187                        };
188
189                        if g.next().is_some() {
190                            Err(Error::Custom {
191                                error: "Extra unhandled parts".into(),
192                                span: Some(a.tokens.span()),
193                            })
194                        } else {
195                            f
196                        }
197                    } else {
198                        Err(Error::Custom {
199                            error: "#[sshwire(...)] attribute is missing (...) part"
200                                .into(),
201                            span: Some(a.tokens.span()),
202                        })
203                    })
204                }
205                _ => panic!("Non-field attribute for field: {a:#?}"),
206            }
207        })
208        .collect()
209}
210
211fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
212    gen.impl_for("crate::sshwire::SSHEncode")
213        .generate_fn("enc")
214        .with_self_arg(FnSelfArg::RefSelf)
215        .with_arg("s", "&mut dyn crate::sshwire::SSHSink")
216        .with_return_type("crate::sshwire::WireResult<()>")
217        .body(|fn_body| {
218            match &body.fields {
219                Some(Fields::Tuple(v)) => {
220                    for (fname, f) in v.iter().enumerate() {
221                        // we're only using single elements for newtype, don't bother with atts for now
222                        if !f.attributes.is_empty() {
223                            return Err(Error::Custom { error: "Attributes aren't allowed for tuple structs".into(), span: Some(f.span()) })
224                        }
225                        fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
226                    }
227                }
228                Some(Fields::Struct(v)) => {
229                    for f in v {
230                        let fname = &f.0;
231                        let atts = take_field_atts(&f.1.attributes)?;
232                        for a in atts {
233                            if let FieldAtt::VariantName(enum_field) = a {
234                                // encode an enum field's variant name before this field
235                                fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{enum_field}.variant_name()?, s)?;"))?;
236                            }
237                        }
238                        fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
239                    }
240
241                }
242                None => {
243                    // nothing to do.
244                    // either an empty braced struct or a unit struct.
245                }
246
247            }
248            fn_body.push_parsed("Ok(())")?;
249            Ok(())
250        })?;
251    Ok(())
252}
253
254fn encode_enum(
255    gen: &mut Generator,
256    atts: &[Attribute],
257    body: EnumBody,
258) -> Result<()> {
259
260    let cont_atts = take_cont_atts(atts)?;
261
262    gen.impl_for("crate::sshwire::SSHEncode")
263        .generate_fn("enc")
264        .with_self_arg(FnSelfArg::RefSelf)
265        .with_arg("s", "&mut dyn crate::sshwire::SSHSink")
266        .with_return_type("crate::sshwire::WireResult<()>")
267        .body(|fn_body| {
268            if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
269                fn_body.push_parsed("crate::sshwire::SSHEncode::enc(&self.variant_name()?, s)?;")?;
270            }
271
272            fn_body.ident_str("match");
273            fn_body.puncts("*");
274            fn_body.ident_str("self");
275            fn_body.group(Delimiter::Brace, |match_arm| {
276                for var in &body.variants {
277                    match_arm.ident_str("Self");
278                    match_arm.puncts("::");
279                    match_arm.ident(var.name.clone());
280
281                    let atts = take_field_atts(&var.attributes)?;
282
283                    let mut rhs = StreamBuilder::new();
284                    match var.fields {
285                        None => {
286                            // Unit enum
287                        }
288                        Some(Fields::Tuple(ref f)) if f.len() == 1 => {
289                            match_arm.group(Delimiter::Parenthesis, |item| {
290                                item.ident_str("ref");
291                                item.ident_str("i");
292                                Ok(())
293                            })?;
294                            if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
295                                rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?;
296                            } else {
297                                rhs.push_parsed(format!("crate::sshwire::SSHEncode::enc(i, s)?;"))?;
298                            }
299
300                        }
301                        _ => return Err(Error::Custom { error: "SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
302                    }
303
304                    match_arm.puncts("=>");
305                    match_arm.group(Delimiter::Brace, |var_body| {
306                        var_body.append(rhs);
307                        Ok(())
308                    })?;
309                }
310                Ok(())
311            })?;
312            // an enum with only an Unknown variant will always return an earlier error
313            fn_body.push_parsed("#[allow(unreachable_code)]")?;
314            fn_body.push_parsed("Ok(())")?;
315            Ok(())
316        })?;
317
318    if !cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
319        encode_enum_names(gen, atts, body)?;
320    }
321    Ok(())
322}
323
324fn field_att_var_names(name: &Ident, mut atts: Vec<FieldAtt>) -> Result<TokenTree> {
325    let mut v = vec![];
326    while let Some(p) = atts.pop() {
327        if let FieldAtt::Variant(t) = p {
328            v.push(t);
329        }
330    }
331    if v.len() != 1 {
332        return Err(Error::Custom { error: format!("One #[sshwire(variant = ...)] attribute is required for each enum field, missing for {:?}", name), span: None});
333    }
334    Ok(v.pop().unwrap())
335}
336
337fn encode_enum_names(
338    gen: &mut Generator,
339    _atts: &[Attribute],
340    body: EnumBody,
341) -> Result<()> {
342    gen.impl_for("crate::sshwire::SSHEncodeEnum")
343        .generate_fn("variant_name")
344        .with_self_arg(FnSelfArg::RefSelf)
345        .with_return_type("crate::sshwire::WireResult<&'static str>")
346        .body(|fn_body| {
347            fn_body.push_parsed("let r = match self")?;
348            fn_body.group(Delimiter::Brace, |match_arm| {
349                for var in &body.variants {
350                    match_arm.ident_str("Self");
351                    match_arm.puncts("::");
352                    match_arm.ident(var.name.clone());
353
354                    let mut rhs = StreamBuilder::new();
355                    let atts = take_field_atts(&var.attributes)?;
356                    if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
357                        rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?;
358                    } else {
359                        rhs.push(field_att_var_names(&var.name, atts)?);
360                    }
361
362                    match var.fields {
363                        None => {
364                            // nothing to do
365                        }
366                        Some(Fields::Tuple(ref f)) if f.len() == 1 => {
367                            match_arm.group(Delimiter::Parenthesis, |item| {
368                                item.ident_str("_");
369                                Ok(())
370                            })?;
371
372                        }
373                        _ => return Err(Error::Custom { error: "SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
374                    }
375
376                    match_arm.puncts("=>");
377                    match_arm.group(Delimiter::Brace, |var_body| {
378                        var_body.append(rhs);
379                        Ok(())
380                    })?;
381                }
382                Ok(())
383            })?;
384            fn_body.push_parsed(";")?;
385            // an enum with only an Unknown variant will always return an earlier error
386            fn_body.push_parsed("#[allow(unreachable_code)]")?;
387            fn_body.push_parsed("Ok(r)")?;
388
389            Ok(())
390        })?;
391
392    Ok(())
393}
394
395fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
396    gen.impl_for_with_lifetimes("crate::sshwire::SSHDecode", ["de"])
397        .modify_generic_constraints(|generics, where_constraints| {
398            for lt in generics.iter_lifetimes() {
399                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
400            }
401            Ok(())
402        })?
403        .generate_fn("dec")
404        .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
405        .with_arg("s", "&mut S")
406        .with_return_type("crate::sshwire::WireResult<Self>")
407        .body(|fn_body| {
408            let mut named_enums = HashSet::new();
409            if let Some(Fields::Struct(v)) = &body.fields {
410                for f in v {
411                    let atts = take_field_atts(&f.1.attributes)?;
412                    for a in atts {
413                        if let FieldAtt::VariantName(enum_field) = a {
414                            // Read the extra field on the wire that isn't directly included in the struct
415                            named_enums.insert(enum_field.to_string());
416                            fn_body.push_parsed(format!("let enum_name_{enum_field}: BinString = crate::sshwire::SSHDecode::dec(s)?;"))?;
417                        }
418                    }
419                    let fname = &f.0;
420                    if named_enums.contains(&fname.to_string()) {
421                        fn_body.push_parsed(format!("let field_{fname} =  crate::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname}.0)?;"))?;
422                    } else {
423                        fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecode::dec(s)?;"))?;
424                    }
425                }
426            }
427            fn_body.ident_str("Ok");
428            fn_body.group(Delimiter::Parenthesis, |fn_body| {
429                match &body.fields {
430                    Some(Fields::Tuple(f)) => {
431                        // we don't handle attributes for Tuple Structs - only use as newtype
432                        fn_body.ident_str("Self");
433                        fn_body.group(Delimiter::Parenthesis, |args| {
434                            for _ in f.iter() {
435                                args.push_parsed(format!("crate::sshwire::SSHDecode::dec(s)?,"))?;
436                            }
437                            Ok(())
438                        })?;
439                    }
440                    Some(Fields::Struct(v)) => {
441                        fn_body.ident_str("Self");
442                        fn_body.group(Delimiter::Brace, |args| {
443                            for f in v {
444                                let fname = &f.0;
445                                args.push_parsed(format!("{fname}: field_{fname},"))?;
446                            }
447                            Ok(())
448                        })?;
449                    }
450                    None => {
451                        // An empty struct (or unit or empty tuple-struct)
452                        fn_body.ident_str("Self");
453                        fn_body.group(Delimiter::Brace, |_| Ok(()))?;
454                    }
455                }
456                Ok(())
457            })?;
458            Ok(())
459        })?;
460    Ok(())
461}
462
463fn decode_enum(
464    gen: &mut Generator,
465    atts: &[Attribute],
466    body: EnumBody,
467) -> Result<()> {
468    let cont_atts = take_cont_atts(atts)?;
469
470    if cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
471        return Err(Error::Custom {
472            error:
473                "SSHDecode derive can't be used with #[sshwire(no_variant_names)]"
474                    .into(),
475            span: None,
476        });
477    }
478
479    // SSHDecode trait if it is self describing
480    if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
481        decode_enum_variant_prefix(gen, atts, &body)?;
482    }
483
484    decode_enum_names(gen, atts, &body)?;
485    Ok(())
486}
487
488fn decode_enum_variant_prefix(
489    gen: &mut Generator,
490    _atts: &[Attribute],
491    _body: &EnumBody,
492) -> Result<()> {
493    gen.impl_for_with_lifetimes("crate::sshwire::SSHDecode", ["de"])
494        .modify_generic_constraints(|generics, where_constraints| {
495            for lt in generics.iter_lifetimes() {
496                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
497            }
498            Ok(())
499        })?
500        .generate_fn("dec")
501        .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
502        .with_arg("s", "&mut S")
503        .with_return_type("crate::sshwire::WireResult<Self>")
504        .body(|fn_body| {
505            fn_body
506                .push_parsed("let variant: crate::sshwire::BinString = crate::sshwire::SSHDecode::dec(s)?;")?;
507            fn_body.push_parsed(
508                "crate::sshwire::SSHDecodeEnum::dec_enum(s, variant.0)",
509            )?;
510            Ok(())
511        })
512}
513
514fn decode_enum_names(
515    gen: &mut Generator,
516    _atts: &[Attribute],
517    body: &EnumBody,
518) -> Result<()> {
519    gen.impl_for_with_lifetimes("crate::sshwire::SSHDecodeEnum", ["de"])
520        .modify_generic_constraints(|generics, where_constraints| {
521            for lt in generics.iter_lifetimes() {
522                where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
523            }
524            Ok(())
525        })?
526        .generate_fn("dec_enum")
527        .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
528        .with_arg("s", "&mut S")
529        .with_arg("variant", "&'de [u8]")
530        .with_return_type("crate::sshwire::WireResult<Self>")
531        .body(|fn_body| {
532            // Some(ascii_string), or None
533            fn_body.push_parsed("let var_str = crate::sshwire::try_as_ascii_str(variant).ok();")?;
534
535            fn_body.push_parsed("let r = match var_str")?;
536            fn_body.group(Delimiter::Brace, |match_arm| {
537                let mut unknown_arm = None;
538                for var in &body.variants {
539                    let atts = take_field_atts(&var.attributes)?;
540                    if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
541                        // create the Unknown fallthrough but it will be at the end of the match list
542                        let mut m = StreamBuilder::new();
543                        m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown::new(variant))}}", var.name))?;
544                        if unknown_arm.replace(m).is_some() {
545                            return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None})
546                        }
547                    } else {
548                        let var_name = field_att_var_names(&var.name, atts)?;
549                        match_arm.push_parsed(format!("Some({}) => ", var_name))?;
550                        match_arm.group(Delimiter::Brace, |var_body| {
551                            match var.fields {
552                                None => {
553                                    var_body.push_parsed(format!("Self::{}", var.name))?;
554                                }
555                                Some(Fields::Tuple(ref f)) if f.len() == 1 => {
556                                    var_body.push_parsed(format!("Self::{}(crate::sshwire::SSHDecode::dec(s)?)", var.name))?;
557                                }
558                            _ => return Err(Error::Custom { error: "SSHDecode currently only implements Unit or single value enum variants. ".into(), span: None})
559                            }
560                            Ok(())
561                        })?;
562
563                    }
564                    if let Some(unk) = unknown_arm.take() {
565                        match_arm.append(unk);
566                    }
567                }
568                Ok(())
569            })?;
570            fn_body.push_parsed("; Ok(r)")?;
571            Ok(())
572        })?;
573    Ok(())
574}