spring_axum_mybatis/
lib.rs

1use anyhow::Result;
2use once_cell::sync::OnceCell;
3use regex::Regex;
4use serde::{Deserialize, Serialize};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use tracing::info;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum StatementKind {
13    Select,
14    Insert,
15    Update,
16    Delete,
17}
18
19#[derive(Debug, Clone)]
20pub struct Statement {
21    pub id: String,                // namespace.id
22    pub sql: String,               // raw sql text
23    pub kind: StatementKind,
24    pub param_names: Vec<String>,  // extracted from #{param}
25    pub nodes: Option<Vec<Node>>,  // parsed dynamic SQL nodes (if/foreach)
26}
27
28#[derive(Default, Debug, Clone)]
29pub struct SqlRegistry {
30    pub stmts: HashMap<String, Statement>,
31}
32
33impl SqlRegistry {
34    pub fn get(&self, id: &str) -> Option<&Statement> {
35        self.stmts.get(id)
36    }
37
38    pub fn prepare(&self, id: &str, params: &HashMap<String, String>) -> Option<(String, Vec<String>)> {
39        let stmt = self.get(id)?;
40        if let Some(nodes) = &stmt.nodes {
41            let json_params = json_from_hash(params);
42            let mut out_sql = String::new();
43            let mut out_params: Vec<String> = Vec::new();
44            render_nodes(nodes, &json_params, &mut out_sql, &mut out_params).ok()?;
45            Some((normalize_sql(&out_sql), out_params))
46        } else {
47            // Replace #{name} with ? and collect values in order
48            let re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
49            let mut ordered: Vec<String> = Vec::new();
50            let sql = re
51                .replace_all(&stmt.sql, |caps: &regex::Captures| {
52                    let name = caps.get(1).unwrap().as_str().to_string();
53                    ordered.push(params.get(&name).cloned().unwrap_or_default());
54                    "?".to_string()
55                })
56                .to_string();
57            Some((normalize_sql(&sql), ordered))
58        }
59    }
60
61    pub fn prepare_json(&self, id: &str, params: &JsonValue) -> Option<(String, Vec<String>)> {
62        let stmt = self.get(id)?;
63        if let Some(nodes) = &stmt.nodes {
64            let mut out_sql = String::new();
65            let mut out_params: Vec<String> = Vec::new();
66            render_nodes(nodes, params, &mut out_sql, &mut out_params).ok()?;
67            Some((normalize_sql(&out_sql), out_params))
68        } else {
69            // Fallback: static replacement using top-level string map
70            let map = params.as_object()?;
71            let sparams: HashMap<String, String> = map
72                .iter()
73                .map(|(k, v)| (k.clone(), json_to_string(v)))
74                .collect();
75            self.prepare(id, &sparams)
76        }
77    }
78}
79
80static GLOBAL_REGISTRY: OnceCell<SqlRegistry> = OnceCell::new();
81
82pub fn init_global_from_dir_once(dir: impl AsRef<Path>) -> Result<()> {
83    if GLOBAL_REGISTRY.get().is_some() {
84        return Ok(());
85    }
86    let reg = load_from_dir(dir)?;
87    let _ = GLOBAL_REGISTRY.set(reg);
88    Ok(())
89}
90
91pub fn global_registry() -> Option<&'static SqlRegistry> {
92    GLOBAL_REGISTRY.get()
93}
94
95pub fn load_from_dir(dir: impl AsRef<Path>) -> Result<SqlRegistry> {
96    let mut reg = SqlRegistry::default();
97    let dir = dir.as_ref();
98    let mut files: Vec<PathBuf> = Vec::new();
99    for entry in fs::read_dir(dir)? {
100        let entry = entry?;
101        let path = entry.path();
102        if path.extension().map(|e| e == "xml").unwrap_or(false) {
103            files.push(path);
104        }
105    }
106    for file in files {
107        let text = fs::read_to_string(&file)?;
108        let mut ns = String::new();
109        // Extract mapper namespace
110    if let Some(caps) = Regex::new(r#"<mapper[^>]*?namespace=\"([^\"]+)\"[^>]*>"#)
111            .unwrap()
112            .captures(&text)
113        {
114            ns = caps.get(1).unwrap().as_str().to_string();
115        }
116        // Extract statements without using backreferences (not supported by Rust regex)
117        // select
118        for caps in Regex::new(r#"(?is)<select\s+id=\"([^\"]+)\"[^>]*>(.*?)</select>"#).unwrap().captures_iter(&text) {
119            let id_local = caps.get(1).unwrap().as_str().to_string();
120            let body = caps.get(2).unwrap().as_str().to_string();
121            let kind = StatementKind::Select;
122            let param_re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
123            let mut names: Vec<String> = Vec::new();
124            for m in param_re.captures_iter(&body) {
125                names.push(m.get(1).unwrap().as_str().to_string());
126            }
127            let nodes = if body.contains("<if") || body.contains("<foreach") {
128                Some(parse_nodes_from_body(&body)?)
129            } else {
130                None
131            };
132            let full_id = if ns.is_empty() { id_local.clone() } else { format!("{}.{}", ns, id_local) };
133            reg.stmts.insert(
134                full_id.clone(),
135                Statement { id: full_id, sql: body, kind, param_names: names, nodes },
136            );
137        }
138        // insert
139        for caps in Regex::new(r#"(?is)<insert\s+id=\"([^\"]+)\"[^>]*>(.*?)</insert>"#).unwrap().captures_iter(&text) {
140            let id_local = caps.get(1).unwrap().as_str().to_string();
141            let body = caps.get(2).unwrap().as_str().to_string();
142            let kind = StatementKind::Insert;
143            let param_re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
144            let mut names: Vec<String> = Vec::new();
145            for m in param_re.captures_iter(&body) {
146                names.push(m.get(1).unwrap().as_str().to_string());
147            }
148            let nodes = if body.contains("<if") || body.contains("<foreach") {
149                Some(parse_nodes_from_body(&body)?)
150            } else {
151                None
152            };
153            let full_id = if ns.is_empty() { id_local.clone() } else { format!("{}.{}", ns, id_local) };
154            reg.stmts.insert(
155                full_id.clone(),
156                Statement { id: full_id, sql: body, kind, param_names: names, nodes },
157            );
158        }
159        // update
160        for caps in Regex::new(r#"(?is)<update\s+id=\"([^\"]+)\"[^>]*>(.*?)</update>"#).unwrap().captures_iter(&text) {
161            let id_local = caps.get(1).unwrap().as_str().to_string();
162            let body = caps.get(2).unwrap().as_str().to_string();
163            let kind = StatementKind::Update;
164            let param_re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
165            let mut names: Vec<String> = Vec::new();
166            for m in param_re.captures_iter(&body) {
167                names.push(m.get(1).unwrap().as_str().to_string());
168            }
169            let nodes = if body.contains("<if") || body.contains("<foreach") {
170                Some(parse_nodes_from_body(&body)?)
171            } else {
172                None
173            };
174            let full_id = if ns.is_empty() { id_local.clone() } else { format!("{}.{}", ns, id_local) };
175            reg.stmts.insert(
176                full_id.clone(),
177                Statement { id: full_id, sql: body, kind, param_names: names, nodes },
178            );
179        }
180        // delete
181        for caps in Regex::new(r#"(?is)<delete\s+id=\"([^\"]+)\"[^>]*>(.*?)</delete>"#).unwrap().captures_iter(&text) {
182            let id_local = caps.get(1).unwrap().as_str().to_string();
183            let body = caps.get(2).unwrap().as_str().to_string();
184            let kind = StatementKind::Delete;
185            let param_re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
186            let mut names: Vec<String> = Vec::new();
187            for m in param_re.captures_iter(&body) {
188                names.push(m.get(1).unwrap().as_str().to_string());
189            }
190            let nodes = if body.contains("<if") || body.contains("<foreach") {
191                Some(parse_nodes_from_body(&body)?)
192            } else {
193                None
194            };
195            let full_id = if ns.is_empty() { id_local.clone() } else { format!("{}.{}", ns, id_local) };
196            reg.stmts.insert(
197                full_id.clone(),
198                Statement { id: full_id, sql: body, kind, param_names: names, nodes },
199            );
200        }
201    }
202    info!(count = reg.stmts.len(), "Loaded XML SQL mappers");
203    Ok(reg)
204}
205
206#[derive(Debug, Clone, Default)]
207pub struct NoopExecutor;
208
209impl NoopExecutor {
210    pub fn execute(&self, sql: &str, params: &[String]) -> Result<()> {
211        info!(sql = %sql, params = ?params, "Noop execute");
212        Ok(())
213    }
214}
215
216// ---------------- Dynamic SQL AST & Rendering ----------------
217
218#[derive(Debug, Clone)]
219pub enum Node {
220    Text(String),
221    Param(String),
222    If { test: String, children: Vec<Node> },
223    Foreach {
224        collection: String,
225        item: String,
226        open: Option<String>,
227        close: Option<String>,
228        separator: Option<String>,
229        children: Vec<Node>,
230    },
231}
232
233fn normalize_sql(s: &str) -> String {
234    let s = s.replace('\n', " ");
235    let re = Regex::new(r"\s+").unwrap();
236    re.replace_all(&s, " ").trim().to_string()
237}
238
239fn json_to_string(v: &JsonValue) -> String {
240    match v {
241        JsonValue::Null => "".to_string(),
242        JsonValue::Bool(b) => b.to_string(),
243        JsonValue::Number(n) => n.to_string(),
244        JsonValue::String(s) => s.clone(),
245        _ => v.to_string(),
246    }
247}
248
249fn json_from_hash(map: &HashMap<String, String>) -> JsonValue {
250    let m: serde_json::Map<String, JsonValue> = map
251        .iter()
252        .map(|(k, v)| (k.clone(), JsonValue::String(v.clone())))
253        .collect();
254    JsonValue::Object(m)
255}
256
257fn parse_nodes_from_body(body: &str) -> Result<Vec<Node>> {
258    let wrapped = format!("<root>{}</root>", body);
259    let mut reader = quick_xml::Reader::from_str(&wrapped);
260    reader.trim_text(true);
261    let mut buf = Vec::new();
262    let mut stack: Vec<Vec<Node>> = vec![Vec::new()];
263    // Hold current element attributes for if/foreach
264    #[derive(Debug, Clone)]
265    enum FrameKind { If { test: String }, Foreach { collection: String, item: String, open: Option<String>, close: Option<String>, separator: Option<String> } }
266    let mut frames: Vec<FrameKind> = Vec::new();
267
268    let param_re = Regex::new(r#"#\{([a-zA-Z_][a-zA-Z0-9_]*)\}"#).unwrap();
269
270    loop {
271        use quick_xml::events::Event;
272        match reader.read_event_into(&mut buf)? {
273            Event::Start(e) => {
274                let name = String::from_utf8_lossy(e.name().as_ref()).to_string();
275                if name == "if" {
276                    let mut test = String::new();
277                    for a in e.attributes().with_checks(false) {
278                        let a = a?;
279                        if a.key.as_ref() == b"test" {
280                            test = a.unescape_value()?.into_owned();
281                        }
282                    }
283                    frames.push(FrameKind::If { test });
284                    stack.push(Vec::new());
285                } else if name == "foreach" {
286                    let mut collection = String::new();
287                    let mut item = String::from("item");
288                    let mut open = None;
289                    let mut close = None;
290                    let mut separator = None;
291                    for a in e.attributes().with_checks(false) {
292                        let a = a?;
293                        match a.key.as_ref() {
294                            b"collection" => collection = a.unescape_value()?.into_owned(),
295                            b"item" => item = a.unescape_value()?.into_owned(),
296                            b"open" => open = Some(a.unescape_value()?.into_owned()),
297                            b"close" => close = Some(a.unescape_value()?.into_owned()),
298                            b"separator" => separator = Some(a.unescape_value()?.into_owned()),
299                            _ => {}
300                        }
301                    }
302                    frames.push(FrameKind::Foreach { collection, item, open, close, separator });
303                    stack.push(Vec::new());
304                } else {
305                    // Unknown tags: we don't create a frame; just start a group to collect children
306                    stack.push(Vec::new());
307                }
308            }
309            Event::Text(t) => {
310                let txt = t.unescape()?.to_string();
311                // Split into text/param nodes
312                let mut last = 0;
313                for caps in param_re.captures_iter(&txt) {
314                    let m = caps.get(0).unwrap();
315                    if m.start() > last {
316                        stack.last_mut().unwrap().push(Node::Text(txt[last..m.start()].to_string()));
317                    }
318                    let name = caps.get(1).unwrap().as_str().to_string();
319                    stack.last_mut().unwrap().push(Node::Param(name));
320                    last = m.end();
321                }
322                if last < txt.len() {
323                    stack.last_mut().unwrap().push(Node::Text(txt[last..].to_string()));
324                }
325            }
326            Event::End(e) => {
327                let name = String::from_utf8_lossy(e.name().as_ref()).to_string();
328                if let Some(frame) = frames.pop() {
329                    let children = stack.pop().unwrap();
330                    match frame {
331                        FrameKind::If { test } => {
332                            if name == "if" {
333                                stack.last_mut().unwrap().push(Node::If { test, children });
334                            } else {
335                                // unknown tag: inline children
336                                stack.last_mut().unwrap().extend(children);
337                            }
338                        }
339                        FrameKind::Foreach { collection, item, open, close, separator } => {
340                            if name == "foreach" {
341                                stack.last_mut().unwrap().push(Node::Foreach { collection, item, open, close, separator, children });
342                            } else {
343                                stack.last_mut().unwrap().extend(children);
344                            }
345                        }
346                    }
347                } else {
348                    // Unknown tag end: just inline collected children
349                    let children = stack.pop().unwrap_or_default();
350                    stack.last_mut().unwrap().extend(children);
351                }
352            }
353            Event::Eof => break,
354            _ => {}
355        }
356        buf.clear();
357    }
358    Ok(stack.pop().unwrap_or_default())
359}
360
361fn eval_test(expr: &str, params: &JsonValue) -> bool {
362    // Very simple evaluator: supports `a != null`, `a == null`, `a != ''`, `a == ''`, combined with `and`/`or`
363    let tokens: Vec<&str> = expr.split_whitespace().collect();
364    // Parse sequence: term (and|or term)*
365    fn term(tok: &[&str], params: &JsonValue) -> (bool, usize) {
366        if tok.len() < 3 { return (false, tok.len()); }
367        let left = tok[0];
368        let op = tok[1];
369        let right = tok[2];
370        let v = lookup(params, left);
371        let res = match (op, right) {
372            ("!=", "null") => !v.is_null(),
373            ("==", "null") => v.is_null(),
374            ("!=", "''") | ("!=", "\"\"") => json_to_string(&v) != "",
375            ("==", "''") | ("==", "\"\"") => json_to_string(&v) == "",
376            _ => false,
377        };
378        (res, 3)
379    }
380    let mut idx = 0;
381    let mut acc = false;
382    let mut first = true;
383    while idx < tokens.len() {
384        let (tval, consumed) = term(&tokens[idx..], params);
385        if first { acc = tval; first = false; } else { acc = tval; }
386        idx += consumed;
387        if idx < tokens.len() {
388            let op = tokens[idx];
389            idx += 1;
390            let (tval2, consumed2) = term(&tokens[idx..], params);
391            match op {
392                "and" => acc = acc && tval2,
393                "or" => acc = acc || tval2,
394                _ => {}
395            }
396            idx += consumed2;
397        }
398    }
399    acc
400}
401
402fn lookup<'a>(params: &'a JsonValue, path: &str) -> &'a JsonValue {
403    // only top-level keys supported for now
404    params.get(path).unwrap_or(&JsonValue::Null)
405}
406
407fn render_nodes(nodes: &[Node], params: &JsonValue, out_sql: &mut String, out_params: &mut Vec<String>) -> Result<()> {
408    for n in nodes {
409        match n {
410            Node::Text(t) => out_sql.push_str(t),
411            Node::Param(name) => {
412                out_sql.push_str("?");
413                out_params.push(json_to_string(lookup(params, name)));
414            }
415            Node::If { test, children } => {
416                if eval_test(test, params) {
417                    render_nodes(children, params, out_sql, out_params)?;
418                }
419            }
420            Node::Foreach { collection, item, open, close, separator, children } => {
421                let arr = lookup(params, collection);
422                if let JsonValue::Array(items) = arr {
423                    if let Some(o) = open { out_sql.push_str(o); }
424                    for (i, it) in items.iter().enumerate() {
425                        let mut local = params.clone();
426                        if let JsonValue::Object(mut obj) = local {
427                            obj.insert(item.clone(), it.clone());
428                            local = JsonValue::Object(obj);
429                        }
430                        render_nodes(children, &local, out_sql, out_params)?;
431                        if i + 1 < items.len() {
432                            if let Some(sep) = separator { out_sql.push_str(sep); }
433                        }
434                    }
435                    if let Some(c) = close { out_sql.push_str(c); }
436                }
437            }
438        }
439    }
440    Ok(())
441}