Skip to main content

tirea_contract/thread/
model.rs

1//! Thread model and persistent history primitives.
2//!
3//! `Thread` (formerly `AgentState`) represents persisted agent state with
4//! message history and patches.
5
6use crate::thread::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};
11
12/// Accumulated new messages and patches since the last `take_pending()`.
13///
14/// This buffer is populated automatically by `with_message`, `with_messages`,
15/// `with_patch`, and `with_patches`. Consumers call `take_pending()` to
16/// drain the buffer and build a `ThreadChangeSet` for storage.
17#[derive(Debug, Clone, Default)]
18pub struct PendingDelta {
19    pub messages: Vec<Arc<Message>>,
20    pub patches: Vec<TrackedPatch>,
21}
22
23impl PendingDelta {
24    /// Returns true if there are no pending messages or patches.
25    pub fn is_empty(&self) -> bool {
26        self.messages.is_empty() && self.patches.is_empty()
27    }
28}
29
30/// Persisted thread state with messages and state history.
31///
32/// `Thread` uses an owned builder pattern: `with_*` methods consume `self`
33/// and return a new `Thread` (e.g., `thread.with_message(msg)`).
34///
35/// Runtime field (`pending`) is transient — not serialized.
36/// It exists for backward compatibility and will be removed
37/// once all callers migrate to `RunContext`.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Thread {
40    /// Unique thread identifier.
41    pub id: String,
42    /// Owner/resource identifier (e.g., user_id, org_id).
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub resource_id: Option<String>,
45    /// Parent thread identifier (links child → parent for sub-agent lineage).
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub parent_thread_id: Option<String>,
48    /// Messages (Arc-wrapped for efficient cloning).
49    pub messages: Vec<Arc<Message>>,
50    /// Initial/snapshot state.
51    pub state: Value,
52    /// Patches applied since the last snapshot.
53    pub patches: Vec<TrackedPatch>,
54    /// Metadata.
55    #[serde(default)]
56    pub metadata: ThreadMetadata,
57    /// Pending delta buffer — tracks new items since last `take_pending()`.
58    /// Deprecated: use `RunContext.take_delta()` instead.
59    #[serde(skip)]
60    pub(crate) pending: PendingDelta,
61}
62
63/// Thread metadata.
64#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct ThreadMetadata {
66    /// Creation timestamp (unix millis).
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub created_at: Option<u64>,
69    /// Last update timestamp (unix millis).
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub updated_at: Option<u64>,
72    /// Persisted state cursor version.
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub version: Option<u64>,
75    /// Timestamp of the latest committed version.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub version_timestamp: Option<u64>,
78    /// Custom metadata.
79    #[serde(flatten)]
80    pub extra: serde_json::Map<String, Value>,
81}
82
83impl Thread {
84    /// Create a new thread with the given ID.
85    pub fn new(id: impl Into<String>) -> Self {
86        Self {
87            id: id.into(),
88            resource_id: None,
89            parent_thread_id: None,
90            messages: Vec::new(),
91            state: Value::Object(serde_json::Map::new()),
92            patches: Vec::new(),
93            metadata: ThreadMetadata::default(),
94            pending: PendingDelta::default(),
95        }
96    }
97
98    /// Create a new thread with initial state.
99    pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
100        Self {
101            id: id.into(),
102            resource_id: None,
103            parent_thread_id: None,
104            messages: Vec::new(),
105            state,
106            patches: Vec::new(),
107            metadata: ThreadMetadata::default(),
108            pending: PendingDelta::default(),
109        }
110    }
111
112    /// Set the resource_id (pure function, returns new Thread).
113    #[must_use]
114    pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
115        self.resource_id = Some(resource_id.into());
116        self
117    }
118
119    /// Set the parent_thread_id (pure function, returns new Thread).
120    #[must_use]
121    pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
122        self.parent_thread_id = Some(parent_thread_id.into());
123        self
124    }
125
126    /// Add a message to the thread (pure function, returns new Thread).
127    ///
128    /// Messages are Arc-wrapped for efficient cloning during agent loops.
129    #[must_use]
130    pub fn with_message(mut self, msg: Message) -> Self {
131        let arc = Arc::new(msg);
132        self.pending.messages.push(arc.clone());
133        self.messages.push(arc);
134        self
135    }
136
137    /// Add multiple messages (pure function, returns new Thread).
138    #[must_use]
139    pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
140        let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
141        self.pending.messages.extend(arcs.iter().cloned());
142        self.messages.extend(arcs);
143        self
144    }
145
146    /// Add a patch to the thread (pure function, returns new Thread).
147    #[must_use]
148    pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
149        self.pending.patches.push(patch.clone());
150        self.patches.push(patch);
151        self
152    }
153
154    /// Add multiple patches (pure function, returns new Thread).
155    #[must_use]
156    pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
157        let patches: Vec<TrackedPatch> = patches.into_iter().collect();
158        self.pending.patches.extend(patches.iter().cloned());
159        self.patches.extend(patches);
160        self
161    }
162
163    /// Drain and return the pending delta buffer.
164    ///
165    /// After this call, the pending buffer is empty. The returned `PendingDelta`
166    /// contains all messages and patches added since the last `take_pending()`.
167    pub fn take_pending(&mut self) -> PendingDelta {
168        std::mem::take(&mut self.pending)
169    }
170
171    /// Rebuild the current state (base + thread patches).
172    pub fn rebuild_state(&self) -> TireaResult<Value> {
173        if self.patches.is_empty() {
174            return Ok(self.state.clone());
175        }
176        apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
177    }
178
179    /// Replay state to a specific patch index (0-based).
180    ///
181    /// - `patch_index = 0`: Returns state after applying the first patch only
182    /// - `patch_index = n`: Returns state after applying patches 0..=n
183    /// - `patch_index >= patch_count`: Returns error
184    ///
185    /// This enables time-travel debugging by accessing any historical state point.
186    pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
187        if patch_index >= self.patches.len() {
188            return Err(TireaError::invalid_operation(format!(
189                "replay index {patch_index} out of bounds (history len: {})",
190                self.patches.len()
191            )));
192        }
193
194        apply_patches(
195            &self.state,
196            self.patches[..=patch_index].iter().map(|p| p.patch()),
197        )
198    }
199
200    /// Create a snapshot, collapsing patches into the base state.
201    ///
202    /// Returns a new Thread with the current state as base and empty patches.
203    pub fn snapshot(self) -> TireaResult<Self> {
204        let current_state = self.rebuild_state()?;
205        Ok(Self {
206            id: self.id,
207            resource_id: self.resource_id,
208            parent_thread_id: self.parent_thread_id,
209            messages: self.messages,
210            state: current_state,
211            patches: Vec::new(),
212            metadata: self.metadata,
213            pending: self.pending,
214        })
215    }
216
217    /// Check if a snapshot is needed (e.g., too many patches).
218    pub fn needs_snapshot(&self, threshold: usize) -> bool {
219        self.patches.len() >= threshold
220    }
221
222    /// Get the number of messages.
223    pub fn message_count(&self) -> usize {
224        self.messages.len()
225    }
226
227    /// Get the number of patches.
228    pub fn patch_count(&self) -> usize {
229        self.patches.len()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use serde_json::json;
237    use tirea_state::{path, Op, Patch};
238
239    // Tests use Thread directly (the canonical name).
240
241    #[test]
242    fn test_pending_delta_tracks_messages() {
243        let mut thread = Thread::new("t-1")
244            .with_message(Message::user("Hello"))
245            .with_message(Message::assistant("Hi!"));
246
247        assert_eq!(thread.pending.messages.len(), 2);
248        assert_eq!(thread.messages.len(), 2);
249
250        let pending = thread.take_pending();
251        assert_eq!(pending.messages.len(), 2);
252        assert_eq!(pending.messages[0].content, "Hello");
253        assert_eq!(pending.messages[1].content, "Hi!");
254        assert!(thread.pending.is_empty());
255    }
256
257    #[test]
258    fn test_pending_delta_tracks_patches() {
259        let mut thread = Thread::new("t-1")
260            .with_patch(TrackedPatch::new(
261                Patch::new().with_op(Op::set(path!("a"), json!(1))),
262            ))
263            .with_patches(vec![
264                TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
265                TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
266            ]);
267
268        assert_eq!(thread.pending.patches.len(), 3);
269        assert_eq!(thread.patches.len(), 3);
270
271        let pending = thread.take_pending();
272        assert_eq!(pending.patches.len(), 3);
273        assert!(thread.pending.is_empty());
274    }
275
276    #[test]
277    fn test_take_pending_resets_buffer() {
278        let mut thread = Thread::new("t-1").with_message(Message::user("first"));
279        let p1 = thread.take_pending();
280        assert_eq!(p1.messages.len(), 1);
281
282        // Second take should be empty
283        let p2 = thread.take_pending();
284        assert!(p2.is_empty());
285
286        // Add more, take again
287        thread = thread.with_message(Message::user("second"));
288        let p3 = thread.take_pending();
289        assert_eq!(p3.messages.len(), 1);
290        assert_eq!(p3.messages[0].content, "second");
291    }
292
293    #[test]
294    fn test_pending_delta_not_serialized() {
295        let thread = Thread::new("t-1").with_message(Message::user("Hello"));
296        assert_eq!(thread.pending.messages.len(), 1);
297
298        let json_str = serde_json::to_string(&thread).unwrap();
299        let restored: Thread = serde_json::from_str(&json_str).unwrap();
300        assert!(
301            restored.pending.is_empty(),
302            "pending should not survive serialization"
303        );
304        assert_eq!(restored.messages.len(), 1);
305    }
306
307    #[test]
308    fn test_pending_clone_is_independent() {
309        let thread = Thread::new("t-1").with_message(Message::user("first"));
310        assert_eq!(thread.pending.messages.len(), 1);
311
312        // Clone carries the same pending state
313        let mut cloned = thread.clone();
314        assert_eq!(cloned.pending.messages.len(), 1);
315
316        // Draining the clone does not affect the original
317        let pending = cloned.take_pending();
318        assert_eq!(pending.messages.len(), 1);
319        assert!(cloned.pending.is_empty());
320
321        // Original still has its pending
322        assert_eq!(thread.pending.messages.len(), 1);
323    }
324
325    #[test]
326    fn test_pending_with_messages_batch() {
327        let msgs = vec![
328            Message::user("a"),
329            Message::assistant("b"),
330            Message::user("c"),
331        ];
332        let mut thread = Thread::new("t-1").with_messages(msgs);
333        assert_eq!(thread.messages.len(), 3);
334        assert_eq!(thread.pending.messages.len(), 3);
335
336        let pending = thread.take_pending();
337        assert_eq!(pending.messages.len(), 3);
338        assert_eq!(pending.messages[0].content, "a");
339        assert_eq!(pending.messages[2].content, "c");
340    }
341
342    #[test]
343    fn test_pending_interleaved_messages_and_patches() {
344        let mut thread = Thread::new("t-1")
345            .with_message(Message::user("hello"))
346            .with_patch(TrackedPatch::new(
347                Patch::new().with_op(Op::set(path!("a"), json!(1))),
348            ))
349            .with_message(Message::assistant("hi"))
350            .with_patches(vec![
351                TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
352                TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
353            ]);
354
355        let pending = thread.take_pending();
356        assert_eq!(pending.messages.len(), 2);
357        assert_eq!(pending.patches.len(), 3);
358        assert!(thread.pending.is_empty());
359
360        // Main arrays still have everything
361        assert_eq!(thread.messages.len(), 2);
362        assert_eq!(thread.patches.len(), 3);
363    }
364
365    #[test]
366    fn test_pending_is_empty() {
367        let delta = PendingDelta::default();
368        assert!(delta.is_empty());
369
370        let delta = PendingDelta {
371            messages: vec![Arc::new(Message::user("hi"))],
372            patches: vec![],
373        };
374        assert!(!delta.is_empty());
375
376        let delta = PendingDelta {
377            messages: vec![],
378            patches: vec![TrackedPatch::new(Patch::new())],
379        };
380        assert!(!delta.is_empty());
381    }
382
383    #[test]
384    fn test_thread_new() {
385        let thread = Thread::new("test-1");
386        assert_eq!(thread.id, "test-1");
387        assert!(thread.resource_id.is_none());
388        assert!(thread.messages.is_empty());
389        assert!(thread.patches.is_empty());
390    }
391
392    #[test]
393    fn test_thread_with_resource_id() {
394        let thread = Thread::new("t-1").with_resource_id("user-123");
395        assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
396    }
397
398    #[test]
399    fn test_thread_with_initial_state() {
400        let state = json!({"counter": 0});
401        let thread = Thread::with_initial_state("test-1", state.clone());
402        assert_eq!(thread.state, state);
403    }
404
405    #[test]
406    fn test_thread_with_message() {
407        let thread = Thread::new("test-1")
408            .with_message(Message::user("Hello"))
409            .with_message(Message::assistant("Hi!"));
410
411        assert_eq!(thread.message_count(), 2);
412        assert_eq!(thread.messages[0].content, "Hello");
413        assert_eq!(thread.messages[1].content, "Hi!");
414    }
415
416    #[test]
417    fn test_thread_with_patch() {
418        let thread = Thread::new("test-1");
419        let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));
420
421        let thread = thread.with_patch(patch);
422        assert_eq!(thread.patch_count(), 1);
423    }
424
425    #[test]
426    fn test_thread_rebuild_state_empty() {
427        let state = json!({"counter": 0});
428        let thread = Thread::with_initial_state("test-1", state.clone());
429
430        let rebuilt = thread.rebuild_state().unwrap();
431        assert_eq!(rebuilt, state);
432    }
433
434    #[test]
435    fn test_thread_rebuild_state_with_patches() {
436        let state = json!({"counter": 0});
437        let thread = Thread::with_initial_state("test-1", state)
438            .with_patch(TrackedPatch::new(
439                Patch::new().with_op(Op::set(path!("counter"), json!(1))),
440            ))
441            .with_patch(TrackedPatch::new(
442                Patch::new().with_op(Op::set(path!("name"), json!("test"))),
443            ));
444
445        let rebuilt = thread.rebuild_state().unwrap();
446        assert_eq!(rebuilt["counter"], 1);
447        assert_eq!(rebuilt["name"], "test");
448    }
449
450    #[test]
451    fn test_thread_snapshot() {
452        let state = json!({"counter": 0});
453        let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
454            Patch::new().with_op(Op::set(path!("counter"), json!(5))),
455        ));
456
457        assert_eq!(thread.patch_count(), 1);
458
459        let snapshotted = thread.snapshot().unwrap();
460        assert_eq!(snapshotted.patch_count(), 0);
461        assert_eq!(snapshotted.state["counter"], 5);
462    }
463
464    #[test]
465    fn test_thread_needs_snapshot() {
466        let thread = Thread::new("test-1");
467        assert!(!thread.needs_snapshot(10));
468
469        let thread = (0..10).fold(thread, |s, i| {
470            s.with_patch(TrackedPatch::new(
471                Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
472            ))
473        });
474
475        assert!(thread.needs_snapshot(10));
476        assert!(!thread.needs_snapshot(20));
477    }
478
479    #[test]
480    fn test_thread_serialization() {
481        let thread = Thread::new("test-1").with_message(Message::user("Hello"));
482
483        let json_str = serde_json::to_string(&thread).unwrap();
484        let restored: Thread = serde_json::from_str(&json_str).unwrap();
485
486        assert_eq!(restored.id, "test-1");
487        assert_eq!(restored.message_count(), 1);
488    }
489
490    #[test]
491    fn test_state_persists_after_serialization() {
492        let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
493            TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
494        );
495
496        let json_str = serde_json::to_string(&thread).unwrap();
497        let restored: Thread = serde_json::from_str(&json_str).unwrap();
498
499        let rebuilt = restored.rebuild_state().unwrap();
500        assert_eq!(
501            rebuilt["counter"], 5,
502            "persisted state should survive serialization"
503        );
504    }
505
506    #[test]
507    fn test_thread_serialization_includes_resource_id() {
508        let thread = Thread::new("t-1").with_resource_id("org-42");
509        let json_str = serde_json::to_string(&thread).unwrap();
510        assert!(json_str.contains("org-42"));
511
512        let restored: Thread = serde_json::from_str(&json_str).unwrap();
513        assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
514    }
515
516    #[test]
517    fn test_thread_replay_to() {
518        let state = json!({"counter": 0});
519        let thread = Thread::with_initial_state("test-1", state)
520            .with_patch(TrackedPatch::new(
521                Patch::new().with_op(Op::set(path!("counter"), json!(10))),
522            ))
523            .with_patch(TrackedPatch::new(
524                Patch::new().with_op(Op::set(path!("counter"), json!(20))),
525            ))
526            .with_patch(TrackedPatch::new(
527                Patch::new().with_op(Op::set(path!("counter"), json!(30))),
528            ));
529
530        let state_at_0 = thread.replay_to(0).unwrap();
531        assert_eq!(state_at_0["counter"], 10);
532
533        let state_at_1 = thread.replay_to(1).unwrap();
534        assert_eq!(state_at_1["counter"], 20);
535
536        let state_at_2 = thread.replay_to(2).unwrap();
537        assert_eq!(state_at_2["counter"], 30);
538
539        let err = thread.replay_to(100).unwrap_err();
540        assert!(err
541            .to_string()
542            .contains("replay index 100 out of bounds (history len: 3)"));
543    }
544
545    #[test]
546    fn test_thread_replay_to_empty() {
547        let state = json!({"counter": 0});
548        let thread = Thread::with_initial_state("test-1", state.clone());
549
550        let err = thread.replay_to(0).unwrap_err();
551        assert!(err
552            .to_string()
553            .contains("replay index 0 out of bounds (history len: 0)"));
554    }
555}