tower_llm/
core.rs

1//! Core agent implementation using Tower services and static dependency injection.
2
3use std::{future::Future, pin::Pin, sync::Arc};
4
5use crate::groups::HandoffPolicy;
6use crate::provider::{ModelService, OpenAIProvider, ProviderResponse};
7use async_openai::{
8    config::OpenAIConfig,
9    types::{
10        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
11        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
12        ChatCompletionRequestUserMessageArgs, ChatCompletionResponseMessage, ChatCompletionTool,
13        ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
14        CreateChatCompletionRequestArgs, FunctionObjectArgs, ReasoningEffort,
15    },
16    Client,
17};
18use async_trait::async_trait;
19use futures::future::BoxFuture;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23use tokio::sync::Semaphore;
24use tower::{
25    util::{BoxCloneService, BoxService},
26    BoxError, Layer, Service, ServiceExt,
27};
28use tracing::{debug, trace};
29
30/// Join policy for parallel tool execution
31#[derive(Debug, Clone, Copy, Default)]
32pub enum ToolJoinPolicy {
33    /// Return error on the first failing tool; pending tools are cancelled
34    #[default]
35    FailFast,
36    /// Run all tools to completion; if any fail, surface an aggregated error at the end
37    JoinAll,
38}
39
40// =============================
41// Tool service modeling
42// =============================
43
44/// Uniform tool invocation passed to routed tool services.
45#[derive(Debug, Clone)]
46pub struct ToolInvocation {
47    pub id: String,   // tool_call_id
48    pub name: String, // function.name
49    pub arguments: Value,
50}
51
52/// Uniform tool output produced by tool services.
53#[derive(Debug, Clone)]
54pub struct ToolOutput {
55    pub id: String, // same as invocation.id
56    pub result: Value,
57}
58
59/// Boxed tool service type alias.
60pub type ToolSvc = BoxCloneService<ToolInvocation, ToolOutput, BoxError>;
61
62/// Definition of a tool: function spec (for OpenAI) + service implementation.
63pub struct ToolDef {
64    pub name: &'static str,
65    pub description: &'static str,
66    pub parameters_schema: Value,
67    pub service: ToolSvc,
68}
69
70impl ToolDef {
71    /// Create a tool definition from a handler function that takes JSON args and returns JSON.
72    pub fn from_handler(
73        name: &'static str,
74        description: &'static str,
75        parameters_schema: Value,
76        handler: std::sync::Arc<
77            dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync + 'static,
78        >,
79    ) -> Self {
80        let handler_arc = handler.clone();
81        let svc = tower::service_fn(move |inv: ToolInvocation| {
82            let handler = handler_arc.clone();
83            async move {
84                if inv.name != name {
85                    return Err::<ToolOutput, BoxError>(
86                        format!("routed to wrong tool: expected={}, got={}", name, inv.name).into(),
87                    );
88                }
89                let out = (handler)(inv.arguments).await?;
90                Ok(ToolOutput {
91                    id: inv.id,
92                    result: out,
93                })
94            }
95        });
96        Self {
97            name,
98            description,
99            parameters_schema,
100            service: BoxCloneService::new(svc),
101        }
102    }
103
104    /// Convert this tool's function signature into an OpenAI ChatCompletionTool spec.
105    pub fn to_openai_tool(&self) -> ChatCompletionTool {
106        let func = FunctionObjectArgs::default()
107            .name(self.name)
108            .description(self.description)
109            .parameters(self.parameters_schema.clone())
110            .build()
111            .expect("valid function object");
112        ChatCompletionToolArgs::default()
113            .r#type(ChatCompletionToolType::Function)
114            .function(func)
115            .build()
116            .expect("valid chat tool")
117    }
118}
119
120/// DX sugar: create a tool from a typed handler.
121/// - `A` is the input args struct (Deserialize + JsonSchema)
122/// - `R` is the output type (Serialize)
123pub fn tool_typed<A, H, Fut, R>(
124    name: &'static str,
125    description: &'static str,
126    handler: H,
127) -> ToolDef
128where
129    A: DeserializeOwned + JsonSchema + Send + 'static,
130    R: serde::Serialize + Send + 'static,
131    H: Fn(A) -> Fut + Send + Sync + 'static,
132    Fut: Future<Output = Result<R, BoxError>> + Send + 'static,
133{
134    let schema = schemars::schema_for!(A);
135    let params_value = serde_json::to_value(schema.schema).expect("schema to value");
136    let handler_arc_inner = Arc::new(handler);
137    let handler_arc: Arc<
138        dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync,
139    > = Arc::new(move |raw: Value| {
140        let h = handler_arc_inner.clone();
141        Box::pin(async move {
142            let args: A = serde_json::from_value(raw)?;
143            let out: R = (h.as_ref())(args).await?;
144            let val = serde_json::to_value(out)?;
145            Ok(val)
146        })
147    });
148    ToolDef::from_handler(name, description, params_value, handler_arc)
149}
150
151/// Simple router service over tools using a name → index table.
152#[derive(Clone)]
153pub struct ToolRouter {
154    name_to_index: std::collections::HashMap<&'static str, usize>,
155    services: Vec<ToolSvc>, // index 0 is the unknown-tool fallback
156}
157
158impl ToolRouter {
159    pub fn new(tools: Vec<ToolDef>) -> (Self, Vec<ChatCompletionTool>) {
160        use std::collections::HashMap;
161
162        let unknown = BoxCloneService::new(tower::service_fn(|inv: ToolInvocation| async move {
163            Err::<ToolOutput, BoxError>(format!("unknown tool: {}", inv.name).into())
164        }));
165
166        let mut services: Vec<ToolSvc> = vec![unknown];
167        let mut specs: Vec<ChatCompletionTool> = Vec::with_capacity(tools.len());
168        let mut name_to_index: HashMap<&'static str, usize> = HashMap::new();
169
170        for (i, td) in tools.into_iter().enumerate() {
171            name_to_index.insert(td.name, i + 1);
172            specs.push(td.to_openai_tool());
173            services.push(td.service);
174        }
175
176        (
177            Self {
178                name_to_index,
179                services,
180            },
181            specs,
182        )
183    }
184}
185
186impl Service<ToolInvocation> for ToolRouter {
187    type Response = ToolOutput;
188    type Error = BoxError;
189    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
190
191    fn poll_ready(
192        &mut self,
193        _cx: &mut std::task::Context<'_>,
194    ) -> std::task::Poll<Result<(), Self::Error>> {
195        // We check readiness per selected service inside `call`.
196        std::task::Poll::Ready(Ok(()))
197    }
198
199    fn call(&mut self, req: ToolInvocation) -> Self::Future {
200        let idx = self
201            .name_to_index
202            .get(req.name.as_str())
203            .copied()
204            .unwrap_or(0);
205
206        // Safe: index 0 is always present (unknown fallback)
207        let svc: &mut ToolSvc = &mut self.services[idx];
208        // Call selected service and forward its future
209        let fut = svc.call(req);
210        Box::pin(fut)
211    }
212}
213
214// ======================================================================
215// Agent instruction provider - abstraction over creating a system prompt
216// ======================================================================
217#[async_trait]
218pub trait LLMInstructionProvider: Send + Sync {
219    async fn instructions(&self) -> Option<String>;
220}
221
222#[derive(Clone)]
223enum InstructionSource {
224    Static(String),
225    Dynamic(Arc<dyn LLMInstructionProvider>),
226}
227
228// =============================
229// Step service and layer
230// =============================
231
232/// Auxiliary accounting captured per step.
233#[derive(Debug, Clone, Default)]
234#[allow(dead_code)]
235pub struct StepAux {
236    pub prompt_tokens: usize,
237    pub completion_tokens: usize,
238    pub tool_invocations: usize,
239}
240
241/// Outcome of a single agent step.
242#[derive(Debug, Clone)]
243#[allow(dead_code)]
244pub enum StepOutcome {
245    Next {
246        messages: Vec<ChatCompletionRequestMessage>,
247        aux: StepAux,
248        invoked_tools: Vec<String>,
249    },
250    Done {
251        messages: Vec<ChatCompletionRequestMessage>,
252        aux: StepAux,
253    },
254}
255
256fn summarize_request_messages(messages: &[ChatCompletionRequestMessage]) -> Vec<&'static str> {
257    messages
258        .iter()
259        .map(|message| match message {
260            ChatCompletionRequestMessage::System(_) => "system",
261            ChatCompletionRequestMessage::User(_) => "user",
262            ChatCompletionRequestMessage::Assistant(_) => "assistant",
263            ChatCompletionRequestMessage::Tool(_) => "tool",
264            ChatCompletionRequestMessage::Function(_) => "function",
265            ChatCompletionRequestMessage::Developer(_) => "developer",
266        })
267        .collect()
268}
269
270fn summarize_assistant_message(message: &ChatCompletionResponseMessage) -> String {
271    let content_shape = if message
272        .content
273        .as_ref()
274        .map(|c| !c.is_empty())
275        .unwrap_or(false)
276    {
277        "text"
278    } else {
279        "none"
280    };
281    let tool_calls = message.tool_calls.as_ref().map(|c| c.len()).unwrap_or(0);
282    format!("assistant(content={content_shape}, tool_calls={tool_calls})")
283}
284
285/// One-step agent service parameterized by a routed tool service `S`.
286pub struct Step<S, P> {
287    provider: Arc<tokio::sync::Mutex<P>>,
288    model: String,
289    temperature: Option<f32>,
290    max_tokens: Option<u32>,
291    reasoning_effort: Option<ReasoningEffort>,
292    instructions: Option<InstructionSource>,
293    tools: S,
294    tool_specs: Arc<Vec<ChatCompletionTool>>, // supplied to requests if missing
295    parallel_tools: bool,
296    tool_concurrency_limit: Option<usize>,
297    join_policy: ToolJoinPolicy,
298}
299
300impl<S, P> Step<S, P> {
301    pub fn new(
302        provider: P,
303        model: impl Into<String>,
304        tools: S,
305        tool_specs: Vec<ChatCompletionTool>,
306    ) -> Self {
307        Self {
308            provider: Arc::new(tokio::sync::Mutex::new(provider)),
309            model: model.into(),
310            temperature: None,
311            max_tokens: None,
312            reasoning_effort: None,
313            instructions: None,
314            tools,
315            tool_specs: Arc::new(tool_specs),
316            parallel_tools: false,
317            tool_concurrency_limit: None,
318            join_policy: ToolJoinPolicy::FailFast,
319        }
320    }
321
322    pub fn temperature(mut self, t: f32) -> Self {
323        self.temperature = Some(t);
324        self
325    }
326
327    pub fn max_tokens(mut self, mt: u32) -> Self {
328        self.max_tokens = Some(mt);
329        self
330    }
331
332    pub fn enable_parallel_tools(mut self, enabled: bool) -> Self {
333        self.parallel_tools = enabled;
334        self
335    }
336
337    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
338        self.tool_concurrency_limit = Some(limit);
339        self
340    }
341
342    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
343        self.join_policy = policy;
344        self
345    }
346
347    pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
348        self.reasoning_effort = Some(effort);
349        self
350    }
351}
352
353/// Layer that lifts a routed tool service `S` into a `Step<S>` service.
354pub struct StepLayer<P> {
355    provider: P,
356    model: String,
357    temperature: Option<f32>,
358    max_tokens: Option<u32>,
359    reasoning_effort: Option<ReasoningEffort>,
360    instructions: Option<InstructionSource>,
361    tool_specs: Arc<Vec<ChatCompletionTool>>,
362    parallel_tools: bool,
363    tool_concurrency_limit: Option<usize>,
364    join_policy: ToolJoinPolicy,
365}
366
367impl<P> StepLayer<P> {
368    pub fn new(provider: P, model: impl Into<String>, tool_specs: Vec<ChatCompletionTool>) -> Self {
369        Self {
370            provider,
371            model: model.into(),
372            temperature: None,
373            max_tokens: None,
374            reasoning_effort: None,
375            instructions: None,
376            tool_specs: Arc::new(tool_specs),
377            parallel_tools: false,
378            tool_concurrency_limit: None,
379            join_policy: ToolJoinPolicy::FailFast,
380        }
381    }
382
383    pub fn temperature(mut self, t: f32) -> Self {
384        self.temperature = Some(t);
385        self
386    }
387
388    pub fn max_tokens(mut self, mt: u32) -> Self {
389        self.max_tokens = Some(mt);
390        self
391    }
392
393    pub fn parallel_tools(mut self, enabled: bool) -> Self {
394        self.parallel_tools = enabled;
395        self
396    }
397
398    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
399        self.tool_concurrency_limit = Some(limit);
400        self
401    }
402
403    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
404        self.join_policy = policy;
405        self
406    }
407
408    pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
409        self.reasoning_effort = Some(effort);
410        self
411    }
412
413    pub fn instructions(mut self, text: impl Into<String>) -> Self {
414        self.instructions = Some(InstructionSource::Static(text.into()));
415        self
416    }
417
418    pub fn instruction_provider(mut self, provider: Arc<dyn LLMInstructionProvider>) -> Self {
419        self.instructions = Some(InstructionSource::Dynamic(provider));
420        self
421    }
422}
423
424impl<S, P> Layer<S> for StepLayer<P>
425where
426    P: Clone,
427{
428    type Service = Step<S, P>;
429
430    fn layer(&self, tools: S) -> Self::Service {
431        let mut s = Step::new(
432            self.provider.clone(),
433            self.model.clone(),
434            tools,
435            (*self.tool_specs).clone(),
436        );
437        s.temperature = self.temperature;
438        s.max_tokens = self.max_tokens;
439        s.reasoning_effort = self.reasoning_effort.clone();
440        s.instructions = self.instructions.clone();
441        // propagate instructions if StepLayer has it
442        // Note: StepLayer currently doesn't store instructions; this will be set via AgentBuilder mapping below
443        s.parallel_tools = self.parallel_tools;
444        s.tool_concurrency_limit = self.tool_concurrency_limit;
445        s.join_policy = self.join_policy;
446        s
447    }
448}
449
450impl<S, P> Service<CreateChatCompletionRequest> for Step<S, P>
451where
452    S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
453    S::Future: Send + 'static,
454    P: ModelService + Send + 'static,
455    P::Future: Send + 'static,
456{
457    type Response = StepOutcome;
458    type Error = BoxError;
459    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
460
461    fn poll_ready(
462        &mut self,
463        cx: &mut std::task::Context<'_>,
464    ) -> std::task::Poll<Result<(), Self::Error>> {
465        let _ = cx; // Always ready; we await tools readiness inside `call`
466        std::task::Poll::Ready(Ok(()))
467    }
468
469    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
470        let provider = self.provider.clone();
471        let model = self.model.clone();
472        let temperature = self.temperature;
473        let max_tokens = self.max_tokens;
474        let reasoning_effort = self.reasoning_effort.clone();
475        let tools = self.tools.clone();
476        let tool_specs = self.tool_specs.clone();
477        let parallel_tools = self.parallel_tools;
478        let _tool_concurrency_limit = self.tool_concurrency_limit;
479        let join_policy = self.join_policy;
480        let instruction_source = self.instructions.clone();
481
482        Box::pin(async move {
483            // Rebuild request using builder to avoid deprecated field access
484            let effective_model: Option<String> = req.model.clone().into();
485
486            // Determine which model will be used
487            let model_to_use = if let Some(m) = effective_model.as_ref() {
488                m.clone()
489            } else {
490                model.clone()
491            };
492
493            // Log model parameters
494            debug!(
495                model = %model_to_use,
496                temperature = ?req.temperature.or(temperature),
497                max_tokens = ?max_tokens,
498                tools_count = if req.tools.is_some() {
499                    req.tools.as_ref().map(|t| t.len())
500                } else {
501                    Some(tool_specs.len())
502                },
503                "Step service preparing API request"
504            );
505
506            // Prepare messages with optional agent-level instructions injection
507            let mut injected_messages = req.messages.clone();
508            let instructions = match instruction_source {
509                Some(InstructionSource::Static(text)) => Some(text),
510                Some(InstructionSource::Dynamic(provider)) => provider.instructions().await,
511                None => None,
512            };
513
514            if let Some(instr) = instructions {
515                // Build a system message for the instructions
516                let sys_msg = ChatCompletionRequestSystemMessageArgs::default()
517                    .content(instr)
518                    .build()
519                    .map(ChatCompletionRequestMessage::from)
520                    .map_err(|e| format!("system msg build error: {}", e))?;
521                // Ensure exactly one system message at the front
522                if let Some(pos) = injected_messages
523                    .iter()
524                    .position(|m| matches!(m, ChatCompletionRequestMessage::System(_)))
525                {
526                    injected_messages.remove(pos);
527                }
528                injected_messages.insert(0, sys_msg);
529            }
530
531            let mut builder = CreateChatCompletionRequestArgs::default();
532            builder.messages(injected_messages);
533            if let Some(m) = effective_model.as_ref() {
534                builder.model(m);
535            } else {
536                builder.model(&model);
537            }
538            if let Some(t) = req.temperature.or(temperature) {
539                builder.temperature(t);
540            }
541            // Use request's max_tokens if set, otherwise use layer's max_tokens if set
542            #[allow(deprecated)]
543            if let Some(mt) = req.max_tokens.or(max_tokens) {
544                builder.max_tokens(mt);
545            }
546            if let Some(effort) = reasoning_effort {
547                builder.reasoning_effort(effort);
548            }
549            if let Some(ts) = req.tools.clone() {
550                builder.tools(ts);
551            } else if !tool_specs.is_empty() {
552                builder.tools((*tool_specs).clone());
553            }
554
555            let rebuilt_req = builder
556                .build()
557                .map_err(|e| format!("request build error: {}", e))?;
558
559            let request_shape = summarize_request_messages(&rebuilt_req.messages);
560
561            // Trace the final request model
562            debug!(
563                final_model = ?rebuilt_req.model,
564                messages_count = rebuilt_req.messages.len(),
565                message_roles = ?request_shape,
566                "Step service final request built"
567            );
568
569            let mut messages = rebuilt_req.messages.clone();
570
571            // Single OpenAI call
572            // Provider call
573            let mut p = provider.lock().await;
574            let ProviderResponse {
575                assistant,
576                prompt_tokens,
577                completion_tokens,
578            } = ServiceExt::ready(&mut *p).await?.call(rebuilt_req).await?;
579            let response_shape = summarize_assistant_message(&assistant);
580            debug!(
581                response_shape = %response_shape,
582                prompt_tokens,
583                completion_tokens,
584                "Step service received provider response"
585            );
586            let mut aux = StepAux {
587                prompt_tokens,
588                completion_tokens,
589                tool_invocations: 0,
590            };
591
592            // Append assistant message by constructing request-side equivalent
593            let mut asst_builder = ChatCompletionRequestAssistantMessageArgs::default();
594            if let Some(content) = assistant.content.clone() {
595                asst_builder.content(content);
596            } else {
597                asst_builder.content("");
598            }
599            if let Some(tool_calls) = assistant.tool_calls.clone() {
600                asst_builder.tool_calls(tool_calls);
601            }
602            let asst_req = asst_builder
603                .build()
604                .map_err(|e| format!("assistant msg build error: {}", e))?;
605            messages.push(ChatCompletionRequestMessage::from(asst_req));
606
607            // Execute tool calls if present
608            let tool_calls = assistant.tool_calls.unwrap_or_default();
609            if tool_calls.is_empty() {
610                return Ok(StepOutcome::Done { messages, aux });
611            }
612
613            let mut invoked_names: Vec<String> = Vec::with_capacity(tool_calls.len());
614            let invocations: Vec<ToolInvocation> = tool_calls
615                .into_iter()
616                .map(|tc| {
617                    let name = tc.function.name;
618                    invoked_names.push(name.clone());
619                    let args: Value =
620                        serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
621                    ToolInvocation {
622                        id: tc.id,
623                        name,
624                        arguments: args,
625                    }
626                })
627                .collect();
628
629            if !invoked_names.is_empty() {
630                debug!(tool_names = ?invoked_names, "Step service invoking tools");
631            }
632
633            if invocations.len() > 1 && parallel_tools {
634                // Fire in parallel, preserve order
635                let sem = _tool_concurrency_limit.map(|n| Arc::new(Semaphore::new(n)));
636                match join_policy {
637                    ToolJoinPolicy::FailFast => {
638                        let futures: Vec<_> = invocations
639                            .into_iter()
640                            .map(|inv| {
641                                let mut svc = tools.clone();
642                                let sem_cl = sem.clone();
643                                async move {
644                                    let _permit = match &sem_cl {
645                                        Some(s) => Some(
646                                            s.clone().acquire_owned().await.expect("semaphore"),
647                                        ),
648                                        None => None,
649                                    };
650                                    let ToolOutput { id, result } =
651                                        ServiceExt::ready(&mut svc).await?.call(inv).await?;
652                                    Ok::<(String, Value), BoxError>((id, result))
653                                }
654                            })
655                            .collect();
656                        let outputs: Vec<(String, Value)> =
657                            futures::future::try_join_all(futures).await?;
658                        for (id, result) in outputs {
659                            aux.tool_invocations += 1;
660                            let tool_msg = ChatCompletionRequestToolMessageArgs::default()
661                                .content(result.to_string())
662                                .tool_call_id(id)
663                                .build()?;
664                            messages.push(tool_msg.into());
665                        }
666                    }
667                    ToolJoinPolicy::JoinAll => {
668                        let futures: Vec<_> =
669                            invocations
670                                .into_iter()
671                                .enumerate()
672                                .map(|(idx, inv)| {
673                                    let mut svc = tools.clone();
674                                    let sem_cl = sem.clone();
675                                    async move {
676                                        let _permit = match &sem_cl {
677                                            Some(s) => Some(
678                                                s.clone().acquire_owned().await.expect("semaphore"),
679                                            ),
680                                            None => None,
681                                        };
682                                        let res =
683                                            ServiceExt::ready(&mut svc).await?.call(inv).await;
684                                        match res {
685                                            Ok(ToolOutput { id, result }) => Ok::<
686                                                Result<(usize, String, Value), BoxError>,
687                                                BoxError,
688                                            >(
689                                                Ok((idx, id, result)),
690                                            ),
691                                            Err(e) => Ok(Err(e)),
692                                        }
693                                    }
694                                })
695                                .collect();
696                        let results = futures::future::join_all(futures).await;
697                        let mut successes: Vec<(usize, String, Value)> = Vec::new();
698                        let mut errors: Vec<String> = Vec::new();
699                        for item in results.into_iter() {
700                            match item {
701                                Ok(Ok((idx, id, result))) => successes.push((idx, id, result)),
702                                Ok(Err(e)) => errors.push(format!("{}", e)),
703                                Err(e) => errors.push(format!("{}", e)),
704                            }
705                        }
706                        successes.sort_by_key(|(idx, _, _)| *idx);
707                        for (_idx, id, result) in successes.into_iter() {
708                            aux.tool_invocations += 1;
709                            let tool_msg = ChatCompletionRequestToolMessageArgs::default()
710                                .content(result.to_string())
711                                .tool_call_id(id)
712                                .build()?;
713                            messages.push(tool_msg.into());
714                        }
715                        if !errors.is_empty() {
716                            return Err(
717                                format!("one or more tools failed: {}", errors.join("; ")).into()
718                            );
719                        }
720                    }
721                }
722            } else {
723                // Sequential
724                for inv in invocations {
725                    let mut svc = tools.clone();
726                    let ToolOutput { id, result } =
727                        ServiceExt::ready(&mut svc).await?.call(inv).await?;
728                    aux.tool_invocations += 1;
729                    let tool_msg = ChatCompletionRequestToolMessageArgs::default()
730                        .content(result.to_string())
731                        .tool_call_id(id)
732                        .build()?;
733                    messages.push(tool_msg.into());
734                }
735            }
736
737            Ok(StepOutcome::Next {
738                messages,
739                aux,
740                invoked_tools: invoked_names,
741            })
742        })
743    }
744}
745
746// =============================
747// Convenience helpers for examples/tests
748// =============================
749
750/// Build a simple chat request from plain strings.
751pub fn simple_chat_request(system: &str, user: &str) -> CreateChatCompletionRequest {
752    let sys = ChatCompletionRequestSystemMessageArgs::default()
753        .content(system)
754        .build()
755        .expect("system msg");
756    let usr = ChatCompletionRequestUserMessageArgs::default()
757        .content(user)
758        .build()
759        .expect("user msg");
760    CreateChatCompletionRequestArgs::default()
761        .model("gpt-4o")
762        .messages(vec![sys.into(), usr.into()])
763        .build()
764        .expect("chat req")
765}
766
767/// Build a simple chat request with only a user message.
768#[allow(dead_code)]
769pub fn simple_user_request(user: &str) -> CreateChatCompletionRequest {
770    let usr = ChatCompletionRequestUserMessageArgs::default()
771        .content(user)
772        .build()
773        .expect("user msg");
774    CreateChatCompletionRequestArgs::default()
775        .model("gpt-4o")
776        .messages(vec![usr.into()])
777        .build()
778        .expect("chat req")
779}
780
781// =============================
782// Agent loop: composable policies and layer
783// =============================
784
785/// Stop reasons reported by the agent loop.
786#[derive(Debug, Clone)]
787#[allow(dead_code)]
788pub enum AgentStopReason {
789    DoneNoToolCalls,
790    MaxSteps,
791    ToolCalled(String),
792    TokensBudgetExceeded,
793    ToolBudgetExceeded,
794    TimeBudgetExceeded,
795}
796
797// =============================
798// DX sugar: Policy builder, Agent builder, run helpers
799// =============================
800
801/// Chainable policy builder.
802#[derive(Default, Clone)]
803pub struct Policy {
804    inner: CompositePolicy,
805}
806
807#[allow(dead_code)]
808impl Policy {
809    pub fn new() -> Self {
810        Self {
811            inner: CompositePolicy::default(),
812        }
813    }
814    pub fn until_no_tool_calls(mut self) -> Self {
815        self.inner.policies.push(policies::until_no_tool_calls());
816        self
817    }
818    pub fn or_tool(mut self, name: impl Into<String>) -> Self {
819        self.inner.policies.push(policies::until_tool_called(name));
820        self
821    }
822    pub fn or_max_steps(mut self, max: usize) -> Self {
823        self.inner.policies.push(policies::max_steps(max));
824        self
825    }
826    pub fn build(self) -> CompositePolicy {
827        self.inner
828    }
829}
830
831/// Boxed agent service type for ergonomic returns.
832pub type AgentSvc = BoxService<CreateChatCompletionRequest, AgentRun, BoxError>;
833
834/// Thin facade to build an agent stack from tools, model, and policy.
835pub struct Agent;
836
837pub struct AgentBuilder {
838    client: Arc<Client<OpenAIConfig>>,
839    model: String,
840    temperature: Option<f32>,
841    max_tokens: Option<u32>,
842    reasoning_effort: Option<ReasoningEffort>,
843    instructions: Option<String>,
844    instruction_provider: Option<Arc<dyn LLMInstructionProvider>>,
845    tools: Vec<ToolDef>,
846    policy: CompositePolicy,
847    handoff: Option<crate::groups::AnyHandoffPolicy>,
848    provider: Option<
849        tower::util::BoxCloneService<
850            CreateChatCompletionRequest,
851            crate::provider::ProviderResponse,
852            BoxError,
853        >,
854    >,
855    enable_parallel_tools: bool,
856    tool_concurrency_limit: Option<usize>,
857    tool_join_policy: ToolJoinPolicy,
858    agent_service_map: Option<Arc<dyn Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static>>, // optional final wrapper
859    auto_compaction: Option<crate::auto_compaction::CompactionPolicy>,
860}
861
862impl Agent {
863    pub fn builder(client: Arc<Client<OpenAIConfig>>) -> AgentBuilder {
864        AgentBuilder {
865            client,
866            model: "gpt-4o".to_string(),
867            temperature: None,
868            max_tokens: None,
869            reasoning_effort: None,
870            instructions: None,
871            instruction_provider: None,
872            tools: Vec::new(),
873            policy: CompositePolicy::default(),
874            handoff: None,
875            provider: None,
876            enable_parallel_tools: false,
877            tool_concurrency_limit: None,
878            tool_join_policy: ToolJoinPolicy::FailFast,
879            agent_service_map: None,
880            auto_compaction: None,
881        }
882    }
883}
884
885impl AgentBuilder {
886    pub fn model(mut self, model: impl Into<String>) -> Self {
887        self.model = model.into();
888        self
889    }
890    pub fn temperature(mut self, t: f32) -> Self {
891        self.temperature = Some(t);
892        self
893    }
894    pub fn max_tokens(mut self, mt: u32) -> Self {
895        self.max_tokens = Some(mt);
896        self
897    }
898    pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
899        self.reasoning_effort = Some(effort);
900        self
901    }
902
903    /// Set agent-level instructions (system prompt). These will be injected on each step.
904    pub fn instructions(mut self, text: impl Into<String>) -> Self {
905        self.instructions = Some(text.into());
906        self
907    }
908    pub fn instruction_provider(mut self, provider: Arc<dyn LLMInstructionProvider>) -> Self {
909        self.instruction_provider = Some(provider);
910        self
911    }
912    pub fn tool(mut self, tool: ToolDef) -> Self {
913        self.tools.push(tool);
914        self
915    }
916    pub fn tools(mut self, tools: Vec<ToolDef>) -> Self {
917        self.tools.extend(tools);
918        self
919    }
920    pub fn policy(mut self, policy: CompositePolicy) -> Self {
921        self.policy = policy;
922        self
923    }
924
925    /// Enable handoff-aware tool interception and advertise handoff tools
926    pub fn handoff_policy(mut self, policy: crate::groups::AnyHandoffPolicy) -> Self {
927        self.handoff = Some(policy);
928        self
929    }
930
931    /// Override the non-streaming provider (useful for testing with a fixed/mocked model)
932    pub fn with_provider<P>(mut self, provider: P) -> Self
933    where
934        P: crate::provider::ModelService + Clone + Send + 'static,
935        P::Future: Send + 'static,
936    {
937        self.provider = Some(tower::util::BoxCloneService::new(provider));
938        self
939    }
940
941    /// Enable or disable parallel tool execution within a step
942    pub fn parallel_tools(mut self, enabled: bool) -> Self {
943        self.enable_parallel_tools = enabled;
944        self
945    }
946
947    /// Set an optional concurrency limit for parallel tool execution
948    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
949        self.tool_concurrency_limit = Some(limit);
950        self
951    }
952
953    /// Configure how parallel tool errors are handled (fail fast or join all)
954    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
955        self.tool_join_policy = policy;
956        self
957    }
958
959    /// Optional: wrap the final built agent service with a custom function.
960    /// This enables applying Tower layers at the agent boundary.
961    pub fn map_agent_service<F>(mut self, f: F) -> Self
962    where
963        F: Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static,
964    {
965        self.agent_service_map = Some(Arc::new(f));
966        self
967    }
968
969    /// Enable auto-compaction with the specified policy
970    pub fn auto_compaction(mut self, policy: crate::auto_compaction::CompactionPolicy) -> Self {
971        self.auto_compaction = Some(policy);
972        self
973    }
974
975    pub fn build(self) -> AgentSvc {
976        let (router, mut specs) = ToolRouter::new(self.tools);
977        // If handoff policy provided, wrap router and extend tool specs
978        let routed: ToolSvc = if let Some(policy) = &self.handoff {
979            let hand_spec = policy.handoff_tools();
980            if !hand_spec.is_empty() {
981                specs.extend(hand_spec);
982            }
983            crate::groups::layer_tool_router_with_handoff(router, policy.clone())
984        } else {
985            // No handoff layer; clonable box of the router
986            BoxCloneService::new(router)
987        };
988
989        let base_provider: tower::util::BoxCloneService<
990            CreateChatCompletionRequest,
991            crate::provider::ProviderResponse,
992            BoxError,
993        > = if let Some(p) = self.provider {
994            p
995        } else {
996            tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
997        };
998        let mut step_layer = StepLayer::new(base_provider.clone(), self.model, specs)
999            .parallel_tools(self.enable_parallel_tools)
1000            .tool_join_policy(self.tool_join_policy);
1001        if let Some(instr) = &self.instructions {
1002            step_layer = step_layer.instructions(instr.clone());
1003        } else if let Some(provider) = &self.instruction_provider {
1004            step_layer = step_layer.instruction_provider(provider.clone());
1005        }
1006        // Only set temperature if explicitly provided
1007        if let Some(t) = self.temperature {
1008            step_layer = step_layer.temperature(t);
1009        }
1010        // Only set max_tokens if explicitly provided
1011        if let Some(mt) = self.max_tokens {
1012            step_layer = step_layer.max_tokens(mt);
1013        }
1014        if let Some(effort) = self.reasoning_effort {
1015            step_layer = step_layer.reasoning_effort(effort);
1016        }
1017        if let Some(lim) = self.tool_concurrency_limit {
1018            step_layer = step_layer.tool_concurrency_limit(lim);
1019        }
1020        let step = step_layer.layer(routed);
1021
1022        // Apply auto-compaction if configured
1023        let step_with_compaction: BoxService<CreateChatCompletionRequest, StepOutcome, BoxError> =
1024            if let Some(compaction_policy) = self.auto_compaction {
1025                // Use the provider for compaction
1026                let compaction_provider = base_provider;
1027                let token_counter = crate::auto_compaction::SimpleTokenCounter::new();
1028                let compaction_layer = crate::auto_compaction::AutoCompactionLayer::new(
1029                    compaction_policy,
1030                    compaction_provider,
1031                    token_counter,
1032                );
1033                BoxService::new(compaction_layer.layer(step))
1034            } else {
1035                BoxService::new(step)
1036            };
1037
1038        let agent = AgentLoopLayer::new(self.policy).layer(step_with_compaction);
1039        let boxed = BoxService::new(agent);
1040        match &self.agent_service_map {
1041            Some(map) => (map)(boxed),
1042            None => boxed,
1043        }
1044    }
1045
1046    /// Build an agent service wrapped with session memory persistence
1047    pub fn build_with_session<Ls, Ss>(
1048        self,
1049        load: Arc<Ls>,
1050        save: Arc<Ss>,
1051        session_id: crate::sessions::SessionId,
1052    ) -> AgentSvc
1053    where
1054        Ls: Service<
1055                crate::sessions::LoadSession,
1056                Response = crate::sessions::History,
1057                Error = BoxError,
1058            > + Send
1059            + Sync
1060            + Clone
1061            + 'static,
1062        Ls::Future: Send + 'static,
1063        Ss: Service<crate::sessions::SaveSession, Response = (), Error = BoxError>
1064            + Send
1065            + Sync
1066            + Clone
1067            + 'static,
1068        Ss::Future: Send + 'static,
1069    {
1070        let (router, mut specs) = ToolRouter::new(self.tools);
1071        let routed: ToolSvc = if let Some(policy) = &self.handoff {
1072            let hand_spec = policy.handoff_tools();
1073            if !hand_spec.is_empty() {
1074                specs.extend(hand_spec);
1075            }
1076            crate::groups::layer_tool_router_with_handoff(router, policy.clone())
1077        } else {
1078            BoxCloneService::new(router)
1079        };
1080
1081        let base_provider: tower::util::BoxCloneService<
1082            CreateChatCompletionRequest,
1083            crate::provider::ProviderResponse,
1084            BoxError,
1085        > = if let Some(p) = self.provider {
1086            p
1087        } else {
1088            tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
1089        };
1090        let mut step_layer = StepLayer::new(base_provider.clone(), self.model, specs)
1091            .parallel_tools(self.enable_parallel_tools)
1092            .tool_join_policy(self.tool_join_policy);
1093        if let Some(instr) = &self.instructions {
1094            step_layer = step_layer.instructions(instr.clone());
1095        } else if let Some(provider) = &self.instruction_provider {
1096            step_layer = step_layer.instruction_provider(provider.clone());
1097        }
1098        // Only set temperature if explicitly provided
1099        if let Some(t) = self.temperature {
1100            step_layer = step_layer.temperature(t);
1101        }
1102        // Only set max_tokens if explicitly provided
1103        if let Some(mt) = self.max_tokens {
1104            step_layer = step_layer.max_tokens(mt);
1105        }
1106        if let Some(effort) = self.reasoning_effort {
1107            step_layer = step_layer.reasoning_effort(effort);
1108        }
1109        if let Some(lim) = self.tool_concurrency_limit {
1110            step_layer = step_layer.tool_concurrency_limit(lim);
1111        }
1112        let step = step_layer.layer(routed);
1113
1114        // Apply auto-compaction if configured (before memory layer)
1115        let step_with_compaction: BoxService<CreateChatCompletionRequest, StepOutcome, BoxError> =
1116            if let Some(compaction_policy) = self.auto_compaction {
1117                // Use the provider for compaction
1118                let compaction_provider = base_provider;
1119                let token_counter = crate::auto_compaction::SimpleTokenCounter::new();
1120                let compaction_layer = crate::auto_compaction::AutoCompactionLayer::new(
1121                    compaction_policy,
1122                    compaction_provider,
1123                    token_counter,
1124                );
1125                BoxService::new(compaction_layer.layer(step))
1126            } else {
1127                BoxService::new(step)
1128            };
1129
1130        // Attach memory layer
1131        let mem_layer = crate::sessions::MemoryLayer::new(load, save, session_id);
1132        let step_with_mem = mem_layer.layer(step_with_compaction);
1133        let agent = AgentLoopLayer::new(self.policy).layer(step_with_mem);
1134        let boxed = BoxService::new(agent);
1135        match &self.agent_service_map {
1136            Some(map) => (map)(boxed),
1137            None => boxed,
1138        }
1139    }
1140}
1141
1142/// Convenience: run a prompt through an agent service.
1143pub async fn run(agent: &mut AgentSvc, system: &str, user: &str) -> Result<AgentRun, BoxError> {
1144    let req = simple_chat_request(system, user);
1145    let resp = ServiceExt::ready(agent).await?.call(req).await?;
1146    Ok(resp)
1147}
1148
1149/// Convenience: run a user message through an agent service. System instructions come from the agent.
1150#[allow(dead_code)]
1151pub async fn run_user(agent: &mut AgentSvc, user: &str) -> Result<AgentRun, BoxError> {
1152    let req = simple_user_request(user);
1153    let resp = ServiceExt::ready(agent).await?.call(req).await?;
1154    Ok(resp)
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159    use super::*;
1160    use async_openai::types::ChatCompletionRequestUserMessageArgs;
1161    use async_trait::async_trait;
1162    use std::sync::Arc;
1163
1164    #[tokio::test]
1165    async fn step_injects_instructions_prepend_or_replace() {
1166        // Provider that echoes back with minimal tokens
1167        #[allow(deprecated)]
1168        let assistant = async_openai::types::ChatCompletionResponseMessage {
1169            content: Some("ok".into()),
1170            role: async_openai::types::Role::Assistant,
1171            tool_calls: None,
1172            function_call: None,
1173            refusal: None,
1174            audio: None,
1175        };
1176        let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1177            assistant,
1178            prompt_tokens: 1,
1179            completion_tokens: 1,
1180        });
1181
1182        // No tools
1183        let (router, specs) = ToolRouter::new(vec![]);
1184        let step = StepLayer::new(provider, "gpt-4o", specs)
1185            .instructions("AGENT INSTR")
1186            .layer(router);
1187        let mut svc = tower::ServiceExt::boxed(step);
1188
1189        // Build request with only user message
1190        let user = ChatCompletionRequestUserMessageArgs::default()
1191            .content("hello")
1192            .build()
1193            .unwrap();
1194        let req = CreateChatCompletionRequestArgs::default()
1195            .model("gpt-4o")
1196            .messages(vec![user.into()])
1197            .build()
1198            .unwrap();
1199
1200        let out = tower::ServiceExt::ready(&mut svc)
1201            .await
1202            .unwrap()
1203            .call(req)
1204            .await
1205            .unwrap();
1206        let msgs = match out {
1207            StepOutcome::Next { messages, .. } => messages,
1208            StepOutcome::Done { messages, .. } => messages,
1209        };
1210        // First message must be system with injected content
1211        match &msgs[0] {
1212            ChatCompletionRequestMessage::System(s) => {
1213                if let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) =
1214                    &s.content
1215                {
1216                    assert_eq!(t, "AGENT INSTR");
1217                } else {
1218                    panic!("expected text content in system message");
1219                }
1220            }
1221            _ => panic!("expected first message to be system"),
1222        }
1223    }
1224
1225    #[derive(Clone)]
1226    struct CountingInstructionProvider {
1227        counter: Arc<tokio::sync::Mutex<usize>>,
1228    }
1229
1230    #[async_trait]
1231    impl LLMInstructionProvider for CountingInstructionProvider {
1232        async fn instructions(&self) -> Option<String> {
1233            let mut guard = self.counter.lock().await;
1234            *guard += 1;
1235            Some(format!("CTX {}", *guard))
1236        }
1237    }
1238
1239    #[tokio::test]
1240    async fn step_uses_instruction_provider_each_call() {
1241        #[allow(deprecated)]
1242        let assistant = async_openai::types::ChatCompletionResponseMessage {
1243            content: Some("ok".into()),
1244            role: async_openai::types::Role::Assistant,
1245            tool_calls: None,
1246            function_call: None,
1247            refusal: None,
1248            audio: None,
1249        };
1250        let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1251            assistant,
1252            prompt_tokens: 1,
1253            completion_tokens: 1,
1254        });
1255
1256        let (router, specs) = ToolRouter::new(vec![]);
1257        let instruction_provider = Arc::new(CountingInstructionProvider {
1258            counter: Arc::new(tokio::sync::Mutex::new(0)),
1259        });
1260        let step = StepLayer::new(provider, "gpt-4o", specs)
1261            .instruction_provider(instruction_provider)
1262            .layer(router);
1263        let mut svc = tower::ServiceExt::boxed(step);
1264
1265        let build_request = || {
1266            let user = ChatCompletionRequestUserMessageArgs::default()
1267                .content("hello")
1268                .build()
1269                .unwrap();
1270            CreateChatCompletionRequestArgs::default()
1271                .model("gpt-4o")
1272                .messages(vec![user.into()])
1273                .build()
1274                .unwrap()
1275        };
1276
1277        for expected in ["CTX 1", "CTX 2"] {
1278            let resp = tower::ServiceExt::ready(&mut svc)
1279                .await
1280                .unwrap()
1281                .call(build_request())
1282                .await
1283                .unwrap();
1284
1285            let messages = match resp {
1286                StepOutcome::Next { messages, .. } => messages,
1287                StepOutcome::Done { messages, .. } => messages,
1288            };
1289
1290            match &messages[0] {
1291                ChatCompletionRequestMessage::System(s) => match &s.content {
1292                    async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1293                        assert_eq!(t, expected);
1294                    }
1295                    _ => panic!("expected text content in system message"),
1296                },
1297                _ => panic!("expected first message to be system"),
1298            }
1299        }
1300    }
1301
1302    #[test]
1303    fn builds_user_request() {
1304        let _ = simple_user_request("hi");
1305    }
1306
1307    #[tokio::test]
1308    async fn run_user_executes_with_instructions() {
1309        #[allow(deprecated)]
1310        let assistant = async_openai::types::ChatCompletionResponseMessage {
1311            content: Some("ok".into()),
1312            role: async_openai::types::Role::Assistant,
1313            tool_calls: None,
1314            function_call: None,
1315            refusal: None,
1316            audio: None,
1317        };
1318        let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1319            assistant,
1320            prompt_tokens: 1,
1321            completion_tokens: 1,
1322        });
1323        let client =
1324            std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1325        let mut agent = Agent::builder(client)
1326            .with_provider(provider)
1327            .model("gpt-4o")
1328            .instructions("INSTR")
1329            .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1330            .build();
1331        let run = run_user(&mut agent, "hello").await.unwrap();
1332        assert!(!run.messages.is_empty());
1333        match &run.messages[0] {
1334            ChatCompletionRequestMessage::System(s) => match &s.content {
1335                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1336                    assert_eq!(t, "INSTR");
1337                }
1338                _ => panic!("expected text content"),
1339            },
1340            _ => panic!("expected first message to be system"),
1341        }
1342    }
1343
1344    #[tokio::test]
1345    async fn sessions_preserve_agent_instructions_in_merged_request() {
1346        use crate::sessions::{InMemorySessionStore, SessionId};
1347        // Capturing provider
1348        #[derive(Clone)]
1349        struct CapturingProvider {
1350            captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1351        }
1352        impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1353            type Response = crate::provider::ProviderResponse;
1354            type Error = BoxError;
1355            type Future = std::pin::Pin<
1356                Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1357            >;
1358            fn poll_ready(
1359                &mut self,
1360                _cx: &mut std::task::Context<'_>,
1361            ) -> std::task::Poll<Result<(), Self::Error>> {
1362                std::task::Poll::Ready(Ok(()))
1363            }
1364            fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1365                let captured = self.captured.clone();
1366                Box::pin(async move {
1367                    *captured.lock().await = Some(req);
1368                    #[allow(deprecated)]
1369                    let assistant = async_openai::types::ChatCompletionResponseMessage {
1370                        content: Some("ok".into()),
1371                        role: async_openai::types::Role::Assistant,
1372                        tool_calls: None,
1373                        function_call: None,
1374                        refusal: None,
1375                        audio: None,
1376                    };
1377                    Ok(crate::provider::ProviderResponse {
1378                        assistant,
1379                        prompt_tokens: 1,
1380                        completion_tokens: 1,
1381                    })
1382                })
1383            }
1384        }
1385
1386        let captured = std::sync::Arc::new(tokio::sync::Mutex::new(None));
1387        let provider = CapturingProvider {
1388            captured: captured.clone(),
1389        };
1390
1391        let client =
1392            std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1393        let load = std::sync::Arc::new(InMemorySessionStore::default());
1394        let save = load.clone();
1395        let mut agent = Agent::builder(client)
1396            .with_provider(provider)
1397            .model("gpt-4o")
1398            .instructions("INSTR")
1399            .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1400            .build_with_session(load, save, SessionId("s1".into()));
1401
1402        let req = CreateChatCompletionRequestArgs::default()
1403            .model("gpt-4o")
1404            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1405                .content("hi")
1406                .build()
1407                .unwrap()
1408                .into()])
1409            .build()
1410            .unwrap();
1411        let _ = tower::ServiceExt::ready(&mut agent)
1412            .await
1413            .unwrap()
1414            .call(req)
1415            .await
1416            .unwrap();
1417
1418        let got = captured.lock().await.clone().expect("captured");
1419        match &got.messages[0] {
1420            ChatCompletionRequestMessage::System(s) => match &s.content {
1421                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1422                    assert_eq!(t, "INSTR");
1423                }
1424                _ => panic!("expected text content"),
1425            },
1426            _ => panic!("expected first message to be system"),
1427        }
1428    }
1429
1430    #[tokio::test]
1431    async fn auto_compaction_preserves_instructions() {
1432        use crate::auto_compaction::{CompactionPolicy, CompactionStrategy, ProactiveThreshold};
1433        // Capturing provider
1434        #[derive(Clone)]
1435        struct CapturingProvider {
1436            captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1437        }
1438        impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1439            type Response = crate::provider::ProviderResponse;
1440            type Error = BoxError;
1441            type Future = std::pin::Pin<
1442                Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1443            >;
1444            fn poll_ready(
1445                &mut self,
1446                _cx: &mut std::task::Context<'_>,
1447            ) -> std::task::Poll<Result<(), Self::Error>> {
1448                std::task::Poll::Ready(Ok(()))
1449            }
1450            fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1451                let captured = self.captured.clone();
1452                Box::pin(async move {
1453                    *captured.lock().await = Some(req);
1454                    #[allow(deprecated)]
1455                    let assistant = async_openai::types::ChatCompletionResponseMessage {
1456                        content: Some("ok".into()),
1457                        role: async_openai::types::Role::Assistant,
1458                        tool_calls: None,
1459                        function_call: None,
1460                        refusal: None,
1461                        audio: None,
1462                    };
1463                    Ok(crate::provider::ProviderResponse {
1464                        assistant,
1465                        prompt_tokens: 1,
1466                        completion_tokens: 1,
1467                    })
1468                })
1469            }
1470        }
1471
1472        let captured = std::sync::Arc::new(tokio::sync::Mutex::new(None));
1473        let provider = CapturingProvider {
1474            captured: captured.clone(),
1475        };
1476        let client =
1477            std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1478
1479        let policy = CompactionPolicy {
1480            compaction_model: "gpt-4o-mini".to_string(),
1481            proactive_threshold: Some(ProactiveThreshold {
1482                token_threshold: 1,
1483                percentage_threshold: None,
1484            }),
1485            compaction_strategy: CompactionStrategy::PreserveSystemAndRecent { recent_count: 1 },
1486            ..Default::default()
1487        };
1488
1489        let mut agent = Agent::builder(client)
1490            .with_provider(provider)
1491            .model("gpt-4o")
1492            .instructions("INSTR")
1493            .auto_compaction(policy)
1494            .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1495            .build();
1496
1497        let mut long_user = String::new();
1498        for _ in 0..200 {
1499            long_user.push('x');
1500        }
1501        let req = CreateChatCompletionRequestArgs::default()
1502            .model("gpt-4o")
1503            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1504                .content(long_user)
1505                .build()
1506                .unwrap()
1507                .into()])
1508            .build()
1509            .unwrap();
1510        let _ = tower::ServiceExt::ready(&mut agent)
1511            .await
1512            .unwrap()
1513            .call(req)
1514            .await
1515            .unwrap();
1516
1517        let got = captured.lock().await.clone().expect("captured");
1518        match &got.messages[0] {
1519            ChatCompletionRequestMessage::System(s) => match &s.content {
1520                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1521                    assert_eq!(t, "INSTR");
1522                }
1523                _ => panic!("expected text content"),
1524            },
1525            _ => panic!("expected first message to be system"),
1526        }
1527    }
1528}
1529
1530/// Loop state visible to policies.
1531#[derive(Debug, Clone, Default)]
1532pub struct LoopState {
1533    pub steps: usize,
1534}
1535
1536/// Policy interface controlling loop termination.
1537pub trait AgentPolicy: Send + Sync {
1538    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason>;
1539}
1540
1541/// Function-backed policy for ergonomic composition.
1542#[derive(Clone)]
1543#[allow(clippy::type_complexity)]
1544pub struct PolicyFn(
1545    pub Arc<dyn Fn(&LoopState, &StepOutcome) -> Option<AgentStopReason> + Send + Sync + 'static>,
1546);
1547
1548impl AgentPolicy for PolicyFn {
1549    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
1550        (self.0)(state, last)
1551    }
1552}
1553
1554/// Composite policy: stop when any sub-policy returns a stop reason.
1555#[derive(Clone, Default)]
1556pub struct CompositePolicy {
1557    policies: Vec<PolicyFn>,
1558}
1559
1560#[allow(dead_code)]
1561impl CompositePolicy {
1562    pub fn new(policies: Vec<PolicyFn>) -> Self {
1563        Self { policies }
1564    }
1565    pub fn push(&mut self, p: PolicyFn) {
1566        self.policies.push(p);
1567    }
1568}
1569
1570impl AgentPolicy for CompositePolicy {
1571    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
1572        for p in &self.policies {
1573            if let Some(r) = p.decide(state, last) {
1574                return Some(r);
1575            }
1576        }
1577        None
1578    }
1579}
1580
1581/// Built-in policies
1582#[allow(dead_code)]
1583pub mod policies {
1584    use super::*;
1585
1586    pub fn until_no_tool_calls() -> PolicyFn {
1587        PolicyFn(Arc::new(|_s, last| match last {
1588            StepOutcome::Done { .. } => Some(AgentStopReason::DoneNoToolCalls),
1589            _ => None,
1590        }))
1591    }
1592
1593    pub fn until_tool_called(tool_name: impl Into<String>) -> PolicyFn {
1594        let target = tool_name.into();
1595        PolicyFn(Arc::new(move |_s, last| match last {
1596            StepOutcome::Next { invoked_tools, .. } => {
1597                if invoked_tools.iter().any(|n| n == &target) {
1598                    Some(AgentStopReason::ToolCalled(target.clone()))
1599                } else {
1600                    None
1601                }
1602            }
1603            _ => None,
1604        }))
1605    }
1606
1607    pub fn max_steps(max: usize) -> PolicyFn {
1608        PolicyFn(Arc::new(move |s, _| {
1609            if s.steps >= max {
1610                Some(AgentStopReason::MaxSteps)
1611            } else {
1612                None
1613            }
1614        }))
1615    }
1616}
1617
1618/// Final run summary from the agent loop.
1619#[derive(Debug, Clone)]
1620pub struct AgentRun {
1621    pub messages: Vec<ChatCompletionRequestMessage>,
1622    pub steps: usize,
1623    pub stop: AgentStopReason,
1624}
1625
1626/// Layer to wrap a step service with an agent loop controlled by a policy.
1627pub struct AgentLoopLayer<P> {
1628    policy: P,
1629}
1630
1631impl<P> AgentLoopLayer<P> {
1632    pub fn new(policy: P) -> Self {
1633        Self { policy }
1634    }
1635}
1636
1637pub struct AgentLoop<S, P> {
1638    inner: Arc<tokio::sync::Mutex<S>>,
1639    policy: P,
1640}
1641
1642impl<S, P> Layer<S> for AgentLoopLayer<P>
1643where
1644    P: Clone,
1645{
1646    type Service = AgentLoop<S, P>;
1647    fn layer(&self, inner: S) -> Self::Service {
1648        AgentLoop {
1649            inner: Arc::new(tokio::sync::Mutex::new(inner)),
1650            policy: self.policy.clone(),
1651        }
1652    }
1653}
1654
1655impl<S, P> Service<CreateChatCompletionRequest> for AgentLoop<S, P>
1656where
1657    S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
1658        + Send
1659        + 'static,
1660    S::Future: Send + 'static,
1661    P: AgentPolicy + Send + Sync + Clone + 'static,
1662{
1663    type Response = AgentRun;
1664    type Error = BoxError;
1665    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1666
1667    fn poll_ready(
1668        &mut self,
1669        _cx: &mut std::task::Context<'_>,
1670    ) -> std::task::Poll<Result<(), Self::Error>> {
1671        std::task::Poll::Ready(Ok(()))
1672    }
1673
1674    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1675        let inner = self.inner.clone();
1676        let policy = self.policy.clone();
1677        Box::pin(async move {
1678            let mut state = LoopState::default();
1679            let base_model = req.model.clone();
1680            let mut current_messages = req.messages.clone();
1681            // Preserve all original request parameters
1682            let base_temperature = req.temperature;
1683            #[allow(deprecated)]
1684            let base_max_tokens = req.max_tokens;
1685            let base_max_completion_tokens = req.max_completion_tokens;
1686            let base_tools = req.tools.clone();
1687
1688            // Log the initial model for the agent loop
1689            debug!(
1690                model = ?base_model,
1691                initial_messages = current_messages.len(),
1692                "AgentLoop starting with model"
1693            );
1694
1695            loop {
1696                // Rebuild request for this iteration, preserving original parameters
1697                let mut builder = CreateChatCompletionRequestArgs::default();
1698                builder.model(&base_model);
1699                builder.messages(current_messages.clone());
1700                if let Some(t) = base_temperature {
1701                    builder.temperature(t);
1702                }
1703                if let Some(mt) = base_max_tokens {
1704                    builder.max_tokens(mt);
1705                }
1706                if let Some(mct) = base_max_completion_tokens {
1707                    builder.max_completion_tokens(mct);
1708                }
1709                if let Some(tools) = base_tools.clone() {
1710                    builder.tools(tools);
1711                }
1712                let current_req = builder
1713                    .build()
1714                    .map_err(|e| format!("build req error: {}", e))?;
1715
1716                trace!(
1717                    step = state.steps + 1,
1718                    model = ?current_req.model,
1719                    messages = current_messages.len(),
1720                    "AgentLoop iteration"
1721                );
1722
1723                let mut guard = inner.lock().await;
1724                let outcome = guard.ready().await?.call(current_req).await?;
1725                drop(guard);
1726
1727                state.steps += 1;
1728
1729                if let Some(stop) = policy.decide(&state, &outcome) {
1730                    let messages = match outcome {
1731                        StepOutcome::Next { messages, .. } => messages,
1732                        StepOutcome::Done { messages, .. } => messages,
1733                    };
1734                    return Ok(AgentRun {
1735                        messages,
1736                        steps: state.steps,
1737                        stop,
1738                    });
1739                }
1740
1741                match outcome {
1742                    StepOutcome::Next { messages, .. } => {
1743                        current_messages = messages;
1744                    }
1745                    StepOutcome::Done { messages, .. } => {
1746                        return Ok(AgentRun {
1747                            messages,
1748                            steps: state.steps,
1749                            stop: AgentStopReason::DoneNoToolCalls,
1750                        });
1751                    }
1752                }
1753            }
1754        })
1755    }
1756}