Skip to main content

tirea_agent_loop/runtime/
activity.rs

1//! Activity state manager and event emission.
2
3use crate::contracts::runtime::ActivityManager;
4use crate::contracts::AgentEvent;
5use serde_json::json;
6use std::collections::HashMap;
7use std::sync::Mutex;
8use tirea_state::{apply_patch, Op, Patch, Value};
9use tokio::sync::mpsc::UnboundedSender;
10
11#[derive(Debug, Clone)]
12struct ActivityEntry {
13    activity_type: String,
14    state: Value,
15}
16
17/// Activity manager that keeps per-stream activity state and emits events on updates.
18#[derive(Debug)]
19pub struct ActivityHub {
20    sender: UnboundedSender<AgentEvent>,
21    entries: Mutex<HashMap<String, ActivityEntry>>,
22}
23
24impl ActivityHub {
25    /// Create a new activity hub.
26    pub fn new(sender: UnboundedSender<AgentEvent>) -> Self {
27        Self {
28            sender,
29            entries: Mutex::new(HashMap::new()),
30        }
31    }
32
33    fn entry_for(&self, _stream_id: &str, activity_type: &str) -> ActivityEntry {
34        ActivityEntry {
35            activity_type: activity_type.to_string(),
36            state: json!({}),
37        }
38    }
39}
40
41impl ActivityManager for ActivityHub {
42    fn snapshot(&self, stream_id: &str) -> Value {
43        self.entries
44            .lock()
45            .unwrap()
46            .get(stream_id)
47            .map(|entry| entry.state.clone())
48            .unwrap_or_else(|| json!({}))
49    }
50
51    fn on_activity_op(&self, stream_id: &str, activity_type: &str, op: &Op) {
52        let mut entries = self.entries.lock().unwrap();
53        let entry = entries
54            .entry(stream_id.to_string())
55            .or_insert_with(|| self.entry_for(stream_id, activity_type));
56
57        if entry.activity_type.is_empty() {
58            entry.activity_type = activity_type.to_string();
59        }
60
61        let patch = Patch::with_ops(vec![op.clone()]);
62        if let Ok(updated) = apply_patch(&entry.state, &patch) {
63            entry.state = updated;
64        }
65
66        let _ = self.sender.send(AgentEvent::ActivitySnapshot {
67            message_id: stream_id.to_string(),
68            activity_type: entry.activity_type.clone(),
69            content: entry.state.clone(),
70            replace: Some(true),
71        });
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::contracts::AgentEvent;
79    use serde_json::json;
80    use std::sync::Arc;
81    use tirea_state::{Op, Path};
82    use tokio::sync::mpsc;
83
84    #[test]
85    fn test_activity_hub_emits_snapshot_and_updates_state() {
86        let (tx, mut rx) = mpsc::unbounded_channel();
87        let hub = ActivityHub::new(tx);
88
89        let op = Op::set(Path::root().key("progress"), json!(0.2));
90        hub.on_activity_op("stream_1", "progress", &op);
91
92        let event = rx.try_recv().expect("activity event");
93        match event {
94            AgentEvent::ActivitySnapshot {
95                message_id,
96                activity_type,
97                content,
98                replace,
99            } => {
100                assert_eq!(message_id, "stream_1");
101                assert_eq!(activity_type, "progress");
102                assert_eq!(content["progress"], 0.2);
103                assert_eq!(replace, Some(true));
104            }
105            _ => panic!("Expected ActivitySnapshot"),
106        }
107
108        let snapshot = hub.snapshot("stream_1");
109        assert_eq!(snapshot["progress"], 0.2);
110    }
111
112    #[test]
113    fn test_activity_hub_accumulates_updates() {
114        let (tx, mut rx) = mpsc::unbounded_channel();
115        let hub = ActivityHub::new(tx);
116
117        let op1 = Op::set(Path::root().key("progress"), json!(0.5));
118        hub.on_activity_op("stream_2", "progress", &op1);
119        let _ = rx.try_recv().expect("first event");
120
121        let op2 = Op::set(Path::root().key("status"), json!("running"));
122        hub.on_activity_op("stream_2", "progress", &op2);
123
124        let event = rx.try_recv().expect("second event");
125        match event {
126            AgentEvent::ActivitySnapshot { content, .. } => {
127                assert_eq!(content["progress"], 0.5);
128                assert_eq!(content["status"], "running");
129            }
130            _ => panic!("Expected ActivitySnapshot"),
131        }
132    }
133
134    #[test]
135    fn test_activity_hub_preserves_activity_type() {
136        let (tx, mut rx) = mpsc::unbounded_channel();
137        let hub = ActivityHub::new(tx);
138
139        let op1 = Op::set(Path::root().key("progress"), json!(0.3));
140        hub.on_activity_op("stream_3", "progress", &op1);
141        let _ = rx.try_recv().expect("first event");
142
143        let op2 = Op::set(Path::root().key("status"), json!("running"));
144        hub.on_activity_op("stream_3", "other", &op2);
145
146        let event = rx.try_recv().expect("second event");
147        match event {
148            AgentEvent::ActivitySnapshot { activity_type, .. } => {
149                assert_eq!(activity_type, "progress");
150            }
151            _ => panic!("Expected ActivitySnapshot"),
152        }
153    }
154
155    #[test]
156    fn test_activity_hub_multiple_streams_isolated() {
157        let (tx, mut rx) = mpsc::unbounded_channel();
158        let hub = ActivityHub::new(tx);
159
160        let op1 = Op::set(Path::root().key("progress"), json!(0.1));
161        hub.on_activity_op("stream_a", "progress", &op1);
162        let _ = rx.try_recv().expect("event stream_a");
163
164        let op2 = Op::set(Path::root().key("progress"), json!(0.9));
165        hub.on_activity_op("stream_b", "progress", &op2);
166        let _ = rx.try_recv().expect("event stream_b");
167
168        let snapshot_a = hub.snapshot("stream_a");
169        let snapshot_b = hub.snapshot("stream_b");
170
171        assert_eq!(snapshot_a["progress"], 0.1);
172        assert_eq!(snapshot_b["progress"], 0.9);
173    }
174
175    #[test]
176    fn test_activity_hub_allows_scalar_state() {
177        let (tx, mut rx) = mpsc::unbounded_channel();
178        let hub = ActivityHub::new(tx);
179
180        let op = Op::set(Path::root(), json!("ok"));
181        hub.on_activity_op("stream_scalar", "status", &op);
182
183        let event = rx.try_recv().expect("event scalar");
184        match event {
185            AgentEvent::ActivitySnapshot { content, .. } => {
186                assert_eq!(content, json!("ok"));
187            }
188            _ => panic!("Expected ActivitySnapshot"),
189        }
190
191        let snapshot = hub.snapshot("stream_scalar");
192        assert_eq!(snapshot, json!("ok"));
193    }
194
195    #[test]
196    fn test_activity_hub_allows_array_root_state() {
197        let (tx, mut rx) = mpsc::unbounded_channel();
198        let hub = ActivityHub::new(tx);
199
200        let op = Op::set(Path::root(), json!([1, 2, 3]));
201        hub.on_activity_op("stream_array", "list", &op);
202
203        let event = rx.try_recv().expect("event array");
204        match event {
205            AgentEvent::ActivitySnapshot { content, .. } => {
206                assert_eq!(content, json!([1, 2, 3]));
207            }
208            _ => panic!("Expected ActivitySnapshot"),
209        }
210
211        let snapshot = hub.snapshot("stream_array");
212        assert_eq!(snapshot, json!([1, 2, 3]));
213    }
214
215    #[test]
216    fn test_activity_hub_invalid_op_keeps_state() {
217        let (tx, mut rx) = mpsc::unbounded_channel();
218        let hub = ActivityHub::new(tx);
219
220        let op = Op::increment(Path::root().key("progress"), 1);
221        hub.on_activity_op("stream_invalid", "progress", &op);
222
223        let event = rx.try_recv().expect("activity event");
224        match event {
225            AgentEvent::ActivitySnapshot { content, .. } => {
226                assert_eq!(content, json!({}));
227            }
228            _ => panic!("Expected ActivitySnapshot"),
229        }
230
231        let snapshot = hub.snapshot("stream_invalid");
232        assert_eq!(snapshot, json!({}));
233    }
234
235    #[test]
236    fn test_activity_hub_emits_events_in_order() {
237        let (tx, mut rx) = mpsc::unbounded_channel();
238        let hub = ActivityHub::new(tx);
239
240        let op1 = Op::set(Path::root().key("progress"), json!(0.1));
241        let op2 = Op::set(Path::root().key("progress"), json!(0.2));
242        hub.on_activity_op("stream_order", "progress", &op1);
243        hub.on_activity_op("stream_order", "progress", &op2);
244
245        let first = rx.try_recv().expect("first event");
246        let second = rx.try_recv().expect("second event");
247
248        match first {
249            AgentEvent::ActivitySnapshot { content, .. } => {
250                assert_eq!(content["progress"], 0.1);
251            }
252            _ => panic!("Expected ActivitySnapshot"),
253        }
254
255        match second {
256            AgentEvent::ActivitySnapshot { content, .. } => {
257                assert_eq!(content["progress"], 0.2);
258            }
259            _ => panic!("Expected ActivitySnapshot"),
260        }
261    }
262
263    #[test]
264    fn test_activity_hub_emits_on_noop_update() {
265        let (tx, mut rx) = mpsc::unbounded_channel();
266        let hub = ActivityHub::new(tx);
267
268        let op = Op::set(Path::root().key("progress"), json!(0.5));
269        hub.on_activity_op("stream_noop", "progress", &op);
270        hub.on_activity_op("stream_noop", "progress", &op);
271
272        let _ = rx.try_recv().expect("first event");
273        let second = rx.try_recv().expect("second event");
274
275        match second {
276            AgentEvent::ActivitySnapshot { content, .. } => {
277                assert_eq!(content["progress"], 0.5);
278            }
279            _ => panic!("Expected ActivitySnapshot"),
280        }
281    }
282
283    #[tokio::test]
284    async fn test_activity_hub_concurrent_updates_merge() {
285        let (tx, mut rx) = mpsc::unbounded_channel();
286        let hub = Arc::new(ActivityHub::new(tx));
287
288        let mut handles = Vec::new();
289        for i in 0..5 {
290            let hub = hub.clone();
291            handles.push(tokio::spawn(async move {
292                let op = Op::set(Path::root().key(format!("k{}", i)), json!(i));
293                hub.on_activity_op("stream_concurrent", "progress", &op);
294            }));
295        }
296
297        for handle in handles {
298            handle.await.expect("task");
299        }
300
301        while rx.try_recv().is_ok() {}
302
303        let snapshot = hub.snapshot("stream_concurrent");
304        for i in 0..5 {
305            assert_eq!(snapshot[format!("k{}", i)], i);
306        }
307    }
308}