tower_llm/provider/
mod.rs

1//! Model provider abstractions (streaming and non-streaming)
2//!
3//! What this module provides (spec)
4//! - An interface for LLM providers (OpenAI, local mocks) decoupled from the step logic
5//!
6//! Exports
7//! - Models
8//!   - `ModelRequest { messages, tools, temperature, max_tokens }`
9//!   - `ModelResponse { assistant: AssistantMessage, usage: Usage }`
10//! - Services
11//!   - `ModelService: Service<ModelRequest, Response=ModelResponse, Error=BoxError>`
12//!   - Implementations: `OpenAIProvider`, `MockProvider`
13//! - Layers
14//!   - `RequestMapLayer<S>`: `Service<RawChatRequest> -> Service<ModelRequest>` adapter
15//! - Utils
16//!   - Provider presets: `providers::openai(model)`, `providers::mock(script)`
17//!
18//! Implementation strategy
19//! - `Step` is made generic over `M: ModelService`
20//! - The OpenAI adapter maps our message/types to async-openai types
21//! - Tools are mapped to function specs on demand
22//!
23//! Composition
24//! - For simple usage, keep `Step` with an OpenAI provider injected
25//! - For testing, swap in `MockProvider` returning scripted `ModelResponse`s
26//!
27//! Testing strategy
28//! - Unit tests for mapping correctness; round-trip assistant tool_calls through adapter
29//! - Integration tests: Step+MockProvider to test tool routing and loop logic independent of network
30
31use std::future::Future;
32use std::pin::Pin;
33
34use async_openai::{
35    config::OpenAIConfig,
36    types::{ChatCompletionResponseMessage, CreateChatCompletionRequest},
37    Client,
38};
39use futures::stream;
40use futures::Stream;
41use tower::BoxError;
42use tracing::{debug, trace};
43
44pub use crate::streaming::{StepChunk, StepProvider};
45
46/// A provider that always yields a fixed sequence of chunks.
47#[derive(Clone)]
48pub struct SequenceProvider {
49    items: Vec<StepChunk>,
50}
51impl SequenceProvider {
52    pub fn new(items: Vec<StepChunk>) -> Self {
53        Self { items }
54    }
55}
56
57impl StepProvider for SequenceProvider {
58    type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
59    fn stream_step(
60        &self,
61        _req: CreateChatCompletionRequest,
62    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
63        let iter = stream::iter(self.items.clone());
64        Box::pin(async move { Ok(Box::pin(iter) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>) })
65    }
66}
67
68// =============================
69// Non-streaming provider (ModelService)
70// =============================
71
72#[derive(Debug, Clone)]
73pub struct ProviderResponse {
74    pub assistant: ChatCompletionResponseMessage,
75    pub prompt_tokens: usize,
76    pub completion_tokens: usize,
77}
78
79/// Trait alias for non-streaming model services
80pub trait ModelService:
81    tower::Service<CreateChatCompletionRequest, Response = ProviderResponse, Error = BoxError>
82{
83}
84impl<T> ModelService for T where
85    T: tower::Service<CreateChatCompletionRequest, Response = ProviderResponse, Error = BoxError>
86{
87}
88
89/// OpenAI adapter for the non-streaming provider interface
90#[derive(Clone)]
91pub struct OpenAIProvider {
92    client: std::sync::Arc<Client<OpenAIConfig>>,
93}
94impl OpenAIProvider {
95    pub fn new(client: std::sync::Arc<Client<OpenAIConfig>>) -> Self {
96        Self { client }
97    }
98}
99
100impl tower::Service<CreateChatCompletionRequest> for OpenAIProvider {
101    type Response = ProviderResponse;
102    type Error = BoxError;
103    type Future = std::pin::Pin<
104        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
105    >;
106
107    fn poll_ready(
108        &mut self,
109        _cx: &mut std::task::Context<'_>,
110    ) -> std::task::Poll<Result<(), Self::Error>> {
111        std::task::Poll::Ready(Ok(()))
112    }
113
114    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
115        let client = self.client.clone();
116
117        // Log the actual model being sent to OpenAI API
118        debug!(
119            model = ?req.model,
120            temperature = ?req.temperature,
121            messages_count = req.messages.len(),
122            tools_count = req.tools.as_ref().map(|t| t.len()).unwrap_or(0),
123            "OpenAIProvider sending request to API"
124        );
125
126        Box::pin(async move {
127            let model_debug = format!("{:?}", req.model);
128            let resp = client.chat().create(req).await.map_err(|e| {
129                // Log API errors with model information
130                debug!(
131                    model = %model_debug,
132                    error = %e,
133                    "OpenAI API error"
134                );
135                e
136            })?;
137            let usage = resp.usage.unwrap_or_default();
138            let choice = resp
139                .choices
140                .into_iter()
141                .next()
142                .ok_or_else(|| "no choices".to_string())?;
143
144            trace!(
145                prompt_tokens = usage.prompt_tokens,
146                completion_tokens = usage.completion_tokens,
147                "OpenAI API response received"
148            );
149
150            Ok(ProviderResponse {
151                assistant: choice.message,
152                prompt_tokens: usage.prompt_tokens as usize,
153                completion_tokens: usage.completion_tokens as usize,
154            })
155        })
156    }
157}
158
159/// Fixed-response provider for tests
160#[derive(Clone)]
161pub struct FixedProvider {
162    output: ProviderResponse,
163}
164impl FixedProvider {
165    pub fn new(output: ProviderResponse) -> Self {
166        Self { output }
167    }
168}
169
170impl tower::Service<CreateChatCompletionRequest> for FixedProvider {
171    type Response = ProviderResponse;
172    type Error = BoxError;
173    type Future = std::pin::Pin<
174        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
175    >;
176
177    fn poll_ready(
178        &mut self,
179        _cx: &mut std::task::Context<'_>,
180    ) -> std::task::Poll<Result<(), Self::Error>> {
181        std::task::Poll::Ready(Ok(()))
182    }
183
184    fn call(&mut self, _req: CreateChatCompletionRequest) -> Self::Future {
185        let out = self.output.clone();
186        Box::pin(async move { Ok(out) })
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use async_openai::types::CreateChatCompletionRequestArgs;
194    use futures::StreamExt;
195
196    #[tokio::test]
197    async fn sequence_provider_streams_items() {
198        let p = SequenceProvider::new(vec![
199            StepChunk::Token("a".into()),
200            StepChunk::Token("b".into()),
201        ]);
202        let req = CreateChatCompletionRequestArgs::default()
203            .model("gpt-4o")
204            .messages(vec![])
205            .build()
206            .unwrap();
207        let mut s = p.stream_step(req).await.unwrap();
208        let items: Vec<_> = s.by_ref().collect().await;
209        assert_eq!(items.len(), 2);
210    }
211}