Skip to main content

tirea_state/
state.rs

1//! State trait for typed state access.
2//!
3//! The `State` trait provides a unified interface for typed access to JSON documents.
4//! It is typically implemented via the derive macro `#[derive(State)]`.
5
6use crate::{DocCell, LatticeRegistry, Op, Patch, Path, TireaResult, TrackedPatch};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::{Arc, Mutex};
10
11/// Lifecycle scope of a [`StateSpec`] type.
12///
13/// Determines when the framework automatically resets the state:
14///
15/// - **`Thread`** — persists across runs, never automatically cleaned.
16///   Examples: reminders, permission overrides, delegation records.
17/// - **`Run`** — reset at the start of each run (in `prepare_run`).
18///   Examples: run lifecycle state, per-run token counters.
19/// - **`ToolCall`** — scoped to a single tool call, cleaned up after execution.
20///   Examples: tool-call progress, suspended-call state.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22#[serde(rename_all = "snake_case")]
23pub enum StateScope {
24    /// State that persists across runs (thread lifetime).
25    Thread,
26    /// State that is reset at the start of each run.
27    Run,
28    /// State that is scoped to a single tool call.
29    ToolCall,
30}
31
32type CollectHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
33
34/// Collector for patch operations.
35///
36/// `PatchSink` collects operations that will be combined into a `Patch`.
37/// It is used internally by `StateRef` types to automatically collect
38/// all state modifications.
39///
40/// # Thread Safety
41///
42/// `PatchSink` uses a `Mutex` internally to support async contexts.
43/// In single-threaded usage, the lock overhead is minimal.
44pub struct PatchSink<'a> {
45    ops: Option<&'a Mutex<Vec<Op>>>,
46    on_collect: Option<CollectHook<'a>>,
47}
48
49impl<'a> PatchSink<'a> {
50    /// Create a new PatchSink wrapping a Mutex.
51    #[doc(hidden)]
52    pub fn new(ops: &'a Mutex<Vec<Op>>) -> Self {
53        Self {
54            ops: Some(ops),
55            on_collect: None,
56        }
57    }
58
59    /// Create a new PatchSink with a collect hook.
60    ///
61    /// The hook is invoked after each operation is collected.
62    #[doc(hidden)]
63    pub fn new_with_hook(ops: &'a Mutex<Vec<Op>>, hook: CollectHook<'a>) -> Self {
64        Self {
65            ops: Some(ops),
66            on_collect: Some(hook),
67        }
68    }
69
70    /// Create a child sink that shares the same collector and hook.
71    ///
72    /// Nested state refs use this so write-through behavior is preserved.
73    #[doc(hidden)]
74    pub fn child(&self) -> Self {
75        Self {
76            ops: self.ops,
77            on_collect: self.on_collect.clone(),
78        }
79    }
80
81    /// Create a read-only PatchSink that errors on collect.
82    ///
83    /// Used for `SealedState::get()` where writes are a programming error.
84    #[doc(hidden)]
85    pub fn read_only() -> Self {
86        Self {
87            ops: None,
88            on_collect: None,
89        }
90    }
91
92    /// Collect an operation.
93    #[inline]
94    pub fn collect(&self, op: Op) -> TireaResult<()> {
95        let ops = self.ops.ok_or_else(|| {
96            crate::TireaError::invalid_operation("write attempted on read-only state reference")
97        })?;
98        let mut guard = ops.lock().map_err(|_| {
99            crate::TireaError::invalid_operation("state operation collector mutex poisoned")
100        })?;
101        guard.push(op.clone());
102        drop(guard);
103        if let Some(hook) = &self.on_collect {
104            hook(&op)?;
105        }
106        Ok(())
107    }
108
109    /// Get the inner Mutex reference (for creating nested PatchSinks).
110    #[doc(hidden)]
111    pub fn inner(&self) -> &'a Mutex<Vec<Op>> {
112        self.ops
113            .expect("PatchSink::inner called on read-only sink (programming error)")
114    }
115}
116
117/// Pure state context with automatic patch collection.
118pub struct StateContext<'a> {
119    doc: &'a DocCell,
120    ops: Mutex<Vec<Op>>,
121}
122
123impl<'a> StateContext<'a> {
124    /// Create a new pure state context.
125    pub fn new(doc: &'a DocCell) -> Self {
126        Self {
127            doc,
128            ops: Mutex::new(Vec::new()),
129        }
130    }
131
132    /// Get a typed state reference at the specified path.
133    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
134        let base = parse_path(path);
135        let hook: CollectHook<'_> = Arc::new(|op: &Op| self.doc.apply(op));
136        T::state_ref(self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
137    }
138
139    /// Get a typed state reference at the type's canonical path.
140    ///
141    /// Requires `T` to have `#[tirea(path = "...")]` set.
142    /// Panics if `T::PATH` is empty.
143    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
144        assert!(
145            !T::PATH.is_empty(),
146            "State type has no bound path; use state::<T>(path) instead"
147        );
148        self.state::<T>(T::PATH)
149    }
150
151    /// Extract collected operations as a plain patch.
152    pub fn take_patch(&self) -> Patch {
153        let ops = std::mem::take(&mut *self.ops.lock().unwrap());
154        Patch::with_ops(ops)
155    }
156
157    /// Extract collected operations as a tracked patch with a source.
158    pub fn take_tracked_patch(&self, source: impl Into<String>) -> TrackedPatch {
159        TrackedPatch::new(self.take_patch()).with_source(source)
160    }
161
162    /// Check if any operations have been collected.
163    pub fn has_changes(&self) -> bool {
164        !self.ops.lock().unwrap().is_empty()
165    }
166
167    /// Get the number of operations collected.
168    pub fn ops_count(&self) -> usize {
169        self.ops.lock().unwrap().len()
170    }
171}
172
173/// Parse a dot-separated path string into a `Path`.
174pub fn parse_path(path: &str) -> Path {
175    if path.is_empty() {
176        return Path::root();
177    }
178
179    let mut result = Path::root();
180    for segment in path.split('.') {
181        if !segment.is_empty() {
182            result = result.key(segment);
183        }
184    }
185    result
186}
187
188/// Trait for types that can create typed state references.
189///
190/// This trait is typically derived using `#[derive(State)]`.
191/// It provides the interface for creating `StateRef` types that
192/// allow typed read/write access to JSON documents.
193///
194/// # Example
195///
196/// ```ignore
197/// use tirea_state::State;
198/// use tirea_state_derive::State;
199///
200/// #[derive(State)]
201/// struct User {
202///     pub name: String,
203///     pub age: i64,
204/// }
205///
206/// // In a StateContext:
207/// let user = ctx.state::<User>("users.alice");
208/// let name = user.name()?;
209/// user.set_name("Alice");
210/// user.set_age(30);
211/// ```
212pub trait State: Sized {
213    /// The reference type that provides typed access.
214    type Ref<'a>;
215
216    /// Canonical JSON path for this state type.
217    ///
218    /// When set via `#[tirea(path = "...")]`, enables `state_of::<T>()` access
219    /// without an explicit path argument. Empty string means no bound path.
220    const PATH: &'static str = "";
221
222    /// Create a state reference at the specified path.
223    ///
224    /// # Arguments
225    ///
226    /// * `doc` - The JSON document to read from
227    /// * `base` - The base path for this state
228    /// * `sink` - The operation collector
229    fn state_ref<'a>(doc: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a>;
230
231    /// Deserialize this type from a JSON value.
232    fn from_value(value: &Value) -> TireaResult<Self>;
233
234    /// Serialize this type to a JSON value.
235    fn to_value(&self) -> TireaResult<Value>;
236
237    /// Register lattice fields into the given registry.
238    ///
239    /// Auto-generated by `#[derive(State)]` for structs with `#[tirea(lattice)]`
240    /// fields. The default implementation is a no-op (no lattice fields).
241    fn register_lattice(_registry: &mut LatticeRegistry) {}
242
243    /// Return the JSON keys of fields annotated with `#[tirea(lattice)]`.
244    ///
245    /// Used by the reducer pipeline to emit `Op::LatticeMerge` (instead of
246    /// `Op::set`) for CRDT fields, enabling proper conflict suppression.
247    /// The default implementation returns an empty slice (no lattice fields).
248    fn lattice_keys() -> &'static [&'static str] {
249        &[]
250    }
251
252    /// Compare two instances and emit minimal ops for changed fields.
253    ///
254    /// The derive macro generates an optimized version that does typed
255    /// per-field comparison and only serializes changed fields. Lattice
256    /// fields emit `Op::LatticeMerge`; regular fields emit `Op::Set`.
257    ///
258    /// The default implementation serializes both values and diffs at JSON
259    /// level. When `lattice_keys()` is non-empty, changed lattice fields
260    /// emit `Op::LatticeMerge`; all others emit `Op::Set`.
261    fn diff_ops(old: &Self, new: &Self, base_path: &Path) -> TireaResult<Vec<Op>> {
262        let old_val = old.to_value()?;
263        let new_val = new.to_value()?;
264        if old_val == new_val {
265            return Ok(Vec::new());
266        }
267        let lattice_keys = Self::lattice_keys();
268        if lattice_keys.is_empty() {
269            return Ok(vec![Op::set(base_path.clone(), new_val)]);
270        }
271        // Per-field diff with LatticeMerge for lattice fields
272        Ok(diff_state_fields(
273            &old_val,
274            &new_val,
275            base_path,
276            lattice_keys,
277        ))
278    }
279
280    /// Create a patch that sets this value at the root.
281    fn to_patch(&self) -> TireaResult<Patch> {
282        Ok(Patch::with_ops(vec![Op::set(
283            Path::root(),
284            self.to_value()?,
285        )]))
286    }
287}
288
289/// Diff two JSON objects field-by-field, emitting `Op::LatticeMerge` for
290/// lattice-annotated fields and `Op::Set` / `Op::Delete` for the rest.
291///
292/// Used by the default `State::diff_ops` implementation for manual impls
293/// that declare `lattice_keys()`.
294fn diff_state_fields(
295    old_value: &Value,
296    new_value: &Value,
297    base_path: &Path,
298    lattice_keys: &[&str],
299) -> Vec<Op> {
300    let empty_obj = serde_json::Map::new();
301    let old_obj = old_value.as_object().unwrap_or(&empty_obj);
302    let new_obj = new_value.as_object().unwrap_or(&empty_obj);
303
304    let mut ops = Vec::new();
305
306    for (key, new_val) in new_obj {
307        let old_val = old_obj.get(key);
308        if old_val == Some(new_val) {
309            continue;
310        }
311        let field_path = base_path.clone().key(key);
312        if lattice_keys.contains(&key.as_str()) {
313            ops.push(Op::lattice_merge(field_path, new_val.clone()));
314        } else {
315            ops.push(Op::set(field_path, new_val.clone()));
316        }
317    }
318
319    for key in old_obj.keys() {
320        if !new_obj.contains_key(key) {
321            ops.push(Op::delete(base_path.clone().key(key)));
322        }
323    }
324
325    ops
326}
327
328/// Extension trait providing convenience methods for State types.
329pub trait StateExt: State {
330    /// Create a state reference at the document root.
331    fn at_root<'a>(doc: &'a DocCell, sink: PatchSink<'a>) -> Self::Ref<'a> {
332        Self::state_ref(doc, Path::root(), sink)
333    }
334}
335
336impl<T: State> StateExt for T {}
337
338/// Extends [`State`] with a typed action and a pure reducer.
339///
340/// Implementors define what actions their state accepts and how the state
341/// transitions in response. The kernel applies actions via type-erased
342/// `AnyStateAction` without knowing the concrete types.
343///
344/// Scope (Run vs ToolCall) is determined at the call site — `AnyStateAction::new()`
345/// for run-scoped state, `AnyStateAction::new_for_call()` for tool-call-scoped state —
346/// rather than being encoded in the trait, keeping business semantics out of `tirea-state`.
347///
348/// # Usage
349///
350/// Typically generated by `#[derive(State)]` with `#[tirea(action = "...")]`:
351///
352/// ```ignore
353/// #[derive(State, Clone, Serialize, Deserialize)]
354/// #[tirea(path = "counters.main", action = "CounterAction")]
355/// struct Counter { value: i64 }
356///
357/// impl Counter {
358///     fn reduce(&mut self, action: CounterAction) {
359///         match action {
360///             CounterAction::Increment(n) => self.value += n,
361///         }
362///     }
363/// }
364/// ```
365pub trait StateSpec: State + Clone + Sized + Send + 'static {
366    /// The action type accepted by this state.
367    type Action: serde::Serialize + serde::de::DeserializeOwned + Send + 'static;
368
369    /// Lifecycle scope of this state type.
370    ///
371    /// Defaults to `Thread` (never automatically cleaned) for backward
372    /// compatibility. Override via `#[tirea(scope = "run")]` or
373    /// `#[tirea(scope = "tool_call")]` in the derive macro.
374    const SCOPE: StateScope = StateScope::Thread;
375
376    /// Pure reducer: apply an action to produce the next state.
377    fn reduce(&mut self, action: Self::Action);
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use serde_json::json;
384
385    #[test]
386    fn test_patch_sink_collect() {
387        let ops = Mutex::new(Vec::new());
388        let sink = PatchSink::new(&ops);
389
390        sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
391            .unwrap();
392        sink.collect(Op::set(Path::root().key("b"), Value::from(2)))
393            .unwrap();
394
395        let collected = ops.lock().unwrap();
396        assert_eq!(collected.len(), 2);
397    }
398
399    #[test]
400    fn test_patch_sink_collect_hook() {
401        let ops = Mutex::new(Vec::new());
402        let seen = Arc::new(Mutex::new(Vec::new()));
403        let seen_hook = seen.clone();
404        let hook = Arc::new(move |op: &Op| {
405            seen_hook.lock().unwrap().push(format!("{:?}", op));
406            Ok(())
407        });
408        let sink = PatchSink::new_with_hook(&ops, hook);
409
410        sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
411            .unwrap();
412        sink.collect(Op::delete(Path::root().key("b"))).unwrap();
413
414        let collected = ops.lock().unwrap();
415        assert_eq!(collected.len(), 2);
416        assert_eq!(seen.lock().unwrap().len(), 2);
417    }
418
419    #[test]
420    fn test_patch_sink_child_preserves_collect_and_hook() {
421        let ops = Mutex::new(Vec::new());
422        let seen = Arc::new(Mutex::new(Vec::new()));
423        let seen_hook = seen.clone();
424        let hook = Arc::new(move |op: &Op| {
425            seen_hook.lock().unwrap().push(format!("{:?}", op));
426            Ok(())
427        });
428        let sink = PatchSink::new_with_hook(&ops, hook);
429        let child = sink.child();
430
431        child
432            .collect(Op::set(Path::root().key("nested"), Value::from(1)))
433            .unwrap();
434
435        assert_eq!(ops.lock().unwrap().len(), 1);
436        assert_eq!(seen.lock().unwrap().len(), 1);
437    }
438
439    #[test]
440    fn test_patch_sink_read_only_child_collect_errors() {
441        let sink = PatchSink::read_only();
442        let child = sink.child();
443        let err = child
444            .collect(Op::set(Path::root().key("x"), Value::from(1)))
445            .unwrap_err();
446        assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
447    }
448
449    #[test]
450    fn test_patch_sink_read_only_collect_errors() {
451        let sink = PatchSink::read_only();
452        let err = sink
453            .collect(Op::set(Path::root().key("x"), Value::from(1)))
454            .unwrap_err();
455        assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
456    }
457
458    #[test]
459    #[should_panic(expected = "read-only sink")]
460    fn test_patch_sink_read_only_inner_panics() {
461        let sink = PatchSink::read_only();
462        let _ = sink.inner();
463    }
464
465    #[test]
466    fn test_parse_path_empty() {
467        let path = parse_path("");
468        assert!(path.is_empty());
469    }
470
471    #[test]
472    fn test_parse_path_nested() {
473        let path = parse_path("tool_calls.call_123.data");
474        assert_eq!(path.to_string(), "$.tool_calls.call_123.data");
475    }
476
477    #[test]
478    fn test_state_context_collects_ops() {
479        struct Counter;
480
481        struct CounterRef<'a> {
482            base: Path,
483            sink: PatchSink<'a>,
484        }
485
486        impl<'a> CounterRef<'a> {
487            fn set_value(&self, value: i64) -> TireaResult<()> {
488                self.sink
489                    .collect(Op::set(self.base.clone().key("value"), Value::from(value)))
490            }
491        }
492
493        impl State for Counter {
494            type Ref<'a> = CounterRef<'a>;
495
496            fn state_ref<'a>(_: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a> {
497                CounterRef { base, sink }
498            }
499
500            fn from_value(_: &Value) -> TireaResult<Self> {
501                Ok(Counter)
502            }
503
504            fn to_value(&self) -> TireaResult<Value> {
505                Ok(Value::Null)
506            }
507        }
508
509        let doc = DocCell::new(json!({"counter": {"value": 1}}));
510        let ctx = StateContext::new(&doc);
511        let counter = ctx.state::<Counter>("counter");
512        counter.set_value(2).unwrap();
513
514        assert!(ctx.has_changes());
515        assert_eq!(ctx.ops_count(), 1);
516        assert_eq!(ctx.take_patch().len(), 1);
517    }
518}