steer_core/events/
mod.rs

1use crate::app::{Message, Operation, OperationOutcome};
2use crate::config::model::ModelId;
3use crate::session::SessionInfo;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use steer_tools::ToolCall;
8use steer_tools::ToolResult;
9
10/// Token usage information
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Usage {
13    pub input_tokens: u32,
14    pub output_tokens: u32,
15}
16
17/// Unified event type for external consumers
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum StreamEvent {
21    // Message events
22    MessagePart {
23        content: String,
24        message_id: String,
25    },
26    MessageComplete {
27        message: Message,
28        #[serde(skip_serializing_if = "Option::is_none")]
29        usage: Option<Usage>,
30        #[serde(default)]
31        metadata: HashMap<String, String>,
32        model: ModelId,
33    },
34
35    // Tool events
36    ToolCallStarted {
37        tool_call: ToolCall,
38        #[serde(default)]
39        metadata: HashMap<String, String>,
40        model: ModelId,
41    },
42    ToolCallCompleted {
43        tool_call_id: String,
44        result: ToolResult,
45        #[serde(default)]
46        metadata: HashMap<String, String>,
47        model: ModelId,
48    },
49    ToolCallFailed {
50        tool_call_id: String,
51        error: String,
52        #[serde(default)]
53        metadata: HashMap<String, String>,
54        model: ModelId,
55    },
56    ToolApprovalRequired {
57        tool_call: ToolCall,
58        timeout_ms: Option<u64>,
59        #[serde(default)]
60        metadata: HashMap<String, String>,
61    },
62
63    // Session events
64    SessionCreated {
65        session_id: String,
66        metadata: SessionMetadata,
67    },
68    SessionResumed {
69        session_id: String,
70        event_offset: u64,
71    },
72    SessionSaved {
73        session_id: String,
74    },
75
76    // Operation events
77    OperationStarted {
78        operation_id: uuid::Uuid,
79        operation: Operation,
80    },
81    OperationCompleted {
82        operation_id: uuid::Uuid,
83        outcome: OperationOutcome,
84    },
85    OperationCancelled {
86        operation_id: uuid::Uuid,
87        reason: String,
88    },
89
90    // System events
91    Error {
92        message: String,
93        error_type: ErrorType,
94    },
95
96    // Workspace events
97    WorkspaceChanged,
98    WorkspaceFiles {
99        files: Vec<String>,
100    },
101}
102
103/// Event with metadata for persistence and replay
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct StreamEventWithMetadata {
106    pub sequence_num: u64,
107    pub timestamp: DateTime<Utc>,
108    pub session_id: String,
109    pub event: StreamEvent,
110}
111
112impl StreamEventWithMetadata {
113    pub fn new(sequence_num: u64, session_id: String, event: StreamEvent) -> Self {
114        Self {
115            sequence_num,
116            timestamp: Utc::now(),
117            session_id,
118            event,
119        }
120    }
121}
122
123/// Session metadata for events
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct SessionMetadata {
126    pub model: ModelId,
127    pub created_at: DateTime<Utc>,
128    pub metadata: HashMap<String, String>,
129}
130
131impl From<&SessionInfo> for SessionMetadata {
132    fn from(session_info: &SessionInfo) -> Self {
133        Self {
134            model: session_info
135                .last_model
136                .clone()
137                .unwrap_or_else(crate::config::model::builtin::claude_sonnet_4_20250514),
138            created_at: session_info.created_at,
139            metadata: session_info.metadata.clone(),
140        }
141    }
142}
143
144/// Error types for system events
145#[derive(Debug, Clone, Serialize, Deserialize)]
146#[serde(rename_all = "snake_case")]
147pub enum ErrorType {
148    /// API error (OpenAI, Anthropic, etc.)
149    Api,
150    /// Tool execution error
151    Tool,
152    /// Session management error
153    Session,
154    /// Persistence/storage error
155    Storage,
156    /// Authentication/authorization error
157    Auth,
158    /// Network/transport error
159    Network,
160    /// Internal server error
161    Internal,
162    /// Validation error
163    Validation,
164    /// Resource limit exceeded
165    ResourceLimit,
166    /// Operation timeout
167    Timeout,
168}
169
170impl StreamEvent {
171    /// Check if this event indicates an error condition
172    pub fn is_error(&self) -> bool {
173        matches!(
174            self,
175            StreamEvent::Error { .. } | StreamEvent::ToolCallFailed { .. }
176        )
177    }
178
179    /// Get the operation ID if this event relates to an operation
180    pub fn operation_id(&self) -> Option<&uuid::Uuid> {
181        match self {
182            StreamEvent::OperationStarted { operation_id, .. }
183            | StreamEvent::OperationCompleted { operation_id, .. }
184            | StreamEvent::OperationCancelled { operation_id, .. } => Some(operation_id),
185            _ => None,
186        }
187    }
188
189    /// Get the session ID if this event relates to a session
190    pub fn session_id(&self) -> Option<&str> {
191        match self {
192            StreamEvent::SessionCreated { session_id, .. }
193            | StreamEvent::SessionResumed { session_id, .. }
194            | StreamEvent::SessionSaved { session_id } => Some(session_id),
195            _ => None,
196        }
197    }
198
199    /// Get the tool call ID if this event relates to a tool call
200    pub fn tool_call_id(&self) -> Option<&str> {
201        match self {
202            StreamEvent::ToolCallStarted { tool_call, .. } => Some(&tool_call.id),
203            StreamEvent::ToolCallCompleted { tool_call_id, .. } => Some(tool_call_id),
204            StreamEvent::ToolCallFailed { tool_call_id, .. } => Some(tool_call_id),
205            StreamEvent::ToolApprovalRequired { tool_call, .. } => Some(&tool_call.id),
206            _ => None,
207        }
208    }
209
210    /// Get the message ID if this event relates to a message
211    pub fn message_id(&self) -> Option<&str> {
212        match self {
213            StreamEvent::MessagePart { message_id, .. } => Some(message_id),
214            StreamEvent::MessageComplete { message, .. } => Some(message.id()),
215            _ => None,
216        }
217    }
218}
219
220/// Event filter for client subscriptions
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct EventFilter {
223    /// Only events matching these types
224    pub event_types: Option<Vec<String>>,
225    /// Only events after this sequence number
226    pub after_sequence: Option<u64>,
227    /// Only events for these sessions
228    pub session_ids: Option<Vec<String>>,
229    /// Only events for these operations
230    pub operation_ids: Option<Vec<String>>,
231    /// Only events for these tool calls
232    pub tool_call_ids: Option<Vec<String>>,
233}
234
235impl EventFilter {
236    /// Create an empty filter (matches all events)
237    pub fn all() -> Self {
238        Self {
239            event_types: None,
240            after_sequence: None,
241            session_ids: None,
242            operation_ids: None,
243            tool_call_ids: None,
244        }
245    }
246
247    /// Create a filter for specific event types
248    pub fn for_types(types: Vec<String>) -> Self {
249        Self {
250            event_types: Some(types),
251            after_sequence: None,
252            session_ids: None,
253            operation_ids: None,
254            tool_call_ids: None,
255        }
256    }
257
258    /// Create a filter for events after a sequence number
259    pub fn after_sequence(sequence: u64) -> Self {
260        Self {
261            event_types: None,
262            after_sequence: Some(sequence),
263            session_ids: None,
264            operation_ids: None,
265            tool_call_ids: None,
266        }
267    }
268
269    /// Create a filter for specific sessions
270    pub fn for_sessions(session_ids: Vec<String>) -> Self {
271        Self {
272            event_types: None,
273            after_sequence: None,
274            session_ids: Some(session_ids),
275            operation_ids: None,
276            tool_call_ids: None,
277        }
278    }
279
280    /// Check if an event matches this filter
281    pub fn matches(&self, event_with_metadata: &StreamEventWithMetadata) -> bool {
282        // Check sequence number
283        if let Some(after_seq) = self.after_sequence {
284            if event_with_metadata.sequence_num <= after_seq {
285                return false;
286            }
287        }
288
289        // Check session ID
290        if let Some(ref session_ids) = self.session_ids {
291            if !session_ids.contains(&event_with_metadata.session_id) {
292                return false;
293            }
294        }
295
296        // Check event type
297        if let Some(ref event_types) = self.event_types {
298            let event_type = match &event_with_metadata.event {
299                StreamEvent::MessagePart { .. } => "message_part",
300                StreamEvent::MessageComplete { .. } => "message_complete",
301                StreamEvent::ToolCallStarted { .. } => "tool_call_started",
302                StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
303                StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
304                StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
305                StreamEvent::SessionCreated { .. } => "session_created",
306                StreamEvent::SessionResumed { .. } => "session_resumed",
307                StreamEvent::SessionSaved { .. } => "session_saved",
308                StreamEvent::OperationStarted { .. } => "operation_started",
309                StreamEvent::OperationCompleted { .. } => "operation_completed",
310                StreamEvent::OperationCancelled { .. } => "operation_cancelled",
311                StreamEvent::Error { .. } => "error",
312                StreamEvent::WorkspaceChanged => "workspace_changed",
313                StreamEvent::WorkspaceFiles { .. } => "workspace_files",
314            };
315            if !event_types.contains(&event_type.to_string()) {
316                return false;
317            }
318        }
319
320        // Check operation ID
321        if let Some(ref operation_ids) = self.operation_ids {
322            if let Some(op_id) = event_with_metadata.event.operation_id() {
323                if !operation_ids.contains(&op_id.to_string()) {
324                    return false;
325                }
326            } else {
327                return false;
328            }
329        }
330
331        // Check tool call ID
332        if let Some(ref tool_call_ids) = self.tool_call_ids {
333            if let Some(tool_id) = event_with_metadata.event.tool_call_id() {
334                if !tool_call_ids.contains(&tool_id.to_string()) {
335                    return false;
336                }
337            } else {
338                return false;
339            }
340        }
341
342        true
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::app::conversation::AssistantContent;
350    use crate::app::{Message, MessageData};
351
352    #[test]
353    fn test_stream_event_serialization() {
354        let event = StreamEvent::ToolCallFailed {
355            tool_call_id: "tool_123".to_string(),
356            error: "Failed to execute".to_string(),
357            metadata: HashMap::new(),
358            model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
359        };
360
361        let serialized = serde_json::to_string(&event).unwrap();
362        let deserialized: StreamEvent = serde_json::from_str(&serialized).unwrap();
363
364        assert!(matches!(deserialized, StreamEvent::ToolCallFailed { .. }));
365        match deserialized {
366            StreamEvent::ToolCallFailed {
367                tool_call_id,
368                error,
369                ..
370            } => {
371                assert_eq!(tool_call_id, "tool_123");
372                assert_eq!(error, "Failed to execute");
373            }
374            _ => unreachable!(),
375        }
376    }
377
378    #[test]
379    fn test_event_with_metadata() {
380        let event = StreamEvent::MessagePart {
381            message_id: "msg_123".to_string(),
382            content: "Hello".to_string(),
383        };
384        let event_with_metadata = StreamEventWithMetadata::new(1, "session_123".to_string(), event);
385
386        assert_eq!(event_with_metadata.sequence_num, 1);
387        assert_eq!(event_with_metadata.session_id, "session_123");
388        assert!(event_with_metadata.timestamp <= Utc::now());
389    }
390
391    #[test]
392    fn test_event_type_checks() {
393        let error_event = StreamEvent::Error {
394            message: "Test error".to_string(),
395            error_type: ErrorType::Api,
396        };
397        assert!(error_event.is_error());
398
399        let tool_failed = StreamEvent::ToolCallFailed {
400            tool_call_id: "tool_123".to_string(),
401            error: "Command failed".to_string(),
402            metadata: HashMap::new(),
403            model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
404        };
405        assert!(tool_failed.is_error());
406
407        let tool_approval = StreamEvent::ToolApprovalRequired {
408            tool_call: ToolCall {
409                id: "tool_123".to_string(),
410                name: "edit_file".to_string(),
411                parameters: serde_json::json!({}),
412            },
413            timeout_ms: None,
414            metadata: HashMap::new(),
415        };
416        assert!(!tool_approval.is_error());
417    }
418
419    #[test]
420    fn test_event_id_extraction() {
421        let tool_event = StreamEvent::ToolCallFailed {
422            tool_call_id: "tool_123".to_string(),
423            error: "Failed".to_string(),
424            metadata: HashMap::new(),
425            model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
426        };
427        assert_eq!(tool_event.tool_call_id(), Some("tool_123"));
428
429        let message_event = StreamEvent::MessagePart {
430            message_id: "msg_123".to_string(),
431            content: "Hello".to_string(),
432        };
433        assert_eq!(message_event.message_id(), Some("msg_123"));
434
435        let op_id = uuid::Uuid::new_v4();
436        let operation_event = StreamEvent::OperationStarted {
437            operation_id: op_id,
438            operation: crate::app::Operation::Bash {
439                cmd: "echo hello".to_string(),
440            },
441        };
442        assert_eq!(operation_event.operation_id(), Some(&op_id));
443
444        let session_event = StreamEvent::SessionCreated {
445            session_id: "session_123".to_string(),
446            metadata: SessionMetadata {
447                model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
448                created_at: Utc::now(),
449                metadata: HashMap::new(),
450            },
451        };
452        assert_eq!(session_event.session_id(), Some("session_123"));
453    }
454
455    #[test]
456    fn test_event_filter() {
457        let event = StreamEvent::ToolCallFailed {
458            tool_call_id: "tool_123".to_string(),
459            error: "Failed".to_string(),
460            metadata: HashMap::new(),
461            model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
462        };
463        let event_with_metadata = StreamEventWithMetadata::new(5, "session_123".to_string(), event);
464
465        // Test sequence filter
466        let after_filter = EventFilter::after_sequence(3);
467        assert!(after_filter.matches(&event_with_metadata));
468
469        let before_filter = EventFilter::after_sequence(5);
470        assert!(!before_filter.matches(&event_with_metadata));
471
472        // Test session filter
473        let session_filter = EventFilter::for_sessions(vec!["session_123".to_string()]);
474        assert!(session_filter.matches(&event_with_metadata));
475
476        let wrong_session_filter = EventFilter::for_sessions(vec!["session_456".to_string()]);
477        assert!(!wrong_session_filter.matches(&event_with_metadata));
478
479        // Test type filter
480        let type_filter = EventFilter::for_types(vec!["tool_call_failed".to_string()]);
481        assert!(type_filter.matches(&event_with_metadata));
482
483        let wrong_type_filter = EventFilter::for_types(vec!["message_part".to_string()]);
484        assert!(!wrong_type_filter.matches(&event_with_metadata));
485    }
486
487    #[test]
488    fn test_message_complete_event() {
489        let message = Message {
490            data: MessageData::Assistant {
491                content: vec![AssistantContent::Text {
492                    text: "Hello world".to_string(),
493                }],
494            },
495            timestamp: 0,
496            id: "msg_123".to_string(),
497            parent_message_id: None,
498        };
499
500        let event = StreamEvent::MessageComplete {
501            message,
502            usage: None,
503            metadata: HashMap::new(),
504            model: crate::config::model::builtin::claude_3_5_sonnet_20241022(),
505        };
506        assert_eq!(event.message_id(), Some("msg_123"));
507    }
508}