1use crate::{DocCell, LatticeRegistry, Op, Patch, Path, TireaResult, TrackedPatch};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22#[serde(rename_all = "snake_case")]
23pub enum StateScope {
24 Thread,
26 Run,
28 ToolCall,
30}
31
32type CollectHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
33
34pub struct PatchSink<'a> {
45 ops: Option<&'a Mutex<Vec<Op>>>,
46 on_collect: Option<CollectHook<'a>>,
47}
48
49impl<'a> PatchSink<'a> {
50 #[doc(hidden)]
52 pub fn new(ops: &'a Mutex<Vec<Op>>) -> Self {
53 Self {
54 ops: Some(ops),
55 on_collect: None,
56 }
57 }
58
59 #[doc(hidden)]
63 pub fn new_with_hook(ops: &'a Mutex<Vec<Op>>, hook: CollectHook<'a>) -> Self {
64 Self {
65 ops: Some(ops),
66 on_collect: Some(hook),
67 }
68 }
69
70 #[doc(hidden)]
74 pub fn child(&self) -> Self {
75 Self {
76 ops: self.ops,
77 on_collect: self.on_collect.clone(),
78 }
79 }
80
81 #[doc(hidden)]
85 pub fn read_only() -> Self {
86 Self {
87 ops: None,
88 on_collect: None,
89 }
90 }
91
92 #[inline]
94 pub fn collect(&self, op: Op) -> TireaResult<()> {
95 let ops = self.ops.ok_or_else(|| {
96 crate::TireaError::invalid_operation("write attempted on read-only state reference")
97 })?;
98 let mut guard = ops.lock().map_err(|_| {
99 crate::TireaError::invalid_operation("state operation collector mutex poisoned")
100 })?;
101 guard.push(op.clone());
102 drop(guard);
103 if let Some(hook) = &self.on_collect {
104 hook(&op)?;
105 }
106 Ok(())
107 }
108
109 #[doc(hidden)]
111 pub fn inner(&self) -> &'a Mutex<Vec<Op>> {
112 self.ops
113 .expect("PatchSink::inner called on read-only sink (programming error)")
114 }
115}
116
117pub struct StateContext<'a> {
119 doc: &'a DocCell,
120 ops: Mutex<Vec<Op>>,
121}
122
123impl<'a> StateContext<'a> {
124 pub fn new(doc: &'a DocCell) -> Self {
126 Self {
127 doc,
128 ops: Mutex::new(Vec::new()),
129 }
130 }
131
132 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
134 let base = parse_path(path);
135 let hook: CollectHook<'_> = Arc::new(|op: &Op| self.doc.apply(op));
136 T::state_ref(self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
137 }
138
139 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
144 assert!(
145 !T::PATH.is_empty(),
146 "State type has no bound path; use state::<T>(path) instead"
147 );
148 self.state::<T>(T::PATH)
149 }
150
151 pub fn take_patch(&self) -> Patch {
153 let ops = std::mem::take(&mut *self.ops.lock().unwrap());
154 Patch::with_ops(ops)
155 }
156
157 pub fn take_tracked_patch(&self, source: impl Into<String>) -> TrackedPatch {
159 TrackedPatch::new(self.take_patch()).with_source(source)
160 }
161
162 pub fn has_changes(&self) -> bool {
164 !self.ops.lock().unwrap().is_empty()
165 }
166
167 pub fn ops_count(&self) -> usize {
169 self.ops.lock().unwrap().len()
170 }
171}
172
173pub fn parse_path(path: &str) -> Path {
175 if path.is_empty() {
176 return Path::root();
177 }
178
179 let mut result = Path::root();
180 for segment in path.split('.') {
181 if !segment.is_empty() {
182 result = result.key(segment);
183 }
184 }
185 result
186}
187
188pub trait State: Sized {
213 type Ref<'a>;
215
216 const PATH: &'static str = "";
221
222 fn state_ref<'a>(doc: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a>;
230
231 fn from_value(value: &Value) -> TireaResult<Self>;
233
234 fn to_value(&self) -> TireaResult<Value>;
236
237 fn register_lattice(_registry: &mut LatticeRegistry) {}
242
243 fn lattice_keys() -> &'static [&'static str] {
249 &[]
250 }
251
252 fn diff_ops(old: &Self, new: &Self, base_path: &Path) -> TireaResult<Vec<Op>> {
262 let old_val = old.to_value()?;
263 let new_val = new.to_value()?;
264 if old_val == new_val {
265 return Ok(Vec::new());
266 }
267 let lattice_keys = Self::lattice_keys();
268 if lattice_keys.is_empty() {
269 return Ok(vec![Op::set(base_path.clone(), new_val)]);
270 }
271 Ok(diff_state_fields(
273 &old_val,
274 &new_val,
275 base_path,
276 lattice_keys,
277 ))
278 }
279
280 fn to_patch(&self) -> TireaResult<Patch> {
282 Ok(Patch::with_ops(vec![Op::set(
283 Path::root(),
284 self.to_value()?,
285 )]))
286 }
287}
288
289fn diff_state_fields(
295 old_value: &Value,
296 new_value: &Value,
297 base_path: &Path,
298 lattice_keys: &[&str],
299) -> Vec<Op> {
300 let empty_obj = serde_json::Map::new();
301 let old_obj = old_value.as_object().unwrap_or(&empty_obj);
302 let new_obj = new_value.as_object().unwrap_or(&empty_obj);
303
304 let mut ops = Vec::new();
305
306 for (key, new_val) in new_obj {
307 let old_val = old_obj.get(key);
308 if old_val == Some(new_val) {
309 continue;
310 }
311 let field_path = base_path.clone().key(key);
312 if lattice_keys.contains(&key.as_str()) {
313 ops.push(Op::lattice_merge(field_path, new_val.clone()));
314 } else {
315 ops.push(Op::set(field_path, new_val.clone()));
316 }
317 }
318
319 for key in old_obj.keys() {
320 if !new_obj.contains_key(key) {
321 ops.push(Op::delete(base_path.clone().key(key)));
322 }
323 }
324
325 ops
326}
327
328pub trait StateExt: State {
330 fn at_root<'a>(doc: &'a DocCell, sink: PatchSink<'a>) -> Self::Ref<'a> {
332 Self::state_ref(doc, Path::root(), sink)
333 }
334}
335
336impl<T: State> StateExt for T {}
337
338pub trait StateSpec: State + Clone + Sized + Send + 'static {
366 type Action: serde::Serialize + serde::de::DeserializeOwned + Send + 'static;
368
369 const SCOPE: StateScope = StateScope::Thread;
375
376 fn reduce(&mut self, action: Self::Action);
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use serde_json::json;
384
385 #[test]
386 fn test_patch_sink_collect() {
387 let ops = Mutex::new(Vec::new());
388 let sink = PatchSink::new(&ops);
389
390 sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
391 .unwrap();
392 sink.collect(Op::set(Path::root().key("b"), Value::from(2)))
393 .unwrap();
394
395 let collected = ops.lock().unwrap();
396 assert_eq!(collected.len(), 2);
397 }
398
399 #[test]
400 fn test_patch_sink_collect_hook() {
401 let ops = Mutex::new(Vec::new());
402 let seen = Arc::new(Mutex::new(Vec::new()));
403 let seen_hook = seen.clone();
404 let hook = Arc::new(move |op: &Op| {
405 seen_hook.lock().unwrap().push(format!("{:?}", op));
406 Ok(())
407 });
408 let sink = PatchSink::new_with_hook(&ops, hook);
409
410 sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
411 .unwrap();
412 sink.collect(Op::delete(Path::root().key("b"))).unwrap();
413
414 let collected = ops.lock().unwrap();
415 assert_eq!(collected.len(), 2);
416 assert_eq!(seen.lock().unwrap().len(), 2);
417 }
418
419 #[test]
420 fn test_patch_sink_child_preserves_collect_and_hook() {
421 let ops = Mutex::new(Vec::new());
422 let seen = Arc::new(Mutex::new(Vec::new()));
423 let seen_hook = seen.clone();
424 let hook = Arc::new(move |op: &Op| {
425 seen_hook.lock().unwrap().push(format!("{:?}", op));
426 Ok(())
427 });
428 let sink = PatchSink::new_with_hook(&ops, hook);
429 let child = sink.child();
430
431 child
432 .collect(Op::set(Path::root().key("nested"), Value::from(1)))
433 .unwrap();
434
435 assert_eq!(ops.lock().unwrap().len(), 1);
436 assert_eq!(seen.lock().unwrap().len(), 1);
437 }
438
439 #[test]
440 fn test_patch_sink_read_only_child_collect_errors() {
441 let sink = PatchSink::read_only();
442 let child = sink.child();
443 let err = child
444 .collect(Op::set(Path::root().key("x"), Value::from(1)))
445 .unwrap_err();
446 assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
447 }
448
449 #[test]
450 fn test_patch_sink_read_only_collect_errors() {
451 let sink = PatchSink::read_only();
452 let err = sink
453 .collect(Op::set(Path::root().key("x"), Value::from(1)))
454 .unwrap_err();
455 assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
456 }
457
458 #[test]
459 #[should_panic(expected = "read-only sink")]
460 fn test_patch_sink_read_only_inner_panics() {
461 let sink = PatchSink::read_only();
462 let _ = sink.inner();
463 }
464
465 #[test]
466 fn test_parse_path_empty() {
467 let path = parse_path("");
468 assert!(path.is_empty());
469 }
470
471 #[test]
472 fn test_parse_path_nested() {
473 let path = parse_path("tool_calls.call_123.data");
474 assert_eq!(path.to_string(), "$.tool_calls.call_123.data");
475 }
476
477 #[test]
478 fn test_state_context_collects_ops() {
479 struct Counter;
480
481 struct CounterRef<'a> {
482 base: Path,
483 sink: PatchSink<'a>,
484 }
485
486 impl<'a> CounterRef<'a> {
487 fn set_value(&self, value: i64) -> TireaResult<()> {
488 self.sink
489 .collect(Op::set(self.base.clone().key("value"), Value::from(value)))
490 }
491 }
492
493 impl State for Counter {
494 type Ref<'a> = CounterRef<'a>;
495
496 fn state_ref<'a>(_: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a> {
497 CounterRef { base, sink }
498 }
499
500 fn from_value(_: &Value) -> TireaResult<Self> {
501 Ok(Counter)
502 }
503
504 fn to_value(&self) -> TireaResult<Value> {
505 Ok(Value::Null)
506 }
507 }
508
509 let doc = DocCell::new(json!({"counter": {"value": 1}}));
510 let ctx = StateContext::new(&doc);
511 let counter = ctx.state::<Counter>("counter");
512 counter.set_value(2).unwrap();
513
514 assert!(ctx.has_changes());
515 assert_eq!(ctx.ops_count(), 1);
516 assert_eq!(ctx.take_patch().len(), 1);
517 }
518}