1use std::collections::HashSet;
7use std::env;
8
9use proc_macro::Delimiter;
10use virtue::generate::FnSelfArg;
11use virtue::parse::{Attribute, AttributeLocation, EnumBody, StructBody};
12use virtue::utils::{parse_tagged_attribute, ParsedAttribute};
13use virtue::prelude::*;
14
15const ENV_SSHWIRE_DEBUG: &str = "SSHWIRE_DEBUG";
16
17#[proc_macro_derive(SSHEncode, attributes(sshwire))]
18pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19 encode_inner(input).unwrap_or_else(|e| e.into_token_stream())
20}
21
22#[proc_macro_derive(SSHDecode, attributes(sshwire))]
23pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 decode_inner(input).unwrap_or_else(|e| e.into_token_stream())
25}
26
27fn encode_inner(input: TokenStream) -> Result<TokenStream> {
28 let parse = Parse::new(input)?;
29 let (mut gen, att, body) = parse.into_generator();
30 match body {
32 Body::Struct(body) => {
33 encode_struct(&mut gen, body)?;
34 }
35 Body::Enum(body) => {
36 encode_enum(&mut gen, &att, body)?;
37 }
38 }
39 if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
40 gen.export_to_file("sshwire", "SSHEncode");
41 }
42 gen.finish()
43}
44
45fn decode_inner(input: TokenStream) -> Result<TokenStream> {
46 let parse = Parse::new(input)?;
47 let (mut gen, att, body) = parse.into_generator();
48 match body {
50 Body::Struct(body) => {
51 decode_struct(&mut gen, body)?;
52 }
53 Body::Enum(body) => {
54 decode_enum(&mut gen, &att, body)?;
55 }
56 }
57 if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
58 gen.export_to_file("sshwire", "SSHDecode");
59 }
60 gen.finish()
61}
62
63#[derive(Debug)]
64enum ContainerAtt {
65 VariantPrefix,
68
69 NoNames,
72}
73
74#[derive(Debug)]
75enum FieldAtt {
76 VariantName(Ident),
79
80 CaptureUnknown,
84
85 Variant(TokenTree),
90}
91
92fn take_cont_atts(atts: &[Attribute]) -> Result<Vec<ContainerAtt>> {
93 let x = atts.iter()
94 .filter_map(|a| {
95 parse_tagged_attribute(&a.tokens, "sshwire")
96 .transpose()
97 });
98
99 let mut ret = vec![];
100 for a in x {
102 for a in a? {
103 let l = match a {
104 ParsedAttribute::Tag(l) if l.to_string() == "no_variant_names" => Ok(ContainerAtt::NoNames),
105 ParsedAttribute::Tag(l) if l.to_string() == "variant_prefix" => Ok(ContainerAtt::VariantPrefix),
106 _ => Err(Error::Custom {
107 error: "Unknown sshwire atttribute".into(),
108 span: None,
109 }),
110 }?;
111 ret.push(l);
112 }
113 }
114 Ok(ret)
115}
116
117fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> {
119 atts.iter()
120 .filter_map(|a| {
121 match a.location {
122 AttributeLocation::Field | AttributeLocation::Variant => {
123 let mut s = a.tokens.stream().into_iter();
124 if &s.next().expect("missing attribute name").to_string()
125 != "sshwire"
126 {
127 return None;
129 }
130 Some(if let Some(TokenTree::Group(g)) = s.next() {
131 let mut g = g.stream().into_iter();
132 let f = match g.next() {
133 Some(TokenTree::Ident(l))
134 if l.to_string() == "variant_name" =>
135 {
136 match g.next() {
138 Some(TokenTree::Punct(p)) if p == '=' => (),
139 _ => {
140 return Some(Err(Error::Custom {
141 error: "Missing '='".into(),
142 span: Some(a.tokens.span()),
143 }))
144 }
145 }
146 match g.next() {
147 Some(TokenTree::Ident(i)) => {
148 Ok(FieldAtt::VariantName(i))
149 }
150 _ => Err(Error::ExpectedIdent(a.tokens.span())),
151 }
152 }
153
154 Some(TokenTree::Ident(l))
155 if l.to_string() == "unknown" =>
156 {
157 Ok(FieldAtt::CaptureUnknown)
158 }
159
160 Some(TokenTree::Ident(l))
161 if l.to_string() == "variant" =>
162 {
163 match g.next() {
165 Some(TokenTree::Punct(p)) if p == '=' => (),
166 _ => {
167 return Some(Err(Error::Custom {
168 error: "Missing '='".into(),
169 span: Some(a.tokens.span()),
170 }))
171 }
172 }
173 if let Some(t) = g.next() {
174 Ok(FieldAtt::Variant(t))
175 } else {
176 Err(Error::Custom {
177 error: "Missing expression".into(),
178 span: Some(a.tokens.span()),
179 })
180 }
181 }
182
183 _ => Err(Error::Custom {
184 error: "Unknown sshwire atttribute".into(),
185 span: Some(a.tokens.span()),
186 }),
187 };
188
189 if g.next().is_some() {
190 Err(Error::Custom {
191 error: "Extra unhandled parts".into(),
192 span: Some(a.tokens.span()),
193 })
194 } else {
195 f
196 }
197 } else {
198 Err(Error::Custom {
199 error: "#[sshwire(...)] attribute is missing (...) part"
200 .into(),
201 span: Some(a.tokens.span()),
202 })
203 })
204 }
205 _ => panic!("Non-field attribute for field: {a:#?}"),
206 }
207 })
208 .collect()
209}
210
211fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
212 gen.impl_for("crate::sshwire::SSHEncode")
213 .generate_fn("enc")
214 .with_self_arg(FnSelfArg::RefSelf)
215 .with_arg("s", "&mut dyn crate::sshwire::SSHSink")
216 .with_return_type("crate::sshwire::WireResult<()>")
217 .body(|fn_body| {
218 match &body.fields {
219 Some(Fields::Tuple(v)) => {
220 for (fname, f) in v.iter().enumerate() {
221 if !f.attributes.is_empty() {
223 return Err(Error::Custom { error: "Attributes aren't allowed for tuple structs".into(), span: Some(f.span()) })
224 }
225 fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
226 }
227 }
228 Some(Fields::Struct(v)) => {
229 for f in v {
230 let fname = &f.0;
231 let atts = take_field_atts(&f.1.attributes)?;
232 for a in atts {
233 if let FieldAtt::VariantName(enum_field) = a {
234 fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{enum_field}.variant_name()?, s)?;"))?;
236 }
237 }
238 fn_body.push_parsed(format!("crate::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
239 }
240
241 }
242 None => {
243 }
246
247 }
248 fn_body.push_parsed("Ok(())")?;
249 Ok(())
250 })?;
251 Ok(())
252}
253
254fn encode_enum(
255 gen: &mut Generator,
256 atts: &[Attribute],
257 body: EnumBody,
258) -> Result<()> {
259
260 let cont_atts = take_cont_atts(atts)?;
261
262 gen.impl_for("crate::sshwire::SSHEncode")
263 .generate_fn("enc")
264 .with_self_arg(FnSelfArg::RefSelf)
265 .with_arg("s", "&mut dyn crate::sshwire::SSHSink")
266 .with_return_type("crate::sshwire::WireResult<()>")
267 .body(|fn_body| {
268 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
269 fn_body.push_parsed("crate::sshwire::SSHEncode::enc(&self.variant_name()?, s)?;")?;
270 }
271
272 fn_body.ident_str("match");
273 fn_body.puncts("*");
274 fn_body.ident_str("self");
275 fn_body.group(Delimiter::Brace, |match_arm| {
276 for var in &body.variants {
277 match_arm.ident_str("Self");
278 match_arm.puncts("::");
279 match_arm.ident(var.name.clone());
280
281 let atts = take_field_atts(&var.attributes)?;
282
283 let mut rhs = StreamBuilder::new();
284 match var.fields {
285 None => {
286 }
288 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
289 match_arm.group(Delimiter::Parenthesis, |item| {
290 item.ident_str("ref");
291 item.ident_str("i");
292 Ok(())
293 })?;
294 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
295 rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?;
296 } else {
297 rhs.push_parsed(format!("crate::sshwire::SSHEncode::enc(i, s)?;"))?;
298 }
299
300 }
301 _ => return Err(Error::Custom { error: "SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
302 }
303
304 match_arm.puncts("=>");
305 match_arm.group(Delimiter::Brace, |var_body| {
306 var_body.append(rhs);
307 Ok(())
308 })?;
309 }
310 Ok(())
311 })?;
312 fn_body.push_parsed("#[allow(unreachable_code)]")?;
314 fn_body.push_parsed("Ok(())")?;
315 Ok(())
316 })?;
317
318 if !cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
319 encode_enum_names(gen, atts, body)?;
320 }
321 Ok(())
322}
323
324fn field_att_var_names(name: &Ident, mut atts: Vec<FieldAtt>) -> Result<TokenTree> {
325 let mut v = vec![];
326 while let Some(p) = atts.pop() {
327 if let FieldAtt::Variant(t) = p {
328 v.push(t);
329 }
330 }
331 if v.len() != 1 {
332 return Err(Error::Custom { error: format!("One #[sshwire(variant = ...)] attribute is required for each enum field, missing for {:?}", name), span: None});
333 }
334 Ok(v.pop().unwrap())
335}
336
337fn encode_enum_names(
338 gen: &mut Generator,
339 _atts: &[Attribute],
340 body: EnumBody,
341) -> Result<()> {
342 gen.impl_for("crate::sshwire::SSHEncodeEnum")
343 .generate_fn("variant_name")
344 .with_self_arg(FnSelfArg::RefSelf)
345 .with_return_type("crate::sshwire::WireResult<&'static str>")
346 .body(|fn_body| {
347 fn_body.push_parsed("let r = match self")?;
348 fn_body.group(Delimiter::Brace, |match_arm| {
349 for var in &body.variants {
350 match_arm.ident_str("Self");
351 match_arm.puncts("::");
352 match_arm.ident(var.name.clone());
353
354 let mut rhs = StreamBuilder::new();
355 let atts = take_field_atts(&var.attributes)?;
356 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
357 rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?;
358 } else {
359 rhs.push(field_att_var_names(&var.name, atts)?);
360 }
361
362 match var.fields {
363 None => {
364 }
366 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
367 match_arm.group(Delimiter::Parenthesis, |item| {
368 item.ident_str("_");
369 Ok(())
370 })?;
371
372 }
373 _ => return Err(Error::Custom { error: "SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
374 }
375
376 match_arm.puncts("=>");
377 match_arm.group(Delimiter::Brace, |var_body| {
378 var_body.append(rhs);
379 Ok(())
380 })?;
381 }
382 Ok(())
383 })?;
384 fn_body.push_parsed(";")?;
385 fn_body.push_parsed("#[allow(unreachable_code)]")?;
387 fn_body.push_parsed("Ok(r)")?;
388
389 Ok(())
390 })?;
391
392 Ok(())
393}
394
395fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
396 gen.impl_for_with_lifetimes("crate::sshwire::SSHDecode", ["de"])
397 .modify_generic_constraints(|generics, where_constraints| {
398 for lt in generics.iter_lifetimes() {
399 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
400 }
401 Ok(())
402 })?
403 .generate_fn("dec")
404 .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
405 .with_arg("s", "&mut S")
406 .with_return_type("crate::sshwire::WireResult<Self>")
407 .body(|fn_body| {
408 let mut named_enums = HashSet::new();
409 if let Some(Fields::Struct(v)) = &body.fields {
410 for f in v {
411 let atts = take_field_atts(&f.1.attributes)?;
412 for a in atts {
413 if let FieldAtt::VariantName(enum_field) = a {
414 named_enums.insert(enum_field.to_string());
416 fn_body.push_parsed(format!("let enum_name_{enum_field}: BinString = crate::sshwire::SSHDecode::dec(s)?;"))?;
417 }
418 }
419 let fname = &f.0;
420 if named_enums.contains(&fname.to_string()) {
421 fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname}.0)?;"))?;
422 } else {
423 fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecode::dec(s)?;"))?;
424 }
425 }
426 }
427 fn_body.ident_str("Ok");
428 fn_body.group(Delimiter::Parenthesis, |fn_body| {
429 match &body.fields {
430 Some(Fields::Tuple(f)) => {
431 fn_body.ident_str("Self");
433 fn_body.group(Delimiter::Parenthesis, |args| {
434 for _ in f.iter() {
435 args.push_parsed(format!("crate::sshwire::SSHDecode::dec(s)?,"))?;
436 }
437 Ok(())
438 })?;
439 }
440 Some(Fields::Struct(v)) => {
441 fn_body.ident_str("Self");
442 fn_body.group(Delimiter::Brace, |args| {
443 for f in v {
444 let fname = &f.0;
445 args.push_parsed(format!("{fname}: field_{fname},"))?;
446 }
447 Ok(())
448 })?;
449 }
450 None => {
451 fn_body.ident_str("Self");
453 fn_body.group(Delimiter::Brace, |_| Ok(()))?;
454 }
455 }
456 Ok(())
457 })?;
458 Ok(())
459 })?;
460 Ok(())
461}
462
463fn decode_enum(
464 gen: &mut Generator,
465 atts: &[Attribute],
466 body: EnumBody,
467) -> Result<()> {
468 let cont_atts = take_cont_atts(atts)?;
469
470 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
471 return Err(Error::Custom {
472 error:
473 "SSHDecode derive can't be used with #[sshwire(no_variant_names)]"
474 .into(),
475 span: None,
476 });
477 }
478
479 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
481 decode_enum_variant_prefix(gen, atts, &body)?;
482 }
483
484 decode_enum_names(gen, atts, &body)?;
485 Ok(())
486}
487
488fn decode_enum_variant_prefix(
489 gen: &mut Generator,
490 _atts: &[Attribute],
491 _body: &EnumBody,
492) -> Result<()> {
493 gen.impl_for_with_lifetimes("crate::sshwire::SSHDecode", ["de"])
494 .modify_generic_constraints(|generics, where_constraints| {
495 for lt in generics.iter_lifetimes() {
496 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
497 }
498 Ok(())
499 })?
500 .generate_fn("dec")
501 .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
502 .with_arg("s", "&mut S")
503 .with_return_type("crate::sshwire::WireResult<Self>")
504 .body(|fn_body| {
505 fn_body
506 .push_parsed("let variant: crate::sshwire::BinString = crate::sshwire::SSHDecode::dec(s)?;")?;
507 fn_body.push_parsed(
508 "crate::sshwire::SSHDecodeEnum::dec_enum(s, variant.0)",
509 )?;
510 Ok(())
511 })
512}
513
514fn decode_enum_names(
515 gen: &mut Generator,
516 _atts: &[Attribute],
517 body: &EnumBody,
518) -> Result<()> {
519 gen.impl_for_with_lifetimes("crate::sshwire::SSHDecodeEnum", ["de"])
520 .modify_generic_constraints(|generics, where_constraints| {
521 for lt in generics.iter_lifetimes() {
522 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
523 }
524 Ok(())
525 })?
526 .generate_fn("dec_enum")
527 .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"])
528 .with_arg("s", "&mut S")
529 .with_arg("variant", "&'de [u8]")
530 .with_return_type("crate::sshwire::WireResult<Self>")
531 .body(|fn_body| {
532 fn_body.push_parsed("let var_str = crate::sshwire::try_as_ascii_str(variant).ok();")?;
534
535 fn_body.push_parsed("let r = match var_str")?;
536 fn_body.group(Delimiter::Brace, |match_arm| {
537 let mut unknown_arm = None;
538 for var in &body.variants {
539 let atts = take_field_atts(&var.attributes)?;
540 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
541 let mut m = StreamBuilder::new();
543 m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown::new(variant))}}", var.name))?;
544 if unknown_arm.replace(m).is_some() {
545 return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None})
546 }
547 } else {
548 let var_name = field_att_var_names(&var.name, atts)?;
549 match_arm.push_parsed(format!("Some({}) => ", var_name))?;
550 match_arm.group(Delimiter::Brace, |var_body| {
551 match var.fields {
552 None => {
553 var_body.push_parsed(format!("Self::{}", var.name))?;
554 }
555 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
556 var_body.push_parsed(format!("Self::{}(crate::sshwire::SSHDecode::dec(s)?)", var.name))?;
557 }
558 _ => return Err(Error::Custom { error: "SSHDecode currently only implements Unit or single value enum variants. ".into(), span: None})
559 }
560 Ok(())
561 })?;
562
563 }
564 if let Some(unk) = unknown_arm.take() {
565 match_arm.append(unk);
566 }
567 }
568 Ok(())
569 })?;
570 fn_body.push_parsed("; Ok(r)")?;
571 Ok(())
572 })?;
573 Ok(())
574}