1#![allow(dead_code)]
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub struct MacroDef {
42 pub name: String,
44
45 pub params: Vec<String>,
47
48 pub body: String,
50}
51
52impl MacroDef {
53 pub fn new(name: String, params: Vec<String>, body: String) -> Self {
55 Self { name, params, body }
56 }
57
58 pub fn arity(&self) -> usize {
60 self.params.len()
61 }
62
63 pub fn validate(&self) -> Result<()> {
65 if self.name.is_empty() {
66 return Err(anyhow!("Macro name cannot be empty"));
67 }
68
69 if !self.name.chars().next().unwrap().is_alphabetic() {
70 return Err(anyhow!(
71 "Macro name must start with a letter: {}",
72 self.name
73 ));
74 }
75
76 if self.params.is_empty() {
77 return Err(anyhow!("Macro must have at least one parameter"));
78 }
79
80 let mut seen = HashMap::new();
82 for (idx, param) in self.params.iter().enumerate() {
83 if let Some(prev_idx) = seen.insert(param, idx) {
84 return Err(anyhow!(
85 "Duplicate parameter '{}' at positions {} and {}",
86 param,
87 prev_idx,
88 idx
89 ));
90 }
91 }
92
93 if self.body.is_empty() {
94 return Err(anyhow!("Macro body cannot be empty"));
95 }
96
97 Ok(())
98 }
99
100 pub fn expand(&self, args: &[String]) -> Result<String> {
102 if args.len() != self.params.len() {
103 return Err(anyhow!(
104 "Macro {} expects {} arguments, got {}",
105 self.name,
106 self.params.len(),
107 args.len()
108 ));
109 }
110
111 let mut substitutions: HashMap<&str, &str> = HashMap::new();
113 for (param, arg) in self.params.iter().zip(args.iter()) {
114 substitutions.insert(param.as_str(), arg.as_str());
115 }
116
117 let mut result = self.body.clone();
119
120 let mut sorted_params: Vec<&String> = self.params.iter().collect();
122 sorted_params.sort_by_key(|p| std::cmp::Reverse(p.len()));
123
124 for param in sorted_params {
125 if let Some(arg) = substitutions.get(param.as_str()) {
126 result = replace_word(&result, param, arg);
128 }
129 }
130
131 Ok(result)
132 }
133}
134
135fn replace_word(text: &str, from: &str, to: &str) -> String {
137 let mut result = String::new();
138 let mut current_word = String::new();
139
140 for ch in text.chars() {
141 if ch.is_alphanumeric() || ch == '_' {
142 current_word.push(ch);
143 } else {
144 if current_word == from {
145 result.push_str(to);
146 } else {
147 result.push_str(¤t_word);
148 }
149 current_word.clear();
150 result.push(ch);
151 }
152 }
153
154 if current_word == from {
156 result.push_str(to);
157 } else {
158 result.push_str(¤t_word);
159 }
160
161 result
162}
163
164#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct MacroRegistry {
167 macros: HashMap<String, MacroDef>,
169}
170
171impl MacroRegistry {
172 pub fn new() -> Self {
174 Self {
175 macros: HashMap::new(),
176 }
177 }
178
179 pub fn with_builtins() -> Self {
181 let mut registry = Self::new();
182
183 let builtins = vec![
185 MacroDef::new(
186 "transitive".to_string(),
187 vec!["R".to_string(), "x".to_string(), "z".to_string()],
188 "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
189 ),
190 MacroDef::new(
191 "symmetric".to_string(),
192 vec!["R".to_string(), "x".to_string(), "y".to_string()],
193 "R(x, y) AND R(y, x)".to_string(),
194 ),
195 MacroDef::new(
196 "reflexive".to_string(),
197 vec!["R".to_string(), "x".to_string()],
198 "R(x, x)".to_string(),
199 ),
200 MacroDef::new(
201 "antisymmetric".to_string(),
202 vec!["R".to_string(), "x".to_string(), "y".to_string()],
203 "(R(x, y) AND R(y, x)) IMPLIES (x = y)".to_string(),
204 ),
205 MacroDef::new(
206 "total".to_string(),
207 vec!["R".to_string(), "x".to_string(), "y".to_string()],
208 "R(x, y) OR R(y, x)".to_string(),
209 ),
210 ];
211
212 for macro_def in builtins {
213 let _ = registry.define(macro_def);
214 }
215
216 registry
217 }
218
219 pub fn define(&mut self, macro_def: MacroDef) -> Result<()> {
221 macro_def.validate()?;
222 self.macros.insert(macro_def.name.clone(), macro_def);
223 Ok(())
224 }
225
226 pub fn get(&self, name: &str) -> Option<&MacroDef> {
228 self.macros.get(name)
229 }
230
231 pub fn contains(&self, name: &str) -> bool {
233 self.macros.contains_key(name)
234 }
235
236 pub fn undefine(&mut self, name: &str) -> Option<MacroDef> {
238 self.macros.remove(name)
239 }
240
241 pub fn list(&self) -> Vec<&MacroDef> {
243 self.macros.values().collect()
244 }
245
246 pub fn clear(&mut self) {
248 self.macros.clear();
249 }
250
251 pub fn len(&self) -> usize {
253 self.macros.len()
254 }
255
256 pub fn is_empty(&self) -> bool {
258 self.macros.is_empty()
259 }
260
261 pub fn expand(&self, name: &str, args: &[String]) -> Result<String> {
263 let macro_def = self
264 .get(name)
265 .ok_or_else(|| anyhow!("Undefined macro: {}", name))?;
266 macro_def.expand(args)
267 }
268
269 pub fn expand_all(&self, expr: &str) -> Result<String> {
271 let mut result = expr.to_string();
272 let mut changed = true;
273 let mut iterations = 0;
274 const MAX_ITERATIONS: usize = 100; while changed && iterations < MAX_ITERATIONS {
277 changed = false;
278 iterations += 1;
279
280 for (name, macro_def) in &self.macros {
282 if let Some(expanded) = self.try_expand_macro(&result, name, macro_def)? {
283 result = expanded;
284 changed = true;
285 break; }
287 }
288 }
289
290 if iterations >= MAX_ITERATIONS {
291 return Err(anyhow!(
292 "Macro expansion exceeded maximum iterations (possible circular definition)"
293 ));
294 }
295
296 Ok(result)
297 }
298
299 fn try_expand_macro(
301 &self,
302 expr: &str,
303 name: &str,
304 macro_def: &MacroDef,
305 ) -> Result<Option<String>> {
306 if let Some(pos) = expr.find(name) {
308 let after_name = pos + name.len();
310 if after_name < expr.len() && expr.chars().nth(after_name) == Some('(') {
311 if let Some(args) = self.extract_args(&expr[after_name..])? {
313 let expanded = macro_def.expand(&args)?;
314 let mut result = String::new();
315 result.push_str(&expr[..pos]);
316 result.push_str(&expanded);
317 result.push_str(
318 &expr[after_name + self.find_matching_paren(&expr[after_name..])? + 1..],
319 );
320 return Ok(Some(result));
321 }
322 }
323 }
324 Ok(None)
325 }
326
327 fn extract_args(&self, text: &str) -> Result<Option<Vec<String>>> {
329 if !text.starts_with('(') {
330 return Ok(None);
331 }
332
333 let closing = self.find_matching_paren(text)?;
334 let args_str = &text[1..closing];
335
336 if args_str.trim().is_empty() {
337 return Ok(Some(Vec::new()));
338 }
339
340 let mut args = Vec::new();
342 let mut current_arg = String::new();
343 let mut depth = 0;
344
345 for ch in args_str.chars() {
346 match ch {
347 '(' => {
348 depth += 1;
349 current_arg.push(ch);
350 }
351 ')' => {
352 depth -= 1;
353 current_arg.push(ch);
354 }
355 ',' if depth == 0 => {
356 args.push(current_arg.trim().to_string());
357 current_arg.clear();
358 }
359 _ => {
360 current_arg.push(ch);
361 }
362 }
363 }
364
365 if !current_arg.is_empty() {
366 args.push(current_arg.trim().to_string());
367 }
368
369 Ok(Some(args))
370 }
371
372 fn find_matching_paren(&self, text: &str) -> Result<usize> {
374 let mut depth = 0;
375 for (i, ch) in text.chars().enumerate() {
376 match ch {
377 '(' => depth += 1,
378 ')' => {
379 depth -= 1;
380 if depth == 0 {
381 return Ok(i);
382 }
383 }
384 _ => {}
385 }
386 }
387 Err(anyhow!("Unmatched parenthesis"))
388 }
389}
390
391pub fn parse_macro_definition(input: &str) -> Result<MacroDef> {
395 let input = input.trim();
396
397 if !input.starts_with("DEFINE MACRO") && !input.starts_with("MACRO") {
399 return Err(anyhow!(
400 "Macro definition must start with 'DEFINE MACRO' or 'MACRO'"
401 ));
402 }
403
404 let input = if let Some(stripped) = input.strip_prefix("DEFINE MACRO") {
405 stripped
406 } else if let Some(stripped) = input.strip_prefix("MACRO") {
407 stripped
408 } else {
409 unreachable!("Already checked for prefixes above")
410 }
411 .trim();
412
413 let eq_pos = input
415 .find('=')
416 .ok_or_else(|| anyhow!("Macro definition must contain '='"))?;
417
418 let signature = input[..eq_pos].trim();
419 let body = input[eq_pos + 1..].trim().to_string();
420
421 let open_paren = signature
423 .find('(')
424 .ok_or_else(|| anyhow!("Macro definition must have parameter list"))?;
425
426 let name = signature[..open_paren].trim().to_string();
427
428 let close_paren = signature
429 .rfind(')')
430 .ok_or_else(|| anyhow!("Unmatched parenthesis in macro signature"))?;
431
432 let params_str = &signature[open_paren + 1..close_paren];
433 let params: Vec<String> = if params_str.trim().is_empty() {
434 return Err(anyhow!("Macro must have at least one parameter"));
435 } else {
436 params_str
437 .split(',')
438 .map(|s| s.trim().to_string())
439 .collect()
440 };
441
442 Ok(MacroDef::new(name, params, body))
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_macro_def_creation() {
451 let macro_def = MacroDef::new(
452 "test".to_string(),
453 vec!["x".to_string(), "y".to_string()],
454 "pred(x, y)".to_string(),
455 );
456 assert_eq!(macro_def.name, "test");
457 assert_eq!(macro_def.arity(), 2);
458 }
459
460 #[test]
461 fn test_macro_validation() {
462 let valid = MacroDef::new(
463 "test".to_string(),
464 vec!["x".to_string()],
465 "pred(x)".to_string(),
466 );
467 assert!(valid.validate().is_ok());
468
469 let invalid_name =
470 MacroDef::new("".to_string(), vec!["x".to_string()], "pred(x)".to_string());
471 assert!(invalid_name.validate().is_err());
472
473 let duplicate_params = MacroDef::new(
474 "test".to_string(),
475 vec!["x".to_string(), "x".to_string()],
476 "pred(x)".to_string(),
477 );
478 assert!(duplicate_params.validate().is_err());
479 }
480
481 #[test]
482 fn test_macro_expansion() {
483 let macro_def = MacroDef::new(
484 "test".to_string(),
485 vec!["x".to_string(), "y".to_string()],
486 "pred(x, y) AND pred(y, x)".to_string(),
487 );
488
489 let expanded = macro_def
490 .expand(&["a".to_string(), "b".to_string()])
491 .unwrap();
492 assert_eq!(expanded, "pred(a, b) AND pred(b, a)");
493 }
494
495 #[test]
496 fn test_macro_registry() {
497 let mut registry = MacroRegistry::new();
498
499 let macro_def = MacroDef::new(
500 "test".to_string(),
501 vec!["x".to_string()],
502 "pred(x)".to_string(),
503 );
504
505 registry.define(macro_def).unwrap();
506 assert!(registry.contains("test"));
507 assert_eq!(registry.len(), 1);
508
509 let expanded = registry.expand("test", &["a".to_string()]).unwrap();
510 assert_eq!(expanded, "pred(a)");
511 }
512
513 #[test]
514 fn test_builtin_macros() {
515 let registry = MacroRegistry::with_builtins();
516 assert!(registry.contains("transitive"));
517 assert!(registry.contains("symmetric"));
518 assert!(registry.contains("reflexive"));
519 }
520
521 #[test]
522 fn test_parse_macro_definition() {
523 let input = "DEFINE MACRO test(x, y) = pred(x, y)";
524 let macro_def = parse_macro_definition(input).unwrap();
525 assert_eq!(macro_def.name, "test");
526 assert_eq!(macro_def.params, vec!["x", "y"]);
527 assert_eq!(macro_def.body, "pred(x, y)");
528 }
529
530 #[test]
531 fn test_replace_word() {
532 assert_eq!(replace_word("x + y", "x", "a"), "a + y");
533 assert_eq!(replace_word("xyz", "x", "a"), "xyz"); assert_eq!(replace_word("x(x, x)", "x", "a"), "a(a, a)");
535 }
536
537 #[test]
538 fn test_macro_expansion_recursive() {
539 let mut registry = MacroRegistry::new();
540
541 let transitive = MacroDef::new(
542 "trans".to_string(),
543 vec!["R".to_string(), "x".to_string(), "z".to_string()],
544 "EXISTS y. (R(x, y) AND R(y, z))".to_string(),
545 );
546 registry.define(transitive).unwrap();
547
548 let expr = "trans(friend, Alice, Bob)";
549 let expanded = registry.expand_all(expr).unwrap();
550 assert!(expanded.contains("EXISTS y"));
551 assert!(expanded.contains("friend(Alice, y)"));
552 assert!(expanded.contains("friend(y, Bob)"));
553 }
554}