1use std::{future::Future, pin::Pin, sync::Arc};
4
5use crate::groups::HandoffPolicy;
6use crate::provider::{ModelService, OpenAIProvider, ProviderResponse};
7use async_openai::{
8 config::OpenAIConfig,
9 types::{
10 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
11 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
12 ChatCompletionRequestUserMessageArgs, ChatCompletionResponseMessage, ChatCompletionTool,
13 ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
14 CreateChatCompletionRequestArgs, FunctionObjectArgs, ReasoningEffort,
15 },
16 Client,
17};
18use async_trait::async_trait;
19use futures::future::BoxFuture;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23use tokio::sync::Semaphore;
24use tower::{
25 util::{BoxCloneService, BoxService},
26 BoxError, Layer, Service, ServiceExt,
27};
28use tracing::{debug, trace};
29
30#[derive(Debug, Clone, Copy, Default)]
32pub enum ToolJoinPolicy {
33 #[default]
35 FailFast,
36 JoinAll,
38}
39
40#[derive(Debug, Clone)]
46pub struct ToolInvocation {
47 pub id: String, pub name: String, pub arguments: Value,
50}
51
52#[derive(Debug, Clone)]
54pub struct ToolOutput {
55 pub id: String, pub result: Value,
57}
58
59pub type ToolSvc = BoxCloneService<ToolInvocation, ToolOutput, BoxError>;
61
62pub struct ToolDef {
64 pub name: &'static str,
65 pub description: &'static str,
66 pub parameters_schema: Value,
67 pub service: ToolSvc,
68}
69
70impl ToolDef {
71 pub fn from_handler(
73 name: &'static str,
74 description: &'static str,
75 parameters_schema: Value,
76 handler: std::sync::Arc<
77 dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync + 'static,
78 >,
79 ) -> Self {
80 let handler_arc = handler.clone();
81 let svc = tower::service_fn(move |inv: ToolInvocation| {
82 let handler = handler_arc.clone();
83 async move {
84 if inv.name != name {
85 return Err::<ToolOutput, BoxError>(
86 format!("routed to wrong tool: expected={}, got={}", name, inv.name).into(),
87 );
88 }
89 let out = (handler)(inv.arguments).await?;
90 Ok(ToolOutput {
91 id: inv.id,
92 result: out,
93 })
94 }
95 });
96 Self {
97 name,
98 description,
99 parameters_schema,
100 service: BoxCloneService::new(svc),
101 }
102 }
103
104 pub fn to_openai_tool(&self) -> ChatCompletionTool {
106 let func = FunctionObjectArgs::default()
107 .name(self.name)
108 .description(self.description)
109 .parameters(self.parameters_schema.clone())
110 .build()
111 .expect("valid function object");
112 ChatCompletionToolArgs::default()
113 .r#type(ChatCompletionToolType::Function)
114 .function(func)
115 .build()
116 .expect("valid chat tool")
117 }
118}
119
120pub fn tool_typed<A, H, Fut, R>(
124 name: &'static str,
125 description: &'static str,
126 handler: H,
127) -> ToolDef
128where
129 A: DeserializeOwned + JsonSchema + Send + 'static,
130 R: serde::Serialize + Send + 'static,
131 H: Fn(A) -> Fut + Send + Sync + 'static,
132 Fut: Future<Output = Result<R, BoxError>> + Send + 'static,
133{
134 let schema = schemars::schema_for!(A);
135 let params_value = serde_json::to_value(schema.schema).expect("schema to value");
136 let handler_arc_inner = Arc::new(handler);
137 let handler_arc: Arc<
138 dyn Fn(Value) -> BoxFuture<'static, Result<Value, BoxError>> + Send + Sync,
139 > = Arc::new(move |raw: Value| {
140 let h = handler_arc_inner.clone();
141 Box::pin(async move {
142 let args: A = serde_json::from_value(raw)?;
143 let out: R = (h.as_ref())(args).await?;
144 let val = serde_json::to_value(out)?;
145 Ok(val)
146 })
147 });
148 ToolDef::from_handler(name, description, params_value, handler_arc)
149}
150
151#[derive(Clone)]
153pub struct ToolRouter {
154 name_to_index: std::collections::HashMap<&'static str, usize>,
155 services: Vec<ToolSvc>, }
157
158impl ToolRouter {
159 pub fn new(tools: Vec<ToolDef>) -> (Self, Vec<ChatCompletionTool>) {
160 use std::collections::HashMap;
161
162 let unknown = BoxCloneService::new(tower::service_fn(|inv: ToolInvocation| async move {
163 Err::<ToolOutput, BoxError>(format!("unknown tool: {}", inv.name).into())
164 }));
165
166 let mut services: Vec<ToolSvc> = vec![unknown];
167 let mut specs: Vec<ChatCompletionTool> = Vec::with_capacity(tools.len());
168 let mut name_to_index: HashMap<&'static str, usize> = HashMap::new();
169
170 for (i, td) in tools.into_iter().enumerate() {
171 name_to_index.insert(td.name, i + 1);
172 specs.push(td.to_openai_tool());
173 services.push(td.service);
174 }
175
176 (
177 Self {
178 name_to_index,
179 services,
180 },
181 specs,
182 )
183 }
184}
185
186impl Service<ToolInvocation> for ToolRouter {
187 type Response = ToolOutput;
188 type Error = BoxError;
189 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
190
191 fn poll_ready(
192 &mut self,
193 _cx: &mut std::task::Context<'_>,
194 ) -> std::task::Poll<Result<(), Self::Error>> {
195 std::task::Poll::Ready(Ok(()))
197 }
198
199 fn call(&mut self, req: ToolInvocation) -> Self::Future {
200 let idx = self
201 .name_to_index
202 .get(req.name.as_str())
203 .copied()
204 .unwrap_or(0);
205
206 let svc: &mut ToolSvc = &mut self.services[idx];
208 let fut = svc.call(req);
210 Box::pin(fut)
211 }
212}
213
214#[async_trait]
218pub trait LLMInstructionProvider: Send + Sync {
219 async fn instructions(&self) -> Option<String>;
220}
221
222#[derive(Clone)]
223enum InstructionSource {
224 Static(String),
225 Dynamic(Arc<dyn LLMInstructionProvider>),
226}
227
228#[derive(Debug, Clone, Default)]
234#[allow(dead_code)]
235pub struct StepAux {
236 pub prompt_tokens: usize,
237 pub completion_tokens: usize,
238 pub tool_invocations: usize,
239}
240
241#[derive(Debug, Clone)]
243#[allow(dead_code)]
244pub enum StepOutcome {
245 Next {
246 messages: Vec<ChatCompletionRequestMessage>,
247 aux: StepAux,
248 invoked_tools: Vec<String>,
249 },
250 Done {
251 messages: Vec<ChatCompletionRequestMessage>,
252 aux: StepAux,
253 },
254}
255
256fn summarize_request_messages(messages: &[ChatCompletionRequestMessage]) -> Vec<&'static str> {
257 messages
258 .iter()
259 .map(|message| match message {
260 ChatCompletionRequestMessage::System(_) => "system",
261 ChatCompletionRequestMessage::User(_) => "user",
262 ChatCompletionRequestMessage::Assistant(_) => "assistant",
263 ChatCompletionRequestMessage::Tool(_) => "tool",
264 ChatCompletionRequestMessage::Function(_) => "function",
265 ChatCompletionRequestMessage::Developer(_) => "developer",
266 })
267 .collect()
268}
269
270fn summarize_assistant_message(message: &ChatCompletionResponseMessage) -> String {
271 let content_shape = if message
272 .content
273 .as_ref()
274 .map(|c| !c.is_empty())
275 .unwrap_or(false)
276 {
277 "text"
278 } else {
279 "none"
280 };
281 let tool_calls = message.tool_calls.as_ref().map(|c| c.len()).unwrap_or(0);
282 format!("assistant(content={content_shape}, tool_calls={tool_calls})")
283}
284
285pub struct Step<S, P> {
287 provider: Arc<tokio::sync::Mutex<P>>,
288 model: String,
289 temperature: Option<f32>,
290 max_tokens: Option<u32>,
291 reasoning_effort: Option<ReasoningEffort>,
292 instructions: Option<InstructionSource>,
293 tools: S,
294 tool_specs: Arc<Vec<ChatCompletionTool>>, parallel_tools: bool,
296 tool_concurrency_limit: Option<usize>,
297 join_policy: ToolJoinPolicy,
298}
299
300impl<S, P> Step<S, P> {
301 pub fn new(
302 provider: P,
303 model: impl Into<String>,
304 tools: S,
305 tool_specs: Vec<ChatCompletionTool>,
306 ) -> Self {
307 Self {
308 provider: Arc::new(tokio::sync::Mutex::new(provider)),
309 model: model.into(),
310 temperature: None,
311 max_tokens: None,
312 reasoning_effort: None,
313 instructions: None,
314 tools,
315 tool_specs: Arc::new(tool_specs),
316 parallel_tools: false,
317 tool_concurrency_limit: None,
318 join_policy: ToolJoinPolicy::FailFast,
319 }
320 }
321
322 pub fn temperature(mut self, t: f32) -> Self {
323 self.temperature = Some(t);
324 self
325 }
326
327 pub fn max_tokens(mut self, mt: u32) -> Self {
328 self.max_tokens = Some(mt);
329 self
330 }
331
332 pub fn enable_parallel_tools(mut self, enabled: bool) -> Self {
333 self.parallel_tools = enabled;
334 self
335 }
336
337 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
338 self.tool_concurrency_limit = Some(limit);
339 self
340 }
341
342 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
343 self.join_policy = policy;
344 self
345 }
346
347 pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
348 self.reasoning_effort = Some(effort);
349 self
350 }
351}
352
353pub struct StepLayer<P> {
355 provider: P,
356 model: String,
357 temperature: Option<f32>,
358 max_tokens: Option<u32>,
359 reasoning_effort: Option<ReasoningEffort>,
360 instructions: Option<InstructionSource>,
361 tool_specs: Arc<Vec<ChatCompletionTool>>,
362 parallel_tools: bool,
363 tool_concurrency_limit: Option<usize>,
364 join_policy: ToolJoinPolicy,
365}
366
367impl<P> StepLayer<P> {
368 pub fn new(provider: P, model: impl Into<String>, tool_specs: Vec<ChatCompletionTool>) -> Self {
369 Self {
370 provider,
371 model: model.into(),
372 temperature: None,
373 max_tokens: None,
374 reasoning_effort: None,
375 instructions: None,
376 tool_specs: Arc::new(tool_specs),
377 parallel_tools: false,
378 tool_concurrency_limit: None,
379 join_policy: ToolJoinPolicy::FailFast,
380 }
381 }
382
383 pub fn temperature(mut self, t: f32) -> Self {
384 self.temperature = Some(t);
385 self
386 }
387
388 pub fn max_tokens(mut self, mt: u32) -> Self {
389 self.max_tokens = Some(mt);
390 self
391 }
392
393 pub fn parallel_tools(mut self, enabled: bool) -> Self {
394 self.parallel_tools = enabled;
395 self
396 }
397
398 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
399 self.tool_concurrency_limit = Some(limit);
400 self
401 }
402
403 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
404 self.join_policy = policy;
405 self
406 }
407
408 pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
409 self.reasoning_effort = Some(effort);
410 self
411 }
412
413 pub fn instructions(mut self, text: impl Into<String>) -> Self {
414 self.instructions = Some(InstructionSource::Static(text.into()));
415 self
416 }
417
418 pub fn instruction_provider(mut self, provider: Arc<dyn LLMInstructionProvider>) -> Self {
419 self.instructions = Some(InstructionSource::Dynamic(provider));
420 self
421 }
422}
423
424impl<S, P> Layer<S> for StepLayer<P>
425where
426 P: Clone,
427{
428 type Service = Step<S, P>;
429
430 fn layer(&self, tools: S) -> Self::Service {
431 let mut s = Step::new(
432 self.provider.clone(),
433 self.model.clone(),
434 tools,
435 (*self.tool_specs).clone(),
436 );
437 s.temperature = self.temperature;
438 s.max_tokens = self.max_tokens;
439 s.reasoning_effort = self.reasoning_effort.clone();
440 s.instructions = self.instructions.clone();
441 s.parallel_tools = self.parallel_tools;
444 s.tool_concurrency_limit = self.tool_concurrency_limit;
445 s.join_policy = self.join_policy;
446 s
447 }
448}
449
450impl<S, P> Service<CreateChatCompletionRequest> for Step<S, P>
451where
452 S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
453 S::Future: Send + 'static,
454 P: ModelService + Send + 'static,
455 P::Future: Send + 'static,
456{
457 type Response = StepOutcome;
458 type Error = BoxError;
459 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
460
461 fn poll_ready(
462 &mut self,
463 cx: &mut std::task::Context<'_>,
464 ) -> std::task::Poll<Result<(), Self::Error>> {
465 let _ = cx; std::task::Poll::Ready(Ok(()))
467 }
468
469 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
470 let provider = self.provider.clone();
471 let model = self.model.clone();
472 let temperature = self.temperature;
473 let max_tokens = self.max_tokens;
474 let reasoning_effort = self.reasoning_effort.clone();
475 let tools = self.tools.clone();
476 let tool_specs = self.tool_specs.clone();
477 let parallel_tools = self.parallel_tools;
478 let _tool_concurrency_limit = self.tool_concurrency_limit;
479 let join_policy = self.join_policy;
480 let instruction_source = self.instructions.clone();
481
482 Box::pin(async move {
483 let effective_model: Option<String> = req.model.clone().into();
485
486 let model_to_use = if let Some(m) = effective_model.as_ref() {
488 m.clone()
489 } else {
490 model.clone()
491 };
492
493 debug!(
495 model = %model_to_use,
496 temperature = ?req.temperature.or(temperature),
497 max_tokens = ?max_tokens,
498 tools_count = if req.tools.is_some() {
499 req.tools.as_ref().map(|t| t.len())
500 } else {
501 Some(tool_specs.len())
502 },
503 "Step service preparing API request"
504 );
505
506 let mut injected_messages = req.messages.clone();
508 let instructions = match instruction_source {
509 Some(InstructionSource::Static(text)) => Some(text),
510 Some(InstructionSource::Dynamic(provider)) => provider.instructions().await,
511 None => None,
512 };
513
514 if let Some(instr) = instructions {
515 let sys_msg = ChatCompletionRequestSystemMessageArgs::default()
517 .content(instr)
518 .build()
519 .map(ChatCompletionRequestMessage::from)
520 .map_err(|e| format!("system msg build error: {}", e))?;
521 if let Some(pos) = injected_messages
523 .iter()
524 .position(|m| matches!(m, ChatCompletionRequestMessage::System(_)))
525 {
526 injected_messages.remove(pos);
527 }
528 injected_messages.insert(0, sys_msg);
529 }
530
531 let mut builder = CreateChatCompletionRequestArgs::default();
532 builder.messages(injected_messages);
533 if let Some(m) = effective_model.as_ref() {
534 builder.model(m);
535 } else {
536 builder.model(&model);
537 }
538 if let Some(t) = req.temperature.or(temperature) {
539 builder.temperature(t);
540 }
541 #[allow(deprecated)]
543 if let Some(mt) = req.max_tokens.or(max_tokens) {
544 builder.max_tokens(mt);
545 }
546 if let Some(effort) = reasoning_effort {
547 builder.reasoning_effort(effort);
548 }
549 if let Some(ts) = req.tools.clone() {
550 builder.tools(ts);
551 } else if !tool_specs.is_empty() {
552 builder.tools((*tool_specs).clone());
553 }
554
555 let rebuilt_req = builder
556 .build()
557 .map_err(|e| format!("request build error: {}", e))?;
558
559 let request_shape = summarize_request_messages(&rebuilt_req.messages);
560
561 debug!(
563 final_model = ?rebuilt_req.model,
564 messages_count = rebuilt_req.messages.len(),
565 message_roles = ?request_shape,
566 "Step service final request built"
567 );
568
569 let mut messages = rebuilt_req.messages.clone();
570
571 let mut p = provider.lock().await;
574 let ProviderResponse {
575 assistant,
576 prompt_tokens,
577 completion_tokens,
578 } = ServiceExt::ready(&mut *p).await?.call(rebuilt_req).await?;
579 let response_shape = summarize_assistant_message(&assistant);
580 debug!(
581 response_shape = %response_shape,
582 prompt_tokens,
583 completion_tokens,
584 "Step service received provider response"
585 );
586 let mut aux = StepAux {
587 prompt_tokens,
588 completion_tokens,
589 tool_invocations: 0,
590 };
591
592 let mut asst_builder = ChatCompletionRequestAssistantMessageArgs::default();
594 if let Some(content) = assistant.content.clone() {
595 asst_builder.content(content);
596 } else {
597 asst_builder.content("");
598 }
599 if let Some(tool_calls) = assistant.tool_calls.clone() {
600 asst_builder.tool_calls(tool_calls);
601 }
602 let asst_req = asst_builder
603 .build()
604 .map_err(|e| format!("assistant msg build error: {}", e))?;
605 messages.push(ChatCompletionRequestMessage::from(asst_req));
606
607 let tool_calls = assistant.tool_calls.unwrap_or_default();
609 if tool_calls.is_empty() {
610 return Ok(StepOutcome::Done { messages, aux });
611 }
612
613 let mut invoked_names: Vec<String> = Vec::with_capacity(tool_calls.len());
614 let invocations: Vec<ToolInvocation> = tool_calls
615 .into_iter()
616 .map(|tc| {
617 let name = tc.function.name;
618 invoked_names.push(name.clone());
619 let args: Value =
620 serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
621 ToolInvocation {
622 id: tc.id,
623 name,
624 arguments: args,
625 }
626 })
627 .collect();
628
629 if !invoked_names.is_empty() {
630 debug!(tool_names = ?invoked_names, "Step service invoking tools");
631 }
632
633 if invocations.len() > 1 && parallel_tools {
634 let sem = _tool_concurrency_limit.map(|n| Arc::new(Semaphore::new(n)));
636 match join_policy {
637 ToolJoinPolicy::FailFast => {
638 let futures: Vec<_> = invocations
639 .into_iter()
640 .map(|inv| {
641 let mut svc = tools.clone();
642 let sem_cl = sem.clone();
643 async move {
644 let _permit = match &sem_cl {
645 Some(s) => Some(
646 s.clone().acquire_owned().await.expect("semaphore"),
647 ),
648 None => None,
649 };
650 let ToolOutput { id, result } =
651 ServiceExt::ready(&mut svc).await?.call(inv).await?;
652 Ok::<(String, Value), BoxError>((id, result))
653 }
654 })
655 .collect();
656 let outputs: Vec<(String, Value)> =
657 futures::future::try_join_all(futures).await?;
658 for (id, result) in outputs {
659 aux.tool_invocations += 1;
660 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
661 .content(result.to_string())
662 .tool_call_id(id)
663 .build()?;
664 messages.push(tool_msg.into());
665 }
666 }
667 ToolJoinPolicy::JoinAll => {
668 let futures: Vec<_> =
669 invocations
670 .into_iter()
671 .enumerate()
672 .map(|(idx, inv)| {
673 let mut svc = tools.clone();
674 let sem_cl = sem.clone();
675 async move {
676 let _permit = match &sem_cl {
677 Some(s) => Some(
678 s.clone().acquire_owned().await.expect("semaphore"),
679 ),
680 None => None,
681 };
682 let res =
683 ServiceExt::ready(&mut svc).await?.call(inv).await;
684 match res {
685 Ok(ToolOutput { id, result }) => Ok::<
686 Result<(usize, String, Value), BoxError>,
687 BoxError,
688 >(
689 Ok((idx, id, result)),
690 ),
691 Err(e) => Ok(Err(e)),
692 }
693 }
694 })
695 .collect();
696 let results = futures::future::join_all(futures).await;
697 let mut successes: Vec<(usize, String, Value)> = Vec::new();
698 let mut errors: Vec<String> = Vec::new();
699 for item in results.into_iter() {
700 match item {
701 Ok(Ok((idx, id, result))) => successes.push((idx, id, result)),
702 Ok(Err(e)) => errors.push(format!("{}", e)),
703 Err(e) => errors.push(format!("{}", e)),
704 }
705 }
706 successes.sort_by_key(|(idx, _, _)| *idx);
707 for (_idx, id, result) in successes.into_iter() {
708 aux.tool_invocations += 1;
709 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
710 .content(result.to_string())
711 .tool_call_id(id)
712 .build()?;
713 messages.push(tool_msg.into());
714 }
715 if !errors.is_empty() {
716 return Err(
717 format!("one or more tools failed: {}", errors.join("; ")).into()
718 );
719 }
720 }
721 }
722 } else {
723 for inv in invocations {
725 let mut svc = tools.clone();
726 let ToolOutput { id, result } =
727 ServiceExt::ready(&mut svc).await?.call(inv).await?;
728 aux.tool_invocations += 1;
729 let tool_msg = ChatCompletionRequestToolMessageArgs::default()
730 .content(result.to_string())
731 .tool_call_id(id)
732 .build()?;
733 messages.push(tool_msg.into());
734 }
735 }
736
737 Ok(StepOutcome::Next {
738 messages,
739 aux,
740 invoked_tools: invoked_names,
741 })
742 })
743 }
744}
745
746pub fn simple_chat_request(system: &str, user: &str) -> CreateChatCompletionRequest {
752 let sys = ChatCompletionRequestSystemMessageArgs::default()
753 .content(system)
754 .build()
755 .expect("system msg");
756 let usr = ChatCompletionRequestUserMessageArgs::default()
757 .content(user)
758 .build()
759 .expect("user msg");
760 CreateChatCompletionRequestArgs::default()
761 .model("gpt-4o")
762 .messages(vec![sys.into(), usr.into()])
763 .build()
764 .expect("chat req")
765}
766
767#[allow(dead_code)]
769pub fn simple_user_request(user: &str) -> CreateChatCompletionRequest {
770 let usr = ChatCompletionRequestUserMessageArgs::default()
771 .content(user)
772 .build()
773 .expect("user msg");
774 CreateChatCompletionRequestArgs::default()
775 .model("gpt-4o")
776 .messages(vec![usr.into()])
777 .build()
778 .expect("chat req")
779}
780
781#[derive(Debug, Clone)]
787#[allow(dead_code)]
788pub enum AgentStopReason {
789 DoneNoToolCalls,
790 MaxSteps,
791 ToolCalled(String),
792 TokensBudgetExceeded,
793 ToolBudgetExceeded,
794 TimeBudgetExceeded,
795}
796
797#[derive(Default, Clone)]
803pub struct Policy {
804 inner: CompositePolicy,
805}
806
807#[allow(dead_code)]
808impl Policy {
809 pub fn new() -> Self {
810 Self {
811 inner: CompositePolicy::default(),
812 }
813 }
814 pub fn until_no_tool_calls(mut self) -> Self {
815 self.inner.policies.push(policies::until_no_tool_calls());
816 self
817 }
818 pub fn or_tool(mut self, name: impl Into<String>) -> Self {
819 self.inner.policies.push(policies::until_tool_called(name));
820 self
821 }
822 pub fn or_max_steps(mut self, max: usize) -> Self {
823 self.inner.policies.push(policies::max_steps(max));
824 self
825 }
826 pub fn build(self) -> CompositePolicy {
827 self.inner
828 }
829}
830
831pub type AgentSvc = BoxService<CreateChatCompletionRequest, AgentRun, BoxError>;
833
834pub struct Agent;
836
837pub struct AgentBuilder {
838 client: Arc<Client<OpenAIConfig>>,
839 model: String,
840 temperature: Option<f32>,
841 max_tokens: Option<u32>,
842 reasoning_effort: Option<ReasoningEffort>,
843 instructions: Option<String>,
844 instruction_provider: Option<Arc<dyn LLMInstructionProvider>>,
845 tools: Vec<ToolDef>,
846 policy: CompositePolicy,
847 handoff: Option<crate::groups::AnyHandoffPolicy>,
848 provider: Option<
849 tower::util::BoxCloneService<
850 CreateChatCompletionRequest,
851 crate::provider::ProviderResponse,
852 BoxError,
853 >,
854 >,
855 enable_parallel_tools: bool,
856 tool_concurrency_limit: Option<usize>,
857 tool_join_policy: ToolJoinPolicy,
858 agent_service_map: Option<Arc<dyn Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static>>, auto_compaction: Option<crate::auto_compaction::CompactionPolicy>,
860}
861
862impl Agent {
863 pub fn builder(client: Arc<Client<OpenAIConfig>>) -> AgentBuilder {
864 AgentBuilder {
865 client,
866 model: "gpt-4o".to_string(),
867 temperature: None,
868 max_tokens: None,
869 reasoning_effort: None,
870 instructions: None,
871 instruction_provider: None,
872 tools: Vec::new(),
873 policy: CompositePolicy::default(),
874 handoff: None,
875 provider: None,
876 enable_parallel_tools: false,
877 tool_concurrency_limit: None,
878 tool_join_policy: ToolJoinPolicy::FailFast,
879 agent_service_map: None,
880 auto_compaction: None,
881 }
882 }
883}
884
885impl AgentBuilder {
886 pub fn model(mut self, model: impl Into<String>) -> Self {
887 self.model = model.into();
888 self
889 }
890 pub fn temperature(mut self, t: f32) -> Self {
891 self.temperature = Some(t);
892 self
893 }
894 pub fn max_tokens(mut self, mt: u32) -> Self {
895 self.max_tokens = Some(mt);
896 self
897 }
898 pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
899 self.reasoning_effort = Some(effort);
900 self
901 }
902
903 pub fn instructions(mut self, text: impl Into<String>) -> Self {
905 self.instructions = Some(text.into());
906 self
907 }
908 pub fn instruction_provider(mut self, provider: Arc<dyn LLMInstructionProvider>) -> Self {
909 self.instruction_provider = Some(provider);
910 self
911 }
912 pub fn tool(mut self, tool: ToolDef) -> Self {
913 self.tools.push(tool);
914 self
915 }
916 pub fn tools(mut self, tools: Vec<ToolDef>) -> Self {
917 self.tools.extend(tools);
918 self
919 }
920 pub fn policy(mut self, policy: CompositePolicy) -> Self {
921 self.policy = policy;
922 self
923 }
924
925 pub fn handoff_policy(mut self, policy: crate::groups::AnyHandoffPolicy) -> Self {
927 self.handoff = Some(policy);
928 self
929 }
930
931 pub fn with_provider<P>(mut self, provider: P) -> Self
933 where
934 P: crate::provider::ModelService + Clone + Send + 'static,
935 P::Future: Send + 'static,
936 {
937 self.provider = Some(tower::util::BoxCloneService::new(provider));
938 self
939 }
940
941 pub fn parallel_tools(mut self, enabled: bool) -> Self {
943 self.enable_parallel_tools = enabled;
944 self
945 }
946
947 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
949 self.tool_concurrency_limit = Some(limit);
950 self
951 }
952
953 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
955 self.tool_join_policy = policy;
956 self
957 }
958
959 pub fn map_agent_service<F>(mut self, f: F) -> Self
962 where
963 F: Fn(AgentSvc) -> AgentSvc + Send + Sync + 'static,
964 {
965 self.agent_service_map = Some(Arc::new(f));
966 self
967 }
968
969 pub fn auto_compaction(mut self, policy: crate::auto_compaction::CompactionPolicy) -> Self {
971 self.auto_compaction = Some(policy);
972 self
973 }
974
975 pub fn build(self) -> AgentSvc {
976 let (router, mut specs) = ToolRouter::new(self.tools);
977 let routed: ToolSvc = if let Some(policy) = &self.handoff {
979 let hand_spec = policy.handoff_tools();
980 if !hand_spec.is_empty() {
981 specs.extend(hand_spec);
982 }
983 crate::groups::layer_tool_router_with_handoff(router, policy.clone())
984 } else {
985 BoxCloneService::new(router)
987 };
988
989 let base_provider: tower::util::BoxCloneService<
990 CreateChatCompletionRequest,
991 crate::provider::ProviderResponse,
992 BoxError,
993 > = if let Some(p) = self.provider {
994 p
995 } else {
996 tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
997 };
998 let mut step_layer = StepLayer::new(base_provider.clone(), self.model, specs)
999 .parallel_tools(self.enable_parallel_tools)
1000 .tool_join_policy(self.tool_join_policy);
1001 if let Some(instr) = &self.instructions {
1002 step_layer = step_layer.instructions(instr.clone());
1003 } else if let Some(provider) = &self.instruction_provider {
1004 step_layer = step_layer.instruction_provider(provider.clone());
1005 }
1006 if let Some(t) = self.temperature {
1008 step_layer = step_layer.temperature(t);
1009 }
1010 if let Some(mt) = self.max_tokens {
1012 step_layer = step_layer.max_tokens(mt);
1013 }
1014 if let Some(effort) = self.reasoning_effort {
1015 step_layer = step_layer.reasoning_effort(effort);
1016 }
1017 if let Some(lim) = self.tool_concurrency_limit {
1018 step_layer = step_layer.tool_concurrency_limit(lim);
1019 }
1020 let step = step_layer.layer(routed);
1021
1022 let step_with_compaction: BoxService<CreateChatCompletionRequest, StepOutcome, BoxError> =
1024 if let Some(compaction_policy) = self.auto_compaction {
1025 let compaction_provider = base_provider;
1027 let token_counter = crate::auto_compaction::SimpleTokenCounter::new();
1028 let compaction_layer = crate::auto_compaction::AutoCompactionLayer::new(
1029 compaction_policy,
1030 compaction_provider,
1031 token_counter,
1032 );
1033 BoxService::new(compaction_layer.layer(step))
1034 } else {
1035 BoxService::new(step)
1036 };
1037
1038 let agent = AgentLoopLayer::new(self.policy).layer(step_with_compaction);
1039 let boxed = BoxService::new(agent);
1040 match &self.agent_service_map {
1041 Some(map) => (map)(boxed),
1042 None => boxed,
1043 }
1044 }
1045
1046 pub fn build_with_session<Ls, Ss>(
1048 self,
1049 load: Arc<Ls>,
1050 save: Arc<Ss>,
1051 session_id: crate::sessions::SessionId,
1052 ) -> AgentSvc
1053 where
1054 Ls: Service<
1055 crate::sessions::LoadSession,
1056 Response = crate::sessions::History,
1057 Error = BoxError,
1058 > + Send
1059 + Sync
1060 + Clone
1061 + 'static,
1062 Ls::Future: Send + 'static,
1063 Ss: Service<crate::sessions::SaveSession, Response = (), Error = BoxError>
1064 + Send
1065 + Sync
1066 + Clone
1067 + 'static,
1068 Ss::Future: Send + 'static,
1069 {
1070 let (router, mut specs) = ToolRouter::new(self.tools);
1071 let routed: ToolSvc = if let Some(policy) = &self.handoff {
1072 let hand_spec = policy.handoff_tools();
1073 if !hand_spec.is_empty() {
1074 specs.extend(hand_spec);
1075 }
1076 crate::groups::layer_tool_router_with_handoff(router, policy.clone())
1077 } else {
1078 BoxCloneService::new(router)
1079 };
1080
1081 let base_provider: tower::util::BoxCloneService<
1082 CreateChatCompletionRequest,
1083 crate::provider::ProviderResponse,
1084 BoxError,
1085 > = if let Some(p) = self.provider {
1086 p
1087 } else {
1088 tower::util::BoxCloneService::new(OpenAIProvider::new(self.client))
1089 };
1090 let mut step_layer = StepLayer::new(base_provider.clone(), self.model, specs)
1091 .parallel_tools(self.enable_parallel_tools)
1092 .tool_join_policy(self.tool_join_policy);
1093 if let Some(instr) = &self.instructions {
1094 step_layer = step_layer.instructions(instr.clone());
1095 } else if let Some(provider) = &self.instruction_provider {
1096 step_layer = step_layer.instruction_provider(provider.clone());
1097 }
1098 if let Some(t) = self.temperature {
1100 step_layer = step_layer.temperature(t);
1101 }
1102 if let Some(mt) = self.max_tokens {
1104 step_layer = step_layer.max_tokens(mt);
1105 }
1106 if let Some(effort) = self.reasoning_effort {
1107 step_layer = step_layer.reasoning_effort(effort);
1108 }
1109 if let Some(lim) = self.tool_concurrency_limit {
1110 step_layer = step_layer.tool_concurrency_limit(lim);
1111 }
1112 let step = step_layer.layer(routed);
1113
1114 let step_with_compaction: BoxService<CreateChatCompletionRequest, StepOutcome, BoxError> =
1116 if let Some(compaction_policy) = self.auto_compaction {
1117 let compaction_provider = base_provider;
1119 let token_counter = crate::auto_compaction::SimpleTokenCounter::new();
1120 let compaction_layer = crate::auto_compaction::AutoCompactionLayer::new(
1121 compaction_policy,
1122 compaction_provider,
1123 token_counter,
1124 );
1125 BoxService::new(compaction_layer.layer(step))
1126 } else {
1127 BoxService::new(step)
1128 };
1129
1130 let mem_layer = crate::sessions::MemoryLayer::new(load, save, session_id);
1132 let step_with_mem = mem_layer.layer(step_with_compaction);
1133 let agent = AgentLoopLayer::new(self.policy).layer(step_with_mem);
1134 let boxed = BoxService::new(agent);
1135 match &self.agent_service_map {
1136 Some(map) => (map)(boxed),
1137 None => boxed,
1138 }
1139 }
1140}
1141
1142pub async fn run(agent: &mut AgentSvc, system: &str, user: &str) -> Result<AgentRun, BoxError> {
1144 let req = simple_chat_request(system, user);
1145 let resp = ServiceExt::ready(agent).await?.call(req).await?;
1146 Ok(resp)
1147}
1148
1149#[allow(dead_code)]
1151pub async fn run_user(agent: &mut AgentSvc, user: &str) -> Result<AgentRun, BoxError> {
1152 let req = simple_user_request(user);
1153 let resp = ServiceExt::ready(agent).await?.call(req).await?;
1154 Ok(resp)
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159 use super::*;
1160 use async_openai::types::ChatCompletionRequestUserMessageArgs;
1161 use async_trait::async_trait;
1162 use std::sync::Arc;
1163
1164 #[tokio::test]
1165 async fn step_injects_instructions_prepend_or_replace() {
1166 #[allow(deprecated)]
1168 let assistant = async_openai::types::ChatCompletionResponseMessage {
1169 content: Some("ok".into()),
1170 role: async_openai::types::Role::Assistant,
1171 tool_calls: None,
1172 function_call: None,
1173 refusal: None,
1174 audio: None,
1175 };
1176 let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1177 assistant,
1178 prompt_tokens: 1,
1179 completion_tokens: 1,
1180 });
1181
1182 let (router, specs) = ToolRouter::new(vec![]);
1184 let step = StepLayer::new(provider, "gpt-4o", specs)
1185 .instructions("AGENT INSTR")
1186 .layer(router);
1187 let mut svc = tower::ServiceExt::boxed(step);
1188
1189 let user = ChatCompletionRequestUserMessageArgs::default()
1191 .content("hello")
1192 .build()
1193 .unwrap();
1194 let req = CreateChatCompletionRequestArgs::default()
1195 .model("gpt-4o")
1196 .messages(vec![user.into()])
1197 .build()
1198 .unwrap();
1199
1200 let out = tower::ServiceExt::ready(&mut svc)
1201 .await
1202 .unwrap()
1203 .call(req)
1204 .await
1205 .unwrap();
1206 let msgs = match out {
1207 StepOutcome::Next { messages, .. } => messages,
1208 StepOutcome::Done { messages, .. } => messages,
1209 };
1210 match &msgs[0] {
1212 ChatCompletionRequestMessage::System(s) => {
1213 if let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) =
1214 &s.content
1215 {
1216 assert_eq!(t, "AGENT INSTR");
1217 } else {
1218 panic!("expected text content in system message");
1219 }
1220 }
1221 _ => panic!("expected first message to be system"),
1222 }
1223 }
1224
1225 #[derive(Clone)]
1226 struct CountingInstructionProvider {
1227 counter: Arc<tokio::sync::Mutex<usize>>,
1228 }
1229
1230 #[async_trait]
1231 impl LLMInstructionProvider for CountingInstructionProvider {
1232 async fn instructions(&self) -> Option<String> {
1233 let mut guard = self.counter.lock().await;
1234 *guard += 1;
1235 Some(format!("CTX {}", *guard))
1236 }
1237 }
1238
1239 #[tokio::test]
1240 async fn step_uses_instruction_provider_each_call() {
1241 #[allow(deprecated)]
1242 let assistant = async_openai::types::ChatCompletionResponseMessage {
1243 content: Some("ok".into()),
1244 role: async_openai::types::Role::Assistant,
1245 tool_calls: None,
1246 function_call: None,
1247 refusal: None,
1248 audio: None,
1249 };
1250 let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1251 assistant,
1252 prompt_tokens: 1,
1253 completion_tokens: 1,
1254 });
1255
1256 let (router, specs) = ToolRouter::new(vec![]);
1257 let instruction_provider = Arc::new(CountingInstructionProvider {
1258 counter: Arc::new(tokio::sync::Mutex::new(0)),
1259 });
1260 let step = StepLayer::new(provider, "gpt-4o", specs)
1261 .instruction_provider(instruction_provider)
1262 .layer(router);
1263 let mut svc = tower::ServiceExt::boxed(step);
1264
1265 let build_request = || {
1266 let user = ChatCompletionRequestUserMessageArgs::default()
1267 .content("hello")
1268 .build()
1269 .unwrap();
1270 CreateChatCompletionRequestArgs::default()
1271 .model("gpt-4o")
1272 .messages(vec![user.into()])
1273 .build()
1274 .unwrap()
1275 };
1276
1277 for expected in ["CTX 1", "CTX 2"] {
1278 let resp = tower::ServiceExt::ready(&mut svc)
1279 .await
1280 .unwrap()
1281 .call(build_request())
1282 .await
1283 .unwrap();
1284
1285 let messages = match resp {
1286 StepOutcome::Next { messages, .. } => messages,
1287 StepOutcome::Done { messages, .. } => messages,
1288 };
1289
1290 match &messages[0] {
1291 ChatCompletionRequestMessage::System(s) => match &s.content {
1292 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1293 assert_eq!(t, expected);
1294 }
1295 _ => panic!("expected text content in system message"),
1296 },
1297 _ => panic!("expected first message to be system"),
1298 }
1299 }
1300 }
1301
1302 #[test]
1303 fn builds_user_request() {
1304 let _ = simple_user_request("hi");
1305 }
1306
1307 #[tokio::test]
1308 async fn run_user_executes_with_instructions() {
1309 #[allow(deprecated)]
1310 let assistant = async_openai::types::ChatCompletionResponseMessage {
1311 content: Some("ok".into()),
1312 role: async_openai::types::Role::Assistant,
1313 tool_calls: None,
1314 function_call: None,
1315 refusal: None,
1316 audio: None,
1317 };
1318 let provider = crate::provider::FixedProvider::new(crate::provider::ProviderResponse {
1319 assistant,
1320 prompt_tokens: 1,
1321 completion_tokens: 1,
1322 });
1323 let client =
1324 std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1325 let mut agent = Agent::builder(client)
1326 .with_provider(provider)
1327 .model("gpt-4o")
1328 .instructions("INSTR")
1329 .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1330 .build();
1331 let run = run_user(&mut agent, "hello").await.unwrap();
1332 assert!(!run.messages.is_empty());
1333 match &run.messages[0] {
1334 ChatCompletionRequestMessage::System(s) => match &s.content {
1335 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1336 assert_eq!(t, "INSTR");
1337 }
1338 _ => panic!("expected text content"),
1339 },
1340 _ => panic!("expected first message to be system"),
1341 }
1342 }
1343
1344 #[tokio::test]
1345 async fn sessions_preserve_agent_instructions_in_merged_request() {
1346 use crate::sessions::{InMemorySessionStore, SessionId};
1347 #[derive(Clone)]
1349 struct CapturingProvider {
1350 captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1351 }
1352 impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1353 type Response = crate::provider::ProviderResponse;
1354 type Error = BoxError;
1355 type Future = std::pin::Pin<
1356 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1357 >;
1358 fn poll_ready(
1359 &mut self,
1360 _cx: &mut std::task::Context<'_>,
1361 ) -> std::task::Poll<Result<(), Self::Error>> {
1362 std::task::Poll::Ready(Ok(()))
1363 }
1364 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1365 let captured = self.captured.clone();
1366 Box::pin(async move {
1367 *captured.lock().await = Some(req);
1368 #[allow(deprecated)]
1369 let assistant = async_openai::types::ChatCompletionResponseMessage {
1370 content: Some("ok".into()),
1371 role: async_openai::types::Role::Assistant,
1372 tool_calls: None,
1373 function_call: None,
1374 refusal: None,
1375 audio: None,
1376 };
1377 Ok(crate::provider::ProviderResponse {
1378 assistant,
1379 prompt_tokens: 1,
1380 completion_tokens: 1,
1381 })
1382 })
1383 }
1384 }
1385
1386 let captured = std::sync::Arc::new(tokio::sync::Mutex::new(None));
1387 let provider = CapturingProvider {
1388 captured: captured.clone(),
1389 };
1390
1391 let client =
1392 std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1393 let load = std::sync::Arc::new(InMemorySessionStore::default());
1394 let save = load.clone();
1395 let mut agent = Agent::builder(client)
1396 .with_provider(provider)
1397 .model("gpt-4o")
1398 .instructions("INSTR")
1399 .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1400 .build_with_session(load, save, SessionId("s1".into()));
1401
1402 let req = CreateChatCompletionRequestArgs::default()
1403 .model("gpt-4o")
1404 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1405 .content("hi")
1406 .build()
1407 .unwrap()
1408 .into()])
1409 .build()
1410 .unwrap();
1411 let _ = tower::ServiceExt::ready(&mut agent)
1412 .await
1413 .unwrap()
1414 .call(req)
1415 .await
1416 .unwrap();
1417
1418 let got = captured.lock().await.clone().expect("captured");
1419 match &got.messages[0] {
1420 ChatCompletionRequestMessage::System(s) => match &s.content {
1421 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1422 assert_eq!(t, "INSTR");
1423 }
1424 _ => panic!("expected text content"),
1425 },
1426 _ => panic!("expected first message to be system"),
1427 }
1428 }
1429
1430 #[tokio::test]
1431 async fn auto_compaction_preserves_instructions() {
1432 use crate::auto_compaction::{CompactionPolicy, CompactionStrategy, ProactiveThreshold};
1433 #[derive(Clone)]
1435 struct CapturingProvider {
1436 captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1437 }
1438 impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1439 type Response = crate::provider::ProviderResponse;
1440 type Error = BoxError;
1441 type Future = std::pin::Pin<
1442 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1443 >;
1444 fn poll_ready(
1445 &mut self,
1446 _cx: &mut std::task::Context<'_>,
1447 ) -> std::task::Poll<Result<(), Self::Error>> {
1448 std::task::Poll::Ready(Ok(()))
1449 }
1450 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1451 let captured = self.captured.clone();
1452 Box::pin(async move {
1453 *captured.lock().await = Some(req);
1454 #[allow(deprecated)]
1455 let assistant = async_openai::types::ChatCompletionResponseMessage {
1456 content: Some("ok".into()),
1457 role: async_openai::types::Role::Assistant,
1458 tool_calls: None,
1459 function_call: None,
1460 refusal: None,
1461 audio: None,
1462 };
1463 Ok(crate::provider::ProviderResponse {
1464 assistant,
1465 prompt_tokens: 1,
1466 completion_tokens: 1,
1467 })
1468 })
1469 }
1470 }
1471
1472 let captured = std::sync::Arc::new(tokio::sync::Mutex::new(None));
1473 let provider = CapturingProvider {
1474 captured: captured.clone(),
1475 };
1476 let client =
1477 std::sync::Arc::new(async_openai::Client::<async_openai::config::OpenAIConfig>::new());
1478
1479 let policy = CompactionPolicy {
1480 compaction_model: "gpt-4o-mini".to_string(),
1481 proactive_threshold: Some(ProactiveThreshold {
1482 token_threshold: 1,
1483 percentage_threshold: None,
1484 }),
1485 compaction_strategy: CompactionStrategy::PreserveSystemAndRecent { recent_count: 1 },
1486 ..Default::default()
1487 };
1488
1489 let mut agent = Agent::builder(client)
1490 .with_provider(provider)
1491 .model("gpt-4o")
1492 .instructions("INSTR")
1493 .auto_compaction(policy)
1494 .policy(CompositePolicy::new(vec![policies::max_steps(1)]))
1495 .build();
1496
1497 let mut long_user = String::new();
1498 for _ in 0..200 {
1499 long_user.push('x');
1500 }
1501 let req = CreateChatCompletionRequestArgs::default()
1502 .model("gpt-4o")
1503 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1504 .content(long_user)
1505 .build()
1506 .unwrap()
1507 .into()])
1508 .build()
1509 .unwrap();
1510 let _ = tower::ServiceExt::ready(&mut agent)
1511 .await
1512 .unwrap()
1513 .call(req)
1514 .await
1515 .unwrap();
1516
1517 let got = captured.lock().await.clone().expect("captured");
1518 match &got.messages[0] {
1519 ChatCompletionRequestMessage::System(s) => match &s.content {
1520 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1521 assert_eq!(t, "INSTR");
1522 }
1523 _ => panic!("expected text content"),
1524 },
1525 _ => panic!("expected first message to be system"),
1526 }
1527 }
1528}
1529
1530#[derive(Debug, Clone, Default)]
1532pub struct LoopState {
1533 pub steps: usize,
1534}
1535
1536pub trait AgentPolicy: Send + Sync {
1538 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason>;
1539}
1540
1541#[derive(Clone)]
1543#[allow(clippy::type_complexity)]
1544pub struct PolicyFn(
1545 pub Arc<dyn Fn(&LoopState, &StepOutcome) -> Option<AgentStopReason> + Send + Sync + 'static>,
1546);
1547
1548impl AgentPolicy for PolicyFn {
1549 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
1550 (self.0)(state, last)
1551 }
1552}
1553
1554#[derive(Clone, Default)]
1556pub struct CompositePolicy {
1557 policies: Vec<PolicyFn>,
1558}
1559
1560#[allow(dead_code)]
1561impl CompositePolicy {
1562 pub fn new(policies: Vec<PolicyFn>) -> Self {
1563 Self { policies }
1564 }
1565 pub fn push(&mut self, p: PolicyFn) {
1566 self.policies.push(p);
1567 }
1568}
1569
1570impl AgentPolicy for CompositePolicy {
1571 fn decide(&self, state: &LoopState, last: &StepOutcome) -> Option<AgentStopReason> {
1572 for p in &self.policies {
1573 if let Some(r) = p.decide(state, last) {
1574 return Some(r);
1575 }
1576 }
1577 None
1578 }
1579}
1580
1581#[allow(dead_code)]
1583pub mod policies {
1584 use super::*;
1585
1586 pub fn until_no_tool_calls() -> PolicyFn {
1587 PolicyFn(Arc::new(|_s, last| match last {
1588 StepOutcome::Done { .. } => Some(AgentStopReason::DoneNoToolCalls),
1589 _ => None,
1590 }))
1591 }
1592
1593 pub fn until_tool_called(tool_name: impl Into<String>) -> PolicyFn {
1594 let target = tool_name.into();
1595 PolicyFn(Arc::new(move |_s, last| match last {
1596 StepOutcome::Next { invoked_tools, .. } => {
1597 if invoked_tools.iter().any(|n| n == &target) {
1598 Some(AgentStopReason::ToolCalled(target.clone()))
1599 } else {
1600 None
1601 }
1602 }
1603 _ => None,
1604 }))
1605 }
1606
1607 pub fn max_steps(max: usize) -> PolicyFn {
1608 PolicyFn(Arc::new(move |s, _| {
1609 if s.steps >= max {
1610 Some(AgentStopReason::MaxSteps)
1611 } else {
1612 None
1613 }
1614 }))
1615 }
1616}
1617
1618#[derive(Debug, Clone)]
1620pub struct AgentRun {
1621 pub messages: Vec<ChatCompletionRequestMessage>,
1622 pub steps: usize,
1623 pub stop: AgentStopReason,
1624}
1625
1626pub struct AgentLoopLayer<P> {
1628 policy: P,
1629}
1630
1631impl<P> AgentLoopLayer<P> {
1632 pub fn new(policy: P) -> Self {
1633 Self { policy }
1634 }
1635}
1636
1637pub struct AgentLoop<S, P> {
1638 inner: Arc<tokio::sync::Mutex<S>>,
1639 policy: P,
1640}
1641
1642impl<S, P> Layer<S> for AgentLoopLayer<P>
1643where
1644 P: Clone,
1645{
1646 type Service = AgentLoop<S, P>;
1647 fn layer(&self, inner: S) -> Self::Service {
1648 AgentLoop {
1649 inner: Arc::new(tokio::sync::Mutex::new(inner)),
1650 policy: self.policy.clone(),
1651 }
1652 }
1653}
1654
1655impl<S, P> Service<CreateChatCompletionRequest> for AgentLoop<S, P>
1656where
1657 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
1658 + Send
1659 + 'static,
1660 S::Future: Send + 'static,
1661 P: AgentPolicy + Send + Sync + Clone + 'static,
1662{
1663 type Response = AgentRun;
1664 type Error = BoxError;
1665 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1666
1667 fn poll_ready(
1668 &mut self,
1669 _cx: &mut std::task::Context<'_>,
1670 ) -> std::task::Poll<Result<(), Self::Error>> {
1671 std::task::Poll::Ready(Ok(()))
1672 }
1673
1674 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1675 let inner = self.inner.clone();
1676 let policy = self.policy.clone();
1677 Box::pin(async move {
1678 let mut state = LoopState::default();
1679 let base_model = req.model.clone();
1680 let mut current_messages = req.messages.clone();
1681 let base_temperature = req.temperature;
1683 #[allow(deprecated)]
1684 let base_max_tokens = req.max_tokens;
1685 let base_max_completion_tokens = req.max_completion_tokens;
1686 let base_tools = req.tools.clone();
1687
1688 debug!(
1690 model = ?base_model,
1691 initial_messages = current_messages.len(),
1692 "AgentLoop starting with model"
1693 );
1694
1695 loop {
1696 let mut builder = CreateChatCompletionRequestArgs::default();
1698 builder.model(&base_model);
1699 builder.messages(current_messages.clone());
1700 if let Some(t) = base_temperature {
1701 builder.temperature(t);
1702 }
1703 if let Some(mt) = base_max_tokens {
1704 builder.max_tokens(mt);
1705 }
1706 if let Some(mct) = base_max_completion_tokens {
1707 builder.max_completion_tokens(mct);
1708 }
1709 if let Some(tools) = base_tools.clone() {
1710 builder.tools(tools);
1711 }
1712 let current_req = builder
1713 .build()
1714 .map_err(|e| format!("build req error: {}", e))?;
1715
1716 trace!(
1717 step = state.steps + 1,
1718 model = ?current_req.model,
1719 messages = current_messages.len(),
1720 "AgentLoop iteration"
1721 );
1722
1723 let mut guard = inner.lock().await;
1724 let outcome = guard.ready().await?.call(current_req).await?;
1725 drop(guard);
1726
1727 state.steps += 1;
1728
1729 if let Some(stop) = policy.decide(&state, &outcome) {
1730 let messages = match outcome {
1731 StepOutcome::Next { messages, .. } => messages,
1732 StepOutcome::Done { messages, .. } => messages,
1733 };
1734 return Ok(AgentRun {
1735 messages,
1736 steps: state.steps,
1737 stop,
1738 });
1739 }
1740
1741 match outcome {
1742 StepOutcome::Next { messages, .. } => {
1743 current_messages = messages;
1744 }
1745 StepOutcome::Done { messages, .. } => {
1746 return Ok(AgentRun {
1747 messages,
1748 steps: state.steps,
1749 stop: AgentStopReason::DoneNoToolCalls,
1750 });
1751 }
1752 }
1753 }
1754 })
1755 }
1756}