soap_service/
lib.rs

1mod codegen;
2mod parser;
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use syn::{parse_macro_input, ItemMod};
8
9#[proc_macro_attribute]
10pub fn service(args: TokenStream, input: TokenStream) -> TokenStream {
11    let config = match parser::parse_service_attributes(args.into()) {
12        Ok(config) => config,
13        Err(e) => return e.to_compile_error().into(),
14    };
15
16    let module = parse_macro_input!(input as ItemMod);
17
18    let operations = match parser::extract_soap_operations(&module) {
19        Ok(ops) => ops,
20        Err(e) => return e.to_compile_error().into(),
21    };
22
23    let enhanced_module = generate_enhanced_module(module, config, operations);
24    enhanced_module.into()
25}
26
27fn generate_enhanced_module(
28    mut module: ItemMod,
29    config: parser::ServiceConfig,
30    operations: Vec<parser::SoapOperation>,
31) -> TokenStream2 {
32    let bind_path = &config.bind_path;
33    let wsdl_path = format!("{}/wsdl", bind_path);
34    let namespace = &config.namespace;
35
36    // Collect type information
37    let types = match parser::collect_types_from_operations(&operations) {
38        Ok(types) => types,
39        Err(_) => std::collections::HashMap::new(),
40    };
41
42    // Generate WSDL content
43    let wsdl_content = codegen::generate_wsdl(&config, &operations, &types);
44
45    // Generate operation dispatcher
46    let operation_handlers = generate_operation_handlers(&operations, namespace);
47
48    let router_code = quote! {
49        use std::collections::HashMap;
50
51        pub fn router() -> axum::Router {
52            axum::Router::new()
53                .route(#bind_path, axum::routing::post(soap_handler))
54                .route(#wsdl_path, axum::routing::get(wsdl_handler))
55        }
56
57        async fn soap_handler(body: String) -> axum::response::Response {
58            match handle_soap_request(&body).await {
59                Ok(response) => {
60                    axum::response::Response::builder()
61                        .status(200)
62                        .header("Content-Type", "text/xml; charset=utf-8")
63                        .header("SOAPAction", "")
64                        .body(response.into())
65                        .unwrap()
66                }
67                Err(error) => {
68                    let fault = create_soap_fault(&error);
69                    axum::response::Response::builder()
70                        .status(500)
71                        .header("Content-Type", "text/xml; charset=utf-8")
72                        .body(fault.into())
73                        .unwrap()
74                }
75            }
76        }
77
78        async fn handle_soap_request(xml: &str) -> Result<String, String> {
79            // Parse SOAP envelope using proper XML parsing
80            let parsed_request = parse_soap_envelope(xml)?;
81            let operation = &parsed_request.operation;
82            let body_content = &parsed_request.body_xml;
83
84            #operation_handlers
85
86            Err(format!("Unknown operation: {}", operation))
87        }
88
89        #[derive(Debug)]
90        struct ParsedSoapRequest {
91            operation: String,
92            body_xml: String,
93            namespace: Option<String>,
94        }
95
96        fn parse_soap_envelope(xml: &str) -> Result<ParsedSoapRequest, String> {
97            // Handle different SOAP Body variations
98            let body_start_patterns = ["<soap:Body>", "<SOAP-ENV:Body>", "<Body>"];
99            let body_end_patterns = ["</soap:Body>", "</SOAP-ENV:Body>", "</Body>"];
100
101            let mut body_start_pos = None;
102            let mut body_end_pos = None;
103            let mut body_tag_len = 0;
104
105            // Find body start
106            for pattern in &body_start_patterns {
107                if let Some(pos) = xml.find(pattern) {
108                    body_start_pos = Some(pos);
109                    body_tag_len = pattern.len();
110                    break;
111                }
112            }
113
114            // Find body end
115            for pattern in &body_end_patterns {
116                if let Some(pos) = xml.find(pattern) {
117                    body_end_pos = Some(pos);
118                    break;
119                }
120            }
121
122            let body_start = body_start_pos.ok_or("SOAP Body start tag not found")?;
123            let body_end = body_end_pos.ok_or("SOAP Body end tag not found")?;
124
125            if body_start + body_tag_len >= body_end {
126                return Err("Invalid SOAP Body structure".to_string());
127            }
128
129            let body_content = &xml[body_start + body_tag_len..body_end];
130            let trimmed_body = body_content.trim();
131
132            // Extract operation name from first element in body
133            let operation = extract_first_element_name(trimmed_body)?;
134
135            Ok(ParsedSoapRequest {
136                operation,
137                body_xml: trimmed_body.to_string(),
138                namespace: extract_target_namespace(xml),
139            })
140        }
141
142        fn extract_first_element_name(xml: &str) -> Result<String, String> {
143            let xml = xml.trim();
144            if !xml.starts_with('<') {
145                return Err("No XML element found".to_string());
146            }
147
148            let after_bracket = &xml[1..];
149            let tag_end = after_bracket.find('>')
150                .ok_or("Invalid XML: no closing bracket found")?;
151
152            let tag_content = &after_bracket[..tag_end];
153
154            // Handle self-closing tags
155            let tag_name = if tag_content.ends_with('/') {
156                &tag_content[..tag_content.len() - 1]
157            } else {
158                tag_content
159            };
160
161            // Remove namespace prefix and attributes
162            let clean_name = tag_name.split_whitespace().next().unwrap_or(tag_name);
163            let operation = if clean_name.contains(':') {
164                clean_name.split(':').last().unwrap_or(clean_name)
165            } else {
166                clean_name
167            };
168
169            Ok(operation.to_string())
170        }
171
172        fn extract_target_namespace(xml: &str) -> Option<String> {
173            // Look for targetNamespace or xmlns attributes
174            if let Some(start) = xml.find("targetNamespace=\"") {
175                let after_start = &xml[start + 17..];
176                if let Some(end) = after_start.find('"') {
177                    return Some(after_start[..end].to_string());
178                }
179            }
180
181            // Fallback to default xmlns
182            if let Some(start) = xml.find("xmlns=\"") {
183                let after_start = &xml[start + 7..];
184                if let Some(end) = after_start.find('"') {
185                    return Some(after_start[..end].to_string());
186                }
187            }
188
189            None
190        }
191
192        fn create_simple_soap_response(content: &str, operation: &str, namespace: &str) -> String {
193            format!(
194                r#"<?xml version="1.0" encoding="UTF-8"?>
195<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/"
196               xmlns:tns="{}">
197    <soap:Body>
198        <tns:{}Response>
199            {}
200        </tns:{}Response>
201    </soap:Body>
202</soap:Envelope>"#,
203                namespace, operation, content, operation
204            )
205        }
206
207        fn extract_xml_value(xml: &str, tag_name: &str) -> Option<String> {
208            // Try multiple patterns to handle namespaces and variations
209            let patterns = [
210                format!("<{}>", tag_name),
211                format!("<{}:", tag_name),  // Handle namespace prefixes
212                format!("<tns:{}>", tag_name),
213                format!("<ns1:{}>", tag_name),
214            ];
215
216            for start_pattern in &patterns {
217                if let Some(start_pos) = xml.find(start_pattern) {
218                    // Find the actual end of the opening tag
219                    let tag_start = start_pos + start_pattern.len();
220                    let remaining = &xml[start_pos..];
221
222                    if let Some(close_bracket) = remaining.find('>') {
223                        let content_start = start_pos + close_bracket + 1;
224
225                        // Look for the closing tag
226                        let end_patterns = [
227                            format!("</{}>", tag_name),
228                            format!("</{}:", tag_name),
229                            format!("</tns:{}>", tag_name),
230                            format!("</ns1:{}>", tag_name),
231                        ];
232
233                        for end_pattern in &end_patterns {
234                            if let Some(end_pos) = xml[content_start..].find(end_pattern) {
235                                let actual_end = content_start + end_pos;
236                                if content_start <= actual_end {
237                                    let content = &xml[content_start..actual_end];
238                                    return Some(decode_xml_content(content.trim()));
239                                }
240                            }
241                        }
242
243                        // Handle self-closing tags like <tag/>
244                        if remaining[..close_bracket].ends_with('/') {
245                            return Some(String::new());
246                        }
247                    }
248                }
249            }
250            None
251        }
252
253        fn decode_xml_content(content: &str) -> String {
254            content
255                .replace("&lt;", "<")
256                .replace("&gt;", ">")
257                .replace("&amp;", "&")
258                .replace("&quot;", "\"")
259                .replace("&apos;", "'")
260        }
261
262        // Generic request parsing using serde_xml_rs directly on operation XML
263        fn parse_request_from_xml<T>(xml: &str) -> Result<T, String>
264        where
265            T: for<'de> ::serde::Deserialize<'de>,
266        {
267            // The xml parameter is already the operation content (e.g., "<Add><Operand1>123</Operand1><Operand2>456</Operand2></Add>")
268            // Use serde_xml_rs to deserialize it directly
269            ::serde_xml_rs::from_str(xml)
270                .map_err(|e| format!("XML deserialization error: {} for XML: {}", e, xml))
271        }
272
273
274        // Generic response serialization using serde_xml_rs
275        fn serialize_response_to_xml<T>(response: &T) -> Result<String, String>
276        where
277            T: ::serde::Serialize,
278        {
279            // Use serde_xml_rs for serialization
280            ::serde_xml_rs::to_string(response)
281                .map_err(|e| format!("XML serialization error: {}", e))
282        }
283
284
285        fn create_soap_fault(error: &str) -> String {
286            format!(
287                r#"<?xml version="1.0" encoding="UTF-8"?>
288<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
289    <soap:Body>
290        <soap:Fault>
291            <faultcode>Server</faultcode>
292            <faultstring>{}</faultstring>
293        </soap:Fault>
294    </soap:Body>
295</soap:Envelope>"#,
296                error
297            )
298        }
299
300        async fn wsdl_handler() -> axum::response::Response {
301            let wsdl = #wsdl_content;
302
303            axum::response::Response::builder()
304                .status(200)
305                .header("Content-Type", "text/xml; charset=utf-8")
306                .body(wsdl.into())
307                .unwrap()
308        }
309    };
310
311    // Add the router code to the module
312    if let Some((brace, ref mut items)) = module.content {
313        // Parse the router code as items and add them
314        let router_items: syn::File = syn::parse2(router_code).unwrap();
315        items.extend(router_items.items);
316        module.content = Some((brace, items.clone()));
317    }
318
319    quote! { #module }
320}
321
322fn generate_operation_handlers(
323    operations: &[parser::SoapOperation],
324    namespace: &str,
325) -> TokenStream2 {
326    let mut handlers = Vec::new();
327
328    for operation in operations {
329        let op_name = &operation.name;
330        let func_name = &operation.function_name;
331        let request_type = &operation.request_type;
332        let response_type = &operation.response_type;
333
334        handlers.push(quote! {
335            if operation == #op_name {
336                // Generic XML parsing using serde
337                let request_data: #request_type = match parse_request_from_xml(&body_content) {
338                    Ok(data) => data,
339                    Err(e) => return Err(format!("Failed to parse request: {}", e)),
340                };
341
342                let result: #response_type = #func_name(request_data).await
343                    .map_err(|e| format!("Operation failed: {}", e))?;
344
345                // Generic response serialization using serde
346                let response_xml = match serialize_response_to_xml(&result) {
347                    Ok(xml) => xml,
348                    Err(e) => return Err(format!("Failed to serialize response: {}", e)),
349                };
350
351                return Ok(create_simple_soap_response(&response_xml, #op_name, #namespace));
352            }
353        });
354    }
355
356    quote! {
357        #(#handlers)*
358    }
359}