split_brain_harness/
rag.rs1use anyhow::{Context, Result};
8use serde::Deserialize;
9use std::fs;
10use std::path::Path;
11
12const DEFAULT_CONTEXT: &str = include_str!("../context/default.toml");
13
14#[derive(Debug, Clone, Deserialize)]
16pub struct ContextDoc {
17 pub id: String,
18 pub title: String,
19 pub text: String,
20 #[serde(default)]
21 pub tags: Vec<String>,
22}
23
24#[derive(Deserialize, Default)]
25struct TomlCorpus {
26 #[serde(default)]
27 docs: Vec<ContextDoc>,
28}
29
30#[derive(Debug, Default, Clone)]
32pub struct ContextCorpus {
33 pub docs: Vec<ContextDoc>,
34}
35
36impl ContextCorpus {
37 pub fn embedded() -> Self {
39 let parsed: TomlCorpus = toml::from_str(DEFAULT_CONTEXT)
40 .expect("embedded context/default.toml is invalid TOML — build error");
41 Self { docs: parsed.docs }
42 }
43
44 pub fn load_file(file_path: &str) -> Result<Self> {
46 let raw = fs::read_to_string(file_path)
47 .with_context(|| format!("cannot read context file: {file_path}"))?;
48 let parsed: TomlCorpus = toml::from_str(&raw)
49 .with_context(|| format!("invalid TOML in context file: {file_path}"))?;
50 Ok(Self { docs: parsed.docs })
51 }
52
53 pub fn load_dir(dir_path: &str) -> Result<Self> {
55 let mut corpus = Self::default();
56 let dir = Path::new(dir_path);
57 let entries = fs::read_dir(dir)
58 .with_context(|| format!("cannot read context directory: {dir_path}"))?;
59 for entry in entries {
60 let path = entry?.path();
61 if path.extension().and_then(|e| e.to_str()) == Some("toml") {
62 let path_str = path.to_string_lossy();
63 let loaded = Self::load_file(&path_str)?;
64 corpus.merge(loaded);
65 }
66 }
67 Ok(corpus)
68 }
69
70 pub fn load(path: &str) -> Result<Self> {
72 if Path::new(path).is_dir() {
73 Self::load_dir(path)
74 } else {
75 Self::load_file(path)
76 }
77 }
78
79 pub fn merge(&mut self, other: Self) {
81 self.docs.extend(other.docs);
82 }
83
84 pub fn render(&self, max_chars: usize) -> String {
87 if self.docs.is_empty() {
88 return String::new();
89 }
90 let mut buf = String::from("<context_pack>\n");
91 for doc in &self.docs {
92 let block = format!("## {}\n{}\n\n", doc.title, doc.text.trim());
93 if buf.len() + block.len() + "</context_pack>".len() > max_chars {
94 break;
95 }
96 buf.push_str(&block);
97 }
98 buf.push_str("</context_pack>");
99 buf
100 }
101
102 pub fn len(&self) -> usize {
103 self.docs.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
107 self.docs.is_empty()
108 }
109}
110
111#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn embedded_corpus_loads_and_has_docs() {
121 let corpus = ContextCorpus::embedded();
122 assert!(!corpus.is_empty(), "embedded corpus must have at least one doc");
123 assert!(corpus.len() >= 4, "expected 4 default docs");
124 }
125
126 #[test]
127 fn embedded_corpus_has_expected_ids() {
128 let corpus = ContextCorpus::embedded();
129 let ids: Vec<&str> = corpus.docs.iter().map(|d| d.id.as_str()).collect();
130 assert!(ids.contains(&"schema.telemetry"), "missing schema.telemetry");
131 assert!(ids.contains(&"threat.prompt_injection"), "missing threat.prompt_injection");
132 assert!(ids.contains(&"threat.social_engineering"), "missing threat.social_engineering");
133 assert!(ids.contains(&"threat.adversarial_probing"), "missing threat.adversarial_probing");
134 }
135
136 #[test]
137 fn render_produces_context_pack_tags() {
138 let corpus = ContextCorpus::embedded();
139 let rendered = corpus.render(usize::MAX);
140 assert!(rendered.starts_with("<context_pack>"), "must start with opening tag");
141 assert!(rendered.ends_with("</context_pack>"), "must end with closing tag");
142 assert!(rendered.contains("## "), "must contain doc headers");
143 }
144
145 #[test]
146 fn render_empty_corpus_is_empty_string() {
147 let corpus = ContextCorpus::default();
148 assert_eq!(corpus.render(usize::MAX), "");
149 }
150
151 #[test]
152 fn render_respects_max_chars() {
153 let corpus = ContextCorpus::embedded();
154 let tiny_limit = 200;
156 let rendered = corpus.render(tiny_limit);
157 assert!(rendered.len() <= tiny_limit || rendered == "<context_pack>\n</context_pack>");
158 }
159
160 #[test]
161 fn merge_combines_docs() {
162 let mut a = ContextCorpus {
163 docs: vec![ContextDoc {
164 id: "a".into(), title: "Doc A".into(), text: "text a".into(), tags: vec![],
165 }],
166 };
167 let b = ContextCorpus {
168 docs: vec![ContextDoc {
169 id: "b".into(), title: "Doc B".into(), text: "text b".into(), tags: vec![],
170 }],
171 };
172 a.merge(b);
173 assert_eq!(a.len(), 2);
174 }
175
176 #[test]
177 fn load_file_returns_err_on_missing_file() {
178 let result = ContextCorpus::load_file("/nonexistent/path/does_not_exist.toml");
179 assert!(result.is_err());
180 }
181
182 #[test]
183 fn load_file_returns_err_on_invalid_toml() {
184 let dir = tempfile::tempdir().unwrap();
185 let path = dir.path().join("bad.toml");
186 std::fs::write(&path, "not valid [[toml").unwrap();
187 let result = ContextCorpus::load_file(path.to_str().unwrap());
188 assert!(result.is_err());
189 }
190
191 #[test]
192 fn load_dir_reads_all_toml_files() {
193 let dir = tempfile::tempdir().unwrap();
194 let toml_a = "[[docs]]\nid=\"a\"\ntitle=\"A\"\ntext=\"text a\"\n";
195 let toml_b = "[[docs]]\nid=\"b\"\ntitle=\"B\"\ntext=\"text b\"\n";
196 std::fs::write(dir.path().join("a.toml"), toml_a).unwrap();
197 std::fs::write(dir.path().join("b.toml"), toml_b).unwrap();
198 std::fs::write(dir.path().join("ignore.txt"), "not toml").unwrap();
199 let corpus = ContextCorpus::load_dir(dir.path().to_str().unwrap()).unwrap();
200 assert_eq!(corpus.len(), 2, "must load exactly 2 TOML docs, not the .txt file");
201 }
202
203 #[test]
204 fn load_dispatches_file_vs_dir() {
205 let dir = tempfile::tempdir().unwrap();
207 let path = dir.path().join("single.toml");
208 std::fs::write(&path, "[[docs]]\nid=\"x\"\ntitle=\"X\"\ntext=\"t\"\n").unwrap();
209 let corpus = ContextCorpus::load(path.to_str().unwrap()).unwrap();
210 assert_eq!(corpus.len(), 1);
211
212 let corpus2 = ContextCorpus::load(dir.path().to_str().unwrap()).unwrap();
214 assert_eq!(corpus2.len(), 1);
215 }
216}