1use crate::api::{Model, ToolCall};
2use crate::app::{
3 App, AppCommand, AppConfig, AppEvent, Conversation, Message as ConversationMessage, MessageData,
4};
5use crate::error::{Error, Result};
6use crate::events::StreamEvent;
7use crate::session::{
8 Session, SessionConfig, SessionFilter, SessionInfo, SessionState, SessionStore,
9 SessionStoreError, ToolCallUpdate,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use steer_tools::ToolResult;
14use thiserror::Error;
15use tokio::sync::{RwLock, mpsc};
16use tokio::task::JoinHandle;
17use tracing::{debug, error, info, warn};
18
19#[derive(Debug, Error)]
21pub enum SessionManagerError {
22 #[error("Maximum session capacity reached ({current}/{max}). Cannot create new session.")]
23 CapacityExceeded { current: usize, max: usize },
24
25 #[error("Session not active: {session_id}")]
26 SessionNotActive { session_id: String },
27
28 #[error("Session {session_id} already has an active listener")]
29 SessionAlreadyHasListener { session_id: String },
30
31 #[error("Failed to create managed session: {message}")]
32 CreationFailed { message: String },
33
34 #[error(transparent)]
35 Storage(#[from] SessionStoreError),
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionManagerConfig {
41 pub max_concurrent_sessions: usize,
43 pub default_model: Model,
45 pub auto_persist: bool,
47}
48
49pub struct ManagedSession {
51 pub session: Session,
53 pub command_tx: mpsc::Sender<AppCommand>,
55 pub event_rx: Option<mpsc::Receiver<AppEvent>>,
57 pub subscriber_count: usize,
59 pub last_activity: chrono::DateTime<chrono::Utc>,
61 pub app_task_handle: JoinHandle<()>,
63 pub event_task_handle: JoinHandle<()>,
65}
66
67impl ManagedSession {
68 pub async fn new(
70 mut session: Session,
71 app_config: AppConfig,
72 store: Arc<dyn SessionStore>,
73 default_model: Model,
74 conversation: Option<Conversation>,
75 ) -> Result<Self> {
76 let (app_event_tx, mut app_event_rx) = mpsc::channel(100);
78 let (app_command_tx, app_command_rx) = mpsc::channel::<AppCommand>(32);
79
80 let (external_event_tx, external_event_rx) = mpsc::channel(100);
82
83 crate::app::OpContext::init_command_tx(app_command_tx.clone());
85
86 let workspace = session.build_workspace().await?;
88
89 let (backend_registry, mcp_servers) = session
91 .config
92 .build_registry(
93 Arc::new(app_config.llm_config_provider.clone()),
94 workspace.clone(),
95 )
96 .await?;
97
98 session.state.mcp_servers = mcp_servers;
100
101 let tool_executor = Arc::new(crate::tools::ToolExecutor::with_all_components(
102 workspace.clone(),
103 Arc::new(backend_registry),
104 Arc::new(crate::app::validation::ValidatorRegistry::new()),
105 app_config.llm_config_provider.clone(),
106 ));
107
108 let mut app = if let Some(conv) = conversation {
110 App::new_with_conversation(
111 app_config,
112 app_event_tx,
113 default_model,
114 workspace.clone(),
115 tool_executor,
116 Some(session.config.clone()),
117 conv,
118 )
119 .await?
120 } else {
121 App::new(
122 app_config,
123 app_event_tx,
124 default_model,
125 workspace.clone(),
126 tool_executor,
127 Some(session.config.clone()),
128 )
129 .await?
130 };
131
132 if let Some(model_str) = session.config.metadata.get("initial_model") {
134 if let Ok(model) = model_str.parse::<crate::api::Model>() {
135 let _ = app.set_model(model).await;
136 }
137 }
138
139 let app_task_handle = tokio::spawn(crate::app::app_actor_loop(app, app_command_rx));
141
142 let session_id = session.id.clone();
144 let store_clone = store.clone();
145
146 let event_task_handle = tokio::spawn(async move {
147 while let Some(app_event) = app_event_rx.recv().await {
148 if let Err(e) = external_event_tx.try_send(app_event.clone()) {
150 warn!(session_id = %session_id, "Failed to send event to external consumer: {}", e);
151 }
152
153 if let AppEvent::ActiveMessageIdChanged { message_id } = &app_event {
155 if let Err(e) = store_clone
156 .update_active_message_id(&session_id, message_id.as_deref())
157 .await
158 {
159 error!(session_id = %session_id, error = %e, "Failed to update active message ID");
160 }
161 }
162
163 if let Some(stream_event) = translate_app_event(app_event) {
165 if let Ok(_sequence_num) =
167 store_clone.append_event(&session_id, &stream_event).await
168 {
169 if let Err(e) =
171 update_session_state_for_event(&store_clone, &session_id, &stream_event)
172 .await
173 {
174 error!(session_id = %session_id, error = %e, "Failed to update session state");
175 }
176 }
177 }
178 }
179 info!(session_id = %session_id, "Event translation loop ended");
180 });
181
182 Ok(Self {
183 session,
184 command_tx: app_command_tx,
185 event_rx: Some(external_event_rx),
186 subscriber_count: 0,
187 last_activity: chrono::Utc::now(),
188 app_task_handle,
189 event_task_handle,
190 })
191 }
192
193 pub fn take_event_rx(&mut self) -> Option<mpsc::Receiver<AppEvent>> {
195 self.event_rx.take()
196 }
197
198 pub fn touch(&mut self) {
200 self.last_activity = chrono::Utc::now();
201 }
202
203 pub fn is_inactive(&self, max_idle_time: chrono::Duration) -> bool {
205 self.subscriber_count == 0 && chrono::Utc::now() - self.last_activity > max_idle_time
206 }
207
208 pub async fn shutdown(self) {
210 let _ = self.command_tx.send(AppCommand::Shutdown).await;
212
213 let _ = self.app_task_handle.await;
215 let _ = self.event_task_handle.await;
216 }
217}
218
219pub struct SessionManager {
221 active_sessions: Arc<RwLock<HashMap<String, ManagedSession>>>,
223 store: Arc<dyn SessionStore>,
225 config: SessionManagerConfig,
227}
228
229impl SessionManager {
230 pub fn new(store: Arc<dyn SessionStore>, config: SessionManagerConfig) -> Self {
232 Self {
233 active_sessions: Arc::new(RwLock::new(HashMap::new())),
234 store,
235 config,
236 }
237 }
238
239 pub async fn create_session(
241 &self,
242 config: SessionConfig,
243 app_config: AppConfig,
244 ) -> Result<(String, mpsc::Sender<AppCommand>)> {
245 let session_config = config;
246
247 let session = self.store.create_session(session_config).await?;
249 let session_id = session.id.clone();
250
251 info!(session_id = %session_id, "Creating new session");
252
253 {
255 let sessions = self.active_sessions.read().await;
256 if sessions.len() >= self.config.max_concurrent_sessions {
257 error!(
258 session_id = %session_id,
259 active_count = sessions.len(),
260 max_capacity = self.config.max_concurrent_sessions,
261 "Session creation rejected: at maximum capacity"
262 );
263 return Err(SessionManagerError::CapacityExceeded {
264 current: sessions.len(),
265 max: self.config.max_concurrent_sessions,
266 }
267 .into());
268 }
269 }
270
271 let managed_session = ManagedSession::new(
273 session.clone(),
274 app_config,
275 self.store.clone(),
276 self.config.default_model,
277 None,
278 )
279 .await
280 .map_err(|e| SessionManagerError::CreationFailed {
281 message: format!("Failed to create managed session: {e}"),
282 })?;
283
284 let command_tx = managed_session.command_tx.clone();
286
287 {
289 let mut sessions = self.active_sessions.write().await;
290 sessions.insert(session_id.clone(), managed_session);
291 }
292
293 let metadata = crate::events::SessionMetadata::from(&SessionInfo::from(&session));
295 let event = StreamEvent::SessionCreated {
296 session_id: session_id.clone(),
297 metadata,
298 };
299 self.emit_event(session_id.clone(), event).await;
300
301 info!(session_id = %session_id, "Session created and activated");
302 Ok((session_id, command_tx))
303 }
304
305 pub async fn take_event_receiver(&self, session_id: &str) -> Result<mpsc::Receiver<AppEvent>> {
307 let mut sessions = self.active_sessions.write().await;
308 match sessions.get_mut(session_id) {
309 Some(managed_session) => {
310 if let Some(receiver) = managed_session.take_event_rx() {
311 Ok(receiver)
312 } else {
313 Err(SessionManagerError::SessionAlreadyHasListener {
314 session_id: session_id.to_string(),
315 }
316 .into())
317 }
318 }
319 None => Err(SessionManagerError::SessionNotActive {
320 session_id: session_id.to_string(),
321 }
322 .into()),
323 }
324 }
325
326 pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionInfo>> {
328 {
330 let sessions = self.active_sessions.read().await;
331 if let Some(managed_session) = sessions.get(session_id) {
332 return Ok(Some(SessionInfo::from(&managed_session.session)));
333 }
334 }
335
336 if let Some(session) = self.store.get_session(session_id).await? {
338 Ok(Some(SessionInfo::from(&session)))
339 } else {
340 Ok(None)
341 }
342 }
343
344 pub async fn get_session_workspace(
346 &self,
347 session_id: &str,
348 ) -> Result<Option<Arc<dyn crate::workspace::Workspace>>> {
349 {
351 let active_sessions = self.active_sessions.read().await;
352 if let Some(managed_session) = active_sessions.get(session_id) {
353 return Ok(Some(
355 managed_session
356 .session
357 .build_workspace()
358 .await
359 .map_err(|e| SessionManagerError::CreationFailed {
360 message: format!("Failed to build workspace: {e}"),
361 })?,
362 ));
363 }
364 }
365
366 if let Some(session_info) = self.store.get_session(session_id).await? {
368 let session = session_info;
369 Ok(Some(session.build_workspace().await.map_err(|e| {
370 SessionManagerError::CreationFailed {
371 message: format!("Failed to build workspace: {e}"),
372 }
373 })?))
374 } else {
375 Ok(None)
376 }
377 }
378
379 pub async fn resume_session(
381 &self,
382 session_id: &str,
383 app_config: AppConfig,
384 ) -> Result<mpsc::Sender<AppCommand>> {
385 {
387 let sessions = self.active_sessions.read().await;
388 if let Some(managed_session) = sessions.get(session_id) {
389 debug!(session_id = %session_id, "Session already active");
390 return Ok(managed_session.command_tx.clone());
391 }
392 }
393
394 let session = match self
396 .store
397 .get_session(session_id)
398 .await
399 .map_err(SessionManagerError::Storage)?
400 {
401 Some(session) => session,
402 None => {
403 debug!(session_id = %session_id, "Session not found in store");
404 return Err(SessionManagerError::SessionNotActive {
405 session_id: session_id.to_string(),
406 }
407 .into());
408 }
409 };
410
411 info!(session_id = %session_id, "Resuming session from storage");
412
413 {
415 let sessions = self.active_sessions.read().await;
416 if sessions.len() >= self.config.max_concurrent_sessions {
417 warn!(
418 session_id = %session_id,
419 active_count = sessions.len(),
420 max_capacity = self.config.max_concurrent_sessions,
421 "At maximum session capacity for resume"
422 );
423 }
425 }
426
427 let conversation = Conversation {
429 messages: session.state.messages.clone(),
430 working_directory: session
431 .config
432 .workspace
433 .get_path()
434 .unwrap_or_default()
435 .into(),
436 active_message_id: session.state.active_message_id.clone(),
437 };
438
439 let managed_session = ManagedSession::new(
441 session.clone(),
442 app_config,
443 self.store.clone(),
444 self.config.default_model,
445 Some(conversation),
446 )
447 .await
448 .map_err(|e| SessionManagerError::CreationFailed {
449 message: format!("Failed to create managed session: {e}"),
450 })?;
451
452 let command_tx = managed_session.command_tx.clone();
454
455 if !session.state.messages.is_empty() || !session.state.approved_tools.is_empty() {
457 info!(
458 session_id = %session_id,
459 message_count = session.state.messages.len(),
460 tool_count = session.state.approved_tools.len(),
461 "Restoring conversation state"
462 );
463
464 command_tx
465 .send(AppCommand::RestoreConversation {
466 messages: session.state.messages.clone(),
467 approved_tools: session.state.approved_tools.clone(),
468 approved_bash_patterns: session.state.approved_bash_patterns.clone(),
469 active_message_id: session.state.active_message_id.clone(),
470 })
471 .await
472 .map_err(|_| SessionManagerError::CreationFailed {
473 message: "Failed to send restore command to App".to_string(),
474 })?;
475 }
476
477 {
479 let mut sessions = self.active_sessions.write().await;
480 sessions.insert(session_id.to_string(), managed_session);
481 }
482
483 let last_sequence = session.state.last_event_sequence;
485
486 let event = StreamEvent::SessionResumed {
488 session_id: session_id.to_string(),
489 event_offset: last_sequence,
490 };
491 self.emit_event(session_id.to_string(), event).await;
492
493 info!(session_id = %session_id, last_sequence = last_sequence, "Session resumed");
494 Ok(command_tx)
495 }
496
497 pub async fn suspend_session(&self, session_id: &str) -> Result<bool> {
499 let managed_session = {
500 let mut sessions = self.active_sessions.write().await;
501 sessions.remove(session_id)
502 };
503
504 let managed_session = match managed_session {
505 Some(session) => session,
506 None => {
507 debug!(session_id = %session_id, "Session not active, cannot suspend");
508 return Ok(false);
509 }
510 };
511
512 info!(session_id = %session_id, "Suspending session");
513
514 self.store.update_session(&managed_session.session).await?;
516
517 let event = StreamEvent::SessionSaved {
519 session_id: session_id.to_string(),
520 };
521 self.emit_event(session_id.to_string(), event).await;
522
523 info!(session_id = %session_id, "Session suspended and saved");
524 Ok(true)
525 }
526
527 pub async fn delete_session(&self, session_id: &str) -> Result<bool> {
529 {
531 let mut sessions = self.active_sessions.write().await;
532 sessions.remove(session_id);
533 }
534
535 self.store.delete_session(session_id).await?;
537
538 info!(session_id = %session_id, "Session deleted");
539 Ok(true)
540 }
541
542 pub async fn list_sessions(&self, filter: SessionFilter) -> Result<Vec<SessionInfo>> {
544 Ok(self.store.list_sessions(filter).await?)
545 }
546
547 pub async fn get_active_sessions(&self) -> Vec<String> {
549 let sessions = self.active_sessions.read().await;
550 sessions.keys().cloned().collect()
551 }
552
553 pub async fn is_session_active(&self, session_id: &str) -> bool {
555 let sessions = self.active_sessions.read().await;
556 sessions.contains_key(session_id)
557 }
558
559 pub async fn send_command(&self, session_id: &str, command: AppCommand) -> Result<()> {
561 let sessions = self.active_sessions.read().await;
562 if let Some(managed_session) = sessions.get(session_id) {
563 managed_session.command_tx.send(command).await.map_err(|_| {
564 Error::SessionManager(SessionManagerError::SessionNotActive {
565 session_id: session_id.to_string(),
566 })
567 })
568 } else {
569 Err(Error::SessionManager(
570 SessionManagerError::SessionNotActive {
571 session_id: session_id.to_string(),
572 },
573 ))
574 }
575 }
576
577 pub async fn update_session_state(
579 &self,
580 session_id: &str,
581 update_fn: impl FnOnce(&mut SessionState),
582 ) -> Result<()> {
583 {
584 let mut sessions = self.active_sessions.write().await;
585 if let Some(managed_session) = sessions.get_mut(session_id) {
586 managed_session.touch();
587 update_fn(&mut managed_session.session.state);
588 managed_session.session.update_timestamp();
589
590 if self.config.auto_persist {
592 self.store.update_session(&managed_session.session).await?;
593 }
594 } else {
595 return Err(SessionManagerError::SessionNotActive {
596 session_id: session_id.to_string(),
597 }
598 .into());
599 }
600 }
601
602 Ok(())
603 }
604
605 pub async fn emit_event(&self, session_id: String, event: StreamEvent) {
607 let sequence_num = match self.store.append_event(&session_id, &event).await {
609 Ok(seq) => seq,
610 Err(e) => {
611 error!(session_id = %session_id, error = %e, "Failed to persist event");
612 return;
613 }
614 };
615
616 if let Err(e) = self
618 .update_session_state(&session_id, |state| {
619 state.last_event_sequence = sequence_num;
620 })
621 .await
622 {
623 error!(session_id = %session_id, error = %e, "Failed to update session sequence number");
624 }
625 }
626
627 pub async fn cleanup_inactive_sessions(&self, max_idle_time: chrono::Duration) -> usize {
629 let mut to_suspend = Vec::new();
630
631 {
632 let sessions = self.active_sessions.read().await;
633 for (session_id, managed_session) in sessions.iter() {
634 if managed_session.is_inactive(max_idle_time) {
635 to_suspend.push(session_id.clone());
636 }
637 }
638 }
639
640 let mut suspended_count = 0;
641 for session_id in to_suspend {
642 if let Ok(true) = self.suspend_session(&session_id).await {
643 suspended_count += 1;
644 }
645 }
646
647 if suspended_count > 0 {
648 info!(
649 suspended_count = suspended_count,
650 "Cleaned up inactive sessions"
651 );
652 }
653
654 suspended_count
655 }
656
657 pub fn store(&self) -> &Arc<dyn SessionStore> {
659 &self.store
660 }
661
662 pub async fn increment_subscriber_count(&self, session_id: &str) -> Result<()> {
664 let mut sessions = self.active_sessions.write().await;
665 if let Some(managed_session) = sessions.get_mut(session_id) {
666 managed_session.subscriber_count += 1;
667 managed_session.touch();
668 debug!(
669 session_id = %session_id,
670 subscriber_count = managed_session.subscriber_count,
671 "Incremented subscriber count"
672 );
673 Ok(())
674 } else {
675 Err(SessionManagerError::SessionNotActive {
676 session_id: session_id.to_string(),
677 }
678 .into())
679 }
680 }
681
682 pub async fn decrement_subscriber_count(&self, session_id: &str) -> Result<()> {
684 let mut sessions = self.active_sessions.write().await;
685 if let Some(managed_session) = sessions.get_mut(session_id) {
686 managed_session.subscriber_count = managed_session.subscriber_count.saturating_sub(1);
687 managed_session.touch();
688 debug!(
689 session_id = %session_id,
690 subscriber_count = managed_session.subscriber_count,
691 "Decremented subscriber count"
692 );
693 Ok(())
694 } else {
695 debug!(session_id = %session_id, "Session not active when decrementing subscriber count");
697 Ok(())
698 }
699 }
700
701 pub async fn touch_session(&self, session_id: &str) -> Result<()> {
703 let mut sessions = self.active_sessions.write().await;
704 if let Some(managed_session) = sessions.get_mut(session_id) {
705 managed_session.touch();
706 Ok(())
707 } else {
708 Ok(())
710 }
711 }
712
713 pub async fn maybe_suspend_idle_session(&self, session_id: &str) -> Result<()> {
715 let should_suspend = {
717 let sessions = self.active_sessions.read().await;
718 if let Some(managed_session) = sessions.get(session_id) {
719 managed_session.subscriber_count == 0
720 } else {
721 false }
723 };
724
725 if should_suspend {
726 info!(session_id = %session_id, "No active subscribers, suspending session");
727 self.suspend_session(session_id).await?;
728 }
729
730 Ok(())
731 }
732
733 pub async fn get_session_state(
735 &self,
736 session_id: &str,
737 ) -> Result<Option<crate::session::SessionState>> {
738 info!("get_session_state called for session: {}", session_id);
739
740 match self.store.get_session(session_id).await {
743 Ok(Some(session)) => {
744 info!(
745 "Loaded session from store with {} messages",
746 session.state.messages.len()
747 );
748 Ok(Some(session.state))
749 }
750 Ok(None) => {
751 info!("Session not found in store: {}", session_id);
752 Ok(None)
753 }
754 Err(e) => {
755 error!("Error loading session from store: {}", e);
756 Err(SessionManagerError::Storage(e).into())
757 }
758 }
759 }
760
761 pub async fn get_mcp_statuses(
763 &self,
764 session_id: &str,
765 ) -> Result<Vec<crate::session::McpServerInfo>> {
766 {
768 let sessions = self.active_sessions.read().await;
769 if let Some(managed_session) = sessions.get(session_id) {
770 let servers: Vec<_> = managed_session
771 .session
772 .state
773 .mcp_servers
774 .values()
775 .cloned()
776 .collect();
777 return Ok(servers);
778 }
779 }
780
781 match self.store.get_session(session_id).await? {
783 Some(session) => {
784 let servers: Vec<_> = session.state.mcp_servers.values().cloned().collect();
785 Ok(servers)
786 }
787 None => Err(SessionManagerError::SessionNotActive {
788 session_id: session_id.to_string(),
789 }
790 .into()),
791 }
792 }
793}
794
795fn translate_app_event(app_event: AppEvent) -> Option<StreamEvent> {
797 match app_event {
798 AppEvent::MessageAdded { message, model } => Some(StreamEvent::MessageComplete {
799 message,
800 usage: None,
801 metadata: std::collections::HashMap::new(),
802 model,
803 }),
804
805 AppEvent::MessagePart { id, delta } => Some(StreamEvent::MessagePart {
806 content: delta,
807 message_id: id,
808 }),
809
810 AppEvent::ToolCallStarted {
811 name,
812 id,
813 parameters,
814 model,
815 } => {
816 let tool_call = ToolCall {
817 id: id.clone(),
818 name: name.clone(),
819 parameters,
820 };
821 Some(StreamEvent::ToolCallStarted {
822 tool_call,
823 metadata: std::collections::HashMap::new(),
824 model,
825 })
826 }
827
828 AppEvent::ToolCallCompleted {
829 name: _,
830 result,
831 id,
832 model,
833 } => Some(StreamEvent::ToolCallCompleted {
834 tool_call_id: id,
835 result,
836 metadata: std::collections::HashMap::new(),
837 model,
838 }),
839
840 AppEvent::ToolCallFailed {
841 name: _,
842 error,
843 id,
844 model,
845 } => Some(StreamEvent::ToolCallFailed {
846 tool_call_id: id,
847 error,
848 metadata: std::collections::HashMap::new(),
849 model,
850 }),
851
852 AppEvent::WorkspaceChanged => Some(StreamEvent::WorkspaceChanged),
853
854 AppEvent::WorkspaceFiles { files } => Some(StreamEvent::WorkspaceFiles {
855 files: files.clone(),
856 }),
857
858 AppEvent::Started { id, op } => Some(StreamEvent::OperationStarted {
859 operation_id: id,
860 operation: op,
861 }),
862 AppEvent::Finished { id, outcome } => Some(StreamEvent::OperationCompleted {
863 operation_id: id,
864 outcome,
865 }),
866 AppEvent::OperationCancelled { op_id, info } => {
867 let operation_id = op_id.unwrap_or_else(uuid::Uuid::new_v4);
869 Some(StreamEvent::OperationCancelled {
870 operation_id,
871 reason: info.to_string(), })
873 }
874
875 _ => None,
877 }
878}
879async fn update_session_state_for_event(
881 store: &Arc<dyn SessionStore>,
882 session_id: &str,
883 event: &StreamEvent,
884) -> Result<()> {
885 match event {
886 StreamEvent::MessageComplete { message, .. } => {
887 store.append_message(session_id, message).await?;
888
889 if let crate::app::conversation::MessageData::Tool {
891 tool_use_id,
892 result,
893 ..
894 } = &message.data
895 {
896 let stats = crate::session::ToolExecutionStats::success_typed(
897 serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
898 result.variant_name().to_string(),
899 0, );
901 let update = ToolCallUpdate::set_result(stats);
902 store.update_tool_call(tool_use_id, update).await?;
903 }
904 }
905 StreamEvent::ToolCallStarted { tool_call, .. } => {
906 store.create_tool_call(session_id, tool_call).await?;
907 }
908 StreamEvent::ToolCallCompleted {
909 tool_call_id,
910 result,
911 ..
912 } => {
913 let stats = crate::session::ToolExecutionStats::success_typed(
914 serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
915 result.variant_name().to_string(),
916 0,
917 );
918 let update = ToolCallUpdate::set_result(stats);
919 store.update_tool_call(tool_call_id, update).await?;
920
921 let messages = store.get_messages(session_id, None).await?;
924 let parent_id = messages.last().map(|m| m.id().to_string());
925
926 let tool_message = ConversationMessage {
927 data: crate::app::conversation::MessageData::Tool {
928 tool_use_id: tool_call_id.clone(),
929 result: result.clone(),
930 },
931 timestamp: std::time::SystemTime::now()
932 .duration_since(std::time::UNIX_EPOCH)
933 .expect("Time went backwards")
934 .as_secs(),
935 id: format!("tool_result_{tool_call_id}"),
936 parent_message_id: parent_id,
937 };
938 store.append_message(session_id, &tool_message).await?;
939 }
940 StreamEvent::ToolCallFailed {
941 tool_call_id,
942 error,
943 ..
944 } => {
945 let update = ToolCallUpdate::set_error(error.clone());
946 store.update_tool_call(tool_call_id, update).await?;
947
948 let messages = store.get_messages(session_id, None).await?;
951 let parent_id = messages.last().map(|m| m.id().to_string());
952
953 let tool_error = steer_tools::error::ToolError::Execution {
954 tool_name: "unknown".to_string(), message: error.clone(),
956 };
957 let tool_message = ConversationMessage {
958 data: MessageData::Tool {
959 tool_use_id: tool_call_id.clone(),
960 result: ToolResult::Error(tool_error),
961 },
962 timestamp: std::time::SystemTime::now()
963 .duration_since(std::time::UNIX_EPOCH)
964 .expect("Time went backwards")
965 .as_secs(),
966 id: format!("tool_result_{tool_call_id}"),
967 parent_message_id: parent_id,
968 };
969 store.append_message(session_id, &tool_message).await?;
970 }
971 _ => {}
973 }
974 Ok(())
975}
976
977#[cfg(test)]
978mod tests {
979 use super::*;
980 use crate::api::ToolCall;
981 use crate::app::MessageData;
982 use crate::app::conversation::{AssistantContent, Role, UserContent};
983 use crate::session::stores::sqlite::SqliteSessionStore;
984 use tempfile::TempDir;
985
986 async fn create_test_manager() -> (SessionManager, TempDir) {
987 let temp_dir = TempDir::new().unwrap();
988 let db_path = temp_dir.path().join("test.db");
989 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
990
991 let config = SessionManagerConfig {
992 max_concurrent_sessions: 100,
993 default_model: Model::default(),
994 auto_persist: true,
995 };
996 let manager = SessionManager::new(store, config);
997
998 (manager, temp_dir)
999 }
1000
1001 fn create_test_app_config() -> AppConfig {
1002 crate::test_utils::test_app_config()
1003 }
1004
1005 #[tokio::test]
1006 async fn test_create_and_resume_session() {
1007 let (manager, temp) = create_test_manager().await;
1008 let app_config = create_test_app_config();
1009
1010 let session_config = SessionConfig {
1012 workspace: crate::session::state::WorkspaceConfig::Local {
1013 path: temp.path().to_path_buf(),
1014 },
1015 tool_config: crate::session::SessionToolConfig::default(),
1016 system_prompt: None,
1017 metadata: std::collections::HashMap::new(),
1018 };
1019 let (session_id, _command_tx) = manager
1020 .create_session(session_config, app_config.clone())
1021 .await
1022 .unwrap();
1023 assert!(!session_id.is_empty());
1024 assert!(manager.is_session_active(&session_id).await);
1025
1026 assert!(manager.suspend_session(&session_id).await.unwrap());
1028 assert!(!manager.is_session_active(&session_id).await);
1029
1030 let _command_tx = manager
1032 .resume_session(&session_id, app_config)
1033 .await
1034 .unwrap();
1035 assert!(manager.is_session_active(&session_id).await);
1036 }
1037
1038 #[tokio::test]
1039 async fn test_session_cleanup() {
1040 let (manager, temp) = create_test_manager().await;
1041 let app_config = create_test_app_config();
1042
1043 let session_config = SessionConfig {
1045 workspace: crate::session::state::WorkspaceConfig::Local {
1046 path: temp.path().to_path_buf(),
1047 },
1048 tool_config: crate::session::SessionToolConfig::default(),
1049 system_prompt: None,
1050 metadata: std::collections::HashMap::new(),
1051 };
1052 let (session_id, _command_tx) = manager
1053 .create_session(session_config, app_config)
1054 .await
1055 .unwrap();
1056
1057 {
1059 let mut sessions = manager.active_sessions.write().await;
1060 if let Some(session) = sessions.get_mut(&session_id) {
1061 session.last_activity = chrono::Utc::now() - chrono::Duration::hours(2);
1062 session.subscriber_count = 0;
1063 }
1064 }
1065
1066 let cleaned = manager
1068 .cleanup_inactive_sessions(chrono::Duration::hours(1))
1069 .await;
1070 assert_eq!(cleaned, 1);
1071 assert!(!manager.is_session_active(&session_id).await);
1072 }
1073
1074 #[tokio::test]
1075 async fn test_capacity_rejection() {
1076 let temp_dir = TempDir::new().unwrap();
1077 let temp = tempfile::TempDir::new().unwrap();
1078 let db_path = temp_dir.path().join("test.db");
1079 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
1080
1081 let config = SessionManagerConfig {
1082 max_concurrent_sessions: 1, default_model: Model::default(),
1084 auto_persist: true,
1085 };
1086 let manager = SessionManager::new(store, config);
1087 let app_config = create_test_app_config();
1088
1089 let tool_config = crate::session::SessionToolConfig {
1091 approval_policy: crate::session::ToolApprovalPolicy::AlwaysAsk,
1092 ..Default::default()
1093 };
1094
1095 let session_config = SessionConfig {
1096 workspace: crate::session::state::WorkspaceConfig::Local {
1097 path: temp.path().to_path_buf(),
1098 },
1099 tool_config,
1100 system_prompt: None,
1101 metadata: std::collections::HashMap::new(),
1102 };
1103 let (session_id1, _command_tx) = manager
1104 .create_session(session_config.clone(), app_config.clone())
1105 .await
1106 .unwrap();
1107 assert!(!session_id1.is_empty());
1108
1109 let result = manager.create_session(session_config, app_config).await;
1111
1112 assert!(result.is_err());
1113 assert!(matches!(
1114 result,
1115 Err(crate::error::Error::SessionManager(
1116 SessionManagerError::CapacityExceeded { .. }
1117 ))
1118 ));
1119 match result.unwrap_err() {
1120 crate::error::Error::SessionManager(SessionManagerError::CapacityExceeded {
1121 current,
1122 max,
1123 }) => {
1124 assert_eq!(current, 1);
1125 assert_eq!(max, 1);
1126 }
1127 _ => unreachable!(),
1128 }
1129 }
1130
1131 #[tokio::test]
1132 async fn test_tool_result_persistence_on_resume() {
1133 let (manager, temp) = create_test_manager().await;
1134 let app_config = create_test_app_config();
1135
1136 let session_config = SessionConfig {
1138 workspace: crate::session::state::WorkspaceConfig::Local {
1139 path: temp.path().to_path_buf(),
1140 },
1141 tool_config: crate::session::SessionToolConfig::default(),
1142 system_prompt: None,
1143 metadata: std::collections::HashMap::new(),
1144 };
1145 let (session_id, _command_tx) = manager
1146 .create_session(session_config, app_config.clone())
1147 .await
1148 .unwrap();
1149
1150 let user_message = ConversationMessage {
1153 data: crate::app::conversation::MessageData::User {
1154 content: vec![UserContent::Text {
1155 text: "Read the file test.txt".to_string(),
1156 }],
1157 },
1158 timestamp: 123456789,
1159 id: "user_1".to_string(),
1160 parent_message_id: None,
1161 };
1162 manager
1163 .store
1164 .append_message(&session_id, &user_message)
1165 .await
1166 .unwrap();
1167
1168 let assistant_message = ConversationMessage {
1170 data: crate::app::conversation::MessageData::Assistant {
1171 content: vec![
1172 AssistantContent::Text {
1173 text: "I'll read that file for you.".to_string(),
1174 },
1175 AssistantContent::ToolCall {
1176 tool_call: ToolCall {
1177 id: "tool_call_1".to_string(),
1178 name: "read_file".to_string(),
1179 parameters: serde_json::json!({"path": "test.txt"}),
1180 },
1181 },
1182 ],
1183 },
1184 timestamp: 123456790,
1185 id: "assistant_1".to_string(),
1186 parent_message_id: Some("user_1".to_string()),
1187 };
1188 manager
1189 .store
1190 .append_message(&session_id, &assistant_message)
1191 .await
1192 .unwrap();
1193
1194 let tool_call = ToolCall {
1197 id: "tool_call_1".to_string(),
1198 name: "read_file".to_string(),
1199 parameters: serde_json::json!({"path": "test.txt"}),
1200 };
1201 manager
1202 .store
1203 .create_tool_call(&session_id, &tool_call)
1204 .await
1205 .unwrap();
1206
1207 let stats = crate::session::ToolExecutionStats::success_typed(
1209 serde_json::json!({
1210 "content": "File contents: Hello, world!",
1211 "file_path": "test.txt",
1212 "line_count": 1,
1213 "truncated": false
1214 }),
1215 "FileContent".to_string(),
1216 0,
1217 );
1218 let update = ToolCallUpdate::set_result(stats);
1219 manager
1220 .store
1221 .update_tool_call("tool_call_1", update)
1222 .await
1223 .unwrap();
1224
1225 let tool_message = ConversationMessage {
1227 data: MessageData::Tool {
1228 tool_use_id: "tool_call_1".to_string(),
1229 result: ToolResult::FileContent(steer_tools::result::FileContentResult {
1230 content: "File contents: Hello, world!".to_string(),
1231 file_path: "test.txt".to_string(),
1232 line_count: 1,
1233 truncated: false,
1234 }),
1235 },
1236 timestamp: 123456790,
1237 id: "tool_result_tool_call_1".to_string(),
1238 parent_message_id: Some("assistant_1".to_string()),
1239 };
1240 manager
1241 .store
1242 .append_message(&session_id, &tool_message)
1243 .await
1244 .unwrap();
1245
1246 let followup_message = ConversationMessage {
1248 data: crate::app::conversation::MessageData::Assistant {
1249 content: vec![AssistantContent::Text {
1250 text: "The file contains: Hello, world!".to_string(),
1251 }],
1252 },
1253 timestamp: 123456791,
1254 id: "assistant_2".to_string(),
1255 parent_message_id: Some("assistant_1".to_string()),
1256 };
1257 manager
1258 .store
1259 .append_message(&session_id, &followup_message)
1260 .await
1261 .unwrap();
1262
1263 manager.suspend_session(&session_id).await.unwrap();
1265
1266 let loaded_session = manager
1268 .store
1269 .get_session(&session_id)
1270 .await
1271 .unwrap()
1272 .unwrap();
1273
1274 assert_eq!(loaded_session.state.messages.len(), 4);
1276
1277 let tool_result_msg = &loaded_session.state.messages[2];
1279 assert_eq!(tool_result_msg.role(), Role::Tool);
1280
1281 assert!(matches!(
1283 &tool_result_msg.data,
1284 crate::app::conversation::MessageData::Tool { .. }
1285 ));
1286 if let crate::app::conversation::MessageData::Tool {
1287 tool_use_id,
1288 result,
1289 ..
1290 } = &tool_result_msg.data
1291 {
1292 assert_eq!(tool_use_id, "tool_call_1");
1293 assert!(matches!(
1294 result,
1295 crate::app::conversation::ToolResult::FileContent(_)
1296 ));
1297 match result {
1298 crate::app::conversation::ToolResult::FileContent(content) => {
1299 assert!(content.content.contains("Hello, world!"));
1300 }
1301 _ => unreachable!(),
1302 }
1303 } else {
1304 panic!("Expected Tool message");
1305 }
1306
1307 let _command_tx = manager
1309 .resume_session(&session_id, app_config)
1310 .await
1311 .unwrap();
1312
1313 }
1317
1318 #[tokio::test]
1319 async fn test_active_message_id_persistence() {
1320 let (manager, temp) = create_test_manager().await;
1321 let app_config = create_test_app_config();
1322
1323 let session_config = SessionConfig {
1325 workspace: crate::session::state::WorkspaceConfig::Local {
1326 path: temp.path().to_path_buf(),
1327 },
1328 tool_config: crate::session::SessionToolConfig::default(),
1329 system_prompt: None,
1330 metadata: std::collections::HashMap::new(),
1331 };
1332 let (session_id, _command_tx) = manager
1333 .create_session(session_config, app_config.clone())
1334 .await
1335 .unwrap();
1336
1337 let msg1 = ConversationMessage {
1339 data: crate::app::conversation::MessageData::User {
1340 content: vec![UserContent::Text {
1341 text: "Hello".to_string(),
1342 }],
1343 },
1344 timestamp: 1000,
1345 id: "msg1".to_string(),
1346 parent_message_id: None,
1347 };
1348
1349 let msg2 = ConversationMessage {
1350 data: crate::app::conversation::MessageData::Assistant {
1351 content: vec![AssistantContent::Text {
1352 text: "Hi there!".to_string(),
1353 }],
1354 },
1355 timestamp: 2000,
1356 id: "msg2".to_string(),
1357 parent_message_id: Some("msg1".to_string()),
1358 };
1359
1360 let msg1_edited = ConversationMessage {
1362 data: crate::app::conversation::MessageData::User {
1363 content: vec![UserContent::Text {
1364 text: "Goodbye".to_string(),
1365 }],
1366 },
1367 timestamp: 3000,
1368 id: "msg1_edited".to_string(),
1369 parent_message_id: None, };
1371
1372 manager
1374 .store
1375 .append_message(&session_id, &msg1)
1376 .await
1377 .unwrap();
1378 manager
1379 .store
1380 .append_message(&session_id, &msg2)
1381 .await
1382 .unwrap();
1383 manager
1384 .store
1385 .append_message(&session_id, &msg1_edited)
1386 .await
1387 .unwrap();
1388
1389 manager
1391 .store
1392 .update_active_message_id(&session_id, Some("msg1_edited"))
1393 .await
1394 .unwrap();
1395
1396 manager.suspend_session(&session_id).await.unwrap();
1398
1399 let loaded_session = manager
1401 .store
1402 .get_session(&session_id)
1403 .await
1404 .unwrap()
1405 .unwrap();
1406
1407 assert_eq!(
1409 loaded_session.state.active_message_id,
1410 Some("msg1_edited".to_string())
1411 );
1412
1413 assert_eq!(loaded_session.state.messages.len(), 3);
1415
1416 let edited_msg = loaded_session
1418 .state
1419 .messages
1420 .iter()
1421 .find(|m| m.id() == "msg1_edited")
1422 .expect("Edited message should exist");
1423
1424 match &edited_msg.data {
1425 crate::app::conversation::MessageData::User { content, .. } => {
1426 if let Some(UserContent::Text { text }) = content.first() {
1427 assert_eq!(text, "Goodbye");
1428 } else {
1429 panic!("Expected text content");
1430 }
1431 }
1432 _ => panic!("Expected user message"),
1433 }
1434
1435 let _ = manager
1437 .resume_session(&session_id, app_config)
1438 .await
1439 .unwrap();
1440
1441 let state = manager
1443 .get_session_state(&session_id)
1444 .await
1445 .unwrap()
1446 .unwrap();
1447 assert_eq!(state.active_message_id, Some("msg1_edited".to_string()));
1448 }
1449}