1use crate::api::{Model, ToolCall};
2use crate::app::{
3 App, AppCommand, AppConfig, AppEvent, Conversation, Message as ConversationMessage, MessageData,
4};
5use crate::error::{Error, Result};
6use crate::events::StreamEvent;
7use crate::session::{
8 Session, SessionConfig, SessionFilter, SessionInfo, SessionState, SessionStore,
9 SessionStoreError, ToolCallUpdate,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use steer_tools::ToolResult;
14use thiserror::Error;
15use tokio::sync::{RwLock, mpsc};
16use tokio::task::JoinHandle;
17use tracing::{debug, error, info, warn};
18
19#[derive(Debug, Error)]
21pub enum SessionManagerError {
22 #[error("Maximum session capacity reached ({current}/{max}). Cannot create new session.")]
23 CapacityExceeded { current: usize, max: usize },
24
25 #[error("Session not active: {session_id}")]
26 SessionNotActive { session_id: String },
27
28 #[error("Session {session_id} already has an active listener")]
29 SessionAlreadyHasListener { session_id: String },
30
31 #[error("Failed to create managed session: {message}")]
32 CreationFailed { message: String },
33
34 #[error(transparent)]
35 Storage(#[from] SessionStoreError),
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionManagerConfig {
41 pub max_concurrent_sessions: usize,
43 pub default_model: Model,
45 pub auto_persist: bool,
47}
48
49pub struct ManagedSession {
51 pub session: Session,
53 pub command_tx: mpsc::Sender<AppCommand>,
55 pub event_rx: Option<mpsc::Receiver<AppEvent>>,
57 pub subscriber_count: usize,
59 pub last_activity: chrono::DateTime<chrono::Utc>,
61 pub app_task_handle: JoinHandle<()>,
63 pub event_task_handle: JoinHandle<()>,
65}
66
67impl ManagedSession {
68 pub async fn new(
70 mut session: Session,
71 app_config: AppConfig,
72 store: Arc<dyn SessionStore>,
73 default_model: Model,
74 conversation: Option<Conversation>,
75 ) -> Result<Self> {
76 let (app_event_tx, mut app_event_rx) = mpsc::channel(100);
78 let (app_command_tx, app_command_rx) = mpsc::channel::<AppCommand>(32);
79
80 let (external_event_tx, external_event_rx) = mpsc::channel(100);
82
83 crate::app::OpContext::init_command_tx(app_command_tx.clone());
85
86 let workspace = session.build_workspace().await?;
88
89 let (backend_registry, mcp_servers) = session
91 .config
92 .build_registry(
93 Arc::new(app_config.llm_config_provider.clone()),
94 workspace.clone(),
95 )
96 .await?;
97
98 session.state.mcp_servers = mcp_servers;
100
101 let tool_executor = Arc::new(crate::tools::ToolExecutor::with_all_components(
102 workspace.clone(),
103 Arc::new(backend_registry),
104 Arc::new(crate::app::validation::ValidatorRegistry::new()),
105 app_config.llm_config_provider.clone(),
106 ));
107
108 let app = if let Some(conv) = conversation {
110 App::new_with_conversation(
111 app_config,
112 app_event_tx,
113 default_model,
114 workspace.clone(),
115 tool_executor,
116 Some(session.config.clone()),
117 conv,
118 )
119 .await?
120 } else {
121 App::new(
122 app_config,
123 app_event_tx,
124 default_model,
125 workspace.clone(),
126 tool_executor,
127 Some(session.config.clone()),
128 )
129 .await?
130 };
131
132 let app_task_handle = tokio::spawn(crate::app::app_actor_loop(app, app_command_rx));
134
135 let session_id = session.id.clone();
137 let store_clone = store.clone();
138
139 let event_task_handle = tokio::spawn(async move {
140 while let Some(app_event) = app_event_rx.recv().await {
141 if let Err(e) = external_event_tx.try_send(app_event.clone()) {
143 warn!(session_id = %session_id, "Failed to send event to external consumer: {}", e);
144 }
145
146 if let AppEvent::ActiveMessageIdChanged { message_id } = &app_event {
148 if let Err(e) = store_clone
149 .update_active_message_id(&session_id, message_id.as_deref())
150 .await
151 {
152 error!(session_id = %session_id, error = %e, "Failed to update active message ID");
153 }
154 }
155
156 if let Some(stream_event) = translate_app_event(app_event) {
158 if let Ok(_sequence_num) =
160 store_clone.append_event(&session_id, &stream_event).await
161 {
162 if let Err(e) =
164 update_session_state_for_event(&store_clone, &session_id, &stream_event)
165 .await
166 {
167 error!(session_id = %session_id, error = %e, "Failed to update session state");
168 }
169 }
170 }
171 }
172 info!(session_id = %session_id, "Event translation loop ended");
173 });
174
175 Ok(Self {
176 session,
177 command_tx: app_command_tx,
178 event_rx: Some(external_event_rx),
179 subscriber_count: 0,
180 last_activity: chrono::Utc::now(),
181 app_task_handle,
182 event_task_handle,
183 })
184 }
185
186 pub fn take_event_rx(&mut self) -> Option<mpsc::Receiver<AppEvent>> {
188 self.event_rx.take()
189 }
190
191 pub fn touch(&mut self) {
193 self.last_activity = chrono::Utc::now();
194 }
195
196 pub fn is_inactive(&self, max_idle_time: chrono::Duration) -> bool {
198 self.subscriber_count == 0 && chrono::Utc::now() - self.last_activity > max_idle_time
199 }
200
201 pub async fn shutdown(self) {
203 let _ = self.command_tx.send(AppCommand::Shutdown).await;
205
206 let _ = self.app_task_handle.await;
208 let _ = self.event_task_handle.await;
209 }
210}
211
212pub struct SessionManager {
214 active_sessions: Arc<RwLock<HashMap<String, ManagedSession>>>,
216 store: Arc<dyn SessionStore>,
218 config: SessionManagerConfig,
220}
221
222impl SessionManager {
223 pub fn new(store: Arc<dyn SessionStore>, config: SessionManagerConfig) -> Self {
225 Self {
226 active_sessions: Arc::new(RwLock::new(HashMap::new())),
227 store,
228 config,
229 }
230 }
231
232 pub async fn create_session(
234 &self,
235 config: SessionConfig,
236 app_config: AppConfig,
237 ) -> Result<(String, mpsc::Sender<AppCommand>)> {
238 let session_config = config;
239
240 let session = self.store.create_session(session_config).await?;
242 let session_id = session.id.clone();
243
244 info!(session_id = %session_id, "Creating new session");
245
246 {
248 let sessions = self.active_sessions.read().await;
249 if sessions.len() >= self.config.max_concurrent_sessions {
250 error!(
251 session_id = %session_id,
252 active_count = sessions.len(),
253 max_capacity = self.config.max_concurrent_sessions,
254 "Session creation rejected: at maximum capacity"
255 );
256 return Err(SessionManagerError::CapacityExceeded {
257 current: sessions.len(),
258 max: self.config.max_concurrent_sessions,
259 }
260 .into());
261 }
262 }
263
264 let managed_session = ManagedSession::new(
266 session.clone(),
267 app_config,
268 self.store.clone(),
269 self.config.default_model,
270 None,
271 )
272 .await
273 .map_err(|e| SessionManagerError::CreationFailed {
274 message: format!("Failed to create managed session: {e}"),
275 })?;
276
277 let command_tx = managed_session.command_tx.clone();
279
280 {
282 let mut sessions = self.active_sessions.write().await;
283 sessions.insert(session_id.clone(), managed_session);
284 }
285
286 let metadata = crate::events::SessionMetadata::from(&SessionInfo::from(&session));
288 let event = StreamEvent::SessionCreated {
289 session_id: session_id.clone(),
290 metadata,
291 };
292 self.emit_event(session_id.clone(), event).await;
293
294 info!(session_id = %session_id, "Session created and activated");
295 Ok((session_id, command_tx))
296 }
297
298 pub async fn take_event_receiver(&self, session_id: &str) -> Result<mpsc::Receiver<AppEvent>> {
300 let mut sessions = self.active_sessions.write().await;
301 match sessions.get_mut(session_id) {
302 Some(managed_session) => {
303 if let Some(receiver) = managed_session.take_event_rx() {
304 Ok(receiver)
305 } else {
306 Err(SessionManagerError::SessionAlreadyHasListener {
307 session_id: session_id.to_string(),
308 }
309 .into())
310 }
311 }
312 None => Err(SessionManagerError::SessionNotActive {
313 session_id: session_id.to_string(),
314 }
315 .into()),
316 }
317 }
318
319 pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionInfo>> {
321 {
323 let sessions = self.active_sessions.read().await;
324 if let Some(managed_session) = sessions.get(session_id) {
325 return Ok(Some(SessionInfo::from(&managed_session.session)));
326 }
327 }
328
329 if let Some(session) = self.store.get_session(session_id).await? {
331 Ok(Some(SessionInfo::from(&session)))
332 } else {
333 Ok(None)
334 }
335 }
336
337 pub async fn get_session_workspace(
339 &self,
340 session_id: &str,
341 ) -> Result<Option<Arc<dyn crate::workspace::Workspace>>> {
342 {
344 let active_sessions = self.active_sessions.read().await;
345 if let Some(managed_session) = active_sessions.get(session_id) {
346 return Ok(Some(
348 managed_session
349 .session
350 .build_workspace()
351 .await
352 .map_err(|e| SessionManagerError::CreationFailed {
353 message: format!("Failed to build workspace: {e}"),
354 })?,
355 ));
356 }
357 }
358
359 if let Some(session_info) = self.store.get_session(session_id).await? {
361 let session = session_info;
362 Ok(Some(session.build_workspace().await.map_err(|e| {
363 SessionManagerError::CreationFailed {
364 message: format!("Failed to build workspace: {e}"),
365 }
366 })?))
367 } else {
368 Ok(None)
369 }
370 }
371
372 pub async fn resume_session(
374 &self,
375 session_id: &str,
376 app_config: AppConfig,
377 ) -> Result<mpsc::Sender<AppCommand>> {
378 {
380 let sessions = self.active_sessions.read().await;
381 if let Some(managed_session) = sessions.get(session_id) {
382 debug!(session_id = %session_id, "Session already active");
383 return Ok(managed_session.command_tx.clone());
384 }
385 }
386
387 let session = match self
389 .store
390 .get_session(session_id)
391 .await
392 .map_err(SessionManagerError::Storage)?
393 {
394 Some(session) => session,
395 None => {
396 debug!(session_id = %session_id, "Session not found in store");
397 return Err(SessionManagerError::SessionNotActive {
398 session_id: session_id.to_string(),
399 }
400 .into());
401 }
402 };
403
404 info!(session_id = %session_id, "Resuming session from storage");
405
406 {
408 let sessions = self.active_sessions.read().await;
409 if sessions.len() >= self.config.max_concurrent_sessions {
410 warn!(
411 session_id = %session_id,
412 active_count = sessions.len(),
413 max_capacity = self.config.max_concurrent_sessions,
414 "At maximum session capacity for resume"
415 );
416 }
418 }
419
420 let conversation = Conversation {
422 messages: session.state.messages.clone(),
423 working_directory: session
424 .config
425 .workspace
426 .get_path()
427 .unwrap_or_default()
428 .into(),
429 active_message_id: session.state.active_message_id.clone(),
430 };
431
432 let managed_session = ManagedSession::new(
434 session.clone(),
435 app_config,
436 self.store.clone(),
437 self.config.default_model,
438 Some(conversation),
439 )
440 .await
441 .map_err(|e| SessionManagerError::CreationFailed {
442 message: format!("Failed to create managed session: {e}"),
443 })?;
444
445 let command_tx = managed_session.command_tx.clone();
447
448 if !session.state.messages.is_empty() || !session.state.approved_tools.is_empty() {
450 info!(
451 session_id = %session_id,
452 message_count = session.state.messages.len(),
453 tool_count = session.state.approved_tools.len(),
454 "Restoring conversation state"
455 );
456
457 command_tx
458 .send(AppCommand::RestoreConversation {
459 messages: session.state.messages.clone(),
460 approved_tools: session.state.approved_tools.clone(),
461 approved_bash_patterns: session.state.approved_bash_patterns.clone(),
462 active_message_id: session.state.active_message_id.clone(),
463 })
464 .await
465 .map_err(|_| SessionManagerError::CreationFailed {
466 message: "Failed to send restore command to App".to_string(),
467 })?;
468 }
469
470 {
472 let mut sessions = self.active_sessions.write().await;
473 sessions.insert(session_id.to_string(), managed_session);
474 }
475
476 let last_sequence = session.state.last_event_sequence;
478
479 let event = StreamEvent::SessionResumed {
481 session_id: session_id.to_string(),
482 event_offset: last_sequence,
483 };
484 self.emit_event(session_id.to_string(), event).await;
485
486 info!(session_id = %session_id, last_sequence = last_sequence, "Session resumed");
487 Ok(command_tx)
488 }
489
490 pub async fn suspend_session(&self, session_id: &str) -> Result<bool> {
492 let managed_session = {
493 let mut sessions = self.active_sessions.write().await;
494 sessions.remove(session_id)
495 };
496
497 let managed_session = match managed_session {
498 Some(session) => session,
499 None => {
500 debug!(session_id = %session_id, "Session not active, cannot suspend");
501 return Ok(false);
502 }
503 };
504
505 info!(session_id = %session_id, "Suspending session");
506
507 self.store.update_session(&managed_session.session).await?;
509
510 let event = StreamEvent::SessionSaved {
512 session_id: session_id.to_string(),
513 };
514 self.emit_event(session_id.to_string(), event).await;
515
516 info!(session_id = %session_id, "Session suspended and saved");
517 Ok(true)
518 }
519
520 pub async fn delete_session(&self, session_id: &str) -> Result<bool> {
522 {
524 let mut sessions = self.active_sessions.write().await;
525 sessions.remove(session_id);
526 }
527
528 self.store.delete_session(session_id).await?;
530
531 info!(session_id = %session_id, "Session deleted");
532 Ok(true)
533 }
534
535 pub async fn list_sessions(&self, filter: SessionFilter) -> Result<Vec<SessionInfo>> {
537 Ok(self.store.list_sessions(filter).await?)
538 }
539
540 pub async fn get_active_sessions(&self) -> Vec<String> {
542 let sessions = self.active_sessions.read().await;
543 sessions.keys().cloned().collect()
544 }
545
546 pub async fn is_session_active(&self, session_id: &str) -> bool {
548 let sessions = self.active_sessions.read().await;
549 sessions.contains_key(session_id)
550 }
551
552 pub async fn send_command(&self, session_id: &str, command: AppCommand) -> Result<()> {
554 let sessions = self.active_sessions.read().await;
555 if let Some(managed_session) = sessions.get(session_id) {
556 managed_session.command_tx.send(command).await.map_err(|_| {
557 Error::SessionManager(SessionManagerError::SessionNotActive {
558 session_id: session_id.to_string(),
559 })
560 })
561 } else {
562 Err(Error::SessionManager(
563 SessionManagerError::SessionNotActive {
564 session_id: session_id.to_string(),
565 },
566 ))
567 }
568 }
569
570 pub async fn update_session_state(
572 &self,
573 session_id: &str,
574 update_fn: impl FnOnce(&mut SessionState),
575 ) -> Result<()> {
576 {
577 let mut sessions = self.active_sessions.write().await;
578 if let Some(managed_session) = sessions.get_mut(session_id) {
579 managed_session.touch();
580 update_fn(&mut managed_session.session.state);
581 managed_session.session.update_timestamp();
582
583 if self.config.auto_persist {
585 self.store.update_session(&managed_session.session).await?;
586 }
587 } else {
588 return Err(SessionManagerError::SessionNotActive {
589 session_id: session_id.to_string(),
590 }
591 .into());
592 }
593 }
594
595 Ok(())
596 }
597
598 pub async fn emit_event(&self, session_id: String, event: StreamEvent) {
600 let sequence_num = match self.store.append_event(&session_id, &event).await {
602 Ok(seq) => seq,
603 Err(e) => {
604 error!(session_id = %session_id, error = %e, "Failed to persist event");
605 return;
606 }
607 };
608
609 if let Err(e) = self
611 .update_session_state(&session_id, |state| {
612 state.last_event_sequence = sequence_num;
613 })
614 .await
615 {
616 error!(session_id = %session_id, error = %e, "Failed to update session sequence number");
617 }
618 }
619
620 pub async fn cleanup_inactive_sessions(&self, max_idle_time: chrono::Duration) -> usize {
622 let mut to_suspend = Vec::new();
623
624 {
625 let sessions = self.active_sessions.read().await;
626 for (session_id, managed_session) in sessions.iter() {
627 if managed_session.is_inactive(max_idle_time) {
628 to_suspend.push(session_id.clone());
629 }
630 }
631 }
632
633 let mut suspended_count = 0;
634 for session_id in to_suspend {
635 if let Ok(true) = self.suspend_session(&session_id).await {
636 suspended_count += 1;
637 }
638 }
639
640 if suspended_count > 0 {
641 info!(
642 suspended_count = suspended_count,
643 "Cleaned up inactive sessions"
644 );
645 }
646
647 suspended_count
648 }
649
650 pub fn store(&self) -> &Arc<dyn SessionStore> {
652 &self.store
653 }
654
655 pub async fn increment_subscriber_count(&self, session_id: &str) -> Result<()> {
657 let mut sessions = self.active_sessions.write().await;
658 if let Some(managed_session) = sessions.get_mut(session_id) {
659 managed_session.subscriber_count += 1;
660 managed_session.touch();
661 debug!(
662 session_id = %session_id,
663 subscriber_count = managed_session.subscriber_count,
664 "Incremented subscriber count"
665 );
666 Ok(())
667 } else {
668 Err(SessionManagerError::SessionNotActive {
669 session_id: session_id.to_string(),
670 }
671 .into())
672 }
673 }
674
675 pub async fn decrement_subscriber_count(&self, session_id: &str) -> Result<()> {
677 let mut sessions = self.active_sessions.write().await;
678 if let Some(managed_session) = sessions.get_mut(session_id) {
679 managed_session.subscriber_count = managed_session.subscriber_count.saturating_sub(1);
680 managed_session.touch();
681 debug!(
682 session_id = %session_id,
683 subscriber_count = managed_session.subscriber_count,
684 "Decremented subscriber count"
685 );
686 Ok(())
687 } else {
688 debug!(session_id = %session_id, "Session not active when decrementing subscriber count");
690 Ok(())
691 }
692 }
693
694 pub async fn touch_session(&self, session_id: &str) -> Result<()> {
696 let mut sessions = self.active_sessions.write().await;
697 if let Some(managed_session) = sessions.get_mut(session_id) {
698 managed_session.touch();
699 Ok(())
700 } else {
701 Ok(())
703 }
704 }
705
706 pub async fn maybe_suspend_idle_session(&self, session_id: &str) -> Result<()> {
708 let should_suspend = {
710 let sessions = self.active_sessions.read().await;
711 if let Some(managed_session) = sessions.get(session_id) {
712 managed_session.subscriber_count == 0
713 } else {
714 false }
716 };
717
718 if should_suspend {
719 info!(session_id = %session_id, "No active subscribers, suspending session");
720 self.suspend_session(session_id).await?;
721 }
722
723 Ok(())
724 }
725
726 pub async fn get_session_state(
728 &self,
729 session_id: &str,
730 ) -> Result<Option<crate::session::SessionState>> {
731 info!("get_session_state called for session: {}", session_id);
732
733 match self.store.get_session(session_id).await {
736 Ok(Some(session)) => {
737 info!(
738 "Loaded session from store with {} messages",
739 session.state.messages.len()
740 );
741 Ok(Some(session.state))
742 }
743 Ok(None) => {
744 info!("Session not found in store: {}", session_id);
745 Ok(None)
746 }
747 Err(e) => {
748 error!("Error loading session from store: {}", e);
749 Err(SessionManagerError::Storage(e).into())
750 }
751 }
752 }
753
754 pub async fn get_mcp_statuses(
756 &self,
757 session_id: &str,
758 ) -> Result<Vec<crate::session::McpServerInfo>> {
759 {
761 let sessions = self.active_sessions.read().await;
762 if let Some(managed_session) = sessions.get(session_id) {
763 let servers: Vec<_> = managed_session
764 .session
765 .state
766 .mcp_servers
767 .values()
768 .cloned()
769 .collect();
770 return Ok(servers);
771 }
772 }
773
774 match self.store.get_session(session_id).await? {
776 Some(session) => {
777 let servers: Vec<_> = session.state.mcp_servers.values().cloned().collect();
778 Ok(servers)
779 }
780 None => Err(SessionManagerError::SessionNotActive {
781 session_id: session_id.to_string(),
782 }
783 .into()),
784 }
785 }
786}
787
788fn translate_app_event(app_event: AppEvent) -> Option<StreamEvent> {
790 match app_event {
791 AppEvent::MessageAdded { message, model } => Some(StreamEvent::MessageComplete {
792 message,
793 usage: None,
794 metadata: std::collections::HashMap::new(),
795 model,
796 }),
797
798 AppEvent::MessagePart { id, delta } => Some(StreamEvent::MessagePart {
799 content: delta,
800 message_id: id,
801 }),
802
803 AppEvent::ToolCallStarted {
804 name,
805 id,
806 parameters,
807 model,
808 } => {
809 let tool_call = ToolCall {
810 id: id.clone(),
811 name: name.clone(),
812 parameters,
813 };
814 Some(StreamEvent::ToolCallStarted {
815 tool_call,
816 metadata: std::collections::HashMap::new(),
817 model,
818 })
819 }
820
821 AppEvent::ToolCallCompleted {
822 name: _,
823 result,
824 id,
825 model,
826 } => Some(StreamEvent::ToolCallCompleted {
827 tool_call_id: id,
828 result,
829 metadata: std::collections::HashMap::new(),
830 model,
831 }),
832
833 AppEvent::ToolCallFailed {
834 name: _,
835 error,
836 id,
837 model,
838 } => Some(StreamEvent::ToolCallFailed {
839 tool_call_id: id,
840 error,
841 metadata: std::collections::HashMap::new(),
842 model,
843 }),
844
845 AppEvent::WorkspaceChanged => Some(StreamEvent::WorkspaceChanged),
846
847 AppEvent::WorkspaceFiles { files } => Some(StreamEvent::WorkspaceFiles {
848 files: files.clone(),
849 }),
850
851 AppEvent::Started { id, op } => Some(StreamEvent::OperationStarted {
852 operation_id: id,
853 operation: op,
854 }),
855 AppEvent::Finished { id, outcome } => Some(StreamEvent::OperationCompleted {
856 operation_id: id,
857 outcome,
858 }),
859 AppEvent::OperationCancelled { op_id, info } => {
860 let operation_id = op_id.unwrap_or_else(uuid::Uuid::new_v4);
862 Some(StreamEvent::OperationCancelled {
863 operation_id,
864 reason: info.to_string(), })
866 }
867
868 _ => None,
870 }
871}
872async fn update_session_state_for_event(
874 store: &Arc<dyn SessionStore>,
875 session_id: &str,
876 event: &StreamEvent,
877) -> Result<()> {
878 match event {
879 StreamEvent::MessageComplete { message, .. } => {
880 store.append_message(session_id, message).await?;
881
882 if let crate::app::conversation::MessageData::Tool {
884 tool_use_id,
885 result,
886 ..
887 } = &message.data
888 {
889 let stats = crate::session::ToolExecutionStats::success_typed(
890 serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
891 result.variant_name().to_string(),
892 0, );
894 let update = ToolCallUpdate::set_result(stats);
895 store.update_tool_call(tool_use_id, update).await?;
896 }
897 }
898 StreamEvent::ToolCallStarted { tool_call, .. } => {
899 store.create_tool_call(session_id, tool_call).await?;
900 }
901 StreamEvent::ToolCallCompleted {
902 tool_call_id,
903 result,
904 ..
905 } => {
906 let stats = crate::session::ToolExecutionStats::success_typed(
907 serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
908 result.variant_name().to_string(),
909 0,
910 );
911 let update = ToolCallUpdate::set_result(stats);
912 store.update_tool_call(tool_call_id, update).await?;
913
914 let messages = store.get_messages(session_id, None).await?;
917 let parent_id = messages.last().map(|m| m.id().to_string());
918
919 let tool_message = ConversationMessage {
920 data: crate::app::conversation::MessageData::Tool {
921 tool_use_id: tool_call_id.clone(),
922 result: result.clone(),
923 },
924 timestamp: std::time::SystemTime::now()
925 .duration_since(std::time::UNIX_EPOCH)
926 .expect("Time went backwards")
927 .as_secs(),
928 id: format!("tool_result_{tool_call_id}"),
929 parent_message_id: parent_id,
930 };
931 store.append_message(session_id, &tool_message).await?;
932 }
933 StreamEvent::ToolCallFailed {
934 tool_call_id,
935 error,
936 ..
937 } => {
938 let update = ToolCallUpdate::set_error(error.clone());
939 store.update_tool_call(tool_call_id, update).await?;
940
941 let messages = store.get_messages(session_id, None).await?;
944 let parent_id = messages.last().map(|m| m.id().to_string());
945
946 let tool_error = steer_tools::error::ToolError::Execution {
947 tool_name: "unknown".to_string(), message: error.clone(),
949 };
950 let tool_message = ConversationMessage {
951 data: MessageData::Tool {
952 tool_use_id: tool_call_id.clone(),
953 result: ToolResult::Error(tool_error),
954 },
955 timestamp: std::time::SystemTime::now()
956 .duration_since(std::time::UNIX_EPOCH)
957 .expect("Time went backwards")
958 .as_secs(),
959 id: format!("tool_result_{tool_call_id}"),
960 parent_message_id: parent_id,
961 };
962 store.append_message(session_id, &tool_message).await?;
963 }
964 _ => {}
966 }
967 Ok(())
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973 use crate::api::ToolCall;
974 use crate::app::MessageData;
975 use crate::app::conversation::{AssistantContent, Role, UserContent};
976 use crate::session::stores::sqlite::SqliteSessionStore;
977 use tempfile::TempDir;
978
979 async fn create_test_manager() -> (SessionManager, TempDir) {
980 let temp_dir = TempDir::new().unwrap();
981 let db_path = temp_dir.path().join("test.db");
982 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
983
984 let config = SessionManagerConfig {
985 max_concurrent_sessions: 100,
986 default_model: Model::default(),
987 auto_persist: true,
988 };
989 let manager = SessionManager::new(store, config);
990
991 (manager, temp_dir)
992 }
993
994 fn create_test_app_config() -> AppConfig {
995 crate::test_utils::test_app_config()
996 }
997
998 #[tokio::test]
999 async fn test_create_and_resume_session() {
1000 let (manager, temp) = create_test_manager().await;
1001 let app_config = create_test_app_config();
1002
1003 let session_config = SessionConfig {
1005 workspace: crate::session::state::WorkspaceConfig::Local {
1006 path: temp.path().to_path_buf(),
1007 },
1008 tool_config: crate::session::SessionToolConfig::default(),
1009 system_prompt: None,
1010 metadata: std::collections::HashMap::new(),
1011 };
1012 let (session_id, _command_tx) = manager
1013 .create_session(session_config, app_config.clone())
1014 .await
1015 .unwrap();
1016 assert!(!session_id.is_empty());
1017 assert!(manager.is_session_active(&session_id).await);
1018
1019 assert!(manager.suspend_session(&session_id).await.unwrap());
1021 assert!(!manager.is_session_active(&session_id).await);
1022
1023 let _command_tx = manager
1025 .resume_session(&session_id, app_config)
1026 .await
1027 .unwrap();
1028 assert!(manager.is_session_active(&session_id).await);
1029 }
1030
1031 #[tokio::test]
1032 async fn test_session_cleanup() {
1033 let (manager, temp) = create_test_manager().await;
1034 let app_config = create_test_app_config();
1035
1036 let session_config = SessionConfig {
1038 workspace: crate::session::state::WorkspaceConfig::Local {
1039 path: temp.path().to_path_buf(),
1040 },
1041 tool_config: crate::session::SessionToolConfig::default(),
1042 system_prompt: None,
1043 metadata: std::collections::HashMap::new(),
1044 };
1045 let (session_id, _command_tx) = manager
1046 .create_session(session_config, app_config)
1047 .await
1048 .unwrap();
1049
1050 {
1052 let mut sessions = manager.active_sessions.write().await;
1053 if let Some(session) = sessions.get_mut(&session_id) {
1054 session.last_activity = chrono::Utc::now() - chrono::Duration::hours(2);
1055 session.subscriber_count = 0;
1056 }
1057 }
1058
1059 let cleaned = manager
1061 .cleanup_inactive_sessions(chrono::Duration::hours(1))
1062 .await;
1063 assert_eq!(cleaned, 1);
1064 assert!(!manager.is_session_active(&session_id).await);
1065 }
1066
1067 #[tokio::test]
1068 async fn test_capacity_rejection() {
1069 let temp_dir = TempDir::new().unwrap();
1070 let temp = tempfile::TempDir::new().unwrap();
1071 let db_path = temp_dir.path().join("test.db");
1072 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
1073
1074 let config = SessionManagerConfig {
1075 max_concurrent_sessions: 1, default_model: Model::default(),
1077 auto_persist: true,
1078 };
1079 let manager = SessionManager::new(store, config);
1080 let app_config = create_test_app_config();
1081
1082 let tool_config = crate::session::SessionToolConfig {
1084 approval_policy: crate::session::ToolApprovalPolicy::AlwaysAsk,
1085 ..Default::default()
1086 };
1087
1088 let session_config = SessionConfig {
1089 workspace: crate::session::state::WorkspaceConfig::Local {
1090 path: temp.path().to_path_buf(),
1091 },
1092 tool_config,
1093 system_prompt: None,
1094 metadata: std::collections::HashMap::new(),
1095 };
1096 let (session_id1, _command_tx) = manager
1097 .create_session(session_config.clone(), app_config.clone())
1098 .await
1099 .unwrap();
1100 assert!(!session_id1.is_empty());
1101
1102 let result = manager.create_session(session_config, app_config).await;
1104
1105 assert!(result.is_err());
1106 assert!(matches!(
1107 result,
1108 Err(crate::error::Error::SessionManager(
1109 SessionManagerError::CapacityExceeded { .. }
1110 ))
1111 ));
1112 match result.unwrap_err() {
1113 crate::error::Error::SessionManager(SessionManagerError::CapacityExceeded {
1114 current,
1115 max,
1116 }) => {
1117 assert_eq!(current, 1);
1118 assert_eq!(max, 1);
1119 }
1120 _ => unreachable!(),
1121 }
1122 }
1123
1124 #[tokio::test]
1125 async fn test_tool_result_persistence_on_resume() {
1126 let (manager, temp) = create_test_manager().await;
1127 let app_config = create_test_app_config();
1128
1129 let session_config = SessionConfig {
1131 workspace: crate::session::state::WorkspaceConfig::Local {
1132 path: temp.path().to_path_buf(),
1133 },
1134 tool_config: crate::session::SessionToolConfig::default(),
1135 system_prompt: None,
1136 metadata: std::collections::HashMap::new(),
1137 };
1138 let (session_id, _command_tx) = manager
1139 .create_session(session_config, app_config.clone())
1140 .await
1141 .unwrap();
1142
1143 let user_message = ConversationMessage {
1146 data: crate::app::conversation::MessageData::User {
1147 content: vec![UserContent::Text {
1148 text: "Read the file test.txt".to_string(),
1149 }],
1150 },
1151 timestamp: 123456789,
1152 id: "user_1".to_string(),
1153 parent_message_id: None,
1154 };
1155 manager
1156 .store
1157 .append_message(&session_id, &user_message)
1158 .await
1159 .unwrap();
1160
1161 let assistant_message = ConversationMessage {
1163 data: crate::app::conversation::MessageData::Assistant {
1164 content: vec![
1165 AssistantContent::Text {
1166 text: "I'll read that file for you.".to_string(),
1167 },
1168 AssistantContent::ToolCall {
1169 tool_call: ToolCall {
1170 id: "tool_call_1".to_string(),
1171 name: "read_file".to_string(),
1172 parameters: serde_json::json!({"path": "test.txt"}),
1173 },
1174 },
1175 ],
1176 },
1177 timestamp: 123456790,
1178 id: "assistant_1".to_string(),
1179 parent_message_id: Some("user_1".to_string()),
1180 };
1181 manager
1182 .store
1183 .append_message(&session_id, &assistant_message)
1184 .await
1185 .unwrap();
1186
1187 let tool_call = ToolCall {
1190 id: "tool_call_1".to_string(),
1191 name: "read_file".to_string(),
1192 parameters: serde_json::json!({"path": "test.txt"}),
1193 };
1194 manager
1195 .store
1196 .create_tool_call(&session_id, &tool_call)
1197 .await
1198 .unwrap();
1199
1200 let stats = crate::session::ToolExecutionStats::success_typed(
1202 serde_json::json!({
1203 "content": "File contents: Hello, world!",
1204 "file_path": "test.txt",
1205 "line_count": 1,
1206 "truncated": false
1207 }),
1208 "FileContent".to_string(),
1209 0,
1210 );
1211 let update = ToolCallUpdate::set_result(stats);
1212 manager
1213 .store
1214 .update_tool_call("tool_call_1", update)
1215 .await
1216 .unwrap();
1217
1218 let tool_message = ConversationMessage {
1220 data: MessageData::Tool {
1221 tool_use_id: "tool_call_1".to_string(),
1222 result: ToolResult::FileContent(steer_tools::result::FileContentResult {
1223 content: "File contents: Hello, world!".to_string(),
1224 file_path: "test.txt".to_string(),
1225 line_count: 1,
1226 truncated: false,
1227 }),
1228 },
1229 timestamp: 123456790,
1230 id: "tool_result_tool_call_1".to_string(),
1231 parent_message_id: Some("assistant_1".to_string()),
1232 };
1233 manager
1234 .store
1235 .append_message(&session_id, &tool_message)
1236 .await
1237 .unwrap();
1238
1239 let followup_message = ConversationMessage {
1241 data: crate::app::conversation::MessageData::Assistant {
1242 content: vec![AssistantContent::Text {
1243 text: "The file contains: Hello, world!".to_string(),
1244 }],
1245 },
1246 timestamp: 123456791,
1247 id: "assistant_2".to_string(),
1248 parent_message_id: Some("assistant_1".to_string()),
1249 };
1250 manager
1251 .store
1252 .append_message(&session_id, &followup_message)
1253 .await
1254 .unwrap();
1255
1256 manager.suspend_session(&session_id).await.unwrap();
1258
1259 let loaded_session = manager
1261 .store
1262 .get_session(&session_id)
1263 .await
1264 .unwrap()
1265 .unwrap();
1266
1267 assert_eq!(loaded_session.state.messages.len(), 4);
1269
1270 let tool_result_msg = &loaded_session.state.messages[2];
1272 assert_eq!(tool_result_msg.role(), Role::Tool);
1273
1274 assert!(matches!(
1276 &tool_result_msg.data,
1277 crate::app::conversation::MessageData::Tool { .. }
1278 ));
1279 if let crate::app::conversation::MessageData::Tool {
1280 tool_use_id,
1281 result,
1282 ..
1283 } = &tool_result_msg.data
1284 {
1285 assert_eq!(tool_use_id, "tool_call_1");
1286 assert!(matches!(
1287 result,
1288 crate::app::conversation::ToolResult::FileContent(_)
1289 ));
1290 match result {
1291 crate::app::conversation::ToolResult::FileContent(content) => {
1292 assert!(content.content.contains("Hello, world!"));
1293 }
1294 _ => unreachable!(),
1295 }
1296 } else {
1297 panic!("Expected Tool message");
1298 }
1299
1300 let _command_tx = manager
1302 .resume_session(&session_id, app_config)
1303 .await
1304 .unwrap();
1305
1306 }
1310
1311 #[tokio::test]
1312 async fn test_active_message_id_persistence() {
1313 let (manager, temp) = create_test_manager().await;
1314 let app_config = create_test_app_config();
1315
1316 let session_config = SessionConfig {
1318 workspace: crate::session::state::WorkspaceConfig::Local {
1319 path: temp.path().to_path_buf(),
1320 },
1321 tool_config: crate::session::SessionToolConfig::default(),
1322 system_prompt: None,
1323 metadata: std::collections::HashMap::new(),
1324 };
1325 let (session_id, _command_tx) = manager
1326 .create_session(session_config, app_config.clone())
1327 .await
1328 .unwrap();
1329
1330 let msg1 = ConversationMessage {
1332 data: crate::app::conversation::MessageData::User {
1333 content: vec![UserContent::Text {
1334 text: "Hello".to_string(),
1335 }],
1336 },
1337 timestamp: 1000,
1338 id: "msg1".to_string(),
1339 parent_message_id: None,
1340 };
1341
1342 let msg2 = ConversationMessage {
1343 data: crate::app::conversation::MessageData::Assistant {
1344 content: vec![AssistantContent::Text {
1345 text: "Hi there!".to_string(),
1346 }],
1347 },
1348 timestamp: 2000,
1349 id: "msg2".to_string(),
1350 parent_message_id: Some("msg1".to_string()),
1351 };
1352
1353 let msg1_edited = ConversationMessage {
1355 data: crate::app::conversation::MessageData::User {
1356 content: vec![UserContent::Text {
1357 text: "Goodbye".to_string(),
1358 }],
1359 },
1360 timestamp: 3000,
1361 id: "msg1_edited".to_string(),
1362 parent_message_id: None, };
1364
1365 manager
1367 .store
1368 .append_message(&session_id, &msg1)
1369 .await
1370 .unwrap();
1371 manager
1372 .store
1373 .append_message(&session_id, &msg2)
1374 .await
1375 .unwrap();
1376 manager
1377 .store
1378 .append_message(&session_id, &msg1_edited)
1379 .await
1380 .unwrap();
1381
1382 manager
1384 .store
1385 .update_active_message_id(&session_id, Some("msg1_edited"))
1386 .await
1387 .unwrap();
1388
1389 manager.suspend_session(&session_id).await.unwrap();
1391
1392 let loaded_session = manager
1394 .store
1395 .get_session(&session_id)
1396 .await
1397 .unwrap()
1398 .unwrap();
1399
1400 assert_eq!(
1402 loaded_session.state.active_message_id,
1403 Some("msg1_edited".to_string())
1404 );
1405
1406 assert_eq!(loaded_session.state.messages.len(), 3);
1408
1409 let edited_msg = loaded_session
1411 .state
1412 .messages
1413 .iter()
1414 .find(|m| m.id() == "msg1_edited")
1415 .expect("Edited message should exist");
1416
1417 match &edited_msg.data {
1418 crate::app::conversation::MessageData::User { content, .. } => {
1419 if let Some(UserContent::Text { text }) = content.first() {
1420 assert_eq!(text, "Goodbye");
1421 } else {
1422 panic!("Expected text content");
1423 }
1424 }
1425 _ => panic!("Expected user message"),
1426 }
1427
1428 let _ = manager
1430 .resume_session(&session_id, app_config)
1431 .await
1432 .unwrap();
1433
1434 let state = manager
1436 .get_session_state(&session_id)
1437 .await
1438 .unwrap()
1439 .unwrap();
1440 assert_eq!(state.active_message_id, Some("msg1_edited".to_string()));
1441 }
1442}