steer_core/session/
state.rs

1use crate::error::Result;
2use chrono::{DateTime, Utc};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use crate::api::Model;
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16/// State of an MCP server connection
17#[derive(Debug, Clone)]
18pub enum McpConnectionState {
19    /// Currently attempting to connect
20    Connecting,
21    /// Successfully connected
22    Connected {
23        /// Names of tools available from this server
24        tool_names: Vec<String>,
25    },
26    /// Failed to connect
27    Failed {
28        /// Error message describing the failure
29        error: String,
30    },
31}
32
33/// Information about an MCP server
34#[derive(Debug, Clone)]
35pub struct McpServerInfo {
36    /// The configured server name
37    pub server_name: String,
38    /// The transport configuration
39    pub transport: McpTransport,
40    /// Current connection state
41    pub state: McpConnectionState,
42    /// Timestamp when this state was last updated
43    pub last_updated: DateTime<Utc>,
44}
45
46/// Container runtime to use
47#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
48#[serde(rename_all = "snake_case")]
49pub enum ContainerRuntime {
50    Docker,
51    Podman,
52}
53
54/// Defines the primary execution environment for a session's workspace
55#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum WorkspaceConfig {
58    Local {
59        path: PathBuf,
60    },
61    Remote {
62        agent_address: String,
63        auth: Option<RemoteAuth>,
64    },
65    Container {
66        image: String,
67        runtime: ContainerRuntime,
68    },
69}
70
71impl WorkspaceConfig {
72    pub fn get_path(&self) -> Option<String> {
73        match self {
74            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
75            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
76            WorkspaceConfig::Container { .. } => None,
77        }
78    }
79
80    /// Convert to steer_workspace::WorkspaceConfig
81    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
82        match self {
83            WorkspaceConfig::Local { path } => {
84                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
85            }
86            WorkspaceConfig::Remote {
87                agent_address,
88                auth,
89            } => steer_workspace::WorkspaceConfig::Remote {
90                address: agent_address.clone(),
91                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
92            },
93            WorkspaceConfig::Container { .. } => {
94                // For now, container workspaces are not implemented in steer_workspace
95                steer_workspace::WorkspaceConfig::Local {
96                    path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
97                }
98            }
99        }
100    }
101}
102
103impl Default for WorkspaceConfig {
104    fn default() -> Self {
105        Self::Local {
106            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
107        }
108    }
109}
110
111/// Complete session representation
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct Session {
114    pub id: String,
115    pub created_at: DateTime<Utc>,
116    pub updated_at: DateTime<Utc>,
117    pub config: SessionConfig,
118    pub state: SessionState,
119}
120
121impl Session {
122    pub fn new(id: String, config: SessionConfig) -> Self {
123        let now = Utc::now();
124        Self {
125            id,
126            created_at: now,
127            updated_at: now,
128            config,
129            state: SessionState::default(),
130        }
131    }
132
133    pub fn update_timestamp(&mut self) {
134        self.updated_at = Utc::now();
135    }
136
137    /// Check if session has any recent activity
138    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
139        let cutoff = Utc::now() - threshold;
140        self.updated_at > cutoff
141    }
142
143    /// Build a workspace from this session's configuration
144    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
145        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
146    }
147}
148
149/// Session configuration - immutable once created
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct SessionConfig {
152    pub workspace: WorkspaceConfig,
153    pub tool_config: SessionToolConfig,
154    /// Optional custom system prompt to use for the session. If `None`, Steer will
155    /// fall back to its built-in default prompt.
156    pub system_prompt: Option<String>,
157    pub metadata: HashMap<String, String>,
158}
159
160impl SessionConfig {
161    /// Build a BackendRegistry from this configuration for external tools only.
162    /// Workspace tools are now handled directly by the Workspace.
163    /// Returns the registry and a map of MCP server connection states.
164    pub async fn build_registry(
165        &self,
166        llm_config_provider: Arc<LlmConfigProvider>,
167        workspace: Arc<dyn crate::workspace::Workspace>,
168    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
169        let mut registry = BackendRegistry::new();
170        let mut mcp_servers = HashMap::new();
171
172        // 1. Register all USER-DEFINED backends first.
173        // Their tool mappings may be overwritten by the more authoritative backends below.
174        for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
175            match backend_config {
176                BackendConfig::Local { tool_filter } => {
177                    let backend = match tool_filter {
178                        ToolFilter::All => {
179                            LocalBackend::full(llm_config_provider.clone(), workspace.clone())
180                        }
181                        ToolFilter::Include(tools) => LocalBackend::with_tools(
182                            tools.clone(),
183                            llm_config_provider.clone(),
184                            workspace.clone(),
185                        ),
186                        ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
187                            excluded.clone(),
188                            llm_config_provider.clone(),
189                            workspace.clone(),
190                        ),
191                    };
192                    registry
193                        .register(format!("user_local_{idx}"), Arc::new(backend))
194                        .await;
195                }
196                BackendConfig::Mcp {
197                    server_name,
198                    transport,
199                    tool_filter,
200                } => {
201                    tracing::info!(
202                        "Attempting to initialize MCP backend '{}' with transport: {:?}",
203                        server_name,
204                        transport
205                    );
206
207                    // Record that we're attempting to connect
208                    let mut server_info = McpServerInfo {
209                        server_name: server_name.clone(),
210                        transport: transport.clone(),
211                        state: McpConnectionState::Connecting,
212                        last_updated: Utc::now(),
213                    };
214
215                    match crate::tools::McpBackend::new(
216                        server_name.clone(),
217                        transport.clone(),
218                        tool_filter.clone(),
219                    )
220                    .await
221                    {
222                        Ok(mcp_backend) => {
223                            let tool_names = mcp_backend.supported_tools().await;
224                            let tool_count = tool_names.len();
225                            tracing::info!(
226                                "Successfully initialized MCP backend '{}' with {} tools",
227                                server_name,
228                                tool_count
229                            );
230                            server_info.state = McpConnectionState::Connected { tool_names };
231                            server_info.last_updated = Utc::now();
232                            registry
233                                .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
234                                .await;
235                        }
236                        Err(e) => {
237                            tracing::error!(
238                                "Failed to initialize MCP backend '{}': {}",
239                                server_name,
240                                e
241                            );
242                            server_info.state = McpConnectionState::Failed {
243                                error: e.to_string(),
244                            };
245                            server_info.last_updated = Utc::now();
246                        }
247                    }
248
249                    mcp_servers.insert(server_name.clone(), server_info);
250                }
251            }
252        }
253
254        // 2. Register SERVER tools (like dispatch_agent and web_fetch).
255        // These are external tools, not workspace tools.
256        let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
257        if !server_backend.supported_tools().await.is_empty() {
258            registry
259                .register("server".to_string(), Arc::new(server_backend))
260                .await;
261        }
262
263        // Note: Workspace tools are handled directly by the Workspace implementation.
264
265        Ok((registry, mcp_servers))
266    }
267
268    /// Filter tools based on visibility settings
269    pub fn filter_tools_by_visibility(
270        &self,
271        tools: Vec<steer_tools::ToolSchema>,
272    ) -> Vec<steer_tools::ToolSchema> {
273        match &self.tool_config.visibility {
274            ToolVisibility::All => tools,
275            ToolVisibility::ReadOnly => {
276                let read_only_names: HashSet<String> = read_only_workspace_tools()
277                    .iter()
278                    .map(|t| t.name().to_string())
279                    .collect();
280
281                tools
282                    .into_iter()
283                    .filter(|schema| read_only_names.contains(&schema.name))
284                    .collect()
285            }
286            ToolVisibility::Whitelist(allowed) => tools
287                .into_iter()
288                .filter(|schema| allowed.contains(&schema.name))
289                .collect(),
290            ToolVisibility::Blacklist(blocked) => tools
291                .into_iter()
292                .filter(|schema| !blocked.contains(&schema.name))
293                .collect(),
294        }
295    }
296
297    /// Minimal read-only configuration
298    pub fn read_only() -> Self {
299        Self {
300            workspace: WorkspaceConfig::Local {
301                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
302            },
303            tool_config: SessionToolConfig::read_only(),
304            system_prompt: None,
305            metadata: HashMap::new(),
306        }
307    }
308}
309
310/// Tool visibility configuration - controls which tools are shown to the AI agent
311#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
312#[serde(tag = "type", rename_all = "snake_case")]
313pub enum ToolVisibility {
314    /// Show all registered tools to the AI
315    All,
316
317    /// Only show read-only tools to the AI
318    ReadOnly,
319
320    /// Show only specific tools to the AI (whitelist)
321    Whitelist(HashSet<String>),
322
323    /// Hide specific tools from the AI (blacklist)
324    Blacklist(HashSet<String>),
325}
326
327impl Default for ToolVisibility {
328    fn default() -> Self {
329        Self::All
330    }
331}
332
333/// Tool-specific configuration for bash
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
335pub struct BashToolConfig {
336    /// Command patterns that are pre-approved for execution
337    #[serde(default)]
338    pub approved_patterns: Vec<String>,
339}
340
341/// Tool approval policy configuration
342#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
343#[serde(tag = "type", rename_all = "snake_case")]
344pub enum ToolApprovalPolicy {
345    /// Always ask for approval before executing any tool
346    AlwaysAsk,
347
348    /// Pre-approved tools execute without asking
349    PreApproved { tools: HashSet<String> },
350
351    /// Mixed policy: some tools pre-approved, others require approval
352    Mixed {
353        pre_approved: HashSet<String>,
354        ask_for_others: bool,
355    },
356}
357
358impl ToolApprovalPolicy {
359    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
360        match self {
361            ToolApprovalPolicy::AlwaysAsk => false,
362            ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
363            ToolApprovalPolicy::Mixed {
364                pre_approved,
365                ask_for_others: _,
366            } => pre_approved.contains(tool_name),
367        }
368    }
369
370    pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
371        match self {
372            ToolApprovalPolicy::AlwaysAsk => true,
373            ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
374            ToolApprovalPolicy::Mixed {
375                pre_approved,
376                ask_for_others,
377            } => !pre_approved.contains(tool_name) && *ask_for_others,
378        }
379    }
380}
381
382/// Authentication configuration for remote backends
383#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
384pub enum RemoteAuth {
385    Bearer { token: String },
386    ApiKey { key: String },
387}
388
389impl RemoteAuth {
390    /// Convert to steer_workspace RemoteAuth type
391    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
392        match self {
393            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
394            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
395        }
396    }
397}
398
399/// Tool filtering configuration for backends
400#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
401#[serde(rename_all = "snake_case")]
402pub enum ToolFilter {
403    /// Include all available tools
404    All,
405    /// Include only the specified tools
406    Include(Vec<String>),
407    /// Include all tools except the specified ones
408    Exclude(Vec<String>),
409}
410
411impl Default for ToolFilter {
412    fn default() -> Self {
413        Self::All
414    }
415}
416
417/// Backend configuration for different tool execution environments
418#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
419#[serde(tag = "type", rename_all = "snake_case")]
420pub enum BackendConfig {
421    Local {
422        /// Tool filtering configuration for the local backend
423        tool_filter: ToolFilter,
424    },
425    Mcp {
426        server_name: String,
427        transport: McpTransport,
428        tool_filter: ToolFilter,
429    },
430}
431
432/// Tool configuration for the session
433#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
434pub struct SessionToolConfig {
435    /// Backend configurations for this session
436    pub backends: Vec<BackendConfig>,
437    /// Tool visibility - controls which tools are shown to the AI agent
438    pub visibility: ToolVisibility,
439    /// Tool approval policy - controls when user approval is needed
440    pub approval_policy: ToolApprovalPolicy,
441    /// Additional metadata for tool configuration
442    pub metadata: HashMap<String, String>,
443    /// Tool-specific configurations
444    #[serde(default)]
445    pub tools: HashMap<String, ToolSpecificConfig>,
446}
447
448/// Tool-specific configurations
449#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
450#[serde(untagged)]
451pub enum ToolSpecificConfig {
452    /// Configuration for the bash tool
453    Bash(BashToolConfig),
454}
455
456impl Default for SessionToolConfig {
457    fn default() -> Self {
458        Self {
459            backends: Vec::new(),
460            visibility: ToolVisibility::All,
461            approval_policy: ToolApprovalPolicy::AlwaysAsk,
462            metadata: HashMap::new(),
463            tools: HashMap::new(),
464        }
465    }
466}
467
468impl SessionToolConfig {
469    /// Minimal read-only configuration
470    pub fn read_only() -> Self {
471        Self {
472            backends: Vec::new(), // Use default backends
473            visibility: ToolVisibility::ReadOnly,
474            approval_policy: ToolApprovalPolicy::AlwaysAsk,
475            metadata: HashMap::new(),
476            tools: HashMap::new(),
477        }
478    }
479}
480
481/// Mutable session state that changes during execution
482#[derive(Debug, Clone, Serialize, Deserialize, Default)]
483pub struct SessionState {
484    /// Conversation messages
485    pub messages: Vec<Message>,
486
487    /// Tool call tracking
488    pub tool_calls: HashMap<String, ToolCallState>,
489
490    /// Tools that have been approved for this session
491    pub approved_tools: HashSet<String>,
492
493    /// Bash commands that have been approved for this session (dynamically added)
494    #[serde(default)]
495    pub approved_bash_patterns: HashSet<String>,
496
497    /// Last processed event sequence number for replay
498    pub last_event_sequence: u64,
499
500    /// Additional runtime metadata
501    pub metadata: HashMap<String, String>,
502
503    /// The ID of the currently active message (head of selected branch)
504    /// None means use last message semantics for backward compatibility
505    #[serde(default, skip_serializing_if = "Option::is_none")]
506    pub active_message_id: Option<String>,
507
508    /// Status of MCP server connections
509    /// This is a transient field that is rebuilt on session activation
510    #[serde(default, skip_serializing, skip_deserializing)]
511    pub mcp_servers: HashMap<String, McpServerInfo>,
512}
513
514impl SessionState {
515    /// Add a message to the conversation
516    pub fn add_message(&mut self, message: Message) {
517        self.messages.push(message);
518    }
519
520    /// Get the number of messages in the conversation
521    pub fn message_count(&self) -> usize {
522        self.messages.len()
523    }
524
525    /// Get the last message in the conversation
526    pub fn last_message(&self) -> Option<&Message> {
527        self.messages.last()
528    }
529
530    /// Add a tool call to tracking
531    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
532        let state = ToolCallState {
533            tool_call: tool_call.clone(),
534            status: ToolCallStatus::PendingApproval,
535            started_at: None,
536            completed_at: None,
537            result: None,
538        };
539        self.tool_calls.insert(tool_call.id, state);
540    }
541
542    /// Update tool call status
543    pub fn update_tool_call_status(
544        &mut self,
545        tool_call_id: &str,
546        status: ToolCallStatus,
547    ) -> std::result::Result<(), String> {
548        let tool_call = self
549            .tool_calls
550            .get_mut(tool_call_id)
551            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
552
553        // Update timestamps based on status changes
554        match (&tool_call.status, &status) {
555            (_, ToolCallStatus::Executing) => {
556                tool_call.started_at = Some(Utc::now());
557            }
558            (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
559                tool_call.completed_at = Some(Utc::now());
560            }
561            _ => {}
562        }
563
564        tool_call.status = status;
565        Ok(())
566    }
567
568    /// Approve a tool for future use
569    pub fn approve_tool(&mut self, tool_name: String) {
570        self.approved_tools.insert(tool_name);
571    }
572
573    /// Check if a tool is approved
574    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
575        self.approved_tools.contains(tool_name)
576    }
577
578    /// Validate internal consistency
579    pub fn validate(&self) -> std::result::Result<(), String> {
580        // Check that all tool calls referenced in messages exist
581        for message in &self.messages {
582            let tool_calls = self.extract_tool_calls_from_message(message);
583            if !tool_calls.is_empty() {
584                for tool_call_id in tool_calls {
585                    if !self.tool_calls.contains_key(&tool_call_id) {
586                        return Err(format!(
587                            "Message references unknown tool call: {tool_call_id}"
588                        ));
589                    }
590                }
591            }
592        }
593
594        Ok(())
595    }
596
597    /// Extract tool call IDs from a message
598    fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
599        let mut tool_call_ids = Vec::new();
600
601        match &message.data {
602            MessageData::Assistant { content, .. } => {
603                for c in content {
604                    if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
605                        tool_call_ids.push(tool_call.id.clone());
606                    }
607                }
608            }
609            MessageData::Tool { tool_use_id, .. } => {
610                tool_call_ids.push(tool_use_id.clone());
611            }
612            _ => {}
613        }
614
615        tool_call_ids
616    }
617
618    /// Apply an event to the session state
619    pub fn apply_event(
620        &mut self,
621        event: &crate::events::StreamEvent,
622    ) -> std::result::Result<(), String> {
623        use crate::events::StreamEvent;
624
625        match event {
626            StreamEvent::MessageComplete { message, .. } => {
627                self.add_message(message.clone());
628            }
629            StreamEvent::ToolCallStarted { tool_call, .. } => {
630                self.add_tool_call(tool_call.clone());
631            }
632            StreamEvent::ToolCallCompleted {
633                tool_call_id,
634                result,
635                ..
636            } => {
637                self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
638                if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
639                    tool_call_state.result = Some(result.clone());
640                }
641            }
642            StreamEvent::ToolCallFailed {
643                tool_call_id,
644                error,
645                ..
646            } => {
647                self.update_tool_call_status(
648                    tool_call_id,
649                    ToolCallStatus::Failed {
650                        error: error.clone(),
651                    },
652                )?;
653            }
654            StreamEvent::ToolApprovalRequired { tool_call, .. } => {
655                // Tool call should already be added with PendingApproval status
656                if !self.tool_calls.contains_key(&tool_call.id) {
657                    self.add_tool_call(tool_call.clone());
658                }
659            }
660            // Other events don't modify state directly
661            _ => {}
662        }
663
664        Ok(())
665    }
666}
667
668/// Tool call state tracking
669#[derive(Debug, Clone, Serialize, Deserialize)]
670pub struct ToolCallState {
671    pub tool_call: ToolCall,
672    pub status: ToolCallStatus,
673    pub started_at: Option<DateTime<Utc>>,
674    pub completed_at: Option<DateTime<Utc>>,
675    pub result: Option<ToolResult>,
676}
677
678impl ToolCallState {
679    pub fn is_pending(&self) -> bool {
680        matches!(self.status, ToolCallStatus::PendingApproval)
681    }
682
683    pub fn is_complete(&self) -> bool {
684        matches!(
685            self.status,
686            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
687        )
688    }
689
690    pub fn duration(&self) -> Option<chrono::Duration> {
691        match (self.started_at, self.completed_at) {
692            (Some(start), Some(end)) => Some(end - start),
693            _ => None,
694        }
695    }
696}
697
698/// Tool call execution status
699#[derive(Debug, Clone, Serialize, Deserialize)]
700#[serde(tag = "status", rename_all = "snake_case")]
701pub enum ToolCallStatus {
702    PendingApproval,
703    Approved,
704    Denied,
705    Executing,
706    Completed,
707    Failed { error: String },
708}
709
710impl ToolCallStatus {
711    pub fn is_terminal(&self) -> bool {
712        matches!(
713            self,
714            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
715        )
716    }
717}
718
719/// Tool execution statistics
720#[derive(Debug, Clone, Serialize, Deserialize)]
721pub struct ToolExecutionStats {
722    #[serde(skip_serializing_if = "Option::is_none")]
723    pub output: Option<String>, // Legacy string output
724    #[serde(skip_serializing_if = "Option::is_none")]
725    pub json_output: Option<serde_json::Value>, // Typed JSON output
726    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
727    pub success: bool,
728    pub execution_time_ms: u64,
729    pub metadata: HashMap<String, String>,
730}
731
732impl ToolExecutionStats {
733    pub fn success(output: String, execution_time_ms: u64) -> Self {
734        Self {
735            output: Some(output),
736            json_output: None,
737            result_type: None,
738            success: true,
739            execution_time_ms,
740            metadata: HashMap::new(),
741        }
742    }
743
744    pub fn success_typed(
745        json_output: serde_json::Value,
746        result_type: String,
747        execution_time_ms: u64,
748    ) -> Self {
749        Self {
750            output: None,
751            json_output: Some(json_output),
752            result_type: Some(result_type),
753            success: true,
754            execution_time_ms,
755            metadata: HashMap::new(),
756        }
757    }
758
759    pub fn failure(error: String, execution_time_ms: u64) -> Self {
760        Self {
761            output: Some(error),
762            json_output: None,
763            result_type: None,
764            success: false,
765            execution_time_ms,
766            metadata: HashMap::new(),
767        }
768    }
769
770    pub fn with_metadata(mut self, key: String, value: String) -> Self {
771        self.metadata.insert(key, value);
772        self
773    }
774}
775
776/// Session metadata for listing and filtering
777#[derive(Debug, Clone, Serialize, Deserialize)]
778pub struct SessionInfo {
779    pub id: String,
780    pub created_at: DateTime<Utc>,
781    pub updated_at: DateTime<Utc>,
782    /// The last known model used in this session
783    pub last_model: Option<Model>,
784    pub message_count: usize,
785    pub metadata: HashMap<String, String>,
786}
787
788impl From<&Session> for SessionInfo {
789    fn from(session: &Session) -> Self {
790        Self {
791            id: session.id.clone(),
792            created_at: session.created_at,
793            updated_at: session.updated_at,
794            last_model: None, // TODO: Track last model used from events
795            message_count: session.state.message_count(),
796            metadata: session.config.metadata.clone(),
797        }
798    }
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use crate::app::conversation::{Message, MessageData, UserContent};
805    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
806
807    #[test]
808    fn test_session_creation() {
809        let config = SessionConfig {
810            workspace: WorkspaceConfig::Local {
811                path: PathBuf::from("/test/path"),
812            },
813            tool_config: SessionToolConfig::default(),
814            system_prompt: None,
815            metadata: HashMap::new(),
816        };
817        let session = Session::new("test-session".to_string(), config.clone());
818
819        assert_eq!(session.id, "test-session");
820        assert!(
821            session
822                .config
823                .tool_config
824                .approval_policy
825                .should_ask_for_approval("any_tool")
826        );
827        assert_eq!(session.state.message_count(), 0);
828    }
829
830    #[test]
831    fn test_tool_approval_policy() {
832        let policy = ToolApprovalPolicy::PreApproved {
833            tools: ["read_file", "list_files"]
834                .iter()
835                .map(|s| s.to_string())
836                .collect(),
837        };
838
839        assert!(policy.is_tool_approved("read_file"));
840        assert!(!policy.is_tool_approved("write_file"));
841        assert!(!policy.should_ask_for_approval("read_file"));
842        assert!(policy.should_ask_for_approval("write_file"));
843    }
844
845    #[test]
846    fn test_session_state_validation() {
847        let mut state = SessionState::default();
848
849        // Valid empty state
850        assert!(state.validate().is_ok());
851
852        // Add a message
853        let message = Message {
854            data: MessageData::User {
855                content: vec![UserContent::Text {
856                    text: "Hello".to_string(),
857                }],
858            },
859            timestamp: 123456789,
860            id: "msg1".to_string(),
861            parent_message_id: None,
862        };
863        state.add_message(message);
864
865        assert!(state.validate().is_ok());
866        assert_eq!(state.message_count(), 1);
867    }
868
869    #[test]
870    fn test_tool_call_state_tracking() {
871        let mut state = SessionState::default();
872
873        let tool_call = ToolCall {
874            id: "tool1".to_string(),
875            name: "read_file".to_string(),
876            parameters: serde_json::json!({"path": "/test.txt"}),
877        };
878
879        state.add_tool_call(tool_call.clone());
880        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
881
882        state
883            .update_tool_call_status("tool1", ToolCallStatus::Executing)
884            .unwrap();
885        let tool_state = state.tool_calls.get("tool1").unwrap();
886        assert!(tool_state.started_at.is_some());
887        assert!(!tool_state.is_complete());
888
889        state
890            .update_tool_call_status("tool1", ToolCallStatus::Completed)
891            .unwrap();
892        let tool_state = state.tool_calls.get("tool1").unwrap();
893        assert!(tool_state.completed_at.is_some());
894        assert!(tool_state.is_complete());
895    }
896
897    #[test]
898    fn test_session_tool_config_default() {
899        let config = SessionToolConfig::default();
900        assert!(config.backends.is_empty());
901    }
902
903    #[test]
904    fn test_tool_filter_exclude() {
905        // Test that we can exclude specific tools
906        let config = SessionToolConfig {
907            backends: vec![BackendConfig::Local {
908                tool_filter: ToolFilter::Exclude(vec![
909                    BASH_TOOL_NAME.to_string(),
910                    EDIT_TOOL_NAME.to_string(),
911                ]),
912            }],
913            visibility: ToolVisibility::All,
914            approval_policy: ToolApprovalPolicy::AlwaysAsk,
915            metadata: HashMap::new(),
916            tools: HashMap::new(),
917        };
918
919        assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
920        if let BackendConfig::Local { tool_filter } = &config.backends[0] {
921            assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
922            if let ToolFilter::Exclude(excluded_tools) = tool_filter {
923                assert_eq!(excluded_tools.len(), 2);
924                assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
925                assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
926            }
927        }
928    }
929
930    #[test]
931    fn test_session_tool_config_read_only() {
932        let config = SessionToolConfig::read_only();
933        assert_eq!(config.backends.len(), 0); // Empty backends means use defaults
934        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
935        assert!(matches!(
936            config.approval_policy,
937            ToolApprovalPolicy::AlwaysAsk
938        ));
939    }
940
941    #[tokio::test]
942    async fn test_session_config_build_registry_server_tools() {
943        use crate::auth::DefaultAuthStorage;
944        use crate::config::LlmConfigProvider;
945
946        // Test that server tools are properly registered
947        let config = SessionConfig {
948            workspace: WorkspaceConfig::Local {
949                path: PathBuf::from("/test/path"),
950            },
951            tool_config: SessionToolConfig::default(),
952            system_prompt: None,
953            metadata: HashMap::new(),
954        };
955
956        // For tests, we'll just unwrap since it's a test environment
957        let auth_storage =
958            DefaultAuthStorage::new().expect("Failed to create auth storage for test");
959        let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
960
961        // Create a test workspace
962        let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
963            .await
964            .unwrap();
965
966        let (registry, _mcp_servers) = config
967            .build_registry(llm_config_provider, workspace)
968            .await
969            .unwrap();
970        let schemas = registry.get_tool_schemas().await;
971        let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
972
973        // Only server tools should be in the registry
974        assert!(tool_names.contains(&"dispatch_agent".to_string()));
975        assert!(tool_names.contains(&"web_fetch".to_string()));
976
977        // Verify workspace tools are NOT in the registry (they're handled by Workspace)
978        let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
979        for tool_name in workspace_tool_names {
980            assert!(
981                !tool_names.contains(&tool_name.to_string()),
982                "Workspace tool {tool_name} should not be in registry"
983            );
984        }
985    }
986
987    // Test removed: workspace tools are no longer in the registry
988
989    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
990
991    // Test removed: workspace backend no longer exists in the registry
992
993    #[test]
994    fn test_mcp_status_tracking() {
995        // Test that MCP server info is properly tracked in session state
996        let mut session_state = SessionState::default();
997
998        // Add some MCP server info
999        let mcp_info = McpServerInfo {
1000            server_name: "test-server".to_string(),
1001            transport: crate::tools::McpTransport::Stdio {
1002                command: "python".to_string(),
1003                args: vec!["-m".to_string(), "test_server".to_string()],
1004            },
1005            state: McpConnectionState::Connected {
1006                tool_names: vec![
1007                    "tool1".to_string(),
1008                    "tool2".to_string(),
1009                    "tool3".to_string(),
1010                    "tool4".to_string(),
1011                    "tool5".to_string(),
1012                ],
1013            },
1014            last_updated: Utc::now(),
1015        };
1016
1017        session_state
1018            .mcp_servers
1019            .insert("test-server".to_string(), mcp_info.clone());
1020
1021        // Verify it's stored
1022        assert_eq!(session_state.mcp_servers.len(), 1);
1023        let stored = session_state.mcp_servers.get("test-server").unwrap();
1024        assert_eq!(stored.server_name, "test-server");
1025        assert!(matches!(
1026            stored.state,
1027            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1028        ));
1029
1030        // Test failed connection
1031        let failed_info = McpServerInfo {
1032            server_name: "failed-server".to_string(),
1033            transport: crate::tools::McpTransport::Tcp {
1034                host: "localhost".to_string(),
1035                port: 9999,
1036            },
1037            state: McpConnectionState::Failed {
1038                error: "Connection refused".to_string(),
1039            },
1040            last_updated: Utc::now(),
1041        };
1042
1043        session_state
1044            .mcp_servers
1045            .insert("failed-server".to_string(), failed_info);
1046        assert_eq!(session_state.mcp_servers.len(), 2);
1047    }
1048
1049    #[tokio::test]
1050    async fn test_mcp_server_tracking_in_build_registry() {
1051        use crate::auth::DefaultAuthStorage;
1052        use crate::config::LlmConfigProvider;
1053
1054        // Create a session config with both good and bad MCP servers
1055        let mut config = SessionConfig::read_only();
1056
1057        // This one should fail (invalid transport)
1058        config.tool_config.backends.push(BackendConfig::Mcp {
1059            server_name: "bad-server".to_string(),
1060            transport: crate::tools::McpTransport::Tcp {
1061                host: "nonexistent.invalid".to_string(),
1062                port: 12345,
1063            },
1064            tool_filter: ToolFilter::All,
1065        });
1066
1067        // This one would succeed if we had a real server running
1068        config.tool_config.backends.push(BackendConfig::Mcp {
1069            server_name: "good-server".to_string(),
1070            transport: crate::tools::McpTransport::Stdio {
1071                command: "echo".to_string(),
1072                args: vec!["test".to_string()],
1073            },
1074            tool_filter: ToolFilter::All,
1075        });
1076
1077        let auth_storage =
1078            DefaultAuthStorage::new().expect("Failed to create auth storage for test");
1079        let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
1080        let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
1081            .await
1082            .unwrap();
1083
1084        let (_registry, mcp_servers) = config
1085            .build_registry(llm_config_provider, workspace)
1086            .await
1087            .unwrap();
1088
1089        // Should have tracked both servers
1090        assert_eq!(mcp_servers.len(), 2);
1091
1092        // Check the bad server
1093        let bad_server = mcp_servers.get("bad-server").unwrap();
1094        assert_eq!(bad_server.server_name, "bad-server");
1095        assert!(matches!(
1096            bad_server.state,
1097            McpConnectionState::Failed { .. }
1098        ));
1099
1100        // Check the good server (will also fail in tests since echo isn't an MCP server)
1101        let good_server = mcp_servers.get("good-server").unwrap();
1102        assert_eq!(good_server.server_name, "good-server");
1103        assert!(matches!(
1104            good_server.state,
1105            McpConnectionState::Failed { .. }
1106        ));
1107    }
1108
1109    #[test]
1110    fn test_backend_config_variants() {
1111        // Test Local variant
1112        let local_config = BackendConfig::Local {
1113            tool_filter: ToolFilter::Include(vec![
1114                VIEW_TOOL_NAME.to_string(),
1115                LS_TOOL_NAME.to_string(),
1116            ]),
1117        };
1118
1119        assert!(matches!(local_config, BackendConfig::Local { .. }));
1120        if let BackendConfig::Local { tool_filter } = local_config {
1121            assert!(matches!(tool_filter, ToolFilter::Include(_)));
1122            if let ToolFilter::Include(tools) = tool_filter {
1123                assert_eq!(tools.len(), 2);
1124            }
1125        }
1126
1127        // Container backend has been removed - tests removed
1128
1129        // Test Mcp variant
1130        let mcp_config = BackendConfig::Mcp {
1131            server_name: "test-mcp".to_string(),
1132            transport: crate::tools::McpTransport::Stdio {
1133                command: "python".to_string(),
1134                args: vec!["-m".to_string(), "test_server".to_string()],
1135            },
1136            tool_filter: ToolFilter::All,
1137        };
1138
1139        assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
1140        if let BackendConfig::Mcp {
1141            server_name,
1142            transport,
1143            ..
1144        } = mcp_config
1145        {
1146            assert_eq!(server_name, "test-mcp");
1147            assert!(matches!(
1148                transport,
1149                crate::tools::McpTransport::Stdio { .. }
1150            ));
1151            if let crate::tools::McpTransport::Stdio { command, args } = transport {
1152                assert_eq!(command, "python");
1153                assert_eq!(args.len(), 2);
1154            }
1155        }
1156    }
1157}