1use anyhow::Result;
11use oxidized_state::{CommitId, MemoryRecord, SurrealHandle};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VectorStoreDelta {
17 pub only_in_a: Vec<MemoryRecord>,
19 pub only_in_b: Vec<MemoryRecord>,
21 pub conflicts: Vec<MemoryConflict>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct MemoryConflict {
28 pub key: String,
30 pub memory_a: MemoryRecord,
32 pub memory_b: MemoryRecord,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct AutoResolvedValue {
39 pub value: serde_json::Value,
41 pub favored_branch: Option<String>,
43 pub reasoning: String,
45 pub confidence: f32,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct MergeResult {
52 pub merge_commit_id: CommitId,
54 pub auto_resolved: usize,
56 pub manual_conflicts: Vec<MemoryConflict>,
58 pub summary: String,
60}
61
62pub async fn diff_memory_vectors(
66 handle: &SurrealHandle,
67 commit_a: &str,
68 commit_b: &str,
69) -> Result<VectorStoreDelta> {
70 let memories_a = handle.get_memories(commit_a).await?;
71 let memories_b = handle.get_memories(commit_b).await?;
72
73 let keys_a: std::collections::HashSet<_> = memories_a.iter().map(|m| &m.key).collect();
74 let keys_b: std::collections::HashSet<_> = memories_b.iter().map(|m| &m.key).collect();
75
76 let only_in_a: Vec<_> = memories_a
77 .iter()
78 .filter(|m| !keys_b.contains(&m.key))
79 .cloned()
80 .collect();
81
82 let only_in_b: Vec<_> = memories_b
83 .iter()
84 .filter(|m| !keys_a.contains(&m.key))
85 .cloned()
86 .collect();
87
88 let mut conflicts = Vec::new();
90 for mem_a in &memories_a {
91 if let Some(mem_b) = memories_b.iter().find(|m| m.key == mem_a.key) {
92 if mem_a.content != mem_b.content {
93 conflicts.push(MemoryConflict {
94 key: mem_a.key.clone(),
95 memory_a: mem_a.clone(),
96 memory_b: mem_b.clone(),
97 });
98 }
99 }
100 }
101
102 Ok(VectorStoreDelta {
103 only_in_a,
104 only_in_b,
105 conflicts,
106 })
107}
108
109pub async fn resolve_conflict_state(
113 _trace_a: &[serde_json::Value],
114 _trace_b: &[serde_json::Value],
115 conflict: &MemoryConflict,
116) -> Result<AutoResolvedValue> {
117 let (value, favored, reasoning) =
121 if conflict.memory_a.content.len() >= conflict.memory_b.content.len() {
122 (
123 serde_json::json!({"content": conflict.memory_a.content}),
124 Some("A".to_string()),
125 "Chose branch A: more detailed content".to_string(),
126 )
127 } else {
128 (
129 serde_json::json!({"content": conflict.memory_b.content}),
130 Some("B".to_string()),
131 "Chose branch B: more detailed content".to_string(),
132 )
133 };
134
135 Ok(AutoResolvedValue {
136 value,
137 favored_branch: favored,
138 reasoning,
139 confidence: 0.6, })
141}
142
143pub async fn synthesize_memory(
147 handle: &SurrealHandle,
148 commit_a: &str,
149 commit_b: &str,
150 new_commit_id: &str,
151) -> Result<Vec<MemoryRecord>> {
152 let delta = diff_memory_vectors(handle, commit_a, commit_b).await?;
153
154 let mut merged_memories = Vec::new();
155
156 for mut mem in delta.only_in_a {
158 mem.commit_id = new_commit_id.to_string();
159 mem.id = None;
160 merged_memories.push(mem);
161 }
162
163 for mut mem in delta.only_in_b {
165 mem.commit_id = new_commit_id.to_string();
166 mem.id = None;
167 merged_memories.push(mem);
168 }
169
170 for conflict in delta.conflicts {
172 let resolved = resolve_conflict_state(&[], &[], &conflict).await?;
173 let merged_mem = MemoryRecord::new(
174 new_commit_id,
175 &conflict.key,
176 resolved
177 .value
178 .get("content")
179 .and_then(|v| v.as_str())
180 .unwrap_or(&conflict.memory_a.content),
181 )
182 .with_metadata(serde_json::json!({
183 "merged_from": [commit_a, commit_b],
184 "resolution": resolved.reasoning,
185 "confidence": resolved.confidence,
186 }));
187 merged_memories.push(merged_mem);
188 }
189
190 Ok(merged_memories)
191}
192
193pub async fn semantic_merge(
195 handle: &SurrealHandle,
196 commit_a: &str,
197 commit_b: &str,
198 message: &str,
199 author: &str,
200) -> Result<MergeResult> {
201 let state_data = format!("merge:{}:{}", commit_a, commit_b);
203 let merge_commit_id = CommitId::from_state(state_data.as_bytes());
204
205 let merged_memories =
207 synthesize_memory(handle, commit_a, commit_b, &merge_commit_id.hash).await?;
208
209 for mem in &merged_memories {
211 handle.save_memory(mem).await?;
212 }
213
214 let commit = oxidized_state::CommitRecord::new(
216 merge_commit_id.clone(),
217 vec![commit_a.to_string(), commit_b.to_string()],
218 message,
219 author,
220 );
221 handle.save_commit(&commit).await?;
222
223 handle
225 .save_commit_graph_edge(&merge_commit_id.hash, commit_a)
226 .await?;
227 handle
228 .save_commit_graph_edge(&merge_commit_id.hash, commit_b)
229 .await?;
230
231 let delta = diff_memory_vectors(handle, commit_a, commit_b).await?;
233
234 Ok(MergeResult {
235 merge_commit_id,
236 auto_resolved: delta.conflicts.len(),
237 manual_conflicts: vec![], summary: format!(
239 "Merged {} memories from A, {} from B, resolved {} conflicts",
240 delta.only_in_a.len(),
241 delta.only_in_b.len(),
242 delta.conflicts.len()
243 ),
244 })
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[tokio::test]
252 async fn test_memory_diff_shows_only_new_vectors() {
253 let handle = SurrealHandle::setup_db().await.unwrap();
254
255 let mem_a1 = MemoryRecord::new("commit-a", "shared-key", "shared content");
257 let mem_a2 = MemoryRecord::new("commit-a", "only-a-key", "only in A");
258 handle.save_memory(&mem_a1).await.unwrap();
259 handle.save_memory(&mem_a2).await.unwrap();
260
261 let mem_b1 = MemoryRecord::new("commit-b", "shared-key", "shared content");
263 let mem_b2 = MemoryRecord::new("commit-b", "only-b-key", "only in B");
264 handle.save_memory(&mem_b1).await.unwrap();
265 handle.save_memory(&mem_b2).await.unwrap();
266
267 let delta = diff_memory_vectors(&handle, "commit-a", "commit-b")
268 .await
269 .unwrap();
270
271 assert_eq!(delta.only_in_a.len(), 1);
272 assert_eq!(delta.only_in_a[0].key, "only-a-key");
273
274 assert_eq!(delta.only_in_b.len(), 1);
275 assert_eq!(delta.only_in_b[0].key, "only-b-key");
276
277 assert_eq!(delta.conflicts.len(), 0); }
279
280 #[tokio::test]
281 async fn test_memory_diff_detects_conflicts() {
282 let handle = SurrealHandle::setup_db().await.unwrap();
283
284 let mem_a = MemoryRecord::new("commit-a", "conflict-key", "content version A");
285 let mem_b = MemoryRecord::new("commit-b", "conflict-key", "content version B");
286 handle.save_memory(&mem_a).await.unwrap();
287 handle.save_memory(&mem_b).await.unwrap();
288
289 let delta = diff_memory_vectors(&handle, "commit-a", "commit-b")
290 .await
291 .unwrap();
292
293 assert_eq!(delta.conflicts.len(), 1);
294 assert_eq!(delta.conflicts[0].key, "conflict-key");
295 }
296
297 #[tokio::test]
298 async fn test_arbiter_resolves_value_conflict_based_on_cot() {
299 let conflict = MemoryConflict {
300 key: "test-key".to_string(),
301 memory_a: MemoryRecord::new("a", "test-key", "short"),
302 memory_b: MemoryRecord::new("b", "test-key", "longer content here"),
303 };
304
305 let resolved = resolve_conflict_state(&[], &[], &conflict).await.unwrap();
306
307 assert!(resolved.confidence > 0.0);
308 assert!(resolved.favored_branch.is_some());
309 assert!(!resolved.reasoning.is_empty());
310 }
311
312 #[tokio::test]
313 async fn test_merge_synthesizes_two_memories_into_one_new_commit() {
314 let handle = SurrealHandle::setup_db().await.unwrap();
315
316 let commit_id_a = oxidized_state::CommitId::from_state(b"branch-a");
318 let commit_id_b = oxidized_state::CommitId::from_state(b"branch-b");
319
320 let commit_a = oxidized_state::CommitRecord::new(
322 commit_id_a.clone(),
323 vec![],
324 "Branch A commit",
325 "agent-a",
326 );
327 handle.save_commit(&commit_a).await.unwrap();
328
329 let commit_b = oxidized_state::CommitRecord::new(
330 commit_id_b.clone(),
331 vec![],
332 "Branch B commit",
333 "agent-b",
334 );
335 handle.save_commit(&commit_b).await.unwrap();
336
337 let mem_a_only =
339 MemoryRecord::new(&commit_id_a.hash, "learned-from-a", "Strategy A knowledge");
340 let mem_b_only =
341 MemoryRecord::new(&commit_id_b.hash, "learned-from-b", "Strategy B knowledge");
342 let mem_conflict_a = MemoryRecord::new(&commit_id_a.hash, "shared-key", "short");
343 let mem_conflict_b = MemoryRecord::new(
344 &commit_id_b.hash,
345 "shared-key",
346 "longer and more detailed content",
347 );
348
349 handle.save_memory(&mem_a_only).await.unwrap();
350 handle.save_memory(&mem_b_only).await.unwrap();
351 handle.save_memory(&mem_conflict_a).await.unwrap();
352 handle.save_memory(&mem_conflict_b).await.unwrap();
353
354 let result = semantic_merge(
356 &handle,
357 &commit_id_a.hash,
358 &commit_id_b.hash,
359 "Merge A and B",
360 "agent-git",
361 )
362 .await
363 .unwrap();
364
365 assert!(!result.merge_commit_id.hash.is_empty());
367
368 let merged_memories = handle
370 .get_memories(&result.merge_commit_id.hash)
371 .await
372 .unwrap();
373
374 assert_eq!(merged_memories.len(), 3, "Expected 3 merged memories");
376
377 let keys: Vec<_> = merged_memories.iter().map(|m| m.key.as_str()).collect();
379 assert!(
380 keys.contains(&"learned-from-a"),
381 "Missing memory from branch A"
382 );
383 assert!(
384 keys.contains(&"learned-from-b"),
385 "Missing memory from branch B"
386 );
387 assert!(keys.contains(&"shared-key"), "Missing resolved conflict");
388
389 let resolved = merged_memories
391 .iter()
392 .find(|m| m.key == "shared-key")
393 .unwrap();
394 assert!(
395 resolved.content.contains("longer") || resolved.content.contains("detailed"),
396 "Conflict resolution should favor more detailed content"
397 );
398
399 assert!(
401 result.summary.contains("2") || result.summary.contains("memories"),
402 "Summary should mention merged memories"
403 );
404 }
405}