Skip to main content

synapse_codegen/
server_gen.rs

1//! Server trait generation
2
3use crate::ServiceDef;
4use anyhow::Result;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Generate server trait for a service
9///
10/// Example output:
11/// ```text
12/// #[async_trait::async_trait]
13/// pub trait UserService: Send + Sync {
14///     type Error: Into<synapse_rpc::ServiceError> + Send;
15///     async fn get_user(&self, request: GetUserRequest) -> Result<GetUserResponse, Self::Error>;
16///     async fn create_user(&self, request: CreateUserRequest) -> Result<CreateUserResponse, Self::Error>;
17/// }
18/// ```
19pub fn generate_server_trait(service: &ServiceDef) -> Result<TokenStream> {
20    let trait_name = format_ident!("{}", service.service_name);
21    let package = &service.package;
22
23    let methods: Vec<TokenStream> = service
24        .methods
25        .iter()
26        .map(|method| {
27            let method_name = format_ident!("{}", method.method_name_snake());
28            let input_type = method.input_type_path(package);
29            let output_type = method.output_type_path(package);
30
31            let comment = method.comment.as_ref().map(|c| {
32                let doc = format!(" {}", c);
33                quote! { #[doc = #doc] }
34            });
35
36            quote! {
37                #comment
38                async fn #method_name(
39                    &self,
40                    state: &synapse_sdk::RequestState<()>,
41                    request: #input_type
42                ) -> Result<#output_type, Self::Error>;
43            }
44        })
45        .collect();
46
47    Ok(quote! {
48        /// Generated server trait for #trait_name
49        #[async_trait::async_trait]
50        pub trait #trait_name: Send + Sync + 'static {
51            /// Error type for this service. Must be convertible to ServiceError.
52            type Error: Into<synapse_rpc::ServiceError> + Send;
53
54            #(#methods)*
55        }
56    })
57}
58
59/// Generate router implementation that converts trait impl to RpcHandler
60///
61/// Example output:
62/// ```text
63/// pub struct UserServiceRouter<T: UserService> {
64///     service: Arc<T>,
65/// }
66///
67/// impl<T: UserService> UserServiceRouter<T> {
68///     pub fn new(service: T) -> Arc<dyn synapse_rpc::RpcHandler> {
69///         // Creates MethodRouter with all methods
70///     }
71/// }
72/// ```
73pub fn generate_router_impl(service: &ServiceDef) -> Result<TokenStream> {
74    let trait_name = format_ident!("{}", service.service_name);
75    let router_name = format_ident!("{}Router", service.service_name);
76    let package = &service.package;
77
78    let interface_id = service.interface_id_expr();
79
80    // Generate handler for each method
81    let method_handlers: Vec<TokenStream> = service
82        .methods
83        .iter()
84        .map(|method| {
85            let method_name_snake = format_ident!("{}", method.method_name_snake());
86            let method_id = method.method_id_expr();
87            let input_type = method.input_type_path(package);
88            let _output_type = method.output_type_path(package);
89
90            quote! {
91                {
92                    let service = Arc::clone(&service);
93                    let handler = synapse_rpc::FunctionHandler::new(move |req: synapse_proto::RpcRequest| {
94                        let service = Arc::clone(&service);
95                        Box::pin(async move {
96                            // Create request state
97                            let state = synapse_sdk::RequestState::new(&req);
98
99                            // Deserialize request using prost
100                            let request: #input_type = match prost::Message::decode(&req.payload[..]) {
101                                Ok(r) => r,
102                                Err(e) => {
103                                    return synapse_rpc::error_response(
104                                        synapse_proto::RpcStatus::InvalidRequest,
105                                        400,
106                                        format!("Invalid request: {}", e),
107                                    );
108                                }
109                            };
110
111                            // Call service method with state
112                            match service.#method_name_snake(&state, request).await {
113                                Ok(response) => {
114                                    // Serialize response using prost
115                                    let payload = prost::Message::encode_to_vec(&response);
116                                    synapse_rpc::ok_response(bytes::Bytes::from(payload))
117                                }
118                                Err(err) => {
119                                    // Convert custom error to ServiceError
120                                    let service_err: synapse_rpc::ServiceError = err.into();
121                                    synapse_rpc::error_response(service_err.status, service_err.code, &service_err.message)
122                                }
123                            }
124                        })
125                    });
126                    router.method(#method_id, std::sync::Arc::new(handler))
127                }
128            }
129        })
130        .collect();
131
132    let method_ids: Vec<TokenStream> = service
133        .methods
134        .iter()
135        .map(|method| {
136            let method_id = method.method_id_expr();
137            quote! { #method_id }
138        })
139        .collect();
140
141    let method_names: Vec<&str> = service
142        .methods
143        .iter()
144        .map(|method| method.name.as_str())
145        .collect();
146
147    Ok(quote! {
148        /// Generated router for #trait_name
149        pub struct #router_name<T: #trait_name> {
150            _phantom: std::marker::PhantomData<T>,
151        }
152
153        impl<T: #trait_name> #router_name<T> {
154            /// Create a router from a service implementation
155            pub fn create(service: T) -> (synapse_rpc::InterfaceRegistration, std::sync::Arc<dyn synapse_rpc::RpcHandler>) {
156                use std::sync::Arc;
157                let service = Arc::new(service);
158
159                // Create method router
160                let mut router = synapse_rpc::MethodRouter::new();
161
162                // Add method handlers
163                #(
164                    router = #method_handlers;
165                )*
166
167                let router = router.build();
168
169                // Create registration
170                let registration = synapse_rpc::InterfaceRegistration {
171                    interface_id: #interface_id,
172                    interface_version: 1_000_000, // v1.0.0
173                    method_ids: [#(#method_ids),*].into_iter().collect(),
174                    method_names: vec![#(#method_names.to_string()),*],
175                    instance_id: synapse_primitives::InstanceId::new_random(),
176                    service_name: stringify!(#trait_name).to_string(),
177                    interface_name: concat!(#package, ".", stringify!(#trait_name)).to_string(),
178                };
179
180                (registration, router)
181            }
182        }
183    })
184}