1use std::{future::Future, pin::Pin, sync::Arc};
4
5use crate::groups::HandoffPolicy;
6use crate::provider::{ModelService, OpenAIProvider, ProviderResponse};
7use async_openai::{
8 config::OpenAIConfig,
9 types::{
10 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
11 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
12 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
13 ChatCompletionToolType, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
14 FunctionObjectArgs,
15 },
16 Client,
17};
18use futures::future::BoxFuture;
19use schemars::JsonSchema;
20use serde::de::DeserializeOwned;
21use serde_json::Value;
22use tokio::sync::Semaphore;
23use tower::{
24 util::{BoxCloneService, BoxService},
25 BoxError, Layer, Service, ServiceExt,
26};
27
28#[derive(Debug, Clone, Copy, Default)]
30pub enum ToolJoinPolicy {
31 #[default]
33 FailFast,
34 JoinAll,
36}
37
38#[derive(Debug, Clone)]
44pub struct ToolInvocation {
45 pub id: String, pub name: String, pub arguments: Value,
48}
49
50#[derive(Debug, Clone)]
52pub struct ToolOutput {
53 pub id: String, pub result: Value,
55}
56
57pub type ToolSvc = BoxCloneService<ToolInvocation, ToolOutput, BoxError>;
59
60pub struct ToolDef {
62 pub name: &'static str,
63 pub description: &'static str,
64 pub parameters_schema: Value,
65 pub service: ToolSvc,
66}
67
68impl ToolDef {
69 pub fn from_handler(
71 name: &'static str,
72 description: &'static str,
73 parameters_schema: Value,
74 handler: std::sync::Arc<
75 dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync + 'static,
76 >,
77 ) -> Self {
78 let handler_arc = handler.clone();
79 let svc = tower::service_fn(move |inv: ToolInvocation| {
80 let handler = handler_arc.clone();
81 async move {
82 if inv.name != name {
83 return Err::<ToolOutput, BoxError>(
84 format!("routed to wrong tool: expected={}, got={}", name, inv.name).into(),
85 );
86 }
87 let out = (handler)(inv.arguments).await?;
88 Ok(ToolOutput {
89 id: inv.id,
90 result: out,
91 })
92 }
93 });
94 Self {
95 name,
96 description,
97 parameters_schema,
98 service: BoxCloneService::new(svc),
99 }
100 }
101
102 pub fn to_openai_tool(&self) -> ChatCompletionTool {
104 let func = FunctionObjectArgs::default()
105 .name(self.name)
106 .description(self.description)
107 .parameters(self.parameters_schema.clone())
108 .build()
109 .expect("valid function object");
110 ChatCompletionToolArgs::default()
111 .r#type(ChatCompletionToolType::Function)
112 .function(func)
113 .build()
114 .expect("valid chat tool")
115 }
116}
117
118pub fn tool_typed<A, H, Fut, R>(
122 name: &'static str,
123 description: &'static str,
124 handler: H,
125) -> ToolDef
126where
127 A: DeserializeOwned + JsonSchema + Send + 'static,
128 R: serde::Serialize + Send + 'static,
129 H: Fn(A) -> Fut + Send + Sync + 'static,
130 Fut: Future<Output = Result<R, BoxError>> + Send + 'static,
131{
132 let schema = schemars::schema_for!(A);
133 let params_value = serde_json::to_value(schema.schema).expect("schema to value");
134 let handler_arc_inner = Arc::new(handler);
135 let handler_arc: Arc<
136 dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync,
137 > = Arc::new(move |raw: Value| {
138 let h = handler_arc_inner.clone();
139 Box::pin(async move {
140 let args: A = serde_json::from_value(raw)?;
141 let out: R = (h.as_ref())(args).await?;
142 let val = serde_json::to_value(out)?;
143 Ok(val)
144 })
145 });
146 ToolDef::from_handler(name, description, params_value, handler_arc)
147}
148
149#[derive(Clone)]
151pub struct ToolRouter {
152 name_to_index: std::collections::HashMap<&'static str, usize>,
153 services: Vec<ToolSvc>, }
155
156impl ToolRouter {
157 pub fn new(tools: Vec<ToolDef>) -> (Self, Vec<ChatCompletionTool>) {
158 use std::collections::HashMap;
159
160 let unknown = BoxCloneService::new(tower::service_fn(|inv: ToolInvocation| async move {
161 Err::<ToolOutput, BoxError>(format!("unknown tool: {}", inv.name).into())
162 }));
163
164 let mut services: Vec<ToolSvc> = vec![unknown];
165 let mut specs: Vec<ChatCompletionTool> = Vec::with_capacity(tools.len());
166 let mut name_to_index: HashMap<&'static str, usize> = HashMap::new();
167
168 for (i, td) in tools.into_iter().enumerate() {
169 name_to_index.insert(td.name, i + 1);
170 specs.push(td.to_openai_tool());
171 services.push(td.service);
172 }
173
174 (
175 Self {
176 name_to_index,
177 services,
178 },
179 specs,
180 )
181 }
182}
183
184impl Service<ToolInvocation> for ToolRouter {
185 type Response = ToolOutput;
186 type Error = BoxError;
187 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
188
189 fn poll_ready(
190 &mut self,
191 _cx: &mut std::task::Context<'_>,
192 ) -> std::task::Poll<Result<(), Self::Error>> {
193 std::task::Poll::Ready(Ok(()))
195 }
196
197 fn call(&mut self, req: ToolInvocation) -> Self::Future {
198 let idx = self
199 .name_to_index
200 .get(req.name.as_str())
201 .copied()
202 .unwrap_or(0);
203
204 let svc: &mut ToolSvc = &mut self.services[idx];
206 let fut = svc.call(req);
208 Box::pin(fut)
209 }
210}
211
212#[derive(Debug, Clone, Default)]
218#[allow(dead_code)]
219pub struct StepAux {
220 pub prompt_tokens: usize,
221 pub completion_tokens: usize,
222 pub tool_invocations: usize,
223}
224
225#[derive(Debug, Clone)]
227#[allow(dead_code)]
228pub enum StepOutcome {
229 Next {
230 messages: Vec<ChatCompletionRequestMessage>,
231 aux: StepAux,
232 invoked_tools: Vec<String>,
233 },
234 Done {
235 messages: Vec<ChatCompletionRequestMessage>,
236 aux: StepAux,
237 },
238}
239
240pub struct Step<S, P> {
242 provider: Arc<tokio::sync::Mutex<P>>,
243 model: String,
244 temperature: Option<f32>,
245 max_tokens: Option<u32>,
246 tools: S,
247 tool_specs: Arc<Vec<ChatCompletionTool>>, parallel_tools: bool,
249 tool_concurrency_limit: Option<usize>,
250 join_policy: ToolJoinPolicy,
251}
252
253impl<S, P> Step<S, P> {
254 pub fn new(
255 provider: P,
256 model: impl Into<String>,
257 tools: S,
258 tool_specs: Vec<ChatCompletionTool>,
259 ) -> Self {
260 Self {
261 provider: Arc::new(tokio::sync::Mutex::new(provider)),
262 model: model.into(),
263 temperature: None,
264 max_tokens: None,
265 tools,
266 tool_specs: Arc::new(tool_specs),
267 parallel_tools: false,
268 tool_concurrency_limit: None,
269 join_policy: ToolJoinPolicy::FailFast,
270 }
271 }
272
273 pub fn temperature(mut self, t: f32) -> Self {
274 self.temperature = Some(t);
275 self
276 }
277
278 pub fn max_tokens(mut self, mt: u32) -> Self {
279 self.max_tokens = Some(mt);
280 self
281 }
282
283 pub fn enable_parallel_tools(mut self, enabled: bool) -> Self {
284 self.parallel_tools = enabled;
285 self
286 }
287
288 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
289 self.tool_concurrency_limit = Some(limit);
290 self
291 }
292
293 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
294 self.join_policy = policy;
295 self
296 }
297}
298
299pub struct StepLayer<P> {
301 provider: P,
302 model: String,
303 temperature: Option<f32>,
304 max_tokens: Option<u32>,
305 tool_specs: Arc<Vec<ChatCompletionTool>>,
306 parallel_tools: bool,
307 tool_concurrency_limit: Option<usize>,
308 join_policy: ToolJoinPolicy,
309}
310
311impl<P> StepLayer<P> {
312 pub fn new(provider: P, model: impl Into<String>, tool_specs: Vec<ChatCompletionTool>) -> Self {
313 Self {
314 provider,
315 model: model.into(),
316 temperature: None,
317 max_tokens: None,
318 tool_specs: Arc::new(tool_specs),
319 parallel_tools: false,
320 tool_concurrency_limit: None,
321 join_policy: ToolJoinPolicy::FailFast,
322 }
323 }
324
325 pub fn temperature(mut self, t: f32) -> Self {
326 self.temperature = Some(t);
327 self
328 }
329
330 pub fn max_tokens(mut self, mt: u32) -> Self {
331 self.max_tokens = Some(mt);
332 self
333 }
334
335 pub fn parallel_tools(mut self, enabled: bool) -> Self {
336 self.parallel_tools = enabled;
337 self
338 }
339
340 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
341 self.tool_concurrency_limit = Some(limit);
342 self
343 }
344
345 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
346 self.join_policy = policy;
347 self
348 }
349}
350
351impl<S, P> Layer<S> for StepLayer<P>
352where
353 P: Clone,
354{
355 type Service = Step<S, P>;
356
357 fn layer(&self, tools: S) -> Self::Service {
358 let mut s = Step::new(
359 self.provider.clone(),
360 self.model.clone(),
361 tools,
362 (*self.tool_specs).clone(),
363 );
364 s.temperature = self.temperature;
365 s.max_tokens = self.max_tokens;
366 s.parallel_tools = self.parallel_tools;
367 s.tool_concurrency_limit = self.tool_concurrency_limit;
368 s.join_policy = self.join_policy;
369 s
370 }
371}
372
373impl<S, P> Service<CreateChatCompletionRequest> for Step<S, P>
374where
375 S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
376 S::Future: Send + 'static,
377 P: ModelService + Send + 'static,
378 P::Future: Send + 'static,
379{
380 type Response = StepOutcome;
381 type Error = BoxError;
382 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
383
384 fn poll_ready(
385 &mut self,
386 cx: &mut std::task::Context<'_>,
387 ) -> std::task::Poll<Result<(), Self::Error>> {
388 let _ = cx; std::task::Poll::Ready(Ok(()))
390 }
391
392 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
393 let provider = self.provider.clone();
394 let model = self.model.clone();
395 let temperature = self.temperature;
396 let max_tokens = self.max_tokens;
397 let tools = self.tools.clone();
398 let tool_specs = self.tool_specs.clone();
399 let parallel_tools = self.parallel_tools;
400 let _tool_concurrency_limit = self.tool_concurrency_limit;
401 let join_policy = self.join_policy;
402
403 Box::pin(async move {
404 let effective_model: Option<String> = req.model.clone().into();
406
407 let mut builder = CreateChatCompletionRequestArgs::default();
408 builder.messages(req.messages.clone());
409 if let Some(m) = effective_model.as_ref() {
410 builder.model(m);
411 } else {
412 builder.model(&model);
413 }
414 if let Some(t) = req.temperature.or(temperature) {
415 builder.temperature(t);
416 }
417 if let Some(mt) = max_tokens {
418 builder.max_tokens(mt);
419 }
420 if let Some(ts) = req.tools.clone() {
421 builder.tools(ts);
422 } else if !tool_specs.is_empty() {
423 builder.tools((*tool_specs).clone());
424 }
425
426 let rebuilt_req = builder
427 .build()
428 .map_err(|e| format!("request build error: {}", e))?;
429
430 let mut messages = rebuilt_req.messages.clone();
431
432 let mut p = provider.lock().await;
435 let ProviderResponse {
436 assistant,
437 prompt_tokens,
438 completion_tokens,
439 } = ServiceExt::ready(&mut *p).await?.call(rebuilt_req).await?;
440 let mut aux = StepAux {
441 prompt_tokens,
442 completion_tokens,
443 tool_invocations: 0,
444 };
445
446 let mut asst_builder = ChatCompletionRequestAssistantMessageArgs::default();
448 if let Some(content) = assistant.content.clone() {
449 asst_builder.content(content);
450 } else {
451 asst_builder.content("");
452 }
453 if let Some(tool_calls) = assistant.tool_calls.clone() {
454 asst_builder.tool_calls(tool_calls);
455 }
456 let asst_req = asst_builder
457 .build()
458 .map_err(|e| format!("assistant msg build error: {}", e))?;
459 messages.push(ChatCompletionRequestMessage::from(asst_req));
460
461 let tool_calls = assistant.tool_calls.unwrap_or_default();
463 if tool_calls.is_empty() {
464 return Ok(StepOutcome::Done { messages, aux });
465 }
466
467 let mut invoked_names: Vec<String> = Vec::with_capacity(tool_calls.len());
468 let invocations: Vec<ToolInvocation> = tool_calls
469 .into_iter()
470 .map(|tc| {
471 let name = tc.function.name;
472 invoked_names.push(name.clone());
473 let args: Value =
474 serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
475 ToolInvocation {
476 id: tc.id,
477 name,
478 arguments: args,
479 }
480 })
481 .collect();
482
483 if invocations.len() > 1 && parallel_tools {
484 let sem = _tool_concurrency_limit.map(|n| Arc::new(Semaphore::new(n)));
486 match join_policy {
487 ToolJoinPolicy::FailFast => {
488 let futures: Vec<_> = invocations
489 .into_iter()
490 .map(|inv| {
491 let mut svc = tools.clone();
492 let sem_cl = sem.clone();
493 async move {
494 let _permit = match &sem_cl {
495 Some(s) => Some(
496 s.clone().acquire_owned().await.expect("semaphore"),
497 ),
498 None => None,
499 };
500 let ToolOutput { id, result } =
501 ServiceExt::ready(&mut svc).await?.call(inv).await?;
502 Ok::<(String, Value), BoxError>((id, result))
503 }
504 })
505 .collect();
506 let outputs: Vec<(String, Value)> =
507 futures::future::try_join_all(futures).await?;
508 for (id, result) in outputs {
509 aux.tool_invocations += 1;
510 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
511 .content(result.to_string())
512 .tool_call_id(id)
513 .build()?;
514 messages.push(tool_msg.into());
515 }
516 }
517 ToolJoinPolicy::JoinAll => {
518 let futures: Vec<_> =
519 invocations
520 .into_iter()
521 .enumerate()
522 .map(|(idx, inv)| {
523 let mut svc = tools.clone();
524 let sem_cl = sem.clone();
525 async move {
526 let _permit = match &sem_cl {
527 Some(s) => Some(
528 s.clone().acquire_owned().await.expect("semaphore"),
529 ),
530 None => None,
531 };
532 let res =
533 ServiceExt::ready(&mut svc).await?.call(inv).await;
534 match res {
535 Ok(ToolOutput { id, result }) => Ok::<
536 Result<(usize, String, Value), BoxError>,
537 BoxError,
538 >(
539 Ok((idx, id, result)),
540 ),
541 Err(e) => Ok(Err(e)),
542 }
543 }
544 })
545 .collect();
546 let results = futures::future::join_all(futures).await;
547 let mut successes: Vec<(usize, String, Value)> = Vec::new();
548 let mut errors: Vec<String> = Vec::new();
549 for item in results.into_iter() {
550 match item {
551 Ok(Ok((idx, id, result))) => successes.push((idx, id, result)),
552 Ok(Err(e)) => errors.push(format!("{}", e)),
553 Err(e) => errors.push(format!("{}", e)),
554 }
555 }
556 successes.sort_by_key(|(idx, _, _)| *idx);
557 for (_idx, id, result) in successes.into_iter() {
558 aux.tool_invocations += 1;
559 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
560 .content(result.to_string())
561 .tool_call_id(id)
562 .build()?;
563 messages.push(tool_msg.into());
564 }
565 if !errors.is_empty() {
566 return Err(
567 format!("one or more tools failed: {}", errors.join("; ")).into()
568 );
569 }
570 }
571 }
572 } else {
573 for inv in invocations {
575 let mut svc = tools.clone();
576 let ToolOutput { id, result } =
577 ServiceExt::ready(&mut svc).await?.call(inv).await?;
578 aux.tool_invocations += 1;
579 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
580 .content(result.to_string())
581 .tool_call_id(id)
582 .build()?;
583 messages.push(tool_msg.into());
584 }
585 }
586
587 Ok(StepOutcome::Next {
588 messages,
589 aux,
590 invoked_tools: invoked_names,
591 })
592 })
593 }
594}
595
596pub fn simple_chat_request(system: &str, user: &str) -> CreateChatCompletionRequest {
602 let sys = ChatCompletionRequestSystemMessageArgs::default()
603 .content(system)
604 .build()
605 .expect("system msg");
606 let usr = ChatCompletionRequestUserMessageArgs::default()
607 .content(user)
608 .build()
609 .expect("user msg");
610 CreateChatCompletionRequestArgs::default()
611 .model("gpt-4o")
612 .messages(vec![sys.into(), usr.into()])
613 .build()
614 .expect("chat req")
615}
616
617#[derive(Debug, Clone)]
623#[allow(dead_code)]
624pub enum AgentStopReason {
625 DoneNoToolCalls,
626 MaxSteps,
627 ToolCalled(String),
628 TokensBudgetExceeded,
629 ToolBudgetExceeded,
630 TimeBudgetExceeded,
631}
632
633#[derive(Default, Clone)]
639pub struct Policy {
640 inner: CompositePolicy,
641}
642
643#[allow(dead_code)]
644impl Policy {
645 pub fn new() -> Self {
646 Self {
647 inner: CompositePolicy::default(),
648 }
649 }
650 pub fn until_no_tool_calls(mut self) -> Self {
651 self.inner.policies.push(policies::until_no_tool_calls());
652 self
653 }
654 pub fn or_tool(mut self, name: impl Into<String>) -> Self {
655 self.inner.policies.push(policies::until_tool_called(name));
656 self
657 }
658 pub fn or_max_steps(mut self, max: usize) -> Self {
659 self.inner.policies.push(policies::max_steps(max));
660 self
661 }
662 pub fn build(self) -> CompositePolicy {
663 self.inner
664 }
665}
666
667pub type AgentSvc = BoxService<CreateChatCompletionRequest, AgentRun, BoxError>;
669
670pub struct Agent;
672
673pub struct AgentBuilder {
674 client: Arc<Client<OpenAIConfig>>,
675 model: String,
676 temperature: Option<f32>,
677 max_tokens: Option<u32>,
678 tools: Vec<ToolDef>,
679 policy: CompositePolicy,
680 handoff: Option<crate::groups::AnyHandoffPolicy>,
681 provider: Option<
682 tower::util::BoxCloneService<
683 CreateChatCompletionRequest,
684 crate::provider::ProviderResponse,
685 BoxError,
686 >,
687 >,
688 enable_parallel_tools: bool,
689 tool_concurrency_limit: Option<usize>,
690 tool_join_policy: ToolJoinPolicy,
691 agent_service_map: Option<Arc<dyn Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static>>, }
693
694impl Agent {
695 pub fn builder(client: Arc<Client<OpenAIConfig>>) -> AgentBuilder {
696 AgentBuilder {
697 client,
698 model: "gpt-4o".to_string(),
699 temperature: None,
700 max_tokens: None,
701 tools: Vec::new(),
702 policy: CompositePolicy::default(),
703 handoff: None,
704 provider: None,
705 enable_parallel_tools: false,
706 tool_concurrency_limit: None,
707 tool_join_policy: ToolJoinPolicy::FailFast,
708 agent_service_map: None,
709 }
710 }
711}
712
713impl AgentBuilder {
714 pub fn model(mut self, model: impl Into<String>) -> Self {
715 self.model = model.into();
716 self
717 }
718 pub fn temperature(mut self, t: f32) -> Self {
719 self.temperature = Some(t);
720 self
721 }
722 pub fn max_tokens(mut self, mt: u32) -> Self {
723 self.max_tokens = Some(mt);
724 self
725 }
726 pub fn tool(mut self, tool: ToolDef) -> Self {
727 self.tools.push(tool);
728 self
729 }
730 pub fn tools(mut self, tools: Vec<ToolDef>) -> Self {
731 self.tools.extend(tools);
732 self
733 }
734 pub fn policy(mut self, policy: CompositePolicy) -> Self {
735 self.policy = policy;
736 self
737 }
738
739 pub fn handoff_policy(mut self, policy: crate::groups::AnyHandoffPolicy) -> Self {
741 self.handoff = Some(policy);
742 self
743 }
744
745 pub fn with_provider<P>(mut self, provider: P) -> Self
747 where
748 P: crate::provider::ModelService + Clone + Send + 'static,
749 P::Future: Send + 'static,
750 {
751 self.provider = Some(tower::util::BoxCloneService::new(provider));
752 self
753 }
754
755 pub fn parallel_tools(mut self, enabled: bool) -> Self {
757 self.enable_parallel_tools = enabled;
758 self
759 }
760
761 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
763 self.tool_concurrency_limit = Some(limit);
764 self
765 }
766
767 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
769 self.tool_join_policy = policy;
770 self
771 }
772
773 pub fn map_agent_service<F>(mut self, f: F) -> Self
776 where
777 F: Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static,
778 {
779 self.agent_service_map = Some(Arc::new(f));
780 self
781 }
782
783 pub fn build(self) -> AgentSvc {
784 let (router, mut specs) = ToolRouter::new(self.tools);
785 let routed: ToolSvc = if let Some(policy) = &self.handoff {
787 let hand_spec = policy.handoff_tools();
788 if !hand_spec.is_empty() {
789 specs.extend(hand_spec);
790 }
791 crate::groups::layer_tool_router_with_handoff(router, policy.clone())
792 } else {
793 BoxCloneService::new(router)
795 };
796
797 let base_provider: tower::util::BoxCloneService<
798 CreateChatCompletionRequest,
799 crate::provider::ProviderResponse,
800 BoxError,
801 > = if let Some(p) = self.provider {
802 p
803 } else {
804 tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
805 };
806 let mut step_layer = StepLayer::new(base_provider, self.model, specs)
807 .temperature(self.temperature.unwrap_or(0.0))
808 .max_tokens(self.max_tokens.unwrap_or(512))
809 .parallel_tools(self.enable_parallel_tools)
810 .tool_join_policy(self.tool_join_policy);
811 if let Some(lim) = self.tool_concurrency_limit {
812 step_layer = step_layer.tool_concurrency_limit(lim);
813 }
814 let step = step_layer.layer(routed);
815 let agent = AgentLoopLayer::new(self.policy).layer(step);
816 let boxed = BoxService::new(agent);
817 match &self.agent_service_map {
818 Some(map) => (map)(boxed),
819 None => boxed,
820 }
821 }
822
823 pub fn build_with_session<Ls, Ss>(
825 self,
826 load: Arc<Ls>,
827 save: Arc<Ss>,
828 session_id: crate::sessions::SessionId,
829 ) -> AgentSvc
830 where
831 Ls: Service<
832 crate::sessions::LoadSession,
833 Response = crate::sessions::History,
834 Error = BoxError,
835 > + Send
836 + Sync
837 + Clone
838 + 'static,
839 Ls::Future: Send + 'static,
840 Ss: Service<crate::sessions::SaveSession, Response = (), Error = BoxError>
841 + Send
842 + Sync
843 + Clone
844 + 'static,
845 Ss::Future: Send + 'static,
846 {
847 let (router, mut specs) = ToolRouter::new(self.tools);
848 let routed: ToolSvc = if let Some(policy) = &self.handoff {
849 let hand_spec = policy.handoff_tools();
850 if !hand_spec.is_empty() {
851 specs.extend(hand_spec);
852 }
853 crate::groups::layer_tool_router_with_handoff(router, policy.clone())
854 } else {
855 BoxCloneService::new(router)
856 };
857
858 let base_provider: tower::util::BoxCloneService<
859 CreateChatCompletionRequest,
860 crate::provider::ProviderResponse,
861 BoxError,
862 > = if let Some(p) = self.provider {
863 p
864 } else {
865 tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
866 };
867 let mut step_layer = StepLayer::new(base_provider, self.model, specs)
868 .temperature(self.temperature.unwrap_or(0.0))
869 .max_tokens(self.max_tokens.unwrap_or(512))
870 .parallel_tools(self.enable_parallel_tools)
871 .tool_join_policy(self.tool_join_policy);
872 if let Some(lim) = self.tool_concurrency_limit {
873 step_layer = step_layer.tool_concurrency_limit(lim);
874 }
875 let step = step_layer.layer(routed);
876
877 let mem_layer = crate::sessions::MemoryLayer::new(load, save, session_id);
879 let step_with_mem = mem_layer.layer(step);
880 let agent = AgentLoopLayer::new(self.policy).layer(step_with_mem);
881 let boxed = BoxService::new(agent);
882 match &self.agent_service_map {
883 Some(map) => (map)(boxed),
884 None => boxed,
885 }
886 }
887}
888
889pub async fn run(agent: &mut AgentSvc, system: &str, user: &str) -> Result<AgentRun, BoxError> {
891 let req = simple_chat_request(system, user);
892 let resp = ServiceExt::ready(agent).await?.call(req).await?;
893 Ok(resp)
894}
895
896#[derive(Debug, Clone, Default)]
898pub struct LoopState {
899 pub steps: usize,
900}
901
902pub trait AgentPolicy: Send + Sync {
904 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason>;
905}
906
907#[derive(Clone)]
909#[allow(clippy::type_complexity)]
910pub struct PolicyFn(
911 pub Arc<dyn Fn(&LoopState, &StepOutcome) -> Option<AgentStopReason> + Send + Sync + 'static>,
912);
913
914impl AgentPolicy for PolicyFn {
915 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
916 (self.0)(state, last)
917 }
918}
919
920#[derive(Clone, Default)]
922pub struct CompositePolicy {
923 policies: Vec<PolicyFn>,
924}
925
926#[allow(dead_code)]
927impl CompositePolicy {
928 pub fn new(policies: Vec<PolicyFn>) -> Self {
929 Self { policies }
930 }
931 pub fn push(&mut self, p: PolicyFn) {
932 self.policies.push(p);
933 }
934}
935
936impl AgentPolicy for CompositePolicy {
937 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
938 for p in &self.policies {
939 if let Some(r) = p.decide(state, last) {
940 return Some(r);
941 }
942 }
943 None
944 }
945}
946
947#[allow(dead_code)]
949pub mod policies {
950 use super::*;
951
952 pub fn until_no_tool_calls() -> PolicyFn {
953 PolicyFn(Arc::new(|_s, last| match last {
954 StepOutcome::Done { .. } => Some(AgentStopReason::DoneNoToolCalls),
955 _ => None,
956 }))
957 }
958
959 pub fn until_tool_called(tool_name: impl Into<String>) -> PolicyFn {
960 let target = tool_name.into();
961 PolicyFn(Arc::new(move |_s, last| match last {
962 StepOutcome::Next { invoked_tools, .. } => {
963 if invoked_tools.iter().any(|n| n == &target) {
964 Some(AgentStopReason::ToolCalled(target.clone()))
965 } else {
966 None
967 }
968 }
969 _ => None,
970 }))
971 }
972
973 pub fn max_steps(max: usize) -> PolicyFn {
974 PolicyFn(Arc::new(move |s, _| {
975 if s.steps >= max {
976 Some(AgentStopReason::MaxSteps)
977 } else {
978 None
979 }
980 }))
981 }
982}
983
984#[derive(Debug, Clone)]
986pub struct AgentRun {
987 pub messages: Vec<ChatCompletionRequestMessage>,
988 pub steps: usize,
989 pub stop: AgentStopReason,
990}
991
992pub struct AgentLoopLayer<P> {
994 policy: P,
995}
996
997impl<P> AgentLoopLayer<P> {
998 pub fn new(policy: P) -> Self {
999 Self { policy }
1000 }
1001}
1002
1003pub struct AgentLoop<S, P> {
1004 inner: Arc<tokio::sync::Mutex<S>>,
1005 policy: P,
1006}
1007
1008impl<S, P> Layer<S> for AgentLoopLayer<P>
1009where
1010 P: Clone,
1011{
1012 type Service = AgentLoop<S, P>;
1013 fn layer(&self, inner: S) -> Self::Service {
1014 AgentLoop {
1015 inner: Arc::new(tokio::sync::Mutex::new(inner)),
1016 policy: self.policy.clone(),
1017 }
1018 }
1019}
1020
1021impl<S, P> Service<CreateChatCompletionRequest> for AgentLoop<S, P>
1022where
1023 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
1024 + Send
1025 + 'static,
1026 S::Future: Send + 'static,
1027 P: AgentPolicy + Send + Sync + Clone + 'static,
1028{
1029 type Response = AgentRun;
1030 type Error = BoxError;
1031 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1032
1033 fn poll_ready(
1034 &mut self,
1035 _cx: &mut std::task::Context<'_>,
1036 ) -> std::task::Poll<Result<(), Self::Error>> {
1037 std::task::Poll::Ready(Ok(()))
1038 }
1039
1040 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1041 let inner = self.inner.clone();
1042 let policy = self.policy.clone();
1043 Box::pin(async move {
1044 let mut state = LoopState::default();
1045 let base_model = req.model.clone();
1046 let mut current_messages = req.messages.clone();
1047 loop {
1048 let mut builder = CreateChatCompletionRequestArgs::default();
1050 builder.model(&base_model);
1051 builder.messages(current_messages.clone());
1052 let current_req = builder
1053 .build()
1054 .map_err(|e| format!("build req error: {}", e))?;
1055
1056 let mut guard = inner.lock().await;
1057 let outcome = guard.ready().await?.call(current_req).await?;
1058 drop(guard);
1059
1060 state.steps += 1;
1061
1062 if let Some(stop) = policy.decide(&state, &outcome) {
1063 let messages = match outcome {
1064 StepOutcome::Next { messages, .. } => messages,
1065 StepOutcome::Done { messages, .. } => messages,
1066 };
1067 return Ok(AgentRun {
1068 messages,
1069 steps: state.steps,
1070 stop,
1071 });
1072 }
1073
1074 match outcome {
1075 StepOutcome::Next { messages, .. } => {
1076 current_messages = messages;
1077 }
1078 StepOutcome::Done { messages, .. } => {
1079 return Ok(AgentRun {
1080 messages,
1081 steps: state.steps,
1082 stop: AgentStopReason::DoneNoToolCalls,
1083 });
1084 }
1085 }
1086 }
1087 })
1088 }
1089}