Skip to main content

steer_core/runners/
one_shot_runner.rs

1use serde::{Deserialize, Serialize};
2use tokio_util::sync::CancellationToken;
3use tracing::{error, info, warn};
4
5use crate::agents::default_agent_spec_id;
6use crate::app::conversation::Message;
7use crate::app::domain::event::SessionEvent;
8use crate::app::domain::runtime::{RuntimeError, RuntimeHandle};
9use crate::app::domain::types::SessionId;
10use crate::config::model::ModelId;
11use crate::error::{Error, Result};
12use crate::session::ToolApprovalPolicy;
13use crate::session::state::SessionConfig;
14use crate::tools::{DISPATCH_AGENT_TOOL_NAME, DispatchAgentParams, DispatchAgentTarget};
15use steer_tools::ToolCall;
16use steer_tools::tools::BASH_TOOL_NAME;
17use steer_tools::tools::bash::BashParams;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RunOnceResult {
21    pub final_message: Message,
22    pub session_id: SessionId,
23}
24
25pub struct OneShotRunner;
26
27impl Default for OneShotRunner {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl OneShotRunner {
34    pub fn new() -> Self {
35        Self
36    }
37
38    pub async fn run_in_session(
39        runtime: &RuntimeHandle,
40        session_id: SessionId,
41        message: String,
42        model: ModelId,
43    ) -> Result<RunOnceResult> {
44        Self::run_in_session_with_cancel(
45            runtime,
46            session_id,
47            message,
48            model,
49            CancellationToken::new(),
50        )
51        .await
52    }
53
54    pub async fn run_in_session_with_cancel(
55        runtime: &RuntimeHandle,
56        session_id: SessionId,
57        message: String,
58        model: ModelId,
59        cancel_token: CancellationToken,
60    ) -> Result<RunOnceResult> {
61        runtime.resume_session(session_id).await.map_err(|e| {
62            Error::InvalidOperation(format!("Failed to resume session {session_id}: {e}"))
63        })?;
64
65        let subscription = runtime.subscribe_events(session_id).await.map_err(|e| {
66            Error::InvalidOperation(format!(
67                "Failed to subscribe to session {session_id} events: {e}"
68            ))
69        })?;
70
71        let approval_policy = match runtime.get_session_state(session_id).await {
72            Ok(state) => state
73                .session_config
74                .map(|config| config.tool_config.approval_policy)
75                .unwrap_or_default(),
76            Err(err) => {
77                warn!(
78                    session_id = %session_id,
79                    error = %err,
80                    "Failed to load session approval policy; defaulting to deny"
81                );
82                ToolApprovalPolicy::default()
83            }
84        };
85
86        info!(session_id = %session_id, message = %message, "Sending message to session");
87
88        let op_id = runtime
89            .submit_user_input(session_id, message, model)
90            .await
91            .map_err(|e| {
92                Error::InvalidOperation(format!(
93                    "Failed to send message to session {session_id}: {e}"
94                ))
95            })?;
96
97        let cancel_task = {
98            let runtime = runtime.clone();
99            let cancel_token = cancel_token.clone();
100            tokio::spawn(async move {
101                cancel_token.cancelled().await;
102                if let Err(err) = runtime.cancel_operation(session_id, Some(op_id)).await {
103                    warn!(
104                        session_id = %session_id,
105                        error = %err,
106                        "Failed to cancel one-shot operation"
107                    );
108                }
109            })
110        };
111
112        let result =
113            Self::process_events(runtime, subscription, session_id, op_id, approval_policy).await;
114
115        cancel_task.abort();
116
117        if let Err(e) = runtime.suspend_session(session_id).await {
118            error!(session_id = %session_id, error = %e, "Failed to suspend session");
119        } else {
120            info!(session_id = %session_id, "Session suspended successfully");
121        }
122
123        result
124    }
125
126    pub async fn run_new_session(
127        runtime: &RuntimeHandle,
128        config: SessionConfig,
129        message: String,
130        model: ModelId,
131    ) -> Result<RunOnceResult> {
132        Self::run_new_session_with_cancel(runtime, config, message, model, CancellationToken::new())
133            .await
134    }
135
136    pub async fn run_new_session_with_cancel(
137        runtime: &RuntimeHandle,
138        config: SessionConfig,
139        message: String,
140        model: ModelId,
141        cancel_token: CancellationToken,
142    ) -> Result<RunOnceResult> {
143        let session_id = runtime
144            .create_session(config)
145            .await
146            .map_err(|e| Error::InvalidOperation(format!("Failed to create session: {e}")))?;
147
148        info!(session_id = %session_id, "Created new session for one-shot run");
149
150        Self::run_in_session_with_cancel(runtime, session_id, message, model, cancel_token).await
151    }
152
153    async fn process_events(
154        runtime: &RuntimeHandle,
155        mut subscription: crate::app::domain::runtime::SessionEventSubscription,
156        session_id: SessionId,
157        op_id: crate::app::domain::types::OpId,
158        approval_policy: ToolApprovalPolicy,
159    ) -> Result<RunOnceResult> {
160        let mut messages = Vec::new();
161        info!(session_id = %session_id, "Starting event processing loop");
162
163        while let Some(envelope) = subscription.recv().await {
164            match envelope.event {
165                SessionEvent::AssistantMessageAdded { message, model: _ } => {
166                    info!(
167                        session_id = %session_id,
168                        role = ?message.role(),
169                        id = %message.id(),
170                        "AssistantMessageAdded event"
171                    );
172                    messages.push(message);
173                }
174
175                SessionEvent::MessageUpdated { message } => {
176                    info!(
177                        session_id = %session_id,
178                        id = %message.id(),
179                        "MessageUpdated event"
180                    );
181                }
182
183                SessionEvent::OperationCompleted {
184                    op_id: completed_op,
185                } => {
186                    if completed_op != op_id {
187                        continue;
188                    }
189                    info!(
190                        session_id = %session_id,
191                        op_id = %completed_op,
192                        "OperationCompleted event received"
193                    );
194                    if !messages.is_empty() {
195                        info!(session_id = %session_id, "Final message received, exiting event loop");
196                        break;
197                    }
198                }
199
200                SessionEvent::OperationCancelled {
201                    op_id: cancelled_op,
202                    ..
203                } => {
204                    if cancelled_op != op_id {
205                        continue;
206                    }
207                    warn!(
208                        session_id = %session_id,
209                        op_id = %cancelled_op,
210                        "OperationCancelled event received"
211                    );
212                    return Err(Error::Cancelled);
213                }
214
215                SessionEvent::Error { message } => {
216                    error!(session_id = %session_id, error = %message, "Error event");
217                    return Err(Error::InvalidOperation(format!(
218                        "Error during processing: {message}"
219                    )));
220                }
221
222                SessionEvent::ApprovalRequested {
223                    request_id,
224                    tool_call,
225                } => {
226                    let approved = tool_is_preapproved(&tool_call, &approval_policy);
227                    if approved {
228                        info!(
229                            session_id = %session_id,
230                            request_id = %request_id,
231                            tool = %tool_call.name,
232                            "Auto-approving preapproved tool"
233                        );
234                    } else {
235                        warn!(
236                            session_id = %session_id,
237                            request_id = %request_id,
238                            tool = %tool_call.name,
239                            "Auto-denying unapproved tool"
240                        );
241                    }
242
243                    runtime
244                        .submit_tool_approval(session_id, request_id, approved, None)
245                        .await
246                        .map_err(|e| {
247                            Error::InvalidOperation(format!(
248                                "Failed to submit tool approval decision: {e}"
249                            ))
250                        })?;
251                }
252
253                _ => {}
254            }
255        }
256
257        match messages.last() {
258            Some(final_message) => {
259                info!(
260                    session_id = %session_id,
261                    message_count = messages.len(),
262                    "Returning final result"
263                );
264                Ok(RunOnceResult {
265                    final_message: final_message.clone(),
266                    session_id,
267                })
268            }
269            None => Err(Error::InvalidOperation("No message received".to_string())),
270        }
271    }
272}
273
274fn tool_is_preapproved(tool_call: &ToolCall, policy: &ToolApprovalPolicy) -> bool {
275    if policy.preapproved.tools.contains(&tool_call.name) {
276        return true;
277    }
278
279    if tool_call.name == DISPATCH_AGENT_TOOL_NAME {
280        let params = serde_json::from_value::<DispatchAgentParams>(tool_call.parameters.clone());
281        if let Ok(params) = params {
282            return match params.target {
283                DispatchAgentTarget::Resume { .. } => true,
284                DispatchAgentTarget::New { agent, .. } => {
285                    let agent_id = agent
286                        .as_deref()
287                        .filter(|value| !value.trim().is_empty())
288                        .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
289                    policy.is_dispatch_agent_pattern_preapproved(&agent_id)
290                }
291            };
292        }
293    }
294
295    if tool_call.name == BASH_TOOL_NAME {
296        let params = serde_json::from_value::<BashParams>(tool_call.parameters.clone());
297        if let Ok(params) = params {
298            return policy.is_bash_pattern_preapproved(&params.command);
299        }
300    }
301
302    false
303}
304
305impl From<RuntimeError> for Error {
306    fn from(e: RuntimeError) -> Self {
307        match e {
308            RuntimeError::SessionNotFound { session_id } => {
309                Error::InvalidOperation(format!("Session not found: {session_id}"))
310            }
311            RuntimeError::SessionAlreadyExists { session_id } => {
312                Error::InvalidOperation(format!("Session already exists: {session_id}"))
313            }
314            RuntimeError::InvalidInput { message } => Error::InvalidOperation(message),
315            RuntimeError::ChannelClosed => {
316                Error::InvalidOperation("Runtime channel closed".to_string())
317            }
318            RuntimeError::ShuttingDown => {
319                Error::InvalidOperation("Runtime is shutting down".to_string())
320            }
321            RuntimeError::Session(e) => Error::InvalidOperation(format!("Session error: {e}")),
322            RuntimeError::EventStore(e) => {
323                Error::InvalidOperation(format!("Event store error: {e}"))
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::api::Client as ApiClient;
333    use crate::api::{ApiError, CompletionResponse, Provider};
334    use crate::app::conversation::{AssistantContent, Message, MessageData};
335    use crate::app::domain::action::ApprovalDecision;
336    use crate::app::domain::runtime::RuntimeService;
337    use crate::app::domain::session::event_store::InMemoryEventStore;
338    use crate::app::validation::ValidatorRegistry;
339    use crate::config::model::builtin;
340    use crate::session::SessionPolicyOverrides;
341    use crate::session::ToolApprovalPolicy;
342    use crate::session::state::{
343        ApprovalRules, SessionToolConfig, UnapprovedBehavior, WorkspaceConfig,
344    };
345    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
346    use crate::tools::{BackendRegistry, ToolExecutor};
347    use dotenvy::dotenv;
348    use serde_json::json;
349    use std::collections::{HashMap, HashSet};
350    use std::sync::Arc;
351    use std::sync::Mutex as StdMutex;
352    use steer_tools::ToolCall;
353    use steer_tools::tools::BASH_TOOL_NAME;
354    use tokio_util::sync::CancellationToken;
355
356    #[derive(Clone)]
357    struct ToolCallThenTextProvider {
358        tool_call: ToolCall,
359        final_text: String,
360        call_count: Arc<StdMutex<usize>>,
361    }
362
363    impl ToolCallThenTextProvider {
364        fn new(tool_call: ToolCall, final_text: impl Into<String>) -> Self {
365            Self {
366                tool_call,
367                final_text: final_text.into(),
368                call_count: Arc::new(StdMutex::new(0)),
369            }
370        }
371    }
372
373    #[async_trait::async_trait]
374    impl Provider for ToolCallThenTextProvider {
375        fn name(&self) -> &'static str {
376            "stub-tool-call"
377        }
378
379        async fn complete(
380            &self,
381            _model_id: &crate::config::model::ModelId,
382            _messages: Vec<Message>,
383            _system: Option<crate::app::SystemContext>,
384            _tools: Option<Vec<steer_tools::ToolSchema>>,
385            _call_options: Option<crate::config::model::ModelParameters>,
386            _token: CancellationToken,
387        ) -> std::result::Result<CompletionResponse, ApiError> {
388            let mut count = self
389                .call_count
390                .lock()
391                .expect("tool call counter lock poisoned");
392            let response = if *count == 0 {
393                CompletionResponse {
394                    content: vec![AssistantContent::ToolCall {
395                        tool_call: self.tool_call.clone(),
396                        thought_signature: None,
397                    }],
398                }
399            } else {
400                CompletionResponse {
401                    content: vec![AssistantContent::Text {
402                        text: self.final_text.clone(),
403                    }],
404                }
405            };
406            *count += 1;
407            Ok(response)
408        }
409    }
410
411    async fn create_test_runtime() -> RuntimeService {
412        let event_store = Arc::new(InMemoryEventStore::new());
413        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
414        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
415        let api_client = Arc::new(ApiClient::new_with_deps(
416            crate::test_utils::test_llm_config_provider().unwrap(),
417            provider_registry,
418            model_registry,
419        ));
420
421        let tool_executor = Arc::new(ToolExecutor::with_components(
422            Arc::new(BackendRegistry::new()),
423            Arc::new(ValidatorRegistry::new()),
424        ));
425
426        RuntimeService::spawn(event_store, api_client, tool_executor)
427    }
428
429    fn create_test_session_config() -> SessionConfig {
430        SessionConfig {
431            default_model: builtin::claude_sonnet_4_5(),
432            workspace: WorkspaceConfig::default(),
433            workspace_ref: None,
434            workspace_id: None,
435            repo_ref: None,
436            parent_session_id: None,
437            workspace_name: None,
438            tool_config: SessionToolConfig::default(),
439            system_prompt: None,
440            primary_agent_id: None,
441            policy_overrides: SessionPolicyOverrides::empty(),
442            metadata: std::collections::HashMap::new(),
443        }
444    }
445
446    fn create_test_tool_approval_policy() -> ToolApprovalPolicy {
447        let tool_names = READ_ONLY_TOOL_NAMES
448            .iter()
449            .map(|name| (*name).to_string())
450            .collect();
451        ToolApprovalPolicy {
452            default_behavior: UnapprovedBehavior::Prompt,
453            preapproved: ApprovalRules {
454                tools: tool_names,
455                per_tool: std::collections::HashMap::new(),
456            },
457        }
458    }
459
460    #[test]
461    fn tool_is_preapproved_allows_whitelisted_tool() {
462        let policy = create_test_tool_approval_policy();
463        let tool_call = ToolCall {
464            id: "tc_read".to_string(),
465            name: READ_ONLY_TOOL_NAMES[0].to_string(),
466            parameters: json!({}),
467        };
468
469        assert!(tool_is_preapproved(&tool_call, &policy));
470    }
471
472    #[test]
473    fn tool_is_preapproved_allows_bash_pattern() {
474        use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
475
476        let mut per_tool = HashMap::new();
477        per_tool.insert(
478            "bash".to_string(),
479            ToolRule::Bash {
480                patterns: vec!["echo *".to_string()],
481            },
482        );
483
484        let policy = ToolApprovalPolicy {
485            default_behavior: UnapprovedBehavior::Prompt,
486            preapproved: ApprovalRules {
487                tools: HashSet::new(),
488                per_tool,
489            },
490        };
491
492        let tool_call = ToolCall {
493            id: "tc_bash".to_string(),
494            name: BASH_TOOL_NAME.to_string(),
495            parameters: json!({ "command": "echo hello" }),
496        };
497
498        assert!(tool_is_preapproved(&tool_call, &policy));
499    }
500
501    #[test]
502    fn tool_is_preapproved_allows_dispatch_agent_pattern() {
503        use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
504
505        let mut per_tool = HashMap::new();
506        per_tool.insert(
507            "dispatch_agent".to_string(),
508            ToolRule::DispatchAgent {
509                agent_patterns: vec!["explore".to_string()],
510            },
511        );
512
513        let policy = ToolApprovalPolicy {
514            default_behavior: UnapprovedBehavior::Prompt,
515            preapproved: ApprovalRules {
516                tools: HashSet::new(),
517                per_tool,
518            },
519        };
520
521        let tool_call = ToolCall {
522            id: "tc_dispatch".to_string(),
523            name: DISPATCH_AGENT_TOOL_NAME.to_string(),
524            parameters: json!({
525                "prompt": "find files",
526                "target": {
527                    "session": "new",
528                    "workspace": {
529                        "location": "current"
530                    },
531                    "agent": "explore"
532                }
533            }),
534        };
535
536        assert!(tool_is_preapproved(&tool_call, &policy));
537    }
538
539    #[test]
540    fn tool_is_preapproved_denies_unlisted_tool() {
541        let policy = create_test_tool_approval_policy();
542        let tool_call = ToolCall {
543            id: "tc_other".to_string(),
544            name: "bash".to_string(),
545            parameters: json!({ "command": "rm -rf /" }),
546        };
547
548        assert!(!tool_is_preapproved(&tool_call, &policy));
549    }
550
551    #[tokio::test]
552    async fn run_new_session_denies_unapproved_tool_requests() {
553        let event_store = Arc::new(InMemoryEventStore::new());
554        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
555        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
556        let api_client = Arc::new(ApiClient::new_with_deps(
557            crate::test_utils::test_llm_config_provider().unwrap(),
558            provider_registry,
559            model_registry.clone(),
560        ));
561
562        let tool_call = ToolCall {
563            id: "tc_1".to_string(),
564            name: "bash".to_string(),
565            parameters: json!({ "command": "echo denied" }),
566        };
567        api_client.insert_test_provider(
568            builtin::claude_sonnet_4_5().provider.clone(),
569            Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
570        );
571
572        let tool_executor = Arc::new(ToolExecutor::with_components(
573            Arc::new(BackendRegistry::new()),
574            Arc::new(ValidatorRegistry::new()),
575        ));
576        let runtime = RuntimeService::spawn(event_store, api_client, tool_executor);
577
578        let mut config = create_test_session_config();
579        config.tool_config.approval_policy = ToolApprovalPolicy {
580            default_behavior: UnapprovedBehavior::Prompt,
581            preapproved: ApprovalRules {
582                tools: HashSet::new(),
583                per_tool: HashMap::new(),
584            },
585        };
586
587        let model = builtin::claude_sonnet_4_5();
588        let result = OneShotRunner::run_new_session(
589            &runtime.handle,
590            config,
591            "Trigger tool call".to_string(),
592            model,
593        )
594        .await
595        .expect("run_new_session should complete");
596
597        let events = runtime
598            .handle
599            .load_events_after(result.session_id, 0)
600            .await
601            .expect("load events");
602
603        let mut saw_request = false;
604        let mut saw_decision = false;
605        let mut saw_denied = false;
606
607        for (_, event) in events {
608            match event {
609                SessionEvent::ApprovalRequested { .. } => saw_request = true,
610                SessionEvent::ApprovalDecided { decision, .. } => {
611                    saw_decision = true;
612                    if decision == ApprovalDecision::Denied {
613                        saw_denied = true;
614                    }
615                }
616                _ => {}
617            }
618        }
619
620        assert!(saw_request, "expected ApprovalRequested event");
621        assert!(saw_decision, "expected ApprovalDecided event");
622        assert!(saw_denied, "expected denied decision");
623
624        runtime.shutdown().await;
625    }
626
627    #[tokio::test]
628    #[ignore = "Requires API keys and network access"]
629    async fn test_run_new_session_basic() {
630        dotenv().ok();
631        let runtime = create_test_runtime().await;
632
633        let mut config = create_test_session_config();
634        config.tool_config = SessionToolConfig::read_only();
635        config.tool_config.approval_policy = create_test_tool_approval_policy();
636        config
637            .metadata
638            .insert("mode".to_string(), "headless".to_string());
639
640        let model = builtin::claude_sonnet_4_5();
641        let result = OneShotRunner::run_new_session(
642            &runtime.handle,
643            config,
644            "What is 2 + 2?".to_string(),
645            model,
646        )
647        .await;
648
649        let result = tokio::time::timeout(std::time::Duration::from_secs(30), async { result })
650            .await
651            .expect("Timed out waiting for response")
652            .expect("run_new_session failed");
653
654        assert!(!result.final_message.id().is_empty());
655        println!("New session run succeeded: {:?}", result.final_message);
656
657        let content = match &result.final_message.data {
658            MessageData::Assistant { content, .. } => content,
659            _ => panic!("expected assistant message, got {:?}", result.final_message),
660        };
661        let text_content = content.iter().find_map(|c| match c {
662            AssistantContent::Text { text } => Some(text),
663            _ => None,
664        });
665        let content = text_content.expect("No text content found in assistant message");
666        assert!(!content.is_empty(), "Response should not be empty");
667        assert!(
668            content.contains('4'),
669            "Expected response to contain '4', got: {content}"
670        );
671
672        runtime.shutdown().await;
673    }
674
675    #[tokio::test]
676    async fn test_session_creation() {
677        let runtime = create_test_runtime().await;
678
679        let mut config = create_test_session_config();
680        config.tool_config.approval_policy = create_test_tool_approval_policy();
681        config
682            .metadata
683            .insert("test".to_string(), "value".to_string());
684
685        let session_id = runtime.handle.create_session(config).await.unwrap();
686
687        assert!(runtime.handle.is_session_active(session_id).await.unwrap());
688
689        let state = runtime.handle.get_session_state(session_id).await.unwrap();
690        assert_eq!(
691            state.session_config.as_ref().unwrap().metadata.get("test"),
692            Some(&"value".to_string())
693        );
694
695        runtime.shutdown().await;
696    }
697
698    #[tokio::test]
699    async fn test_run_in_session_nonexistent_session() {
700        let runtime = create_test_runtime().await;
701
702        let fake_session_id = SessionId::new();
703        let model = builtin::claude_sonnet_4_5();
704        let result = OneShotRunner::run_in_session(
705            &runtime.handle,
706            fake_session_id,
707            "Test message".to_string(),
708            model,
709        )
710        .await;
711
712        assert!(result.is_err());
713        let err = result.err().unwrap().to_string();
714        assert!(
715            err.contains("not found") || err.contains("Session"),
716            "Expected session not found error, got: {err}"
717        );
718
719        runtime.shutdown().await;
720    }
721
722    #[tokio::test]
723    #[ignore = "Requires API keys and network access"]
724    async fn test_run_in_session_with_real_api() {
725        dotenv().ok();
726        let runtime = create_test_runtime().await;
727
728        let mut config = create_test_session_config();
729        config.tool_config = SessionToolConfig::read_only();
730        config.tool_config.approval_policy = create_test_tool_approval_policy();
731        config
732            .metadata
733            .insert("test".to_string(), "api_test".to_string());
734
735        let session_id = runtime.handle.create_session(config).await.unwrap();
736        let model = builtin::claude_sonnet_4_5();
737
738        let result = OneShotRunner::run_in_session(
739            &runtime.handle,
740            session_id,
741            "What is the capital of France?".to_string(),
742            model,
743        )
744        .await;
745
746        match result {
747            Ok(run_result) => {
748                println!("Session run succeeded: {:?}", run_result.final_message);
749
750                let content = match &run_result.final_message.data {
751                    MessageData::Assistant { content, .. } => content.clone(),
752                    _ => panic!(
753                        "expected assistant message, got {:?}",
754                        run_result.final_message
755                    ),
756                };
757                let text_content = content.iter().find_map(|c| match c {
758                    AssistantContent::Text { text } => Some(text),
759                    _ => None,
760                });
761                let content = text_content.expect("expected text response in assistant message");
762                assert!(!content.is_empty(), "Response should not be empty");
763                assert!(
764                    content.to_lowercase().contains("paris"),
765                    "Expected response to contain 'Paris', got: {content}"
766                );
767            }
768            Err(e) => {
769                println!("Session run failed (expected if no API key): {e}");
770                assert!(
771                    e.to_string().contains("API key")
772                        || e.to_string().contains("authentication")
773                        || e.to_string().contains("timed out"),
774                    "Unexpected error: {e}"
775                );
776            }
777        }
778
779        runtime.shutdown().await;
780    }
781
782    #[tokio::test]
783    #[ignore = "Requires API keys and network access"]
784    async fn test_run_in_session_preserves_context() {
785        dotenv().ok();
786        let runtime = create_test_runtime().await;
787
788        let mut config = create_test_session_config();
789        config.tool_config = SessionToolConfig::read_only();
790        config.tool_config.approval_policy = create_test_tool_approval_policy();
791        config
792            .metadata
793            .insert("test".to_string(), "context_test".to_string());
794
795        let session_id = runtime.handle.create_session(config).await.unwrap();
796        let model = builtin::claude_sonnet_4_5();
797
798        let result1 = OneShotRunner::run_in_session(
799            &runtime.handle,
800            session_id,
801            "My name is Alice and I like pizza.".to_string(),
802            model.clone(),
803        )
804        .await
805        .expect("First session run should succeed");
806
807        println!("First interaction: {:?}", result1.final_message);
808
809        runtime.handle.resume_session(session_id).await.unwrap();
810
811        let result2 = OneShotRunner::run_in_session(
812            &runtime.handle,
813            session_id,
814            "What is my name and what do I like?".to_string(),
815            model,
816        )
817        .await
818        .expect("Second session run should succeed");
819
820        println!("Second interaction: {:?}", result2.final_message);
821
822        match &result2.final_message.data {
823            MessageData::Assistant { content, .. } => {
824                let text_content = content.iter().find_map(|c| match c {
825                    AssistantContent::Text { text } => Some(text),
826                    _ => None,
827                });
828
829                match text_content {
830                    Some(content) => {
831                        assert!(!content.is_empty(), "Response should not be empty");
832                        let content_lower = content.to_lowercase();
833
834                        assert!(
835                            content_lower.contains("alice") || content_lower.contains("name"),
836                            "Expected response to reference the name or context, got: {content}"
837                        );
838                    }
839                    None => {
840                        panic!("expected text response in assistant message");
841                    }
842                }
843            }
844            _ => {
845                panic!(
846                    "expected assistant message, got {:?}",
847                    result2.final_message
848                );
849            }
850        }
851
852        runtime.shutdown().await;
853    }
854
855    #[tokio::test]
856    #[ignore = "Requires API keys and network access"]
857    async fn test_run_new_session_with_tool_usage() {
858        dotenv().ok();
859        let runtime = create_test_runtime().await;
860
861        let mut config = create_test_session_config();
862        config.tool_config = SessionToolConfig::read_only();
863        config.tool_config.approval_policy = create_test_tool_approval_policy();
864        let model = builtin::claude_sonnet_4_5();
865
866        let result = OneShotRunner::run_new_session(
867            &runtime.handle,
868            config,
869            "List the files in the current directory".to_string(),
870            model,
871        )
872        .await
873        .expect("New session run with tools should succeed with valid API key");
874
875        assert!(!result.final_message.id().is_empty());
876        println!(
877            "New session run with tools succeeded: {:?}",
878            result.final_message
879        );
880
881        let has_content = match &result.final_message.data {
882            MessageData::Assistant { content, .. } => content.iter().any(|c| match c {
883                AssistantContent::Text { text } => !text.is_empty(),
884                _ => true,
885            }),
886            _ => false,
887        };
888        assert!(has_content, "Response should have some content");
889
890        runtime.shutdown().await;
891    }
892}