1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::Parse, parse::ParseStream, parse_macro_input, FnArg, ItemFn, Pat, ReturnType, Type,
5};
6
7struct MacroArgs {
9 path: String,
10 method: String,
11}
12
13impl Parse for MacroArgs {
14 fn parse(input: ParseStream) -> syn::Result<Self> {
15 let mut path = None;
16 let mut method = None;
17
18 loop {
20 if input.is_empty() {
21 break;
22 }
23
24 let ident: syn::Ident = input.parse()?;
26 input.parse::<syn::Token![=]>()?;
27
28 if ident == "path" {
29 let path_lit: syn::LitStr = input.parse()?;
30 path = Some(path_lit.value());
31 } else if ident == "method" {
32 let method_lit: syn::LitStr = input.parse()?;
33 let method_value = method_lit.value().to_uppercase();
34
35 if !["GET", "POST", "PUT", "DELETE", "PATCH"].contains(&method_value.as_str()) {
37 return Err(syn::Error::new(
38 method_lit.span(),
39 "Invalid HTTP method. Must be one of: GET, POST, PUT, DELETE, PATCH",
40 ));
41 }
42 method = Some(method_value);
43 } else {
44 return Err(syn::Error::new(
45 ident.span(),
46 format!("Unknown argument '{}'. Expected 'path' or 'method'", ident),
47 ));
48 }
49
50 if input.peek(syn::Token![,]) {
52 input.parse::<syn::Token![,]>()?;
53 } else {
54 break;
55 }
56 }
57
58 let path =
60 path.ok_or_else(|| syn::Error::new(input.span(), "Missing required argument 'path'"))?;
61
62 let method = method.unwrap_or_else(|| "POST".to_string());
64
65 Ok(MacroArgs { path, method })
66 }
67}
68
69#[proc_macro_attribute]
76pub fn yewserverhook(args: TokenStream, input: TokenStream) -> TokenStream {
77 let input = parse_macro_input!(input as ItemFn);
78
79 let args = parse_macro_input!(args as MacroArgs);
81 let path = args.path;
82 let method = args.method;
83
84 let fn_name = &input.sig.ident;
86 let fn_vis = &input.vis;
87 let fn_block = &input.block;
88 let fn_inputs = &input.sig.inputs;
89 let fn_output = &input.sig.output;
90
91 let has_params = !fn_inputs.is_empty();
93
94 let (return_type, error_type) = extract_return_type(fn_output);
96 let error_type = error_type.unwrap_or_else(|| quote! { () });
97
98 let hook_name = format!("use_{}", fn_name.to_string());
100 let hook_ident = syn::Ident::new(&hook_name, fn_name.span());
101
102 let param_struct = if has_params {
104 generate_param_struct(fn_name, fn_inputs)
105 } else {
106 quote! {}
107 };
108
109 let server_handler = generate_server_handler(
111 fn_name,
112 fn_vis,
113 fn_block,
114 fn_inputs,
115 fn_output,
116 has_params,
117 &return_type,
118 &error_type,
119 &path,
120 &method,
121 );
122
123 let client_hook = generate_client_hook(
125 &hook_ident,
126 fn_vis,
127 &path,
128 &return_type,
129 has_params,
130 fn_name,
131 fn_inputs,
132 &method,
133 );
134
135 let client_function = generate_client_function(
137 fn_name,
138 fn_vis,
139 &path,
140 &return_type,
141 has_params,
142 fn_inputs,
143 &method,
144 );
145
146 let hook_wrapper = quote! {};
148
149 let expanded = quote! {
150
151 #[cfg(feature = "ssr")]
152 #input
153
154 #param_struct
155
156 #server_handler
157
158 #client_hook
159
160 #[cfg(not(feature = "ssr"))]
161 #client_function
162
163 #hook_wrapper
164 };
165
166 TokenStream::from(expanded)
167}
168
169fn extract_return_type(
170 output: &ReturnType,
171) -> (proc_macro2::TokenStream, Option<proc_macro2::TokenStream>) {
172 match output {
173 ReturnType::Default => (quote! { () }, None),
174 ReturnType::Type(_, ty) => {
175 if let Type::Path(type_path) = &**ty {
177 if let Some(segment) = type_path.path.segments.last() {
178 if segment.ident == "Result" {
179 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
180 if let (
181 Some(syn::GenericArgument::Type(ok_type)),
182 Some(syn::GenericArgument::Type(err_type)),
183 ) = (args.args.first(), args.args.iter().nth(1))
184 {
185 return (quote! { #ok_type }, Some(quote! { #err_type }));
186 }
187 }
188 }
189 }
190 }
191 (quote! { #ty }, None)
192 }
193 }
194}
195
196fn generate_param_struct(
197 fn_name: &syn::Ident,
198 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
199) -> proc_macro2::TokenStream {
200 let struct_name = syn::Ident::new(
201 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
202 fn_name.span(),
203 );
204
205 let mut fields = Vec::new();
206
207 for input in inputs {
208 if let FnArg::Typed(pat_type) = input {
209 if let Pat::Ident(pat_ident) = &*pat_type.pat {
210 let field_name = &pat_ident.ident;
211 let field_type = &pat_type.ty;
212 fields.push(quote! {
213 pub #field_name: #field_type
214 });
215 }
216 }
217 }
218
219 quote! {
220 #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
221 pub struct #struct_name {
222 #(#fields),*
223 }
224 }
225}
226
227fn generate_server_handler(
228 fn_name: &syn::Ident,
229 vis: &syn::Visibility,
230 block: &syn::Block,
231 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
232 _output: &ReturnType,
233 has_params: bool,
234 return_type: &proc_macro2::TokenStream,
235 error_type: &proc_macro2::TokenStream,
236 path: &str,
237 method: &str,
238) -> proc_macro2::TokenStream {
239 let fn_handler_name =
240 syn::Ident::new(&format!("{}_handler", fn_name.to_string()), fn_name.span());
241
242 let params_arg = if has_params {
243 let struct_name = syn::Ident::new(
244 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
245 fn_name.span(),
246 );
247 if method == "GET" {
249 quote! { axum::extract::Query(params): axum::extract::Query<#struct_name>, }
250 } else {
251 quote! { axum::Json(params): axum::Json<#struct_name>, }
252 }
253 } else {
254 quote! {}
255 };
256
257 let param_extraction = if has_params {
258 let mut field_names = Vec::new();
259 for input in inputs {
260 if let FnArg::Typed(pat_type) = input {
261 if let Pat::Ident(pat_ident) = &*pat_type.pat {
262 field_names.push(&pat_ident.ident);
263 }
264 }
265 }
266 let struct_name = syn::Ident::new(
267 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
268 fn_name.span(),
269 );
270 quote! {
271 let #struct_name { #(#field_names),* } = params;
272 }
273 } else {
274 quote! {}
275 };
276
277 let original_stmts = &block.stmts;
279 let modified_block = quote! {
280 {
281 #param_extraction
282
283 let result: Result<#return_type, #error_type> = async {
285 #(#original_stmts)*
286 }.await;
287
288 result.map(axum::Json)
290 }
291 };
292
293 let wrapper_fn_name = syn::Ident::new(
295 &format!("{}_wrapper", fn_handler_name),
296 fn_handler_name.span(),
297 );
298
299 let extract_and_call = if has_params {
301 let struct_name = syn::Ident::new(
302 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
303 fn_name.span(),
304 );
305
306 if method == "GET" {
307 quote! {
309 use ::axum::extract::FromRequestParts;
310
311 let (mut parts, _body) = req.into_parts();
312
313 ::yew_extra::provide_request_parts(parts.clone()).await;
315
316 let result = match ::axum::extract::Query::<#struct_name>::from_request_parts(&mut parts, &()).await {
317 Ok(::axum::extract::Query(params)) => {
318 let response = #fn_handler_name(::axum::extract::Query(params)).await;
319 response.into_response()
320 },
321 Err(e) => {
322 ::axum::http::Response::builder()
323 .status(::axum::http::StatusCode::BAD_REQUEST)
324 .body(::axum::body::Body::from(format!("Invalid query parameters: {}", e)))
325 .unwrap()
326 }
327 };
328
329 ::yew_extra::clear_request_parts().await;
331 result
332 }
333 } else {
334 quote! {
336 use ::axum::extract::FromRequest;
337
338 let (parts, body) = req.into_parts();
339
340 ::yew_extra::provide_request_parts(parts.clone()).await;
342
343 let req = ::axum::http::Request::from_parts(parts, body);
344
345 let result = match ::axum::Json::<#struct_name>::from_request(req, &()).await {
346 Ok(params) => {
347 let response = #fn_handler_name(params).await;
348 response.into_response()
349 },
350 Err(e) => {
351 ::axum::http::Response::builder()
352 .status(::axum::http::StatusCode::BAD_REQUEST)
353 .body(::axum::body::Body::from(format!("Invalid request: {}", e)))
354 .unwrap()
355 }
356 };
357
358 ::yew_extra::clear_request_parts().await;
360 result
361 }
362 }
363 } else {
364 quote! {
365 let (parts, _body) = req.into_parts();
367
368 ::yew_extra::provide_request_parts(parts).await;
370
371 let response = #fn_handler_name().await;
372
373 ::yew_extra::clear_request_parts().await;
375
376 response.into_response()
377 }
378 };
379
380 let method_ident = syn::Ident::new(&method, proc_macro2::Span::call_site());
382
383 let inventory_submission = quote! {
387 #[cfg(all(feature = "ssr", not(test)))]
389 fn #wrapper_fn_name(
390 req: ::axum::http::Request<::axum::body::Body>
391 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::axum::http::Response<::axum::body::Body>> + Send>> {
392 Box::pin(async move {
393 use ::axum::response::IntoResponse;
394 #extract_and_call
395 })
396 }
397
398 #[cfg(all(feature = "ssr", not(test)))]
399 ::inventory::submit! {
400 crate::route_registry::RouteInfo::new(
401 #path,
402 ::axum::http::Method::#method_ident,
403 #wrapper_fn_name
404 )
405 }
406 };
407
408 quote! {
409 #[cfg(feature = "ssr")]
410 #vis async fn #fn_handler_name(
411 #params_arg
412 ) -> Result<axum::Json<#return_type>, #error_type> #modified_block
414
415 #inventory_submission
416 }
417}
418
419fn generate_client_function(
420 fn_name: &syn::Ident,
421 vis: &syn::Visibility,
422 path: &str,
423 return_type: &proc_macro2::TokenStream,
424 has_params: bool,
425 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
426 method: &str,
427) -> proc_macro2::TokenStream {
428 let host_url = quote! { "" };
430
431 let func_params = if has_params {
433 let mut params = Vec::new();
434 for input in inputs {
435 if let FnArg::Typed(pat_type) = input {
436 if let Pat::Ident(pat_ident) = &*pat_type.pat {
437 let param_name = &pat_ident.ident;
438 let param_type = &pat_type.ty;
439 params.push(quote! { #param_name: #param_type });
440 }
441 }
442 }
443 quote! { #(#params),* }
444 } else {
445 quote! {}
446 };
447
448 let method_lower = method.to_lowercase();
450 let method_fn = syn::Ident::new(&method_lower, proc_macro2::Span::call_site());
451
452 let request_body = if has_params && method != "GET" {
454 let struct_name = syn::Ident::new(
455 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
456 fn_name.span(),
457 );
458 let mut field_names = Vec::new();
459 for input in inputs {
460 if let FnArg::Typed(pat_type) = input {
461 if let Pat::Ident(pat_ident) = &*pat_type.pat {
462 field_names.push(&pat_ident.ident);
463 }
464 }
465 }
466 quote! {
467 let params = #struct_name {
468 #(#field_names),*
469 };
470 let body = serde_json::to_string(¶ms)
471 .map_err(|e| format!("Failed to serialize parameters: {}", e))?;
472
473 let request = gloo_net::http::Request::#method_fn(&format!("{}{}", #host_url, #path))
474 .header("Content-Type", "application/json")
475 .body(body)
476 .map_err(|e| format!("Failed to create request: {}", e))?;
477 }
478 } else if has_params && method == "GET" {
479 let struct_name = syn::Ident::new(
481 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
482 fn_name.span(),
483 );
484 let mut field_names = Vec::new();
485 for input in inputs {
486 if let FnArg::Typed(pat_type) = input {
487 if let Pat::Ident(pat_ident) = &*pat_type.pat {
488 field_names.push(&pat_ident.ident);
489 }
490 }
491 }
492 quote! {
493 let params = #struct_name {
494 #(#field_names),*
495 };
496
497 let query_string = serde_urlencoded::to_string(¶ms)
499 .map_err(|e| format!("Failed to serialize query parameters: {}", e))?;
500
501 let url = format!("{}{}?{}", #host_url, #path, query_string);
502
503 let request = gloo_net::http::Request::#method_fn(&url)
504 .header("Content-Type", "application/json");
505 }
506 } else {
507 quote! {
508 let request = gloo_net::http::Request::#method_fn(&format!("{}{}", #host_url, #path))
509 .header("Content-Type", "application/json");
510 }
511 };
512
513 let async_fn_name = syn::Ident::new(&format!("{}", fn_name.to_string()), fn_name.span());
515
516 quote! {
517 #[cfg(not(feature = "ssr"))]
518 #vis async fn #async_fn_name(#func_params) -> Result<#return_type, String> {
519 #request_body
520
521 let response = request
522 .send()
523 .await
524 .map_err(|e| format!("Failed to fetch data: {}", e))?;
525
526 if response.ok() {
528 response
529 .json::<#return_type>()
530 .await
531 .map_err(|e| format!("Failed to parse response: {}", e))
532 } else {
533 let status = response.status();
535 let error_msg = match response.text().await {
536 Ok(text) => {
537 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
539 if let Some(msg) = json.get("error").and_then(|v| v.as_str()) {
540 msg.to_string()
541 } else if let Some(msg) = json.get("message").and_then(|v| v.as_str()) {
542 msg.to_string()
543 } else {
544 text
545 }
546 } else {
547 text
548 }
549 }
550 Err(_) => format!("Request failed with status {}", status)
551 };
552 Err(error_msg)
553 }
554 }
555 }
556}
557
558fn generate_client_hook(
559 hook_name: &syn::Ident,
560 vis: &syn::Visibility,
561 path: &str,
562 return_type: &proc_macro2::TokenStream,
563 has_params: bool,
564 fn_name: &syn::Ident,
565 inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
566 method: &str,
567) -> proc_macro2::TokenStream {
568 let host_url = quote! { "" };
570
571 let hook_params = if has_params {
572 let mut params = Vec::new();
573 for input in inputs {
574 if let FnArg::Typed(pat_type) = input {
575 if let Pat::Ident(pat_ident) = &*pat_type.pat {
576 let param_name = &pat_ident.ident;
577 let param_type = &pat_type.ty;
578 params.push(quote! { #param_name: #param_type });
579 }
580 }
581 }
582 quote! { #(#params),* }
583 } else {
584 quote! {}
585 };
586
587 let method_lower = method.to_lowercase();
589 let method_fn = syn::Ident::new(&method_lower, proc_macro2::Span::call_site());
590
591 let request_body = if has_params && method != "GET" {
592 let struct_name = syn::Ident::new(
593 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
594 fn_name.span(),
595 );
596 let mut field_names = Vec::new();
597 for input in inputs {
598 if let FnArg::Typed(pat_type) = input {
599 if let Pat::Ident(pat_ident) = &*pat_type.pat {
600 field_names.push(&pat_ident.ident);
601 }
602 }
603 }
604 quote! {
605 let params = #struct_name {
606 #(#field_names: #field_names.clone()),*
607 };
608 let body = serde_json::to_string(¶ms).unwrap();
609 let request = match gloo_net::http::Request::#method_fn(
610 &format!("{}{}", #host_url, #path)
611 )
612 .header("Content-Type", "application/json")
613 .body(body) {
614 Ok(req) => req,
615 Err(e) => {
616 state.set(DataState::Error(format!("Failed to create request: {}", e)));
617 return;
618 }
619 };
620 }
621 } else if has_params && method == "GET" {
622 let struct_name = syn::Ident::new(
624 &format!("{}Params", to_pascal_case(&fn_name.to_string())),
625 fn_name.span(),
626 );
627 let mut field_names = Vec::new();
628 for input in inputs {
629 if let FnArg::Typed(pat_type) = input {
630 if let Pat::Ident(pat_ident) = &*pat_type.pat {
631 field_names.push(&pat_ident.ident);
632 }
633 }
634 }
635 quote! {
636 let params = #struct_name {
637 #(#field_names: #field_names.clone()),*
638 };
639 let query_string = match serde_urlencoded::to_string(¶ms) {
640 Ok(qs) => qs,
641 Err(e) => {
642 state.set(DataState::Error(format!("Failed to serialize query parameters: {}", e)));
643 return;
644 }
645 };
646 let request = gloo_net::http::Request::#method_fn(
647 &format!("{}{}?{}", #host_url, #path, query_string)
648 )
649 .header("Content-Type", "application/json");
650 }
651 } else {
652 quote! {
653 let request = gloo_net::http::Request::#method_fn(
654 &format!("{}{}", #host_url, #path)
655 )
656 .header("Content-Type", "application/json");
657 }
658 };
659
660 let deps = if has_params {
661 let mut dep_names = Vec::new();
662 for input in inputs {
663 if let FnArg::Typed(pat_type) = input {
664 if let Pat::Ident(pat_ident) = &*pat_type.pat {
665 dep_names.push(&pat_ident.ident);
666 }
667 }
668 }
669 quote! { (#(#dep_names.clone()),*) }
670 } else {
671 quote! { () }
672 };
673
674 let is_vec = quote!(#return_type).to_string().contains("Vec");
676
677 let data_handling = if is_vec {
678 quote! {
679 if fetched_data.is_empty() {
680 state.set(DataState::Empty);
681 } else {
682 state.set(DataState::Data(fetched_data));
683 }
684 }
685 } else {
686 quote! {
687 state.set(DataState::Data(fetched_data));
688 }
689 };
690
691 quote! {
692
693 #[cfg(feature = "ssr")]
694 #[yew::hook]
695 #vis fn #hook_name(#hook_params) -> ApiHook<#return_type> {
696 let state = yew::use_state(|| DataState::<#return_type>::Loading);
697
698 let is_loading = yew::use_state(|| false);
699 let is_updating = yew::use_state(|| false);
700
701 ApiHook {
702 state: (*state).clone(),
703 is_loading: (*is_loading).clone(),
704 is_updating: (*is_updating).clone(),
705 }
706 }
707
708 #[cfg(not(feature = "ssr"))]
709 #[yew::hook]
710 #vis fn #hook_name(#hook_params) -> ApiHook<#return_type> {
711 let state = yew::use_state(|| DataState::<#return_type>::Loading);
712
713 let is_loading = yew::use_state(|| false);
714 let is_updating = yew::use_state(|| false);
715
716 {
717 let state = state.clone();
718 let is_loading = is_loading.clone();
719 let is_updating = is_updating.clone();
720
721 yew::use_effect_with(#deps, move |_| {
722 let is_first_load = matches!(*state, DataState::Loading);
724
725 if is_first_load {
727 is_loading.set(true);
728 is_updating.set(true);
729 } else {
730 is_updating.set(true);
731 }
732
733 wasm_bindgen_futures::spawn_local(async move {
734 #request_body
735
736 match request.send().await {
737 Ok(response) => {
738 if response.ok() {
740 match response.json::<#return_type>().await {
741 Ok(fetched_data) => {
742 #data_handling
743 }
744 Err(e) => {
745 state.set(DataState::Error(format!(
746 "Failed to parse response: {}",
747 e
748 )));
749 }
750 }
751 } else {
752 let status = response.status();
754 let error_msg = match response.text().await {
755 Ok(text) => {
756 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
758 if let Some(msg) = json.get("error").and_then(|v| v.as_str()) {
759 msg.to_string()
760 } else if let Some(msg) = json.get("message").and_then(|v| v.as_str()) {
761 msg.to_string()
762 } else {
763 text
764 }
765 } else {
766 text
767 }
768 }
769 Err(_) => format!("Request failed with status {}", status)
770 };
771 state.set(DataState::Error(error_msg));
772 }
773 }
774 Err(e) => {
775 state.set(DataState::Error(format!(
776 "Failed to fetch data: {}",
777 e
778 )));
779 }
780 }
781
782 is_loading.set(false);
784 is_updating.set(false);
785 });
786 || ()
787 });
788 }
789
790 ApiHook {
791 state: (*state).clone(),
792 is_loading: *is_loading,
793 is_updating: *is_updating,
794 }
795 }
796 }
797}
798
799fn to_pascal_case(s: &str) -> String {
800 s.split('_')
801 .map(|word| {
802 let mut chars = word.chars();
803 match chars.next() {
804 None => String::new(),
805 Some(first) => first.to_uppercase().chain(chars).collect(),
806 }
807 })
808 .collect()
809}