1use crate::thread::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};
11
12#[derive(Debug, Clone, Default)]
18pub struct PendingDelta {
19 pub messages: Vec<Arc<Message>>,
20 pub patches: Vec<TrackedPatch>,
21}
22
23impl PendingDelta {
24 pub fn is_empty(&self) -> bool {
26 self.messages.is_empty() && self.patches.is_empty()
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Thread {
40 pub id: String,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub resource_id: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub parent_thread_id: Option<String>,
48 pub messages: Vec<Arc<Message>>,
50 pub state: Value,
52 pub patches: Vec<TrackedPatch>,
54 #[serde(default)]
56 pub metadata: ThreadMetadata,
57 #[serde(skip)]
60 pub(crate) pending: PendingDelta,
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct ThreadMetadata {
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub created_at: Option<u64>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub updated_at: Option<u64>,
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub version: Option<u64>,
75 #[serde(skip_serializing_if = "Option::is_none")]
77 pub version_timestamp: Option<u64>,
78 #[serde(flatten)]
80 pub extra: serde_json::Map<String, Value>,
81}
82
83impl Thread {
84 pub fn new(id: impl Into<String>) -> Self {
86 Self {
87 id: id.into(),
88 resource_id: None,
89 parent_thread_id: None,
90 messages: Vec::new(),
91 state: Value::Object(serde_json::Map::new()),
92 patches: Vec::new(),
93 metadata: ThreadMetadata::default(),
94 pending: PendingDelta::default(),
95 }
96 }
97
98 pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
100 Self {
101 id: id.into(),
102 resource_id: None,
103 parent_thread_id: None,
104 messages: Vec::new(),
105 state,
106 patches: Vec::new(),
107 metadata: ThreadMetadata::default(),
108 pending: PendingDelta::default(),
109 }
110 }
111
112 #[must_use]
114 pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
115 self.resource_id = Some(resource_id.into());
116 self
117 }
118
119 #[must_use]
121 pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
122 self.parent_thread_id = Some(parent_thread_id.into());
123 self
124 }
125
126 #[must_use]
130 pub fn with_message(mut self, msg: Message) -> Self {
131 let arc = Arc::new(msg);
132 self.pending.messages.push(arc.clone());
133 self.messages.push(arc);
134 self
135 }
136
137 #[must_use]
139 pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
140 let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
141 self.pending.messages.extend(arcs.iter().cloned());
142 self.messages.extend(arcs);
143 self
144 }
145
146 #[must_use]
148 pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
149 self.pending.patches.push(patch.clone());
150 self.patches.push(patch);
151 self
152 }
153
154 #[must_use]
156 pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
157 let patches: Vec<TrackedPatch> = patches.into_iter().collect();
158 self.pending.patches.extend(patches.iter().cloned());
159 self.patches.extend(patches);
160 self
161 }
162
163 pub fn take_pending(&mut self) -> PendingDelta {
168 std::mem::take(&mut self.pending)
169 }
170
171 pub fn rebuild_state(&self) -> TireaResult<Value> {
173 if self.patches.is_empty() {
174 return Ok(self.state.clone());
175 }
176 apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
177 }
178
179 pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
187 if patch_index >= self.patches.len() {
188 return Err(TireaError::invalid_operation(format!(
189 "replay index {patch_index} out of bounds (history len: {})",
190 self.patches.len()
191 )));
192 }
193
194 apply_patches(
195 &self.state,
196 self.patches[..=patch_index].iter().map(|p| p.patch()),
197 )
198 }
199
200 pub fn snapshot(self) -> TireaResult<Self> {
204 let current_state = self.rebuild_state()?;
205 Ok(Self {
206 id: self.id,
207 resource_id: self.resource_id,
208 parent_thread_id: self.parent_thread_id,
209 messages: self.messages,
210 state: current_state,
211 patches: Vec::new(),
212 metadata: self.metadata,
213 pending: self.pending,
214 })
215 }
216
217 pub fn needs_snapshot(&self, threshold: usize) -> bool {
219 self.patches.len() >= threshold
220 }
221
222 pub fn message_count(&self) -> usize {
224 self.messages.len()
225 }
226
227 pub fn patch_count(&self) -> usize {
229 self.patches.len()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use serde_json::json;
237 use tirea_state::{path, Op, Patch};
238
239 #[test]
242 fn test_pending_delta_tracks_messages() {
243 let mut thread = Thread::new("t-1")
244 .with_message(Message::user("Hello"))
245 .with_message(Message::assistant("Hi!"));
246
247 assert_eq!(thread.pending.messages.len(), 2);
248 assert_eq!(thread.messages.len(), 2);
249
250 let pending = thread.take_pending();
251 assert_eq!(pending.messages.len(), 2);
252 assert_eq!(pending.messages[0].content, "Hello");
253 assert_eq!(pending.messages[1].content, "Hi!");
254 assert!(thread.pending.is_empty());
255 }
256
257 #[test]
258 fn test_pending_delta_tracks_patches() {
259 let mut thread = Thread::new("t-1")
260 .with_patch(TrackedPatch::new(
261 Patch::new().with_op(Op::set(path!("a"), json!(1))),
262 ))
263 .with_patches(vec![
264 TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
265 TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
266 ]);
267
268 assert_eq!(thread.pending.patches.len(), 3);
269 assert_eq!(thread.patches.len(), 3);
270
271 let pending = thread.take_pending();
272 assert_eq!(pending.patches.len(), 3);
273 assert!(thread.pending.is_empty());
274 }
275
276 #[test]
277 fn test_take_pending_resets_buffer() {
278 let mut thread = Thread::new("t-1").with_message(Message::user("first"));
279 let p1 = thread.take_pending();
280 assert_eq!(p1.messages.len(), 1);
281
282 let p2 = thread.take_pending();
284 assert!(p2.is_empty());
285
286 thread = thread.with_message(Message::user("second"));
288 let p3 = thread.take_pending();
289 assert_eq!(p3.messages.len(), 1);
290 assert_eq!(p3.messages[0].content, "second");
291 }
292
293 #[test]
294 fn test_pending_delta_not_serialized() {
295 let thread = Thread::new("t-1").with_message(Message::user("Hello"));
296 assert_eq!(thread.pending.messages.len(), 1);
297
298 let json_str = serde_json::to_string(&thread).unwrap();
299 let restored: Thread = serde_json::from_str(&json_str).unwrap();
300 assert!(
301 restored.pending.is_empty(),
302 "pending should not survive serialization"
303 );
304 assert_eq!(restored.messages.len(), 1);
305 }
306
307 #[test]
308 fn test_pending_clone_is_independent() {
309 let thread = Thread::new("t-1").with_message(Message::user("first"));
310 assert_eq!(thread.pending.messages.len(), 1);
311
312 let mut cloned = thread.clone();
314 assert_eq!(cloned.pending.messages.len(), 1);
315
316 let pending = cloned.take_pending();
318 assert_eq!(pending.messages.len(), 1);
319 assert!(cloned.pending.is_empty());
320
321 assert_eq!(thread.pending.messages.len(), 1);
323 }
324
325 #[test]
326 fn test_pending_with_messages_batch() {
327 let msgs = vec![
328 Message::user("a"),
329 Message::assistant("b"),
330 Message::user("c"),
331 ];
332 let mut thread = Thread::new("t-1").with_messages(msgs);
333 assert_eq!(thread.messages.len(), 3);
334 assert_eq!(thread.pending.messages.len(), 3);
335
336 let pending = thread.take_pending();
337 assert_eq!(pending.messages.len(), 3);
338 assert_eq!(pending.messages[0].content, "a");
339 assert_eq!(pending.messages[2].content, "c");
340 }
341
342 #[test]
343 fn test_pending_interleaved_messages_and_patches() {
344 let mut thread = Thread::new("t-1")
345 .with_message(Message::user("hello"))
346 .with_patch(TrackedPatch::new(
347 Patch::new().with_op(Op::set(path!("a"), json!(1))),
348 ))
349 .with_message(Message::assistant("hi"))
350 .with_patches(vec![
351 TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
352 TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
353 ]);
354
355 let pending = thread.take_pending();
356 assert_eq!(pending.messages.len(), 2);
357 assert_eq!(pending.patches.len(), 3);
358 assert!(thread.pending.is_empty());
359
360 assert_eq!(thread.messages.len(), 2);
362 assert_eq!(thread.patches.len(), 3);
363 }
364
365 #[test]
366 fn test_pending_is_empty() {
367 let delta = PendingDelta::default();
368 assert!(delta.is_empty());
369
370 let delta = PendingDelta {
371 messages: vec![Arc::new(Message::user("hi"))],
372 patches: vec![],
373 };
374 assert!(!delta.is_empty());
375
376 let delta = PendingDelta {
377 messages: vec![],
378 patches: vec![TrackedPatch::new(Patch::new())],
379 };
380 assert!(!delta.is_empty());
381 }
382
383 #[test]
384 fn test_thread_new() {
385 let thread = Thread::new("test-1");
386 assert_eq!(thread.id, "test-1");
387 assert!(thread.resource_id.is_none());
388 assert!(thread.messages.is_empty());
389 assert!(thread.patches.is_empty());
390 }
391
392 #[test]
393 fn test_thread_with_resource_id() {
394 let thread = Thread::new("t-1").with_resource_id("user-123");
395 assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
396 }
397
398 #[test]
399 fn test_thread_with_initial_state() {
400 let state = json!({"counter": 0});
401 let thread = Thread::with_initial_state("test-1", state.clone());
402 assert_eq!(thread.state, state);
403 }
404
405 #[test]
406 fn test_thread_with_message() {
407 let thread = Thread::new("test-1")
408 .with_message(Message::user("Hello"))
409 .with_message(Message::assistant("Hi!"));
410
411 assert_eq!(thread.message_count(), 2);
412 assert_eq!(thread.messages[0].content, "Hello");
413 assert_eq!(thread.messages[1].content, "Hi!");
414 }
415
416 #[test]
417 fn test_thread_with_patch() {
418 let thread = Thread::new("test-1");
419 let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));
420
421 let thread = thread.with_patch(patch);
422 assert_eq!(thread.patch_count(), 1);
423 }
424
425 #[test]
426 fn test_thread_rebuild_state_empty() {
427 let state = json!({"counter": 0});
428 let thread = Thread::with_initial_state("test-1", state.clone());
429
430 let rebuilt = thread.rebuild_state().unwrap();
431 assert_eq!(rebuilt, state);
432 }
433
434 #[test]
435 fn test_thread_rebuild_state_with_patches() {
436 let state = json!({"counter": 0});
437 let thread = Thread::with_initial_state("test-1", state)
438 .with_patch(TrackedPatch::new(
439 Patch::new().with_op(Op::set(path!("counter"), json!(1))),
440 ))
441 .with_patch(TrackedPatch::new(
442 Patch::new().with_op(Op::set(path!("name"), json!("test"))),
443 ));
444
445 let rebuilt = thread.rebuild_state().unwrap();
446 assert_eq!(rebuilt["counter"], 1);
447 assert_eq!(rebuilt["name"], "test");
448 }
449
450 #[test]
451 fn test_thread_snapshot() {
452 let state = json!({"counter": 0});
453 let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
454 Patch::new().with_op(Op::set(path!("counter"), json!(5))),
455 ));
456
457 assert_eq!(thread.patch_count(), 1);
458
459 let snapshotted = thread.snapshot().unwrap();
460 assert_eq!(snapshotted.patch_count(), 0);
461 assert_eq!(snapshotted.state["counter"], 5);
462 }
463
464 #[test]
465 fn test_thread_needs_snapshot() {
466 let thread = Thread::new("test-1");
467 assert!(!thread.needs_snapshot(10));
468
469 let thread = (0..10).fold(thread, |s, i| {
470 s.with_patch(TrackedPatch::new(
471 Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
472 ))
473 });
474
475 assert!(thread.needs_snapshot(10));
476 assert!(!thread.needs_snapshot(20));
477 }
478
479 #[test]
480 fn test_thread_serialization() {
481 let thread = Thread::new("test-1").with_message(Message::user("Hello"));
482
483 let json_str = serde_json::to_string(&thread).unwrap();
484 let restored: Thread = serde_json::from_str(&json_str).unwrap();
485
486 assert_eq!(restored.id, "test-1");
487 assert_eq!(restored.message_count(), 1);
488 }
489
490 #[test]
491 fn test_state_persists_after_serialization() {
492 let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
493 TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
494 );
495
496 let json_str = serde_json::to_string(&thread).unwrap();
497 let restored: Thread = serde_json::from_str(&json_str).unwrap();
498
499 let rebuilt = restored.rebuild_state().unwrap();
500 assert_eq!(
501 rebuilt["counter"], 5,
502 "persisted state should survive serialization"
503 );
504 }
505
506 #[test]
507 fn test_thread_serialization_includes_resource_id() {
508 let thread = Thread::new("t-1").with_resource_id("org-42");
509 let json_str = serde_json::to_string(&thread).unwrap();
510 assert!(json_str.contains("org-42"));
511
512 let restored: Thread = serde_json::from_str(&json_str).unwrap();
513 assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
514 }
515
516 #[test]
517 fn test_thread_replay_to() {
518 let state = json!({"counter": 0});
519 let thread = Thread::with_initial_state("test-1", state)
520 .with_patch(TrackedPatch::new(
521 Patch::new().with_op(Op::set(path!("counter"), json!(10))),
522 ))
523 .with_patch(TrackedPatch::new(
524 Patch::new().with_op(Op::set(path!("counter"), json!(20))),
525 ))
526 .with_patch(TrackedPatch::new(
527 Patch::new().with_op(Op::set(path!("counter"), json!(30))),
528 ));
529
530 let state_at_0 = thread.replay_to(0).unwrap();
531 assert_eq!(state_at_0["counter"], 10);
532
533 let state_at_1 = thread.replay_to(1).unwrap();
534 assert_eq!(state_at_1["counter"], 20);
535
536 let state_at_2 = thread.replay_to(2).unwrap();
537 assert_eq!(state_at_2["counter"], 30);
538
539 let err = thread.replay_to(100).unwrap_err();
540 assert!(err
541 .to_string()
542 .contains("replay index 100 out of bounds (history len: 3)"));
543 }
544
545 #[test]
546 fn test_thread_replay_to_empty() {
547 let state = json!({"counter": 0});
548 let thread = Thread::with_initial_state("test-1", state.clone());
549
550 let err = thread.replay_to(0).unwrap_err();
551 assert!(err
552 .to_string()
553 .contains("replay index 0 out of bounds (history len: 0)"));
554 }
555}