steer_core/session/
state.rs

1use crate::error::Result;
2use chrono::{DateTime, Utc};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use crate::api::Model;
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16/// Container runtime to use
17#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
18#[serde(rename_all = "snake_case")]
19pub enum ContainerRuntime {
20    Docker,
21    Podman,
22}
23
24/// Defines the primary execution environment for a session's workspace
25#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum WorkspaceConfig {
28    Local {
29        path: PathBuf,
30    },
31    Remote {
32        agent_address: String,
33        auth: Option<RemoteAuth>,
34    },
35    Container {
36        image: String,
37        runtime: ContainerRuntime,
38    },
39}
40
41impl WorkspaceConfig {
42    pub fn get_path(&self) -> Option<String> {
43        match self {
44            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
45            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
46            WorkspaceConfig::Container { .. } => None,
47        }
48    }
49
50    /// Convert to steer_workspace::WorkspaceConfig
51    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
52        match self {
53            WorkspaceConfig::Local { path } => {
54                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
55            }
56            WorkspaceConfig::Remote {
57                agent_address,
58                auth,
59            } => steer_workspace::WorkspaceConfig::Remote {
60                address: agent_address.clone(),
61                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
62            },
63            WorkspaceConfig::Container { .. } => {
64                // For now, container workspaces are not implemented in steer_workspace
65                steer_workspace::WorkspaceConfig::Local {
66                    path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
67                }
68            }
69        }
70    }
71}
72
73impl Default for WorkspaceConfig {
74    fn default() -> Self {
75        Self::Local {
76            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
77        }
78    }
79}
80
81/// Complete session representation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Session {
84    pub id: String,
85    pub created_at: DateTime<Utc>,
86    pub updated_at: DateTime<Utc>,
87    pub config: SessionConfig,
88    pub state: SessionState,
89}
90
91impl Session {
92    pub fn new(id: String, config: SessionConfig) -> Self {
93        let now = Utc::now();
94        Self {
95            id,
96            created_at: now,
97            updated_at: now,
98            config,
99            state: SessionState::default(),
100        }
101    }
102
103    pub fn update_timestamp(&mut self) {
104        self.updated_at = Utc::now();
105    }
106
107    /// Check if session has any recent activity
108    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
109        let cutoff = Utc::now() - threshold;
110        self.updated_at > cutoff
111    }
112
113    /// Build a workspace from this session's configuration
114    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
115        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
116    }
117}
118
119/// Session configuration - immutable once created
120#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
121pub struct SessionConfig {
122    pub workspace: WorkspaceConfig,
123    pub tool_config: SessionToolConfig,
124    /// Optional custom system prompt to use for the session. If `None`, Steer will
125    /// fall back to its built-in default prompt.
126    pub system_prompt: Option<String>,
127    pub metadata: HashMap<String, String>,
128}
129
130impl SessionConfig {
131    /// Build a BackendRegistry from this configuration for external tools only.
132    /// Workspace tools are now handled directly by the Workspace.
133    pub async fn build_registry(
134        &self,
135        llm_config_provider: Arc<LlmConfigProvider>,
136        workspace: Arc<dyn crate::workspace::Workspace>,
137    ) -> Result<BackendRegistry> {
138        let mut registry = BackendRegistry::new();
139
140        // 1. Register all USER-DEFINED backends first.
141        // Their tool mappings may be overwritten by the more authoritative backends below.
142        for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
143            match backend_config {
144                BackendConfig::Local { tool_filter } => {
145                    let backend = match tool_filter {
146                        ToolFilter::All => {
147                            LocalBackend::full(llm_config_provider.clone(), workspace.clone())
148                        }
149                        ToolFilter::Include(tools) => LocalBackend::with_tools(
150                            tools.clone(),
151                            llm_config_provider.clone(),
152                            workspace.clone(),
153                        ),
154                        ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
155                            excluded.clone(),
156                            llm_config_provider.clone(),
157                            workspace.clone(),
158                        ),
159                    };
160                    registry
161                        .register(format!("user_local_{idx}"), Arc::new(backend))
162                        .await;
163                }
164                BackendConfig::Mcp {
165                    server_name,
166                    transport,
167                    tool_filter,
168                } => {
169                    tracing::info!(
170                        "Attempting to initialize MCP backend '{}' with transport: {:?}",
171                        server_name,
172                        transport
173                    );
174                    match crate::tools::McpBackend::new(
175                        server_name.clone(),
176                        transport.clone(),
177                        tool_filter.clone(),
178                    )
179                    .await
180                    {
181                        Ok(mcp_backend) => {
182                            tracing::info!(
183                                "Successfully initialized MCP backend '{}'",
184                                server_name
185                            );
186                            registry
187                                .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
188                                .await;
189                        }
190                        Err(e) => {
191                            tracing::error!(
192                                "Failed to initialize MCP backend '{}': {}",
193                                server_name,
194                                e
195                            );
196                        }
197                    }
198                }
199            }
200        }
201
202        // 2. Register SERVER tools (like dispatch_agent and web_fetch).
203        // These are external tools, not workspace tools.
204        let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
205        if !server_backend.supported_tools().await.is_empty() {
206            registry
207                .register("server".to_string(), Arc::new(server_backend))
208                .await;
209        }
210
211        // Note: Workspace tools are handled directly by the Workspace implementation.
212
213        Ok(registry)
214    }
215
216    /// Filter tools based on visibility settings
217    pub fn filter_tools_by_visibility(
218        &self,
219        tools: Vec<steer_tools::ToolSchema>,
220    ) -> Vec<steer_tools::ToolSchema> {
221        match &self.tool_config.visibility {
222            ToolVisibility::All => tools,
223            ToolVisibility::ReadOnly => {
224                let read_only_names: HashSet<String> = read_only_workspace_tools()
225                    .iter()
226                    .map(|t| t.name().to_string())
227                    .collect();
228
229                tools
230                    .into_iter()
231                    .filter(|schema| read_only_names.contains(&schema.name))
232                    .collect()
233            }
234            ToolVisibility::Whitelist(allowed) => tools
235                .into_iter()
236                .filter(|schema| allowed.contains(&schema.name))
237                .collect(),
238            ToolVisibility::Blacklist(blocked) => tools
239                .into_iter()
240                .filter(|schema| !blocked.contains(&schema.name))
241                .collect(),
242        }
243    }
244
245    /// Minimal read-only configuration
246    pub fn read_only() -> Self {
247        Self {
248            workspace: WorkspaceConfig::Local {
249                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
250            },
251            tool_config: SessionToolConfig::read_only(),
252            system_prompt: None,
253            metadata: HashMap::new(),
254        }
255    }
256}
257
258/// Tool visibility configuration - controls which tools are shown to the AI agent
259#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
260#[serde(tag = "type", rename_all = "snake_case")]
261pub enum ToolVisibility {
262    /// Show all registered tools to the AI
263    All,
264
265    /// Only show read-only tools to the AI
266    ReadOnly,
267
268    /// Show only specific tools to the AI (whitelist)
269    Whitelist(HashSet<String>),
270
271    /// Hide specific tools from the AI (blacklist)
272    Blacklist(HashSet<String>),
273}
274
275impl Default for ToolVisibility {
276    fn default() -> Self {
277        Self::All
278    }
279}
280
281/// Tool-specific configuration for bash
282#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
283pub struct BashToolConfig {
284    /// Command patterns that are pre-approved for execution
285    #[serde(default)]
286    pub approved_patterns: Vec<String>,
287}
288
289/// Tool approval policy configuration
290#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
291#[serde(tag = "type", rename_all = "snake_case")]
292pub enum ToolApprovalPolicy {
293    /// Always ask for approval before executing any tool
294    AlwaysAsk,
295
296    /// Pre-approved tools execute without asking
297    PreApproved { tools: HashSet<String> },
298
299    /// Mixed policy: some tools pre-approved, others require approval
300    Mixed {
301        pre_approved: HashSet<String>,
302        ask_for_others: bool,
303    },
304}
305
306impl ToolApprovalPolicy {
307    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
308        match self {
309            ToolApprovalPolicy::AlwaysAsk => false,
310            ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
311            ToolApprovalPolicy::Mixed {
312                pre_approved,
313                ask_for_others: _,
314            } => pre_approved.contains(tool_name),
315        }
316    }
317
318    pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
319        match self {
320            ToolApprovalPolicy::AlwaysAsk => true,
321            ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
322            ToolApprovalPolicy::Mixed {
323                pre_approved,
324                ask_for_others,
325            } => !pre_approved.contains(tool_name) && *ask_for_others,
326        }
327    }
328}
329
330/// Authentication configuration for remote backends
331#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
332pub enum RemoteAuth {
333    Bearer { token: String },
334    ApiKey { key: String },
335}
336
337impl RemoteAuth {
338    /// Convert to steer_workspace RemoteAuth type
339    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
340        match self {
341            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
342            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
343        }
344    }
345}
346
347/// Tool filtering configuration for backends
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
349#[serde(rename_all = "snake_case")]
350pub enum ToolFilter {
351    /// Include all available tools
352    All,
353    /// Include only the specified tools
354    Include(Vec<String>),
355    /// Include all tools except the specified ones
356    Exclude(Vec<String>),
357}
358
359impl Default for ToolFilter {
360    fn default() -> Self {
361        Self::All
362    }
363}
364
365/// Backend configuration for different tool execution environments
366#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
367#[serde(tag = "type", rename_all = "snake_case")]
368pub enum BackendConfig {
369    Local {
370        /// Tool filtering configuration for the local backend
371        tool_filter: ToolFilter,
372    },
373    Mcp {
374        server_name: String,
375        transport: McpTransport,
376        tool_filter: ToolFilter,
377    },
378}
379
380/// Tool configuration for the session
381#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
382pub struct SessionToolConfig {
383    /// Backend configurations for this session
384    pub backends: Vec<BackendConfig>,
385    /// Tool visibility - controls which tools are shown to the AI agent
386    pub visibility: ToolVisibility,
387    /// Tool approval policy - controls when user approval is needed
388    pub approval_policy: ToolApprovalPolicy,
389    /// Additional metadata for tool configuration
390    pub metadata: HashMap<String, String>,
391    /// Tool-specific configurations
392    #[serde(default)]
393    pub tools: HashMap<String, ToolSpecificConfig>,
394}
395
396/// Tool-specific configurations
397#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
398#[serde(untagged)]
399pub enum ToolSpecificConfig {
400    /// Configuration for the bash tool
401    Bash(BashToolConfig),
402}
403
404impl Default for SessionToolConfig {
405    fn default() -> Self {
406        Self {
407            backends: Vec::new(),
408            visibility: ToolVisibility::All,
409            approval_policy: ToolApprovalPolicy::AlwaysAsk,
410            metadata: HashMap::new(),
411            tools: HashMap::new(),
412        }
413    }
414}
415
416impl SessionToolConfig {
417    /// Minimal read-only configuration
418    pub fn read_only() -> Self {
419        Self {
420            backends: Vec::new(), // Use default backends
421            visibility: ToolVisibility::ReadOnly,
422            approval_policy: ToolApprovalPolicy::AlwaysAsk,
423            metadata: HashMap::new(),
424            tools: HashMap::new(),
425        }
426    }
427}
428
429/// Mutable session state that changes during execution
430#[derive(Debug, Clone, Serialize, Deserialize, Default)]
431pub struct SessionState {
432    /// Conversation messages
433    pub messages: Vec<Message>,
434
435    /// Tool call tracking
436    pub tool_calls: HashMap<String, ToolCallState>,
437
438    /// Tools that have been approved for this session
439    pub approved_tools: HashSet<String>,
440
441    /// Bash commands that have been approved for this session (dynamically added)
442    #[serde(default)]
443    pub approved_bash_patterns: HashSet<String>,
444
445    /// Last processed event sequence number for replay
446    pub last_event_sequence: u64,
447
448    /// Additional runtime metadata
449    pub metadata: HashMap<String, String>,
450
451    /// The ID of the currently active message (head of selected branch)
452    /// None means use last message semantics for backward compatibility
453    #[serde(default, skip_serializing_if = "Option::is_none")]
454    pub active_message_id: Option<String>,
455}
456
457impl SessionState {
458    /// Add a message to the conversation
459    pub fn add_message(&mut self, message: Message) {
460        self.messages.push(message);
461    }
462
463    /// Get the number of messages in the conversation
464    pub fn message_count(&self) -> usize {
465        self.messages.len()
466    }
467
468    /// Get the last message in the conversation
469    pub fn last_message(&self) -> Option<&Message> {
470        self.messages.last()
471    }
472
473    /// Add a tool call to tracking
474    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
475        let state = ToolCallState {
476            tool_call: tool_call.clone(),
477            status: ToolCallStatus::PendingApproval,
478            started_at: None,
479            completed_at: None,
480            result: None,
481        };
482        self.tool_calls.insert(tool_call.id, state);
483    }
484
485    /// Update tool call status
486    pub fn update_tool_call_status(
487        &mut self,
488        tool_call_id: &str,
489        status: ToolCallStatus,
490    ) -> std::result::Result<(), String> {
491        let tool_call = self
492            .tool_calls
493            .get_mut(tool_call_id)
494            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
495
496        // Update timestamps based on status changes
497        match (&tool_call.status, &status) {
498            (_, ToolCallStatus::Executing) => {
499                tool_call.started_at = Some(Utc::now());
500            }
501            (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
502                tool_call.completed_at = Some(Utc::now());
503            }
504            _ => {}
505        }
506
507        tool_call.status = status;
508        Ok(())
509    }
510
511    /// Approve a tool for future use
512    pub fn approve_tool(&mut self, tool_name: String) {
513        self.approved_tools.insert(tool_name);
514    }
515
516    /// Check if a tool is approved
517    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
518        self.approved_tools.contains(tool_name)
519    }
520
521    /// Validate internal consistency
522    pub fn validate(&self) -> std::result::Result<(), String> {
523        // Check that all tool calls referenced in messages exist
524        for message in &self.messages {
525            let tool_calls = self.extract_tool_calls_from_message(message);
526            if !tool_calls.is_empty() {
527                for tool_call_id in tool_calls {
528                    if !self.tool_calls.contains_key(&tool_call_id) {
529                        return Err(format!(
530                            "Message references unknown tool call: {tool_call_id}"
531                        ));
532                    }
533                }
534            }
535        }
536
537        Ok(())
538    }
539
540    /// Extract tool call IDs from a message
541    fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
542        let mut tool_call_ids = Vec::new();
543
544        match &message.data {
545            MessageData::Assistant { content, .. } => {
546                for c in content {
547                    if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
548                        tool_call_ids.push(tool_call.id.clone());
549                    }
550                }
551            }
552            MessageData::Tool { tool_use_id, .. } => {
553                tool_call_ids.push(tool_use_id.clone());
554            }
555            _ => {}
556        }
557
558        tool_call_ids
559    }
560
561    /// Apply an event to the session state
562    pub fn apply_event(
563        &mut self,
564        event: &crate::events::StreamEvent,
565    ) -> std::result::Result<(), String> {
566        use crate::events::StreamEvent;
567
568        match event {
569            StreamEvent::MessageComplete { message, .. } => {
570                self.add_message(message.clone());
571            }
572            StreamEvent::ToolCallStarted { tool_call, .. } => {
573                self.add_tool_call(tool_call.clone());
574            }
575            StreamEvent::ToolCallCompleted {
576                tool_call_id,
577                result,
578                ..
579            } => {
580                self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
581                if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
582                    tool_call_state.result = Some(result.clone());
583                }
584            }
585            StreamEvent::ToolCallFailed {
586                tool_call_id,
587                error,
588                ..
589            } => {
590                self.update_tool_call_status(
591                    tool_call_id,
592                    ToolCallStatus::Failed {
593                        error: error.clone(),
594                    },
595                )?;
596            }
597            StreamEvent::ToolApprovalRequired { tool_call, .. } => {
598                // Tool call should already be added with PendingApproval status
599                if !self.tool_calls.contains_key(&tool_call.id) {
600                    self.add_tool_call(tool_call.clone());
601                }
602            }
603            // Other events don't modify state directly
604            _ => {}
605        }
606
607        Ok(())
608    }
609}
610
611/// Tool call state tracking
612#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct ToolCallState {
614    pub tool_call: ToolCall,
615    pub status: ToolCallStatus,
616    pub started_at: Option<DateTime<Utc>>,
617    pub completed_at: Option<DateTime<Utc>>,
618    pub result: Option<ToolResult>,
619}
620
621impl ToolCallState {
622    pub fn is_pending(&self) -> bool {
623        matches!(self.status, ToolCallStatus::PendingApproval)
624    }
625
626    pub fn is_complete(&self) -> bool {
627        matches!(
628            self.status,
629            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
630        )
631    }
632
633    pub fn duration(&self) -> Option<chrono::Duration> {
634        match (self.started_at, self.completed_at) {
635            (Some(start), Some(end)) => Some(end - start),
636            _ => None,
637        }
638    }
639}
640
641/// Tool call execution status
642#[derive(Debug, Clone, Serialize, Deserialize)]
643#[serde(tag = "status", rename_all = "snake_case")]
644pub enum ToolCallStatus {
645    PendingApproval,
646    Approved,
647    Denied,
648    Executing,
649    Completed,
650    Failed { error: String },
651}
652
653impl ToolCallStatus {
654    pub fn is_terminal(&self) -> bool {
655        matches!(
656            self,
657            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
658        )
659    }
660}
661
662/// Tool execution statistics
663#[derive(Debug, Clone, Serialize, Deserialize)]
664pub struct ToolExecutionStats {
665    #[serde(skip_serializing_if = "Option::is_none")]
666    pub output: Option<String>, // Legacy string output
667    #[serde(skip_serializing_if = "Option::is_none")]
668    pub json_output: Option<serde_json::Value>, // Typed JSON output
669    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
670    pub success: bool,
671    pub execution_time_ms: u64,
672    pub metadata: HashMap<String, String>,
673}
674
675impl ToolExecutionStats {
676    pub fn success(output: String, execution_time_ms: u64) -> Self {
677        Self {
678            output: Some(output),
679            json_output: None,
680            result_type: None,
681            success: true,
682            execution_time_ms,
683            metadata: HashMap::new(),
684        }
685    }
686
687    pub fn success_typed(
688        json_output: serde_json::Value,
689        result_type: String,
690        execution_time_ms: u64,
691    ) -> Self {
692        Self {
693            output: None,
694            json_output: Some(json_output),
695            result_type: Some(result_type),
696            success: true,
697            execution_time_ms,
698            metadata: HashMap::new(),
699        }
700    }
701
702    pub fn failure(error: String, execution_time_ms: u64) -> Self {
703        Self {
704            output: Some(error),
705            json_output: None,
706            result_type: None,
707            success: false,
708            execution_time_ms,
709            metadata: HashMap::new(),
710        }
711    }
712
713    pub fn with_metadata(mut self, key: String, value: String) -> Self {
714        self.metadata.insert(key, value);
715        self
716    }
717}
718
719/// Session metadata for listing and filtering
720#[derive(Debug, Clone, Serialize, Deserialize)]
721pub struct SessionInfo {
722    pub id: String,
723    pub created_at: DateTime<Utc>,
724    pub updated_at: DateTime<Utc>,
725    /// The last known model used in this session
726    pub last_model: Option<Model>,
727    pub message_count: usize,
728    pub metadata: HashMap<String, String>,
729}
730
731impl From<&Session> for SessionInfo {
732    fn from(session: &Session) -> Self {
733        Self {
734            id: session.id.clone(),
735            created_at: session.created_at,
736            updated_at: session.updated_at,
737            last_model: None, // TODO: Track last model used from events
738            message_count: session.state.message_count(),
739            metadata: session.config.metadata.clone(),
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747    use crate::app::conversation::{Message, MessageData, UserContent};
748    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
749
750    #[test]
751    fn test_session_creation() {
752        let config = SessionConfig {
753            workspace: WorkspaceConfig::Local {
754                path: PathBuf::from("/test/path"),
755            },
756            tool_config: SessionToolConfig::default(),
757            system_prompt: None,
758            metadata: HashMap::new(),
759        };
760        let session = Session::new("test-session".to_string(), config.clone());
761
762        assert_eq!(session.id, "test-session");
763        assert!(
764            session
765                .config
766                .tool_config
767                .approval_policy
768                .should_ask_for_approval("any_tool")
769        );
770        assert_eq!(session.state.message_count(), 0);
771    }
772
773    #[test]
774    fn test_tool_approval_policy() {
775        let policy = ToolApprovalPolicy::PreApproved {
776            tools: ["read_file", "list_files"]
777                .iter()
778                .map(|s| s.to_string())
779                .collect(),
780        };
781
782        assert!(policy.is_tool_approved("read_file"));
783        assert!(!policy.is_tool_approved("write_file"));
784        assert!(!policy.should_ask_for_approval("read_file"));
785        assert!(policy.should_ask_for_approval("write_file"));
786    }
787
788    #[test]
789    fn test_session_state_validation() {
790        let mut state = SessionState::default();
791
792        // Valid empty state
793        assert!(state.validate().is_ok());
794
795        // Add a message
796        let message = Message {
797            data: MessageData::User {
798                content: vec![UserContent::Text {
799                    text: "Hello".to_string(),
800                }],
801            },
802            timestamp: 123456789,
803            id: "msg1".to_string(),
804            parent_message_id: None,
805        };
806        state.add_message(message);
807
808        assert!(state.validate().is_ok());
809        assert_eq!(state.message_count(), 1);
810    }
811
812    #[test]
813    fn test_tool_call_state_tracking() {
814        let mut state = SessionState::default();
815
816        let tool_call = ToolCall {
817            id: "tool1".to_string(),
818            name: "read_file".to_string(),
819            parameters: serde_json::json!({"path": "/test.txt"}),
820        };
821
822        state.add_tool_call(tool_call.clone());
823        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
824
825        state
826            .update_tool_call_status("tool1", ToolCallStatus::Executing)
827            .unwrap();
828        let tool_state = state.tool_calls.get("tool1").unwrap();
829        assert!(tool_state.started_at.is_some());
830        assert!(!tool_state.is_complete());
831
832        state
833            .update_tool_call_status("tool1", ToolCallStatus::Completed)
834            .unwrap();
835        let tool_state = state.tool_calls.get("tool1").unwrap();
836        assert!(tool_state.completed_at.is_some());
837        assert!(tool_state.is_complete());
838    }
839
840    #[test]
841    fn test_session_tool_config_default() {
842        let config = SessionToolConfig::default();
843        assert!(config.backends.is_empty());
844    }
845
846    #[test]
847    fn test_tool_filter_exclude() {
848        // Test that we can exclude specific tools
849        let config = SessionToolConfig {
850            backends: vec![BackendConfig::Local {
851                tool_filter: ToolFilter::Exclude(vec![
852                    BASH_TOOL_NAME.to_string(),
853                    EDIT_TOOL_NAME.to_string(),
854                ]),
855            }],
856            visibility: ToolVisibility::All,
857            approval_policy: ToolApprovalPolicy::AlwaysAsk,
858            metadata: HashMap::new(),
859            tools: HashMap::new(),
860        };
861
862        assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
863        if let BackendConfig::Local { tool_filter } = &config.backends[0] {
864            assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
865            if let ToolFilter::Exclude(excluded_tools) = tool_filter {
866                assert_eq!(excluded_tools.len(), 2);
867                assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
868                assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
869            }
870        }
871    }
872
873    #[test]
874    fn test_session_tool_config_read_only() {
875        let config = SessionToolConfig::read_only();
876        assert_eq!(config.backends.len(), 0); // Empty backends means use defaults
877        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
878        assert!(matches!(
879            config.approval_policy,
880            ToolApprovalPolicy::AlwaysAsk
881        ));
882    }
883
884    #[tokio::test]
885    async fn test_session_config_build_registry_server_tools() {
886        use crate::auth::DefaultAuthStorage;
887        use crate::config::LlmConfigProvider;
888
889        // Test that server tools are properly registered
890        let config = SessionConfig {
891            workspace: WorkspaceConfig::Local {
892                path: PathBuf::from("/test/path"),
893            },
894            tool_config: SessionToolConfig::default(),
895            system_prompt: None,
896            metadata: HashMap::new(),
897        };
898
899        // For tests, we'll just unwrap since it's a test environment
900        let auth_storage =
901            DefaultAuthStorage::new().expect("Failed to create auth storage for test");
902        let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
903
904        // Create a test workspace
905        let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
906            .await
907            .unwrap();
908
909        let registry = config
910            .build_registry(llm_config_provider, workspace)
911            .await
912            .unwrap();
913        let schemas = registry.get_tool_schemas().await;
914        let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
915
916        // Only server tools should be in the registry
917        assert!(tool_names.contains(&"dispatch_agent".to_string()));
918        assert!(tool_names.contains(&"web_fetch".to_string()));
919
920        // Verify workspace tools are NOT in the registry (they're handled by Workspace)
921        let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
922        for tool_name in workspace_tool_names {
923            assert!(
924                !tool_names.contains(&tool_name.to_string()),
925                "Workspace tool {tool_name} should not be in registry"
926            );
927        }
928    }
929
930    // Test removed: workspace tools are no longer in the registry
931
932    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
933
934    // Test removed: workspace backend no longer exists in the registry
935
936    #[test]
937    fn test_backend_config_variants() {
938        // Test Local variant
939        let local_config = BackendConfig::Local {
940            tool_filter: ToolFilter::Include(vec![
941                VIEW_TOOL_NAME.to_string(),
942                LS_TOOL_NAME.to_string(),
943            ]),
944        };
945
946        assert!(matches!(local_config, BackendConfig::Local { .. }));
947        if let BackendConfig::Local { tool_filter } = local_config {
948            assert!(matches!(tool_filter, ToolFilter::Include(_)));
949            if let ToolFilter::Include(tools) = tool_filter {
950                assert_eq!(tools.len(), 2);
951            }
952        }
953
954        // Container backend has been removed - tests removed
955
956        // Test Mcp variant
957        let mcp_config = BackendConfig::Mcp {
958            server_name: "test-mcp".to_string(),
959            transport: crate::tools::McpTransport::Stdio {
960                command: "python".to_string(),
961                args: vec!["-m".to_string(), "test_server".to_string()],
962            },
963            tool_filter: ToolFilter::All,
964        };
965
966        assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
967        if let BackendConfig::Mcp {
968            server_name,
969            transport,
970            ..
971        } = mcp_config
972        {
973            assert_eq!(server_name, "test-mcp");
974            assert!(matches!(
975                transport,
976                crate::tools::McpTransport::Stdio { .. }
977            ));
978            if let crate::tools::McpTransport::Stdio { command, args } = transport {
979                assert_eq!(command, "python");
980                assert_eq!(args.len(), 2);
981            }
982        }
983    }
984}