Skip to main content

spool/
rules.rs

1//! User-defined rules that influence memory extraction, injection,
2//! and classification behavior. Stored at `~/.spool/rules.toml`.
3
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8const RULES_FILE_NAME: &str = "rules.toml";
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct UserRules {
12    #[serde(default)]
13    pub extraction: Vec<ExtractionRule>,
14    #[serde(default)]
15    pub context: Vec<ContextRule>,
16    #[serde(default)]
17    pub suppress: Vec<SuppressRule>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ExtractionRule {
22    pub trigger: String,
23    #[serde(default = "default_memory_type")]
24    pub memory_type: String,
25    #[serde(default)]
26    pub description: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ContextRule {
31    #[serde(default = "default_scope")]
32    pub scope: String,
33    pub always_include: Vec<String>,
34    #[serde(default)]
35    pub description: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SuppressRule {
40    pub pattern: String,
41    #[serde(default = "default_action")]
42    pub action: String,
43    #[serde(default)]
44    pub description: String,
45}
46
47fn default_memory_type() -> String {
48    "preference".to_string()
49}
50fn default_scope() -> String {
51    "project".to_string()
52}
53fn default_action() -> String {
54    "skip".to_string()
55}
56
57pub fn rules_path(spool_root: &Path) -> PathBuf {
58    spool_root.join(RULES_FILE_NAME)
59}
60
61pub fn load(spool_root: &Path) -> UserRules {
62    let path = rules_path(spool_root);
63    match fs::read_to_string(&path) {
64        Ok(content) => toml::from_str(&content).unwrap_or_default(),
65        Err(_) => UserRules::default(),
66    }
67}
68
69pub fn save(spool_root: &Path, rules: &UserRules) -> anyhow::Result<()> {
70    let path = rules_path(spool_root);
71    if let Some(parent) = path.parent() {
72        fs::create_dir_all(parent)?;
73    }
74    let content = toml::to_string_pretty(rules)?;
75    fs::write(&path, content)?;
76    Ok(())
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use tempfile::tempdir;
83
84    #[test]
85    fn load_returns_default_when_file_missing() {
86        let temp = tempdir().unwrap();
87        let rules = load(temp.path());
88        assert!(rules.extraction.is_empty());
89        assert!(rules.context.is_empty());
90        assert!(rules.suppress.is_empty());
91    }
92
93    #[test]
94    fn save_and_load_roundtrip() {
95        let temp = tempdir().unwrap();
96        let rules = UserRules {
97            extraction: vec![ExtractionRule {
98                trigger: "技术选型".to_string(),
99                memory_type: "decision".to_string(),
100                description: "技术选型相关决策".to_string(),
101            }],
102            context: vec![ContextRule {
103                scope: "project".to_string(),
104                always_include: vec!["架构约束".to_string()],
105                description: "".to_string(),
106            }],
107            suppress: vec![SuppressRule {
108                pattern: "临时.*测试".to_string(),
109                action: "skip".to_string(),
110                description: "跳过临时测试内容".to_string(),
111            }],
112        };
113        save(temp.path(), &rules).unwrap();
114        let loaded = load(temp.path());
115        assert_eq!(loaded.extraction.len(), 1);
116        assert_eq!(loaded.extraction[0].trigger, "技术选型");
117        assert_eq!(loaded.context[0].always_include[0], "架构约束");
118        assert_eq!(loaded.suppress[0].pattern, "临时.*测试");
119    }
120
121    #[test]
122    fn load_handles_partial_toml() {
123        let temp = tempdir().unwrap();
124        fs::write(
125            temp.path().join(RULES_FILE_NAME),
126            "[[extraction]]\ntrigger = \"test\"\n",
127        )
128        .unwrap();
129        let rules = load(temp.path());
130        assert_eq!(rules.extraction.len(), 1);
131        assert!(rules.context.is_empty());
132    }
133}