tensorlogic_cli/
macros.rs

1//! Macro system for defining reusable logical patterns
2//!
3//! This module provides a powerful macro system that allows users to define
4//! parameterized logical patterns that can be reused throughout their expressions.
5//!
6//! # Syntax
7//!
8//! Macros are defined using the following syntax:
9//!
10//! ```text
11//! DEFINE MACRO name(param1, param2, ...) = expression
12//! ```
13//!
14//! # Examples
15//!
16//! ```text
17//! // Define a transitive relation macro
18//! DEFINE MACRO transitive(R, x, z) = EXISTS y. (R(x, y) AND R(y, z))
19//!
20//! // Define a symmetric relation macro
21//! DEFINE MACRO symmetric(R, x, y) = R(x, y) AND R(y, x)
22//!
23//! // Define a reflexive relation macro
24//! DEFINE MACRO reflexive(R, x) = R(x, x)
25//!
26//! // Define an equivalence relation macro
27//! DEFINE MACRO equivalence(R, x, y) = reflexive(R, x) AND reflexive(R, y) AND symmetric(R, x, y)
28//!
29//! // Use macros in expressions
30//! transitive(friend, Alice, Bob)
31//! ```
32
33#![allow(dead_code)]
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39/// A macro definition with parameters and body
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub struct MacroDef {
42    /// Macro name
43    pub name: String,
44
45    /// Parameter names
46    pub params: Vec<String>,
47
48    /// Macro body (as expression string)
49    pub body: String,
50}
51
52impl MacroDef {
53    /// Create a new macro definition
54    pub fn new(name: String, params: Vec<String>, body: String) -> Self {
55        Self { name, params, body }
56    }
57
58    /// Get the arity (number of parameters)
59    pub fn arity(&self) -> usize {
60        self.params.len()
61    }
62
63    /// Validate that the macro definition is well-formed
64    pub fn validate(&self) -> Result<()> {
65        if self.name.is_empty() {
66            return Err(anyhow!("Macro name cannot be empty"));
67        }
68
69        if !self.name.chars().next().unwrap().is_alphabetic() {
70            return Err(anyhow!(
71                "Macro name must start with a letter: {}",
72                self.name
73            ));
74        }
75
76        if self.params.is_empty() {
77            return Err(anyhow!("Macro must have at least one parameter"));
78        }
79
80        // Check for duplicate parameters
81        let mut seen = HashMap::new();
82        for (idx, param) in self.params.iter().enumerate() {
83            if let Some(prev_idx) = seen.insert(param, idx) {
84                return Err(anyhow!(
85                    "Duplicate parameter '{}' at positions {} and {}",
86                    param,
87                    prev_idx,
88                    idx
89                ));
90            }
91        }
92
93        if self.body.is_empty() {
94            return Err(anyhow!("Macro body cannot be empty"));
95        }
96
97        Ok(())
98    }
99
100    /// Expand the macro with the given arguments
101    pub fn expand(&self, args: &[String]) -> Result<String> {
102        if args.len() != self.params.len() {
103            return Err(anyhow!(
104                "Macro {} expects {} arguments, got {}",
105                self.name,
106                self.params.len(),
107                args.len()
108            ));
109        }
110
111        // Create substitution map
112        let mut substitutions: HashMap<&str, &str> = HashMap::new();
113        for (param, arg) in self.params.iter().zip(args.iter()) {
114            substitutions.insert(param.as_str(), arg.as_str());
115        }
116
117        // Perform substitution in the body
118        let mut result = self.body.clone();
119
120        // Sort parameters by length (descending) to handle overlapping names correctly
121        let mut sorted_params: Vec<&String> = self.params.iter().collect();
122        sorted_params.sort_by_key(|p| std::cmp::Reverse(p.len()));
123
124        for param in sorted_params {
125            if let Some(arg) = substitutions.get(param.as_str()) {
126                // Use word boundaries to avoid partial replacements
127                result = replace_word(&result, param, arg);
128            }
129        }
130
131        Ok(result)
132    }
133}
134
135/// Replace whole words only (not substrings)
136fn replace_word(text: &str, from: &str, to: &str) -> String {
137    let mut result = String::new();
138    let mut current_word = String::new();
139
140    for ch in text.chars() {
141        if ch.is_alphanumeric() || ch == '_' {
142            current_word.push(ch);
143        } else {
144            if current_word == from {
145                result.push_str(to);
146            } else {
147                result.push_str(&current_word);
148            }
149            current_word.clear();
150            result.push(ch);
151        }
152    }
153
154    // Handle final word
155    if current_word == from {
156        result.push_str(to);
157    } else {
158        result.push_str(&current_word);
159    }
160
161    result
162}
163
164/// Registry for managing macro definitions
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct MacroRegistry {
167    /// Map of macro name to definition
168    macros: HashMap<String, MacroDef>,
169}
170
171impl MacroRegistry {
172    /// Create a new empty macro registry
173    pub fn new() -> Self {
174        Self {
175            macros: HashMap::new(),
176        }
177    }
178
179    /// Create a registry with built-in macros
180    pub fn with_builtins() -> Self {
181        let mut registry = Self::new();
182
183        // Add common built-in macros
184        let builtins = vec![
185            MacroDef::new(
186                "transitive".to_string(),
187                vec!["R".to_string(), "x".to_string(), "z".to_string()],
188                "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
189            ),
190            MacroDef::new(
191                "symmetric".to_string(),
192                vec!["R".to_string(), "x".to_string(), "y".to_string()],
193                "R(x, y) AND R(y, x)".to_string(),
194            ),
195            MacroDef::new(
196                "reflexive".to_string(),
197                vec!["R".to_string(), "x".to_string()],
198                "R(x, x)".to_string(),
199            ),
200            MacroDef::new(
201                "antisymmetric".to_string(),
202                vec!["R".to_string(), "x".to_string(), "y".to_string()],
203                "(R(x, y) AND R(y, x)) IMPLIES (x = y)".to_string(),
204            ),
205            MacroDef::new(
206                "total".to_string(),
207                vec!["R".to_string(), "x".to_string(), "y".to_string()],
208                "R(x, y) OR R(y, x)".to_string(),
209            ),
210        ];
211
212        for macro_def in builtins {
213            let _ = registry.define(macro_def);
214        }
215
216        registry
217    }
218
219    /// Define a new macro
220    pub fn define(&mut self, macro_def: MacroDef) -> Result<()> {
221        macro_def.validate()?;
222        self.macros.insert(macro_def.name.clone(), macro_def);
223        Ok(())
224    }
225
226    /// Get a macro definition by name
227    pub fn get(&self, name: &str) -> Option<&MacroDef> {
228        self.macros.get(name)
229    }
230
231    /// Check if a macro is defined
232    pub fn contains(&self, name: &str) -> bool {
233        self.macros.contains_key(name)
234    }
235
236    /// Remove a macro definition
237    pub fn undefine(&mut self, name: &str) -> Option<MacroDef> {
238        self.macros.remove(name)
239    }
240
241    /// List all defined macros
242    pub fn list(&self) -> Vec<&MacroDef> {
243        self.macros.values().collect()
244    }
245
246    /// Clear all macros
247    pub fn clear(&mut self) {
248        self.macros.clear();
249    }
250
251    /// Get the number of defined macros
252    pub fn len(&self) -> usize {
253        self.macros.len()
254    }
255
256    /// Check if the registry is empty
257    pub fn is_empty(&self) -> bool {
258        self.macros.is_empty()
259    }
260
261    /// Expand a macro call
262    pub fn expand(&self, name: &str, args: &[String]) -> Result<String> {
263        let macro_def = self
264            .get(name)
265            .ok_or_else(|| anyhow!("Undefined macro: {}", name))?;
266        macro_def.expand(args)
267    }
268
269    /// Recursively expand all macros in an expression
270    pub fn expand_all(&self, expr: &str) -> Result<String> {
271        let mut result = expr.to_string();
272        let mut changed = true;
273        let mut iterations = 0;
274        const MAX_ITERATIONS: usize = 100; // Prevent infinite loops
275
276        while changed && iterations < MAX_ITERATIONS {
277            changed = false;
278            iterations += 1;
279
280            // Try to find and expand macro calls
281            for (name, macro_def) in &self.macros {
282                if let Some(expanded) = self.try_expand_macro(&result, name, macro_def)? {
283                    result = expanded;
284                    changed = true;
285                    break; // Start over to ensure proper nesting
286                }
287            }
288        }
289
290        if iterations >= MAX_ITERATIONS {
291            return Err(anyhow!(
292                "Macro expansion exceeded maximum iterations (possible circular definition)"
293            ));
294        }
295
296        Ok(result)
297    }
298
299    /// Try to expand a specific macro in the expression
300    fn try_expand_macro(
301        &self,
302        expr: &str,
303        name: &str,
304        macro_def: &MacroDef,
305    ) -> Result<Option<String>> {
306        // Simple pattern matching for macro calls: name(arg1, arg2, ...)
307        if let Some(pos) = expr.find(name) {
308            // Check if this is actually a macro call (followed by '(')
309            let after_name = pos + name.len();
310            if after_name < expr.len() && expr.chars().nth(after_name) == Some('(') {
311                // Extract arguments
312                if let Some(args) = self.extract_args(&expr[after_name..])? {
313                    let expanded = macro_def.expand(&args)?;
314                    let mut result = String::new();
315                    result.push_str(&expr[..pos]);
316                    result.push_str(&expanded);
317                    result.push_str(
318                        &expr[after_name + self.find_matching_paren(&expr[after_name..])? + 1..],
319                    );
320                    return Ok(Some(result));
321                }
322            }
323        }
324        Ok(None)
325    }
326
327    /// Extract arguments from a function call
328    fn extract_args(&self, text: &str) -> Result<Option<Vec<String>>> {
329        if !text.starts_with('(') {
330            return Ok(None);
331        }
332
333        let closing = self.find_matching_paren(text)?;
334        let args_str = &text[1..closing];
335
336        if args_str.trim().is_empty() {
337            return Ok(Some(Vec::new()));
338        }
339
340        // Split by commas (respecting nested parentheses)
341        let mut args = Vec::new();
342        let mut current_arg = String::new();
343        let mut depth = 0;
344
345        for ch in args_str.chars() {
346            match ch {
347                '(' => {
348                    depth += 1;
349                    current_arg.push(ch);
350                }
351                ')' => {
352                    depth -= 1;
353                    current_arg.push(ch);
354                }
355                ',' if depth == 0 => {
356                    args.push(current_arg.trim().to_string());
357                    current_arg.clear();
358                }
359                _ => {
360                    current_arg.push(ch);
361                }
362            }
363        }
364
365        if !current_arg.is_empty() {
366            args.push(current_arg.trim().to_string());
367        }
368
369        Ok(Some(args))
370    }
371
372    /// Find the position of the matching closing parenthesis
373    fn find_matching_paren(&self, text: &str) -> Result<usize> {
374        let mut depth = 0;
375        for (i, ch) in text.chars().enumerate() {
376            match ch {
377                '(' => depth += 1,
378                ')' => {
379                    depth -= 1;
380                    if depth == 0 {
381                        return Ok(i);
382                    }
383                }
384                _ => {}
385            }
386        }
387        Err(anyhow!("Unmatched parenthesis"))
388    }
389}
390
391/// Parse a macro definition from a string
392///
393/// Expected format: `DEFINE MACRO name(param1, param2, ...) = body`
394pub fn parse_macro_definition(input: &str) -> Result<MacroDef> {
395    let input = input.trim();
396
397    // Check for DEFINE MACRO prefix
398    if !input.starts_with("DEFINE MACRO") && !input.starts_with("MACRO") {
399        return Err(anyhow!(
400            "Macro definition must start with 'DEFINE MACRO' or 'MACRO'"
401        ));
402    }
403
404    let input = if let Some(stripped) = input.strip_prefix("DEFINE MACRO") {
405        stripped
406    } else if let Some(stripped) = input.strip_prefix("MACRO") {
407        stripped
408    } else {
409        unreachable!("Already checked for prefixes above")
410    }
411    .trim();
412
413    // Find the equals sign
414    let eq_pos = input
415        .find('=')
416        .ok_or_else(|| anyhow!("Macro definition must contain '='"))?;
417
418    let signature = input[..eq_pos].trim();
419    let body = input[eq_pos + 1..].trim().to_string();
420
421    // Parse signature: name(param1, param2, ...)
422    let open_paren = signature
423        .find('(')
424        .ok_or_else(|| anyhow!("Macro definition must have parameter list"))?;
425
426    let name = signature[..open_paren].trim().to_string();
427
428    let close_paren = signature
429        .rfind(')')
430        .ok_or_else(|| anyhow!("Unmatched parenthesis in macro signature"))?;
431
432    let params_str = &signature[open_paren + 1..close_paren];
433    let params: Vec<String> = if params_str.trim().is_empty() {
434        return Err(anyhow!("Macro must have at least one parameter"));
435    } else {
436        params_str
437            .split(',')
438            .map(|s| s.trim().to_string())
439            .collect()
440    };
441
442    Ok(MacroDef::new(name, params, body))
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_macro_def_creation() {
451        let macro_def = MacroDef::new(
452            "test".to_string(),
453            vec!["x".to_string(), "y".to_string()],
454            "pred(x, y)".to_string(),
455        );
456        assert_eq!(macro_def.name, "test");
457        assert_eq!(macro_def.arity(), 2);
458    }
459
460    #[test]
461    fn test_macro_validation() {
462        let valid = MacroDef::new(
463            "test".to_string(),
464            vec!["x".to_string()],
465            "pred(x)".to_string(),
466        );
467        assert!(valid.validate().is_ok());
468
469        let invalid_name =
470            MacroDef::new("".to_string(), vec!["x".to_string()], "pred(x)".to_string());
471        assert!(invalid_name.validate().is_err());
472
473        let duplicate_params = MacroDef::new(
474            "test".to_string(),
475            vec!["x".to_string(), "x".to_string()],
476            "pred(x)".to_string(),
477        );
478        assert!(duplicate_params.validate().is_err());
479    }
480
481    #[test]
482    fn test_macro_expansion() {
483        let macro_def = MacroDef::new(
484            "test".to_string(),
485            vec!["x".to_string(), "y".to_string()],
486            "pred(x, y) AND pred(y, x)".to_string(),
487        );
488
489        let expanded = macro_def
490            .expand(&["a".to_string(), "b".to_string()])
491            .unwrap();
492        assert_eq!(expanded, "pred(a, b) AND pred(b, a)");
493    }
494
495    #[test]
496    fn test_macro_registry() {
497        let mut registry = MacroRegistry::new();
498
499        let macro_def = MacroDef::new(
500            "test".to_string(),
501            vec!["x".to_string()],
502            "pred(x)".to_string(),
503        );
504
505        registry.define(macro_def).unwrap();
506        assert!(registry.contains("test"));
507        assert_eq!(registry.len(), 1);
508
509        let expanded = registry.expand("test", &["a".to_string()]).unwrap();
510        assert_eq!(expanded, "pred(a)");
511    }
512
513    #[test]
514    fn test_builtin_macros() {
515        let registry = MacroRegistry::with_builtins();
516        assert!(registry.contains("transitive"));
517        assert!(registry.contains("symmetric"));
518        assert!(registry.contains("reflexive"));
519    }
520
521    #[test]
522    fn test_parse_macro_definition() {
523        let input = "DEFINE MACRO test(x, y) = pred(x, y)";
524        let macro_def = parse_macro_definition(input).unwrap();
525        assert_eq!(macro_def.name, "test");
526        assert_eq!(macro_def.params, vec!["x", "y"]);
527        assert_eq!(macro_def.body, "pred(x, y)");
528    }
529
530    #[test]
531    fn test_replace_word() {
532        assert_eq!(replace_word("x + y", "x", "a"), "a + y");
533        assert_eq!(replace_word("xyz", "x", "a"), "xyz"); // Shouldn't replace
534        assert_eq!(replace_word("x(x, x)", "x", "a"), "a(a, a)");
535    }
536
537    #[test]
538    fn test_macro_expansion_recursive() {
539        let mut registry = MacroRegistry::new();
540
541        let transitive = MacroDef::new(
542            "trans".to_string(),
543            vec!["R".to_string(), "x".to_string(), "z".to_string()],
544            "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
545        );
546        registry.define(transitive).unwrap();
547
548        let expr = "trans(friend, Alice, Bob)";
549        let expanded = registry.expand_all(expr).unwrap();
550        assert!(expanded.contains("EXISTS y"));
551        assert!(expanded.contains("friend(Alice, y)"));
552        assert!(expanded.contains("friend(y, Bob)"));
553    }
554}