Skip to main content

typeway_macros/
lib.rs

1//! `typeway-macros` — proc macros for the Typeway web framework.
2//!
3//! Provides `typeway_path!` for ergonomic path type construction,
4//! `typeway_api!` for defining complete API types with inline routes,
5//! and `#[handler]` for validating handler functions at the definition site.
6
7use proc_macro::TokenStream;
8use proc_macro2::{Span, TokenStream as TokenStream2};
9use quote::{format_ident, quote};
10use syn::parse::{Parse, ParseStream};
11use syn::{Ident, LitStr, Token, Type};
12
13// ---------------------------------------------------------------------------
14// Path segment parsing (shared between macros)
15// ---------------------------------------------------------------------------
16
17enum PathSegment {
18    Literal(String),
19    Capture(Type),
20}
21
22fn parse_one_segment(input: ParseStream) -> syn::Result<PathSegment> {
23    if input.peek(LitStr) {
24        let lit: LitStr = input.parse()?;
25        let value = lit.value();
26        if value.is_empty() {
27            return Err(syn::Error::new(lit.span(), "path literal cannot be empty"));
28        }
29        if value.contains('/') {
30            return Err(syn::Error::new(
31                lit.span(),
32                "path literal cannot contain '/'; use separate segments",
33            ));
34        }
35        Ok(PathSegment::Literal(value))
36    } else {
37        let ty: Type = input.parse()?;
38        Ok(PathSegment::Capture(ty))
39    }
40}
41
42fn parse_path_segments(input: ParseStream) -> syn::Result<Vec<PathSegment>> {
43    let mut segments = Vec::new();
44    if input.is_empty() || input.peek(Token![;]) {
45        return Ok(segments);
46    }
47    segments.push(parse_one_segment(input)?);
48    while input.peek(Token![/]) {
49        input.parse::<Token![/]>()?;
50        segments.push(parse_one_segment(input)?);
51    }
52    Ok(segments)
53}
54
55/// Build the HCons type chain from path segments.
56/// `mod_path` is the module path prefix for marker types (e.g., `__m::` ).
57fn build_hlist_type(segments: &[PathSegment], mod_path: &TokenStream2) -> TokenStream2 {
58    if segments.is_empty() {
59        return quote! { ::typeway_core::HNil };
60    }
61
62    let head = match &segments[0] {
63        PathSegment::Literal(s) => {
64            let marker = lit_marker_ident(s);
65            quote! { ::typeway_core::Lit<#mod_path #marker> }
66        }
67        PathSegment::Capture(ty) => {
68            quote! { ::typeway_core::Capture<#ty> }
69        }
70    };
71
72    let tail = build_hlist_type(&segments[1..], mod_path);
73    quote! { ::typeway_core::HCons<#head, #tail> }
74}
75
76fn lit_marker_ident(s: &str) -> Ident {
77    format_ident!("__lit_{}", s)
78}
79
80/// Collect unique marker type definitions for all literal segments.
81fn collect_marker_defs(
82    segments: &[PathSegment],
83    seen: &mut std::collections::HashSet<String>,
84) -> Vec<TokenStream2> {
85    let mut defs = Vec::new();
86    for seg in segments {
87        if let PathSegment::Literal(s) = seg {
88            if seen.insert(s.clone()) {
89                let marker = lit_marker_ident(s);
90                let value = s.as_str();
91                defs.push(quote! {
92                    #[allow(non_camel_case_types)]
93                    pub struct #marker;
94                    impl ::typeway_core::LitSegment for #marker {
95                        const VALUE: &'static str = #value;
96                    }
97                });
98            }
99        }
100    }
101    defs
102}
103
104// ---------------------------------------------------------------------------
105// typeway_path! macro
106// ---------------------------------------------------------------------------
107
108/// Defines a path type with auto-generated literal segment markers.
109///
110/// Markers are scoped in a private module to avoid name collisions.
111///
112/// # Syntax
113///
114/// ```ignore
115/// typeway_path!(type UserPath = "users" / u32);
116/// ```
117///
118/// Expands to:
119///
120/// ```ignore
121/// mod __wp_UserPath {
122///     pub struct __lit_users;
123///     impl typeway_core::LitSegment for __lit_users { ... }
124/// }
125/// type UserPath = HCons<Lit<__wp_UserPath::__lit_users>, HCons<Capture<u32>, HNil>>;
126/// ```
127#[proc_macro]
128pub fn typeway_path(input: TokenStream) -> TokenStream {
129    let input = syn::parse_macro_input!(input as TypewayPathInput);
130    let name = &input.name;
131    let vis = &input.vis;
132    let mod_name = format_ident!("__wp_{}", name);
133
134    let mut seen = std::collections::HashSet::new();
135    let marker_defs = collect_marker_defs(&input.segments, &mut seen);
136
137    let mod_path: TokenStream2 = quote! { #mod_name:: };
138    let hlist_type = build_hlist_type(&input.segments, &mod_path);
139
140    quote! {
141        #[doc(hidden)]
142        #[allow(non_snake_case)]
143        mod #mod_name {
144            #(#marker_defs)*
145        }
146        #vis type #name = #hlist_type;
147    }
148    .into()
149}
150
151struct TypewayPathInput {
152    vis: syn::Visibility,
153    name: Ident,
154    segments: Vec<PathSegment>,
155}
156
157impl Parse for TypewayPathInput {
158    fn parse(input: ParseStream) -> syn::Result<Self> {
159        let vis: syn::Visibility = input.parse()?;
160        input.parse::<Token![type]>()?;
161        let name: Ident = input.parse()?;
162        input.parse::<Token![=]>()?;
163        let segments = parse_path_segments(input)?;
164        if input.peek(Token![;]) {
165            input.parse::<Token![;]>()?;
166        }
167        Ok(TypewayPathInput {
168            vis,
169            name,
170            segments,
171        })
172    }
173}
174
175// ---------------------------------------------------------------------------
176// typeway_api! macro
177// ---------------------------------------------------------------------------
178
179/// Defines a complete API type with inline route definitions.
180///
181/// # Syntax
182///
183/// ```ignore
184/// typeway_api! {
185///     type MyAPI = {
186///         GET "users" => Json<Vec<User>>,
187///         GET "users" / u32 => Json<User>,
188///         POST "users" [Json<CreateUser>] => Json<User>,
189///         DELETE "users" / u32 => StatusCode,
190///     };
191/// }
192/// ```
193///
194/// - Methods: `GET`, `POST`, `PUT`, `DELETE`, `PATCH`, `HEAD`, `OPTIONS`
195/// - Request body is specified in `[brackets]` (optional, only for POST/PUT/PATCH)
196/// - Response type follows `=>`
197#[proc_macro]
198pub fn typeway_api(input: TokenStream) -> TokenStream {
199    let input = syn::parse_macro_input!(input as TypewayApiInput);
200    let name = &input.name;
201    let vis = &input.vis;
202    let mod_name = format_ident!("__wa_{}", name);
203
204    let mut seen = std::collections::HashSet::new();
205    let mut all_marker_defs = Vec::new();
206    for route in &input.routes {
207        all_marker_defs.extend(collect_marker_defs(&route.path, &mut seen));
208    }
209
210    let mod_path: TokenStream2 = quote! { #mod_name:: };
211    let mut endpoint_types = Vec::new();
212    for route in &input.routes {
213        let path_type = build_hlist_type(&route.path, &mod_path);
214        let method = method_type_ident(&route.method);
215        let res_type = &route.response;
216
217        let endpoint = if let Some(ref req) = route.request {
218            quote! { ::typeway_core::Endpoint<::typeway_core::#method, #path_type, #req, #res_type> }
219        } else {
220            quote! { ::typeway_core::Endpoint<::typeway_core::#method, #path_type, ::typeway_core::NoBody, #res_type> }
221        };
222        endpoint_types.push(endpoint);
223    }
224
225    quote! {
226        #[doc(hidden)]
227        #[allow(non_snake_case)]
228        mod #mod_name {
229            #(#all_marker_defs)*
230        }
231        #vis type #name = (#(#endpoint_types,)*);
232    }
233    .into()
234}
235
236fn method_type_ident(method: &str) -> Ident {
237    let s = match method.to_uppercase().as_str() {
238        "GET" => "Get",
239        "POST" => "Post",
240        "PUT" => "Put",
241        "DELETE" => "Delete",
242        "PATCH" => "Patch",
243        "HEAD" => "Head",
244        "OPTIONS" => "Options",
245        other => panic!("unknown HTTP method: {other}"),
246    };
247    Ident::new(s, Span::call_site())
248}
249
250struct ApiRoute {
251    method: String,
252    path: Vec<PathSegment>,
253    request: Option<Type>,
254    response: Type,
255}
256
257struct TypewayApiInput {
258    vis: syn::Visibility,
259    name: Ident,
260    routes: Vec<ApiRoute>,
261}
262
263impl Parse for TypewayApiInput {
264    fn parse(input: ParseStream) -> syn::Result<Self> {
265        let vis: syn::Visibility = input.parse()?;
266        input.parse::<Token![type]>()?;
267        let name: Ident = input.parse()?;
268        input.parse::<Token![=]>()?;
269
270        let content;
271        syn::braced!(content in input);
272
273        let mut routes = Vec::new();
274        while !content.is_empty() {
275            routes.push(parse_api_route(&content)?);
276            if content.peek(Token![,]) {
277                content.parse::<Token![,]>()?;
278            }
279        }
280
281        if input.peek(Token![;]) {
282            input.parse::<Token![;]>()?;
283        }
284
285        Ok(TypewayApiInput { vis, name, routes })
286    }
287}
288
289fn parse_api_route(input: ParseStream) -> syn::Result<ApiRoute> {
290    let method_ident: Ident = input.parse()?;
291    let method = method_ident.to_string();
292
293    let mut path = Vec::new();
294    while !input.peek(Token![=>]) && !input.peek(syn::token::Bracket) {
295        if !path.is_empty() {
296            input.parse::<Token![/]>()?;
297        }
298        path.push(parse_one_segment(input)?);
299    }
300
301    let request = if input.peek(syn::token::Bracket) {
302        let bracket_content;
303        syn::bracketed!(bracket_content in input);
304        Some(bracket_content.parse::<Type>()?)
305    } else {
306        None
307    };
308
309    input.parse::<Token![=>]>()?;
310    let response: Type = input.parse()?;
311
312    Ok(ApiRoute {
313        method,
314        path,
315        request,
316        response,
317    })
318}
319
320// ---------------------------------------------------------------------------
321// path! — lightweight type-position macro (for use in type aliases/binds)
322// ---------------------------------------------------------------------------
323
324/// Constructs a path type expression. Unlike `typeway_path!`, this does NOT
325/// generate marker types — it references markers that were already defined
326/// by a `typeway_path!` or `typeway_api!` invocation.
327///
328/// Not recommended for direct use — prefer `typeway_path!` which handles
329/// both marker generation and type definition.
330#[proc_macro]
331pub fn path(input: TokenStream) -> TokenStream {
332    let input = syn::parse_macro_input!(input as PathRefInput);
333    let empty_mod = quote! {};
334    let hlist = build_hlist_type(&input.segments, &empty_mod);
335    hlist.into()
336}
337
338struct PathRefInput {
339    segments: Vec<PathSegment>,
340}
341
342impl Parse for PathRefInput {
343    fn parse(input: ParseStream) -> syn::Result<Self> {
344        let segments = parse_path_segments(input)?;
345        Ok(PathRefInput { segments })
346    }
347}
348
349// ---------------------------------------------------------------------------
350// #[handler] attribute macro
351// ---------------------------------------------------------------------------
352
353/// Validates a handler function at its definition site.
354///
355/// Checks that:
356/// - The function is `async`
357/// - All arguments (except the last) implement `FromRequestParts`
358/// - The last argument implements either `FromRequestParts` or `FromRequest`
359/// - The return type implements `IntoResponse`
360///
361/// The function is emitted unchanged. It already works with `bind` and
362/// the blanket `Handler<Args>` impls. This macro exists purely for early,
363/// readable compile errors instead of cryptic trait-resolution failures at
364/// the `Server::new` call site.
365///
366/// # Example
367///
368/// ```ignore
369/// #[handler]
370/// async fn get_user(path: Path<UserByIdPath>, state: State<AppState>) -> Json<User> {
371///     // ...
372/// }
373/// ```
374///
375/// # Compile errors
376///
377/// ```ignore
378/// #[handler]
379/// fn not_async() -> String { "hello".to_string() }
380/// // error: handler functions must be async
381///
382/// #[handler]
383/// async fn bad_return() -> NotAResponse { NotAResponse }
384/// // error: `NotAResponse` does not implement `IntoResponse`
385/// ```
386#[proc_macro_attribute]
387pub fn handler(attr: TokenStream, item: TokenStream) -> TokenStream {
388    let _ = attr; // no attributes expected
389    let func = match syn::parse::<syn::ItemFn>(item.clone()) {
390        Ok(f) => f,
391        Err(e) => return e.to_compile_error().into(),
392    };
393
394    // Must be async.
395    if func.sig.asyncness.is_none() {
396        return syn::Error::new_spanned(func.sig.fn_token, "handler functions must be async")
397            .to_compile_error()
398            .into();
399    }
400
401    let fn_name = &func.sig.ident;
402    let check_mod = format_ident!("__typeway_check_{}", fn_name);
403
404    // Collect typed arguments (skip self).
405    let typed_args: Vec<&syn::PatType> = func
406        .sig
407        .inputs
408        .iter()
409        .filter_map(|arg| match arg {
410            syn::FnArg::Typed(pt) => Some(pt),
411            _ => None,
412        })
413        .collect();
414
415    // Generate FromRequestParts assertions for all-but-last args,
416    // and a FromRequestParts-or-FromRequest check for the last arg.
417    let mut parts_checks = Vec::new();
418
419    for (i, arg) in typed_args.iter().enumerate() {
420        let ty = &arg.ty;
421        if i < typed_args.len() - 1 {
422            // Non-last args must be FromRequestParts.
423            let assert_fn = format_ident!("__assert_parts_{}", i);
424            let call_fn = format_ident!("__call_parts_{}", i);
425            parts_checks.push(quote! {
426                fn #assert_fn<T: ::typeway_server::FromRequestParts>() {}
427                fn #call_fn() { #assert_fn::<#ty>(); }
428            });
429        }
430        // Last arg: could be FromRequestParts or FromRequest.
431        // We can't express an "or" bound, so the blanket Handler
432        // impl catches type mismatches for the last argument.
433    }
434
435    // Return type must implement IntoResponse.
436    let ret_ty = match &func.sig.output {
437        syn::ReturnType::Default => quote! { () },
438        syn::ReturnType::Type(_, ty) => quote! { #ty },
439    };
440
441    let expanded = quote! {
442        #func
443
444        #[doc(hidden)]
445        #[allow(non_snake_case, unused, dead_code, unreachable_code)]
446        mod #check_mod {
447            use super::*;
448
449            fn __check_response<T: ::typeway_server::IntoResponse>() {}
450
451            fn __check_response_call() {
452                __check_response::<#ret_ty>();
453            }
454
455            #(#parts_checks)*
456        }
457    };
458
459    expanded.into()
460}
461
462// ---------------------------------------------------------------------------
463// #[api_description] trait macro
464// ---------------------------------------------------------------------------
465
466/// Defines an API as a trait, generating endpoint types and a `Serves` bridge.
467///
468/// Each method in the trait is annotated with an HTTP method attribute (`#[get(...)]`,
469/// `#[post(...)]`, etc.) that specifies the path. The macro generates:
470///
471/// 1. A type alias `<TraitName>Spec` — a tuple of endpoint types
472/// 2. The original trait with async method signatures
473/// 3. An `into_handlers` method that produces a handler tuple for `Server::new`
474///
475/// # Example
476///
477/// ```ignore
478/// #[api_description]
479/// trait UserAPI {
480///     #[get("users" / u32)]
481///     async fn get_user(path: Path<UserByIdPath>) -> Json<User>;
482///
483///     #[post("users")]
484///     async fn create_user(body: Json<CreateUser>) -> Json<User>;
485/// }
486///
487/// struct MyImpl { db: DbPool }
488/// impl UserAPI for MyImpl {
489///     async fn get_user(path: Path<UserByIdPath>) -> Json<User> {
490///         let user = User { id: path.id, name: "Alice".into() };
491///         Json(user)
492///     }
493///     async fn create_user(body: Json<CreateUser>) -> Json<User> {
494///         let user = User { id: 1, name: body.0.name.clone() };
495///         Json(user)
496///     }
497/// }
498///
499/// // Use: serve_user_api() bridges the trait impl to Server::new
500/// Server::<UserAPISpec>::new(serve_user_api(MyImpl));
501/// ```
502#[proc_macro_attribute]
503pub fn api_description(attr: TokenStream, item: TokenStream) -> TokenStream {
504    let _ = attr;
505    let trait_def = match syn::parse::<syn::ItemTrait>(item) {
506        Ok(t) => t,
507        Err(e) => return e.to_compile_error().into(),
508    };
509
510    match api_description_impl(trait_def) {
511        Ok(ts) => ts.into(),
512        Err(e) => e.to_compile_error().into(),
513    }
514}
515
516fn api_description_impl(trait_def: syn::ItemTrait) -> syn::Result<TokenStream2> {
517    let trait_name = &trait_def.ident;
518    let trait_vis = &trait_def.vis;
519    let spec_name = format_ident!("{}Spec", trait_name);
520    let handlers_fn = format_ident!("serve_{}", to_snake_case(&trait_name.to_string()));
521    let markers_mod = format_ident!("__wa_desc_{}", trait_name);
522
523    // Parse each method and its route attribute.
524    let mut routes = Vec::new();
525    let mut clean_methods = Vec::new();
526
527    for item in &trait_def.items {
528        let method = match item {
529            syn::TraitItem::Fn(m) => m,
530            other => {
531                clean_methods.push(quote! { #other });
532                continue;
533            }
534        };
535
536        // Find and parse the route attribute (#[get(...)], #[post(...)], etc.).
537        let (http_method, path_segments, remaining_attrs) = parse_route_attr(method)?;
538
539        let sig = &method.sig;
540        if sig.asyncness.is_none() {
541            return Err(syn::Error::new_spanned(
542                sig.fn_token,
543                "api_description methods must be async",
544            ));
545        }
546
547        // Emit clean method (without the route attribute).
548        // - Inject `&self` as the first parameter if not already present.
549        // - Desugar `async fn` to `fn -> impl Future<Output = T> + Send`
550        //   so the trait is object-safe and futures are Send.
551        let default_body = &method.default;
552        let mut clean_sig = method.sig.clone();
553        let has_self = clean_sig
554            .inputs
555            .iter()
556            .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
557        if !has_self {
558            clean_sig.inputs.insert(0, syn::parse_quote! { &self });
559        }
560        // Desugar async fn to fn -> impl Future + Send.
561        if clean_sig.asyncness.is_some() {
562            clean_sig.asyncness = None;
563            let ret_ty = match &clean_sig.output {
564                syn::ReturnType::Default => quote! { () },
565                syn::ReturnType::Type(_, ty) => quote! { #ty },
566            };
567            clean_sig.output = syn::parse_quote! {
568                -> impl ::std::future::Future<Output = #ret_ty> + Send
569            };
570        }
571        let semi = if default_body.is_none() {
572            quote! { ; }
573        } else {
574            quote! {}
575        };
576        clean_methods.push(quote! {
577            #(#remaining_attrs)*
578            #clean_sig #default_body #semi
579        });
580
581        routes.push(ParsedRoute {
582            method_name: sig.ident.clone(),
583            http_method,
584            path_segments,
585            sig: sig.clone(),
586        });
587    }
588
589    // Collect all literal marker types.
590    let mut seen = std::collections::HashSet::new();
591    let mut all_marker_defs = Vec::new();
592    for route in &routes {
593        all_marker_defs.extend(collect_marker_defs(&route.path_segments, &mut seen));
594    }
595
596    let mod_path: TokenStream2 = quote! { #markers_mod:: };
597
598    // Build endpoint types and path type aliases for each route.
599    let mut endpoint_types = Vec::new();
600    let mut path_type_aliases = Vec::new();
601    for route in &routes {
602        let path_type = build_hlist_type(&route.path_segments, &mod_path);
603        let method_type = method_type_ident(&route.http_method);
604
605        // Generate a path type alias named after the method (e.g., get_user -> GetUserPath).
606        let path_alias = format_ident!("{}Path", to_pascal_case(&route.method_name.to_string()));
607        path_type_aliases.push(quote! {
608            #trait_vis type #path_alias = #path_type;
609        });
610
611        // Extract request body type and response type from the signature.
612        let (req_type, res_type) = extract_req_res_types(&route.sig)?;
613
614        let endpoint = match req_type {
615            Some(req) => {
616                quote! { ::typeway_core::Endpoint<::typeway_core::#method_type, #path_type, #req, #res_type> }
617            }
618            None => {
619                quote! { ::typeway_core::Endpoint<::typeway_core::#method_type, #path_type, ::typeway_core::NoBody, #res_type> }
620            }
621        };
622        endpoint_types.push(endpoint);
623    }
624
625    // Generate the into_handlers function.
626    // For each route, create a closure that calls the trait method.
627    let impl_clones: Vec<TokenStream2> = routes
628        .iter()
629        .enumerate()
630        .map(|(i, _)| {
631            let clone_name = format_ident!("__impl_{}", i);
632            quote! { let #clone_name = __impl.clone(); }
633        })
634        .collect();
635
636    let handler_binds: Vec<TokenStream2> = routes
637        .iter()
638        .enumerate()
639        .map(|(i, route)| {
640            let method_name = &route.method_name;
641            let ep_type = &endpoint_types[i];
642            let clone_name = format_ident!("__impl_{}", i);
643
644            // Collect typed arguments (skip &self receivers).
645            let args: Vec<&syn::PatType> = route
646                .sig
647                .inputs
648                .iter()
649                .filter_map(|arg| match arg {
650                    syn::FnArg::Typed(pt) => Some(pt),
651                    _ => None,
652                })
653                .collect();
654
655            let arg_pats: Vec<&syn::Pat> = args.iter().map(|a| a.pat.as_ref()).collect();
656            let arg_types: Vec<&syn::Type> = args.iter().map(|a| a.ty.as_ref()).collect();
657
658            quote! {
659                ::typeway_server::bind::<#ep_type, _, _>(
660                    move |#(#arg_pats: #arg_types),*| {
661                        let __self = #clone_name.clone();
662                        async move {
663                            __self.#method_name(#(#arg_pats),*).await
664                        }
665                    }
666                )
667            }
668        })
669        .collect();
670
671    // Supertraits of the original trait.
672    let supertraits = &trait_def.supertraits;
673    let colon_token = &trait_def.colon_token;
674
675    let expanded = quote! {
676        // Marker types for literal path segments.
677        #[doc(hidden)]
678        #[allow(non_snake_case, non_camel_case_types)]
679        mod #markers_mod {
680            #(#all_marker_defs)*
681        }
682
683        // Path type aliases for each route (e.g., GetUserPath, CreateUserPath).
684        #(#path_type_aliases)*
685
686        // The API spec type alias.
687        #trait_vis type #spec_name = (#(#endpoint_types,)*);
688
689        // The trait itself (with route attributes stripped).
690        #trait_vis trait #trait_name #colon_token #supertraits {
691            #(#clean_methods)*
692        }
693
694        // Bridge: convert a trait impl into bound handlers for Server::new.
695        //
696        // Usage: `Server::<UserAPISpec>::new(serve_user_api(my_impl))`
697        #trait_vis fn #handlers_fn<T>(
698            __impl: T,
699        ) -> (#(::typeway_server::BoundHandler<#endpoint_types>,)*)
700        where
701            T: #trait_name + Clone + Send + Sync + 'static,
702        {
703            #(#impl_clones)*
704            (#(#handler_binds,)*)
705        }
706    };
707
708    Ok(expanded)
709}
710
711/// Convert snake_case to PascalCase.
712fn to_pascal_case(s: &str) -> String {
713    s.split('_')
714        .map(|word| {
715            let mut chars = word.chars();
716            match chars.next() {
717                None => String::new(),
718                Some(c) => c.to_uppercase().chain(chars).collect(),
719            }
720        })
721        .collect()
722}
723
724/// Convert PascalCase to snake_case, handling acronyms correctly.
725/// "UserAPI" -> "user_api", "HTMLParser" -> "html_parser"
726fn to_snake_case(s: &str) -> String {
727    let mut result = String::new();
728    let chars: Vec<char> = s.chars().collect();
729    for (i, &ch) in chars.iter().enumerate() {
730        if ch.is_uppercase() {
731            if i > 0 {
732                let prev_upper = chars[i - 1].is_uppercase();
733                let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
734                // Insert underscore before: a new uppercase word, or the last letter
735                // of an acronym followed by a lowercase letter.
736                if !prev_upper || next_lower {
737                    result.push('_');
738                }
739            }
740            result.push(ch.to_ascii_lowercase());
741        } else {
742            result.push(ch);
743        }
744    }
745    result
746}
747
748struct ParsedRoute {
749    method_name: Ident,
750    http_method: String,
751    path_segments: Vec<PathSegment>,
752    sig: syn::Signature,
753}
754
755/// Parse the `#[get(...)]`, `#[post(...)]`, etc. attribute from a trait method.
756/// Returns (http_method, path_segments, remaining_attrs).
757fn parse_route_attr(
758    method: &syn::TraitItemFn,
759) -> syn::Result<(String, Vec<PathSegment>, Vec<syn::Attribute>)> {
760    let route_methods = ["get", "post", "put", "delete", "patch", "head", "options"];
761    let mut http_method = None;
762    let mut path_segments = None;
763    let mut remaining_attrs = Vec::new();
764
765    for attr in &method.attrs {
766        let ident = attr.path().get_ident();
767        if let Some(id) = ident {
768            let name = id.to_string();
769            if route_methods.contains(&name.as_str()) {
770                if http_method.is_some() {
771                    return Err(syn::Error::new_spanned(
772                        attr,
773                        "only one route attribute per method",
774                    ));
775                }
776                http_method = Some(name.to_uppercase());
777                // Parse the attribute arguments as path segments.
778                let segments: Vec<PathSegment> = attr.parse_args_with(parse_path_segments)?;
779                path_segments = Some(segments);
780                continue;
781            }
782        }
783        remaining_attrs.push(attr.clone());
784    }
785
786    match (http_method, path_segments) {
787        (Some(m), Some(p)) => Ok((m, p, remaining_attrs)),
788        _ => Err(syn::Error::new_spanned(
789            &method.sig.ident,
790            "api_description methods must have a route attribute: #[get(...)], #[post(...)], etc.",
791        )),
792    }
793}
794
795/// Extract request body type and response type from a method signature.
796///
797/// The response type is the return type. The request body type is the last
798/// argument if the HTTP method supports a body (POST, PUT, PATCH) and the
799/// last argument looks like a body extractor (Json<T>, Bytes, String).
800/// For simplicity, we always treat the return type as the response and
801/// don't try to extract the body type from the signature — that's determined
802/// by the endpoint type parameters and the Handler impls at the server level.
803fn extract_req_res_types(
804    sig: &syn::Signature,
805) -> syn::Result<(Option<TokenStream2>, TokenStream2)> {
806    let res_type = match &sig.output {
807        syn::ReturnType::Default => quote! { () },
808        syn::ReturnType::Type(_, ty) => quote! { #ty },
809    };
810
811    // For the request type, we don't infer it from args — it's NoBody by default.
812    // Users should specify body types via the endpoint type if needed.
813    // The macro primarily provides ergonomic trait-based API definitions.
814    Ok((None, res_type))
815}
816
817// ---------------------------------------------------------------------------
818// endpoint! — builder-style endpoint type macro
819// ---------------------------------------------------------------------------
820
821/// Defines an endpoint type with builder-style options.
822///
823/// Desugars nested wrappers (`Protected`, `Validated`, `Strict`, etc.)
824/// into a single readable declaration.
825///
826/// # Syntax
827///
828/// ```ignore
829/// endpoint! {
830///     GET "users" / u32 => Json<User>,
831///     auth: AuthUser,
832///     errors: JsonError,
833///     strict: true,
834/// }
835///
836/// endpoint! {
837///     POST "users",
838///     body: CreateUser => Json<User>,
839///     auth: AuthUser,
840///     validate: CreateUserValidator,
841///     content_type: json,
842///     errors: JsonError,
843///     version: V1,
844/// }
845/// ```
846///
847/// # Fields
848///
849/// - Method + path + `=>` response (required)
850/// - `body:` request body type (for POST/PUT/PATCH)
851/// - `auth:` wraps in `Protected<Auth, _>`
852/// - `validate:` wraps in `Validated<V, _>`
853/// - `content_type:` wraps in `ContentType<C, _>` (`json` or a type)
854/// - `errors:` sets the `Err` type parameter
855/// - `version:` wraps in `Versioned<V, _>`
856/// - `strict:` wraps in `Strict<_>` (if `true`)
857/// - `rate_limit:` wraps in `RateLimited<R, _>`
858#[proc_macro]
859pub fn endpoint(input: TokenStream) -> TokenStream {
860    let input = syn::parse_macro_input!(input as EndpointInput);
861    match endpoint_impl(input) {
862        Ok(ts) => ts.into(),
863        Err(e) => e.to_compile_error().into(),
864    }
865}
866
867struct EndpointInput {
868    method: String,
869    path_type: Type,
870    body_type: Option<Type>,
871    response_type: Type,
872    auth: Option<Type>,
873    validate: Option<Type>,
874    content_type: Option<Type>,
875    errors: Option<Type>,
876    version: Option<Type>,
877    strict: bool,
878    rate_limit: Option<Type>,
879}
880
881impl Parse for EndpointInput {
882    fn parse(input: ParseStream) -> syn::Result<Self> {
883        // Parse method
884        let method_ident: Ident = input.parse()?;
885        let method = method_ident.to_string().to_uppercase();
886
887        // Parse path type (a named type, not segments)
888        let path_type: Type = input.parse()?;
889
890        // Parse => Response or , body: ... => Response
891        let mut body_type = None;
892        let response_type;
893
894        if input.peek(Token![=>]) {
895            // GET PathType => Response
896            input.parse::<Token![=>]>()?;
897            response_type = input.parse::<Type>()?;
898        } else if input.peek(Token![,]) {
899            input.parse::<Token![,]>()?;
900            // Look for body: ... => Response
901            let key: Ident = input.parse()?;
902            if key != "body" {
903                return Err(syn::Error::new(
904                    key.span(),
905                    "expected `=> Response` or `body: Type => Response`",
906                ));
907            }
908            input.parse::<Token![:]>()?;
909            body_type = Some(input.parse::<Type>()?);
910            input.parse::<Token![=>]>()?;
911            response_type = input.parse::<Type>()?;
912        } else {
913            return Err(input.error("expected `=>` or `,`"));
914        }
915
916        // Consume trailing comma
917        if input.peek(Token![,]) {
918            input.parse::<Token![,]>()?;
919        }
920
921        // Parse optional key: value fields
922        let mut auth = None;
923        let mut validate = None;
924        let mut content_type = None;
925        let mut errors = None;
926        let mut version = None;
927        let mut strict = false;
928        let mut rate_limit = None;
929
930        while !input.is_empty() {
931            let key: Ident = input.parse()?;
932            input.parse::<Token![:]>()?;
933
934            match key.to_string().as_str() {
935                "auth" => auth = Some(input.parse::<Type>()?),
936                "validate" => validate = Some(input.parse::<Type>()?),
937                "content_type" => {
938                    if input.peek(Ident) {
939                        let ct: Ident = input.parse()?;
940                        content_type = Some(match ct.to_string().as_str() {
941                            "json" => syn::parse_quote! { ::typeway_server::typed::JsonContent },
942                            "form" => syn::parse_quote! { ::typeway_server::typed::FormContent },
943                            _ => {
944                                return Err(syn::Error::new(
945                                    ct.span(),
946                                    "expected `json`, `form`, or a type",
947                                ))
948                            }
949                        });
950                    } else {
951                        content_type = Some(input.parse::<Type>()?);
952                    }
953                }
954                "errors" => errors = Some(input.parse::<Type>()?),
955                "version" => version = Some(input.parse::<Type>()?),
956                "strict" => {
957                    let v: syn::LitBool = input.parse()?;
958                    strict = v.value;
959                }
960                "rate_limit" => rate_limit = Some(input.parse::<Type>()?),
961                other => {
962                    return Err(syn::Error::new(
963                        key.span(),
964                        format!("unknown field `{other}`"),
965                    ))
966                }
967            }
968
969            if input.peek(Token![,]) {
970                input.parse::<Token![,]>()?;
971            }
972        }
973
974        Ok(EndpointInput {
975            method,
976            path_type,
977            body_type,
978            response_type,
979            auth,
980            validate,
981            content_type,
982            errors,
983            version,
984            strict,
985            rate_limit,
986        })
987    }
988}
989
990fn endpoint_impl(input: EndpointInput) -> syn::Result<TokenStream2> {
991    let path_type = &input.path_type;
992    let method_type = method_type_ident(&input.method);
993    let response_type = &input.response_type;
994
995    let (req_type, q_type, err_type) = {
996        let req = match &input.body_type {
997            Some(t) => quote! { #t },
998            None => quote! { ::typeway_core::NoBody },
999        };
1000        let q = quote! { () };
1001        let err = match &input.errors {
1002            Some(t) => quote! { #t },
1003            None => quote! { () },
1004        };
1005        (req, q, err)
1006    };
1007
1008    let mut result = quote! {
1009        ::typeway_core::Endpoint<
1010            ::typeway_core::#method_type,
1011            #path_type,
1012            #req_type,
1013            #response_type,
1014            #q_type,
1015            #err_type
1016        >
1017    };
1018
1019    // Apply wrappers inside-out:
1020    // strict → content_type → validate → rate_limit → version → auth
1021    // (auth is outermost so it's checked first at runtime)
1022
1023    if input.strict {
1024        result = quote! { ::typeway_server::typed_response::Strict<#result> };
1025    }
1026
1027    if let Some(ref ct) = input.content_type {
1028        result = quote! { ::typeway_server::typed::ContentType<#ct, #result> };
1029    }
1030
1031    if let Some(ref v) = input.validate {
1032        result = quote! { ::typeway_server::typed::Validated<#v, #result> };
1033    }
1034
1035    if let Some(ref r) = input.rate_limit {
1036        result = quote! { ::typeway_server::typed::RateLimited<#r, #result> };
1037    }
1038
1039    if let Some(ref v) = input.version {
1040        result = quote! { ::typeway_server::typed::Versioned<#v, #result> };
1041    }
1042
1043    if let Some(ref auth) = input.auth {
1044        result = quote! { ::typeway_server::auth::Protected<#auth, #result> };
1045    }
1046
1047    Ok(result)
1048}
1049
1050// ---------------------------------------------------------------------------
1051// #[documented_handler] attribute macro
1052// ---------------------------------------------------------------------------
1053
1054/// Extracts doc comments from a handler function and generates a companion
1055/// `const` of type `HandlerDoc` (from `typeway_core`) containing the
1056/// summary, description, operation ID, and tags.
1057///
1058/// The first line of the doc comment becomes the `summary`. All subsequent
1059/// non-empty lines become the `description`. The function name becomes the
1060/// `operation_id`. Tags can be specified via the attribute parameter.
1061///
1062/// # Generated output
1063///
1064/// For a function named `list_users`, the macro generates a constant named
1065/// `LIST_USERS_DOC` of type `typeway_core::HandlerDoc`.
1066///
1067/// # Example
1068///
1069/// ```ignore
1070/// /// List all users.
1071/// ///
1072/// /// Returns a paginated list of users with optional filtering.
1073/// #[documented_handler(tags = "users")]
1074/// async fn list_users(state: State<Db>) -> Json<Vec<User>> {
1075///     // ...
1076/// }
1077///
1078/// // Generated:
1079/// // pub const LIST_USERS_DOC: typeway_core::HandlerDoc = typeway_core::HandlerDoc {
1080/// //     summary: "List all users.",
1081/// //     description: "Returns a paginated list of users with optional filtering.",
1082/// //     operation_id: "list_users",
1083/// //     tags: &["users"],
1084/// // };
1085/// ```
1086///
1087/// # Tags
1088///
1089/// Multiple tags can be comma-separated:
1090///
1091/// ```ignore
1092/// #[documented_handler(tags = "users, admin")]
1093/// ```
1094#[proc_macro_attribute]
1095pub fn documented_handler(attr: TokenStream, item: TokenStream) -> TokenStream {
1096    let func = match syn::parse::<syn::ItemFn>(item.clone()) {
1097        Ok(f) => f,
1098        Err(e) => return e.to_compile_error().into(),
1099    };
1100
1101    let tags = parse_documented_handler_tags(attr.into());
1102
1103    // Extract doc comment lines from #[doc = "..."] attributes.
1104    let doc_lines: Vec<String> = func
1105        .attrs
1106        .iter()
1107        .filter_map(|attr| {
1108            if !attr.path().is_ident("doc") {
1109                return None;
1110            }
1111            if let syn::Meta::NameValue(nv) = &attr.meta {
1112                if let syn::Expr::Lit(lit) = &nv.value {
1113                    if let syn::Lit::Str(s) = &lit.lit {
1114                        return Some(s.value());
1115                    }
1116                }
1117            }
1118            None
1119        })
1120        .collect();
1121
1122    // First non-empty trimmed line is summary, rest is description.
1123    let trimmed: Vec<String> = doc_lines.iter().map(|l| l.trim().to_string()).collect();
1124
1125    let summary = trimmed
1126        .iter()
1127        .find(|l| !l.is_empty())
1128        .cloned()
1129        .unwrap_or_default();
1130
1131    // Description: everything after the first non-empty line, with leading
1132    // blank lines stripped, then joined with newlines.
1133    let description = {
1134        let after_summary: Vec<&str> = trimmed
1135            .iter()
1136            .skip_while(|l| l.is_empty()) // skip leading blanks
1137            .skip(1) // skip the summary line
1138            .map(|s| s.as_str())
1139            .collect();
1140        // Trim leading and trailing empty lines from the description.
1141        let start = after_summary.iter().position(|l| !l.is_empty());
1142        let end = after_summary.iter().rposition(|l| !l.is_empty());
1143        match (start, end) {
1144            (Some(s), Some(e)) => after_summary[s..=e].join("\n"),
1145            _ => String::new(),
1146        }
1147    };
1148
1149    let fn_name = &func.sig.ident;
1150    let const_name = format_ident!("{}_DOC", to_screaming_snake(&fn_name.to_string()));
1151    let operation_id = fn_name.to_string();
1152
1153    let tags_tokens: Vec<TokenStream2> = tags.iter().map(|t| quote! { #t }).collect();
1154    let tags_array = if tags_tokens.is_empty() {
1155        quote! { &[] }
1156    } else {
1157        quote! { &[#(#tags_tokens),*] }
1158    };
1159
1160    let expanded = quote! {
1161        #func
1162
1163        /// Auto-generated handler documentation metadata.
1164        pub const #const_name: ::typeway_core::HandlerDoc = ::typeway_core::HandlerDoc {
1165            summary: #summary,
1166            description: #description,
1167            operation_id: #operation_id,
1168            tags: #tags_array,
1169        };
1170    };
1171
1172    expanded.into()
1173}
1174
1175/// Parse `tags = "foo, bar"` from the attribute arguments.
1176fn parse_documented_handler_tags(attr: TokenStream2) -> Vec<String> {
1177    // Try to parse as `tags = "..."`.
1178    struct TagsAttr {
1179        tags: Vec<String>,
1180    }
1181
1182    impl Parse for TagsAttr {
1183        fn parse(input: ParseStream) -> syn::Result<Self> {
1184            if input.is_empty() {
1185                return Ok(TagsAttr { tags: Vec::new() });
1186            }
1187            let key: Ident = input.parse()?;
1188            if key != "tags" {
1189                return Err(syn::Error::new(key.span(), "expected `tags = \"...\"`"));
1190            }
1191            input.parse::<Token![=]>()?;
1192            let value: LitStr = input.parse()?;
1193            let tags = value
1194                .value()
1195                .split(',')
1196                .map(|s| s.trim().to_string())
1197                .filter(|s| !s.is_empty())
1198                .collect();
1199            Ok(TagsAttr { tags })
1200        }
1201    }
1202
1203    syn::parse2::<TagsAttr>(attr)
1204        .map(|t| t.tags)
1205        .unwrap_or_default()
1206}
1207
1208/// Convert snake_case to SCREAMING_SNAKE_CASE.
1209fn to_screaming_snake(s: &str) -> String {
1210    s.to_uppercase()
1211}
1212
1213// ---------------------------------------------------------------------------
1214// #[derive(TypewaySchema)] — OpenAPI schema derivation from struct definitions
1215// ---------------------------------------------------------------------------
1216
1217/// Derives a `ToSchema` implementation for a struct with named fields.
1218///
1219/// Struct-level and field-level doc comments become `description` values in the
1220/// generated OpenAPI schema. Supports `#[serde(rename_all = "...")]` on the
1221/// struct and `#[serde(rename = "...")]` on individual fields.
1222///
1223/// # Example
1224///
1225/// ```ignore
1226/// /// A user account.
1227/// #[derive(TypewaySchema)]
1228/// struct User {
1229///     /// The unique user identifier.
1230///     id: u32,
1231///     /// The user's display name.
1232///     name: String,
1233/// }
1234/// ```
1235///
1236/// Generates an `impl typeway_openapi::ToSchema for User` that returns an
1237/// object schema with `id` and `name` properties, each carrying its doc
1238/// comment as a description.
1239///
1240/// # Serde rename support
1241///
1242/// ```ignore
1243/// #[derive(TypewaySchema)]
1244/// #[serde(rename_all = "camelCase")]
1245/// struct Article {
1246///     article_title: String,
1247///     tag_list: Vec<String>,
1248/// }
1249/// ```
1250///
1251/// The generated schema uses `articleTitle` and `tagList` as property names.
1252/// Per-field `#[serde(rename = "...")]` overrides `rename_all`.
1253///
1254/// Supported rename strategies: `camelCase`, `snake_case`, `PascalCase`,
1255/// `SCREAMING_SNAKE_CASE`, `kebab-case`.
1256#[proc_macro_derive(TypewaySchema)]
1257pub fn derive_typeway_schema(input: TokenStream) -> TokenStream {
1258    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1259    match derive_typeway_schema_impl(input) {
1260        Ok(ts) => ts.into(),
1261        Err(e) => e.to_compile_error().into(),
1262    }
1263}
1264
1265fn derive_typeway_schema_impl(input: syn::DeriveInput) -> syn::Result<TokenStream2> {
1266    let name = &input.ident;
1267    let name_str = name.to_string();
1268
1269    // Extract container-level doc comment.
1270    let struct_doc = extract_doc_string(&input.attrs);
1271
1272    // Check for #[serde(rename_all = "...")].
1273    let rename_all = extract_serde_rename_all(&input.attrs);
1274
1275    let schema_body = match &input.data {
1276        syn::Data::Struct(data) => match &data.fields {
1277            syn::Fields::Named(named) => {
1278                derive_struct_schema(&named.named, &rename_all, &struct_doc)
1279            }
1280            _ => {
1281                return Err(syn::Error::new_spanned(
1282                    name,
1283                    "TypewaySchema only supports structs with named fields",
1284                ));
1285            }
1286        },
1287        syn::Data::Enum(data) => {
1288            derive_enum_schema(name, data, &input.attrs, &rename_all, &struct_doc)?
1289        }
1290        _ => {
1291            return Err(syn::Error::new_spanned(
1292                name,
1293                "TypewaySchema supports structs and enums",
1294            ));
1295        }
1296    };
1297
1298    let expanded = quote! {
1299        impl ::typeway_openapi::ToSchema for #name {
1300            fn schema() -> ::typeway_openapi::spec::Schema {
1301                use ::typeway_openapi::spec::Schema as __Schema;
1302                #schema_body
1303            }
1304
1305            fn type_name() -> &'static str {
1306                #name_str
1307            }
1308        }
1309    };
1310
1311    Ok(expanded)
1312}
1313
1314fn derive_struct_schema(
1315    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
1316    rename_all: &Option<String>,
1317    struct_doc: &Option<String>,
1318) -> TokenStream2 {
1319    let property_entries: Vec<TokenStream2> = fields
1320        .iter()
1321        .map(|field| {
1322            let field_ident = field.ident.as_ref().unwrap();
1323            let field_type = &field.ty;
1324
1325            let field_name_str = if let Some(rename) = extract_serde_field_rename(&field.attrs) {
1326                rename
1327            } else if let Some(strategy) = rename_all {
1328                apply_rename_strategy(&field_ident.to_string(), strategy)
1329            } else {
1330                field_ident.to_string()
1331            };
1332
1333            let field_doc = extract_doc_string(&field.attrs);
1334
1335            match field_doc {
1336                Some(doc) => quote! {
1337                    (#field_name_str, <#field_type as ::typeway_openapi::ToSchema>::schema()
1338                        .with_description(#doc))
1339                },
1340                None => quote! {
1341                    (#field_name_str, <#field_type as ::typeway_openapi::ToSchema>::schema())
1342                },
1343            }
1344        })
1345        .collect();
1346
1347    let struct_description = match struct_doc {
1348        Some(doc) => quote! { Some(#doc) },
1349        None => quote! { None },
1350    };
1351
1352    quote! {
1353        __Schema::object_with_properties(
1354            vec![#(#property_entries),*],
1355            #struct_description,
1356        )
1357    }
1358}
1359
1360#[derive(Default)]
1361struct EnumTagging {
1362    tag: Option<String>,
1363    content: Option<String>,
1364    untagged: bool,
1365}
1366
1367fn extract_serde_enum_tagging(attrs: &[syn::Attribute]) -> EnumTagging {
1368    let mut t = EnumTagging::default();
1369    for attr in attrs {
1370        if !attr.path().is_ident("serde") {
1371            continue;
1372        }
1373        let _ = attr.parse_nested_meta(|meta| {
1374            if meta.path.is_ident("tag") {
1375                let value = meta.value()?;
1376                let lit: LitStr = value.parse()?;
1377                t.tag = Some(lit.value());
1378            } else if meta.path.is_ident("content") {
1379                let value = meta.value()?;
1380                let lit: LitStr = value.parse()?;
1381                t.content = Some(lit.value());
1382            } else if meta.path.is_ident("untagged") {
1383                t.untagged = true;
1384            }
1385            Ok(())
1386        });
1387    }
1388    t
1389}
1390
1391fn derive_enum_schema(
1392    name: &syn::Ident,
1393    data: &syn::DataEnum,
1394    attrs: &[syn::Attribute],
1395    rename_all: &Option<String>,
1396    struct_doc: &Option<String>,
1397) -> syn::Result<TokenStream2> {
1398    let tagging = extract_serde_enum_tagging(attrs);
1399
1400    // Collect each variant's serialized name plus its kind.
1401    let variants: Vec<(String, &syn::Variant)> = data
1402        .variants
1403        .iter()
1404        .map(|v| {
1405            let serialized = if let Some(rename) = extract_serde_field_rename(&v.attrs) {
1406                rename
1407            } else if let Some(strategy) = rename_all {
1408                apply_rename_strategy(&v.ident.to_string(), strategy)
1409            } else {
1410                v.ident.to_string()
1411            };
1412            (serialized, v)
1413        })
1414        .collect();
1415
1416    let all_unit = variants
1417        .iter()
1418        .all(|(_, v)| matches!(v.fields, syn::Fields::Unit));
1419
1420    let description_setter = match struct_doc {
1421        Some(doc) => quote! { __sch.description = Some(#doc.to_string()); },
1422        None => quote! {},
1423    };
1424
1425    // All-unit + plain (no tag/content/untagged) → string enum.
1426    if all_unit && tagging.tag.is_none() && !tagging.untagged {
1427        let names = variants.iter().map(|(n, _)| n.clone()).collect::<Vec<_>>();
1428        return Ok(quote! {{
1429            let mut __sch = __Schema::string_enum(vec![#(#names),*]);
1430            #description_setter
1431            __sch
1432        }});
1433    }
1434
1435    // General oneOf path.
1436    let variant_schemas: Vec<TokenStream2> = variants
1437        .iter()
1438        .map(|(serialized, v)| build_variant_schema(serialized, v, &tagging))
1439        .collect::<syn::Result<Vec<_>>>()?;
1440
1441    let discriminator = match (&tagging.tag, tagging.untagged) {
1442        (Some(tag), false) => {
1443            quote! {
1444                Some(::typeway_openapi::spec::Discriminator {
1445                    property_name: #tag.to_string(),
1446                    mapping: None,
1447                })
1448            }
1449        }
1450        _ => quote! { None },
1451    };
1452
1453    let _ = name;
1454
1455    Ok(quote! {{
1456        let mut __sch = __Schema::one_of(
1457            vec![#(#variant_schemas),*],
1458            #discriminator,
1459        );
1460        #description_setter
1461        __sch
1462    }})
1463}
1464
1465fn build_variant_schema(
1466    serialized: &str,
1467    v: &syn::Variant,
1468    tagging: &EnumTagging,
1469) -> syn::Result<TokenStream2> {
1470    let variant_doc = extract_doc_string(&v.attrs);
1471    let desc_setter = match &variant_doc {
1472        Some(doc) => quote! { __vsch.description = Some(#doc.to_string()); },
1473        None => quote! {},
1474    };
1475
1476    // Build a Schema expression for the bare variant payload.
1477    let payload_schema = bare_payload_schema(&v.fields)?;
1478
1479    let body = match (&tagging.tag, &tagging.content, tagging.untagged, &v.fields) {
1480        // Untagged: just emit the payload schema.
1481        (_, _, true, _) => quote! { #payload_schema },
1482
1483        // Adjacently tagged: { <tag>: <name>, <content>: <payload> } (or just tag for unit).
1484        (Some(tag), Some(_), false, syn::Fields::Unit) => {
1485            quote! {{
1486                let __tag_schema = __Schema::string_enum(vec![#serialized]);
1487                __Schema::object_with_properties(
1488                    vec![(#tag, __tag_schema)],
1489                    None,
1490                )
1491            }}
1492        }
1493        (Some(tag), Some(content), false, _) => {
1494            quote! {{
1495                let __tag_schema = __Schema::string_enum(vec![#serialized]);
1496                __Schema::object_with_properties(
1497                    vec![
1498                        (#tag, __tag_schema),
1499                        (#content, #payload_schema),
1500                    ],
1501                    None,
1502                )
1503            }}
1504        }
1505
1506        // Internally tagged: merge tag into the payload object's properties.
1507        (Some(tag), None, false, syn::Fields::Unit) => {
1508            quote! {{
1509                let __tag_schema = __Schema::string_enum(vec![#serialized]);
1510                __Schema::object_with_properties(
1511                    vec![(#tag, __tag_schema)],
1512                    None,
1513                )
1514            }}
1515        }
1516        (Some(tag), None, false, syn::Fields::Named(_)) => {
1517            // Build the object from named fields plus the tag property.
1518            let named_entries = named_field_entries(&v.fields)?;
1519            quote! {{
1520                let __tag_schema = __Schema::string_enum(vec![#serialized]);
1521                let mut __entries: Vec<(&str, __Schema)> = vec![(#tag, __tag_schema)];
1522                #(__entries.push(#named_entries);)*
1523                __Schema::object_with_properties(__entries, None)
1524            }}
1525        }
1526        (Some(_), None, false, _) => {
1527            // Internal tagging on tuple/newtype variants is not representable as a
1528            // single tag-merged object. Fall back to the bare payload schema.
1529            quote! { #payload_schema }
1530        }
1531
1532        // Externally tagged (default).
1533        (None, _, false, syn::Fields::Unit) => {
1534            // Serde serializes unit variants as bare strings: "Foo"
1535            quote! { __Schema::string_enum(vec![#serialized]) }
1536        }
1537        (None, _, false, _) => {
1538            // {"<variant>": <payload>}
1539            quote! {{
1540                __Schema::object_with_properties(
1541                    vec![(#serialized, #payload_schema)],
1542                    None,
1543                )
1544            }}
1545        }
1546    };
1547
1548    Ok(quote! {{
1549        let mut __vsch: __Schema = #body;
1550        #desc_setter
1551        __vsch
1552    }})
1553}
1554
1555fn bare_payload_schema(fields: &syn::Fields) -> syn::Result<TokenStream2> {
1556    Ok(match fields {
1557        syn::Fields::Unit => quote! { __Schema::object() },
1558        syn::Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => {
1559            let ty = &unnamed.unnamed.first().unwrap().ty;
1560            quote! { <#ty as ::typeway_openapi::ToSchema>::schema() }
1561        }
1562        syn::Fields::Unnamed(unnamed) => {
1563            // Tuple variant: array of mixed types is not expressible cleanly;
1564            // emit an array schema whose items are the first element's schema.
1565            // This is a reasonable approximation that documents the shape.
1566            let first_ty = &unnamed.unnamed.first().unwrap().ty;
1567            quote! {
1568                __Schema::array(<#first_ty as ::typeway_openapi::ToSchema>::schema())
1569            }
1570        }
1571        syn::Fields::Named(named) => {
1572            let entries: Vec<TokenStream2> = named
1573                .named
1574                .iter()
1575                .map(|field| {
1576                    let field_ident = field.ident.as_ref().unwrap();
1577                    let field_type = &field.ty;
1578                    let field_name = if let Some(rename) = extract_serde_field_rename(&field.attrs)
1579                    {
1580                        rename
1581                    } else {
1582                        field_ident.to_string()
1583                    };
1584                    quote! {
1585                        (#field_name, <#field_type as ::typeway_openapi::ToSchema>::schema())
1586                    }
1587                })
1588                .collect();
1589            quote! {
1590                __Schema::object_with_properties(vec![#(#entries),*], None)
1591            }
1592        }
1593    })
1594}
1595
1596fn named_field_entries(fields: &syn::Fields) -> syn::Result<Vec<TokenStream2>> {
1597    let named = match fields {
1598        syn::Fields::Named(named) => &named.named,
1599        _ => {
1600            return Err(syn::Error::new_spanned(fields, "expected named fields"));
1601        }
1602    };
1603    Ok(named
1604        .iter()
1605        .map(|field| {
1606            let field_ident = field.ident.as_ref().unwrap();
1607            let field_type = &field.ty;
1608            let field_name = if let Some(rename) = extract_serde_field_rename(&field.attrs) {
1609                rename
1610            } else {
1611                field_ident.to_string()
1612            };
1613            quote! {
1614                (#field_name, <#field_type as ::typeway_openapi::ToSchema>::schema())
1615            }
1616        })
1617        .collect())
1618}
1619
1620/// Extract the combined doc comment string from `#[doc = "..."]` attributes.
1621fn extract_doc_string(attrs: &[syn::Attribute]) -> Option<String> {
1622    let doc_lines: Vec<String> = attrs
1623        .iter()
1624        .filter_map(|attr| {
1625            if !attr.path().is_ident("doc") {
1626                return None;
1627            }
1628            if let syn::Meta::NameValue(nv) = &attr.meta {
1629                if let syn::Expr::Lit(lit) = &nv.value {
1630                    if let syn::Lit::Str(s) = &lit.lit {
1631                        return Some(s.value().trim().to_string());
1632                    }
1633                }
1634            }
1635            None
1636        })
1637        .filter(|s| !s.is_empty())
1638        .collect();
1639
1640    if doc_lines.is_empty() {
1641        None
1642    } else {
1643        Some(doc_lines.join("\n"))
1644    }
1645}
1646
1647/// Extract `rename_all` value from `#[serde(rename_all = "...")]`.
1648fn extract_serde_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
1649    for attr in attrs {
1650        if !attr.path().is_ident("serde") {
1651            continue;
1652        }
1653        let mut result = None;
1654        let _ = attr.parse_nested_meta(|meta| {
1655            if meta.path.is_ident("rename_all") {
1656                let value = meta.value()?;
1657                let lit: LitStr = value.parse()?;
1658                result = Some(lit.value());
1659            }
1660            Ok(())
1661        });
1662        if result.is_some() {
1663            return result;
1664        }
1665    }
1666    None
1667}
1668
1669/// Extract `rename` value from `#[serde(rename = "...")]` on a field.
1670fn extract_serde_field_rename(attrs: &[syn::Attribute]) -> Option<String> {
1671    for attr in attrs {
1672        if !attr.path().is_ident("serde") {
1673            continue;
1674        }
1675        let mut result = None;
1676        let _ = attr.parse_nested_meta(|meta| {
1677            if meta.path.is_ident("rename") {
1678                let value = meta.value()?;
1679                let lit: LitStr = value.parse()?;
1680                result = Some(lit.value());
1681            }
1682            Ok(())
1683        });
1684        if result.is_some() {
1685            return result;
1686        }
1687    }
1688    None
1689}
1690
1691// ---------------------------------------------------------------------------
1692// #[derive(ToProtoType)] — protobuf message derivation from struct definitions
1693// ---------------------------------------------------------------------------
1694
1695/// Derives a `ToProtoType` implementation for a struct with named fields.
1696///
1697/// Each field is mapped to a `ProtoField` entry (from `typeway_grpc`).
1698/// Field tags can be specified explicitly with `#[proto(tag = N)]`; fields
1699/// without an explicit tag are auto-numbered based on their 1-indexed position.
1700///
1701/// `Option<T>` fields produce `optional` proto fields. `Vec<T>` fields produce
1702/// `repeated` proto fields (except `Vec<u8>`, which maps to `bytes`).
1703///
1704/// # Example
1705///
1706/// ```ignore
1707/// #[derive(ToProtoType)]
1708/// struct User {
1709///     #[proto(tag = 1)]
1710///     id: u32,
1711///     #[proto(tag = 2)]
1712///     name: String,
1713///     #[proto(tag = 3)]
1714///     bio: Option<String>,
1715/// }
1716/// ```
1717#[proc_macro_derive(ToProtoType, attributes(proto))]
1718pub fn derive_to_proto_type(input: TokenStream) -> TokenStream {
1719    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1720    match derive_to_proto_type_impl(input) {
1721        Ok(ts) => ts.into(),
1722        Err(e) => e.to_compile_error().into(),
1723    }
1724}
1725
1726fn derive_to_proto_type_impl(input: syn::DeriveInput) -> syn::Result<TokenStream2> {
1727    let name = &input.ident;
1728    let name_str = name.to_string();
1729
1730    match &input.data {
1731        syn::Data::Struct(data) => match &data.fields {
1732            syn::Fields::Named(named) => derive_to_proto_type_struct(name, name_str, &named.named),
1733            _ => Err(syn::Error::new_spanned(
1734                name,
1735                "ToProtoType only supports structs with named fields or enums",
1736            )),
1737        },
1738        syn::Data::Enum(data) => {
1739            let is_simple = data.variants.iter().all(|v| v.fields.is_empty());
1740            if is_simple {
1741                derive_to_proto_type_simple_enum(name, name_str, data)
1742            } else {
1743                derive_to_proto_type_oneof_enum(name, name_str, data)
1744            }
1745        }
1746        syn::Data::Union(_) => Err(syn::Error::new_spanned(
1747            name,
1748            "ToProtoType does not support unions",
1749        )),
1750    }
1751}
1752
1753fn derive_to_proto_type_struct(
1754    name: &Ident,
1755    name_str: String,
1756    fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
1757) -> syn::Result<TokenStream2> {
1758    let mut field_entries = Vec::new();
1759    let mut collect_stmts = Vec::new();
1760
1761    for (i, field) in fields.iter().enumerate() {
1762        let field_ident = field.ident.as_ref().unwrap();
1763        let field_name_str = field_ident.to_string();
1764        let field_ty = &field.ty;
1765        let tag = extract_proto_tag(&field.attrs).unwrap_or((i as u32) + 1);
1766        let field_doc = extract_doc_string(&field.attrs);
1767
1768        // Detect Option<T>, Vec<T>, and HashMap<K,V>/BTreeMap<K,V> to set
1769        // optional/repeated/map and use the appropriate inner types.
1770        let (proto_type_ty, optional, repeated, is_map_field) =
1771            if let Some(inner) = is_option_type(field_ty) {
1772                (inner.clone(), true, false, false)
1773            } else if is_vec_u8(field_ty) {
1774                // Vec<u8> maps to bytes — use Vec<u8> directly, not repeated.
1775                (field_ty.clone(), false, false, false)
1776            } else if let Some(inner) = is_vec_type(field_ty) {
1777                (inner.clone(), false, true, false)
1778            } else if is_map_type(field_ty).is_some() {
1779                // Map types — use the original type so ToProtoType dispatches correctly.
1780                (field_ty.clone(), false, false, true)
1781            } else {
1782                (field_ty.clone(), false, false, false)
1783            };
1784
1785        let doc_expr = match &field_doc {
1786            Some(doc) => quote! { ::core::option::Option::Some(#doc.to_string()) },
1787            None => quote! { ::core::option::Option::None },
1788        };
1789
1790        let field_entry = if is_map_field {
1791            let (key_ty, val_ty) = is_map_type(field_ty).unwrap();
1792            quote! {
1793                ::typeway_grpc::ProtoField {
1794                    name: #field_name_str.to_string(),
1795                    proto_type: "map".to_string(),
1796                    tag: #tag,
1797                    repeated: false,
1798                    optional: false,
1799                    is_map: true,
1800                    map_key_type: ::core::option::Option::Some(
1801                        <#key_ty as ::typeway_grpc::ToProtoType>::proto_type_name().to_string()
1802                    ),
1803                    map_value_type: ::core::option::Option::Some(
1804                        <#val_ty as ::typeway_grpc::ToProtoType>::proto_type_name().to_string()
1805                    ),
1806                    doc: #doc_expr,
1807                }
1808            }
1809        } else {
1810            quote! {
1811                ::typeway_grpc::ProtoField {
1812                    name: #field_name_str.to_string(),
1813                    proto_type: <#proto_type_ty as ::typeway_grpc::ToProtoType>::proto_type_name().to_string(),
1814                    tag: #tag,
1815                    repeated: #repeated,
1816                    optional: #optional,
1817                    is_map: false,
1818                    map_key_type: ::core::option::Option::None,
1819                    map_value_type: ::core::option::Option::None,
1820                    doc: #doc_expr,
1821                }
1822            }
1823        };
1824        field_entries.push(field_entry);
1825
1826        collect_stmts.push(quote! {
1827            msgs.extend(<#proto_type_ty as ::typeway_grpc::ToProtoType>::collect_messages());
1828        });
1829    }
1830
1831    let expanded = quote! {
1832        impl ::typeway_grpc::ToProtoType for #name {
1833            fn proto_type_name() -> &'static str {
1834                #name_str
1835            }
1836
1837            fn is_message() -> bool {
1838                true
1839            }
1840
1841            fn message_definition() -> ::core::option::Option<::std::string::String> {
1842                ::core::option::Option::Some(::typeway_grpc::build_message(#name_str, &[
1843                    #(#field_entries),*
1844                ]))
1845            }
1846
1847            fn collect_messages() -> ::std::vec::Vec<::std::string::String> {
1848                let mut msgs = ::std::vec::Vec::new();
1849                #(#collect_stmts)*
1850                if let ::core::option::Option::Some(def) = Self::message_definition() {
1851                    msgs.push(def);
1852                }
1853                msgs
1854            }
1855
1856            fn proto_fields() -> ::std::vec::Vec<::typeway_grpc::ProtoField> {
1857                ::std::vec![#(#field_entries),*]
1858            }
1859        }
1860    };
1861
1862    Ok(expanded)
1863}
1864
1865/// Generate a `ToProtoType` impl for a simple (fieldless) enum as a protobuf enum.
1866fn derive_to_proto_type_simple_enum(
1867    name: &Ident,
1868    name_str: String,
1869    data: &syn::DataEnum,
1870) -> syn::Result<TokenStream2> {
1871    let mut variant_names = Vec::new();
1872    let mut variant_tags = Vec::new();
1873
1874    for (i, variant) in data.variants.iter().enumerate() {
1875        let tag = extract_proto_tag(&variant.attrs).unwrap_or(i as u32);
1876        let proto_name = to_screaming_snake(&variant.ident.to_string());
1877        variant_names.push(proto_name);
1878        variant_tags.push(tag);
1879    }
1880
1881    let expanded = quote! {
1882        impl ::typeway_grpc::ToProtoType for #name {
1883            fn proto_type_name() -> &'static str {
1884                #name_str
1885            }
1886
1887            fn is_message() -> bool {
1888                true
1889            }
1890
1891            fn message_definition() -> ::core::option::Option<::std::string::String> {
1892                let mut lines = ::std::vec![::std::format!("enum {} {{", #name_str)];
1893                #(
1894                    lines.push(::std::format!("  {} = {};", #variant_names, #variant_tags));
1895                )*
1896                lines.push("}".to_string());
1897                ::core::option::Option::Some(lines.join("\n"))
1898            }
1899
1900            fn collect_messages() -> ::std::vec::Vec<::std::string::String> {
1901                let mut msgs = ::std::vec::Vec::new();
1902                if let ::core::option::Option::Some(def) = Self::message_definition() {
1903                    msgs.push(def);
1904                }
1905                msgs
1906            }
1907        }
1908    };
1909
1910    Ok(expanded)
1911}
1912
1913/// Generate a `ToProtoType` impl for a tagged enum as a protobuf `oneof` in a
1914/// wrapper message.
1915fn derive_to_proto_type_oneof_enum(
1916    name: &Ident,
1917    name_str: String,
1918    data: &syn::DataEnum,
1919) -> syn::Result<TokenStream2> {
1920    let oneof_name = to_snake_case(&name_str);
1921
1922    let mut variant_field_names = Vec::new();
1923    let mut variant_types: Vec<syn::Type> = Vec::new();
1924    let mut variant_tags = Vec::new();
1925    let mut collect_stmts = Vec::new();
1926
1927    for (i, variant) in data.variants.iter().enumerate() {
1928        let tag = extract_proto_tag(&variant.attrs).unwrap_or((i + 1) as u32);
1929        let field_name = to_snake_case(&variant.ident.to_string());
1930        variant_field_names.push(field_name);
1931        variant_tags.push(tag);
1932
1933        match &variant.fields {
1934            syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
1935                let ty = fields.unnamed[0].ty.clone();
1936                collect_stmts.push(quote! {
1937                    msgs.extend(<#ty as ::typeway_grpc::ToProtoType>::collect_messages());
1938                });
1939                variant_types.push(ty);
1940            }
1941            syn::Fields::Unnamed(_) => {
1942                return Err(syn::Error::new_spanned(
1943                    &variant.ident,
1944                    "ToProtoType oneof variants must have exactly one field",
1945                ));
1946            }
1947            syn::Fields::Named(_) => {
1948                return Err(syn::Error::new_spanned(
1949                    &variant.ident,
1950                    "ToProtoType oneof variants must use tuple syntax, e.g., Variant(Type)",
1951                ));
1952            }
1953            syn::Fields::Unit => {
1954                return Err(syn::Error::new_spanned(
1955                    &variant.ident,
1956                    "mixed unit and data variants are not supported; \
1957                     all variants must have fields for oneof generation",
1958                ));
1959            }
1960        }
1961    }
1962
1963    let expanded = quote! {
1964        impl ::typeway_grpc::ToProtoType for #name {
1965            fn proto_type_name() -> &'static str {
1966                #name_str
1967            }
1968
1969            fn is_message() -> bool {
1970                true
1971            }
1972
1973            fn message_definition() -> ::core::option::Option<::std::string::String> {
1974                let mut lines = ::std::vec![::std::format!("message {} {{", #name_str)];
1975                lines.push(::std::format!("  oneof {} {{", #oneof_name));
1976                #(
1977                    lines.push(::std::format!("    {} {} = {};",
1978                        <#variant_types as ::typeway_grpc::ToProtoType>::proto_type_name(),
1979                        #variant_field_names,
1980                        #variant_tags,
1981                    ));
1982                )*
1983                lines.push("  }".to_string());
1984                lines.push("}".to_string());
1985                ::core::option::Option::Some(lines.join("\n"))
1986            }
1987
1988            fn collect_messages() -> ::std::vec::Vec<::std::string::String> {
1989                let mut msgs = ::std::vec::Vec::new();
1990                #(#collect_stmts)*
1991                if let ::core::option::Option::Some(def) = Self::message_definition() {
1992                    msgs.push(def);
1993                }
1994                msgs
1995            }
1996        }
1997    };
1998
1999    Ok(expanded)
2000}
2001
2002/// Extract a `#[proto(tag = N)]` attribute value from field attributes.
2003fn extract_proto_tag(attrs: &[syn::Attribute]) -> Option<u32> {
2004    for attr in attrs {
2005        if attr.path().is_ident("proto") {
2006            if let Ok(meta) = attr.parse_args::<syn::MetaNameValue>() {
2007                if meta.path.is_ident("tag") {
2008                    if let syn::Expr::Lit(lit) = &meta.value {
2009                        if let syn::Lit::Int(int) = &lit.lit {
2010                            return int.base10_parse().ok();
2011                        }
2012                    }
2013                }
2014            }
2015        }
2016    }
2017    None
2018}
2019
2020/// If the type is `Option<T>`, return `Some(T)`.
2021fn is_option_type(ty: &syn::Type) -> Option<&syn::Type> {
2022    if let syn::Type::Path(path) = ty {
2023        if let Some(seg) = path.path.segments.last() {
2024            if seg.ident == "Option" {
2025                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
2026                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
2027                        return Some(inner);
2028                    }
2029                }
2030            }
2031        }
2032    }
2033    None
2034}
2035
2036/// If the type is `Vec<T>`, return `Some(T)`.
2037fn is_vec_type(ty: &syn::Type) -> Option<&syn::Type> {
2038    if let syn::Type::Path(path) = ty {
2039        if let Some(seg) = path.path.segments.last() {
2040            if seg.ident == "Vec" {
2041                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
2042                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
2043                        return Some(inner);
2044                    }
2045                }
2046            }
2047        }
2048    }
2049    None
2050}
2051
2052/// Check if the type is `Vec<u8>` (which maps to protobuf `bytes`).
2053fn is_vec_u8(ty: &syn::Type) -> bool {
2054    if let syn::Type::Path(path) = ty {
2055        if let Some(seg) = path.path.segments.last() {
2056            if seg.ident == "Vec" {
2057                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
2058                    if let Some(syn::GenericArgument::Type(syn::Type::Path(inner_path))) =
2059                        args.args.first()
2060                    {
2061                        if let Some(inner_seg) = inner_path.path.segments.last() {
2062                            return inner_seg.ident == "u8" && inner_seg.arguments.is_none();
2063                        }
2064                    }
2065                }
2066            }
2067        }
2068    }
2069    false
2070}
2071
2072/// If the type is `HashMap<K, V>` or `BTreeMap<K, V>`, return `Some((K, V))`.
2073fn is_map_type(ty: &syn::Type) -> Option<(syn::Type, syn::Type)> {
2074    if let syn::Type::Path(path) = ty {
2075        if let Some(seg) = path.path.segments.last() {
2076            if seg.ident == "HashMap" || seg.ident == "BTreeMap" {
2077                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
2078                    let mut types = args.args.iter().filter_map(|a| {
2079                        if let syn::GenericArgument::Type(t) = a {
2080                            Some(t)
2081                        } else {
2082                            None
2083                        }
2084                    });
2085                    if let (Some(k), Some(v)) = (types.next(), types.next()) {
2086                        return Some((k.clone(), v.clone()));
2087                    }
2088                }
2089            }
2090        }
2091    }
2092    None
2093}
2094
2095/// Apply a serde rename strategy to a snake_case field name.
2096fn apply_rename_strategy(name: &str, strategy: &str) -> String {
2097    // Normalize input (which may be snake_case from struct fields or PascalCase
2098    // from enum variants) to a snake_case canonical form, then transform.
2099    let snake = to_snake_case(name);
2100    match strategy {
2101        "camelCase" => {
2102            let mut result = String::new();
2103            let mut capitalize_next = false;
2104            for c in snake.chars() {
2105                if c == '_' {
2106                    capitalize_next = true;
2107                } else if capitalize_next {
2108                    result.extend(c.to_uppercase());
2109                    capitalize_next = false;
2110                } else {
2111                    result.push(c);
2112                }
2113            }
2114            result
2115        }
2116        "snake_case" => snake,
2117        "PascalCase" => snake
2118            .split('_')
2119            .map(|w| {
2120                let mut c = w.chars();
2121                match c.next() {
2122                    Some(f) => f.to_uppercase().to_string() + &c.collect::<String>(),
2123                    None => String::new(),
2124                }
2125            })
2126            .collect(),
2127        "SCREAMING_SNAKE_CASE" => snake.to_uppercase(),
2128        "kebab-case" => snake.replace('_', "-"),
2129        "lowercase" => snake.replace('_', ""),
2130        "UPPERCASE" => snake.replace('_', "").to_uppercase(),
2131        _ => name.to_string(),
2132    }
2133}
2134
2135// ---------------------------------------------------------------------------
2136// #[derive(TypewayCodec)] — compile-time specialized protobuf encode/decode
2137// ---------------------------------------------------------------------------
2138
2139/// Derive `TypewayEncode` and `TypewayDecode` for a struct.
2140///
2141/// Generates specialized encode/decode functions with no runtime dispatch.
2142/// Each field is encoded/decoded directly based on its Rust type and proto
2143/// tag number.
2144///
2145/// # Attributes
2146///
2147/// - `#[proto(tag = N)]` — Set the protobuf field tag number (default: 1-indexed position)
2148///
2149/// # Supported field types
2150///
2151/// | Rust type | Proto type | Wire type |
2152/// |-----------|-----------|-----------|
2153/// | `u32` | uint32 | varint (0) |
2154/// | `u64` | uint64 | varint (0) |
2155/// | `i32` | int32 | varint (0) |
2156/// | `i64` | int64 | varint (0) |
2157/// | `bool` | bool | varint (0) |
2158/// | `f32` | float | 32-bit (5) |
2159/// | `f64` | double | 64-bit (1) |
2160/// | `String` | string | len-delimited (2) |
2161/// | `Vec<u8>` | bytes | len-delimited (2) |
2162/// | `Vec<T>` | repeated T | (varies) |
2163/// | `Option<T>` | optional T | (varies) |
2164///
2165/// # Example
2166///
2167/// ```ignore
2168/// #[derive(TypewayCodec)]
2169/// struct User {
2170///     #[proto(tag = 1)]
2171///     id: u32,
2172///     #[proto(tag = 2)]
2173///     name: String,
2174///     #[proto(tag = 3)]
2175///     active: bool,
2176/// }
2177/// ```
2178#[proc_macro_derive(TypewayCodec, attributes(proto))]
2179pub fn derive_typeway_codec(input: TokenStream) -> TokenStream {
2180    let input = syn::parse_macro_input!(input as syn::DeriveInput);
2181    match derive_typeway_codec_impl(input) {
2182        Ok(ts) => ts.into(),
2183        Err(e) => e.to_compile_error().into(),
2184    }
2185}
2186
2187/// Derive a typestate builder for a struct.
2188///
2189/// Fields marked `#[required]` must be set before `.build()` is available.
2190/// Optional fields (those not marked `#[required]`) can be set in any order.
2191///
2192/// ```ignore
2193/// #[derive(TypestateBuilder)]
2194/// struct User {
2195///     #[required]
2196///     id: u32,
2197///     #[required]
2198///     name: String,
2199///     email: Option<String>,
2200/// }
2201///
2202/// let user = User::builder().id(42).name("Alice".into()).build();
2203/// ```
2204#[proc_macro_derive(TypestateBuilder, attributes(required))]
2205pub fn derive_typestate_builder(input: TokenStream) -> TokenStream {
2206    let input = syn::parse_macro_input!(input as syn::DeriveInput);
2207    match derive_typestate_builder_impl(input) {
2208        Ok(ts) => ts.into(),
2209        Err(e) => e.to_compile_error().into(),
2210    }
2211}
2212
2213fn derive_typestate_builder_impl(input: syn::DeriveInput) -> syn::Result<TokenStream2> {
2214    let name = &input.ident;
2215
2216    let fields = match &input.data {
2217        syn::Data::Struct(data) => match &data.fields {
2218            syn::Fields::Named(named) => &named.named,
2219            _ => {
2220                return Err(syn::Error::new_spanned(
2221                    name,
2222                    "TypestateBuilder requires named fields",
2223                ))
2224            }
2225        },
2226        _ => {
2227            return Err(syn::Error::new_spanned(
2228                name,
2229                "TypestateBuilder only supports structs",
2230            ))
2231        }
2232    };
2233
2234    // Classify fields as required or optional.
2235    let mut required_fields = Vec::new();
2236    let mut optional_fields = Vec::new();
2237
2238    for field in fields.iter() {
2239        let ident = field.ident.as_ref().unwrap();
2240        let ty = &field.ty;
2241        let is_required = field.attrs.iter().any(|a| a.path().is_ident("required"));
2242        if is_required {
2243            required_fields.push((ident.clone(), ty.clone()));
2244        } else {
2245            optional_fields.push((ident.clone(), ty.clone()));
2246        }
2247    }
2248
2249    // Generate type parameter names for required fields.
2250    let req_type_params: Vec<Ident> = required_fields
2251        .iter()
2252        .map(|(ident, _)| {
2253            Ident::new(
2254                &format!("__{}", ident.to_string().to_uppercase()),
2255                ident.span(),
2256            )
2257        })
2258        .collect();
2259
2260    let builder_name = Ident::new(&format!("{}Builder", name), name.span());
2261
2262    // Builder struct: carries Option<T> for each field + phantom type params.
2263    let builder_fields: Vec<TokenStream2> = fields
2264        .iter()
2265        .map(|f| {
2266            let ident = f.ident.as_ref().unwrap();
2267            let ty = &f.ty;
2268            quote! { #ident: ::core::option::Option<#ty>, }
2269        })
2270        .collect();
2271
2272    let phantom_fields: Vec<TokenStream2> = req_type_params
2273        .iter()
2274        .map(|p| quote! { #p: ::core::marker::PhantomData<#p>, })
2275        .collect();
2276
2277    // Default builder: all fields None, all type params = Missing.
2278    let missing_type_args: Vec<TokenStream2> = req_type_params
2279        .iter()
2280        .map(|_| quote! { ::typeway_protobuf::builder::Missing })
2281        .collect();
2282
2283    let set_type = quote! { ::typeway_protobuf::builder::Set };
2284
2285    let field_none_inits: Vec<TokenStream2> = fields
2286        .iter()
2287        .map(|f| {
2288            let ident = f.ident.as_ref().unwrap();
2289            quote! { #ident: ::core::option::Option::None, }
2290        })
2291        .collect();
2292
2293    let phantom_none_inits: Vec<TokenStream2> = req_type_params
2294        .iter()
2295        .map(|p| quote! { #p: ::core::marker::PhantomData, })
2296        .collect();
2297
2298    // Setter methods for required fields — each transitions one type param from Missing to Set.
2299    let required_setters: Vec<TokenStream2> = required_fields
2300        .iter()
2301        .zip(req_type_params.iter())
2302        .enumerate()
2303        .map(|(idx, ((ident, ty), param))| {
2304            // Build the return type: same type params but this one is Set.
2305            let mut ret_params: Vec<TokenStream2> = req_type_params
2306                .iter()
2307                .enumerate()
2308                .map(|(i, p)| {
2309                    if i == idx {
2310                        set_type.clone()
2311                    } else {
2312                        quote! { #p }
2313                    }
2314                })
2315                .collect();
2316            let _ = &mut ret_params; // suppress unused
2317
2318            let ret_type_params: Vec<TokenStream2> = req_type_params
2319                .iter()
2320                .enumerate()
2321                .map(|(i, p)| {
2322                    if i == idx {
2323                        set_type.clone()
2324                    } else {
2325                        quote! { #p }
2326                    }
2327                })
2328                .collect();
2329
2330            let other_params: Vec<&Ident> = req_type_params
2331                .iter()
2332                .enumerate()
2333                .filter(|(i, _)| *i != idx)
2334                .map(|(_, p)| p)
2335                .collect();
2336
2337            // Copy all fields + phantoms, replacing this field's phantom.
2338            let field_copies: Vec<TokenStream2> = fields
2339                .iter()
2340                .map(|f| {
2341                    let fi = f.ident.as_ref().unwrap();
2342                    if fi == ident {
2343                        quote! { #fi: ::core::option::Option::Some(value), }
2344                    } else {
2345                        quote! { #fi: self.#fi, }
2346                    }
2347                })
2348                .collect();
2349
2350            let phantom_copies: Vec<TokenStream2> = req_type_params
2351                .iter()
2352                .map(|p| quote! { #p: ::core::marker::PhantomData, })
2353                .collect();
2354
2355            let _ = &other_params;
2356            let _ = param;
2357
2358            quote! {
2359                /// Set the `#ident` field (required).
2360                pub fn #ident(self, value: #ty) -> #builder_name<#(#ret_type_params),*> {
2361                    #builder_name {
2362                        #(#field_copies)*
2363                        #(#phantom_copies)*
2364                    }
2365                }
2366            }
2367        })
2368        .collect();
2369
2370    // Setter methods for optional fields — available on any builder state.
2371    let optional_setters: Vec<TokenStream2> = optional_fields
2372        .iter()
2373        .map(|(ident, ty)| {
2374            // For Option<T> fields, accept T directly.
2375            let inner_ty = is_option_type(ty);
2376            let (param_ty, wrap) = if let Some(inner) = inner_ty {
2377                (
2378                    inner.clone(),
2379                    quote! { ::core::option::Option::Some(::core::option::Option::Some(value)) },
2380                )
2381            } else {
2382                (ty.clone(), quote! { ::core::option::Option::Some(value) })
2383            };
2384            quote! {
2385                /// Set the `#ident` field (optional).
2386                pub fn #ident(mut self, value: #param_ty) -> Self {
2387                    self.#ident = #wrap;
2388                    self
2389                }
2390            }
2391        })
2392        .collect();
2393
2394    // Build method: only available when all type params are Set.
2395    let all_set: Vec<TokenStream2> = req_type_params.iter().map(|_| set_type.clone()).collect();
2396
2397    let build_fields: Vec<TokenStream2> = fields
2398        .iter()
2399        .map(|f| {
2400            let ident = f.ident.as_ref().unwrap();
2401            let is_required = f.attrs.iter().any(|a| a.path().is_ident("required"));
2402            if is_required {
2403                // Required field: unwrap is safe because the typestate guarantees it's set.
2404                quote! { #ident: self.#ident.unwrap(), }
2405            } else if is_option_type(&f.ty).is_some() {
2406                // Optional<T> field: flatten Option<Option<T>> → Option<T>.
2407                quote! { #ident: self.#ident.flatten(), }
2408            } else {
2409                // Non-optional, non-required: use default.
2410                quote! { #ident: self.#ident.unwrap_or_default(), }
2411            }
2412        })
2413        .collect();
2414
2415    let expanded = quote! {
2416        /// Typestate builder for [`#name`].
2417        pub struct #builder_name<#(#req_type_params = ::typeway_protobuf::builder::Missing),*> {
2418            #(#builder_fields)*
2419            #(#phantom_fields)*
2420        }
2421
2422        impl #name {
2423            /// Create a typestate builder. Required fields must be set before `.build()`.
2424            pub fn builder() -> #builder_name<#(#missing_type_args),*> {
2425                #builder_name {
2426                    #(#field_none_inits)*
2427                    #(#phantom_none_inits)*
2428                }
2429            }
2430        }
2431
2432        impl<#(#req_type_params),*> #builder_name<#(#req_type_params),*> {
2433            #(#optional_setters)*
2434        }
2435
2436        impl<#(#req_type_params),*> #builder_name<#(#req_type_params),*> {
2437            #(#required_setters)*
2438        }
2439
2440        impl #builder_name<#(#all_set),*> {
2441            /// Build the message. Only available when all required fields are set.
2442            pub fn build(self) -> #name {
2443                #name {
2444                    #(#build_fields)*
2445                }
2446            }
2447        }
2448
2449        impl ::typeway_protobuf::builder::HasBuilder for #name {
2450            type Builder = #builder_name;
2451            fn builder() -> Self::Builder {
2452                #name::builder()
2453            }
2454        }
2455    };
2456
2457    Ok(expanded)
2458}
2459
2460fn derive_typeway_codec_impl(input: syn::DeriveInput) -> syn::Result<TokenStream2> {
2461    let name = &input.ident;
2462
2463    match &input.data {
2464        syn::Data::Struct(data) => match &data.fields {
2465            syn::Fields::Named(named) => derive_typeway_codec_struct(name, &named.named),
2466            _ => Err(syn::Error::new_spanned(
2467                name,
2468                "TypewayCodec only supports structs with named fields",
2469            )),
2470        },
2471        syn::Data::Enum(data) => {
2472            let is_simple = data.variants.iter().all(|v| v.fields.is_empty());
2473            if is_simple {
2474                derive_typeway_codec_simple_enum(name, data)
2475            } else {
2476                derive_typeway_codec_oneof_enum(name, data)
2477            }
2478        }
2479        _ => Err(syn::Error::new_spanned(
2480            name,
2481            "TypewayCodec does not support unions",
2482        )),
2483    }
2484}
2485
2486/// Generate TypewayEncode/TypewayDecode for a simple (fieldless) enum.
2487///
2488/// Encodes as a varint (i32). Each variant maps to its proto tag value.
2489/// Default variant (tag 0) is the first variant.
2490fn derive_typeway_codec_simple_enum(
2491    name: &Ident,
2492    data: &syn::DataEnum,
2493) -> syn::Result<TokenStream2> {
2494    let mut variant_idents = Vec::new();
2495    let mut variant_tags: Vec<u32> = Vec::new();
2496
2497    for (i, variant) in data.variants.iter().enumerate() {
2498        let tag = extract_proto_tag(&variant.attrs).unwrap_or(i as u32);
2499        variant_idents.push(&variant.ident);
2500        variant_tags.push(tag);
2501    }
2502
2503    let first_variant = &variant_idents[0];
2504
2505    // Encode: match variant → tag value, encode as varint.
2506    let encode_arms: Vec<TokenStream2> = variant_idents
2507        .iter()
2508        .zip(variant_tags.iter())
2509        .map(|(ident, tag)| {
2510            let tag_u64 = *tag as u64;
2511            quote! { #name::#ident => #tag_u64, }
2512        })
2513        .collect();
2514
2515    let len_arms: Vec<TokenStream2> = variant_idents
2516        .iter()
2517        .zip(variant_tags.iter())
2518        .map(|(ident, tag)| {
2519            let tag_u64 = *tag as u64;
2520            quote! { #name::#ident => ::typeway_protobuf::tw_varint_len(#tag_u64), }
2521        })
2522        .collect();
2523
2524    // Decode: read varint, match tag → variant.
2525    let decode_arms: Vec<TokenStream2> = variant_idents
2526        .iter()
2527        .zip(variant_tags.iter())
2528        .map(|(ident, tag)| {
2529            let tag_u32 = *tag;
2530            quote! { #tag_u32 => #name::#ident, }
2531        })
2532        .collect();
2533
2534    Ok(quote! {
2535        impl ::typeway_protobuf::TypewayEncode for #name {
2536            fn encoded_len(&self) -> usize {
2537                match self {
2538                    #(#len_arms)*
2539                }
2540            }
2541
2542            fn encode_to(&self, buf: &mut ::std::vec::Vec<u8>) {
2543                let val: u64 = match self {
2544                    #(#encode_arms)*
2545                };
2546                ::typeway_protobuf::tw_encode_varint(buf, val);
2547            }
2548        }
2549
2550        impl ::typeway_protobuf::TypewayDecode for #name {
2551            fn typeway_decode(
2552                bytes: &[u8],
2553            ) -> ::core::result::Result<Self, ::typeway_protobuf::TypewayDecodeError> {
2554                if bytes.is_empty() {
2555                    return Ok(#name::#first_variant);
2556                }
2557                let (val, _consumed) = ::typeway_protobuf::tw_decode_varint(bytes)?;
2558                let val = val as u32;
2559                Ok(match val {
2560                    #(#decode_arms)*
2561                    _ => #name::#first_variant,
2562                })
2563            }
2564        }
2565    })
2566}
2567
2568/// Generate TypewayEncode/TypewayDecode for a tagged enum (protobuf oneof).
2569///
2570/// Each variant is encoded as a tagged field in a message: tag + wire_type + value.
2571/// Only one variant is present at a time.
2572fn derive_typeway_codec_oneof_enum(
2573    name: &Ident,
2574    data: &syn::DataEnum,
2575) -> syn::Result<TokenStream2> {
2576    let mut variant_idents = Vec::new();
2577    let mut variant_tags: Vec<u32> = Vec::new();
2578    let mut variant_types: Vec<syn::Type> = Vec::new();
2579
2580    for (i, variant) in data.variants.iter().enumerate() {
2581        let tag = extract_proto_tag(&variant.attrs).unwrap_or((i + 1) as u32);
2582        variant_idents.push(&variant.ident);
2583        variant_tags.push(tag);
2584
2585        match &variant.fields {
2586            syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2587                variant_types.push(fields.unnamed[0].ty.clone());
2588            }
2589            syn::Fields::Unnamed(_) => {
2590                return Err(syn::Error::new_spanned(
2591                    &variant.ident,
2592                    "TypewayCodec oneof variants must have exactly one field",
2593                ));
2594            }
2595            syn::Fields::Named(_) => {
2596                return Err(syn::Error::new_spanned(
2597                    &variant.ident,
2598                    "TypewayCodec oneof variants must use tuple syntax, e.g., Variant(Type)",
2599                ));
2600            }
2601            syn::Fields::Unit => {
2602                return Err(syn::Error::new_spanned(
2603                    &variant.ident,
2604                    "oneof variants must have a field; use a simple enum for fieldless variants",
2605                ));
2606            }
2607        }
2608    }
2609
2610    // Encode, encoded_len, decode arms — dispatch on CodecKind for correctness.
2611    let mut encode_arms = Vec::new();
2612    let mut len_arms = Vec::new();
2613    let mut decode_arms = Vec::new();
2614
2615    for ((ident, tag), ty) in variant_idents
2616        .iter()
2617        .zip(variant_tags.iter())
2618        .zip(variant_types.iter())
2619    {
2620        let kind = oneof_codec_kind(ty);
2621        let wt = wire_type_for_kind(&kind);
2622        let tp = emit_tag_push(*tag, wt);
2623        let tag_len = if *tag < 16 {
2624            1usize
2625        } else if *tag < 2048 {
2626            2
2627        } else {
2628            3
2629        };
2630
2631        let (enc, len_expr, dec) = match &kind {
2632            CodecKind::Varint => (
2633                quote! { #name::#ident(ref val) => { #tp ::typeway_protobuf::tw_encode_varint(buf, *val as u64); } },
2634                quote! { #name::#ident(ref val) => { #tag_len + ::typeway_protobuf::tw_varint_len(*val as u64) } },
2635                quote! { #tag => { let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?; offset += consumed; result = #name::#ident(val as _); } },
2636            ),
2637            CodecKind::Bool => (
2638                quote! { #name::#ident(ref val) => { #tp buf.push(if *val { 1 } else { 0 }); } },
2639                quote! { #name::#ident(_) => { #tag_len + 1 } },
2640                quote! { #tag => { let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?; offset += consumed; result = #name::#ident(val != 0); } },
2641            ),
2642            CodecKind::Fixed32 => (
2643                quote! { #name::#ident(ref val) => { #tp buf.extend_from_slice(&val.to_le_bytes()); } },
2644                quote! { #name::#ident(_) => { #tag_len + 4 } },
2645                quote! { #tag => {
2646                    if offset + 4 > bytes.len() { return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof); }
2647                    let val = f32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]);
2648                    offset += 4; result = #name::#ident(val);
2649                } },
2650            ),
2651            CodecKind::Fixed64 => (
2652                quote! { #name::#ident(ref val) => { #tp buf.extend_from_slice(&val.to_le_bytes()); } },
2653                quote! { #name::#ident(_) => { #tag_len + 8 } },
2654                quote! { #tag => {
2655                    if offset + 8 > bytes.len() { return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof); }
2656                    let val = f64::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3], bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7]]);
2657                    offset += 8; result = #name::#ident(val);
2658                } },
2659            ),
2660            CodecKind::LenString => (
2661                quote! { #name::#ident(ref val) => { #tp ::typeway_protobuf::tw_encode_varint(buf, val.len() as u64); buf.extend_from_slice(val.as_bytes()); } },
2662                quote! { #name::#ident(ref val) => { #tag_len + ::typeway_protobuf::tw_varint_len(val.len() as u64) + val.len() } },
2663                quote! { #tag => {
2664                    let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2665                    offset += consumed; let str_len = str_len as usize;
2666                    if offset + str_len > bytes.len() { return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof); }
2667                    let slice = &bytes[offset..offset + str_len];
2668                    ::core::str::from_utf8(slice).map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8("oneof"))?;
2669                    result = #name::#ident(unsafe { String::from_utf8_unchecked(slice.to_vec()) });
2670                    offset += str_len;
2671                } },
2672            ),
2673            CodecKind::LenBytesStr => (
2674                quote! { #name::#ident(ref val) => { #tp ::typeway_protobuf::tw_encode_varint(buf, val.len() as u64); buf.extend_from_slice(val.as_bytes()); } },
2675                quote! { #name::#ident(ref val) => { #tag_len + ::typeway_protobuf::tw_varint_len(val.len() as u64) + val.len() } },
2676                quote! { #tag => {
2677                    let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2678                    offset += consumed; let str_len = str_len as usize;
2679                    if offset + str_len > bytes.len() { return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof); }
2680                    let slice = &bytes[offset..offset + str_len];
2681                    ::core::str::from_utf8(slice).map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8("oneof"))?;
2682                    result = #name::#ident(::typeway_protobuf::BytesStr::from(unsafe { String::from_utf8_unchecked(slice.to_vec()) }));
2683                    offset += str_len;
2684                } },
2685            ),
2686            // Message type (nested struct implementing TypewayEncode/TypewayDecode)
2687            _ => (
2688                quote! { #name::#ident(ref val) => {
2689                    #tp
2690                    let nested = ::typeway_protobuf::TypewayEncode::encode_to_vec(val);
2691                    ::typeway_protobuf::tw_encode_varint(buf, nested.len() as u64);
2692                    buf.extend_from_slice(&nested);
2693                } },
2694                quote! { #name::#ident(ref val) => {
2695                    let nested_len = ::typeway_protobuf::TypewayEncode::encoded_len(val);
2696                    #tag_len + ::typeway_protobuf::tw_varint_len(nested_len as u64) + nested_len
2697                } },
2698                quote! { #tag => {
2699                    let (msg_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2700                    offset += consumed; let msg_len = msg_len as usize;
2701                    if offset + msg_len > bytes.len() { return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof); }
2702                    result = #name::#ident(::typeway_protobuf::TypewayDecode::typeway_decode(&bytes[offset..offset+msg_len])?);
2703                    offset += msg_len;
2704                } },
2705            ),
2706        };
2707
2708        encode_arms.push(enc);
2709        len_arms.push(len_expr);
2710        decode_arms.push(dec);
2711    }
2712
2713    let first_variant = &variant_idents[0];
2714    let first_type = &variant_types[0];
2715
2716    Ok(quote! {
2717        impl ::typeway_protobuf::TypewayEncode for #name {
2718            fn encoded_len(&self) -> usize {
2719                match self {
2720                    #(#len_arms)*
2721                }
2722            }
2723
2724            fn encode_to(&self, buf: &mut ::std::vec::Vec<u8>) {
2725                match self {
2726                    #(#encode_arms)*
2727                }
2728            }
2729        }
2730
2731        impl ::typeway_protobuf::TypewayDecode for #name {
2732            fn typeway_decode(
2733                bytes: &[u8],
2734            ) -> ::core::result::Result<Self, ::typeway_protobuf::TypewayDecodeError> {
2735                let mut result = #name::#first_variant(
2736                    <#first_type as ::core::default::Default>::default()
2737                );
2738                let mut offset: usize = 0;
2739                while offset < bytes.len() {
2740                    let (tag_wire, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2741                    offset += consumed;
2742                    let field_number = (tag_wire >> 3) as u32;
2743                    let wire_type = (tag_wire & 0x07) as u8;
2744                    match field_number {
2745                        #(#decode_arms)*
2746                        _ => {
2747                            let skipped = ::typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
2748                            offset += skipped;
2749                        }
2750                    }
2751                }
2752                Ok(result)
2753            }
2754        }
2755    })
2756}
2757
2758/// Classify a Rust type for oneof codec generation.
2759fn oneof_codec_kind(ty: &syn::Type) -> CodecKind {
2760    classify_scalar(ty)
2761}
2762
2763/// Information about a single field for codec generation.
2764struct CodecField {
2765    ident: Ident,
2766    ty: syn::Type,
2767    tag: u32,
2768    codec_kind: CodecKind,
2769}
2770
2771/// What kind of encoding a field needs.
2772enum CodecKind {
2773    /// Varint (wire type 0): u32, u64, i32, i64
2774    Varint,
2775    /// Bool (wire type 0, but needs special encode/decode)
2776    Bool,
2777    /// Fixed 32-bit (wire type 5): f32
2778    Fixed32,
2779    /// Fixed 64-bit (wire type 1): f64
2780    Fixed64,
2781    /// Length-delimited String (wire type 2)
2782    LenString,
2783    /// Length-delimited BytesStr (wire type 2, zero-copy decode)
2784    LenBytesStr,
2785    /// Length-delimited bytes (wire type 2)
2786    LenBytes,
2787    /// Optional wrapper around another kind
2788    Optional(Box<CodecKind>),
2789    /// Repeated (Vec<T>) wrapper — element kind + element type for iteration
2790    Repeated(Box<CodecKind>),
2791    /// Nested message that also implements TypewayEncode/TypewayDecode
2792    Message,
2793    /// Optional nested message
2794    OptionalMessage,
2795    /// Repeated nested message
2796    RepeatedMessage,
2797}
2798
2799fn classify_type(ty: &syn::Type) -> CodecKind {
2800    if let Some(inner) = is_option_type(ty) {
2801        let inner_kind = classify_type(inner);
2802        match inner_kind {
2803            CodecKind::Message => CodecKind::OptionalMessage,
2804            other => CodecKind::Optional(Box::new(other)),
2805        }
2806    } else if is_vec_u8(ty) {
2807        CodecKind::LenBytes
2808    } else if let Some(inner) = is_vec_type(ty) {
2809        let inner_kind = classify_type(inner);
2810        match inner_kind {
2811            CodecKind::Message => CodecKind::RepeatedMessage,
2812            other => CodecKind::Repeated(Box::new(other)),
2813        }
2814    } else {
2815        classify_scalar(ty)
2816    }
2817}
2818
2819fn classify_scalar(ty: &syn::Type) -> CodecKind {
2820    let ty_str = quote!(#ty).to_string().replace(' ', "");
2821    match ty_str.as_str() {
2822        "u32" | "u64" | "i32" | "i64" => CodecKind::Varint,
2823        "bool" => CodecKind::Bool,
2824        "f32" => CodecKind::Fixed32,
2825        "f64" => CodecKind::Fixed64,
2826        "String" => CodecKind::LenString,
2827        "BytesStr" | "typeway_protobuf::BytesStr" => CodecKind::LenBytesStr,
2828        _ => CodecKind::Message,
2829    }
2830}
2831
2832fn wire_type_for_kind(kind: &CodecKind) -> u8 {
2833    match kind {
2834        CodecKind::Varint | CodecKind::Bool => 0,
2835        CodecKind::Fixed64 => 1,
2836        CodecKind::LenString
2837        | CodecKind::LenBytesStr
2838        | CodecKind::LenBytes
2839        | CodecKind::Message => 2,
2840        CodecKind::Fixed32 => 5,
2841        CodecKind::Optional(inner) | CodecKind::Repeated(inner) => wire_type_for_kind(inner),
2842        CodecKind::OptionalMessage | CodecKind::RepeatedMessage => 2,
2843    }
2844}
2845
2846fn derive_typeway_codec_struct(
2847    name: &Ident,
2848    fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
2849) -> syn::Result<TokenStream2> {
2850    // Parse fields and validate tags.
2851    let mut codec_fields = Vec::new();
2852    let mut seen_tags = std::collections::HashMap::new();
2853    for (i, field) in fields.iter().enumerate() {
2854        let ident = field.ident.clone().unwrap();
2855        let tag = extract_proto_tag(&field.attrs).unwrap_or((i as u32) + 1);
2856
2857        if tag == 0 {
2858            return Err(syn::Error::new_spanned(
2859                &ident,
2860                "proto tag 0 is reserved; tags must be >= 1",
2861            ));
2862        }
2863
2864        if let Some(prev_ident) = seen_tags.get(&tag) {
2865            return Err(syn::Error::new_spanned(
2866                &ident,
2867                format!("duplicate proto tag {tag}: already used by field `{prev_ident}`"),
2868            ));
2869        }
2870        seen_tags.insert(tag, ident.to_string());
2871
2872        let codec_kind = classify_type(&field.ty);
2873        codec_fields.push(CodecField {
2874            ident,
2875            ty: field.ty.clone(),
2876            tag,
2877            codec_kind,
2878        });
2879    }
2880
2881    // Generate encode_to body.
2882    let encode_stmts: Vec<TokenStream2> = codec_fields.iter().map(gen_encode_field).collect();
2883
2884    // Generate encoded_len body.
2885    let len_stmts: Vec<TokenStream2> = codec_fields.iter().map(gen_encoded_len_field).collect();
2886
2887    // Generate decode body.
2888    let field_defaults: Vec<TokenStream2> = codec_fields
2889        .iter()
2890        .map(|f| {
2891            let ident = &f.ident;
2892            let ty = &f.ty;
2893            quote! { let mut #ident: #ty = ::core::default::Default::default(); }
2894        })
2895        .collect();
2896
2897    let decode_arms: Vec<TokenStream2> = codec_fields.iter().map(gen_decode_arm).collect();
2898
2899    let decode_bytes_arms: Vec<TokenStream2> =
2900        codec_fields.iter().map(gen_decode_bytes_arm).collect();
2901
2902    let field_names: Vec<&Ident> = codec_fields.iter().map(|f| &f.ident).collect();
2903
2904    Ok(quote! {
2905        impl ::typeway_protobuf::TypewayEncode for #name {
2906            fn encoded_len(&self) -> usize {
2907                let mut len: usize = 0;
2908                #(#len_stmts)*
2909                len
2910            }
2911
2912            fn encode_to(&self, buf: &mut ::std::vec::Vec<u8>) {
2913                #(#encode_stmts)*
2914            }
2915        }
2916
2917        impl ::typeway_protobuf::TypewayDecode for #name {
2918            fn typeway_decode(
2919                bytes: &[u8],
2920            ) -> ::core::result::Result<Self, ::typeway_protobuf::TypewayDecodeError> {
2921                #(#field_defaults)*
2922                let mut offset: usize = 0;
2923
2924                while offset < bytes.len() {
2925                    let (tag_wire, consumed) =
2926                        ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2927                    offset += consumed;
2928                    let field_number = (tag_wire >> 3) as u32;
2929                    let wire_type = (tag_wire & 0x07) as u8;
2930
2931                    match field_number {
2932                        #(#decode_arms)*
2933                        _ => {
2934                            let skipped =
2935                                ::typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
2936                            offset += skipped;
2937                        }
2938                    }
2939                }
2940
2941                Ok(#name { #(#field_names),* })
2942            }
2943
2944            fn typeway_decode_bytes(
2945                input: ::bytes::Bytes,
2946            ) -> ::core::result::Result<Self, ::typeway_protobuf::TypewayDecodeError> {
2947                let bytes = &input[..];
2948                #(#field_defaults)*
2949                let mut offset: usize = 0;
2950
2951                while offset < bytes.len() {
2952                    let (tag_wire, consumed) =
2953                        ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
2954                    offset += consumed;
2955                    let field_number = (tag_wire >> 3) as u32;
2956                    let wire_type = (tag_wire & 0x07) as u8;
2957
2958                    match field_number {
2959                        #(#decode_bytes_arms)*
2960                        _ => {
2961                            let skipped =
2962                                ::typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
2963                            offset += skipped;
2964                        }
2965                    }
2966                }
2967
2968                Ok(#name { #(#field_names),* })
2969            }
2970        }
2971    })
2972}
2973
2974/// Precompute the tag+wiretype byte(s) at macro expansion time.
2975fn precompute_tag_byte(field_number: u32, wire_type: u8) -> u8 {
2976    ((field_number << 3) | (wire_type as u32)) as u8
2977}
2978
2979/// Emit code to write a precomputed tag byte (fields 1-15).
2980fn emit_tag_push(tag: u32, wt: u8) -> TokenStream2 {
2981    let byte = precompute_tag_byte(tag, wt);
2982    if tag < 16 {
2983        // Single byte — just push the constant.
2984        quote! { buf.push(#byte); }
2985    } else {
2986        // Multi-byte tag — use the varint encoder.
2987        quote! { ::typeway_protobuf::tw_encode_tag(buf, #tag, #wt); }
2988    }
2989}
2990
2991fn gen_encode_field(f: &CodecField) -> TokenStream2 {
2992    let ident = &f.ident;
2993    let tag = f.tag;
2994    let wt = wire_type_for_kind(&f.codec_kind);
2995    let tag_push = emit_tag_push(tag, wt);
2996
2997    match &f.codec_kind {
2998        CodecKind::Varint => quote! {
2999            if self.#ident != 0 {
3000                #tag_push
3001                ::typeway_protobuf::tw_encode_varint(buf, self.#ident as u64);
3002            }
3003        },
3004        CodecKind::Bool => {
3005            // Tag + value as two bytes pushed together.
3006            let tag_byte = precompute_tag_byte(tag, wt);
3007            quote! {
3008                if self.#ident {
3009                    buf.extend_from_slice(&[#tag_byte, 1]);
3010                }
3011            }
3012        }
3013        CodecKind::Fixed32 => quote! {
3014            if self.#ident != 0.0 {
3015                #tag_push
3016                buf.extend_from_slice(&self.#ident.to_le_bytes());
3017            }
3018        },
3019        CodecKind::Fixed64 => quote! {
3020            if self.#ident != 0.0 {
3021                #tag_push
3022                buf.extend_from_slice(&self.#ident.to_le_bytes());
3023            }
3024        },
3025        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3026            if !self.#ident.is_empty() {
3027                #tag_push
3028                ::typeway_protobuf::tw_encode_varint(buf, self.#ident.len() as u64);
3029                buf.extend_from_slice(self.#ident.as_bytes());
3030            }
3031        },
3032        CodecKind::LenBytes => quote! {
3033            if !self.#ident.is_empty() {
3034                #tag_push
3035                ::typeway_protobuf::tw_encode_varint(buf, self.#ident.len() as u64);
3036                buf.extend_from_slice(&self.#ident);
3037            }
3038        },
3039        CodecKind::Message => quote! {
3040            {
3041                let nested = ::typeway_protobuf::TypewayEncode::encode_to_vec(&self.#ident);
3042                if !nested.is_empty() {
3043                    #tag_push
3044                    ::typeway_protobuf::tw_encode_varint(buf, nested.len() as u64);
3045                    buf.extend_from_slice(&nested);
3046                }
3047            }
3048        },
3049        CodecKind::Optional(inner) => {
3050            let inner_encode = gen_encode_optional_inner(tag, wt, inner);
3051            quote! {
3052                if let Some(ref val) = self.#ident {
3053                    #inner_encode
3054                }
3055            }
3056        }
3057        CodecKind::OptionalMessage => quote! {
3058            if let Some(ref val) = self.#ident {
3059                let nested = ::typeway_protobuf::TypewayEncode::encode_to_vec(val);
3060                #tag_push
3061                ::typeway_protobuf::tw_encode_varint(buf, nested.len() as u64);
3062                buf.extend_from_slice(&nested);
3063            }
3064        },
3065        CodecKind::Repeated(inner) => {
3066            if is_packable(inner) {
3067                let item_write = gen_packed_item_write(inner);
3068                let is_varint = matches!(inner.as_ref(), CodecKind::Varint);
3069                let packed_tag_push = emit_tag_push(tag, 2);
3070                if is_varint {
3071                    quote! {
3072                        if !self.#ident.is_empty() {
3073                            #packed_tag_push
3074                            let len_pos = buf.len();
3075                            buf.push(0); // placeholder for length
3076                            let data_start = buf.len();
3077                            buf.reserve(self.#ident.len() * 10);
3078                            // Batch unsafe write: ONE set_len for all varints.
3079                            unsafe {
3080                                let base = buf.as_mut_ptr();
3081                                let mut pos = data_start;
3082                                for item in &self.#ident {
3083                                    let mut v = *item as u64;
3084                                    while v >= 0x80 {
3085                                        *base.add(pos) = (v as u8 & 0x7F) | 0x80;
3086                                        v >>= 7;
3087                                        pos += 1;
3088                                    }
3089                                    *base.add(pos) = v as u8;
3090                                    pos += 1;
3091                                }
3092                                buf.set_len(pos);
3093                            }
3094                            let packed_len = buf.len() - data_start;
3095                            if packed_len < 0x80 {
3096                                buf[len_pos] = packed_len as u8;
3097                            } else {
3098                                let data = buf[data_start..].to_vec();
3099                                buf.truncate(len_pos);
3100                                ::typeway_protobuf::tw_encode_varint(buf, packed_len as u64);
3101                                buf.extend_from_slice(&data);
3102                            }
3103                        }
3104                    }
3105                } else {
3106                    // Fixed-size types: length is known without iterating.
3107                    let packed_len_expr = match inner.as_ref() {
3108                        CodecKind::Fixed32 => quote! { self.#ident.len() * 4 },
3109                        CodecKind::Fixed64 => quote! { self.#ident.len() * 8 },
3110                        CodecKind::Bool => quote! { self.#ident.len() },
3111                        _ => unreachable!(),
3112                    };
3113                    // For Fixed32/Fixed64: bulk memcpy instead of per-element writes.
3114                    let bulk_write = match inner.as_ref() {
3115                        CodecKind::Fixed64 => quote! {
3116                            // Safety: f64 is 8 bytes, same layout as [u8; 8] on LE.
3117                            // On little-endian (most modern CPUs), IEEE 754 f64 is
3118                            // already in protobuf wire order.
3119                            #[cfg(target_endian = "little")]
3120                            {
3121                                let slice_bytes = unsafe {
3122                                    ::core::slice::from_raw_parts(
3123                                        self.#ident.as_ptr() as *const u8,
3124                                        self.#ident.len() * 8,
3125                                    )
3126                                };
3127                                buf.extend_from_slice(slice_bytes);
3128                            }
3129                            #[cfg(not(target_endian = "little"))]
3130                            {
3131                                for item in &self.#ident {
3132                                    buf.extend_from_slice(&item.to_le_bytes());
3133                                }
3134                            }
3135                        },
3136                        CodecKind::Fixed32 => quote! {
3137                            #[cfg(target_endian = "little")]
3138                            {
3139                                let slice_bytes = unsafe {
3140                                    ::core::slice::from_raw_parts(
3141                                        self.#ident.as_ptr() as *const u8,
3142                                        self.#ident.len() * 4,
3143                                    )
3144                                };
3145                                buf.extend_from_slice(slice_bytes);
3146                            }
3147                            #[cfg(not(target_endian = "little"))]
3148                            {
3149                                for item in &self.#ident {
3150                                    buf.extend_from_slice(&item.to_le_bytes());
3151                                }
3152                            }
3153                        },
3154                        _ => quote! {
3155                            for item in &self.#ident {
3156                                #item_write
3157                            }
3158                        },
3159                    };
3160                    quote! {
3161                        if !self.#ident.is_empty() {
3162                            let packed_len = #packed_len_expr;
3163                            #packed_tag_push
3164                            ::typeway_protobuf::tw_encode_varint(buf, packed_len as u64);
3165                            #bulk_write
3166                        }
3167                    }
3168                }
3169            } else {
3170                // Non-packable (strings, messages): per-element tag.
3171                let item_encode = gen_encode_repeated_item(tag, wt, inner);
3172                quote! {
3173                    for item in &self.#ident {
3174                        #item_encode
3175                    }
3176                }
3177            }
3178        }
3179        CodecKind::RepeatedMessage => quote! {
3180            for item in &self.#ident {
3181                let nested = ::typeway_protobuf::TypewayEncode::encode_to_vec(item);
3182                #tag_push
3183                ::typeway_protobuf::tw_encode_varint(buf, nested.len() as u64);
3184                buf.extend_from_slice(&nested);
3185            }
3186        },
3187    }
3188}
3189
3190/// Returns true if the inner type can use packed encoding (scalars only).
3191fn is_packable(kind: &CodecKind) -> bool {
3192    matches!(
3193        kind,
3194        CodecKind::Varint | CodecKind::Bool | CodecKind::Fixed32 | CodecKind::Fixed64
3195    )
3196}
3197
3198/// Generate the per-item write for packed encoding (no tag per item).
3199/// For varints, uses the unchecked variant (caller pre-reserves capacity).
3200fn gen_packed_item_write(kind: &CodecKind) -> TokenStream2 {
3201    match kind {
3202        CodecKind::Varint => quote! {
3203            unsafe { ::typeway_protobuf::tw_encode_varint_unchecked(buf, *item as u64); }
3204        },
3205        CodecKind::Bool => quote! {
3206            buf.push(if *item { 1 } else { 0 });
3207        },
3208        CodecKind::Fixed32 => quote! {
3209            buf.extend_from_slice(&item.to_le_bytes());
3210        },
3211        CodecKind::Fixed64 => quote! {
3212            buf.extend_from_slice(&item.to_le_bytes());
3213        },
3214        _ => quote! {},
3215    }
3216}
3217
3218/// Generate the per-item length for packed encoding.
3219fn gen_packed_item_len(kind: &CodecKind) -> TokenStream2 {
3220    match kind {
3221        CodecKind::Varint => quote! {
3222            ::typeway_protobuf::tw_varint_len(*item as u64)
3223        },
3224        CodecKind::Bool => quote! { 1 },
3225        CodecKind::Fixed32 => quote! { 4 },
3226        CodecKind::Fixed64 => quote! { 8 },
3227        _ => quote! { 0 },
3228    }
3229}
3230
3231/// Generate per-item read for packed decoding.
3232fn gen_packed_item_read(ident: &Ident, kind: &CodecKind) -> TokenStream2 {
3233    match kind {
3234        CodecKind::Varint => quote! {
3235            let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3236            offset += consumed;
3237            #ident.push(val as _);
3238        },
3239        CodecKind::Bool => quote! {
3240            #ident.push(bytes[offset] != 0);
3241            offset += 1;
3242        },
3243        CodecKind::Fixed32 => quote! {
3244            if offset + 4 > bytes.len() {
3245                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3246            }
3247            #ident.push(f32::from_le_bytes([
3248                bytes[offset], bytes[offset + 1],
3249                bytes[offset + 2], bytes[offset + 3],
3250            ]));
3251            offset += 4;
3252        },
3253        CodecKind::Fixed64 => quote! {
3254            if offset + 8 > bytes.len() {
3255                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3256            }
3257            #ident.push(f64::from_le_bytes([
3258                bytes[offset], bytes[offset + 1],
3259                bytes[offset + 2], bytes[offset + 3],
3260                bytes[offset + 4], bytes[offset + 5],
3261                bytes[offset + 6], bytes[offset + 7],
3262            ]));
3263            offset += 8;
3264        },
3265        _ => quote! {},
3266    }
3267}
3268
3269fn gen_encode_optional_inner(tag: u32, wt: u8, kind: &CodecKind) -> TokenStream2 {
3270    let tp = emit_tag_push(tag, wt);
3271    match kind {
3272        CodecKind::Varint => quote! {
3273            #tp
3274            ::typeway_protobuf::tw_encode_varint(buf, *val as u64);
3275        },
3276        CodecKind::Bool => quote! {
3277            #tp
3278            buf.push(if *val { 1 } else { 0 });
3279        },
3280        CodecKind::Fixed32 => quote! {
3281            #tp
3282            buf.extend_from_slice(&val.to_le_bytes());
3283        },
3284        CodecKind::Fixed64 => quote! {
3285            #tp
3286            buf.extend_from_slice(&val.to_le_bytes());
3287        },
3288        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3289            #tp
3290            ::typeway_protobuf::tw_encode_varint(buf, val.len() as u64);
3291            buf.extend_from_slice(val.as_bytes());
3292        },
3293        CodecKind::LenBytes => quote! {
3294            #tp
3295            ::typeway_protobuf::tw_encode_varint(buf, val.len() as u64);
3296            buf.extend_from_slice(val);
3297        },
3298        _ => quote! {},
3299    }
3300}
3301
3302fn gen_encode_repeated_item(tag: u32, wt: u8, kind: &CodecKind) -> TokenStream2 {
3303    let tp = emit_tag_push(tag, wt);
3304    match kind {
3305        CodecKind::Varint => quote! {
3306            #tp
3307            ::typeway_protobuf::tw_encode_varint(buf, *item as u64);
3308        },
3309        CodecKind::Fixed32 => quote! {
3310            #tp
3311            buf.extend_from_slice(&item.to_le_bytes());
3312        },
3313        CodecKind::Fixed64 => quote! {
3314            #tp
3315            buf.extend_from_slice(&item.to_le_bytes());
3316        },
3317        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3318            #tp
3319            ::typeway_protobuf::tw_encode_varint(buf, item.len() as u64);
3320            buf.extend_from_slice(item.as_bytes());
3321        },
3322        _ => quote! {},
3323    }
3324}
3325
3326fn gen_encoded_len_field(f: &CodecField) -> TokenStream2 {
3327    let ident = &f.ident;
3328    let tag = f.tag;
3329    // Precompute tag length at macro expansion time.
3330    let wt = wire_type_for_kind(&f.codec_kind);
3331    let tag_byte_count = if tag < 16 {
3332        1usize
3333    } else if tag < 2048 {
3334        2
3335    } else {
3336        3
3337    };
3338    let tag_len_expr = quote! { #tag_byte_count };
3339    let _ = wt; // used in computation above conceptually
3340
3341    match &f.codec_kind {
3342        CodecKind::Varint => quote! {
3343            if self.#ident != 0 {
3344                len += #tag_len_expr + ::typeway_protobuf::tw_varint_len(self.#ident as u64);
3345            }
3346        },
3347        CodecKind::Bool => quote! {
3348            if self.#ident {
3349                len += #tag_len_expr + 1;
3350            }
3351        },
3352        CodecKind::Fixed32 => quote! {
3353            if self.#ident != 0.0 {
3354                len += #tag_len_expr + 4;
3355            }
3356        },
3357        CodecKind::Fixed64 => quote! {
3358            if self.#ident != 0.0 {
3359                len += #tag_len_expr + 8;
3360            }
3361        },
3362        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3363            if !self.#ident.is_empty() {
3364                len += #tag_len_expr
3365                    + ::typeway_protobuf::tw_varint_len(self.#ident.len() as u64)
3366                    + self.#ident.len();
3367            }
3368        },
3369        CodecKind::LenBytes => quote! {
3370            if !self.#ident.is_empty() {
3371                len += #tag_len_expr
3372                    + ::typeway_protobuf::tw_varint_len(self.#ident.len() as u64)
3373                    + self.#ident.len();
3374            }
3375        },
3376        CodecKind::Message => quote! {
3377            {
3378                let nested_len = ::typeway_protobuf::TypewayEncode::encoded_len(&self.#ident);
3379                if nested_len > 0 {
3380                    len += #tag_len_expr
3381                        + ::typeway_protobuf::tw_varint_len(nested_len as u64)
3382                        + nested_len;
3383                }
3384            }
3385        },
3386        CodecKind::Optional(inner) => {
3387            let inner_len = gen_encoded_len_optional_inner(tag, inner);
3388            quote! {
3389                if let Some(ref val) = self.#ident {
3390                    #inner_len
3391                }
3392            }
3393        }
3394        CodecKind::OptionalMessage => quote! {
3395            if let Some(ref val) = self.#ident {
3396                let nested_len = ::typeway_protobuf::TypewayEncode::encoded_len(val);
3397                len += #tag_len_expr
3398                    + ::typeway_protobuf::tw_varint_len(nested_len as u64)
3399                    + nested_len;
3400            }
3401        },
3402        CodecKind::Repeated(inner) => {
3403            if is_packable(inner) {
3404                let item_len = gen_packed_item_len(inner);
3405                quote! {
3406                    if !self.#ident.is_empty() {
3407                        let mut packed_len: usize = 0;
3408                        for item in &self.#ident {
3409                            packed_len += #item_len;
3410                        }
3411                        // tag + length varint + packed data
3412                        len += #tag_len_expr
3413                            + ::typeway_protobuf::tw_varint_len(packed_len as u64)
3414                            + packed_len;
3415                    }
3416                }
3417            } else {
3418                let item_len = gen_encoded_len_repeated_item(tag, inner);
3419                quote! {
3420                    for item in &self.#ident {
3421                        #item_len
3422                    }
3423                }
3424            }
3425        }
3426        CodecKind::RepeatedMessage => quote! {
3427            for item in &self.#ident {
3428                let nested_len = ::typeway_protobuf::TypewayEncode::encoded_len(item);
3429                len += #tag_len_expr
3430                    + ::typeway_protobuf::tw_varint_len(nested_len as u64)
3431                    + nested_len;
3432            }
3433        },
3434    }
3435}
3436
3437fn gen_encoded_len_optional_inner(tag: u32, kind: &CodecKind) -> TokenStream2 {
3438    let tl = if tag < 16 {
3439        1usize
3440    } else if tag < 2048 {
3441        2
3442    } else {
3443        3
3444    };
3445    let tag_len_expr = quote! { #tl };
3446    match kind {
3447        CodecKind::Varint => quote! {
3448            len += #tag_len_expr + ::typeway_protobuf::tw_varint_len(*val as u64);
3449        },
3450        CodecKind::Bool => quote! { len += #tag_len_expr + 1; },
3451        CodecKind::Fixed32 => quote! { len += #tag_len_expr + 4; },
3452        CodecKind::Fixed64 => quote! { len += #tag_len_expr + 8; },
3453        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3454            len += #tag_len_expr
3455                + ::typeway_protobuf::tw_varint_len(val.len() as u64)
3456                + val.len();
3457        },
3458        _ => quote! {},
3459    }
3460}
3461
3462fn gen_encoded_len_repeated_item(tag: u32, kind: &CodecKind) -> TokenStream2 {
3463    let tl = if tag < 16 {
3464        1usize
3465    } else if tag < 2048 {
3466        2
3467    } else {
3468        3
3469    };
3470    let tag_len_expr = quote! { #tl };
3471    match kind {
3472        CodecKind::Varint => quote! {
3473            len += #tag_len_expr + ::typeway_protobuf::tw_varint_len(*item as u64);
3474        },
3475        CodecKind::Fixed32 => quote! { len += #tag_len_expr + 4; },
3476        CodecKind::Fixed64 => quote! { len += #tag_len_expr + 8; },
3477        CodecKind::LenString | CodecKind::LenBytesStr => quote! {
3478            len += #tag_len_expr
3479                + ::typeway_protobuf::tw_varint_len(item.len() as u64)
3480                + item.len();
3481        },
3482        _ => quote! {},
3483    }
3484}
3485
3486/// Generate a decode arm for `typeway_decode_bytes` — uses `Bytes::slice()`
3487/// for `BytesStr` fields (zero-copy), delegates to `gen_decode_arm` for others.
3488fn gen_decode_bytes_arm(f: &CodecField) -> TokenStream2 {
3489    let ident = &f.ident;
3490    let tag = f.tag;
3491    let ident_str = ident.to_string();
3492
3493    // For BytesStr fields, use Bytes::slice() — zero-copy.
3494    if matches!(&f.codec_kind, CodecKind::LenBytesStr) {
3495        return quote! {
3496            #tag => {
3497                let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3498                offset += consumed;
3499                let str_len = str_len as usize;
3500                if offset + str_len > bytes.len() {
3501                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3502                }
3503                // Zero-copy: validate UTF-8, then slice the Bytes (refcount increment, no copy).
3504                ::core::str::from_utf8(&bytes[offset..offset + str_len])
3505                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3506                #ident = unsafe {
3507                    ::typeway_protobuf::BytesStr::from_utf8_unchecked(
3508                        input.slice(offset..offset + str_len)
3509                    )
3510                };
3511                offset += str_len;
3512            }
3513        };
3514    }
3515
3516    // For all other field types, use the same logic as typeway_decode.
3517    gen_decode_arm(f)
3518}
3519
3520fn gen_decode_arm(f: &CodecField) -> TokenStream2 {
3521    let ident = &f.ident;
3522    let tag = f.tag;
3523    let ident_str = ident.to_string();
3524
3525    match &f.codec_kind {
3526        CodecKind::Varint => quote! {
3527            #tag => {
3528                let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3529                offset += consumed;
3530                #ident = val as _;
3531            }
3532        },
3533        CodecKind::Bool => quote! {
3534            #tag => {
3535                let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3536                offset += consumed;
3537                #ident = val != 0;
3538            }
3539        },
3540        CodecKind::Fixed32 => quote! {
3541            #tag => {
3542                if offset + 4 > bytes.len() {
3543                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3544                }
3545                #ident = f32::from_le_bytes([
3546                    bytes[offset], bytes[offset + 1],
3547                    bytes[offset + 2], bytes[offset + 3],
3548                ]);
3549                offset += 4;
3550            }
3551        },
3552        CodecKind::Fixed64 => quote! {
3553            #tag => {
3554                if offset + 8 > bytes.len() {
3555                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3556                }
3557                #ident = f64::from_le_bytes([
3558                    bytes[offset], bytes[offset + 1],
3559                    bytes[offset + 2], bytes[offset + 3],
3560                    bytes[offset + 4], bytes[offset + 5],
3561                    bytes[offset + 6], bytes[offset + 7],
3562                ]);
3563                offset += 8;
3564            }
3565        },
3566        CodecKind::LenString => quote! {
3567            #tag => {
3568                let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3569                offset += consumed;
3570                let str_len = str_len as usize;
3571                if offset + str_len > bytes.len() {
3572                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3573                }
3574                let slice = &bytes[offset..offset + str_len];
3575                ::core::str::from_utf8(slice)
3576                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3577                #ident = unsafe { String::from_utf8_unchecked(slice.to_vec()) };
3578                offset += str_len;
3579            }
3580        },
3581        CodecKind::LenBytesStr => quote! {
3582            #tag => {
3583                let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3584                offset += consumed;
3585                let str_len = str_len as usize;
3586                if offset + str_len > bytes.len() {
3587                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3588                }
3589                let slice = &bytes[offset..offset + str_len];
3590                ::core::str::from_utf8(slice)
3591                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3592                #ident = ::typeway_protobuf::BytesStr::from(
3593                    unsafe { String::from_utf8_unchecked(slice.to_vec()) }
3594                );
3595                offset += str_len;
3596            }
3597        },
3598        CodecKind::LenBytes => quote! {
3599            #tag => {
3600                let (byte_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3601                offset += consumed;
3602                let byte_len = byte_len as usize;
3603                if offset + byte_len > bytes.len() {
3604                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3605                }
3606                #ident = bytes[offset..offset + byte_len].to_vec();
3607                offset += byte_len;
3608            }
3609        },
3610        CodecKind::Message => quote! {
3611            #tag => {
3612                let (msg_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3613                offset += consumed;
3614                let msg_len = msg_len as usize;
3615                if offset + msg_len > bytes.len() {
3616                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3617                }
3618                #ident = ::typeway_protobuf::TypewayDecode::typeway_decode(
3619                    &bytes[offset..offset + msg_len]
3620                )?;
3621                offset += msg_len;
3622            }
3623        },
3624        CodecKind::Optional(inner) => {
3625            let inner_decode = gen_decode_optional_inner(ident, &ident_str, inner);
3626            quote! { #tag => { #inner_decode } }
3627        }
3628        CodecKind::OptionalMessage => quote! {
3629            #tag => {
3630                let (msg_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3631                offset += consumed;
3632                let msg_len = msg_len as usize;
3633                if offset + msg_len > bytes.len() {
3634                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3635                }
3636                #ident = Some(::typeway_protobuf::TypewayDecode::typeway_decode(
3637                    &bytes[offset..offset + msg_len]
3638                )?);
3639                offset += msg_len;
3640            }
3641        },
3642        CodecKind::Repeated(inner) => {
3643            if is_packable(inner) {
3644                let is_varint = matches!(inner.as_ref(), CodecKind::Varint | CodecKind::Bool);
3645                let is_bool = matches!(inner.as_ref(), CodecKind::Bool);
3646                if is_varint {
3647                    // Optimized packed varint decode: inline 1-byte fast path,
3648                    // pre-reserve Vec capacity.
3649                    // For bool: use `!= 0` conversion. For integers: use `as _`.
3650                    let push_packed_fast = if is_bool {
3651                        quote! { #ident.push(b != 0); }
3652                    } else {
3653                        quote! { #ident.push(b as _); }
3654                    };
3655                    let push_packed_slow = if is_bool {
3656                        quote! { #ident.push(val != 0); }
3657                    } else {
3658                        quote! { #ident.push(val as _); }
3659                    };
3660                    let push_unpacked = push_packed_slow.clone();
3661                    quote! {
3662                        #tag => {
3663                            if wire_type == 2 {
3664                                let (packed_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3665                                offset += consumed;
3666                                let packed_len = packed_len as usize;
3667                                let packed_end = offset + packed_len;
3668                                if packed_end > bytes.len() {
3669                                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3670                                }
3671                                // Reserve worst case: at least 1 element per byte.
3672                                #ident.reserve(packed_len);
3673                                while offset < packed_end {
3674                                    // Inline 1-byte fast path (most common for small values).
3675                                    let b = bytes[offset];
3676                                    if b < 0x80 {
3677                                        #push_packed_fast
3678                                        offset += 1;
3679                                    } else {
3680                                        let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3681                                        offset += consumed;
3682                                        #push_packed_slow
3683                                    }
3684                                }
3685                            } else {
3686                                let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3687                                offset += consumed;
3688                                #push_unpacked
3689                            }
3690                        }
3691                    }
3692                } else {
3693                    let item_read = gen_packed_item_read(ident, inner);
3694                    quote! {
3695                        #tag => {
3696                            if wire_type == 2 {
3697                                let (packed_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3698                                offset += consumed;
3699                                let packed_len = packed_len as usize;
3700                                let packed_end = offset + packed_len;
3701                                if packed_end > bytes.len() {
3702                                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3703                                }
3704                                while offset < packed_end {
3705                                    #item_read
3706                                }
3707                            } else {
3708                                let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3709                                offset += consumed;
3710                                #ident.push(val as _);
3711                            }
3712                        }
3713                    }
3714                }
3715            } else {
3716                let item_decode = gen_decode_repeated_item(ident, &ident_str, inner);
3717                quote! { #tag => { #item_decode } }
3718            }
3719        }
3720        CodecKind::RepeatedMessage => quote! {
3721            #tag => {
3722                let (msg_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3723                offset += consumed;
3724                let msg_len = msg_len as usize;
3725                if offset + msg_len > bytes.len() {
3726                    return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3727                }
3728                #ident.push(::typeway_protobuf::TypewayDecode::typeway_decode(
3729                    &bytes[offset..offset + msg_len]
3730                )?);
3731                offset += msg_len;
3732            }
3733        },
3734    }
3735}
3736
3737fn gen_decode_optional_inner(ident: &Ident, ident_str: &str, kind: &CodecKind) -> TokenStream2 {
3738    match kind {
3739        CodecKind::Varint => quote! {
3740            let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3741            offset += consumed;
3742            #ident = Some(val as _);
3743        },
3744        CodecKind::Bool => quote! {
3745            let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3746            offset += consumed;
3747            #ident = Some(val != 0);
3748        },
3749        CodecKind::Fixed32 => quote! {
3750            if offset + 4 > bytes.len() {
3751                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3752            }
3753            #ident = Some(f32::from_le_bytes([
3754                bytes[offset], bytes[offset + 1],
3755                bytes[offset + 2], bytes[offset + 3],
3756            ]));
3757            offset += 4;
3758        },
3759        CodecKind::Fixed64 => quote! {
3760            if offset + 8 > bytes.len() {
3761                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3762            }
3763            #ident = Some(f64::from_le_bytes([
3764                bytes[offset], bytes[offset + 1],
3765                bytes[offset + 2], bytes[offset + 3],
3766                bytes[offset + 4], bytes[offset + 5],
3767                bytes[offset + 6], bytes[offset + 7],
3768            ]));
3769            offset += 8;
3770        },
3771        CodecKind::LenString => quote! {
3772            let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3773            offset += consumed;
3774            let str_len = str_len as usize;
3775            if offset + str_len > bytes.len() {
3776                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3777            }
3778            {
3779                let slice = &bytes[offset..offset + str_len];
3780                ::core::str::from_utf8(slice)
3781                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3782                #ident = Some(unsafe { String::from_utf8_unchecked(slice.to_vec()) });
3783            }
3784            offset += str_len;
3785        },
3786        CodecKind::LenBytesStr => quote! {
3787            let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3788            offset += consumed;
3789            let str_len = str_len as usize;
3790            if offset + str_len > bytes.len() {
3791                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3792            }
3793            {
3794                let slice = &bytes[offset..offset + str_len];
3795                ::core::str::from_utf8(slice)
3796                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3797                #ident = Some(::typeway_protobuf::BytesStr::from(
3798                    unsafe { String::from_utf8_unchecked(slice.to_vec()) }
3799                ));
3800            }
3801            offset += str_len;
3802        },
3803        _ => quote! {
3804            let skipped = ::typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
3805            offset += skipped;
3806        },
3807    }
3808}
3809
3810fn gen_decode_repeated_item(ident: &Ident, ident_str: &str, kind: &CodecKind) -> TokenStream2 {
3811    match kind {
3812        CodecKind::Varint => quote! {
3813            let (val, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3814            offset += consumed;
3815            #ident.push(val as _);
3816        },
3817        CodecKind::Fixed32 => quote! {
3818            if offset + 4 > bytes.len() {
3819                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3820            }
3821            #ident.push(f32::from_le_bytes([
3822                bytes[offset], bytes[offset + 1],
3823                bytes[offset + 2], bytes[offset + 3],
3824            ]));
3825            offset += 4;
3826        },
3827        CodecKind::Fixed64 => quote! {
3828            if offset + 8 > bytes.len() {
3829                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3830            }
3831            #ident.push(f64::from_le_bytes([
3832                bytes[offset], bytes[offset + 1],
3833                bytes[offset + 2], bytes[offset + 3],
3834                bytes[offset + 4], bytes[offset + 5],
3835                bytes[offset + 6], bytes[offset + 7],
3836            ]));
3837            offset += 8;
3838        },
3839        CodecKind::LenString => quote! {
3840            let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3841            offset += consumed;
3842            let str_len = str_len as usize;
3843            if offset + str_len > bytes.len() {
3844                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3845            }
3846            {
3847                let slice = &bytes[offset..offset + str_len];
3848                ::core::str::from_utf8(slice)
3849                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3850                #ident.push(unsafe { String::from_utf8_unchecked(slice.to_vec()) });
3851            }
3852            offset += str_len;
3853        },
3854        CodecKind::LenBytesStr => quote! {
3855            let (str_len, consumed) = ::typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
3856            offset += consumed;
3857            let str_len = str_len as usize;
3858            if offset + str_len > bytes.len() {
3859                return Err(::typeway_protobuf::TypewayDecodeError::UnexpectedEof);
3860            }
3861            {
3862                let slice = &bytes[offset..offset + str_len];
3863                ::core::str::from_utf8(slice)
3864                    .map_err(|_| ::typeway_protobuf::TypewayDecodeError::InvalidUtf8(#ident_str))?;
3865                #ident.push(::typeway_protobuf::BytesStr::from(
3866                    unsafe { String::from_utf8_unchecked(slice.to_vec()) }
3867                ));
3868            }
3869            offset += str_len;
3870        },
3871        _ => quote! {
3872            let skipped = ::typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
3873            offset += skipped;
3874        },
3875    }
3876}