Skip to main content

split_brain_harness/
rag.rs

1/// Operator-configurable context corpus for the transformer RAG layer.
2///
3/// Loads `[[docs]]` entries from TOML files and renders them as a
4/// `<context_pack>` block injected into the system prompt before the
5/// model call. The embedded default corpus (context/default.toml) ships
6/// with the binary; operators may extend or replace it via SBH_CONTEXT_PATH.
7use anyhow::{Context, Result};
8use serde::Deserialize;
9use std::fs;
10use std::path::Path;
11
12const DEFAULT_CONTEXT: &str = include_str!("../context/default.toml");
13
14/// One context document: a titled, tagged block of text.
15#[derive(Debug, Clone, Deserialize)]
16pub struct ContextDoc {
17    pub id: String,
18    pub title: String,
19    pub text: String,
20    #[serde(default)]
21    pub tags: Vec<String>,
22}
23
24#[derive(Deserialize, Default)]
25struct TomlCorpus {
26    #[serde(default)]
27    docs: Vec<ContextDoc>,
28}
29
30/// A collection of context documents injected into the transformer prompt.
31#[derive(Debug, Default, Clone)]
32pub struct ContextCorpus {
33    pub docs: Vec<ContextDoc>,
34}
35
36impl ContextCorpus {
37    /// Returns the embedded default corpus compiled into the binary.
38    pub fn embedded() -> Self {
39        let parsed: TomlCorpus = toml::from_str(DEFAULT_CONTEXT)
40            .expect("embedded context/default.toml is invalid TOML — build error");
41        Self { docs: parsed.docs }
42    }
43
44    /// Load a single TOML file.
45    pub fn load_file(file_path: &str) -> Result<Self> {
46        let raw = fs::read_to_string(file_path)
47            .with_context(|| format!("cannot read context file: {file_path}"))?;
48        let parsed: TomlCorpus = toml::from_str(&raw)
49            .with_context(|| format!("invalid TOML in context file: {file_path}"))?;
50        Ok(Self { docs: parsed.docs })
51    }
52
53    /// Load all `*.toml` files from a directory, merging them.
54    pub fn load_dir(dir_path: &str) -> Result<Self> {
55        let mut corpus = Self::default();
56        let dir = Path::new(dir_path);
57        let entries = fs::read_dir(dir)
58            .with_context(|| format!("cannot read context directory: {dir_path}"))?;
59        for entry in entries {
60            let path = entry?.path();
61            if path.extension().and_then(|e| e.to_str()) == Some("toml") {
62                let path_str = path.to_string_lossy();
63                let loaded = Self::load_file(&path_str)?;
64                corpus.merge(loaded);
65            }
66        }
67        Ok(corpus)
68    }
69
70    /// Load from a path: if it's a directory, load all TOML files; if a file, load it.
71    pub fn load(path: &str) -> Result<Self> {
72        if Path::new(path).is_dir() {
73            Self::load_dir(path)
74        } else {
75            Self::load_file(path)
76        }
77    }
78
79    /// Merge another corpus into this one (appends docs).
80    pub fn merge(&mut self, other: Self) {
81        self.docs.extend(other.docs);
82    }
83
84    /// Render as a `<context_pack>` block, truncating to `max_chars` total.
85    /// Whole docs are dropped (not split) to avoid broken context.
86    pub fn render(&self, max_chars: usize) -> String {
87        if self.docs.is_empty() {
88            return String::new();
89        }
90        let mut buf = String::from("<context_pack>\n");
91        for doc in &self.docs {
92            let block = format!("## {}\n{}\n\n", doc.title, doc.text.trim());
93            if buf.len() + block.len() + "</context_pack>".len() > max_chars {
94                break;
95            }
96            buf.push_str(&block);
97        }
98        buf.push_str("</context_pack>");
99        buf
100    }
101
102    pub fn len(&self) -> usize {
103        self.docs.len()
104    }
105
106    pub fn is_empty(&self) -> bool {
107        self.docs.is_empty()
108    }
109}
110
111// ---------------------------------------------------------------------------
112// Tests
113// ---------------------------------------------------------------------------
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn embedded_corpus_loads_and_has_docs() {
121        let corpus = ContextCorpus::embedded();
122        assert!(!corpus.is_empty(), "embedded corpus must have at least one doc");
123        assert!(corpus.len() >= 4, "expected 4 default docs");
124    }
125
126    #[test]
127    fn embedded_corpus_has_expected_ids() {
128        let corpus = ContextCorpus::embedded();
129        let ids: Vec<&str> = corpus.docs.iter().map(|d| d.id.as_str()).collect();
130        assert!(ids.contains(&"schema.telemetry"), "missing schema.telemetry");
131        assert!(ids.contains(&"threat.prompt_injection"), "missing threat.prompt_injection");
132        assert!(ids.contains(&"threat.social_engineering"), "missing threat.social_engineering");
133        assert!(ids.contains(&"threat.adversarial_probing"), "missing threat.adversarial_probing");
134    }
135
136    #[test]
137    fn render_produces_context_pack_tags() {
138        let corpus = ContextCorpus::embedded();
139        let rendered = corpus.render(usize::MAX);
140        assert!(rendered.starts_with("<context_pack>"), "must start with opening tag");
141        assert!(rendered.ends_with("</context_pack>"), "must end with closing tag");
142        assert!(rendered.contains("## "), "must contain doc headers");
143    }
144
145    #[test]
146    fn render_empty_corpus_is_empty_string() {
147        let corpus = ContextCorpus::default();
148        assert_eq!(corpus.render(usize::MAX), "");
149    }
150
151    #[test]
152    fn render_respects_max_chars() {
153        let corpus = ContextCorpus::embedded();
154        // Only allow enough for the wrapper + maybe one doc
155        let tiny_limit = 200;
156        let rendered = corpus.render(tiny_limit);
157        assert!(rendered.len() <= tiny_limit || rendered == "<context_pack>\n</context_pack>");
158    }
159
160    #[test]
161    fn merge_combines_docs() {
162        let mut a = ContextCorpus {
163            docs: vec![ContextDoc {
164                id: "a".into(), title: "Doc A".into(), text: "text a".into(), tags: vec![],
165            }],
166        };
167        let b = ContextCorpus {
168            docs: vec![ContextDoc {
169                id: "b".into(), title: "Doc B".into(), text: "text b".into(), tags: vec![],
170            }],
171        };
172        a.merge(b);
173        assert_eq!(a.len(), 2);
174    }
175
176    #[test]
177    fn load_file_returns_err_on_missing_file() {
178        let result = ContextCorpus::load_file("/nonexistent/path/does_not_exist.toml");
179        assert!(result.is_err());
180    }
181
182    #[test]
183    fn load_file_returns_err_on_invalid_toml() {
184        let dir = tempfile::tempdir().unwrap();
185        let path = dir.path().join("bad.toml");
186        std::fs::write(&path, "not valid [[toml").unwrap();
187        let result = ContextCorpus::load_file(path.to_str().unwrap());
188        assert!(result.is_err());
189    }
190
191    #[test]
192    fn load_dir_reads_all_toml_files() {
193        let dir = tempfile::tempdir().unwrap();
194        let toml_a = "[[docs]]\nid=\"a\"\ntitle=\"A\"\ntext=\"text a\"\n";
195        let toml_b = "[[docs]]\nid=\"b\"\ntitle=\"B\"\ntext=\"text b\"\n";
196        std::fs::write(dir.path().join("a.toml"), toml_a).unwrap();
197        std::fs::write(dir.path().join("b.toml"), toml_b).unwrap();
198        std::fs::write(dir.path().join("ignore.txt"), "not toml").unwrap();
199        let corpus = ContextCorpus::load_dir(dir.path().to_str().unwrap()).unwrap();
200        assert_eq!(corpus.len(), 2, "must load exactly 2 TOML docs, not the .txt file");
201    }
202
203    #[test]
204    fn load_dispatches_file_vs_dir() {
205        // file path
206        let dir = tempfile::tempdir().unwrap();
207        let path = dir.path().join("single.toml");
208        std::fs::write(&path, "[[docs]]\nid=\"x\"\ntitle=\"X\"\ntext=\"t\"\n").unwrap();
209        let corpus = ContextCorpus::load(path.to_str().unwrap()).unwrap();
210        assert_eq!(corpus.len(), 1);
211
212        // dir path
213        let corpus2 = ContextCorpus::load(dir.path().to_str().unwrap()).unwrap();
214        assert_eq!(corpus2.len(), 1);
215    }
216}