steer_core/session/
state.rs

1use crate::config::model::ModelId;
2use crate::error::Result;
3use chrono::{DateTime, Utc};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16/// State of an MCP server connection
17#[derive(Debug, Clone)]
18pub enum McpConnectionState {
19    /// Currently attempting to connect
20    Connecting,
21    /// Successfully connected
22    Connected {
23        /// Names of tools available from this server
24        tool_names: Vec<String>,
25    },
26    /// Failed to connect
27    Failed {
28        /// Error message describing the failure
29        error: String,
30    },
31}
32
33/// Information about an MCP server
34#[derive(Debug, Clone)]
35pub struct McpServerInfo {
36    /// The configured server name
37    pub server_name: String,
38    /// The transport configuration
39    pub transport: McpTransport,
40    /// Current connection state
41    pub state: McpConnectionState,
42    /// Timestamp when this state was last updated
43    pub last_updated: DateTime<Utc>,
44}
45
46/// Defines the primary execution environment for a session's workspace
47#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
48#[serde(tag = "type", rename_all = "snake_case")]
49pub enum WorkspaceConfig {
50    Local {
51        path: PathBuf,
52    },
53    Remote {
54        agent_address: String,
55        auth: Option<RemoteAuth>,
56    },
57}
58
59impl WorkspaceConfig {
60    pub fn get_path(&self) -> Option<String> {
61        match self {
62            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
63            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
64        }
65    }
66
67    /// Convert to steer_workspace::WorkspaceConfig
68    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
69        match self {
70            WorkspaceConfig::Local { path } => {
71                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
72            }
73            WorkspaceConfig::Remote {
74                agent_address,
75                auth,
76            } => steer_workspace::WorkspaceConfig::Remote {
77                address: agent_address.clone(),
78                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
79            },
80        }
81    }
82}
83
84impl Default for WorkspaceConfig {
85    fn default() -> Self {
86        Self::Local {
87            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
88        }
89    }
90}
91
92/// Complete session representation
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct Session {
95    pub id: String,
96    pub created_at: DateTime<Utc>,
97    pub updated_at: DateTime<Utc>,
98    pub config: SessionConfig,
99    pub state: SessionState,
100}
101
102impl Session {
103    pub fn new(id: String, config: SessionConfig) -> Self {
104        let now = Utc::now();
105        Self {
106            id,
107            created_at: now,
108            updated_at: now,
109            config,
110            state: SessionState::default(),
111        }
112    }
113
114    pub fn update_timestamp(&mut self) {
115        self.updated_at = Utc::now();
116    }
117
118    /// Check if session has any recent activity
119    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
120        let cutoff = Utc::now() - threshold;
121        self.updated_at > cutoff
122    }
123
124    /// Build a workspace from this session's configuration
125    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
126        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
127    }
128}
129
130/// Session configuration - immutable once created
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132pub struct SessionConfig {
133    pub workspace: WorkspaceConfig,
134    pub tool_config: SessionToolConfig,
135    /// Optional custom system prompt to use for the session. If `None`, Steer will
136    /// fall back to its built-in default prompt.
137    pub system_prompt: Option<String>,
138    pub metadata: HashMap<String, String>,
139}
140
141impl SessionConfig {
142    /// Build a BackendRegistry from this configuration for external tools only.
143    /// Workspace tools are now handled directly by the Workspace.
144    /// Returns the registry and a map of MCP server connection states.
145    pub async fn build_registry(
146        &self,
147        llm_config_provider: Arc<LlmConfigProvider>,
148        workspace: Arc<dyn crate::workspace::Workspace>,
149    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
150        let mut registry = BackendRegistry::new();
151        let mut mcp_servers = HashMap::new();
152
153        // 1. Register all USER-DEFINED backends first.
154        // Their tool mappings may be overwritten by the more authoritative backends below.
155        for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
156            match backend_config {
157                BackendConfig::Local { tool_filter } => {
158                    let backend = match tool_filter {
159                        ToolFilter::All => {
160                            LocalBackend::full(llm_config_provider.clone(), workspace.clone())
161                        }
162                        ToolFilter::Include(tools) => LocalBackend::with_tools(
163                            tools.clone(),
164                            llm_config_provider.clone(),
165                            workspace.clone(),
166                        ),
167                        ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
168                            excluded.clone(),
169                            llm_config_provider.clone(),
170                            workspace.clone(),
171                        ),
172                    };
173                    registry
174                        .register(format!("user_local_{idx}"), Arc::new(backend))
175                        .await;
176                }
177                BackendConfig::Mcp {
178                    server_name,
179                    transport,
180                    tool_filter,
181                } => {
182                    tracing::info!(
183                        "Attempting to initialize MCP backend '{}' with transport: {:?}",
184                        server_name,
185                        transport
186                    );
187
188                    // Record that we're attempting to connect
189                    let mut server_info = McpServerInfo {
190                        server_name: server_name.clone(),
191                        transport: transport.clone(),
192                        state: McpConnectionState::Connecting,
193                        last_updated: Utc::now(),
194                    };
195
196                    match crate::tools::McpBackend::new(
197                        server_name.clone(),
198                        transport.clone(),
199                        tool_filter.clone(),
200                    )
201                    .await
202                    {
203                        Ok(mcp_backend) => {
204                            let tool_names = mcp_backend.supported_tools().await;
205                            let tool_count = tool_names.len();
206                            tracing::info!(
207                                "Successfully initialized MCP backend '{}' with {} tools",
208                                server_name,
209                                tool_count
210                            );
211                            server_info.state = McpConnectionState::Connected { tool_names };
212                            server_info.last_updated = Utc::now();
213                            registry
214                                .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
215                                .await;
216                        }
217                        Err(e) => {
218                            tracing::error!(
219                                "Failed to initialize MCP backend '{}': {}",
220                                server_name,
221                                e
222                            );
223                            server_info.state = McpConnectionState::Failed {
224                                error: e.to_string(),
225                            };
226                            server_info.last_updated = Utc::now();
227                        }
228                    }
229
230                    mcp_servers.insert(server_name.clone(), server_info);
231                }
232            }
233        }
234
235        // 2. Register SERVER tools (like dispatch_agent and web_fetch).
236        // These are external tools, not workspace tools.
237        let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
238        if !server_backend.supported_tools().await.is_empty() {
239            registry
240                .register("server".to_string(), Arc::new(server_backend))
241                .await;
242        }
243
244        // Note: Workspace tools are handled directly by the Workspace implementation.
245
246        Ok((registry, mcp_servers))
247    }
248
249    /// Filter tools based on visibility settings
250    pub fn filter_tools_by_visibility(
251        &self,
252        tools: Vec<steer_tools::ToolSchema>,
253    ) -> Vec<steer_tools::ToolSchema> {
254        match &self.tool_config.visibility {
255            ToolVisibility::All => tools,
256            ToolVisibility::ReadOnly => {
257                let read_only_names: HashSet<String> = read_only_workspace_tools()
258                    .iter()
259                    .map(|t| t.name().to_string())
260                    .collect();
261
262                tools
263                    .into_iter()
264                    .filter(|schema| read_only_names.contains(&schema.name))
265                    .collect()
266            }
267            ToolVisibility::Whitelist(allowed) => tools
268                .into_iter()
269                .filter(|schema| allowed.contains(&schema.name))
270                .collect(),
271            ToolVisibility::Blacklist(blocked) => tools
272                .into_iter()
273                .filter(|schema| !blocked.contains(&schema.name))
274                .collect(),
275        }
276    }
277
278    /// Minimal read-only configuration
279    pub fn read_only() -> Self {
280        Self {
281            workspace: WorkspaceConfig::Local {
282                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
283            },
284            tool_config: SessionToolConfig::read_only(),
285            system_prompt: None,
286            metadata: HashMap::new(),
287        }
288    }
289}
290
291/// Tool visibility configuration - controls which tools are shown to the AI agent
292#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
293#[serde(tag = "type", rename_all = "snake_case")]
294pub enum ToolVisibility {
295    /// Show all registered tools to the AI
296    All,
297
298    /// Only show read-only tools to the AI
299    ReadOnly,
300
301    /// Show only specific tools to the AI (whitelist)
302    Whitelist(HashSet<String>),
303
304    /// Hide specific tools from the AI (blacklist)
305    Blacklist(HashSet<String>),
306}
307
308impl Default for ToolVisibility {
309    fn default() -> Self {
310        Self::All
311    }
312}
313
314/// Tool-specific configuration for bash
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
316pub struct BashToolConfig {
317    /// Command patterns that are pre-approved for execution
318    #[serde(default)]
319    pub approved_patterns: Vec<String>,
320}
321
322/// Tool approval policy configuration
323#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
324#[serde(tag = "type", rename_all = "snake_case")]
325pub enum ToolApprovalPolicy {
326    /// Always ask for approval before executing any tool
327    AlwaysAsk,
328
329    /// Pre-approved tools execute without asking
330    PreApproved { tools: HashSet<String> },
331
332    /// Mixed policy: some tools pre-approved, others require approval
333    Mixed {
334        pre_approved: HashSet<String>,
335        ask_for_others: bool,
336    },
337}
338
339impl ToolApprovalPolicy {
340    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
341        match self {
342            ToolApprovalPolicy::AlwaysAsk => false,
343            ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
344            ToolApprovalPolicy::Mixed {
345                pre_approved,
346                ask_for_others: _,
347            } => pre_approved.contains(tool_name),
348        }
349    }
350
351    pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
352        match self {
353            ToolApprovalPolicy::AlwaysAsk => true,
354            ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
355            ToolApprovalPolicy::Mixed {
356                pre_approved,
357                ask_for_others,
358            } => !pre_approved.contains(tool_name) && *ask_for_others,
359        }
360    }
361}
362
363/// Authentication configuration for remote backends
364#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
365pub enum RemoteAuth {
366    Bearer { token: String },
367    ApiKey { key: String },
368}
369
370impl RemoteAuth {
371    /// Convert to steer_workspace RemoteAuth type
372    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
373        match self {
374            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
375            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
376        }
377    }
378}
379
380/// Tool filtering configuration for backends
381#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
382#[serde(rename_all = "snake_case")]
383pub enum ToolFilter {
384    /// Include all available tools
385    All,
386    /// Include only the specified tools
387    Include(Vec<String>),
388    /// Include all tools except the specified ones
389    Exclude(Vec<String>),
390}
391
392impl Default for ToolFilter {
393    fn default() -> Self {
394        Self::All
395    }
396}
397
398/// Backend configuration for different tool execution environments
399#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
400#[serde(tag = "type", rename_all = "snake_case")]
401pub enum BackendConfig {
402    Local {
403        /// Tool filtering configuration for the local backend
404        tool_filter: ToolFilter,
405    },
406    Mcp {
407        server_name: String,
408        transport: McpTransport,
409        tool_filter: ToolFilter,
410    },
411}
412
413/// Tool configuration for the session
414#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
415pub struct SessionToolConfig {
416    /// Backend configurations for this session
417    pub backends: Vec<BackendConfig>,
418    /// Tool visibility - controls which tools are shown to the AI agent
419    pub visibility: ToolVisibility,
420    /// Tool approval policy - controls when user approval is needed
421    pub approval_policy: ToolApprovalPolicy,
422    /// Additional metadata for tool configuration
423    pub metadata: HashMap<String, String>,
424    /// Tool-specific configurations
425    #[serde(default)]
426    pub tools: HashMap<String, ToolSpecificConfig>,
427}
428
429/// Tool-specific configurations
430#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
431#[serde(untagged)]
432pub enum ToolSpecificConfig {
433    /// Configuration for the bash tool
434    Bash(BashToolConfig),
435}
436
437impl Default for SessionToolConfig {
438    fn default() -> Self {
439        Self {
440            backends: Vec::new(),
441            visibility: ToolVisibility::All,
442            approval_policy: ToolApprovalPolicy::AlwaysAsk,
443            metadata: HashMap::new(),
444            tools: HashMap::new(),
445        }
446    }
447}
448
449impl SessionToolConfig {
450    /// Minimal read-only configuration
451    pub fn read_only() -> Self {
452        Self {
453            backends: Vec::new(), // Use default backends
454            visibility: ToolVisibility::ReadOnly,
455            approval_policy: ToolApprovalPolicy::AlwaysAsk,
456            metadata: HashMap::new(),
457            tools: HashMap::new(),
458        }
459    }
460}
461
462/// Mutable session state that changes during execution
463#[derive(Debug, Clone, Serialize, Deserialize, Default)]
464pub struct SessionState {
465    /// Conversation messages
466    pub messages: Vec<Message>,
467
468    /// Tool call tracking
469    pub tool_calls: HashMap<String, ToolCallState>,
470
471    /// Tools that have been approved for this session
472    pub approved_tools: HashSet<String>,
473
474    /// Bash commands that have been approved for this session (dynamically added)
475    #[serde(default)]
476    pub approved_bash_patterns: HashSet<String>,
477
478    /// Last processed event sequence number for replay
479    pub last_event_sequence: u64,
480
481    /// Additional runtime metadata
482    pub metadata: HashMap<String, String>,
483
484    /// The ID of the currently active message (head of selected branch)
485    /// None means use last message semantics for backward compatibility
486    #[serde(default, skip_serializing_if = "Option::is_none")]
487    pub active_message_id: Option<String>,
488
489    /// Status of MCP server connections
490    /// This is a transient field that is rebuilt on session activation
491    #[serde(default, skip_serializing, skip_deserializing)]
492    pub mcp_servers: HashMap<String, McpServerInfo>,
493}
494
495impl SessionState {
496    /// Add a message to the conversation
497    pub fn add_message(&mut self, message: Message) {
498        self.messages.push(message);
499    }
500
501    /// Get the number of messages in the conversation
502    pub fn message_count(&self) -> usize {
503        self.messages.len()
504    }
505
506    /// Get the last message in the conversation
507    pub fn last_message(&self) -> Option<&Message> {
508        self.messages.last()
509    }
510
511    /// Add a tool call to tracking
512    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
513        let state = ToolCallState {
514            tool_call: tool_call.clone(),
515            status: ToolCallStatus::PendingApproval,
516            started_at: None,
517            completed_at: None,
518            result: None,
519        };
520        self.tool_calls.insert(tool_call.id, state);
521    }
522
523    /// Update tool call status
524    pub fn update_tool_call_status(
525        &mut self,
526        tool_call_id: &str,
527        status: ToolCallStatus,
528    ) -> std::result::Result<(), String> {
529        let tool_call = self
530            .tool_calls
531            .get_mut(tool_call_id)
532            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
533
534        // Update timestamps based on status changes
535        match (&tool_call.status, &status) {
536            (_, ToolCallStatus::Executing) => {
537                tool_call.started_at = Some(Utc::now());
538            }
539            (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
540                tool_call.completed_at = Some(Utc::now());
541            }
542            _ => {}
543        }
544
545        tool_call.status = status;
546        Ok(())
547    }
548
549    /// Approve a tool for future use
550    pub fn approve_tool(&mut self, tool_name: String) {
551        self.approved_tools.insert(tool_name);
552    }
553
554    /// Check if a tool is approved
555    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
556        self.approved_tools.contains(tool_name)
557    }
558
559    /// Validate internal consistency
560    pub fn validate(&self) -> std::result::Result<(), String> {
561        // Check that all tool calls referenced in messages exist
562        for message in &self.messages {
563            let tool_calls = self.extract_tool_calls_from_message(message);
564            if !tool_calls.is_empty() {
565                for tool_call_id in tool_calls {
566                    if !self.tool_calls.contains_key(&tool_call_id) {
567                        return Err(format!(
568                            "Message references unknown tool call: {tool_call_id}"
569                        ));
570                    }
571                }
572            }
573        }
574
575        Ok(())
576    }
577
578    /// Extract tool call IDs from a message
579    fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
580        let mut tool_call_ids = Vec::new();
581
582        match &message.data {
583            MessageData::Assistant { content, .. } => {
584                for c in content {
585                    if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
586                        tool_call_ids.push(tool_call.id.clone());
587                    }
588                }
589            }
590            MessageData::Tool { tool_use_id, .. } => {
591                tool_call_ids.push(tool_use_id.clone());
592            }
593            _ => {}
594        }
595
596        tool_call_ids
597    }
598
599    /// Apply an event to the session state
600    pub fn apply_event(
601        &mut self,
602        event: &crate::events::StreamEvent,
603    ) -> std::result::Result<(), String> {
604        use crate::events::StreamEvent;
605
606        match event {
607            StreamEvent::MessageComplete { message, .. } => {
608                self.add_message(message.clone());
609            }
610            StreamEvent::ToolCallStarted { tool_call, .. } => {
611                self.add_tool_call(tool_call.clone());
612            }
613            StreamEvent::ToolCallCompleted {
614                tool_call_id,
615                result,
616                ..
617            } => {
618                self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
619                if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
620                    tool_call_state.result = Some(result.clone());
621                }
622            }
623            StreamEvent::ToolCallFailed {
624                tool_call_id,
625                error,
626                ..
627            } => {
628                self.update_tool_call_status(
629                    tool_call_id,
630                    ToolCallStatus::Failed {
631                        error: error.clone(),
632                    },
633                )?;
634            }
635            StreamEvent::ToolApprovalRequired { tool_call, .. } => {
636                // Tool call should already be added with PendingApproval status
637                if !self.tool_calls.contains_key(&tool_call.id) {
638                    self.add_tool_call(tool_call.clone());
639                }
640            }
641            // Other events don't modify state directly
642            _ => {}
643        }
644
645        Ok(())
646    }
647}
648
649/// Tool call state tracking
650#[derive(Debug, Clone, Serialize, Deserialize)]
651pub struct ToolCallState {
652    pub tool_call: ToolCall,
653    pub status: ToolCallStatus,
654    pub started_at: Option<DateTime<Utc>>,
655    pub completed_at: Option<DateTime<Utc>>,
656    pub result: Option<ToolResult>,
657}
658
659impl ToolCallState {
660    pub fn is_pending(&self) -> bool {
661        matches!(self.status, ToolCallStatus::PendingApproval)
662    }
663
664    pub fn is_complete(&self) -> bool {
665        matches!(
666            self.status,
667            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
668        )
669    }
670
671    pub fn duration(&self) -> Option<chrono::Duration> {
672        match (self.started_at, self.completed_at) {
673            (Some(start), Some(end)) => Some(end - start),
674            _ => None,
675        }
676    }
677}
678
679/// Tool call execution status
680#[derive(Debug, Clone, Serialize, Deserialize)]
681#[serde(tag = "status", rename_all = "snake_case")]
682pub enum ToolCallStatus {
683    PendingApproval,
684    Approved,
685    Denied,
686    Executing,
687    Completed,
688    Failed { error: String },
689}
690
691impl ToolCallStatus {
692    pub fn is_terminal(&self) -> bool {
693        matches!(
694            self,
695            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
696        )
697    }
698}
699
700/// Tool execution statistics
701#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct ToolExecutionStats {
703    #[serde(skip_serializing_if = "Option::is_none")]
704    pub output: Option<String>, // Legacy string output
705    #[serde(skip_serializing_if = "Option::is_none")]
706    pub json_output: Option<serde_json::Value>, // Typed JSON output
707    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
708    pub success: bool,
709    pub execution_time_ms: u64,
710    pub metadata: HashMap<String, String>,
711}
712
713impl ToolExecutionStats {
714    pub fn success(output: String, execution_time_ms: u64) -> Self {
715        Self {
716            output: Some(output),
717            json_output: None,
718            result_type: None,
719            success: true,
720            execution_time_ms,
721            metadata: HashMap::new(),
722        }
723    }
724
725    pub fn success_typed(
726        json_output: serde_json::Value,
727        result_type: String,
728        execution_time_ms: u64,
729    ) -> Self {
730        Self {
731            output: None,
732            json_output: Some(json_output),
733            result_type: Some(result_type),
734            success: true,
735            execution_time_ms,
736            metadata: HashMap::new(),
737        }
738    }
739
740    pub fn failure(error: String, execution_time_ms: u64) -> Self {
741        Self {
742            output: Some(error),
743            json_output: None,
744            result_type: None,
745            success: false,
746            execution_time_ms,
747            metadata: HashMap::new(),
748        }
749    }
750
751    pub fn with_metadata(mut self, key: String, value: String) -> Self {
752        self.metadata.insert(key, value);
753        self
754    }
755}
756
757/// Session metadata for listing and filtering
758#[derive(Debug, Clone, Serialize, Deserialize)]
759pub struct SessionInfo {
760    pub id: String,
761    pub created_at: DateTime<Utc>,
762    pub updated_at: DateTime<Utc>,
763    /// The last known model used in this session
764    pub last_model: Option<ModelId>,
765    pub message_count: usize,
766    pub metadata: HashMap<String, String>,
767}
768
769impl From<&Session> for SessionInfo {
770    fn from(session: &Session) -> Self {
771        Self {
772            id: session.id.clone(),
773            created_at: session.created_at,
774            updated_at: session.updated_at,
775            last_model: None, // TODO: Track last model used from events
776            message_count: session.state.message_count(),
777            metadata: session.config.metadata.clone(),
778        }
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::app::conversation::{Message, MessageData, UserContent};
786    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
787
788    #[test]
789    fn test_session_creation() {
790        let config = SessionConfig {
791            workspace: WorkspaceConfig::Local {
792                path: PathBuf::from("/test/path"),
793            },
794            tool_config: SessionToolConfig::default(),
795            system_prompt: None,
796            metadata: HashMap::new(),
797        };
798        let session = Session::new("test-session".to_string(), config.clone());
799
800        assert_eq!(session.id, "test-session");
801        assert!(
802            session
803                .config
804                .tool_config
805                .approval_policy
806                .should_ask_for_approval("any_tool")
807        );
808        assert_eq!(session.state.message_count(), 0);
809    }
810
811    #[test]
812    fn test_tool_approval_policy() {
813        let policy = ToolApprovalPolicy::PreApproved {
814            tools: ["read_file", "list_files"]
815                .iter()
816                .map(|s| s.to_string())
817                .collect(),
818        };
819
820        assert!(policy.is_tool_approved("read_file"));
821        assert!(!policy.is_tool_approved("write_file"));
822        assert!(!policy.should_ask_for_approval("read_file"));
823        assert!(policy.should_ask_for_approval("write_file"));
824    }
825
826    #[test]
827    fn test_session_state_validation() {
828        let mut state = SessionState::default();
829
830        // Valid empty state
831        assert!(state.validate().is_ok());
832
833        // Add a message
834        let message = Message {
835            data: MessageData::User {
836                content: vec![UserContent::Text {
837                    text: "Hello".to_string(),
838                }],
839            },
840            timestamp: 123456789,
841            id: "msg1".to_string(),
842            parent_message_id: None,
843        };
844        state.add_message(message);
845
846        assert!(state.validate().is_ok());
847        assert_eq!(state.message_count(), 1);
848    }
849
850    #[test]
851    fn test_tool_call_state_tracking() {
852        let mut state = SessionState::default();
853
854        let tool_call = ToolCall {
855            id: "tool1".to_string(),
856            name: "read_file".to_string(),
857            parameters: serde_json::json!({"path": "/test.txt"}),
858        };
859
860        state.add_tool_call(tool_call.clone());
861        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
862
863        state
864            .update_tool_call_status("tool1", ToolCallStatus::Executing)
865            .unwrap();
866        let tool_state = state.tool_calls.get("tool1").unwrap();
867        assert!(tool_state.started_at.is_some());
868        assert!(!tool_state.is_complete());
869
870        state
871            .update_tool_call_status("tool1", ToolCallStatus::Completed)
872            .unwrap();
873        let tool_state = state.tool_calls.get("tool1").unwrap();
874        assert!(tool_state.completed_at.is_some());
875        assert!(tool_state.is_complete());
876    }
877
878    #[test]
879    fn test_session_tool_config_default() {
880        let config = SessionToolConfig::default();
881        assert!(config.backends.is_empty());
882    }
883
884    #[test]
885    fn test_tool_filter_exclude() {
886        // Test that we can exclude specific tools
887        let config = SessionToolConfig {
888            backends: vec![BackendConfig::Local {
889                tool_filter: ToolFilter::Exclude(vec![
890                    BASH_TOOL_NAME.to_string(),
891                    EDIT_TOOL_NAME.to_string(),
892                ]),
893            }],
894            visibility: ToolVisibility::All,
895            approval_policy: ToolApprovalPolicy::AlwaysAsk,
896            metadata: HashMap::new(),
897            tools: HashMap::new(),
898        };
899
900        assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
901        if let BackendConfig::Local { tool_filter } = &config.backends[0] {
902            assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
903            if let ToolFilter::Exclude(excluded_tools) = tool_filter {
904                assert_eq!(excluded_tools.len(), 2);
905                assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
906                assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
907            }
908        }
909    }
910
911    #[test]
912    fn test_session_tool_config_read_only() {
913        let config = SessionToolConfig::read_only();
914        assert_eq!(config.backends.len(), 0); // Empty backends means use defaults
915        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
916        assert!(matches!(
917            config.approval_policy,
918            ToolApprovalPolicy::AlwaysAsk
919        ));
920    }
921
922    #[tokio::test]
923    async fn test_session_config_build_registry_server_tools() {
924        use crate::auth::DefaultAuthStorage;
925        use crate::config::LlmConfigProvider;
926
927        // Test that server tools are properly registered
928        let config = SessionConfig {
929            workspace: WorkspaceConfig::Local {
930                path: PathBuf::from("/test/path"),
931            },
932            tool_config: SessionToolConfig::default(),
933            system_prompt: None,
934            metadata: HashMap::new(),
935        };
936
937        // For tests, we'll just unwrap since it's a test environment
938        let auth_storage =
939            DefaultAuthStorage::new().expect("Failed to create auth storage for test");
940        let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
941
942        // Create a test workspace
943        let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
944            .await
945            .unwrap();
946
947        let (registry, _mcp_servers) = config
948            .build_registry(llm_config_provider, workspace)
949            .await
950            .unwrap();
951        let schemas = registry.get_tool_schemas().await;
952        let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
953
954        // Only server tools should be in the registry
955        assert!(tool_names.contains(&"dispatch_agent".to_string()));
956        assert!(tool_names.contains(&"web_fetch".to_string()));
957
958        // Verify workspace tools are NOT in the registry (they're handled by Workspace)
959        let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
960        for tool_name in workspace_tool_names {
961            assert!(
962                !tool_names.contains(&tool_name.to_string()),
963                "Workspace tool {tool_name} should not be in registry"
964            );
965        }
966    }
967
968    // Test removed: workspace tools are no longer in the registry
969
970    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
971
972    // Test removed: workspace backend no longer exists in the registry
973
974    #[test]
975    fn test_mcp_status_tracking() {
976        // Test that MCP server info is properly tracked in session state
977        let mut session_state = SessionState::default();
978
979        // Add some MCP server info
980        let mcp_info = McpServerInfo {
981            server_name: "test-server".to_string(),
982            transport: crate::tools::McpTransport::Stdio {
983                command: "python".to_string(),
984                args: vec!["-m".to_string(), "test_server".to_string()],
985            },
986            state: McpConnectionState::Connected {
987                tool_names: vec![
988                    "tool1".to_string(),
989                    "tool2".to_string(),
990                    "tool3".to_string(),
991                    "tool4".to_string(),
992                    "tool5".to_string(),
993                ],
994            },
995            last_updated: Utc::now(),
996        };
997
998        session_state
999            .mcp_servers
1000            .insert("test-server".to_string(), mcp_info.clone());
1001
1002        // Verify it's stored
1003        assert_eq!(session_state.mcp_servers.len(), 1);
1004        let stored = session_state.mcp_servers.get("test-server").unwrap();
1005        assert_eq!(stored.server_name, "test-server");
1006        assert!(matches!(
1007            stored.state,
1008            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1009        ));
1010
1011        // Test failed connection
1012        let failed_info = McpServerInfo {
1013            server_name: "failed-server".to_string(),
1014            transport: crate::tools::McpTransport::Tcp {
1015                host: "localhost".to_string(),
1016                port: 9999,
1017            },
1018            state: McpConnectionState::Failed {
1019                error: "Connection refused".to_string(),
1020            },
1021            last_updated: Utc::now(),
1022        };
1023
1024        session_state
1025            .mcp_servers
1026            .insert("failed-server".to_string(), failed_info);
1027        assert_eq!(session_state.mcp_servers.len(), 2);
1028    }
1029
1030    #[tokio::test]
1031    async fn test_mcp_server_tracking_in_build_registry() {
1032        use crate::auth::DefaultAuthStorage;
1033        use crate::config::LlmConfigProvider;
1034
1035        // Create a session config with both good and bad MCP servers
1036        let mut config = SessionConfig::read_only();
1037
1038        // This one should fail (invalid transport)
1039        config.tool_config.backends.push(BackendConfig::Mcp {
1040            server_name: "bad-server".to_string(),
1041            transport: crate::tools::McpTransport::Tcp {
1042                host: "nonexistent.invalid".to_string(),
1043                port: 12345,
1044            },
1045            tool_filter: ToolFilter::All,
1046        });
1047
1048        // This one would succeed if we had a real server running
1049        config.tool_config.backends.push(BackendConfig::Mcp {
1050            server_name: "good-server".to_string(),
1051            transport: crate::tools::McpTransport::Stdio {
1052                command: "echo".to_string(),
1053                args: vec!["test".to_string()],
1054            },
1055            tool_filter: ToolFilter::All,
1056        });
1057
1058        let auth_storage =
1059            DefaultAuthStorage::new().expect("Failed to create auth storage for test");
1060        let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
1061        let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
1062            .await
1063            .unwrap();
1064
1065        let (_registry, mcp_servers) = config
1066            .build_registry(llm_config_provider, workspace)
1067            .await
1068            .unwrap();
1069
1070        // Should have tracked both servers
1071        assert_eq!(mcp_servers.len(), 2);
1072
1073        // Check the bad server
1074        let bad_server = mcp_servers.get("bad-server").unwrap();
1075        assert_eq!(bad_server.server_name, "bad-server");
1076        assert!(matches!(
1077            bad_server.state,
1078            McpConnectionState::Failed { .. }
1079        ));
1080
1081        // Check the good server (will also fail in tests since echo isn't an MCP server)
1082        let good_server = mcp_servers.get("good-server").unwrap();
1083        assert_eq!(good_server.server_name, "good-server");
1084        assert!(matches!(
1085            good_server.state,
1086            McpConnectionState::Failed { .. }
1087        ));
1088    }
1089
1090    #[test]
1091    fn test_backend_config_variants() {
1092        // Test Local variant
1093        let local_config = BackendConfig::Local {
1094            tool_filter: ToolFilter::Include(vec![
1095                VIEW_TOOL_NAME.to_string(),
1096                LS_TOOL_NAME.to_string(),
1097            ]),
1098        };
1099
1100        assert!(matches!(local_config, BackendConfig::Local { .. }));
1101        if let BackendConfig::Local { tool_filter } = local_config {
1102            assert!(matches!(tool_filter, ToolFilter::Include(_)));
1103            if let ToolFilter::Include(tools) = tool_filter {
1104                assert_eq!(tools.len(), 2);
1105            }
1106        }
1107
1108        // Test Mcp variant
1109        let mcp_config = BackendConfig::Mcp {
1110            server_name: "test-mcp".to_string(),
1111            transport: crate::tools::McpTransport::Stdio {
1112                command: "python".to_string(),
1113                args: vec!["-m".to_string(), "test_server".to_string()],
1114            },
1115            tool_filter: ToolFilter::All,
1116        };
1117
1118        assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
1119        if let BackendConfig::Mcp {
1120            server_name,
1121            transport,
1122            ..
1123        } = mcp_config
1124        {
1125            assert_eq!(server_name, "test-mcp");
1126            assert!(matches!(
1127                transport,
1128                crate::tools::McpTransport::Stdio { .. }
1129            ));
1130            if let crate::tools::McpTransport::Stdio { command, args } = transport {
1131                assert_eq!(command, "python");
1132                assert_eq!(args.len(), 2);
1133            }
1134        }
1135    }
1136}