tower_llm/
core.rs

1//! Core agent implementation using Tower services and static dependency injection.
2
3use std::{future::Future, pin::Pin, sync::Arc};
4
5use crate::groups::HandoffPolicy;
6use crate::provider::{ModelService, OpenAIProvider, ProviderResponse};
7use async_openai::{
8    config::OpenAIConfig,
9    types::{
10        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
11        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
12        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
13        ChatCompletionToolType, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
14        FunctionObjectArgs,
15    },
16    Client,
17};
18use futures::future::BoxFuture;
19use schemars::JsonSchema;
20use serde::de::DeserializeOwned;
21use serde_json::Value;
22use tokio::sync::Semaphore;
23use tower::{
24    util::{BoxCloneService, BoxService},
25    BoxError, Layer, Service, ServiceExt,
26};
27
28/// Join policy for parallel tool execution
29#[derive(Debug, Clone, Copy, Default)]
30pub enum ToolJoinPolicy {
31    /// Return error on the first failing tool; pending tools are cancelled
32    #[default]
33    FailFast,
34    /// Run all tools to completion; if any fail, surface an aggregated error at the end
35    JoinAll,
36}
37
38// =============================
39// Tool service modeling
40// =============================
41
42/// Uniform tool invocation passed to routed tool services.
43#[derive(Debug, Clone)]
44pub struct ToolInvocation {
45    pub id: String,   // tool_call_id
46    pub name: String, // function.name
47    pub arguments: Value,
48}
49
50/// Uniform tool output produced by tool services.
51#[derive(Debug, Clone)]
52pub struct ToolOutput {
53    pub id: String, // same as invocation.id
54    pub result: Value,
55}
56
57/// Boxed tool service type alias.
58pub type ToolSvc = BoxCloneService<ToolInvocation, ToolOutput, BoxError>;
59
60/// Definition of a tool: function spec (for OpenAI) + service implementation.
61pub struct ToolDef {
62    pub name: &'static str,
63    pub description: &'static str,
64    pub parameters_schema: Value,
65    pub service: ToolSvc,
66}
67
68impl ToolDef {
69    /// Create a tool definition from a handler function that takes JSON args and returns JSON.
70    pub fn from_handler(
71        name: &'static str,
72        description: &'static str,
73        parameters_schema: Value,
74        handler: std::sync::Arc<
75            dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync + 'static,
76        >,
77    ) -> Self {
78        let handler_arc = handler.clone();
79        let svc = tower::service_fn(move |inv: ToolInvocation| {
80            let handler = handler_arc.clone();
81            async move {
82                if inv.name != name {
83                    return Err::<ToolOutput, BoxError>(
84                        format!("routed to wrong tool: expected={}, got={}", name, inv.name).into(),
85                    );
86                }
87                let out = (handler)(inv.arguments).await?;
88                Ok(ToolOutput {
89                    id: inv.id,
90                    result: out,
91                })
92            }
93        });
94        Self {
95            name,
96            description,
97            parameters_schema,
98            service: BoxCloneService::new(svc),
99        }
100    }
101
102    /// Convert this tool's function signature into an OpenAI ChatCompletionTool spec.
103    pub fn to_openai_tool(&self) -> ChatCompletionTool {
104        let func = FunctionObjectArgs::default()
105            .name(self.name)
106            .description(self.description)
107            .parameters(self.parameters_schema.clone())
108            .build()
109            .expect("valid function object");
110        ChatCompletionToolArgs::default()
111            .r#type(ChatCompletionToolType::Function)
112            .function(func)
113            .build()
114            .expect("valid chat tool")
115    }
116}
117
118/// DX sugar: create a tool from a typed handler.
119/// - `A` is the input args struct (Deserialize + JsonSchema)
120/// - `R` is the output type (Serialize)
121pub fn tool_typed<A, H, Fut, R>(
122    name: &'static str,
123    description: &'static str,
124    handler: H,
125) -> ToolDef
126where
127    A: DeserializeOwned + JsonSchema + Send + 'static,
128    R: serde::Serialize + Send + 'static,
129    H: Fn(A) -> Fut + Send + Sync + 'static,
130    Fut: Future<Output = Result<R, BoxError>> + Send + 'static,
131{
132    let schema = schemars::schema_for!(A);
133    let params_value = serde_json::to_value(schema.schema).expect("schema to value");
134    let handler_arc_inner = Arc::new(handler);
135    let handler_arc: Arc<
136        dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync,
137    > = Arc::new(move |raw: Value| {
138        let h = handler_arc_inner.clone();
139        Box::pin(async move {
140            let args: A = serde_json::from_value(raw)?;
141            let out: R = (h.as_ref())(args).await?;
142            let val = serde_json::to_value(out)?;
143            Ok(val)
144        })
145    });
146    ToolDef::from_handler(name, description, params_value, handler_arc)
147}
148
149/// Simple router service over tools using a name → index table.
150#[derive(Clone)]
151pub struct ToolRouter {
152    name_to_index: std::collections::HashMap<&'static str, usize>,
153    services: Vec<ToolSvc>, // index 0 is the unknown-tool fallback
154}
155
156impl ToolRouter {
157    pub fn new(tools: Vec<ToolDef>) -> (Self, Vec<ChatCompletionTool>) {
158        use std::collections::HashMap;
159
160        let unknown = BoxCloneService::new(tower::service_fn(|inv: ToolInvocation| async move {
161            Err::<ToolOutput, BoxError>(format!("unknown tool: {}", inv.name).into())
162        }));
163
164        let mut services: Vec<ToolSvc> = vec![unknown];
165        let mut specs: Vec<ChatCompletionTool> = Vec::with_capacity(tools.len());
166        let mut name_to_index: HashMap<&'static str, usize> = HashMap::new();
167
168        for (i, td) in tools.into_iter().enumerate() {
169            name_to_index.insert(td.name, i + 1);
170            specs.push(td.to_openai_tool());
171            services.push(td.service);
172        }
173
174        (
175            Self {
176                name_to_index,
177                services,
178            },
179            specs,
180        )
181    }
182}
183
184impl Service<ToolInvocation> for ToolRouter {
185    type Response = ToolOutput;
186    type Error = BoxError;
187    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
188
189    fn poll_ready(
190        &mut self,
191        _cx: &mut std::task::Context<'_>,
192    ) -> std::task::Poll<Result<(), Self::Error>> {
193        // We check readiness per selected service inside `call`.
194        std::task::Poll::Ready(Ok(()))
195    }
196
197    fn call(&mut self, req: ToolInvocation) -> Self::Future {
198        let idx = self
199            .name_to_index
200            .get(req.name.as_str())
201            .copied()
202            .unwrap_or(0);
203
204        // Safe: index 0 is always present (unknown fallback)
205        let svc: &mut ToolSvc = &mut self.services[idx];
206        // Call selected service and forward its future
207        let fut = svc.call(req);
208        Box::pin(fut)
209    }
210}
211
212// =============================
213// Step service and layer
214// =============================
215
216/// Auxiliary accounting captured per step.
217#[derive(Debug, Clone, Default)]
218#[allow(dead_code)]
219pub struct StepAux {
220    pub prompt_tokens: usize,
221    pub completion_tokens: usize,
222    pub tool_invocations: usize,
223}
224
225/// Outcome of a single agent step.
226#[derive(Debug, Clone)]
227#[allow(dead_code)]
228pub enum StepOutcome {
229    Next {
230        messages: Vec<ChatCompletionRequestMessage>,
231        aux: StepAux,
232        invoked_tools: Vec<String>,
233    },
234    Done {
235        messages: Vec<ChatCompletionRequestMessage>,
236        aux: StepAux,
237    },
238}
239
240/// One-step agent service parameterized by a routed tool service `S`.
241pub struct Step<S, P> {
242    provider: Arc<tokio::sync::Mutex<P>>,
243    model: String,
244    temperature: Option<f32>,
245    max_tokens: Option<u32>,
246    tools: S,
247    tool_specs: Arc<Vec<ChatCompletionTool>>, // supplied to requests if missing
248    parallel_tools: bool,
249    tool_concurrency_limit: Option<usize>,
250    join_policy: ToolJoinPolicy,
251}
252
253impl<S, P> Step<S, P> {
254    pub fn new(
255        provider: P,
256        model: impl Into<String>,
257        tools: S,
258        tool_specs: Vec<ChatCompletionTool>,
259    ) -> Self {
260        Self {
261            provider: Arc::new(tokio::sync::Mutex::new(provider)),
262            model: model.into(),
263            temperature: None,
264            max_tokens: None,
265            tools,
266            tool_specs: Arc::new(tool_specs),
267            parallel_tools: false,
268            tool_concurrency_limit: None,
269            join_policy: ToolJoinPolicy::FailFast,
270        }
271    }
272
273    pub fn temperature(mut self, t: f32) -> Self {
274        self.temperature = Some(t);
275        self
276    }
277
278    pub fn max_tokens(mut self, mt: u32) -> Self {
279        self.max_tokens = Some(mt);
280        self
281    }
282
283    pub fn enable_parallel_tools(mut self, enabled: bool) -> Self {
284        self.parallel_tools = enabled;
285        self
286    }
287
288    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
289        self.tool_concurrency_limit = Some(limit);
290        self
291    }
292
293    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
294        self.join_policy = policy;
295        self
296    }
297}
298
299/// Layer that lifts a routed tool service `S` into a `Step<S>` service.
300pub struct StepLayer<P> {
301    provider: P,
302    model: String,
303    temperature: Option<f32>,
304    max_tokens: Option<u32>,
305    tool_specs: Arc<Vec<ChatCompletionTool>>,
306    parallel_tools: bool,
307    tool_concurrency_limit: Option<usize>,
308    join_policy: ToolJoinPolicy,
309}
310
311impl<P> StepLayer<P> {
312    pub fn new(provider: P, model: impl Into<String>, tool_specs: Vec<ChatCompletionTool>) -> Self {
313        Self {
314            provider,
315            model: model.into(),
316            temperature: None,
317            max_tokens: None,
318            tool_specs: Arc::new(tool_specs),
319            parallel_tools: false,
320            tool_concurrency_limit: None,
321            join_policy: ToolJoinPolicy::FailFast,
322        }
323    }
324
325    pub fn temperature(mut self, t: f32) -> Self {
326        self.temperature = Some(t);
327        self
328    }
329
330    pub fn max_tokens(mut self, mt: u32) -> Self {
331        self.max_tokens = Some(mt);
332        self
333    }
334
335    pub fn parallel_tools(mut self, enabled: bool) -> Self {
336        self.parallel_tools = enabled;
337        self
338    }
339
340    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
341        self.tool_concurrency_limit = Some(limit);
342        self
343    }
344
345    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
346        self.join_policy = policy;
347        self
348    }
349}
350
351impl<S, P> Layer<S> for StepLayer<P>
352where
353    P: Clone,
354{
355    type Service = Step<S, P>;
356
357    fn layer(&self, tools: S) -> Self::Service {
358        let mut s = Step::new(
359            self.provider.clone(),
360            self.model.clone(),
361            tools,
362            (*self.tool_specs).clone(),
363        );
364        s.temperature = self.temperature;
365        s.max_tokens = self.max_tokens;
366        s.parallel_tools = self.parallel_tools;
367        s.tool_concurrency_limit = self.tool_concurrency_limit;
368        s.join_policy = self.join_policy;
369        s
370    }
371}
372
373impl<S, P> Service<CreateChatCompletionRequest> for Step<S, P>
374where
375    S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
376    S::Future: Send + 'static,
377    P: ModelService + Send + 'static,
378    P::Future: Send + 'static,
379{
380    type Response = StepOutcome;
381    type Error = BoxError;
382    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
383
384    fn poll_ready(
385        &mut self,
386        cx: &mut std::task::Context<'_>,
387    ) -> std::task::Poll<Result<(), Self::Error>> {
388        let _ = cx; // Always ready; we await tools readiness inside `call`
389        std::task::Poll::Ready(Ok(()))
390    }
391
392    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
393        let provider = self.provider.clone();
394        let model = self.model.clone();
395        let temperature = self.temperature;
396        let max_tokens = self.max_tokens;
397        let tools = self.tools.clone();
398        let tool_specs = self.tool_specs.clone();
399        let parallel_tools = self.parallel_tools;
400        let _tool_concurrency_limit = self.tool_concurrency_limit;
401        let join_policy = self.join_policy;
402
403        Box::pin(async move {
404            // Rebuild request using builder to avoid deprecated field access
405            let effective_model: Option<String> = req.model.clone().into();
406
407            let mut builder = CreateChatCompletionRequestArgs::default();
408            builder.messages(req.messages.clone());
409            if let Some(m) = effective_model.as_ref() {
410                builder.model(m);
411            } else {
412                builder.model(&model);
413            }
414            if let Some(t) = req.temperature.or(temperature) {
415                builder.temperature(t);
416            }
417            if let Some(mt) = max_tokens {
418                builder.max_tokens(mt);
419            }
420            if let Some(ts) = req.tools.clone() {
421                builder.tools(ts);
422            } else if !tool_specs.is_empty() {
423                builder.tools((*tool_specs).clone());
424            }
425
426            let rebuilt_req = builder
427                .build()
428                .map_err(|e| format!("request build error: {}", e))?;
429
430            let mut messages = rebuilt_req.messages.clone();
431
432            // Single OpenAI call
433            // Provider call
434            let mut p = provider.lock().await;
435            let ProviderResponse {
436                assistant,
437                prompt_tokens,
438                completion_tokens,
439            } = ServiceExt::ready(&mut *p).await?.call(rebuilt_req).await?;
440            let mut aux = StepAux {
441                prompt_tokens,
442                completion_tokens,
443                tool_invocations: 0,
444            };
445
446            // Append assistant message by constructing request-side equivalent
447            let mut asst_builder = ChatCompletionRequestAssistantMessageArgs::default();
448            if let Some(content) = assistant.content.clone() {
449                asst_builder.content(content);
450            } else {
451                asst_builder.content("");
452            }
453            if let Some(tool_calls) = assistant.tool_calls.clone() {
454                asst_builder.tool_calls(tool_calls);
455            }
456            let asst_req = asst_builder
457                .build()
458                .map_err(|e| format!("assistant msg build error: {}", e))?;
459            messages.push(ChatCompletionRequestMessage::from(asst_req));
460
461            // Execute tool calls if present
462            let tool_calls = assistant.tool_calls.unwrap_or_default();
463            if tool_calls.is_empty() {
464                return Ok(StepOutcome::Done { messages, aux });
465            }
466
467            let mut invoked_names: Vec<String> = Vec::with_capacity(tool_calls.len());
468            let invocations: Vec<ToolInvocation> = tool_calls
469                .into_iter()
470                .map(|tc| {
471                    let name = tc.function.name;
472                    invoked_names.push(name.clone());
473                    let args: Value =
474                        serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
475                    ToolInvocation {
476                        id: tc.id,
477                        name,
478                        arguments: args,
479                    }
480                })
481                .collect();
482
483            if invocations.len() > 1 && parallel_tools {
484                // Fire in parallel, preserve order
485                let sem = _tool_concurrency_limit.map(|n| Arc::new(Semaphore::new(n)));
486                match join_policy {
487                    ToolJoinPolicy::FailFast => {
488                        let futures: Vec<_> = invocations
489                            .into_iter()
490                            .map(|inv| {
491                                let mut svc = tools.clone();
492                                let sem_cl = sem.clone();
493                                async move {
494                                    let _permit = match &sem_cl {
495                                        Some(s) => Some(
496                                            s.clone().acquire_owned().await.expect("semaphore"),
497                                        ),
498                                        None => None,
499                                    };
500                                    let ToolOutput { id, result } =
501                                        ServiceExt::ready(&mut svc).await?.call(inv).await?;
502                                    Ok::<(String, Value), BoxError>((id, result))
503                                }
504                            })
505                            .collect();
506                        let outputs: Vec<(String, Value)> =
507                            futures::future::try_join_all(futures).await?;
508                        for (id, result) in outputs {
509                            aux.tool_invocations += 1;
510                            let tool_msg = ChatCompletionRequestToolMessageArgs::default()
511                                .content(result.to_string())
512                                .tool_call_id(id)
513                                .build()?;
514                            messages.push(tool_msg.into());
515                        }
516                    }
517                    ToolJoinPolicy::JoinAll => {
518                        let futures: Vec<_> =
519                            invocations
520                                .into_iter()
521                                .enumerate()
522                                .map(|(idx, inv)| {
523                                    let mut svc = tools.clone();
524                                    let sem_cl = sem.clone();
525                                    async move {
526                                        let _permit = match &sem_cl {
527                                            Some(s) => Some(
528                                                s.clone().acquire_owned().await.expect("semaphore"),
529                                            ),
530                                            None => None,
531                                        };
532                                        let res =
533                                            ServiceExt::ready(&mut svc).await?.call(inv).await;
534                                        match res {
535                                            Ok(ToolOutput { id, result }) => Ok::<
536                                                Result<(usize, String, Value), BoxError>,
537                                                BoxError,
538                                            >(
539                                                Ok((idx, id, result)),
540                                            ),
541                                            Err(e) => Ok(Err(e)),
542                                        }
543                                    }
544                                })
545                                .collect();
546                        let results = futures::future::join_all(futures).await;
547                        let mut successes: Vec<(usize, String, Value)> = Vec::new();
548                        let mut errors: Vec<String> = Vec::new();
549                        for item in results.into_iter() {
550                            match item {
551                                Ok(Ok((idx, id, result))) => successes.push((idx, id, result)),
552                                Ok(Err(e)) => errors.push(format!("{}", e)),
553                                Err(e) => errors.push(format!("{}", e)),
554                            }
555                        }
556                        successes.sort_by_key(|(idx, _, _)| *idx);
557                        for (_idx, id, result) in successes.into_iter() {
558                            aux.tool_invocations += 1;
559                            let tool_msg = ChatCompletionRequestToolMessageArgs::default()
560                                .content(result.to_string())
561                                .tool_call_id(id)
562                                .build()?;
563                            messages.push(tool_msg.into());
564                        }
565                        if !errors.is_empty() {
566                            return Err(
567                                format!("one or more tools failed: {}", errors.join("; ")).into()
568                            );
569                        }
570                    }
571                }
572            } else {
573                // Sequential
574                for inv in invocations {
575                    let mut svc = tools.clone();
576                    let ToolOutput { id, result } =
577                        ServiceExt::ready(&mut svc).await?.call(inv).await?;
578                    aux.tool_invocations += 1;
579                    let tool_msg = ChatCompletionRequestToolMessageArgs::default()
580                        .content(result.to_string())
581                        .tool_call_id(id)
582                        .build()?;
583                    messages.push(tool_msg.into());
584                }
585            }
586
587            Ok(StepOutcome::Next {
588                messages,
589                aux,
590                invoked_tools: invoked_names,
591            })
592        })
593    }
594}
595
596// =============================
597// Convenience helpers for examples/tests
598// =============================
599
600/// Build a simple chat request from plain strings.
601pub fn simple_chat_request(system: &str, user: &str) -> CreateChatCompletionRequest {
602    let sys = ChatCompletionRequestSystemMessageArgs::default()
603        .content(system)
604        .build()
605        .expect("system msg");
606    let usr = ChatCompletionRequestUserMessageArgs::default()
607        .content(user)
608        .build()
609        .expect("user msg");
610    CreateChatCompletionRequestArgs::default()
611        .model("gpt-4o")
612        .messages(vec![sys.into(), usr.into()])
613        .build()
614        .expect("chat req")
615}
616
617// =============================
618// Agent loop: composable policies and layer
619// =============================
620
621/// Stop reasons reported by the agent loop.
622#[derive(Debug, Clone)]
623#[allow(dead_code)]
624pub enum AgentStopReason {
625    DoneNoToolCalls,
626    MaxSteps,
627    ToolCalled(String),
628    TokensBudgetExceeded,
629    ToolBudgetExceeded,
630    TimeBudgetExceeded,
631}
632
633// =============================
634// DX sugar: Policy builder, Agent builder, run helpers
635// =============================
636
637/// Chainable policy builder.
638#[derive(Default, Clone)]
639pub struct Policy {
640    inner: CompositePolicy,
641}
642
643#[allow(dead_code)]
644impl Policy {
645    pub fn new() -> Self {
646        Self {
647            inner: CompositePolicy::default(),
648        }
649    }
650    pub fn until_no_tool_calls(mut self) -> Self {
651        self.inner.policies.push(policies::until_no_tool_calls());
652        self
653    }
654    pub fn or_tool(mut self, name: impl Into<String>) -> Self {
655        self.inner.policies.push(policies::until_tool_called(name));
656        self
657    }
658    pub fn or_max_steps(mut self, max: usize) -> Self {
659        self.inner.policies.push(policies::max_steps(max));
660        self
661    }
662    pub fn build(self) -> CompositePolicy {
663        self.inner
664    }
665}
666
667/// Boxed agent service type for ergonomic returns.
668pub type AgentSvc = BoxService<CreateChatCompletionRequest, AgentRun, BoxError>;
669
670/// Thin facade to build an agent stack from tools, model, and policy.
671pub struct Agent;
672
673pub struct AgentBuilder {
674    client: Arc<Client<OpenAIConfig>>,
675    model: String,
676    temperature: Option<f32>,
677    max_tokens: Option<u32>,
678    tools: Vec<ToolDef>,
679    policy: CompositePolicy,
680    handoff: Option<crate::groups::AnyHandoffPolicy>,
681    provider: Option<
682        tower::util::BoxCloneService<
683            CreateChatCompletionRequest,
684            crate::provider::ProviderResponse,
685            BoxError,
686        >,
687    >,
688    enable_parallel_tools: bool,
689    tool_concurrency_limit: Option<usize>,
690    tool_join_policy: ToolJoinPolicy,
691    agent_service_map: Option<Arc<dyn Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static>>, // optional final wrapper
692}
693
694impl Agent {
695    pub fn builder(client: Arc<Client<OpenAIConfig>>) -> AgentBuilder {
696        AgentBuilder {
697            client,
698            model: "gpt-4o".to_string(),
699            temperature: None,
700            max_tokens: None,
701            tools: Vec::new(),
702            policy: CompositePolicy::default(),
703            handoff: None,
704            provider: None,
705            enable_parallel_tools: false,
706            tool_concurrency_limit: None,
707            tool_join_policy: ToolJoinPolicy::FailFast,
708            agent_service_map: None,
709        }
710    }
711}
712
713impl AgentBuilder {
714    pub fn model(mut self, model: impl Into<String>) -> Self {
715        self.model = model.into();
716        self
717    }
718    pub fn temperature(mut self, t: f32) -> Self {
719        self.temperature = Some(t);
720        self
721    }
722    pub fn max_tokens(mut self, mt: u32) -> Self {
723        self.max_tokens = Some(mt);
724        self
725    }
726    pub fn tool(mut self, tool: ToolDef) -> Self {
727        self.tools.push(tool);
728        self
729    }
730    pub fn tools(mut self, tools: Vec<ToolDef>) -> Self {
731        self.tools.extend(tools);
732        self
733    }
734    pub fn policy(mut self, policy: CompositePolicy) -> Self {
735        self.policy = policy;
736        self
737    }
738
739    /// Enable handoff-aware tool interception and advertise handoff tools
740    pub fn handoff_policy(mut self, policy: crate::groups::AnyHandoffPolicy) -> Self {
741        self.handoff = Some(policy);
742        self
743    }
744
745    /// Override the non-streaming provider (useful for testing with a fixed/mocked model)
746    pub fn with_provider<P>(mut self, provider: P) -> Self
747    where
748        P: crate::provider::ModelService + Clone + Send + 'static,
749        P::Future: Send + 'static,
750    {
751        self.provider = Some(tower::util::BoxCloneService::new(provider));
752        self
753    }
754
755    /// Enable or disable parallel tool execution within a step
756    pub fn parallel_tools(mut self, enabled: bool) -> Self {
757        self.enable_parallel_tools = enabled;
758        self
759    }
760
761    /// Set an optional concurrency limit for parallel tool execution
762    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
763        self.tool_concurrency_limit = Some(limit);
764        self
765    }
766
767    /// Configure how parallel tool errors are handled (fail fast or join all)
768    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
769        self.tool_join_policy = policy;
770        self
771    }
772
773    /// Optional: wrap the final built agent service with a custom function.
774    /// This enables applying Tower layers at the agent boundary.
775    pub fn map_agent_service<F>(mut self, f: F) -> Self
776    where
777        F: Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static,
778    {
779        self.agent_service_map = Some(Arc::new(f));
780        self
781    }
782
783    pub fn build(self) -> AgentSvc {
784        let (router, mut specs) = ToolRouter::new(self.tools);
785        // If handoff policy provided, wrap router and extend tool specs
786        let routed: ToolSvc = if let Some(policy) = &self.handoff {
787            let hand_spec = policy.handoff_tools();
788            if !hand_spec.is_empty() {
789                specs.extend(hand_spec);
790            }
791            crate::groups::layer_tool_router_with_handoff(router, policy.clone())
792        } else {
793            // No handoff layer; clonable box of the router
794            BoxCloneService::new(router)
795        };
796
797        let base_provider: tower::util::BoxCloneService<
798            CreateChatCompletionRequest,
799            crate::provider::ProviderResponse,
800            BoxError,
801        > = if let Some(p) = self.provider {
802            p
803        } else {
804            tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
805        };
806        let mut step_layer = StepLayer::new(base_provider, self.model, specs)
807            .temperature(self.temperature.unwrap_or(0.0))
808            .max_tokens(self.max_tokens.unwrap_or(512))
809            .parallel_tools(self.enable_parallel_tools)
810            .tool_join_policy(self.tool_join_policy);
811        if let Some(lim) = self.tool_concurrency_limit {
812            step_layer = step_layer.tool_concurrency_limit(lim);
813        }
814        let step = step_layer.layer(routed);
815        let agent = AgentLoopLayer::new(self.policy).layer(step);
816        let boxed = BoxService::new(agent);
817        match &self.agent_service_map {
818            Some(map) => (map)(boxed),
819            None => boxed,
820        }
821    }
822
823    /// Build an agent service wrapped with session memory persistence
824    pub fn build_with_session<Ls, Ss>(
825        self,
826        load: Arc<Ls>,
827        save: Arc<Ss>,
828        session_id: crate::sessions::SessionId,
829    ) -> AgentSvc
830    where
831        Ls: Service<
832                crate::sessions::LoadSession,
833                Response = crate::sessions::History,
834                Error = BoxError,
835            > + Send
836            + Sync
837            + Clone
838            + 'static,
839        Ls::Future: Send + 'static,
840        Ss: Service<crate::sessions::SaveSession, Response = (), Error = BoxError>
841            + Send
842            + Sync
843            + Clone
844            + 'static,
845        Ss::Future: Send + 'static,
846    {
847        let (router, mut specs) = ToolRouter::new(self.tools);
848        let routed: ToolSvc = if let Some(policy) = &self.handoff {
849            let hand_spec = policy.handoff_tools();
850            if !hand_spec.is_empty() {
851                specs.extend(hand_spec);
852            }
853            crate::groups::layer_tool_router_with_handoff(router, policy.clone())
854        } else {
855            BoxCloneService::new(router)
856        };
857
858        let base_provider: tower::util::BoxCloneService<
859            CreateChatCompletionRequest,
860            crate::provider::ProviderResponse,
861            BoxError,
862        > = if let Some(p) = self.provider {
863            p
864        } else {
865            tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
866        };
867        let mut step_layer = StepLayer::new(base_provider, self.model, specs)
868            .temperature(self.temperature.unwrap_or(0.0))
869            .max_tokens(self.max_tokens.unwrap_or(512))
870            .parallel_tools(self.enable_parallel_tools)
871            .tool_join_policy(self.tool_join_policy);
872        if let Some(lim) = self.tool_concurrency_limit {
873            step_layer = step_layer.tool_concurrency_limit(lim);
874        }
875        let step = step_layer.layer(routed);
876
877        // Attach memory layer
878        let mem_layer = crate::sessions::MemoryLayer::new(load, save, session_id);
879        let step_with_mem = mem_layer.layer(step);
880        let agent = AgentLoopLayer::new(self.policy).layer(step_with_mem);
881        let boxed = BoxService::new(agent);
882        match &self.agent_service_map {
883            Some(map) => (map)(boxed),
884            None => boxed,
885        }
886    }
887}
888
889/// Convenience: run a prompt through an agent service.
890pub async fn run(agent: &mut AgentSvc, system: &str, user: &str) -> Result<AgentRun, BoxError> {
891    let req = simple_chat_request(system, user);
892    let resp = ServiceExt::ready(agent).await?.call(req).await?;
893    Ok(resp)
894}
895
896/// Loop state visible to policies.
897#[derive(Debug, Clone, Default)]
898pub struct LoopState {
899    pub steps: usize,
900}
901
902/// Policy interface controlling loop termination.
903pub trait AgentPolicy: Send + Sync {
904    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason>;
905}
906
907/// Function-backed policy for ergonomic composition.
908#[derive(Clone)]
909#[allow(clippy::type_complexity)]
910pub struct PolicyFn(
911    pub Arc<dyn Fn(&LoopState, &StepOutcome) -> Option<AgentStopReason> + Send + Sync + 'static>,
912);
913
914impl AgentPolicy for PolicyFn {
915    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
916        (self.0)(state, last)
917    }
918}
919
920/// Composite policy: stop when any sub-policy returns a stop reason.
921#[derive(Clone, Default)]
922pub struct CompositePolicy {
923    policies: Vec<PolicyFn>,
924}
925
926#[allow(dead_code)]
927impl CompositePolicy {
928    pub fn new(policies: Vec<PolicyFn>) -> Self {
929        Self { policies }
930    }
931    pub fn push(&mut self, p: PolicyFn) {
932        self.policies.push(p);
933    }
934}
935
936impl AgentPolicy for CompositePolicy {
937    fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
938        for p in &self.policies {
939            if let Some(r) = p.decide(state, last) {
940                return Some(r);
941            }
942        }
943        None
944    }
945}
946
947/// Built-in policies
948#[allow(dead_code)]
949pub mod policies {
950    use super::*;
951
952    pub fn until_no_tool_calls() -> PolicyFn {
953        PolicyFn(Arc::new(|_s, last| match last {
954            StepOutcome::Done { .. } => Some(AgentStopReason::DoneNoToolCalls),
955            _ => None,
956        }))
957    }
958
959    pub fn until_tool_called(tool_name: impl Into<String>) -> PolicyFn {
960        let target = tool_name.into();
961        PolicyFn(Arc::new(move |_s, last| match last {
962            StepOutcome::Next { invoked_tools, .. } => {
963                if invoked_tools.iter().any(|n| n == &target) {
964                    Some(AgentStopReason::ToolCalled(target.clone()))
965                } else {
966                    None
967                }
968            }
969            _ => None,
970        }))
971    }
972
973    pub fn max_steps(max: usize) -> PolicyFn {
974        PolicyFn(Arc::new(move |s, _| {
975            if s.steps >= max {
976                Some(AgentStopReason::MaxSteps)
977            } else {
978                None
979            }
980        }))
981    }
982}
983
984/// Final run summary from the agent loop.
985#[derive(Debug, Clone)]
986pub struct AgentRun {
987    pub messages: Vec<ChatCompletionRequestMessage>,
988    pub steps: usize,
989    pub stop: AgentStopReason,
990}
991
992/// Layer to wrap a step service with an agent loop controlled by a policy.
993pub struct AgentLoopLayer<P> {
994    policy: P,
995}
996
997impl<P> AgentLoopLayer<P> {
998    pub fn new(policy: P) -> Self {
999        Self { policy }
1000    }
1001}
1002
1003pub struct AgentLoop<S, P> {
1004    inner: Arc<tokio::sync::Mutex<S>>,
1005    policy: P,
1006}
1007
1008impl<S, P> Layer<S> for AgentLoopLayer<P>
1009where
1010    P: Clone,
1011{
1012    type Service = AgentLoop<S, P>;
1013    fn layer(&self, inner: S) -> Self::Service {
1014        AgentLoop {
1015            inner: Arc::new(tokio::sync::Mutex::new(inner)),
1016            policy: self.policy.clone(),
1017        }
1018    }
1019}
1020
1021impl<S, P> Service<CreateChatCompletionRequest> for AgentLoop<S, P>
1022where
1023    S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
1024        + Send
1025        + 'static,
1026    S::Future: Send + 'static,
1027    P: AgentPolicy + Send + Sync + Clone + 'static,
1028{
1029    type Response = AgentRun;
1030    type Error = BoxError;
1031    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1032
1033    fn poll_ready(
1034        &mut self,
1035        _cx: &mut std::task::Context<'_>,
1036    ) -> std::task::Poll<Result<(), Self::Error>> {
1037        std::task::Poll::Ready(Ok(()))
1038    }
1039
1040    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1041        let inner = self.inner.clone();
1042        let policy = self.policy.clone();
1043        Box::pin(async move {
1044            let mut state = LoopState::default();
1045            let base_model = req.model.clone();
1046            let mut current_messages = req.messages.clone();
1047            loop {
1048                // Rebuild request for this iteration
1049                let mut builder = CreateChatCompletionRequestArgs::default();
1050                builder.model(&base_model);
1051                builder.messages(current_messages.clone());
1052                let current_req = builder
1053                    .build()
1054                    .map_err(|e| format!("build req error: {}", e))?;
1055
1056                let mut guard = inner.lock().await;
1057                let outcome = guard.ready().await?.call(current_req).await?;
1058                drop(guard);
1059
1060                state.steps += 1;
1061
1062                if let Some(stop) = policy.decide(&state, &outcome) {
1063                    let messages = match outcome {
1064                        StepOutcome::Next { messages, .. } => messages,
1065                        StepOutcome::Done { messages, .. } => messages,
1066                    };
1067                    return Ok(AgentRun {
1068                        messages,
1069                        steps: state.steps,
1070                        stop,
1071                    });
1072                }
1073
1074                match outcome {
1075                    StepOutcome::Next { messages, .. } => {
1076                        current_messages = messages;
1077                    }
1078                    StepOutcome::Done { messages, .. } => {
1079                        return Ok(AgentRun {
1080                            messages,
1081                            steps: state.steps,
1082                            stop: AgentStopReason::DoneNoToolCalls,
1083                        });
1084                    }
1085                }
1086            }
1087        })
1088    }
1089}