1#![allow(dead_code)]
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub struct MacroDef {
42 pub name: String,
44
45 pub params: Vec<String>,
47
48 pub body: String,
50}
51
52impl MacroDef {
53 pub fn new(name: String, params: Vec<String>, body: String) -> Self {
55 Self { name, params, body }
56 }
57
58 pub fn arity(&self) -> usize {
60 self.params.len()
61 }
62
63 pub fn validate(&self) -> Result<()> {
65 if self.name.is_empty() {
66 return Err(anyhow!("Macro name cannot be empty"));
67 }
68
69 if !self
70 .name
71 .chars()
72 .next()
73 .expect("name non-empty after is_empty check above")
74 .is_alphabetic()
75 {
76 return Err(anyhow!(
77 "Macro name must start with a letter: {}",
78 self.name
79 ));
80 }
81
82 if self.params.is_empty() {
83 return Err(anyhow!("Macro must have at least one parameter"));
84 }
85
86 let mut seen = HashMap::new();
88 for (idx, param) in self.params.iter().enumerate() {
89 if let Some(prev_idx) = seen.insert(param, idx) {
90 return Err(anyhow!(
91 "Duplicate parameter '{}' at positions {} and {}",
92 param,
93 prev_idx,
94 idx
95 ));
96 }
97 }
98
99 if self.body.is_empty() {
100 return Err(anyhow!("Macro body cannot be empty"));
101 }
102
103 Ok(())
104 }
105
106 pub fn expand(&self, args: &[String]) -> Result<String> {
108 if args.len() != self.params.len() {
109 return Err(anyhow!(
110 "Macro {} expects {} arguments, got {}",
111 self.name,
112 self.params.len(),
113 args.len()
114 ));
115 }
116
117 let mut substitutions: HashMap<&str, &str> = HashMap::new();
119 for (param, arg) in self.params.iter().zip(args.iter()) {
120 substitutions.insert(param.as_str(), arg.as_str());
121 }
122
123 let mut result = self.body.clone();
125
126 let mut sorted_params: Vec<&String> = self.params.iter().collect();
128 sorted_params.sort_by_key(|p| std::cmp::Reverse(p.len()));
129
130 for param in sorted_params {
131 if let Some(arg) = substitutions.get(param.as_str()) {
132 result = replace_word(&result, param, arg);
134 }
135 }
136
137 Ok(result)
138 }
139}
140
141fn replace_word(text: &str, from: &str, to: &str) -> String {
143 let mut result = String::new();
144 let mut current_word = String::new();
145
146 for ch in text.chars() {
147 if ch.is_alphanumeric() || ch == '_' {
148 current_word.push(ch);
149 } else {
150 if current_word == from {
151 result.push_str(to);
152 } else {
153 result.push_str(¤t_word);
154 }
155 current_word.clear();
156 result.push(ch);
157 }
158 }
159
160 if current_word == from {
162 result.push_str(to);
163 } else {
164 result.push_str(¤t_word);
165 }
166
167 result
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct MacroRegistry {
173 macros: HashMap<String, MacroDef>,
175}
176
177impl MacroRegistry {
178 pub fn new() -> Self {
180 Self {
181 macros: HashMap::new(),
182 }
183 }
184
185 pub fn with_builtins() -> Self {
187 let mut registry = Self::new();
188
189 let builtins = vec![
191 MacroDef::new(
192 "transitive".to_string(),
193 vec!["R".to_string(), "x".to_string(), "z".to_string()],
194 "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
195 ),
196 MacroDef::new(
197 "symmetric".to_string(),
198 vec!["R".to_string(), "x".to_string(), "y".to_string()],
199 "R(x, y) AND R(y, x)".to_string(),
200 ),
201 MacroDef::new(
202 "reflexive".to_string(),
203 vec!["R".to_string(), "x".to_string()],
204 "R(x, x)".to_string(),
205 ),
206 MacroDef::new(
207 "antisymmetric".to_string(),
208 vec!["R".to_string(), "x".to_string(), "y".to_string()],
209 "(R(x, y) AND R(y, x)) IMPLIES (x = y)".to_string(),
210 ),
211 MacroDef::new(
212 "total".to_string(),
213 vec!["R".to_string(), "x".to_string(), "y".to_string()],
214 "R(x, y) OR R(y, x)".to_string(),
215 ),
216 ];
217
218 for macro_def in builtins {
219 let _ = registry.define(macro_def);
220 }
221
222 registry
223 }
224
225 pub fn define(&mut self, macro_def: MacroDef) -> Result<()> {
227 macro_def.validate()?;
228 self.macros.insert(macro_def.name.clone(), macro_def);
229 Ok(())
230 }
231
232 pub fn get(&self, name: &str) -> Option<&MacroDef> {
234 self.macros.get(name)
235 }
236
237 pub fn contains(&self, name: &str) -> bool {
239 self.macros.contains_key(name)
240 }
241
242 pub fn undefine(&mut self, name: &str) -> Option<MacroDef> {
244 self.macros.remove(name)
245 }
246
247 pub fn list(&self) -> Vec<&MacroDef> {
249 self.macros.values().collect()
250 }
251
252 pub fn clear(&mut self) {
254 self.macros.clear();
255 }
256
257 pub fn len(&self) -> usize {
259 self.macros.len()
260 }
261
262 pub fn is_empty(&self) -> bool {
264 self.macros.is_empty()
265 }
266
267 pub fn expand(&self, name: &str, args: &[String]) -> Result<String> {
269 let macro_def = self
270 .get(name)
271 .ok_or_else(|| anyhow!("Undefined macro: {}", name))?;
272 macro_def.expand(args)
273 }
274
275 pub fn expand_all(&self, expr: &str) -> Result<String> {
277 let mut result = expr.to_string();
278 let mut changed = true;
279 let mut iterations = 0;
280 const MAX_ITERATIONS: usize = 100; while changed && iterations < MAX_ITERATIONS {
283 changed = false;
284 iterations += 1;
285
286 for (name, macro_def) in &self.macros {
288 if let Some(expanded) = self.try_expand_macro(&result, name, macro_def)? {
289 result = expanded;
290 changed = true;
291 break; }
293 }
294 }
295
296 if iterations >= MAX_ITERATIONS {
297 return Err(anyhow!(
298 "Macro expansion exceeded maximum iterations (possible circular definition)"
299 ));
300 }
301
302 Ok(result)
303 }
304
305 fn try_expand_macro(
307 &self,
308 expr: &str,
309 name: &str,
310 macro_def: &MacroDef,
311 ) -> Result<Option<String>> {
312 if let Some(pos) = expr.find(name) {
314 let after_name = pos + name.len();
316 if after_name < expr.len() && expr.chars().nth(after_name) == Some('(') {
317 if let Some(args) = self.extract_args(&expr[after_name..])? {
319 let expanded = macro_def.expand(&args)?;
320 let mut result = String::new();
321 result.push_str(&expr[..pos]);
322 result.push_str(&expanded);
323 result.push_str(
324 &expr[after_name + self.find_matching_paren(&expr[after_name..])? + 1..],
325 );
326 return Ok(Some(result));
327 }
328 }
329 }
330 Ok(None)
331 }
332
333 fn extract_args(&self, text: &str) -> Result<Option<Vec<String>>> {
335 if !text.starts_with('(') {
336 return Ok(None);
337 }
338
339 let closing = self.find_matching_paren(text)?;
340 let args_str = &text[1..closing];
341
342 if args_str.trim().is_empty() {
343 return Ok(Some(Vec::new()));
344 }
345
346 let mut args = Vec::new();
348 let mut current_arg = String::new();
349 let mut depth = 0;
350
351 for ch in args_str.chars() {
352 match ch {
353 '(' => {
354 depth += 1;
355 current_arg.push(ch);
356 }
357 ')' => {
358 depth -= 1;
359 current_arg.push(ch);
360 }
361 ',' if depth == 0 => {
362 args.push(current_arg.trim().to_string());
363 current_arg.clear();
364 }
365 _ => {
366 current_arg.push(ch);
367 }
368 }
369 }
370
371 if !current_arg.is_empty() {
372 args.push(current_arg.trim().to_string());
373 }
374
375 Ok(Some(args))
376 }
377
378 fn find_matching_paren(&self, text: &str) -> Result<usize> {
380 let mut depth = 0;
381 for (i, ch) in text.chars().enumerate() {
382 match ch {
383 '(' => depth += 1,
384 ')' => {
385 depth -= 1;
386 if depth == 0 {
387 return Ok(i);
388 }
389 }
390 _ => {}
391 }
392 }
393 Err(anyhow!("Unmatched parenthesis"))
394 }
395}
396
397pub fn parse_macro_definition(input: &str) -> Result<MacroDef> {
401 let input = input.trim();
402
403 if !input.starts_with("DEFINE MACRO") && !input.starts_with("MACRO") {
405 return Err(anyhow!(
406 "Macro definition must start with 'DEFINE MACRO' or 'MACRO'"
407 ));
408 }
409
410 let input = if let Some(stripped) = input.strip_prefix("DEFINE MACRO") {
411 stripped
412 } else if let Some(stripped) = input.strip_prefix("MACRO") {
413 stripped
414 } else {
415 unreachable!("Already checked for prefixes above")
416 }
417 .trim();
418
419 let eq_pos = input
421 .find('=')
422 .ok_or_else(|| anyhow!("Macro definition must contain '='"))?;
423
424 let signature = input[..eq_pos].trim();
425 let body = input[eq_pos + 1..].trim().to_string();
426
427 let open_paren = signature
429 .find('(')
430 .ok_or_else(|| anyhow!("Macro definition must have parameter list"))?;
431
432 let name = signature[..open_paren].trim().to_string();
433
434 let close_paren = signature
435 .rfind(')')
436 .ok_or_else(|| anyhow!("Unmatched parenthesis in macro signature"))?;
437
438 let params_str = &signature[open_paren + 1..close_paren];
439 let params: Vec<String> = if params_str.trim().is_empty() {
440 return Err(anyhow!("Macro must have at least one parameter"));
441 } else {
442 params_str
443 .split(',')
444 .map(|s| s.trim().to_string())
445 .collect()
446 };
447
448 Ok(MacroDef::new(name, params, body))
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_macro_def_creation() {
457 let macro_def = MacroDef::new(
458 "test".to_string(),
459 vec!["x".to_string(), "y".to_string()],
460 "pred(x, y)".to_string(),
461 );
462 assert_eq!(macro_def.name, "test");
463 assert_eq!(macro_def.arity(), 2);
464 }
465
466 #[test]
467 fn test_macro_validation() {
468 let valid = MacroDef::new(
469 "test".to_string(),
470 vec!["x".to_string()],
471 "pred(x)".to_string(),
472 );
473 assert!(valid.validate().is_ok());
474
475 let invalid_name =
476 MacroDef::new("".to_string(), vec!["x".to_string()], "pred(x)".to_string());
477 assert!(invalid_name.validate().is_err());
478
479 let duplicate_params = MacroDef::new(
480 "test".to_string(),
481 vec!["x".to_string(), "x".to_string()],
482 "pred(x)".to_string(),
483 );
484 assert!(duplicate_params.validate().is_err());
485 }
486
487 #[test]
488 fn test_macro_expansion() {
489 let macro_def = MacroDef::new(
490 "test".to_string(),
491 vec!["x".to_string(), "y".to_string()],
492 "pred(x, y) AND pred(y, x)".to_string(),
493 );
494
495 let expanded = macro_def
496 .expand(&["a".to_string(), "b".to_string()])
497 .expect("macro expansion should succeed");
498 assert_eq!(expanded, "pred(a, b) AND pred(b, a)");
499 }
500
501 #[test]
502 fn test_macro_registry() {
503 let mut registry = MacroRegistry::new();
504
505 let macro_def = MacroDef::new(
506 "test".to_string(),
507 vec!["x".to_string()],
508 "pred(x)".to_string(),
509 );
510
511 registry
512 .define(macro_def)
513 .expect("macro define should succeed");
514 assert!(registry.contains("test"));
515 assert_eq!(registry.len(), 1);
516
517 let expanded = registry
518 .expand("test", &["a".to_string()])
519 .expect("macro expand should succeed");
520 assert_eq!(expanded, "pred(a)");
521 }
522
523 #[test]
524 fn test_builtin_macros() {
525 let registry = MacroRegistry::with_builtins();
526 assert!(registry.contains("transitive"));
527 assert!(registry.contains("symmetric"));
528 assert!(registry.contains("reflexive"));
529 }
530
531 #[test]
532 fn test_parse_macro_definition() {
533 let input = "DEFINE MACRO test(x, y) = pred(x, y)";
534 let macro_def = parse_macro_definition(input).expect("macro definition should parse");
535 assert_eq!(macro_def.name, "test");
536 assert_eq!(macro_def.params, vec!["x", "y"]);
537 assert_eq!(macro_def.body, "pred(x, y)");
538 }
539
540 #[test]
541 fn test_replace_word() {
542 assert_eq!(replace_word("x + y", "x", "a"), "a + y");
543 assert_eq!(replace_word("xyz", "x", "a"), "xyz"); assert_eq!(replace_word("x(x, x)", "x", "a"), "a(a, a)");
545 }
546
547 #[test]
548 fn test_macro_expansion_recursive() {
549 let mut registry = MacroRegistry::new();
550
551 let transitive = MacroDef::new(
552 "trans".to_string(),
553 vec!["R".to_string(), "x".to_string(), "z".to_string()],
554 "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
555 );
556 registry
557 .define(transitive)
558 .expect("transitive macro define should succeed");
559
560 let expr = "trans(friend, Alice, Bob)";
561 let expanded = registry
562 .expand_all(expr)
563 .expect("macro expand_all should succeed");
564 assert!(expanded.contains("EXISTS y"));
565 assert!(expanded.contains("friend(Alice, y)"));
566 assert!(expanded.contains("friend(y, Bob)"));
567 }
568}