1use crate::app::{Message, Operation, OperationOutcome};
2use crate::config::model::ModelId;
3use crate::session::SessionInfo;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use steer_tools::ToolCall;
8use steer_tools::ToolResult;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Usage {
13 pub input_tokens: u32,
14 pub output_tokens: u32,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum StreamEvent {
21 MessagePart {
23 content: String,
24 message_id: String,
25 },
26 MessageComplete {
27 message: Message,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 usage: Option<Usage>,
30 #[serde(default)]
31 metadata: HashMap<String, String>,
32 model: ModelId,
33 },
34
35 ToolCallStarted {
37 tool_call: ToolCall,
38 #[serde(default)]
39 metadata: HashMap<String, String>,
40 model: ModelId,
41 },
42 ToolCallCompleted {
43 tool_call_id: String,
44 result: ToolResult,
45 #[serde(default)]
46 metadata: HashMap<String, String>,
47 model: ModelId,
48 },
49 ToolCallFailed {
50 tool_call_id: String,
51 error: String,
52 #[serde(default)]
53 metadata: HashMap<String, String>,
54 model: ModelId,
55 },
56 ToolApprovalRequired {
57 tool_call: ToolCall,
58 timeout_ms: Option<u64>,
59 #[serde(default)]
60 metadata: HashMap<String, String>,
61 },
62
63 SessionCreated {
65 session_id: String,
66 metadata: SessionMetadata,
67 },
68 SessionResumed {
69 session_id: String,
70 event_offset: u64,
71 },
72 SessionSaved {
73 session_id: String,
74 },
75
76 OperationStarted {
78 operation_id: uuid::Uuid,
79 operation: Operation,
80 },
81 OperationCompleted {
82 operation_id: uuid::Uuid,
83 outcome: OperationOutcome,
84 },
85 OperationCancelled {
86 operation_id: uuid::Uuid,
87 reason: String,
88 },
89
90 Error {
92 message: String,
93 error_type: ErrorType,
94 },
95
96 WorkspaceChanged,
98 WorkspaceFiles {
99 files: Vec<String>,
100 },
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct StreamEventWithMetadata {
106 pub sequence_num: u64,
107 pub timestamp: DateTime<Utc>,
108 pub session_id: String,
109 pub event: StreamEvent,
110}
111
112impl StreamEventWithMetadata {
113 pub fn new(sequence_num: u64, session_id: String, event: StreamEvent) -> Self {
114 Self {
115 sequence_num,
116 timestamp: Utc::now(),
117 session_id,
118 event,
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct SessionMetadata {
126 pub model: ModelId,
127 pub created_at: DateTime<Utc>,
128 pub metadata: HashMap<String, String>,
129}
130
131impl From<&SessionInfo> for SessionMetadata {
132 fn from(session_info: &SessionInfo) -> Self {
133 Self {
134 model: session_info
135 .last_model
136 .clone()
137 .unwrap_or_else(crate::config::model::builtin::claude_sonnet_4_20250514),
138 created_at: session_info.created_at,
139 metadata: session_info.metadata.clone(),
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146#[serde(rename_all = "snake_case")]
147pub enum ErrorType {
148 Api,
150 Tool,
152 Session,
154 Storage,
156 Auth,
158 Network,
160 Internal,
162 Validation,
164 ResourceLimit,
166 Timeout,
168}
169
170impl StreamEvent {
171 pub fn is_error(&self) -> bool {
173 matches!(
174 self,
175 StreamEvent::Error { .. } | StreamEvent::ToolCallFailed { .. }
176 )
177 }
178
179 pub fn operation_id(&self) -> Option<&uuid::Uuid> {
181 match self {
182 StreamEvent::OperationStarted { operation_id, .. }
183 | StreamEvent::OperationCompleted { operation_id, .. }
184 | StreamEvent::OperationCancelled { operation_id, .. } => Some(operation_id),
185 _ => None,
186 }
187 }
188
189 pub fn session_id(&self) -> Option<&str> {
191 match self {
192 StreamEvent::SessionCreated { session_id, .. }
193 | StreamEvent::SessionResumed { session_id, .. }
194 | StreamEvent::SessionSaved { session_id } => Some(session_id),
195 _ => None,
196 }
197 }
198
199 pub fn tool_call_id(&self) -> Option<&str> {
201 match self {
202 StreamEvent::ToolCallStarted { tool_call, .. } => Some(&tool_call.id),
203 StreamEvent::ToolCallCompleted { tool_call_id, .. } => Some(tool_call_id),
204 StreamEvent::ToolCallFailed { tool_call_id, .. } => Some(tool_call_id),
205 StreamEvent::ToolApprovalRequired { tool_call, .. } => Some(&tool_call.id),
206 _ => None,
207 }
208 }
209
210 pub fn message_id(&self) -> Option<&str> {
212 match self {
213 StreamEvent::MessagePart { message_id, .. } => Some(message_id),
214 StreamEvent::MessageComplete { message, .. } => Some(message.id()),
215 _ => None,
216 }
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct EventFilter {
223 pub event_types: Option<Vec<String>>,
225 pub after_sequence: Option<u64>,
227 pub session_ids: Option<Vec<String>>,
229 pub operation_ids: Option<Vec<String>>,
231 pub tool_call_ids: Option<Vec<String>>,
233}
234
235impl EventFilter {
236 pub fn all() -> Self {
238 Self {
239 event_types: None,
240 after_sequence: None,
241 session_ids: None,
242 operation_ids: None,
243 tool_call_ids: None,
244 }
245 }
246
247 pub fn for_types(types: Vec<String>) -> Self {
249 Self {
250 event_types: Some(types),
251 after_sequence: None,
252 session_ids: None,
253 operation_ids: None,
254 tool_call_ids: None,
255 }
256 }
257
258 pub fn after_sequence(sequence: u64) -> Self {
260 Self {
261 event_types: None,
262 after_sequence: Some(sequence),
263 session_ids: None,
264 operation_ids: None,
265 tool_call_ids: None,
266 }
267 }
268
269 pub fn for_sessions(session_ids: Vec<String>) -> Self {
271 Self {
272 event_types: None,
273 after_sequence: None,
274 session_ids: Some(session_ids),
275 operation_ids: None,
276 tool_call_ids: None,
277 }
278 }
279
280 pub fn matches(&self, event_with_metadata: &StreamEventWithMetadata) -> bool {
282 if let Some(after_seq) = self.after_sequence {
284 if event_with_metadata.sequence_num <= after_seq {
285 return false;
286 }
287 }
288
289 if let Some(ref session_ids) = self.session_ids {
291 if !session_ids.contains(&event_with_metadata.session_id) {
292 return false;
293 }
294 }
295
296 if let Some(ref event_types) = self.event_types {
298 let event_type = match &event_with_metadata.event {
299 StreamEvent::MessagePart { .. } => "message_part",
300 StreamEvent::MessageComplete { .. } => "message_complete",
301 StreamEvent::ToolCallStarted { .. } => "tool_call_started",
302 StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
303 StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
304 StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
305 StreamEvent::SessionCreated { .. } => "session_created",
306 StreamEvent::SessionResumed { .. } => "session_resumed",
307 StreamEvent::SessionSaved { .. } => "session_saved",
308 StreamEvent::OperationStarted { .. } => "operation_started",
309 StreamEvent::OperationCompleted { .. } => "operation_completed",
310 StreamEvent::OperationCancelled { .. } => "operation_cancelled",
311 StreamEvent::Error { .. } => "error",
312 StreamEvent::WorkspaceChanged => "workspace_changed",
313 StreamEvent::WorkspaceFiles { .. } => "workspace_files",
314 };
315 if !event_types.contains(&event_type.to_string()) {
316 return false;
317 }
318 }
319
320 if let Some(ref operation_ids) = self.operation_ids {
322 if let Some(op_id) = event_with_metadata.event.operation_id() {
323 if !operation_ids.contains(&op_id.to_string()) {
324 return false;
325 }
326 } else {
327 return false;
328 }
329 }
330
331 if let Some(ref tool_call_ids) = self.tool_call_ids {
333 if let Some(tool_id) = event_with_metadata.event.tool_call_id() {
334 if !tool_call_ids.contains(&tool_id.to_string()) {
335 return false;
336 }
337 } else {
338 return false;
339 }
340 }
341
342 true
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::app::conversation::AssistantContent;
350 use crate::app::{Message, MessageData};
351
352 #[test]
353 fn test_stream_event_serialization() {
354 let event = StreamEvent::ToolCallFailed {
355 tool_call_id: "tool_123".to_string(),
356 error: "Failed to execute".to_string(),
357 metadata: HashMap::new(),
358 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
359 };
360
361 let serialized = serde_json::to_string(&event).unwrap();
362 let deserialized: StreamEvent = serde_json::from_str(&serialized).unwrap();
363
364 assert!(matches!(deserialized, StreamEvent::ToolCallFailed { .. }));
365 match deserialized {
366 StreamEvent::ToolCallFailed {
367 tool_call_id,
368 error,
369 ..
370 } => {
371 assert_eq!(tool_call_id, "tool_123");
372 assert_eq!(error, "Failed to execute");
373 }
374 _ => unreachable!(),
375 }
376 }
377
378 #[test]
379 fn test_event_with_metadata() {
380 let event = StreamEvent::MessagePart {
381 message_id: "msg_123".to_string(),
382 content: "Hello".to_string(),
383 };
384 let event_with_metadata = StreamEventWithMetadata::new(1, "session_123".to_string(), event);
385
386 assert_eq!(event_with_metadata.sequence_num, 1);
387 assert_eq!(event_with_metadata.session_id, "session_123");
388 assert!(event_with_metadata.timestamp <= Utc::now());
389 }
390
391 #[test]
392 fn test_event_type_checks() {
393 let error_event = StreamEvent::Error {
394 message: "Test error".to_string(),
395 error_type: ErrorType::Api,
396 };
397 assert!(error_event.is_error());
398
399 let tool_failed = StreamEvent::ToolCallFailed {
400 tool_call_id: "tool_123".to_string(),
401 error: "Command failed".to_string(),
402 metadata: HashMap::new(),
403 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
404 };
405 assert!(tool_failed.is_error());
406
407 let tool_approval = StreamEvent::ToolApprovalRequired {
408 tool_call: ToolCall {
409 id: "tool_123".to_string(),
410 name: "edit_file".to_string(),
411 parameters: serde_json::json!({}),
412 },
413 timeout_ms: None,
414 metadata: HashMap::new(),
415 };
416 assert!(!tool_approval.is_error());
417 }
418
419 #[test]
420 fn test_event_id_extraction() {
421 let tool_event = StreamEvent::ToolCallFailed {
422 tool_call_id: "tool_123".to_string(),
423 error: "Failed".to_string(),
424 metadata: HashMap::new(),
425 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
426 };
427 assert_eq!(tool_event.tool_call_id(), Some("tool_123"));
428
429 let message_event = StreamEvent::MessagePart {
430 message_id: "msg_123".to_string(),
431 content: "Hello".to_string(),
432 };
433 assert_eq!(message_event.message_id(), Some("msg_123"));
434
435 let op_id = uuid::Uuid::new_v4();
436 let operation_event = StreamEvent::OperationStarted {
437 operation_id: op_id,
438 operation: crate::app::Operation::Bash {
439 cmd: "echo hello".to_string(),
440 },
441 };
442 assert_eq!(operation_event.operation_id(), Some(&op_id));
443
444 let session_event = StreamEvent::SessionCreated {
445 session_id: "session_123".to_string(),
446 metadata: SessionMetadata {
447 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
448 created_at: Utc::now(),
449 metadata: HashMap::new(),
450 },
451 };
452 assert_eq!(session_event.session_id(), Some("session_123"));
453 }
454
455 #[test]
456 fn test_event_filter() {
457 let event = StreamEvent::ToolCallFailed {
458 tool_call_id: "tool_123".to_string(),
459 error: "Failed".to_string(),
460 metadata: HashMap::new(),
461 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
462 };
463 let event_with_metadata = StreamEventWithMetadata::new(5, "session_123".to_string(), event);
464
465 let after_filter = EventFilter::after_sequence(3);
467 assert!(after_filter.matches(&event_with_metadata));
468
469 let before_filter = EventFilter::after_sequence(5);
470 assert!(!before_filter.matches(&event_with_metadata));
471
472 let session_filter = EventFilter::for_sessions(vec!["session_123".to_string()]);
474 assert!(session_filter.matches(&event_with_metadata));
475
476 let wrong_session_filter = EventFilter::for_sessions(vec!["session_456".to_string()]);
477 assert!(!wrong_session_filter.matches(&event_with_metadata));
478
479 let type_filter = EventFilter::for_types(vec!["tool_call_failed".to_string()]);
481 assert!(type_filter.matches(&event_with_metadata));
482
483 let wrong_type_filter = EventFilter::for_types(vec!["message_part".to_string()]);
484 assert!(!wrong_type_filter.matches(&event_with_metadata));
485 }
486
487 #[test]
488 fn test_message_complete_event() {
489 let message = Message {
490 data: MessageData::Assistant {
491 content: vec![AssistantContent::Text {
492 text: "Hello world".to_string(),
493 }],
494 },
495 timestamp: 0,
496 id: "msg_123".to_string(),
497 parent_message_id: None,
498 };
499
500 let event = StreamEvent::MessageComplete {
501 message,
502 usage: None,
503 metadata: HashMap::new(),
504 model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
505 };
506 assert_eq!(event.message_id(), Some("msg_123"));
507 }
508}