Skip to main content

steer_core/runners/
one_shot_runner.rs

1use serde::{Deserialize, Serialize};
2use tokio_util::sync::CancellationToken;
3use tracing::{error, info, warn};
4
5use crate::agents::default_agent_spec_id;
6use crate::app::conversation::{Message, UserContent};
7use crate::app::domain::event::SessionEvent;
8use crate::app::domain::runtime::{RuntimeError, RuntimeHandle};
9use crate::app::domain::types::SessionId;
10use crate::config::model::ModelId;
11use crate::error::{Error, Result};
12use crate::session::ToolApprovalPolicy;
13use crate::session::state::SessionConfig;
14use crate::tools::{DISPATCH_AGENT_TOOL_NAME, DispatchAgentParams, DispatchAgentTarget};
15use steer_tools::ToolCall;
16use steer_tools::tools::BASH_TOOL_NAME;
17use steer_tools::tools::bash::BashParams;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RunOnceResult {
21    pub final_message: Message,
22    pub session_id: SessionId,
23}
24
25pub struct OneShotRunner;
26
27impl Default for OneShotRunner {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl OneShotRunner {
34    pub fn new() -> Self {
35        Self
36    }
37
38    pub async fn run_in_session(
39        runtime: &RuntimeHandle,
40        session_id: SessionId,
41        message: String,
42        model: ModelId,
43    ) -> Result<RunOnceResult> {
44        Self::run_in_session_with_cancel(
45            runtime,
46            session_id,
47            message,
48            model,
49            CancellationToken::new(),
50        )
51        .await
52    }
53
54    pub async fn run_in_session_with_cancel(
55        runtime: &RuntimeHandle,
56        session_id: SessionId,
57        message: String,
58        model: ModelId,
59        cancel_token: CancellationToken,
60    ) -> Result<RunOnceResult> {
61        runtime.resume_session(session_id).await.map_err(|e| {
62            Error::InvalidOperation(format!("Failed to resume session {session_id}: {e}"))
63        })?;
64
65        let subscription = runtime.subscribe_events(session_id).await.map_err(|e| {
66            Error::InvalidOperation(format!(
67                "Failed to subscribe to session {session_id} events: {e}"
68            ))
69        })?;
70
71        let approval_policy = match runtime.get_session_state(session_id).await {
72            Ok(state) => state
73                .session_config
74                .map(|config| config.tool_config.approval_policy)
75                .unwrap_or_default(),
76            Err(err) => {
77                warn!(
78                    session_id = %session_id,
79                    error = %err,
80                    "Failed to load session approval policy; defaulting to deny"
81                );
82                ToolApprovalPolicy::default()
83            }
84        };
85
86        info!(session_id = %session_id, message = %message, "Sending message to session");
87
88        let (op_id, _message_id) = runtime
89            .submit_user_input(
90                session_id,
91                vec![UserContent::Text {
92                    text: message.clone(),
93                }],
94                model,
95            )
96            .await
97            .map_err(|e| {
98                Error::InvalidOperation(format!(
99                    "Failed to send message to session {session_id}: {e}"
100                ))
101            })?;
102
103        let cancel_task = {
104            let runtime = runtime.clone();
105            let cancel_token = cancel_token.clone();
106            tokio::spawn(async move {
107                cancel_token.cancelled().await;
108                if let Err(err) = runtime.cancel_operation(session_id, Some(op_id)).await {
109                    warn!(
110                        session_id = %session_id,
111                        error = %err,
112                        "Failed to cancel one-shot operation"
113                    );
114                }
115            })
116        };
117
118        let result =
119            Self::process_events(runtime, subscription, session_id, op_id, approval_policy).await;
120
121        cancel_task.abort();
122
123        if let Err(e) = runtime.suspend_session(session_id).await {
124            error!(session_id = %session_id, error = %e, "Failed to suspend session");
125        } else {
126            info!(session_id = %session_id, "Session suspended successfully");
127        }
128
129        result
130    }
131
132    pub async fn run_new_session(
133        runtime: &RuntimeHandle,
134        config: SessionConfig,
135        message: String,
136        model: ModelId,
137    ) -> Result<RunOnceResult> {
138        Self::run_new_session_with_cancel(runtime, config, message, model, CancellationToken::new())
139            .await
140    }
141
142    pub async fn run_new_session_with_cancel(
143        runtime: &RuntimeHandle,
144        config: SessionConfig,
145        message: String,
146        model: ModelId,
147        cancel_token: CancellationToken,
148    ) -> Result<RunOnceResult> {
149        let session_id = runtime
150            .create_session(config)
151            .await
152            .map_err(|e| Error::InvalidOperation(format!("Failed to create session: {e}")))?;
153
154        info!(session_id = %session_id, "Created new session for one-shot run");
155
156        Self::run_in_session_with_cancel(runtime, session_id, message, model, cancel_token).await
157    }
158
159    async fn process_events(
160        runtime: &RuntimeHandle,
161        mut subscription: crate::app::domain::runtime::SessionEventSubscription,
162        session_id: SessionId,
163        op_id: crate::app::domain::types::OpId,
164        mut approval_policy: ToolApprovalPolicy,
165    ) -> Result<RunOnceResult> {
166        let mut messages = Vec::new();
167        info!(session_id = %session_id, "Starting event processing loop");
168
169        while let Some(envelope) = subscription.recv().await {
170            match envelope.event {
171                SessionEvent::AssistantMessageAdded { message, model: _ } => {
172                    info!(
173                        session_id = %session_id,
174                        role = ?message.role(),
175                        id = %message.id(),
176                        "AssistantMessageAdded event"
177                    );
178                    messages.push(message);
179                }
180
181                SessionEvent::MessageUpdated { message } => {
182                    info!(
183                        session_id = %session_id,
184                        id = %message.id(),
185                        "MessageUpdated event"
186                    );
187                }
188
189                SessionEvent::OperationCompleted {
190                    op_id: completed_op,
191                } => {
192                    if completed_op != op_id {
193                        continue;
194                    }
195                    info!(
196                        session_id = %session_id,
197                        op_id = %completed_op,
198                        "OperationCompleted event received"
199                    );
200                    if !messages.is_empty() {
201                        info!(session_id = %session_id, "Final message received, exiting event loop");
202                        break;
203                    }
204                }
205
206                SessionEvent::OperationCancelled {
207                    op_id: cancelled_op,
208                    ..
209                } => {
210                    if cancelled_op != op_id {
211                        continue;
212                    }
213                    warn!(
214                        session_id = %session_id,
215                        op_id = %cancelled_op,
216                        "OperationCancelled event received"
217                    );
218                    return Err(Error::Cancelled);
219                }
220
221                SessionEvent::Error { message } => {
222                    error!(session_id = %session_id, error = %message, "Error event");
223                    return Err(Error::InvalidOperation(format!(
224                        "Error during processing: {message}"
225                    )));
226                }
227
228                SessionEvent::ApprovalRequested {
229                    request_id,
230                    tool_call,
231                } => {
232                    let approved = tool_is_preapproved(&tool_call, &approval_policy);
233                    if approved {
234                        info!(
235                            session_id = %session_id,
236                            request_id = %request_id,
237                            tool = %tool_call.name,
238                            "Auto-approving preapproved tool"
239                        );
240                    } else {
241                        warn!(
242                            session_id = %session_id,
243                            request_id = %request_id,
244                            tool = %tool_call.name,
245                            "Auto-denying unapproved tool"
246                        );
247                    }
248
249                    runtime
250                        .submit_tool_approval(session_id, request_id, approved, None)
251                        .await
252                        .map_err(|e| {
253                            Error::InvalidOperation(format!(
254                                "Failed to submit tool approval decision: {e}"
255                            ))
256                        })?;
257                }
258
259                SessionEvent::SessionConfigUpdated { config, .. } => {
260                    approval_policy = config.tool_config.approval_policy.clone();
261                }
262
263                _ => {}
264            }
265        }
266
267        match messages.last() {
268            Some(final_message) => {
269                info!(
270                    session_id = %session_id,
271                    message_count = messages.len(),
272                    "Returning final result"
273                );
274                Ok(RunOnceResult {
275                    final_message: final_message.clone(),
276                    session_id,
277                })
278            }
279            None => Err(Error::InvalidOperation("No message received".to_string())),
280        }
281    }
282}
283
284fn tool_is_preapproved(tool_call: &ToolCall, policy: &ToolApprovalPolicy) -> bool {
285    if policy.preapproved.tools.contains(&tool_call.name) {
286        return true;
287    }
288
289    if tool_call.name == DISPATCH_AGENT_TOOL_NAME {
290        let params = serde_json::from_value::<DispatchAgentParams>(tool_call.parameters.clone());
291        if let Ok(params) = params {
292            return match params.target {
293                DispatchAgentTarget::Resume { .. } => true,
294                DispatchAgentTarget::New { agent, .. } => {
295                    let agent_id = agent
296                        .as_deref()
297                        .filter(|value| !value.trim().is_empty())
298                        .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
299                    policy.is_dispatch_agent_pattern_preapproved(&agent_id)
300                }
301            };
302        }
303    }
304
305    if tool_call.name == BASH_TOOL_NAME {
306        let params = serde_json::from_value::<BashParams>(tool_call.parameters.clone());
307        if let Ok(params) = params {
308            return policy.is_bash_pattern_preapproved(&params.command);
309        }
310    }
311
312    false
313}
314
315impl From<RuntimeError> for Error {
316    fn from(e: RuntimeError) -> Self {
317        match e {
318            RuntimeError::SessionNotFound { session_id } => {
319                Error::InvalidOperation(format!("Session not found: {session_id}"))
320            }
321            RuntimeError::SessionAlreadyExists { session_id } => {
322                Error::InvalidOperation(format!("Session already exists: {session_id}"))
323            }
324            RuntimeError::InvalidInput { message } => Error::InvalidOperation(message),
325            RuntimeError::ChannelClosed => {
326                Error::InvalidOperation("Runtime channel closed".to_string())
327            }
328            RuntimeError::ShuttingDown => {
329                Error::InvalidOperation("Runtime is shutting down".to_string())
330            }
331            RuntimeError::Session(e) => Error::InvalidOperation(format!("Session error: {e}")),
332            RuntimeError::EventStore(e) => {
333                Error::InvalidOperation(format!("Event store error: {e}"))
334            }
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::api::Client as ApiClient;
343    use crate::api::{ApiError, CompletionResponse, Provider};
344    use crate::app::conversation::{AssistantContent, Message, MessageData};
345    use crate::app::domain::action::ApprovalDecision;
346    use crate::app::domain::runtime::RuntimeService;
347    use crate::app::domain::session::event_store::InMemoryEventStore;
348    use crate::app::validation::ValidatorRegistry;
349    use crate::config::model::builtin;
350    use crate::session::SessionPolicyOverrides;
351    use crate::session::ToolApprovalPolicy;
352    use crate::session::state::{
353        ApprovalRules, SessionToolConfig, UnapprovedBehavior, WorkspaceConfig,
354    };
355    use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
356    use crate::tools::{BackendRegistry, ToolExecutor};
357    use dotenvy::dotenv;
358    use serde_json::json;
359    use std::collections::{HashMap, HashSet};
360    use std::sync::Arc;
361    use std::sync::Mutex as StdMutex;
362    use std::time::Duration;
363    use steer_tools::ToolCall;
364    use steer_tools::tools::BASH_TOOL_NAME;
365    use tokio_util::sync::CancellationToken;
366
367    #[derive(Clone)]
368    struct ToolCallThenTextProvider {
369        tool_call: ToolCall,
370        final_text: String,
371        call_count: Arc<StdMutex<usize>>,
372        first_call_delay: Duration,
373    }
374
375    impl ToolCallThenTextProvider {
376        fn new(tool_call: ToolCall, final_text: impl Into<String>) -> Self {
377            Self::new_with_delay(tool_call, final_text, Duration::ZERO)
378        }
379
380        fn new_with_delay(
381            tool_call: ToolCall,
382            final_text: impl Into<String>,
383            first_call_delay: Duration,
384        ) -> Self {
385            Self {
386                tool_call,
387                final_text: final_text.into(),
388                call_count: Arc::new(StdMutex::new(0)),
389                first_call_delay,
390            }
391        }
392    }
393
394    #[async_trait::async_trait]
395    impl Provider for ToolCallThenTextProvider {
396        fn name(&self) -> &'static str {
397            "stub-tool-call"
398        }
399
400        async fn complete(
401            &self,
402            _model_id: &crate::config::model::ModelId,
403            _messages: Vec<Message>,
404            _system: Option<crate::app::SystemContext>,
405            _tools: Option<Vec<steer_tools::ToolSchema>>,
406            _call_options: Option<crate::config::model::ModelParameters>,
407            _token: CancellationToken,
408        ) -> std::result::Result<CompletionResponse, ApiError> {
409            let call_index = {
410                let mut count = self
411                    .call_count
412                    .lock()
413                    .expect("tool call counter lock poisoned");
414                let idx = *count;
415                *count += 1;
416                idx
417            };
418
419            if call_index == 0 && !self.first_call_delay.is_zero() {
420                tokio::time::sleep(self.first_call_delay).await;
421            }
422
423            if call_index == 0 {
424                Ok(CompletionResponse {
425                    content: vec![AssistantContent::ToolCall {
426                        tool_call: self.tool_call.clone(),
427                        thought_signature: None,
428                    }],
429                    usage: None,
430                })
431            } else {
432                Ok(CompletionResponse {
433                    content: vec![AssistantContent::Text {
434                        text: self.final_text.clone(),
435                    }],
436                    usage: None,
437                })
438            }
439        }
440    }
441
442    async fn create_test_runtime() -> RuntimeService {
443        let event_store = Arc::new(InMemoryEventStore::new());
444        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
445        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
446        let api_client = Arc::new(ApiClient::new_with_deps(
447            crate::test_utils::test_llm_config_provider().unwrap(),
448            provider_registry,
449            model_registry,
450        ));
451
452        let tool_executor = Arc::new(ToolExecutor::with_components(
453            Arc::new(BackendRegistry::new()),
454            Arc::new(ValidatorRegistry::new()),
455        ));
456
457        RuntimeService::spawn(event_store, api_client, tool_executor)
458    }
459
460    fn create_test_session_config() -> SessionConfig {
461        SessionConfig {
462            default_model: builtin::claude_sonnet_4_5(),
463            workspace: WorkspaceConfig::default(),
464            workspace_ref: None,
465            workspace_id: None,
466            repo_ref: None,
467            parent_session_id: None,
468            workspace_name: None,
469            tool_config: SessionToolConfig::default(),
470            system_prompt: None,
471            primary_agent_id: None,
472            policy_overrides: SessionPolicyOverrides::empty(),
473            title: None,
474            metadata: std::collections::HashMap::new(),
475            auto_compaction: crate::session::state::AutoCompactionConfig::default(),
476        }
477    }
478
479    fn create_test_tool_approval_policy() -> ToolApprovalPolicy {
480        let tool_names = READ_ONLY_TOOL_NAMES
481            .iter()
482            .map(|name| (*name).to_string())
483            .collect();
484        ToolApprovalPolicy {
485            default_behavior: UnapprovedBehavior::Prompt,
486            preapproved: ApprovalRules {
487                tools: tool_names,
488                per_tool: std::collections::HashMap::new(),
489            },
490        }
491    }
492
493    #[test]
494    fn tool_is_preapproved_allows_whitelisted_tool() {
495        let policy = create_test_tool_approval_policy();
496        let tool_call = ToolCall {
497            id: "tc_read".to_string(),
498            name: READ_ONLY_TOOL_NAMES[0].to_string(),
499            parameters: json!({}),
500        };
501
502        assert!(tool_is_preapproved(&tool_call, &policy));
503    }
504
505    #[test]
506    fn tool_is_preapproved_allows_bash_pattern() {
507        use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
508
509        let mut per_tool = HashMap::new();
510        per_tool.insert(
511            "bash".to_string(),
512            ToolRule::Bash {
513                patterns: vec!["echo *".to_string()],
514            },
515        );
516
517        let policy = ToolApprovalPolicy {
518            default_behavior: UnapprovedBehavior::Prompt,
519            preapproved: ApprovalRules {
520                tools: HashSet::new(),
521                per_tool,
522            },
523        };
524
525        let tool_call = ToolCall {
526            id: "tc_bash".to_string(),
527            name: BASH_TOOL_NAME.to_string(),
528            parameters: json!({ "command": "echo hello" }),
529        };
530
531        assert!(tool_is_preapproved(&tool_call, &policy));
532    }
533
534    #[test]
535    fn tool_is_preapproved_allows_dispatch_agent_pattern() {
536        use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
537
538        let mut per_tool = HashMap::new();
539        per_tool.insert(
540            "dispatch_agent".to_string(),
541            ToolRule::DispatchAgent {
542                agent_patterns: vec!["explore".to_string()],
543            },
544        );
545
546        let policy = ToolApprovalPolicy {
547            default_behavior: UnapprovedBehavior::Prompt,
548            preapproved: ApprovalRules {
549                tools: HashSet::new(),
550                per_tool,
551            },
552        };
553
554        let tool_call = ToolCall {
555            id: "tc_dispatch".to_string(),
556            name: DISPATCH_AGENT_TOOL_NAME.to_string(),
557            parameters: json!({
558                "prompt": "find files",
559                "target": {
560                    "session": "new",
561                    "workspace": {
562                        "location": "current"
563                    },
564                    "agent": "explore"
565                }
566            }),
567        };
568
569        assert!(tool_is_preapproved(&tool_call, &policy));
570    }
571
572    #[test]
573    fn tool_is_preapproved_denies_unlisted_tool() {
574        let policy = create_test_tool_approval_policy();
575        let tool_call = ToolCall {
576            id: "tc_other".to_string(),
577            name: "bash".to_string(),
578            parameters: json!({ "command": "rm -rf /" }),
579        };
580
581        assert!(!tool_is_preapproved(&tool_call, &policy));
582    }
583
584    #[tokio::test]
585    async fn run_new_session_denies_unapproved_tool_requests() {
586        let event_store = Arc::new(InMemoryEventStore::new());
587        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
588        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
589        let api_client = Arc::new(ApiClient::new_with_deps(
590            crate::test_utils::test_llm_config_provider().unwrap(),
591            provider_registry,
592            model_registry.clone(),
593        ));
594
595        let tool_call = ToolCall {
596            id: "tc_1".to_string(),
597            name: "bash".to_string(),
598            parameters: json!({ "command": "echo denied" }),
599        };
600        api_client.insert_test_provider(
601            builtin::claude_sonnet_4_5().provider.clone(),
602            Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
603        );
604
605        let tool_executor = Arc::new(ToolExecutor::with_components(
606            Arc::new(BackendRegistry::new()),
607            Arc::new(ValidatorRegistry::new()),
608        ));
609        let runtime = RuntimeService::spawn(event_store, api_client, tool_executor);
610
611        let mut config = create_test_session_config();
612        config.tool_config.approval_policy = ToolApprovalPolicy {
613            default_behavior: UnapprovedBehavior::Prompt,
614            preapproved: ApprovalRules {
615                tools: HashSet::new(),
616                per_tool: HashMap::new(),
617            },
618        };
619
620        let model = builtin::claude_sonnet_4_5();
621        let result = OneShotRunner::run_new_session(
622            &runtime.handle,
623            config,
624            "Trigger tool call".to_string(),
625            model,
626        )
627        .await
628        .expect("run_new_session should complete");
629
630        let events = runtime
631            .handle
632            .load_events_after(result.session_id, 0)
633            .await
634            .expect("load events");
635
636        let mut saw_request = false;
637        let mut saw_decision = false;
638        let mut saw_denied = false;
639
640        for (_, event) in events {
641            match event {
642                SessionEvent::ApprovalRequested { .. } => saw_request = true,
643                SessionEvent::ApprovalDecided { decision, .. } => {
644                    saw_decision = true;
645                    if decision == ApprovalDecision::Denied {
646                        saw_denied = true;
647                    }
648                }
649                _ => {}
650            }
651        }
652
653        assert!(saw_request, "expected ApprovalRequested event");
654        assert!(saw_decision, "expected ApprovalDecided event");
655        assert!(saw_denied, "expected denied decision");
656
657        runtime.shutdown().await;
658    }
659
660    #[tokio::test]
661    async fn run_new_session_updates_approval_policy_after_agent_switch_event() {
662        let event_store = Arc::new(InMemoryEventStore::new());
663        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
664        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
665        let api_client = Arc::new(ApiClient::new_with_deps(
666            crate::test_utils::test_llm_config_provider().unwrap(),
667            provider_registry,
668            model_registry,
669        ));
670
671        let tool_call = ToolCall {
672            id: "tc_switch".to_string(),
673            name: "bash".to_string(),
674            parameters: json!({ "command": "echo switched" }),
675        };
676        api_client.insert_test_provider(
677            builtin::claude_sonnet_4_5().provider.clone(),
678            Arc::new(ToolCallThenTextProvider::new_with_delay(
679                tool_call,
680                "done",
681                Duration::from_millis(75),
682            )),
683        );
684
685        let tool_executor = Arc::new(ToolExecutor::with_components(
686            Arc::new(BackendRegistry::new()),
687            Arc::new(ValidatorRegistry::new()),
688        ));
689        let runtime = RuntimeService::spawn(event_store, api_client, tool_executor);
690
691        let mut config = create_test_session_config();
692        config.tool_config.approval_policy = ToolApprovalPolicy {
693            default_behavior: UnapprovedBehavior::Prompt,
694            preapproved: ApprovalRules {
695                tools: HashSet::new(),
696                per_tool: HashMap::new(),
697            },
698        };
699
700        let session_id = runtime
701            .handle
702            .create_session(config)
703            .await
704            .expect("create session");
705
706        let runtime_handle_for_run = runtime.handle.clone();
707        let run_task = tokio::spawn(async move {
708            OneShotRunner::run_in_session(
709                &runtime_handle_for_run,
710                session_id,
711                "Trigger tool call".to_string(),
712                builtin::claude_sonnet_4_5(),
713            )
714            .await
715        });
716
717        tokio::time::sleep(Duration::from_millis(20)).await;
718
719        runtime
720            .handle
721            .switch_primary_agent(session_id, "yolo".to_string())
722            .await
723            .expect("switch to yolo");
724
725        let result = run_task
726            .await
727            .expect("run task join")
728            .expect("run_in_session should complete");
729
730        let events = runtime
731            .handle
732            .load_events_after(result.session_id, 0)
733            .await
734            .expect("load events");
735
736        let mut saw_approval_requested = false;
737        let mut saw_session_config_updated = false;
738        let mut saw_tool_started = false;
739        let mut saw_approval_decision = false;
740
741        for (_, event) in events {
742            match event {
743                SessionEvent::ApprovalRequested { .. } => saw_approval_requested = true,
744                SessionEvent::SessionConfigUpdated {
745                    primary_agent_id, ..
746                } if primary_agent_id == "yolo" => saw_session_config_updated = true,
747                SessionEvent::ToolCallStarted { id, .. } if id.as_str() == "tc_switch" => {
748                    saw_tool_started = true
749                }
750                SessionEvent::ApprovalDecided { .. } => saw_approval_decision = true,
751                _ => {}
752            }
753        }
754
755        assert!(
756            saw_session_config_updated,
757            "expected in-flight session config update event"
758        );
759        assert!(
760            saw_tool_started,
761            "expected bash tool call to auto-run after policy switched to yolo"
762        );
763        assert!(
764            !saw_approval_requested,
765            "expected no approval request after switching to yolo before the first tool call"
766        );
767        assert!(
768            !saw_approval_decision,
769            "expected no approval decision when no approval was requested"
770        );
771
772        runtime.shutdown().await;
773    }
774
775    #[tokio::test]
776    #[ignore = "Requires API keys and network access"]
777    async fn test_run_new_session_basic() {
778        dotenv().ok();
779        let runtime = create_test_runtime().await;
780
781        let mut config = create_test_session_config();
782        config.tool_config = SessionToolConfig::read_only();
783        config.tool_config.approval_policy = create_test_tool_approval_policy();
784        config
785            .metadata
786            .insert("mode".to_string(), "headless".to_string());
787
788        let model = builtin::claude_sonnet_4_5();
789        let result = OneShotRunner::run_new_session(
790            &runtime.handle,
791            config,
792            "What is 2 + 2?".to_string(),
793            model,
794        )
795        .await;
796
797        let result = tokio::time::timeout(std::time::Duration::from_secs(30), async { result })
798            .await
799            .expect("Timed out waiting for response")
800            .expect("run_new_session failed");
801
802        assert!(!result.final_message.id().is_empty());
803        println!("New session run succeeded: {:?}", result.final_message);
804
805        let content = match &result.final_message.data {
806            MessageData::Assistant { content, .. } => content,
807            _ => panic!("expected assistant message, got {:?}", result.final_message),
808        };
809        let text_content = content.iter().find_map(|c| match c {
810            AssistantContent::Text { text } => Some(text),
811            _ => None,
812        });
813        let content = text_content.expect("No text content found in assistant message");
814        assert!(!content.is_empty(), "Response should not be empty");
815        assert!(
816            content.contains('4'),
817            "Expected response to contain '4', got: {content}"
818        );
819
820        runtime.shutdown().await;
821    }
822
823    #[tokio::test]
824    async fn test_session_creation() {
825        let runtime = create_test_runtime().await;
826
827        let mut config = create_test_session_config();
828        config.tool_config.approval_policy = create_test_tool_approval_policy();
829        config
830            .metadata
831            .insert("test".to_string(), "value".to_string());
832
833        let session_id = runtime.handle.create_session(config).await.unwrap();
834
835        assert!(runtime.handle.is_session_active(session_id).await.unwrap());
836
837        let state = runtime.handle.get_session_state(session_id).await.unwrap();
838        assert_eq!(
839            state.session_config.as_ref().unwrap().metadata.get("test"),
840            Some(&"value".to_string())
841        );
842
843        runtime.shutdown().await;
844    }
845
846    #[tokio::test]
847    async fn test_run_in_session_nonexistent_session() {
848        let runtime = create_test_runtime().await;
849
850        let fake_session_id = SessionId::new();
851        let model = builtin::claude_sonnet_4_5();
852        let result = OneShotRunner::run_in_session(
853            &runtime.handle,
854            fake_session_id,
855            "Test message".to_string(),
856            model,
857        )
858        .await;
859
860        assert!(result.is_err());
861        let err = result.err().unwrap().to_string();
862        assert!(
863            err.contains("not found") || err.contains("Session"),
864            "Expected session not found error, got: {err}"
865        );
866
867        runtime.shutdown().await;
868    }
869
870    #[tokio::test]
871    #[ignore = "Requires API keys and network access"]
872    async fn test_run_in_session_with_real_api() {
873        dotenv().ok();
874        let runtime = create_test_runtime().await;
875
876        let mut config = create_test_session_config();
877        config.tool_config = SessionToolConfig::read_only();
878        config.tool_config.approval_policy = create_test_tool_approval_policy();
879        config
880            .metadata
881            .insert("test".to_string(), "api_test".to_string());
882
883        let session_id = runtime.handle.create_session(config).await.unwrap();
884        let model = builtin::claude_sonnet_4_5();
885
886        let result = OneShotRunner::run_in_session(
887            &runtime.handle,
888            session_id,
889            "What is the capital of France?".to_string(),
890            model,
891        )
892        .await;
893
894        match result {
895            Ok(run_result) => {
896                println!("Session run succeeded: {:?}", run_result.final_message);
897
898                let content = match &run_result.final_message.data {
899                    MessageData::Assistant { content, .. } => content.clone(),
900                    _ => panic!(
901                        "expected assistant message, got {:?}",
902                        run_result.final_message
903                    ),
904                };
905                let text_content = content.iter().find_map(|c| match c {
906                    AssistantContent::Text { text } => Some(text),
907                    _ => None,
908                });
909                let content = text_content.expect("expected text response in assistant message");
910                assert!(!content.is_empty(), "Response should not be empty");
911                assert!(
912                    content.to_lowercase().contains("paris"),
913                    "Expected response to contain 'Paris', got: {content}"
914                );
915            }
916            Err(e) => {
917                println!("Session run failed (expected if no API key): {e}");
918                assert!(
919                    e.to_string().contains("API key")
920                        || e.to_string().contains("authentication")
921                        || e.to_string().contains("timed out"),
922                    "Unexpected error: {e}"
923                );
924            }
925        }
926
927        runtime.shutdown().await;
928    }
929
930    #[tokio::test]
931    #[ignore = "Requires API keys and network access"]
932    async fn test_run_in_session_preserves_context() {
933        dotenv().ok();
934        let runtime = create_test_runtime().await;
935
936        let mut config = create_test_session_config();
937        config.tool_config = SessionToolConfig::read_only();
938        config.tool_config.approval_policy = create_test_tool_approval_policy();
939        config
940            .metadata
941            .insert("test".to_string(), "context_test".to_string());
942
943        let session_id = runtime.handle.create_session(config).await.unwrap();
944        let model = builtin::claude_sonnet_4_5();
945
946        let result1 = OneShotRunner::run_in_session(
947            &runtime.handle,
948            session_id,
949            "My name is Alice and I like pizza.".to_string(),
950            model.clone(),
951        )
952        .await
953        .expect("First session run should succeed");
954
955        println!("First interaction: {:?}", result1.final_message);
956
957        runtime.handle.resume_session(session_id).await.unwrap();
958
959        let result2 = OneShotRunner::run_in_session(
960            &runtime.handle,
961            session_id,
962            "What is my name and what do I like?".to_string(),
963            model,
964        )
965        .await
966        .expect("Second session run should succeed");
967
968        println!("Second interaction: {:?}", result2.final_message);
969
970        match &result2.final_message.data {
971            MessageData::Assistant { content, .. } => {
972                let text_content = content.iter().find_map(|c| match c {
973                    AssistantContent::Text { text } => Some(text),
974                    _ => None,
975                });
976
977                match text_content {
978                    Some(content) => {
979                        assert!(!content.is_empty(), "Response should not be empty");
980                        let content_lower = content.to_lowercase();
981
982                        assert!(
983                            content_lower.contains("alice") || content_lower.contains("name"),
984                            "Expected response to reference the name or context, got: {content}"
985                        );
986                    }
987                    None => {
988                        panic!("expected text response in assistant message");
989                    }
990                }
991            }
992            _ => {
993                panic!(
994                    "expected assistant message, got {:?}",
995                    result2.final_message
996                );
997            }
998        }
999
1000        runtime.shutdown().await;
1001    }
1002
1003    #[tokio::test]
1004    #[ignore = "Requires API keys and network access"]
1005    async fn test_run_new_session_with_tool_usage() {
1006        dotenv().ok();
1007        let runtime = create_test_runtime().await;
1008
1009        let mut config = create_test_session_config();
1010        config.tool_config = SessionToolConfig::read_only();
1011        config.tool_config.approval_policy = create_test_tool_approval_policy();
1012        let model = builtin::claude_sonnet_4_5();
1013
1014        let result = OneShotRunner::run_new_session(
1015            &runtime.handle,
1016            config,
1017            "List the files in the current directory".to_string(),
1018            model,
1019        )
1020        .await
1021        .expect("New session run with tools should succeed with valid API key");
1022
1023        assert!(!result.final_message.id().is_empty());
1024        println!(
1025            "New session run with tools succeeded: {:?}",
1026            result.final_message
1027        );
1028
1029        let has_content = match &result.final_message.data {
1030            MessageData::Assistant { content, .. } => content.iter().any(|c| match c {
1031                AssistantContent::Text { text } => !text.is_empty(),
1032                _ => true,
1033            }),
1034            _ => false,
1035        };
1036        assert!(has_content, "Response should have some content");
1037
1038        runtime.shutdown().await;
1039    }
1040}