Skip to main content

tensorlogic_cli/
macros.rs

1//! Macro system for defining reusable logical patterns
2//!
3//! This module provides a powerful macro system that allows users to define
4//! parameterized logical patterns that can be reused throughout their expressions.
5//!
6//! # Syntax
7//!
8//! Macros are defined using the following syntax:
9//!
10//! ```text
11//! DEFINE MACRO name(param1, param2, ...) = expression
12//! ```
13//!
14//! # Examples
15//!
16//! ```text
17//! // Define a transitive relation macro
18//! DEFINE MACRO transitive(R, x, z) = EXISTS y. (R(x, y) AND R(y, z))
19//!
20//! // Define a symmetric relation macro
21//! DEFINE MACRO symmetric(R, x, y) = R(x, y) AND R(y, x)
22//!
23//! // Define a reflexive relation macro
24//! DEFINE MACRO reflexive(R, x) = R(x, x)
25//!
26//! // Define an equivalence relation macro
27//! DEFINE MACRO equivalence(R, x, y) = reflexive(R, x) AND reflexive(R, y) AND symmetric(R, x, y)
28//!
29//! // Use macros in expressions
30//! transitive(friend, Alice, Bob)
31//! ```
32
33#![allow(dead_code)]
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39/// A macro definition with parameters and body
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub struct MacroDef {
42    /// Macro name
43    pub name: String,
44
45    /// Parameter names
46    pub params: Vec<String>,
47
48    /// Macro body (as expression string)
49    pub body: String,
50}
51
52impl MacroDef {
53    /// Create a new macro definition
54    pub fn new(name: String, params: Vec<String>, body: String) -> Self {
55        Self { name, params, body }
56    }
57
58    /// Get the arity (number of parameters)
59    pub fn arity(&self) -> usize {
60        self.params.len()
61    }
62
63    /// Validate that the macro definition is well-formed
64    pub fn validate(&self) -> Result<()> {
65        if self.name.is_empty() {
66            return Err(anyhow!("Macro name cannot be empty"));
67        }
68
69        if !self
70            .name
71            .chars()
72            .next()
73            .expect("name non-empty after is_empty check above")
74            .is_alphabetic()
75        {
76            return Err(anyhow!(
77                "Macro name must start with a letter: {}",
78                self.name
79            ));
80        }
81
82        if self.params.is_empty() {
83            return Err(anyhow!("Macro must have at least one parameter"));
84        }
85
86        // Check for duplicate parameters
87        let mut seen = HashMap::new();
88        for (idx, param) in self.params.iter().enumerate() {
89            if let Some(prev_idx) = seen.insert(param, idx) {
90                return Err(anyhow!(
91                    "Duplicate parameter '{}' at positions {} and {}",
92                    param,
93                    prev_idx,
94                    idx
95                ));
96            }
97        }
98
99        if self.body.is_empty() {
100            return Err(anyhow!("Macro body cannot be empty"));
101        }
102
103        Ok(())
104    }
105
106    /// Expand the macro with the given arguments
107    pub fn expand(&self, args: &[String]) -> Result<String> {
108        if args.len() != self.params.len() {
109            return Err(anyhow!(
110                "Macro {} expects {} arguments, got {}",
111                self.name,
112                self.params.len(),
113                args.len()
114            ));
115        }
116
117        // Create substitution map
118        let mut substitutions: HashMap<&str, &str> = HashMap::new();
119        for (param, arg) in self.params.iter().zip(args.iter()) {
120            substitutions.insert(param.as_str(), arg.as_str());
121        }
122
123        // Perform substitution in the body
124        let mut result = self.body.clone();
125
126        // Sort parameters by length (descending) to handle overlapping names correctly
127        let mut sorted_params: Vec<&String> = self.params.iter().collect();
128        sorted_params.sort_by_key(|p| std::cmp::Reverse(p.len()));
129
130        for param in sorted_params {
131            if let Some(arg) = substitutions.get(param.as_str()) {
132                // Use word boundaries to avoid partial replacements
133                result = replace_word(&result, param, arg);
134            }
135        }
136
137        Ok(result)
138    }
139}
140
141/// Replace whole words only (not substrings)
142fn replace_word(text: &str, from: &str, to: &str) -> String {
143    let mut result = String::new();
144    let mut current_word = String::new();
145
146    for ch in text.chars() {
147        if ch.is_alphanumeric() || ch == '_' {
148            current_word.push(ch);
149        } else {
150            if current_word == from {
151                result.push_str(to);
152            } else {
153                result.push_str(&current_word);
154            }
155            current_word.clear();
156            result.push(ch);
157        }
158    }
159
160    // Handle final word
161    if current_word == from {
162        result.push_str(to);
163    } else {
164        result.push_str(&current_word);
165    }
166
167    result
168}
169
170/// Registry for managing macro definitions
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct MacroRegistry {
173    /// Map of macro name to definition
174    macros: HashMap<String, MacroDef>,
175}
176
177impl MacroRegistry {
178    /// Create a new empty macro registry
179    pub fn new() -> Self {
180        Self {
181            macros: HashMap::new(),
182        }
183    }
184
185    /// Create a registry with built-in macros
186    pub fn with_builtins() -> Self {
187        let mut registry = Self::new();
188
189        // Add common built-in macros
190        let builtins = vec![
191            MacroDef::new(
192                "transitive".to_string(),
193                vec!["R".to_string(), "x".to_string(), "z".to_string()],
194                "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
195            ),
196            MacroDef::new(
197                "symmetric".to_string(),
198                vec!["R".to_string(), "x".to_string(), "y".to_string()],
199                "R(x, y) AND R(y, x)".to_string(),
200            ),
201            MacroDef::new(
202                "reflexive".to_string(),
203                vec!["R".to_string(), "x".to_string()],
204                "R(x, x)".to_string(),
205            ),
206            MacroDef::new(
207                "antisymmetric".to_string(),
208                vec!["R".to_string(), "x".to_string(), "y".to_string()],
209                "(R(x, y) AND R(y, x)) IMPLIES (x = y)".to_string(),
210            ),
211            MacroDef::new(
212                "total".to_string(),
213                vec!["R".to_string(), "x".to_string(), "y".to_string()],
214                "R(x, y) OR R(y, x)".to_string(),
215            ),
216        ];
217
218        for macro_def in builtins {
219            let _ = registry.define(macro_def);
220        }
221
222        registry
223    }
224
225    /// Define a new macro
226    pub fn define(&mut self, macro_def: MacroDef) -> Result<()> {
227        macro_def.validate()?;
228        self.macros.insert(macro_def.name.clone(), macro_def);
229        Ok(())
230    }
231
232    /// Get a macro definition by name
233    pub fn get(&self, name: &str) -> Option<&MacroDef> {
234        self.macros.get(name)
235    }
236
237    /// Check if a macro is defined
238    pub fn contains(&self, name: &str) -> bool {
239        self.macros.contains_key(name)
240    }
241
242    /// Remove a macro definition
243    pub fn undefine(&mut self, name: &str) -> Option<MacroDef> {
244        self.macros.remove(name)
245    }
246
247    /// List all defined macros
248    pub fn list(&self) -> Vec<&MacroDef> {
249        self.macros.values().collect()
250    }
251
252    /// Clear all macros
253    pub fn clear(&mut self) {
254        self.macros.clear();
255    }
256
257    /// Get the number of defined macros
258    pub fn len(&self) -> usize {
259        self.macros.len()
260    }
261
262    /// Check if the registry is empty
263    pub fn is_empty(&self) -> bool {
264        self.macros.is_empty()
265    }
266
267    /// Expand a macro call
268    pub fn expand(&self, name: &str, args: &[String]) -> Result<String> {
269        let macro_def = self
270            .get(name)
271            .ok_or_else(|| anyhow!("Undefined macro: {}", name))?;
272        macro_def.expand(args)
273    }
274
275    /// Recursively expand all macros in an expression
276    pub fn expand_all(&self, expr: &str) -> Result<String> {
277        let mut result = expr.to_string();
278        let mut changed = true;
279        let mut iterations = 0;
280        const MAX_ITERATIONS: usize = 100; // Prevent infinite loops
281
282        while changed && iterations < MAX_ITERATIONS {
283            changed = false;
284            iterations += 1;
285
286            // Try to find and expand macro calls
287            for (name, macro_def) in &self.macros {
288                if let Some(expanded) = self.try_expand_macro(&result, name, macro_def)? {
289                    result = expanded;
290                    changed = true;
291                    break; // Start over to ensure proper nesting
292                }
293            }
294        }
295
296        if iterations >= MAX_ITERATIONS {
297            return Err(anyhow!(
298                "Macro expansion exceeded maximum iterations (possible circular definition)"
299            ));
300        }
301
302        Ok(result)
303    }
304
305    /// Try to expand a specific macro in the expression
306    fn try_expand_macro(
307        &self,
308        expr: &str,
309        name: &str,
310        macro_def: &MacroDef,
311    ) -> Result<Option<String>> {
312        // Simple pattern matching for macro calls: name(arg1, arg2, ...)
313        if let Some(pos) = expr.find(name) {
314            // Check if this is actually a macro call (followed by '(')
315            let after_name = pos + name.len();
316            if after_name < expr.len() && expr.chars().nth(after_name) == Some('(') {
317                // Extract arguments
318                if let Some(args) = self.extract_args(&expr[after_name..])? {
319                    let expanded = macro_def.expand(&args)?;
320                    let mut result = String::new();
321                    result.push_str(&expr[..pos]);
322                    result.push_str(&expanded);
323                    result.push_str(
324                        &expr[after_name + self.find_matching_paren(&expr[after_name..])? + 1..],
325                    );
326                    return Ok(Some(result));
327                }
328            }
329        }
330        Ok(None)
331    }
332
333    /// Extract arguments from a function call
334    fn extract_args(&self, text: &str) -> Result<Option<Vec<String>>> {
335        if !text.starts_with('(') {
336            return Ok(None);
337        }
338
339        let closing = self.find_matching_paren(text)?;
340        let args_str = &text[1..closing];
341
342        if args_str.trim().is_empty() {
343            return Ok(Some(Vec::new()));
344        }
345
346        // Split by commas (respecting nested parentheses)
347        let mut args = Vec::new();
348        let mut current_arg = String::new();
349        let mut depth = 0;
350
351        for ch in args_str.chars() {
352            match ch {
353                '(' => {
354                    depth += 1;
355                    current_arg.push(ch);
356                }
357                ')' => {
358                    depth -= 1;
359                    current_arg.push(ch);
360                }
361                ',' if depth == 0 => {
362                    args.push(current_arg.trim().to_string());
363                    current_arg.clear();
364                }
365                _ => {
366                    current_arg.push(ch);
367                }
368            }
369        }
370
371        if !current_arg.is_empty() {
372            args.push(current_arg.trim().to_string());
373        }
374
375        Ok(Some(args))
376    }
377
378    /// Find the position of the matching closing parenthesis
379    fn find_matching_paren(&self, text: &str) -> Result<usize> {
380        let mut depth = 0;
381        for (i, ch) in text.chars().enumerate() {
382            match ch {
383                '(' => depth += 1,
384                ')' => {
385                    depth -= 1;
386                    if depth == 0 {
387                        return Ok(i);
388                    }
389                }
390                _ => {}
391            }
392        }
393        Err(anyhow!("Unmatched parenthesis"))
394    }
395}
396
397/// Parse a macro definition from a string
398///
399/// Expected format: `DEFINE MACRO name(param1, param2, ...) = body`
400pub fn parse_macro_definition(input: &str) -> Result<MacroDef> {
401    let input = input.trim();
402
403    // Check for DEFINE MACRO prefix
404    if !input.starts_with("DEFINE MACRO") && !input.starts_with("MACRO") {
405        return Err(anyhow!(
406            "Macro definition must start with 'DEFINE MACRO' or 'MACRO'"
407        ));
408    }
409
410    let input = if let Some(stripped) = input.strip_prefix("DEFINE MACRO") {
411        stripped
412    } else if let Some(stripped) = input.strip_prefix("MACRO") {
413        stripped
414    } else {
415        unreachable!("Already checked for prefixes above")
416    }
417    .trim();
418
419    // Find the equals sign
420    let eq_pos = input
421        .find('=')
422        .ok_or_else(|| anyhow!("Macro definition must contain '='"))?;
423
424    let signature = input[..eq_pos].trim();
425    let body = input[eq_pos + 1..].trim().to_string();
426
427    // Parse signature: name(param1, param2, ...)
428    let open_paren = signature
429        .find('(')
430        .ok_or_else(|| anyhow!("Macro definition must have parameter list"))?;
431
432    let name = signature[..open_paren].trim().to_string();
433
434    let close_paren = signature
435        .rfind(')')
436        .ok_or_else(|| anyhow!("Unmatched parenthesis in macro signature"))?;
437
438    let params_str = &signature[open_paren + 1..close_paren];
439    let params: Vec<String> = if params_str.trim().is_empty() {
440        return Err(anyhow!("Macro must have at least one parameter"));
441    } else {
442        params_str
443            .split(',')
444            .map(|s| s.trim().to_string())
445            .collect()
446    };
447
448    Ok(MacroDef::new(name, params, body))
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_macro_def_creation() {
457        let macro_def = MacroDef::new(
458            "test".to_string(),
459            vec!["x".to_string(), "y".to_string()],
460            "pred(x, y)".to_string(),
461        );
462        assert_eq!(macro_def.name, "test");
463        assert_eq!(macro_def.arity(), 2);
464    }
465
466    #[test]
467    fn test_macro_validation() {
468        let valid = MacroDef::new(
469            "test".to_string(),
470            vec!["x".to_string()],
471            "pred(x)".to_string(),
472        );
473        assert!(valid.validate().is_ok());
474
475        let invalid_name =
476            MacroDef::new("".to_string(), vec!["x".to_string()], "pred(x)".to_string());
477        assert!(invalid_name.validate().is_err());
478
479        let duplicate_params = MacroDef::new(
480            "test".to_string(),
481            vec!["x".to_string(), "x".to_string()],
482            "pred(x)".to_string(),
483        );
484        assert!(duplicate_params.validate().is_err());
485    }
486
487    #[test]
488    fn test_macro_expansion() {
489        let macro_def = MacroDef::new(
490            "test".to_string(),
491            vec!["x".to_string(), "y".to_string()],
492            "pred(x, y) AND pred(y, x)".to_string(),
493        );
494
495        let expanded = macro_def
496            .expand(&["a".to_string(), "b".to_string()])
497            .expect("macro expansion should succeed");
498        assert_eq!(expanded, "pred(a, b) AND pred(b, a)");
499    }
500
501    #[test]
502    fn test_macro_registry() {
503        let mut registry = MacroRegistry::new();
504
505        let macro_def = MacroDef::new(
506            "test".to_string(),
507            vec!["x".to_string()],
508            "pred(x)".to_string(),
509        );
510
511        registry
512            .define(macro_def)
513            .expect("macro define should succeed");
514        assert!(registry.contains("test"));
515        assert_eq!(registry.len(), 1);
516
517        let expanded = registry
518            .expand("test", &["a".to_string()])
519            .expect("macro expand should succeed");
520        assert_eq!(expanded, "pred(a)");
521    }
522
523    #[test]
524    fn test_builtin_macros() {
525        let registry = MacroRegistry::with_builtins();
526        assert!(registry.contains("transitive"));
527        assert!(registry.contains("symmetric"));
528        assert!(registry.contains("reflexive"));
529    }
530
531    #[test]
532    fn test_parse_macro_definition() {
533        let input = "DEFINE MACRO test(x, y) = pred(x, y)";
534        let macro_def = parse_macro_definition(input).expect("macro definition should parse");
535        assert_eq!(macro_def.name, "test");
536        assert_eq!(macro_def.params, vec!["x", "y"]);
537        assert_eq!(macro_def.body, "pred(x, y)");
538    }
539
540    #[test]
541    fn test_replace_word() {
542        assert_eq!(replace_word("x + y", "x", "a"), "a + y");
543        assert_eq!(replace_word("xyz", "x", "a"), "xyz"); // Shouldn't replace
544        assert_eq!(replace_word("x(x, x)", "x", "a"), "a(a, a)");
545    }
546
547    #[test]
548    fn test_macro_expansion_recursive() {
549        let mut registry = MacroRegistry::new();
550
551        let transitive = MacroDef::new(
552            "trans".to_string(),
553            vec!["R".to_string(), "x".to_string(), "z".to_string()],
554            "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
555        );
556        registry
557            .define(transitive)
558            .expect("transitive macro define should succeed");
559
560        let expr = "trans(friend, Alice, Bob)";
561        let expanded = registry
562            .expand_all(expr)
563            .expect("macro expand_all should succeed");
564        assert!(expanded.contains("EXISTS y"));
565        assert!(expanded.contains("friend(Alice, y)"));
566        assert!(expanded.contains("friend(y, Bob)"));
567    }
568}