1use crate::thread::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Thread {
18 pub id: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub resource_id: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 pub parent_thread_id: Option<String>,
26 pub messages: Vec<Arc<Message>>,
28 pub state: Value,
30 pub patches: Vec<TrackedPatch>,
32 #[serde(default)]
34 pub metadata: ThreadMetadata,
35}
36
37#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct ThreadMetadata {
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub created_at: Option<u64>,
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub updated_at: Option<u64>,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub version: Option<u64>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub version_timestamp: Option<u64>,
52 #[serde(flatten)]
54 pub extra: serde_json::Map<String, Value>,
55}
56
57impl Thread {
58 pub fn new(id: impl Into<String>) -> Self {
60 Self {
61 id: id.into(),
62 resource_id: None,
63 parent_thread_id: None,
64 messages: Vec::new(),
65 state: Value::Object(serde_json::Map::new()),
66 patches: Vec::new(),
67 metadata: ThreadMetadata::default(),
68 }
69 }
70
71 pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
73 Self {
74 id: id.into(),
75 resource_id: None,
76 parent_thread_id: None,
77 messages: Vec::new(),
78 state,
79 patches: Vec::new(),
80 metadata: ThreadMetadata::default(),
81 }
82 }
83
84 #[must_use]
86 pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
87 self.resource_id = Some(resource_id.into());
88 self
89 }
90
91 #[must_use]
93 pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
94 self.parent_thread_id = Some(parent_thread_id.into());
95 self
96 }
97
98 #[must_use]
102 pub fn with_message(mut self, msg: Message) -> Self {
103 self.messages.push(Arc::new(msg));
104 self
105 }
106
107 #[must_use]
109 pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
110 let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
111 self.messages.extend(arcs);
112 self
113 }
114
115 #[must_use]
117 pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
118 self.patches.push(patch);
119 self
120 }
121
122 #[must_use]
124 pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
125 self.patches.extend(patches);
126 self
127 }
128
129 pub fn rebuild_state(&self) -> TireaResult<Value> {
131 if self.patches.is_empty() {
132 return Ok(self.state.clone());
133 }
134 apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
135 }
136
137 pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
145 if patch_index >= self.patches.len() {
146 return Err(TireaError::invalid_operation(format!(
147 "replay index {patch_index} out of bounds (history len: {})",
148 self.patches.len()
149 )));
150 }
151
152 apply_patches(
153 &self.state,
154 self.patches[..=patch_index].iter().map(|p| p.patch()),
155 )
156 }
157
158 pub fn snapshot(self) -> TireaResult<Self> {
162 let current_state = self.rebuild_state()?;
163 Ok(Self {
164 id: self.id,
165 resource_id: self.resource_id,
166 parent_thread_id: self.parent_thread_id,
167 messages: self.messages,
168 state: current_state,
169 patches: Vec::new(),
170 metadata: self.metadata,
171 })
172 }
173
174 pub fn needs_snapshot(&self, threshold: usize) -> bool {
176 self.patches.len() >= threshold
177 }
178
179 pub fn message_count(&self) -> usize {
181 self.messages.len()
182 }
183
184 pub fn patch_count(&self) -> usize {
186 self.patches.len()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use serde_json::json;
194 use tirea_state::{path, Op, Patch};
195
196 #[test]
199 fn test_thread_with_messages_batch() {
200 let msgs = vec![
201 Message::user("a"),
202 Message::assistant("b"),
203 Message::user("c"),
204 ];
205 let thread = Thread::new("t-1").with_messages(msgs);
206 assert_eq!(thread.messages.len(), 3);
207 assert_eq!(thread.messages[0].content, "a");
208 assert_eq!(thread.messages[2].content, "c");
209 }
210
211 #[test]
212 fn test_thread_with_patches_batch() {
213 let thread = Thread::new("t-1").with_patches(vec![
214 TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1)))),
215 TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
216 TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
217 ]);
218 assert_eq!(thread.patches.len(), 3);
219 }
220
221 #[test]
222 fn test_thread_new() {
223 let thread = Thread::new("test-1");
224 assert_eq!(thread.id, "test-1");
225 assert!(thread.resource_id.is_none());
226 assert!(thread.messages.is_empty());
227 assert!(thread.patches.is_empty());
228 }
229
230 #[test]
231 fn test_thread_with_resource_id() {
232 let thread = Thread::new("t-1").with_resource_id("user-123");
233 assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
234 }
235
236 #[test]
237 fn test_thread_with_initial_state() {
238 let state = json!({"counter": 0});
239 let thread = Thread::with_initial_state("test-1", state.clone());
240 assert_eq!(thread.state, state);
241 }
242
243 #[test]
244 fn test_thread_with_message() {
245 let thread = Thread::new("test-1")
246 .with_message(Message::user("Hello"))
247 .with_message(Message::assistant("Hi!"));
248
249 assert_eq!(thread.message_count(), 2);
250 assert_eq!(thread.messages[0].content, "Hello");
251 assert_eq!(thread.messages[1].content, "Hi!");
252 }
253
254 #[test]
255 fn test_thread_with_patch() {
256 let thread = Thread::new("test-1");
257 let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));
258
259 let thread = thread.with_patch(patch);
260 assert_eq!(thread.patch_count(), 1);
261 }
262
263 #[test]
264 fn test_thread_rebuild_state_empty() {
265 let state = json!({"counter": 0});
266 let thread = Thread::with_initial_state("test-1", state.clone());
267
268 let rebuilt = thread.rebuild_state().unwrap();
269 assert_eq!(rebuilt, state);
270 }
271
272 #[test]
273 fn test_thread_rebuild_state_with_patches() {
274 let state = json!({"counter": 0});
275 let thread = Thread::with_initial_state("test-1", state)
276 .with_patch(TrackedPatch::new(
277 Patch::new().with_op(Op::set(path!("counter"), json!(1))),
278 ))
279 .with_patch(TrackedPatch::new(
280 Patch::new().with_op(Op::set(path!("name"), json!("test"))),
281 ));
282
283 let rebuilt = thread.rebuild_state().unwrap();
284 assert_eq!(rebuilt["counter"], 1);
285 assert_eq!(rebuilt["name"], "test");
286 }
287
288 #[test]
289 fn test_thread_snapshot() {
290 let state = json!({"counter": 0});
291 let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
292 Patch::new().with_op(Op::set(path!("counter"), json!(5))),
293 ));
294
295 assert_eq!(thread.patch_count(), 1);
296
297 let snapshotted = thread.snapshot().unwrap();
298 assert_eq!(snapshotted.patch_count(), 0);
299 assert_eq!(snapshotted.state["counter"], 5);
300 }
301
302 #[test]
303 fn test_thread_needs_snapshot() {
304 let thread = Thread::new("test-1");
305 assert!(!thread.needs_snapshot(10));
306
307 let thread = (0..10).fold(thread, |s, i| {
308 s.with_patch(TrackedPatch::new(
309 Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
310 ))
311 });
312
313 assert!(thread.needs_snapshot(10));
314 assert!(!thread.needs_snapshot(20));
315 }
316
317 #[test]
318 fn test_thread_serialization() {
319 let thread = Thread::new("test-1").with_message(Message::user("Hello"));
320
321 let json_str = serde_json::to_string(&thread).unwrap();
322 let restored: Thread = serde_json::from_str(&json_str).unwrap();
323
324 assert_eq!(restored.id, "test-1");
325 assert_eq!(restored.message_count(), 1);
326 }
327
328 #[test]
329 fn test_state_persists_after_serialization() {
330 let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
331 TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
332 );
333
334 let json_str = serde_json::to_string(&thread).unwrap();
335 let restored: Thread = serde_json::from_str(&json_str).unwrap();
336
337 let rebuilt = restored.rebuild_state().unwrap();
338 assert_eq!(
339 rebuilt["counter"], 5,
340 "persisted state should survive serialization"
341 );
342 }
343
344 #[test]
345 fn test_thread_serialization_includes_resource_id() {
346 let thread = Thread::new("t-1").with_resource_id("org-42");
347 let json_str = serde_json::to_string(&thread).unwrap();
348 assert!(json_str.contains("org-42"));
349
350 let restored: Thread = serde_json::from_str(&json_str).unwrap();
351 assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
352 }
353
354 #[test]
355 fn test_thread_replay_to() {
356 let state = json!({"counter": 0});
357 let thread = Thread::with_initial_state("test-1", state)
358 .with_patch(TrackedPatch::new(
359 Patch::new().with_op(Op::set(path!("counter"), json!(10))),
360 ))
361 .with_patch(TrackedPatch::new(
362 Patch::new().with_op(Op::set(path!("counter"), json!(20))),
363 ))
364 .with_patch(TrackedPatch::new(
365 Patch::new().with_op(Op::set(path!("counter"), json!(30))),
366 ));
367
368 let state_at_0 = thread.replay_to(0).unwrap();
369 assert_eq!(state_at_0["counter"], 10);
370
371 let state_at_1 = thread.replay_to(1).unwrap();
372 assert_eq!(state_at_1["counter"], 20);
373
374 let state_at_2 = thread.replay_to(2).unwrap();
375 assert_eq!(state_at_2["counter"], 30);
376
377 let err = thread.replay_to(100).unwrap_err();
378 assert!(err
379 .to_string()
380 .contains("replay index 100 out of bounds (history len: 3)"));
381 }
382
383 #[test]
384 fn test_thread_replay_to_empty() {
385 let state = json!({"counter": 0});
386 let thread = Thread::with_initial_state("test-1", state.clone());
387
388 let err = thread.replay_to(0).unwrap_err();
389 assert!(err
390 .to_string()
391 .contains("replay index 0 out of bounds (history len: 0)"));
392 }
393}