tower_llm/provider/
mod.rs1use std::future::Future;
32use std::pin::Pin;
33
34use async_openai::{
35 config::OpenAIConfig,
36 types::{ChatCompletionResponseMessage, CreateChatCompletionRequest},
37 Client,
38};
39use futures::stream;
40use futures::Stream;
41use tower::BoxError;
42use tracing::{debug, trace};
43
44pub use crate::streaming::{StepChunk, StepProvider};
45
46#[derive(Clone)]
48pub struct SequenceProvider {
49 items: Vec<StepChunk>,
50}
51impl SequenceProvider {
52 pub fn new(items: Vec<StepChunk>) -> Self {
53 Self { items }
54 }
55}
56
57impl StepProvider for SequenceProvider {
58 type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
59 fn stream_step(
60 &self,
61 _req: CreateChatCompletionRequest,
62 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
63 let iter = stream::iter(self.items.clone());
64 Box::pin(async move { Ok(Box::pin(iter) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>) })
65 }
66}
67
68#[derive(Debug, Clone)]
73pub struct ProviderResponse {
74 pub assistant: ChatCompletionResponseMessage,
75 pub prompt_tokens: usize,
76 pub completion_tokens: usize,
77}
78
79pub trait ModelService:
81 tower::Service<CreateChatCompletionRequest, Response = ProviderResponse, Error = BoxError>
82{
83}
84impl<T> ModelService for T where
85 T: tower::Service<CreateChatCompletionRequest, Response = ProviderResponse, Error = BoxError>
86{
87}
88
89#[derive(Clone)]
91pub struct OpenAIProvider {
92 client: std::sync::Arc<Client<OpenAIConfig>>,
93}
94impl OpenAIProvider {
95 pub fn new(client: std::sync::Arc<Client<OpenAIConfig>>) -> Self {
96 Self { client }
97 }
98}
99
100impl tower::Service<CreateChatCompletionRequest> for OpenAIProvider {
101 type Response = ProviderResponse;
102 type Error = BoxError;
103 type Future = std::pin::Pin<
104 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
105 >;
106
107 fn poll_ready(
108 &mut self,
109 _cx: &mut std::task::Context<'_>,
110 ) -> std::task::Poll<Result<(), Self::Error>> {
111 std::task::Poll::Ready(Ok(()))
112 }
113
114 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
115 let client = self.client.clone();
116
117 debug!(
119 model = ?req.model,
120 temperature = ?req.temperature,
121 messages_count = req.messages.len(),
122 tools_count = req.tools.as_ref().map(|t| t.len()).unwrap_or(0),
123 "OpenAIProvider sending request to API"
124 );
125
126 Box::pin(async move {
127 let model_debug = format!("{:?}", req.model);
128 let resp = client.chat().create(req).await.map_err(|e| {
129 debug!(
131 model = %model_debug,
132 error = %e,
133 "OpenAI API error"
134 );
135 e
136 })?;
137 let usage = resp.usage.unwrap_or_default();
138 let choice = resp
139 .choices
140 .into_iter()
141 .next()
142 .ok_or_else(|| "no choices".to_string())?;
143
144 trace!(
145 prompt_tokens = usage.prompt_tokens,
146 completion_tokens = usage.completion_tokens,
147 "OpenAI API response received"
148 );
149
150 Ok(ProviderResponse {
151 assistant: choice.message,
152 prompt_tokens: usage.prompt_tokens as usize,
153 completion_tokens: usage.completion_tokens as usize,
154 })
155 })
156 }
157}
158
159#[derive(Clone)]
161pub struct FixedProvider {
162 output: ProviderResponse,
163}
164impl FixedProvider {
165 pub fn new(output: ProviderResponse) -> Self {
166 Self { output }
167 }
168}
169
170impl tower::Service<CreateChatCompletionRequest> for FixedProvider {
171 type Response = ProviderResponse;
172 type Error = BoxError;
173 type Future = std::pin::Pin<
174 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
175 >;
176
177 fn poll_ready(
178 &mut self,
179 _cx: &mut std::task::Context<'_>,
180 ) -> std::task::Poll<Result<(), Self::Error>> {
181 std::task::Poll::Ready(Ok(()))
182 }
183
184 fn call(&mut self, _req: CreateChatCompletionRequest) -> Self::Future {
185 let out = self.output.clone();
186 Box::pin(async move { Ok(out) })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use async_openai::types::CreateChatCompletionRequestArgs;
194 use futures::StreamExt;
195
196 #[tokio::test]
197 async fn sequence_provider_streams_items() {
198 let p = SequenceProvider::new(vec![
199 StepChunk::Token("a".into()),
200 StepChunk::Token("b".into()),
201 ]);
202 let req = CreateChatCompletionRequestArgs::default()
203 .model("gpt-4o")
204 .messages(vec![])
205 .build()
206 .unwrap();
207 let mut s = p.stream_step(req).await.unwrap();
208 let items: Vec<_> = s.by_ref().collect().await;
209 assert_eq!(items.len(), 2);
210 }
211}