1use serde::{Deserialize, Serialize};
2use tokio_util::sync::CancellationToken;
3use tracing::{error, info, warn};
4
5use crate::agents::default_agent_spec_id;
6use crate::app::conversation::{Message, UserContent};
7use crate::app::domain::event::SessionEvent;
8use crate::app::domain::runtime::{RuntimeError, RuntimeHandle};
9use crate::app::domain::types::SessionId;
10use crate::config::model::ModelId;
11use crate::error::{Error, Result};
12use crate::session::ToolApprovalPolicy;
13use crate::session::state::SessionConfig;
14use crate::tools::{DISPATCH_AGENT_TOOL_NAME, DispatchAgentParams, DispatchAgentTarget};
15use steer_tools::ToolCall;
16use steer_tools::tools::BASH_TOOL_NAME;
17use steer_tools::tools::bash::BashParams;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RunOnceResult {
21 pub final_message: Message,
22 pub session_id: SessionId,
23}
24
25pub struct OneShotRunner;
26
27impl Default for OneShotRunner {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl OneShotRunner {
34 pub fn new() -> Self {
35 Self
36 }
37
38 pub async fn run_in_session(
39 runtime: &RuntimeHandle,
40 session_id: SessionId,
41 message: String,
42 model: ModelId,
43 ) -> Result<RunOnceResult> {
44 Self::run_in_session_with_cancel(
45 runtime,
46 session_id,
47 message,
48 model,
49 CancellationToken::new(),
50 )
51 .await
52 }
53
54 pub async fn run_in_session_with_cancel(
55 runtime: &RuntimeHandle,
56 session_id: SessionId,
57 message: String,
58 model: ModelId,
59 cancel_token: CancellationToken,
60 ) -> Result<RunOnceResult> {
61 runtime.resume_session(session_id).await.map_err(|e| {
62 Error::InvalidOperation(format!("Failed to resume session {session_id}: {e}"))
63 })?;
64
65 let subscription = runtime.subscribe_events(session_id).await.map_err(|e| {
66 Error::InvalidOperation(format!(
67 "Failed to subscribe to session {session_id} events: {e}"
68 ))
69 })?;
70
71 let approval_policy = match runtime.get_session_state(session_id).await {
72 Ok(state) => state
73 .session_config
74 .map(|config| config.tool_config.approval_policy)
75 .unwrap_or_default(),
76 Err(err) => {
77 warn!(
78 session_id = %session_id,
79 error = %err,
80 "Failed to load session approval policy; defaulting to deny"
81 );
82 ToolApprovalPolicy::default()
83 }
84 };
85
86 info!(session_id = %session_id, message = %message, "Sending message to session");
87
88 let (op_id, _message_id) = runtime
89 .submit_user_input(
90 session_id,
91 vec![UserContent::Text {
92 text: message.clone(),
93 }],
94 model,
95 )
96 .await
97 .map_err(|e| {
98 Error::InvalidOperation(format!(
99 "Failed to send message to session {session_id}: {e}"
100 ))
101 })?;
102
103 let cancel_task = {
104 let runtime = runtime.clone();
105 let cancel_token = cancel_token.clone();
106 tokio::spawn(async move {
107 cancel_token.cancelled().await;
108 if let Err(err) = runtime.cancel_operation(session_id, Some(op_id)).await {
109 warn!(
110 session_id = %session_id,
111 error = %err,
112 "Failed to cancel one-shot operation"
113 );
114 }
115 })
116 };
117
118 let result =
119 Self::process_events(runtime, subscription, session_id, op_id, approval_policy).await;
120
121 cancel_task.abort();
122
123 if let Err(e) = runtime.suspend_session(session_id).await {
124 error!(session_id = %session_id, error = %e, "Failed to suspend session");
125 } else {
126 info!(session_id = %session_id, "Session suspended successfully");
127 }
128
129 result
130 }
131
132 pub async fn run_new_session(
133 runtime: &RuntimeHandle,
134 config: SessionConfig,
135 message: String,
136 model: ModelId,
137 ) -> Result<RunOnceResult> {
138 Self::run_new_session_with_cancel(runtime, config, message, model, CancellationToken::new())
139 .await
140 }
141
142 pub async fn run_new_session_with_cancel(
143 runtime: &RuntimeHandle,
144 config: SessionConfig,
145 message: String,
146 model: ModelId,
147 cancel_token: CancellationToken,
148 ) -> Result<RunOnceResult> {
149 let session_id = runtime
150 .create_session(config)
151 .await
152 .map_err(|e| Error::InvalidOperation(format!("Failed to create session: {e}")))?;
153
154 info!(session_id = %session_id, "Created new session for one-shot run");
155
156 Self::run_in_session_with_cancel(runtime, session_id, message, model, cancel_token).await
157 }
158
159 async fn process_events(
160 runtime: &RuntimeHandle,
161 mut subscription: crate::app::domain::runtime::SessionEventSubscription,
162 session_id: SessionId,
163 op_id: crate::app::domain::types::OpId,
164 mut approval_policy: ToolApprovalPolicy,
165 ) -> Result<RunOnceResult> {
166 let mut messages = Vec::new();
167 info!(session_id = %session_id, "Starting event processing loop");
168
169 while let Some(envelope) = subscription.recv().await {
170 match envelope.event {
171 SessionEvent::AssistantMessageAdded { message, model: _ } => {
172 info!(
173 session_id = %session_id,
174 role = ?message.role(),
175 id = %message.id(),
176 "AssistantMessageAdded event"
177 );
178 messages.push(message);
179 }
180
181 SessionEvent::MessageUpdated { message } => {
182 info!(
183 session_id = %session_id,
184 id = %message.id(),
185 "MessageUpdated event"
186 );
187 }
188
189 SessionEvent::OperationCompleted {
190 op_id: completed_op,
191 } => {
192 if completed_op != op_id {
193 continue;
194 }
195 info!(
196 session_id = %session_id,
197 op_id = %completed_op,
198 "OperationCompleted event received"
199 );
200 if !messages.is_empty() {
201 info!(session_id = %session_id, "Final message received, exiting event loop");
202 break;
203 }
204 }
205
206 SessionEvent::OperationCancelled {
207 op_id: cancelled_op,
208 ..
209 } => {
210 if cancelled_op != op_id {
211 continue;
212 }
213 warn!(
214 session_id = %session_id,
215 op_id = %cancelled_op,
216 "OperationCancelled event received"
217 );
218 return Err(Error::Cancelled);
219 }
220
221 SessionEvent::Error { message } => {
222 error!(session_id = %session_id, error = %message, "Error event");
223 return Err(Error::InvalidOperation(format!(
224 "Error during processing: {message}"
225 )));
226 }
227
228 SessionEvent::ApprovalRequested {
229 request_id,
230 tool_call,
231 } => {
232 let approved = tool_is_preapproved(&tool_call, &approval_policy);
233 if approved {
234 info!(
235 session_id = %session_id,
236 request_id = %request_id,
237 tool = %tool_call.name,
238 "Auto-approving preapproved tool"
239 );
240 } else {
241 warn!(
242 session_id = %session_id,
243 request_id = %request_id,
244 tool = %tool_call.name,
245 "Auto-denying unapproved tool"
246 );
247 }
248
249 runtime
250 .submit_tool_approval(session_id, request_id, approved, None)
251 .await
252 .map_err(|e| {
253 Error::InvalidOperation(format!(
254 "Failed to submit tool approval decision: {e}"
255 ))
256 })?;
257 }
258
259 SessionEvent::SessionConfigUpdated { config, .. } => {
260 approval_policy = config.tool_config.approval_policy.clone();
261 }
262
263 _ => {}
264 }
265 }
266
267 match messages.last() {
268 Some(final_message) => {
269 info!(
270 session_id = %session_id,
271 message_count = messages.len(),
272 "Returning final result"
273 );
274 Ok(RunOnceResult {
275 final_message: final_message.clone(),
276 session_id,
277 })
278 }
279 None => Err(Error::InvalidOperation("No message received".to_string())),
280 }
281 }
282}
283
284fn tool_is_preapproved(tool_call: &ToolCall, policy: &ToolApprovalPolicy) -> bool {
285 if policy.preapproved.tools.contains(&tool_call.name) {
286 return true;
287 }
288
289 if tool_call.name == DISPATCH_AGENT_TOOL_NAME {
290 let params = serde_json::from_value::<DispatchAgentParams>(tool_call.parameters.clone());
291 if let Ok(params) = params {
292 return match params.target {
293 DispatchAgentTarget::Resume { .. } => true,
294 DispatchAgentTarget::New { agent, .. } => {
295 let agent_id = agent
296 .as_deref()
297 .filter(|value| !value.trim().is_empty())
298 .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
299 policy.is_dispatch_agent_pattern_preapproved(&agent_id)
300 }
301 };
302 }
303 }
304
305 if tool_call.name == BASH_TOOL_NAME {
306 let params = serde_json::from_value::<BashParams>(tool_call.parameters.clone());
307 if let Ok(params) = params {
308 return policy.is_bash_pattern_preapproved(¶ms.command);
309 }
310 }
311
312 false
313}
314
315impl From<RuntimeError> for Error {
316 fn from(e: RuntimeError) -> Self {
317 match e {
318 RuntimeError::SessionNotFound { session_id } => {
319 Error::InvalidOperation(format!("Session not found: {session_id}"))
320 }
321 RuntimeError::SessionAlreadyExists { session_id } => {
322 Error::InvalidOperation(format!("Session already exists: {session_id}"))
323 }
324 RuntimeError::InvalidInput { message } => Error::InvalidOperation(message),
325 RuntimeError::ChannelClosed => {
326 Error::InvalidOperation("Runtime channel closed".to_string())
327 }
328 RuntimeError::ShuttingDown => {
329 Error::InvalidOperation("Runtime is shutting down".to_string())
330 }
331 RuntimeError::Session(e) => Error::InvalidOperation(format!("Session error: {e}")),
332 RuntimeError::EventStore(e) => {
333 Error::InvalidOperation(format!("Event store error: {e}"))
334 }
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::api::Client as ApiClient;
343 use crate::api::{ApiError, CompletionResponse, Provider};
344 use crate::app::conversation::{AssistantContent, Message, MessageData};
345 use crate::app::domain::action::ApprovalDecision;
346 use crate::app::domain::runtime::RuntimeService;
347 use crate::app::domain::session::event_store::InMemoryEventStore;
348 use crate::app::validation::ValidatorRegistry;
349 use crate::config::model::builtin;
350 use crate::session::SessionPolicyOverrides;
351 use crate::session::ToolApprovalPolicy;
352 use crate::session::state::{
353 ApprovalRules, SessionToolConfig, UnapprovedBehavior, WorkspaceConfig,
354 };
355 use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
356 use crate::tools::{BackendRegistry, ToolExecutor};
357 use dotenvy::dotenv;
358 use serde_json::json;
359 use std::collections::{HashMap, HashSet};
360 use std::sync::Arc;
361 use std::sync::Mutex as StdMutex;
362 use std::time::Duration;
363 use steer_tools::ToolCall;
364 use steer_tools::tools::BASH_TOOL_NAME;
365 use tokio_util::sync::CancellationToken;
366
367 #[derive(Clone)]
368 struct ToolCallThenTextProvider {
369 tool_call: ToolCall,
370 final_text: String,
371 call_count: Arc<StdMutex<usize>>,
372 first_call_delay: Duration,
373 }
374
375 impl ToolCallThenTextProvider {
376 fn new(tool_call: ToolCall, final_text: impl Into<String>) -> Self {
377 Self::new_with_delay(tool_call, final_text, Duration::ZERO)
378 }
379
380 fn new_with_delay(
381 tool_call: ToolCall,
382 final_text: impl Into<String>,
383 first_call_delay: Duration,
384 ) -> Self {
385 Self {
386 tool_call,
387 final_text: final_text.into(),
388 call_count: Arc::new(StdMutex::new(0)),
389 first_call_delay,
390 }
391 }
392 }
393
394 #[async_trait::async_trait]
395 impl Provider for ToolCallThenTextProvider {
396 fn name(&self) -> &'static str {
397 "stub-tool-call"
398 }
399
400 async fn complete(
401 &self,
402 _model_id: &crate::config::model::ModelId,
403 _messages: Vec<Message>,
404 _system: Option<crate::app::SystemContext>,
405 _tools: Option<Vec<steer_tools::ToolSchema>>,
406 _call_options: Option<crate::config::model::ModelParameters>,
407 _token: CancellationToken,
408 ) -> std::result::Result<CompletionResponse, ApiError> {
409 let call_index = {
410 let mut count = self
411 .call_count
412 .lock()
413 .expect("tool call counter lock poisoned");
414 let idx = *count;
415 *count += 1;
416 idx
417 };
418
419 if call_index == 0 && !self.first_call_delay.is_zero() {
420 tokio::time::sleep(self.first_call_delay).await;
421 }
422
423 if call_index == 0 {
424 Ok(CompletionResponse {
425 content: vec![AssistantContent::ToolCall {
426 tool_call: self.tool_call.clone(),
427 thought_signature: None,
428 }],
429 usage: None,
430 })
431 } else {
432 Ok(CompletionResponse {
433 content: vec![AssistantContent::Text {
434 text: self.final_text.clone(),
435 }],
436 usage: None,
437 })
438 }
439 }
440 }
441
442 async fn create_test_runtime() -> RuntimeService {
443 let event_store = Arc::new(InMemoryEventStore::new());
444 let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
445 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
446 let api_client = Arc::new(ApiClient::new_with_deps(
447 crate::test_utils::test_llm_config_provider().unwrap(),
448 provider_registry,
449 model_registry,
450 ));
451
452 let tool_executor = Arc::new(ToolExecutor::with_components(
453 Arc::new(BackendRegistry::new()),
454 Arc::new(ValidatorRegistry::new()),
455 ));
456
457 RuntimeService::spawn(event_store, api_client, tool_executor)
458 }
459
460 fn create_test_session_config() -> SessionConfig {
461 SessionConfig {
462 default_model: builtin::claude_sonnet_4_5(),
463 workspace: WorkspaceConfig::default(),
464 workspace_ref: None,
465 workspace_id: None,
466 repo_ref: None,
467 parent_session_id: None,
468 workspace_name: None,
469 tool_config: SessionToolConfig::default(),
470 system_prompt: None,
471 primary_agent_id: None,
472 policy_overrides: SessionPolicyOverrides::empty(),
473 title: None,
474 metadata: std::collections::HashMap::new(),
475 auto_compaction: crate::session::state::AutoCompactionConfig::default(),
476 }
477 }
478
479 fn create_test_tool_approval_policy() -> ToolApprovalPolicy {
480 let tool_names = READ_ONLY_TOOL_NAMES
481 .iter()
482 .map(|name| (*name).to_string())
483 .collect();
484 ToolApprovalPolicy {
485 default_behavior: UnapprovedBehavior::Prompt,
486 preapproved: ApprovalRules {
487 tools: tool_names,
488 per_tool: std::collections::HashMap::new(),
489 },
490 }
491 }
492
493 #[test]
494 fn tool_is_preapproved_allows_whitelisted_tool() {
495 let policy = create_test_tool_approval_policy();
496 let tool_call = ToolCall {
497 id: "tc_read".to_string(),
498 name: READ_ONLY_TOOL_NAMES[0].to_string(),
499 parameters: json!({}),
500 };
501
502 assert!(tool_is_preapproved(&tool_call, &policy));
503 }
504
505 #[test]
506 fn tool_is_preapproved_allows_bash_pattern() {
507 use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
508
509 let mut per_tool = HashMap::new();
510 per_tool.insert(
511 "bash".to_string(),
512 ToolRule::Bash {
513 patterns: vec!["echo *".to_string()],
514 },
515 );
516
517 let policy = ToolApprovalPolicy {
518 default_behavior: UnapprovedBehavior::Prompt,
519 preapproved: ApprovalRules {
520 tools: HashSet::new(),
521 per_tool,
522 },
523 };
524
525 let tool_call = ToolCall {
526 id: "tc_bash".to_string(),
527 name: BASH_TOOL_NAME.to_string(),
528 parameters: json!({ "command": "echo hello" }),
529 };
530
531 assert!(tool_is_preapproved(&tool_call, &policy));
532 }
533
534 #[test]
535 fn tool_is_preapproved_allows_dispatch_agent_pattern() {
536 use crate::session::state::{ApprovalRules, ToolRule, UnapprovedBehavior};
537
538 let mut per_tool = HashMap::new();
539 per_tool.insert(
540 "dispatch_agent".to_string(),
541 ToolRule::DispatchAgent {
542 agent_patterns: vec!["explore".to_string()],
543 },
544 );
545
546 let policy = ToolApprovalPolicy {
547 default_behavior: UnapprovedBehavior::Prompt,
548 preapproved: ApprovalRules {
549 tools: HashSet::new(),
550 per_tool,
551 },
552 };
553
554 let tool_call = ToolCall {
555 id: "tc_dispatch".to_string(),
556 name: DISPATCH_AGENT_TOOL_NAME.to_string(),
557 parameters: json!({
558 "prompt": "find files",
559 "target": {
560 "session": "new",
561 "workspace": {
562 "location": "current"
563 },
564 "agent": "explore"
565 }
566 }),
567 };
568
569 assert!(tool_is_preapproved(&tool_call, &policy));
570 }
571
572 #[test]
573 fn tool_is_preapproved_denies_unlisted_tool() {
574 let policy = create_test_tool_approval_policy();
575 let tool_call = ToolCall {
576 id: "tc_other".to_string(),
577 name: "bash".to_string(),
578 parameters: json!({ "command": "rm -rf /" }),
579 };
580
581 assert!(!tool_is_preapproved(&tool_call, &policy));
582 }
583
584 #[tokio::test]
585 async fn run_new_session_denies_unapproved_tool_requests() {
586 let event_store = Arc::new(InMemoryEventStore::new());
587 let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
588 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
589 let api_client = Arc::new(ApiClient::new_with_deps(
590 crate::test_utils::test_llm_config_provider().unwrap(),
591 provider_registry,
592 model_registry.clone(),
593 ));
594
595 let tool_call = ToolCall {
596 id: "tc_1".to_string(),
597 name: "bash".to_string(),
598 parameters: json!({ "command": "echo denied" }),
599 };
600 api_client.insert_test_provider(
601 builtin::claude_sonnet_4_5().provider.clone(),
602 Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
603 );
604
605 let tool_executor = Arc::new(ToolExecutor::with_components(
606 Arc::new(BackendRegistry::new()),
607 Arc::new(ValidatorRegistry::new()),
608 ));
609 let runtime = RuntimeService::spawn(event_store, api_client, tool_executor);
610
611 let mut config = create_test_session_config();
612 config.tool_config.approval_policy = ToolApprovalPolicy {
613 default_behavior: UnapprovedBehavior::Prompt,
614 preapproved: ApprovalRules {
615 tools: HashSet::new(),
616 per_tool: HashMap::new(),
617 },
618 };
619
620 let model = builtin::claude_sonnet_4_5();
621 let result = OneShotRunner::run_new_session(
622 &runtime.handle,
623 config,
624 "Trigger tool call".to_string(),
625 model,
626 )
627 .await
628 .expect("run_new_session should complete");
629
630 let events = runtime
631 .handle
632 .load_events_after(result.session_id, 0)
633 .await
634 .expect("load events");
635
636 let mut saw_request = false;
637 let mut saw_decision = false;
638 let mut saw_denied = false;
639
640 for (_, event) in events {
641 match event {
642 SessionEvent::ApprovalRequested { .. } => saw_request = true,
643 SessionEvent::ApprovalDecided { decision, .. } => {
644 saw_decision = true;
645 if decision == ApprovalDecision::Denied {
646 saw_denied = true;
647 }
648 }
649 _ => {}
650 }
651 }
652
653 assert!(saw_request, "expected ApprovalRequested event");
654 assert!(saw_decision, "expected ApprovalDecided event");
655 assert!(saw_denied, "expected denied decision");
656
657 runtime.shutdown().await;
658 }
659
660 #[tokio::test]
661 async fn run_new_session_updates_approval_policy_after_agent_switch_event() {
662 let event_store = Arc::new(InMemoryEventStore::new());
663 let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[]).unwrap());
664 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
665 let api_client = Arc::new(ApiClient::new_with_deps(
666 crate::test_utils::test_llm_config_provider().unwrap(),
667 provider_registry,
668 model_registry,
669 ));
670
671 let tool_call = ToolCall {
672 id: "tc_switch".to_string(),
673 name: "bash".to_string(),
674 parameters: json!({ "command": "echo switched" }),
675 };
676 api_client.insert_test_provider(
677 builtin::claude_sonnet_4_5().provider.clone(),
678 Arc::new(ToolCallThenTextProvider::new_with_delay(
679 tool_call,
680 "done",
681 Duration::from_millis(75),
682 )),
683 );
684
685 let tool_executor = Arc::new(ToolExecutor::with_components(
686 Arc::new(BackendRegistry::new()),
687 Arc::new(ValidatorRegistry::new()),
688 ));
689 let runtime = RuntimeService::spawn(event_store, api_client, tool_executor);
690
691 let mut config = create_test_session_config();
692 config.tool_config.approval_policy = ToolApprovalPolicy {
693 default_behavior: UnapprovedBehavior::Prompt,
694 preapproved: ApprovalRules {
695 tools: HashSet::new(),
696 per_tool: HashMap::new(),
697 },
698 };
699
700 let session_id = runtime
701 .handle
702 .create_session(config)
703 .await
704 .expect("create session");
705
706 let runtime_handle_for_run = runtime.handle.clone();
707 let run_task = tokio::spawn(async move {
708 OneShotRunner::run_in_session(
709 &runtime_handle_for_run,
710 session_id,
711 "Trigger tool call".to_string(),
712 builtin::claude_sonnet_4_5(),
713 )
714 .await
715 });
716
717 tokio::time::sleep(Duration::from_millis(20)).await;
718
719 runtime
720 .handle
721 .switch_primary_agent(session_id, "yolo".to_string())
722 .await
723 .expect("switch to yolo");
724
725 let result = run_task
726 .await
727 .expect("run task join")
728 .expect("run_in_session should complete");
729
730 let events = runtime
731 .handle
732 .load_events_after(result.session_id, 0)
733 .await
734 .expect("load events");
735
736 let mut saw_approval_requested = false;
737 let mut saw_session_config_updated = false;
738 let mut saw_tool_started = false;
739 let mut saw_approval_decision = false;
740
741 for (_, event) in events {
742 match event {
743 SessionEvent::ApprovalRequested { .. } => saw_approval_requested = true,
744 SessionEvent::SessionConfigUpdated {
745 primary_agent_id, ..
746 } if primary_agent_id == "yolo" => saw_session_config_updated = true,
747 SessionEvent::ToolCallStarted { id, .. } if id.as_str() == "tc_switch" => {
748 saw_tool_started = true
749 }
750 SessionEvent::ApprovalDecided { .. } => saw_approval_decision = true,
751 _ => {}
752 }
753 }
754
755 assert!(
756 saw_session_config_updated,
757 "expected in-flight session config update event"
758 );
759 assert!(
760 saw_tool_started,
761 "expected bash tool call to auto-run after policy switched to yolo"
762 );
763 assert!(
764 !saw_approval_requested,
765 "expected no approval request after switching to yolo before the first tool call"
766 );
767 assert!(
768 !saw_approval_decision,
769 "expected no approval decision when no approval was requested"
770 );
771
772 runtime.shutdown().await;
773 }
774
775 #[tokio::test]
776 #[ignore = "Requires API keys and network access"]
777 async fn test_run_new_session_basic() {
778 dotenv().ok();
779 let runtime = create_test_runtime().await;
780
781 let mut config = create_test_session_config();
782 config.tool_config = SessionToolConfig::read_only();
783 config.tool_config.approval_policy = create_test_tool_approval_policy();
784 config
785 .metadata
786 .insert("mode".to_string(), "headless".to_string());
787
788 let model = builtin::claude_sonnet_4_5();
789 let result = OneShotRunner::run_new_session(
790 &runtime.handle,
791 config,
792 "What is 2 + 2?".to_string(),
793 model,
794 )
795 .await;
796
797 let result = tokio::time::timeout(std::time::Duration::from_secs(30), async { result })
798 .await
799 .expect("Timed out waiting for response")
800 .expect("run_new_session failed");
801
802 assert!(!result.final_message.id().is_empty());
803 println!("New session run succeeded: {:?}", result.final_message);
804
805 let content = match &result.final_message.data {
806 MessageData::Assistant { content, .. } => content,
807 _ => panic!("expected assistant message, got {:?}", result.final_message),
808 };
809 let text_content = content.iter().find_map(|c| match c {
810 AssistantContent::Text { text } => Some(text),
811 _ => None,
812 });
813 let content = text_content.expect("No text content found in assistant message");
814 assert!(!content.is_empty(), "Response should not be empty");
815 assert!(
816 content.contains('4'),
817 "Expected response to contain '4', got: {content}"
818 );
819
820 runtime.shutdown().await;
821 }
822
823 #[tokio::test]
824 async fn test_session_creation() {
825 let runtime = create_test_runtime().await;
826
827 let mut config = create_test_session_config();
828 config.tool_config.approval_policy = create_test_tool_approval_policy();
829 config
830 .metadata
831 .insert("test".to_string(), "value".to_string());
832
833 let session_id = runtime.handle.create_session(config).await.unwrap();
834
835 assert!(runtime.handle.is_session_active(session_id).await.unwrap());
836
837 let state = runtime.handle.get_session_state(session_id).await.unwrap();
838 assert_eq!(
839 state.session_config.as_ref().unwrap().metadata.get("test"),
840 Some(&"value".to_string())
841 );
842
843 runtime.shutdown().await;
844 }
845
846 #[tokio::test]
847 async fn test_run_in_session_nonexistent_session() {
848 let runtime = create_test_runtime().await;
849
850 let fake_session_id = SessionId::new();
851 let model = builtin::claude_sonnet_4_5();
852 let result = OneShotRunner::run_in_session(
853 &runtime.handle,
854 fake_session_id,
855 "Test message".to_string(),
856 model,
857 )
858 .await;
859
860 assert!(result.is_err());
861 let err = result.err().unwrap().to_string();
862 assert!(
863 err.contains("not found") || err.contains("Session"),
864 "Expected session not found error, got: {err}"
865 );
866
867 runtime.shutdown().await;
868 }
869
870 #[tokio::test]
871 #[ignore = "Requires API keys and network access"]
872 async fn test_run_in_session_with_real_api() {
873 dotenv().ok();
874 let runtime = create_test_runtime().await;
875
876 let mut config = create_test_session_config();
877 config.tool_config = SessionToolConfig::read_only();
878 config.tool_config.approval_policy = create_test_tool_approval_policy();
879 config
880 .metadata
881 .insert("test".to_string(), "api_test".to_string());
882
883 let session_id = runtime.handle.create_session(config).await.unwrap();
884 let model = builtin::claude_sonnet_4_5();
885
886 let result = OneShotRunner::run_in_session(
887 &runtime.handle,
888 session_id,
889 "What is the capital of France?".to_string(),
890 model,
891 )
892 .await;
893
894 match result {
895 Ok(run_result) => {
896 println!("Session run succeeded: {:?}", run_result.final_message);
897
898 let content = match &run_result.final_message.data {
899 MessageData::Assistant { content, .. } => content.clone(),
900 _ => panic!(
901 "expected assistant message, got {:?}",
902 run_result.final_message
903 ),
904 };
905 let text_content = content.iter().find_map(|c| match c {
906 AssistantContent::Text { text } => Some(text),
907 _ => None,
908 });
909 let content = text_content.expect("expected text response in assistant message");
910 assert!(!content.is_empty(), "Response should not be empty");
911 assert!(
912 content.to_lowercase().contains("paris"),
913 "Expected response to contain 'Paris', got: {content}"
914 );
915 }
916 Err(e) => {
917 println!("Session run failed (expected if no API key): {e}");
918 assert!(
919 e.to_string().contains("API key")
920 || e.to_string().contains("authentication")
921 || e.to_string().contains("timed out"),
922 "Unexpected error: {e}"
923 );
924 }
925 }
926
927 runtime.shutdown().await;
928 }
929
930 #[tokio::test]
931 #[ignore = "Requires API keys and network access"]
932 async fn test_run_in_session_preserves_context() {
933 dotenv().ok();
934 let runtime = create_test_runtime().await;
935
936 let mut config = create_test_session_config();
937 config.tool_config = SessionToolConfig::read_only();
938 config.tool_config.approval_policy = create_test_tool_approval_policy();
939 config
940 .metadata
941 .insert("test".to_string(), "context_test".to_string());
942
943 let session_id = runtime.handle.create_session(config).await.unwrap();
944 let model = builtin::claude_sonnet_4_5();
945
946 let result1 = OneShotRunner::run_in_session(
947 &runtime.handle,
948 session_id,
949 "My name is Alice and I like pizza.".to_string(),
950 model.clone(),
951 )
952 .await
953 .expect("First session run should succeed");
954
955 println!("First interaction: {:?}", result1.final_message);
956
957 runtime.handle.resume_session(session_id).await.unwrap();
958
959 let result2 = OneShotRunner::run_in_session(
960 &runtime.handle,
961 session_id,
962 "What is my name and what do I like?".to_string(),
963 model,
964 )
965 .await
966 .expect("Second session run should succeed");
967
968 println!("Second interaction: {:?}", result2.final_message);
969
970 match &result2.final_message.data {
971 MessageData::Assistant { content, .. } => {
972 let text_content = content.iter().find_map(|c| match c {
973 AssistantContent::Text { text } => Some(text),
974 _ => None,
975 });
976
977 match text_content {
978 Some(content) => {
979 assert!(!content.is_empty(), "Response should not be empty");
980 let content_lower = content.to_lowercase();
981
982 assert!(
983 content_lower.contains("alice") || content_lower.contains("name"),
984 "Expected response to reference the name or context, got: {content}"
985 );
986 }
987 None => {
988 panic!("expected text response in assistant message");
989 }
990 }
991 }
992 _ => {
993 panic!(
994 "expected assistant message, got {:?}",
995 result2.final_message
996 );
997 }
998 }
999
1000 runtime.shutdown().await;
1001 }
1002
1003 #[tokio::test]
1004 #[ignore = "Requires API keys and network access"]
1005 async fn test_run_new_session_with_tool_usage() {
1006 dotenv().ok();
1007 let runtime = create_test_runtime().await;
1008
1009 let mut config = create_test_session_config();
1010 config.tool_config = SessionToolConfig::read_only();
1011 config.tool_config.approval_policy = create_test_tool_approval_policy();
1012 let model = builtin::claude_sonnet_4_5();
1013
1014 let result = OneShotRunner::run_new_session(
1015 &runtime.handle,
1016 config,
1017 "List the files in the current directory".to_string(),
1018 model,
1019 )
1020 .await
1021 .expect("New session run with tools should succeed with valid API key");
1022
1023 assert!(!result.final_message.id().is_empty());
1024 println!(
1025 "New session run with tools succeeded: {:?}",
1026 result.final_message
1027 );
1028
1029 let has_content = match &result.final_message.data {
1030 MessageData::Assistant { content, .. } => content.iter().any(|c| match c {
1031 AssistantContent::Text { text } => !text.is_empty(),
1032 _ => true,
1033 }),
1034 _ => false,
1035 };
1036 assert!(has_content, "Response should have some content");
1037
1038 runtime.shutdown().await;
1039 }
1040}