1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use futures::{Stream, StreamExt, stream};
11use std::collections::VecDeque;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use crate::{
17 error::Result,
18 event::{AgentEvent, EventStream, HookManager, converter},
19 log::LogNormalizer,
20 types::{ExecutorType, ExitStatus},
21};
22
23pub type RawLogStream = Pin<Box<dyn Stream<Item = Vec<u8>> + Send>>;
27
28#[derive(Debug, Clone)]
30pub struct SessionMetadata {
31 pub session_id: String,
33 pub executor_type: ExecutorType,
35 pub created_at: DateTime<Utc>,
37 pub last_message_id: Option<String>,
39 pub working_dir: PathBuf,
41 pub context_window_override_tokens: Option<u32>,
43}
44
45#[derive(Debug, Clone)]
47pub struct SessionResume {
48 pub session_id: String,
50 pub reset_to_message: Option<String>,
52}
53
54pub struct AgentSession {
60 pub session_id: String,
62 pub executor_type: ExecutorType,
64 pub working_dir: PathBuf,
66 pub created_at: DateTime<Utc>,
68 pub last_message_id: Option<String>,
70 pub context_window_override_tokens: Option<u32>,
72 lifecycle_controller: SessionControllerRef,
73}
74
75impl AgentSession {
76 pub fn new(
81 session_id: impl Into<String>,
82 executor_type: ExecutorType,
83 working_dir: impl Into<PathBuf>,
84 context_window_override_tokens: Option<u32>,
85 ) -> Self {
86 Self::from_parts(
87 SessionMetadata {
88 session_id: session_id.into(),
89 executor_type,
90 created_at: Utc::now(),
91 last_message_id: None,
92 working_dir: working_dir.into(),
93 context_window_override_tokens,
94 },
95 Arc::new(DetachedSessionLifecycleController),
96 )
97 }
98
99 pub fn from_metadata(metadata: SessionMetadata) -> Self {
101 Self::from_parts(metadata, Arc::new(DetachedSessionLifecycleController))
102 }
103
104 pub(crate) fn from_metadata_with_exit_status(
105 metadata: SessionMetadata,
106 exit_status: ExitStatus,
107 ) -> Self {
108 Self::from_parts(
109 metadata,
110 Arc::new(CompletedSessionLifecycleController { exit_status }),
111 )
112 }
113
114 fn from_parts(metadata: SessionMetadata, lifecycle_controller: SessionControllerRef) -> Self {
115 Self {
116 session_id: metadata.session_id,
117 executor_type: metadata.executor_type,
118 working_dir: metadata.working_dir,
119 created_at: metadata.created_at,
120 last_message_id: metadata.last_message_id,
121 context_window_override_tokens: metadata.context_window_override_tokens,
122 lifecycle_controller,
123 }
124 }
125
126 pub fn event_stream(
149 &self,
150 raw_logs: RawLogStream,
151 normalizer: Box<dyn LogNormalizer + Send>,
152 hooks: Option<Arc<HookManager>>,
153 ) -> EventStream {
154 let state = EventPipelineState {
155 session_id: self.session_id.clone(),
156 raw_logs,
157 normalizer,
158 hooks,
159 pending_events: VecDeque::new(),
160 emitted_started: false,
161 finished: false,
162 saw_error: false,
163 context_window_override_tokens: self.context_window_override_tokens,
164 };
165
166 let stream = stream::unfold(state, |mut state| async move {
167 loop {
168 if let Some(event) = state.pending_events.pop_front() {
169 if let Some(hook_manager) = &state.hooks {
170 hook_manager.trigger(&event).await;
171 }
172 return Some((event, state));
173 }
174
175 if !state.emitted_started {
176 state.emitted_started = true;
177 state.push_event(AgentEvent::SessionStarted {
178 session_id: state.session_id.clone(),
179 });
180 continue;
181 }
182
183 if state.finished {
184 return None;
185 }
186
187 match state.raw_logs.next().await {
188 Some(chunk) => {
189 let logs = state.normalizer.normalize(&chunk);
190 state.push_logs(logs);
191 }
192 None => {
193 let logs = state.normalizer.flush();
194 state.push_logs(logs);
195 state.push_event(AgentEvent::SessionCompleted {
196 exit_status: ExitStatus {
197 code: None,
198 success: !state.saw_error,
199 },
200 });
201 state.finished = true;
202 }
203 }
204 }
205 });
206
207 EventStream::new(Box::pin(stream))
208 }
209
210 pub fn metadata(&self) -> SessionMetadata {
212 SessionMetadata {
213 session_id: self.session_id.clone(),
214 executor_type: self.executor_type,
215 created_at: self.created_at,
216 last_message_id: self.last_message_id.clone(),
217 working_dir: self.working_dir.clone(),
218 context_window_override_tokens: self.context_window_override_tokens,
219 }
220 }
221
222 pub async fn wait(&mut self) -> Result<ExitStatus> {
227 self.lifecycle_controller.wait().await
228 }
229
230 pub async fn cancel(&mut self) -> Result<()> {
234 self.lifecycle_controller.cancel().await
235 }
236}
237
238#[async_trait]
239pub(crate) trait SessionLifecycleController: Send + Sync {
240 async fn wait(&self) -> Result<ExitStatus>;
241 async fn cancel(&self) -> Result<()>;
242}
243
244struct DetachedSessionLifecycleController;
245
246#[async_trait]
247impl SessionLifecycleController for DetachedSessionLifecycleController {
248 async fn wait(&self) -> Result<ExitStatus> {
249 Ok(ExitStatus {
250 code: None,
251 success: true,
252 })
253 }
254
255 async fn cancel(&self) -> Result<()> {
256 Ok(())
257 }
258}
259
260struct CompletedSessionLifecycleController {
261 exit_status: ExitStatus,
262}
263
264#[async_trait]
265impl SessionLifecycleController for CompletedSessionLifecycleController {
266 async fn wait(&self) -> Result<ExitStatus> {
267 Ok(self.exit_status)
268 }
269
270 async fn cancel(&self) -> Result<()> {
271 Ok(())
272 }
273}
274
275type SessionControllerRef = Arc<dyn SessionLifecycleController>;
276
277struct EventPipelineState {
278 session_id: String,
279 raw_logs: RawLogStream,
280 normalizer: Box<dyn LogNormalizer + Send>,
281 hooks: Option<Arc<HookManager>>,
282 pending_events: VecDeque<AgentEvent>,
283 emitted_started: bool,
284 finished: bool,
285 saw_error: bool,
286 context_window_override_tokens: Option<u32>,
287}
288
289impl EventPipelineState {
290 fn push_logs(&mut self, logs: Vec<crate::log::NormalizedLog>) {
291 for log in logs {
292 for event in converter::from_normalized_log_with_context_override(
293 log,
294 self.context_window_override_tokens,
295 ) {
296 self.push_event(event);
297 }
298 }
299 }
300
301 fn push_event(&mut self, event: AgentEvent) {
302 if matches!(event, AgentEvent::ErrorOccurred { .. }) {
303 self.saw_error = true;
304 }
305 self.pending_events.push_back(event);
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use async_trait::async_trait;
313 use futures::{StreamExt, stream};
314 use serde_json::json;
315 use std::sync::atomic::{AtomicBool, Ordering};
316 use tokio::sync::Mutex;
317
318 use crate::{
319 event::EventType,
320 log::{ActionType, NormalizedLog},
321 types::{ContextUsageSource, Role, ToolStatus},
322 };
323
324 struct TestNormalizer;
325
326 impl LogNormalizer for TestNormalizer {
327 fn normalize(&mut self, chunk: &[u8]) -> Vec<NormalizedLog> {
328 match chunk {
329 b"message" => vec![NormalizedLog::Message {
330 role: Role::Assistant,
331 content: "hello".to_string(),
332 }],
333 b"tool-start" => vec![NormalizedLog::ToolCall {
334 name: "bash".to_string(),
335 args: json!({"cmd":"ls"}),
336 status: ToolStatus::Started,
337 action: ActionType::CommandRun {
338 command: "ls".to_string(),
339 },
340 }],
341 b"tool-done" => vec![NormalizedLog::ToolCall {
342 name: "bash".to_string(),
343 args: json!({"cmd":"ls"}),
344 status: ToolStatus::Completed,
345 action: ActionType::CommandRun {
346 command: "ls".to_string(),
347 },
348 }],
349 b"error" => vec![NormalizedLog::Error {
350 error_type: "execution_failed".to_string(),
351 message: "boom".to_string(),
352 }],
353 _ => Vec::new(),
354 }
355 }
356
357 fn flush(&mut self) -> Vec<NormalizedLog> {
358 vec![NormalizedLog::TokenUsage {
359 total: 10,
360 limit: 100,
361 }]
362 }
363 }
364
365 #[tokio::test]
366 async fn session_event_stream_builds_pipeline_and_triggers_hooks() {
367 let session = AgentSession::new("session-1", ExecutorType::Codex, PathBuf::from("."), None);
368
369 let received_messages = Arc::new(Mutex::new(Vec::<String>::new()));
370 let hooks = Arc::new(HookManager::new());
371 hooks.register(
372 EventType::MessageReceived,
373 Arc::new({
374 let received_messages = Arc::clone(&received_messages);
375 move |event| {
376 let received_messages = Arc::clone(&received_messages);
377 let content = match event {
378 AgentEvent::MessageReceived { content, .. } => Some(content.clone()),
379 _ => None,
380 };
381 Box::pin(async move {
382 if let Some(content) = content {
383 received_messages.lock().await.push(content);
384 }
385 })
386 }
387 }),
388 );
389
390 let raw_logs: RawLogStream = Box::pin(stream::iter(vec![
391 b"message".to_vec(),
392 b"tool-start".to_vec(),
393 b"tool-done".to_vec(),
394 ]));
395
396 let events = session
397 .event_stream(raw_logs, Box::new(TestNormalizer), Some(hooks))
398 .collect::<Vec<_>>()
399 .await;
400
401 assert!(matches!(
402 events.first(),
403 Some(AgentEvent::SessionStarted { session_id }) if session_id == "session-1"
404 ));
405 assert!(events
406 .iter()
407 .any(|event| matches!(event, AgentEvent::MessageReceived { content, .. } if content == "hello")));
408 assert!(events.iter().any(
409 |event| matches!(event, AgentEvent::ToolCallStarted { tool, .. } if tool == "bash")
410 ));
411 assert!(events.iter().any(
412 |event| matches!(event, AgentEvent::ToolCallCompleted { tool, .. } if tool == "bash")
413 ));
414 assert!(events.iter().any(|event| matches!(
415 event,
416 AgentEvent::ContextUsageUpdated { usage }
417 if usage.used_tokens == 10
418 && usage.window_tokens == Some(100)
419 && usage.remaining_tokens == Some(90)
420 && usage.source == ContextUsageSource::ProviderReported
421 )));
422 assert!(matches!(
423 events.last(),
424 Some(AgentEvent::SessionCompleted { exit_status }) if exit_status.success
425 ));
426
427 let captured = received_messages.lock().await.clone();
428 assert_eq!(captured, vec!["hello".to_string()]);
429 }
430
431 #[tokio::test]
432 async fn session_event_stream_marks_completion_as_failed_when_errors_seen() {
433 let session = AgentSession::new(
434 "session-2",
435 ExecutorType::ClaudeCode,
436 PathBuf::from("."),
437 None,
438 );
439
440 let raw_logs: RawLogStream = Box::pin(stream::iter(vec![b"error".to_vec()]));
441 let events = session
442 .event_stream(raw_logs, Box::new(TestNormalizer), None)
443 .collect::<Vec<_>>()
444 .await;
445
446 assert!(events.iter().any(
447 |event| matches!(event, AgentEvent::ErrorOccurred { error } if error.contains("boom"))
448 ));
449 assert!(matches!(
450 events.last(),
451 Some(AgentEvent::SessionCompleted { exit_status }) if !exit_status.success
452 ));
453 }
454
455 struct UnknownLimitNormalizer;
456
457 impl LogNormalizer for UnknownLimitNormalizer {
458 fn normalize(&mut self, _chunk: &[u8]) -> Vec<NormalizedLog> {
459 Vec::new()
460 }
461
462 fn flush(&mut self) -> Vec<NormalizedLog> {
463 vec![NormalizedLog::TokenUsage {
464 total: 15,
465 limit: 0,
466 }]
467 }
468 }
469
470 #[tokio::test]
471 async fn session_event_stream_applies_context_window_override() {
472 let session = AgentSession::new(
473 "session-3",
474 ExecutorType::Codex,
475 PathBuf::from("."),
476 Some(60),
477 );
478
479 let raw_logs: RawLogStream = Box::pin(stream::iter(Vec::<Vec<u8>>::new()));
480 let events = session
481 .event_stream(raw_logs, Box::new(UnknownLimitNormalizer), None)
482 .collect::<Vec<_>>()
483 .await;
484
485 assert!(events.iter().any(|event| matches!(
486 event,
487 AgentEvent::ContextUsageUpdated { usage }
488 if usage.used_tokens == 15
489 && usage.window_tokens == Some(60)
490 && usage.remaining_tokens == Some(45)
491 && usage.source == ContextUsageSource::ConfigOverride
492 )));
493 }
494
495 #[tokio::test]
496 async fn wait_defaults_to_completed_success_when_unmanaged() {
497 let mut session = AgentSession::new(
498 "session-unmanaged",
499 ExecutorType::Codex,
500 PathBuf::from("."),
501 None,
502 );
503
504 let exit_status = session.wait().await.expect("wait should succeed");
505 assert_eq!(
506 exit_status,
507 ExitStatus {
508 code: None,
509 success: true
510 }
511 );
512 }
513
514 #[tokio::test]
515 async fn wait_uses_session_lifecycle_controller() {
516 let mut session = AgentSession::from_metadata_with_exit_status(
517 SessionMetadata {
518 session_id: "session-managed".to_string(),
519 executor_type: ExecutorType::ClaudeCode,
520 created_at: Utc::now(),
521 last_message_id: None,
522 working_dir: PathBuf::from("."),
523 context_window_override_tokens: None,
524 },
525 ExitStatus {
526 code: Some(17),
527 success: false,
528 },
529 );
530
531 let first = session.wait().await.expect("wait should use controller");
532 assert_eq!(
533 first,
534 ExitStatus {
535 code: Some(17),
536 success: false
537 }
538 );
539
540 let second = session
541 .wait()
542 .await
543 .expect("second wait should remain stable");
544 assert_eq!(
545 second,
546 ExitStatus {
547 code: Some(17),
548 success: false
549 }
550 );
551 }
552
553 struct CancelProbeController {
554 cancelled: Arc<AtomicBool>,
555 }
556
557 #[async_trait]
558 impl SessionLifecycleController for CancelProbeController {
559 async fn wait(&self) -> Result<ExitStatus> {
560 Ok(ExitStatus {
561 code: None,
562 success: true,
563 })
564 }
565
566 async fn cancel(&self) -> Result<()> {
567 self.cancelled.store(true, Ordering::Relaxed);
568 Ok(())
569 }
570 }
571
572 #[tokio::test]
573 async fn cancel_delegates_to_registered_lifecycle_controller() {
574 let session_id = "session-cancel".to_string();
575 let cancelled = Arc::new(AtomicBool::new(false));
576
577 let mut session = AgentSession::from_parts(
578 SessionMetadata {
579 session_id,
580 executor_type: ExecutorType::ClaudeCode,
581 created_at: Utc::now(),
582 last_message_id: None,
583 working_dir: PathBuf::from("."),
584 context_window_override_tokens: None,
585 },
586 Arc::new(CancelProbeController {
587 cancelled: Arc::clone(&cancelled),
588 }),
589 );
590
591 session.cancel().await.expect("cancel should succeed");
592 assert!(cancelled.load(Ordering::Relaxed));
593 }
594}