Skip to main content

steer_core/session/
state.rs

1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15/// State of an MCP server connection
16#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18    /// Currently attempting to connect
19    Connecting,
20    /// Successfully connected
21    Connected {
22        /// Names of tools available from this server
23        tool_names: Vec<String>,
24    },
25    /// Gracefully disconnected
26    Disconnected {
27        /// Optional reason for disconnection
28        reason: Option<String>,
29    },
30    /// Failed to connect
31    Failed {
32        /// Error message describing the failure
33        error: String,
34    },
35}
36
37/// Information about an MCP server
38#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40    /// The configured server name
41    pub server_name: String,
42    /// The transport configuration
43    pub transport: McpTransport,
44    /// Current connection state
45    pub state: McpConnectionState,
46    /// Timestamp when this state was last updated
47    pub last_updated: DateTime<Utc>,
48}
49
50/// Defines the primary execution environment for a session's workspace
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54    Local {
55        path: PathBuf,
56    },
57    Remote {
58        agent_address: String,
59        auth: Option<RemoteAuth>,
60    },
61}
62
63impl WorkspaceConfig {
64    pub fn get_path(&self) -> Option<String> {
65        match self {
66            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68        }
69    }
70
71    /// Convert to steer_workspace::WorkspaceConfig
72    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73        match self {
74            WorkspaceConfig::Local { path } => {
75                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76            }
77            WorkspaceConfig::Remote {
78                agent_address,
79                auth,
80            } => steer_workspace::WorkspaceConfig::Remote {
81                address: agent_address.clone(),
82                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83            },
84        }
85    }
86}
87
88impl Default for WorkspaceConfig {
89    fn default() -> Self {
90        Self::Local {
91            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92        }
93    }
94}
95
96/// Complete session representation
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99    pub id: String,
100    pub created_at: DateTime<Utc>,
101    pub updated_at: DateTime<Utc>,
102    pub config: SessionConfig,
103    pub state: SessionState,
104}
105
106impl Session {
107    pub fn new(id: String, config: SessionConfig) -> Self {
108        let now = Utc::now();
109        Self {
110            id,
111            created_at: now,
112            updated_at: now,
113            config,
114            state: SessionState::default(),
115        }
116    }
117
118    pub fn update_timestamp(&mut self) {
119        self.updated_at = Utc::now();
120    }
121
122    /// Check if session has any recent activity
123    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124        let cutoff = Utc::now() - threshold;
125        self.updated_at > cutoff
126    }
127
128    /// Build a workspace from this session's configuration
129    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131    }
132}
133
134/// Session configuration - immutable once created
135#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137    pub workspace: WorkspaceConfig,
138    #[serde(default)]
139    pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140    #[serde(default)]
141    pub workspace_id: Option<crate::workspace::WorkspaceId>,
142    #[serde(default)]
143    pub repo_ref: Option<crate::workspace::RepoRef>,
144    #[serde(default)]
145    pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146    #[serde(default)]
147    pub workspace_name: Option<String>,
148    pub tool_config: SessionToolConfig,
149    /// Optional custom system prompt to use for the session. If `None`, Steer will
150    /// fall back to its built-in default prompt.
151    pub system_prompt: Option<String>,
152    /// Primary agent mode for this session. Defaults to "normal" if unset.
153    #[serde(default)]
154    pub primary_agent_id: Option<String>,
155    /// User-controlled policy overrides that apply on top of the primary agent base policy.
156    #[serde(default = "SessionPolicyOverrides::empty")]
157    pub policy_overrides: SessionPolicyOverrides,
158    pub metadata: HashMap<String, String>,
159    pub default_model: ModelId,
160}
161
162impl SessionConfig {
163    /// Build a BackendRegistry from MCP server configurations.
164    /// Returns the registry and a map of MCP server connection states.
165    pub async fn build_registry(
166        &self,
167    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
168        let mut registry = BackendRegistry::new();
169        let mut mcp_servers = HashMap::new();
170
171        for backend_config in &self.tool_config.backends {
172            let BackendConfig::Mcp {
173                server_name,
174                transport,
175                tool_filter,
176            } = backend_config;
177
178            tracing::info!(
179                "Attempting to initialize MCP backend '{}' with transport: {:?}",
180                server_name,
181                transport
182            );
183
184            let mut server_info = McpServerInfo {
185                server_name: server_name.clone(),
186                transport: transport.clone(),
187                state: McpConnectionState::Connecting,
188                last_updated: Utc::now(),
189            };
190
191            match crate::tools::McpBackend::new(
192                server_name.clone(),
193                transport.clone(),
194                tool_filter.clone(),
195            )
196            .await
197            {
198                Ok(mcp_backend) => {
199                    let tool_names = mcp_backend.supported_tools().await;
200                    let tool_count = tool_names.len();
201                    tracing::info!(
202                        "Successfully initialized MCP backend '{}' with {} tools",
203                        server_name,
204                        tool_count
205                    );
206                    server_info.state = McpConnectionState::Connected { tool_names };
207                    server_info.last_updated = Utc::now();
208                    registry
209                        .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
210                        .await;
211                }
212                Err(e) => {
213                    tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
214                    server_info.state = McpConnectionState::Failed {
215                        error: e.to_string(),
216                    };
217                    server_info.last_updated = Utc::now();
218                }
219            }
220
221            mcp_servers.insert(server_name.clone(), server_info);
222        }
223
224        Ok((registry, mcp_servers))
225    }
226
227    /// Filter tools based on visibility settings
228    pub fn filter_tools_by_visibility(
229        &self,
230        tools: Vec<steer_tools::ToolSchema>,
231    ) -> Vec<steer_tools::ToolSchema> {
232        match &self.tool_config.visibility {
233            ToolVisibility::All => tools,
234            ToolVisibility::ReadOnly => {
235                let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
236                    .iter()
237                    .map(|name| (*name).to_string())
238                    .collect();
239
240                tools
241                    .into_iter()
242                    .filter(|schema| read_only_names.contains(&schema.name))
243                    .collect()
244            }
245            ToolVisibility::Whitelist(allowed) => tools
246                .into_iter()
247                .filter(|schema| allowed.contains(&schema.name))
248                .collect(),
249            ToolVisibility::Blacklist(blocked) => tools
250                .into_iter()
251                .filter(|schema| !blocked.contains(&schema.name))
252                .collect(),
253        }
254    }
255
256    /// Minimal read-only configuration
257    #[cfg(test)]
258    pub fn read_only(default_model: ModelId) -> Self {
259        Self {
260            workspace: WorkspaceConfig::Local {
261                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
262            },
263            workspace_ref: None,
264            workspace_id: None,
265            repo_ref: None,
266            parent_session_id: None,
267            workspace_name: None,
268            tool_config: SessionToolConfig::read_only(),
269            system_prompt: None,
270            primary_agent_id: None,
271            policy_overrides: SessionPolicyOverrides::empty(),
272            metadata: HashMap::new(),
273            default_model,
274        }
275    }
276}
277
278/// User-controlled policy overrides applied on top of a primary agent base policy.
279#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
280pub struct SessionPolicyOverrides {
281    #[serde(default, skip_serializing_if = "Option::is_none")]
282    pub default_model: Option<ModelId>,
283    #[serde(default, skip_serializing_if = "Option::is_none")]
284    pub tool_visibility: Option<ToolVisibility>,
285    #[serde(default = "ToolApprovalPolicyOverrides::empty")]
286    pub approval_policy: ToolApprovalPolicyOverrides,
287}
288
289impl SessionPolicyOverrides {
290    pub fn empty() -> Self {
291        Self {
292            default_model: None,
293            tool_visibility: None,
294            approval_policy: ToolApprovalPolicyOverrides::empty(),
295        }
296    }
297
298    pub fn is_empty(&self) -> bool {
299        self.default_model.is_none()
300            && self.tool_visibility.is_none()
301            && self.approval_policy.is_empty()
302    }
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
306pub struct ToolApprovalPolicyOverrides {
307    #[serde(default, skip_serializing_if = "Option::is_none")]
308    pub default_behavior: Option<UnapprovedBehavior>,
309    #[serde(default = "ApprovalRulesOverrides::empty")]
310    pub preapproved: ApprovalRulesOverrides,
311}
312
313impl ToolApprovalPolicyOverrides {
314    pub fn empty() -> Self {
315        Self {
316            default_behavior: None,
317            preapproved: ApprovalRulesOverrides::empty(),
318        }
319    }
320
321    pub fn is_empty(&self) -> bool {
322        self.default_behavior.is_none() && self.preapproved.is_empty()
323    }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
327pub struct ApprovalRulesOverrides {
328    #[serde(default)]
329    pub tools: HashSet<String>,
330    #[serde(default)]
331    pub per_tool: HashMap<String, ToolRuleOverrides>,
332}
333
334impl ApprovalRulesOverrides {
335    pub fn empty() -> Self {
336        Self {
337            tools: HashSet::new(),
338            per_tool: HashMap::new(),
339        }
340    }
341
342    pub fn is_empty(&self) -> bool {
343        self.tools.is_empty() && self.per_tool.is_empty()
344    }
345}
346
347#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum ToolRuleOverrides {
350    Bash { patterns: Vec<String> },
351    DispatchAgent { agent_patterns: Vec<String> },
352}
353
354/// Tool visibility configuration - controls which tools are shown to the AI agent
355#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
356#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
357pub enum ToolVisibility {
358    /// Show all registered tools to the AI
359    #[default]
360    All,
361
362    /// Only show read-only tools to the AI
363    ReadOnly,
364
365    /// Show only specific tools to the AI (whitelist)
366    Whitelist(HashSet<String>),
367
368    /// Hide specific tools from the AI (blacklist)
369    Blacklist(HashSet<String>),
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq)]
373pub enum ToolDecision {
374    Allow,
375    Ask,
376    Deny,
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
380#[serde(rename_all = "snake_case")]
381pub enum UnapprovedBehavior {
382    #[default]
383    Prompt,
384    Deny,
385    Allow,
386}
387
388#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
389#[serde(tag = "type", rename_all = "snake_case")]
390pub enum ToolRule {
391    Bash { patterns: Vec<String> },
392    DispatchAgent { agent_patterns: Vec<String> },
393}
394
395#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
396pub struct ApprovalRules {
397    #[serde(default)]
398    pub tools: HashSet<String>,
399    #[serde(default)]
400    pub per_tool: HashMap<String, ToolRule>,
401}
402
403impl ApprovalRules {
404    pub fn is_empty(&self) -> bool {
405        self.tools.is_empty() && self.per_tool.is_empty()
406    }
407
408    pub fn bash_patterns(&self) -> Option<&[String]> {
409        self.per_tool.get("bash").and_then(|rule| match rule {
410            ToolRule::Bash { patterns } => Some(patterns.as_slice()),
411            ToolRule::DispatchAgent { .. } => None,
412        })
413    }
414
415    pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
416        self.per_tool
417            .get("dispatch_agent")
418            .and_then(|rule| match rule {
419                ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
420                ToolRule::Bash { .. } => None,
421            })
422    }
423}
424
425#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
426pub struct ToolApprovalPolicy {
427    pub default_behavior: UnapprovedBehavior,
428    #[serde(default)]
429    pub preapproved: ApprovalRules,
430}
431
432impl Default for ToolApprovalPolicy {
433    fn default() -> Self {
434        Self {
435            default_behavior: UnapprovedBehavior::Prompt,
436            preapproved: ApprovalRules {
437                tools: READ_ONLY_TOOL_NAMES
438                    .iter()
439                    .map(|name| (*name).to_string())
440                    .collect(),
441                per_tool: HashMap::new(),
442            },
443        }
444    }
445}
446
447impl ToolApprovalPolicy {
448    pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
449        if self.preapproved.tools.contains(tool_name) {
450            ToolDecision::Allow
451        } else {
452            match self.default_behavior {
453                UnapprovedBehavior::Prompt => ToolDecision::Ask,
454                UnapprovedBehavior::Deny => ToolDecision::Deny,
455                UnapprovedBehavior::Allow => ToolDecision::Allow,
456            }
457        }
458    }
459
460    pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
461        let Some(patterns) = self.preapproved.bash_patterns() else {
462            return false;
463        };
464        patterns.iter().any(|pattern| {
465            if pattern == command {
466                return true;
467            }
468            glob::Pattern::new(pattern)
469                .map(|glob| glob.matches(command))
470                .unwrap_or(false)
471        })
472    }
473
474    pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
475        let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
476            return false;
477        };
478        patterns.iter().any(|pattern| {
479            if pattern == agent_id {
480                return true;
481            }
482            glob::Pattern::new(pattern)
483                .map(|glob| glob.matches(agent_id))
484                .unwrap_or(false)
485        })
486    }
487
488    pub fn pre_approved_tools(&self) -> &HashSet<String> {
489        &self.preapproved.tools
490    }
491}
492
493impl ToolApprovalPolicyOverrides {
494    pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
495        let mut merged = base.clone();
496
497        if let Some(default_behavior) = self.default_behavior {
498            merged.default_behavior = default_behavior;
499        }
500
501        if !self.preapproved.tools.is_empty() {
502            merged
503                .preapproved
504                .tools
505                .extend(self.preapproved.tools.iter().cloned());
506        }
507
508        for (tool_name, override_rule) in &self.preapproved.per_tool {
509            let base_rule = merged.preapproved.per_tool.get(tool_name);
510            let merged_rule = merge_tool_rule_override(base_rule, override_rule);
511            merged
512                .preapproved
513                .per_tool
514                .insert(tool_name.clone(), merged_rule);
515        }
516
517        merged
518    }
519}
520
521fn merge_tool_rule_override(
522    base: Option<&ToolRule>,
523    override_rule: &ToolRuleOverrides,
524) -> ToolRule {
525    match (base, override_rule) {
526        (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
527            ToolRule::Bash {
528                patterns: merge_patterns(patterns, extra),
529            }
530        }
531        (
532            Some(ToolRule::DispatchAgent { agent_patterns }),
533            ToolRuleOverrides::DispatchAgent {
534                agent_patterns: extra,
535            },
536        ) => ToolRule::DispatchAgent {
537            agent_patterns: merge_patterns(agent_patterns, extra),
538        },
539        (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
540            patterns: patterns.clone(),
541        },
542        (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
543            agent_patterns: agent_patterns.clone(),
544        },
545    }
546}
547
548fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
549    let mut merged = base.to_vec();
550    for pattern in extra {
551        if !merged.iter().any(|existing| existing == pattern) {
552            merged.push(pattern.clone());
553        }
554    }
555    merged
556}
557
558/// Authentication configuration for remote backends
559#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
560pub enum RemoteAuth {
561    Bearer { token: String },
562    ApiKey { key: String },
563}
564
565impl RemoteAuth {
566    /// Convert to steer_workspace RemoteAuth type
567    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
568        match self {
569            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
570            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
571        }
572    }
573}
574
575/// Tool filtering configuration for backends
576#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
577#[serde(rename_all = "snake_case")]
578#[derive(Default)]
579pub enum ToolFilter {
580    /// Include all available tools
581    #[default]
582    All,
583    /// Include only the specified tools
584    Include(Vec<String>),
585    /// Include all tools except the specified ones
586    Exclude(Vec<String>),
587}
588
589/// Configuration for MCP server backends
590#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
591#[serde(tag = "type", rename_all = "snake_case")]
592pub enum BackendConfig {
593    Mcp {
594        server_name: String,
595        transport: McpTransport,
596        tool_filter: ToolFilter,
597    },
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
601pub struct SessionToolConfig {
602    pub backends: Vec<BackendConfig>,
603    pub visibility: ToolVisibility,
604    pub approval_policy: ToolApprovalPolicy,
605    pub metadata: HashMap<String, String>,
606}
607
608impl Default for SessionToolConfig {
609    fn default() -> Self {
610        Self {
611            backends: Vec::new(),
612            visibility: ToolVisibility::All,
613            approval_policy: ToolApprovalPolicy::default(),
614            metadata: HashMap::new(),
615        }
616    }
617}
618
619impl SessionToolConfig {
620    pub fn read_only() -> Self {
621        Self {
622            backends: Vec::new(),
623            visibility: ToolVisibility::ReadOnly,
624            approval_policy: ToolApprovalPolicy::default(),
625            metadata: HashMap::new(),
626        }
627    }
628}
629
630/// Mutable session state that changes during execution
631#[derive(Debug, Clone, Serialize, Deserialize, Default)]
632pub struct SessionState {
633    /// Conversation messages
634    pub messages: Vec<Message>,
635
636    /// Tool call tracking
637    pub tool_calls: HashMap<String, ToolCallState>,
638
639    /// Tools that have been approved for this session
640    pub approved_tools: HashSet<String>,
641
642    /// Bash commands that have been approved for this session (dynamically added)
643    #[serde(default)]
644    pub approved_bash_patterns: HashSet<String>,
645
646    /// Last processed event sequence number for replay
647    pub last_event_sequence: u64,
648
649    /// Additional runtime metadata
650    pub metadata: HashMap<String, String>,
651
652    /// The ID of the currently active message (head of selected branch)
653    /// None means use last message semantics for backward compatibility
654    #[serde(default, skip_serializing_if = "Option::is_none")]
655    pub active_message_id: Option<String>,
656
657    /// Status of MCP server connections
658    /// This is a transient field that is rebuilt on session activation
659    #[serde(default, skip_serializing, skip_deserializing)]
660    pub mcp_servers: HashMap<String, McpServerInfo>,
661}
662
663impl SessionState {
664    /// Add a message to the conversation
665    pub fn add_message(&mut self, message: Message) {
666        self.messages.push(message);
667    }
668
669    /// Get the number of messages in the conversation
670    pub fn message_count(&self) -> usize {
671        self.messages.len()
672    }
673
674    /// Get the last message in the conversation
675    pub fn last_message(&self) -> Option<&Message> {
676        self.messages.last()
677    }
678
679    /// Add a tool call to tracking
680    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
681        let state = ToolCallState {
682            tool_call: tool_call.clone(),
683            status: ToolCallStatus::PendingApproval,
684            started_at: None,
685            completed_at: None,
686            result: None,
687        };
688        self.tool_calls.insert(tool_call.id, state);
689    }
690
691    /// Update tool call status
692    pub fn update_tool_call_status(
693        &mut self,
694        tool_call_id: &str,
695        status: ToolCallStatus,
696    ) -> std::result::Result<(), String> {
697        let tool_call = self
698            .tool_calls
699            .get_mut(tool_call_id)
700            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
701
702        // Update timestamps based on status changes
703        match (&tool_call.status, &status) {
704            (_, ToolCallStatus::Executing) => {
705                tool_call.started_at = Some(Utc::now());
706            }
707            (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
708                tool_call.completed_at = Some(Utc::now());
709            }
710            _ => {}
711        }
712
713        tool_call.status = status;
714        Ok(())
715    }
716
717    /// Approve a tool for future use
718    pub fn approve_tool(&mut self, tool_name: String) {
719        self.approved_tools.insert(tool_name);
720    }
721
722    /// Check if a tool is approved
723    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
724        self.approved_tools.contains(tool_name)
725    }
726
727    /// Validate internal consistency
728    pub fn validate(&self) -> std::result::Result<(), String> {
729        // Check that all tool calls referenced in messages exist
730        for message in &self.messages {
731            let tool_calls = Self::extract_tool_calls_from_message(message);
732            if !tool_calls.is_empty() {
733                for tool_call_id in tool_calls {
734                    if !self.tool_calls.contains_key(&tool_call_id) {
735                        return Err(format!(
736                            "Message references unknown tool call: {tool_call_id}"
737                        ));
738                    }
739                }
740            }
741        }
742
743        Ok(())
744    }
745
746    /// Extract tool call IDs from a message
747    fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
748        let mut tool_call_ids = Vec::new();
749
750        match &message.data {
751            MessageData::Assistant { content, .. } => {
752                for c in content {
753                    if let crate::app::conversation::AssistantContent::ToolCall {
754                        tool_call, ..
755                    } = c
756                    {
757                        tool_call_ids.push(tool_call.id.clone());
758                    }
759                }
760            }
761            MessageData::Tool { tool_use_id, .. } => {
762                tool_call_ids.push(tool_use_id.clone());
763            }
764            MessageData::User { .. } => {}
765        }
766
767        tool_call_ids
768    }
769}
770
771/// Tool call state tracking
772#[derive(Debug, Clone, Serialize, Deserialize)]
773pub struct ToolCallState {
774    pub tool_call: ToolCall,
775    pub status: ToolCallStatus,
776    pub started_at: Option<DateTime<Utc>>,
777    pub completed_at: Option<DateTime<Utc>>,
778    pub result: Option<ToolResult>,
779}
780
781impl ToolCallState {
782    pub fn is_pending(&self) -> bool {
783        matches!(self.status, ToolCallStatus::PendingApproval)
784    }
785
786    pub fn is_complete(&self) -> bool {
787        matches!(
788            self.status,
789            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
790        )
791    }
792
793    pub fn duration(&self) -> Option<chrono::Duration> {
794        match (self.started_at, self.completed_at) {
795            (Some(start), Some(end)) => Some(end - start),
796            _ => None,
797        }
798    }
799}
800
801/// Tool call execution status
802#[derive(Debug, Clone, Serialize, Deserialize)]
803#[serde(tag = "status", rename_all = "snake_case")]
804pub enum ToolCallStatus {
805    PendingApproval,
806    Approved,
807    Denied,
808    Executing,
809    Completed,
810    Failed { error: String },
811}
812
813impl ToolCallStatus {
814    pub fn is_terminal(&self) -> bool {
815        matches!(
816            self,
817            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
818        )
819    }
820}
821
822/// Tool execution statistics
823#[derive(Debug, Clone, Serialize, Deserialize)]
824pub struct ToolExecutionStats {
825    #[serde(skip_serializing_if = "Option::is_none")]
826    pub output: Option<String>, // Legacy string output
827    #[serde(skip_serializing_if = "Option::is_none")]
828    pub json_output: Option<serde_json::Value>, // Typed JSON output
829    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
830    pub success: bool,
831    pub execution_time_ms: u64,
832    pub metadata: HashMap<String, String>,
833}
834
835impl ToolExecutionStats {
836    pub fn success(output: String, execution_time_ms: u64) -> Self {
837        Self {
838            output: Some(output),
839            json_output: None,
840            result_type: None,
841            success: true,
842            execution_time_ms,
843            metadata: HashMap::new(),
844        }
845    }
846
847    pub fn success_typed(
848        json_output: serde_json::Value,
849        result_type: String,
850        execution_time_ms: u64,
851    ) -> Self {
852        Self {
853            output: None,
854            json_output: Some(json_output),
855            result_type: Some(result_type),
856            success: true,
857            execution_time_ms,
858            metadata: HashMap::new(),
859        }
860    }
861
862    pub fn failure(error: String, execution_time_ms: u64) -> Self {
863        Self {
864            output: Some(error),
865            json_output: None,
866            result_type: None,
867            success: false,
868            execution_time_ms,
869            metadata: HashMap::new(),
870        }
871    }
872
873    pub fn with_metadata(mut self, key: String, value: String) -> Self {
874        self.metadata.insert(key, value);
875        self
876    }
877}
878
879/// Session metadata for listing and filtering
880#[derive(Debug, Clone, Serialize, Deserialize)]
881pub struct SessionInfo {
882    pub id: String,
883    pub created_at: DateTime<Utc>,
884    pub updated_at: DateTime<Utc>,
885    /// The last known model used in this session
886    pub last_model: Option<ModelId>,
887    pub message_count: usize,
888    pub metadata: HashMap<String, String>,
889}
890
891impl From<&Session> for SessionInfo {
892    fn from(session: &Session) -> Self {
893        Self {
894            id: session.id.clone(),
895            created_at: session.created_at,
896            updated_at: session.updated_at,
897            last_model: None, // TODO: Track last model used from events
898            message_count: session.state.message_count(),
899            metadata: session.config.metadata.clone(),
900        }
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907    use crate::app::conversation::{Message, MessageData, UserContent};
908    use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
909    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
910    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
911    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
912
913    #[test]
914    fn test_session_creation() {
915        let config = SessionConfig {
916            workspace: WorkspaceConfig::Local {
917                path: PathBuf::from("/test/path"),
918            },
919            workspace_ref: None,
920            workspace_id: None,
921            repo_ref: None,
922            parent_session_id: None,
923            workspace_name: None,
924            tool_config: SessionToolConfig::default(),
925            system_prompt: None,
926            primary_agent_id: None,
927            policy_overrides: SessionPolicyOverrides::empty(),
928            metadata: HashMap::new(),
929            default_model: test_model(),
930        };
931        let session = Session::new("test-session".to_string(), config.clone());
932
933        assert_eq!(session.id, "test-session");
934        assert_eq!(
935            session
936                .config
937                .tool_config
938                .approval_policy
939                .tool_decision("any_tool"),
940            ToolDecision::Ask
941        );
942        assert_eq!(session.state.message_count(), 0);
943    }
944
945    #[test]
946    fn test_tool_approval_policy_prompt_unapproved() {
947        let policy = ToolApprovalPolicy {
948            default_behavior: UnapprovedBehavior::Prompt,
949            preapproved: ApprovalRules {
950                tools: ["read_file", "list_files"]
951                    .iter()
952                    .map(|s| (*s).to_string())
953                    .collect(),
954                per_tool: HashMap::new(),
955            },
956        };
957
958        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
959        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
960    }
961
962    #[test]
963    fn test_tool_approval_policy_deny_unapproved() {
964        let policy = ToolApprovalPolicy {
965            default_behavior: UnapprovedBehavior::Deny,
966            preapproved: ApprovalRules {
967                tools: ["read_file", "list_files"]
968                    .iter()
969                    .map(|s| (*s).to_string())
970                    .collect(),
971                per_tool: HashMap::new(),
972            },
973        };
974
975        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
976        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
977    }
978
979    #[test]
980    fn test_tool_approval_policy_default() {
981        let policy = ToolApprovalPolicy::default();
982
983        assert_eq!(
984            policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
985            ToolDecision::Allow
986        );
987        assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
988    }
989
990    #[test]
991    fn test_tool_approval_policy_allow_unapproved() {
992        let policy = ToolApprovalPolicy {
993            default_behavior: UnapprovedBehavior::Allow,
994            preapproved: ApprovalRules {
995                tools: ["read_file", "list_files"]
996                    .iter()
997                    .map(|s| (*s).to_string())
998                    .collect(),
999                per_tool: HashMap::new(),
1000            },
1001        };
1002
1003        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1004        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1005    }
1006
1007    #[test]
1008    fn test_tool_approval_policy_overrides_union_rules() {
1009        let base_policy = ToolApprovalPolicy {
1010            default_behavior: UnapprovedBehavior::Prompt,
1011            preapproved: ApprovalRules {
1012                tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1013                per_tool: [
1014                    (
1015                        BASH_TOOL_NAME.to_string(),
1016                        ToolRule::Bash {
1017                            patterns: vec!["git status".to_string()],
1018                        },
1019                    ),
1020                    (
1021                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1022                        ToolRule::DispatchAgent {
1023                            agent_patterns: vec!["explore".to_string()],
1024                        },
1025                    ),
1026                ]
1027                .into_iter()
1028                .collect(),
1029            },
1030        };
1031
1032        let overrides = ToolApprovalPolicyOverrides {
1033            default_behavior: Some(UnapprovedBehavior::Deny),
1034            preapproved: ApprovalRulesOverrides {
1035                tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1036                per_tool: [
1037                    (
1038                        BASH_TOOL_NAME.to_string(),
1039                        ToolRuleOverrides::Bash {
1040                            patterns: vec!["git log".to_string()],
1041                        },
1042                    ),
1043                    (
1044                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1045                        ToolRuleOverrides::DispatchAgent {
1046                            agent_patterns: vec!["review".to_string()],
1047                        },
1048                    ),
1049                ]
1050                .into_iter()
1051                .collect(),
1052            },
1053        };
1054
1055        let merged = overrides.apply_to(&base_policy);
1056
1057        assert_eq!(merged.default_behavior, UnapprovedBehavior::Deny);
1058        assert!(merged.preapproved.tools.contains("read_file"));
1059        assert!(merged.preapproved.tools.contains("write_file"));
1060
1061        let bash_patterns = match merged
1062            .preapproved
1063            .per_tool
1064            .get(BASH_TOOL_NAME)
1065            .expect("bash rule")
1066        {
1067            ToolRule::Bash { patterns } => patterns,
1068            ToolRule::DispatchAgent { .. } => {
1069                panic!("Unexpected bash rule: dispatch agent")
1070            }
1071        };
1072        assert!(bash_patterns.contains(&"git status".to_string()));
1073        assert!(bash_patterns.contains(&"git log".to_string()));
1074        assert_eq!(bash_patterns.len(), 2);
1075
1076        let agent_patterns = match merged
1077            .preapproved
1078            .per_tool
1079            .get(DISPATCH_AGENT_TOOL_NAME)
1080            .expect("dispatch_agent rule")
1081        {
1082            ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1083            ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1084        };
1085        assert!(agent_patterns.contains(&"explore".to_string()));
1086        assert!(agent_patterns.contains(&"review".to_string()));
1087        assert_eq!(agent_patterns.len(), 2);
1088    }
1089
1090    #[test]
1091    fn test_bash_pattern_matching() {
1092        let policy = ToolApprovalPolicy {
1093            default_behavior: UnapprovedBehavior::Prompt,
1094            preapproved: ApprovalRules {
1095                tools: HashSet::new(),
1096                per_tool: [(
1097                    "bash".to_string(),
1098                    ToolRule::Bash {
1099                        patterns: vec![
1100                            "git status".to_string(),
1101                            "git log*".to_string(),
1102                            "git * --oneline".to_string(),
1103                            "ls -?a*".to_string(),
1104                            "cargo build*".to_string(),
1105                        ],
1106                    },
1107                )]
1108                .into_iter()
1109                .collect(),
1110            },
1111        };
1112
1113        assert!(policy.is_bash_pattern_preapproved("git status"));
1114        assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1115        assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1116        assert!(policy.is_bash_pattern_preapproved("ls -la"));
1117        assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1118        assert!(!policy.is_bash_pattern_preapproved("git commit"));
1119        assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1120        assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1121    }
1122
1123    #[test]
1124    fn test_dispatch_agent_pattern_matching() {
1125        let policy = ToolApprovalPolicy {
1126            default_behavior: UnapprovedBehavior::Prompt,
1127            preapproved: ApprovalRules {
1128                tools: HashSet::new(),
1129                per_tool: [(
1130                    "dispatch_agent".to_string(),
1131                    ToolRule::DispatchAgent {
1132                        agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1133                    },
1134                )]
1135                .into_iter()
1136                .collect(),
1137            },
1138        };
1139
1140        assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1141        assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1142        assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1143    }
1144
1145    #[test]
1146    fn test_session_state_validation() {
1147        let mut state = SessionState::default();
1148
1149        // Valid empty state
1150        assert!(state.validate().is_ok());
1151
1152        // Add a message
1153        let message = Message {
1154            data: MessageData::User {
1155                content: vec![UserContent::Text {
1156                    text: "Hello".to_string(),
1157                }],
1158            },
1159            timestamp: 123_456_789,
1160            id: "msg1".to_string(),
1161            parent_message_id: None,
1162        };
1163        state.add_message(message);
1164
1165        assert!(state.validate().is_ok());
1166        assert_eq!(state.message_count(), 1);
1167    }
1168
1169    #[test]
1170    fn test_tool_call_state_tracking() {
1171        let mut state = SessionState::default();
1172
1173        let tool_call = ToolCall {
1174            id: "tool1".to_string(),
1175            name: "read_file".to_string(),
1176            parameters: serde_json::json!({"path": "/test.txt"}),
1177        };
1178
1179        state.add_tool_call(tool_call.clone());
1180        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1181
1182        state
1183            .update_tool_call_status("tool1", ToolCallStatus::Executing)
1184            .unwrap();
1185        let tool_state = state.tool_calls.get("tool1").unwrap();
1186        assert!(tool_state.started_at.is_some());
1187        assert!(!tool_state.is_complete());
1188
1189        state
1190            .update_tool_call_status("tool1", ToolCallStatus::Completed)
1191            .unwrap();
1192        let tool_state = state.tool_calls.get("tool1").unwrap();
1193        assert!(tool_state.completed_at.is_some());
1194        assert!(tool_state.is_complete());
1195    }
1196
1197    #[test]
1198    fn test_session_tool_config_default() {
1199        let config = SessionToolConfig::default();
1200        assert!(config.backends.is_empty());
1201    }
1202
1203    #[test]
1204    fn test_tool_filter_exclude() {
1205        let excluded =
1206            ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1207
1208        if let ToolFilter::Exclude(tools) = &excluded {
1209            assert_eq!(tools.len(), 2);
1210            assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1211            assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1212        } else {
1213            panic!("Expected ToolFilter::Exclude");
1214        }
1215    }
1216
1217    #[test]
1218    fn test_session_tool_config_read_only() {
1219        let config = SessionToolConfig::read_only();
1220        assert_eq!(config.backends.len(), 0);
1221        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1222        assert_eq!(
1223            config.approval_policy.default_behavior,
1224            UnapprovedBehavior::Prompt
1225        );
1226    }
1227
1228    #[tokio::test]
1229    async fn test_session_config_build_registry_no_default_backends() {
1230        // Test that BackendRegistry only contains user-configured backends.
1231        // Static tools (dispatch_agent, web_fetch) are now in ToolRegistry,
1232        // not BackendRegistry.
1233        let config = SessionConfig {
1234            workspace: WorkspaceConfig::Local {
1235                path: PathBuf::from("/test/path"),
1236            },
1237            workspace_ref: None,
1238            workspace_id: None,
1239            repo_ref: None,
1240            parent_session_id: None,
1241            workspace_name: None,
1242            tool_config: SessionToolConfig::default(), // No backends configured
1243            system_prompt: None,
1244            primary_agent_id: None,
1245            policy_overrides: SessionPolicyOverrides::empty(),
1246            metadata: HashMap::new(),
1247            default_model: test_model(),
1248        };
1249
1250        let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1251        let schemas = registry.get_tool_schemas().await;
1252
1253        assert!(
1254            schemas.is_empty(),
1255            "BackendRegistry should be empty with default config; got: {:?}",
1256            schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1257        );
1258    }
1259
1260    // Test removed: workspace tools are no longer in the registry
1261
1262    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
1263
1264    // Test removed: workspace backend no longer exists in the registry
1265
1266    #[test]
1267    fn test_mcp_status_tracking() {
1268        // Test that MCP server info is properly tracked in session state
1269        let mut session_state = SessionState::default();
1270
1271        // Add some MCP server info
1272        let mcp_info = McpServerInfo {
1273            server_name: "test-server".to_string(),
1274            transport: crate::tools::McpTransport::Stdio {
1275                command: "python".to_string(),
1276                args: vec!["-m".to_string(), "test_server".to_string()],
1277            },
1278            state: McpConnectionState::Connected {
1279                tool_names: vec![
1280                    "tool1".to_string(),
1281                    "tool2".to_string(),
1282                    "tool3".to_string(),
1283                    "tool4".to_string(),
1284                    "tool5".to_string(),
1285                ],
1286            },
1287            last_updated: Utc::now(),
1288        };
1289
1290        session_state
1291            .mcp_servers
1292            .insert("test-server".to_string(), mcp_info.clone());
1293
1294        // Verify it's stored
1295        assert_eq!(session_state.mcp_servers.len(), 1);
1296        let stored = session_state.mcp_servers.get("test-server").unwrap();
1297        assert_eq!(stored.server_name, "test-server");
1298        assert!(matches!(
1299            stored.state,
1300            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1301        ));
1302
1303        // Test failed connection
1304        let failed_info = McpServerInfo {
1305            server_name: "failed-server".to_string(),
1306            transport: crate::tools::McpTransport::Tcp {
1307                host: "localhost".to_string(),
1308                port: 9999,
1309            },
1310            state: McpConnectionState::Failed {
1311                error: "Connection refused".to_string(),
1312            },
1313            last_updated: Utc::now(),
1314        };
1315
1316        session_state
1317            .mcp_servers
1318            .insert("failed-server".to_string(), failed_info);
1319        assert_eq!(session_state.mcp_servers.len(), 2);
1320    }
1321
1322    #[tokio::test]
1323    async fn test_mcp_server_tracking_in_build_registry() {
1324        // Create a session config with both good and bad MCP servers
1325        let mut config = SessionConfig::read_only(test_model());
1326
1327        // This one should fail (invalid transport)
1328        config.tool_config.backends.push(BackendConfig::Mcp {
1329            server_name: "bad-server".to_string(),
1330            transport: crate::tools::McpTransport::Tcp {
1331                host: "nonexistent.invalid".to_string(),
1332                port: 12345,
1333            },
1334            tool_filter: ToolFilter::All,
1335        });
1336
1337        // This one would succeed if we had a real server running
1338        config.tool_config.backends.push(BackendConfig::Mcp {
1339            server_name: "good-server".to_string(),
1340            transport: crate::tools::McpTransport::Stdio {
1341                command: "echo".to_string(),
1342                args: vec!["test".to_string()],
1343            },
1344            tool_filter: ToolFilter::All,
1345        });
1346
1347        let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1348
1349        // Should have tracked both servers
1350        assert_eq!(mcp_servers.len(), 2);
1351
1352        // Check the bad server
1353        let bad_server = mcp_servers.get("bad-server").unwrap();
1354        assert_eq!(bad_server.server_name, "bad-server");
1355        assert!(matches!(
1356            bad_server.state,
1357            McpConnectionState::Failed { .. }
1358        ));
1359
1360        // Check the good server (will also fail in tests since echo isn't an MCP server)
1361        let good_server = mcp_servers.get("good-server").unwrap();
1362        assert_eq!(good_server.server_name, "good-server");
1363        assert!(matches!(
1364            good_server.state,
1365            McpConnectionState::Failed { .. }
1366        ));
1367    }
1368
1369    #[test]
1370    fn test_backend_config_mcp_variant() {
1371        let mcp_config = BackendConfig::Mcp {
1372            server_name: "test-mcp".to_string(),
1373            transport: crate::tools::McpTransport::Stdio {
1374                command: "python".to_string(),
1375                args: vec!["-m".to_string(), "test_server".to_string()],
1376            },
1377            tool_filter: ToolFilter::All,
1378        };
1379
1380        let BackendConfig::Mcp {
1381            server_name,
1382            transport,
1383            ..
1384        } = mcp_config;
1385
1386        assert_eq!(server_name, "test-mcp");
1387        if let crate::tools::McpTransport::Stdio { command, args } = transport {
1388            assert_eq!(command, "python");
1389            assert_eq!(args.len(), 2);
1390        } else {
1391            panic!("Expected Stdio transport");
1392        }
1393    }
1394}