Skip to main content

serdes_ai/
direct.rs

1//! Direct model request functions.
2//!
3//! These functions allow making imperative requests to models with minimal abstraction.
4//! The only abstraction is input/output schema translation for unified API access.
5//!
6//! Use these when you want simple, direct access to models without the full agent
7//! infrastructure. Great for one-off queries, scripts, and simple integrations.
8//!
9//! # Examples
10//!
11//! ## Non-streaming request
12//!
13//! ```rust,ignore
14//! use serdes_ai::direct::model_request;
15//! use serdes_ai_core::ModelRequest;
16//!
17//! let response = model_request(
18//!     "openai:gpt-4o",
19//!     &[ModelRequest::user("What is the capital of France?")],
20//!     None,
21//!     None,
22//! ).await?;
23//!
24//! println!("{}", response.text());
25//! ```
26//!
27//! ## Streaming request
28//!
29//! ```rust,ignore
30//! use serdes_ai::direct::model_request_stream;
31//! use futures::StreamExt;
32//!
33//! let mut stream = model_request_stream(
34//!     "anthropic:claude-3-5-sonnet",
35//!     &[ModelRequest::user("Write a poem")],
36//!     None,
37//!     None,
38//! ).await?;
39//!
40//! while let Some(event) = stream.next().await {
41//!     // Handle streaming events
42//! }
43//! ```
44//!
45//! ## Using a pre-built model instance
46//!
47//! ```rust,ignore
48//! use serdes_ai::direct::model_request;
49//! use serdes_ai_models::openai::OpenAIChatModel;
50//!
51//! let model = OpenAIChatModel::from_env("gpt-4o")?;
52//! let response = model_request(
53//!     model,
54//!     &[ModelRequest::user("Hello!")],
55//!     None,
56//!     None,
57//! ).await?;
58//! ```
59
60use std::sync::Arc;
61
62use futures::StreamExt;
63use serdes_ai_core::{
64    messages::ModelResponseStreamEvent, ModelRequest, ModelResponse, ModelSettings,
65};
66use serdes_ai_models::{BoxedModel, Model, ModelError, ModelRequestParameters, StreamedResponse};
67use thiserror::Error;
68
69// ============================================================================
70// Error Type
71// ============================================================================
72
73/// Error type for direct requests.
74#[derive(Debug, Error)]
75pub enum DirectError {
76    /// Invalid model name format.
77    #[error("Invalid model name: {0}")]
78    InvalidModelName(String),
79
80    /// Model-level error (API, network, etc.).
81    #[error("Model error: {0}")]
82    ModelError(#[from] ModelError),
83
84    /// Runtime error (e.g., sync functions called in async context).
85    #[error("Runtime error: {0}")]
86    RuntimeError(String),
87
88    /// Provider not available (feature not enabled).
89    #[error("Provider not available: {0}. Enable the corresponding feature.")]
90    ProviderNotAvailable(String),
91}
92
93// ============================================================================
94// Model Specification
95// ============================================================================
96
97/// Model specification - either a string like "openai:gpt-4o" or a Model instance.
98///
99/// This allows flexible model specification in the direct API functions.
100///
101/// # Examples
102///
103/// ```rust,ignore
104/// // From string
105/// let spec: ModelSpec = "openai:gpt-4o".into();
106///
107/// // From model instance
108/// let model = OpenAIChatModel::from_env("gpt-4o")?;
109/// let spec: ModelSpec = model.into();
110/// ```
111#[derive(Clone)]
112pub enum ModelSpec {
113    /// Model specified by name (e.g., "openai:gpt-4o").
114    Name(String),
115    /// Pre-built model instance.
116    Instance(BoxedModel),
117}
118
119impl From<&str> for ModelSpec {
120    fn from(s: &str) -> Self {
121        ModelSpec::Name(s.to_string())
122    }
123}
124
125impl From<String> for ModelSpec {
126    fn from(s: String) -> Self {
127        ModelSpec::Name(s)
128    }
129}
130
131impl From<BoxedModel> for ModelSpec {
132    fn from(model: BoxedModel) -> Self {
133        ModelSpec::Instance(model)
134    }
135}
136
137impl ModelSpec {
138    /// Create a ModelSpec from any concrete Model type.
139    ///
140    /// This is a convenience method for wrapping concrete model types.
141    ///
142    /// # Example
143    ///
144    /// ```rust,ignore
145    /// use serdes_ai::direct::ModelSpec;
146    /// use serdes_ai_models::openai::OpenAIChatModel;
147    ///
148    /// let model = OpenAIChatModel::from_env("gpt-4o")?;
149    /// let spec = ModelSpec::from_model(model);
150    /// ```
151    pub fn from_model<M: Model + 'static>(model: M) -> Self {
152        ModelSpec::Instance(Arc::new(model))
153    }
154}
155
156impl ModelSpec {
157    /// Resolve the spec into a concrete model instance.
158    fn resolve(self) -> Result<BoxedModel, DirectError> {
159        match self {
160            ModelSpec::Name(name) => parse_model_name(&name),
161            ModelSpec::Instance(model) => Ok(model),
162        }
163    }
164}
165
166// ============================================================================
167// Non-Streaming Requests
168// ============================================================================
169
170/// Make a non-streamed request to a model.
171///
172/// This is the simplest way to get a response from a model. It blocks until
173/// the full response is available.
174///
175/// # Arguments
176///
177/// * `model` - Model specification (string like "openai:gpt-4o" or a Model instance)
178/// * `messages` - Slice of request messages
179/// * `model_settings` - Optional model settings (temperature, max_tokens, etc.)
180/// * `model_request_parameters` - Optional request parameters (tools, output schema, etc.)
181///
182/// # Example
183///
184/// ```rust,ignore
185/// use serdes_ai::direct::model_request;
186/// use serdes_ai_core::ModelRequest;
187///
188/// let response = model_request(
189///     "openai:gpt-4o",
190///     &[ModelRequest::user("What is the capital of France?")],
191///     None,
192///     None,
193/// ).await?;
194///
195/// println!("{}", response.text());
196/// ```
197pub async fn model_request(
198    model: impl Into<ModelSpec>,
199    messages: &[ModelRequest],
200    model_settings: Option<ModelSettings>,
201    model_request_parameters: Option<ModelRequestParameters>,
202) -> Result<ModelResponse, DirectError> {
203    let model = model.into().resolve()?;
204    let settings = model_settings.unwrap_or_default();
205    let params = model_request_parameters.unwrap_or_default();
206
207    let response = model.request(messages, &settings, &params).await?;
208    Ok(response)
209}
210
211/// Make a synchronous (blocking) non-streamed request.
212///
213/// This wraps `model_request` with a tokio runtime. It creates a new runtime
214/// for each call, so it's not the most efficient for high-throughput scenarios.
215///
216/// # Warning
217///
218/// Cannot be used inside async code (will panic if called from an async context).
219/// Use `model_request` instead in async contexts.
220///
221/// # Example
222///
223/// ```rust,ignore
224/// use serdes_ai::direct::model_request_sync;
225/// use serdes_ai_core::ModelRequest;
226///
227/// fn main() {
228///     let response = model_request_sync(
229///         "openai:gpt-4o",
230///         &[ModelRequest::user("Hello!")],
231///         None,
232///         None,
233///     ).unwrap();
234///
235///     println!("{}", response.text());
236/// }
237/// ```
238pub fn model_request_sync(
239    model: impl Into<ModelSpec>,
240    messages: &[ModelRequest],
241    model_settings: Option<ModelSettings>,
242    model_request_parameters: Option<ModelRequestParameters>,
243) -> Result<ModelResponse, DirectError> {
244    // Check if we're already in an async context
245    if tokio::runtime::Handle::try_current().is_ok() {
246        return Err(DirectError::RuntimeError(
247            "model_request_sync cannot be called from async context. Use model_request instead."
248                .to_string(),
249        ));
250    }
251
252    // Create a new runtime for the blocking call
253    let rt = tokio::runtime::Builder::new_current_thread()
254        .enable_all()
255        .build()
256        .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
257
258    // Clone what we need since we can't move references
259    let model_spec = model.into();
260    let messages_owned: Vec<ModelRequest> = messages.to_vec();
261    let settings = model_settings;
262    let params = model_request_parameters;
263
264    rt.block_on(async move { model_request(model_spec, &messages_owned, settings, params).await })
265}
266
267// ============================================================================
268// Streaming Requests
269// ============================================================================
270
271/// Make a streaming request to a model.
272///
273/// Returns a stream of response events that can be processed as they arrive.
274/// This is useful for real-time output and long responses.
275///
276/// # Arguments
277///
278/// * `model` - Model specification (string like "openai:gpt-4o" or a Model instance)
279/// * `messages` - Slice of request messages
280/// * `model_settings` - Optional model settings (temperature, max_tokens, etc.)
281/// * `model_request_parameters` - Optional request parameters (tools, output schema, etc.)
282///
283/// # Example
284///
285/// ```rust,ignore
286/// use serdes_ai::direct::model_request_stream;
287/// use serdes_ai_core::messages::ModelResponseStreamEvent;
288/// use futures::StreamExt;
289///
290/// let mut stream = model_request_stream(
291///     "anthropic:claude-3-5-sonnet",
292///     &[ModelRequest::user("Write a poem about Rust")],
293///     None,
294///     None,
295/// ).await?;
296///
297/// while let Some(event) = stream.next().await {
298///     match event? {
299///         ModelResponseStreamEvent::PartDelta(delta) => {
300///             if let Some(text) = delta.delta.content_delta() {
301///                 print!("{}", text);
302///             }
303///         }
304///         _ => {}
305///     }
306/// }
307/// ```
308pub async fn model_request_stream(
309    model: impl Into<ModelSpec>,
310    messages: &[ModelRequest],
311    model_settings: Option<ModelSettings>,
312    model_request_parameters: Option<ModelRequestParameters>,
313) -> Result<StreamedResponse, DirectError> {
314    let model = model.into().resolve()?;
315    let settings = model_settings.unwrap_or_default();
316    let params = model_request_parameters.unwrap_or_default();
317
318    let stream = model.request_stream(messages, &settings, &params).await?;
319    Ok(stream)
320}
321
322/// Synchronous streaming request wrapper.
323///
324/// This struct wraps a streaming response and provides a synchronous iterator
325/// interface for consuming streaming events.
326///
327/// # Warning
328///
329/// Cannot be used inside async code (will panic if called from an async context).
330pub struct StreamedResponseSync {
331    /// The underlying async runtime.
332    runtime: tokio::runtime::Runtime,
333    /// The underlying async stream.
334    stream: Option<StreamedResponse>,
335}
336
337impl StreamedResponseSync {
338    /// Create a new sync wrapper around an async stream.
339    fn new(stream: StreamedResponse) -> Result<Self, DirectError> {
340        let runtime = tokio::runtime::Builder::new_current_thread()
341            .enable_all()
342            .build()
343            .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
344
345        Ok(Self {
346            runtime,
347            stream: Some(stream),
348        })
349    }
350}
351
352impl Iterator for StreamedResponseSync {
353    type Item = Result<ModelResponseStreamEvent, ModelError>;
354
355    fn next(&mut self) -> Option<Self::Item> {
356        let stream = self.stream.as_mut()?;
357        self.runtime.block_on(stream.next())
358    }
359}
360
361/// Synchronous streaming request.
362///
363/// This creates a streaming request and wraps it in a synchronous iterator.
364///
365/// # Warning
366///
367/// Cannot be used inside async code (will panic if called from an async context).
368/// Use `model_request_stream` instead in async contexts.
369///
370/// # Example
371///
372/// ```rust,ignore
373/// use serdes_ai::direct::model_request_stream_sync;
374/// use serdes_ai_core::ModelRequest;
375///
376/// fn main() {
377///     let stream = model_request_stream_sync(
378///         "openai:gpt-4o",
379///         &[ModelRequest::user("Tell me a story")],
380///         None,
381///         None,
382///     ).unwrap();
383///
384///     for event in stream {
385///         // Handle each event
386///     }
387/// }
388/// ```
389pub fn model_request_stream_sync(
390    model: impl Into<ModelSpec>,
391    messages: &[ModelRequest],
392    model_settings: Option<ModelSettings>,
393    model_request_parameters: Option<ModelRequestParameters>,
394) -> Result<StreamedResponseSync, DirectError> {
395    // Check if we're already in an async context
396    if tokio::runtime::Handle::try_current().is_ok() {
397        return Err(DirectError::RuntimeError(
398            "model_request_stream_sync cannot be called from async context. Use model_request_stream instead."
399                .to_string(),
400        ));
401    }
402
403    // Create a runtime to set up the stream
404    let setup_rt = tokio::runtime::Builder::new_current_thread()
405        .enable_all()
406        .build()
407        .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
408
409    let model_spec = model.into();
410    let messages_owned: Vec<ModelRequest> = messages.to_vec();
411    let settings = model_settings;
412    let params = model_request_parameters;
413
414    let stream = setup_rt.block_on(async move {
415        model_request_stream(model_spec, &messages_owned, settings, params).await
416    })?;
417
418    // Drop the setup runtime and create the iterator with its own runtime
419    drop(setup_rt);
420
421    StreamedResponseSync::new(stream)
422}
423
424// ============================================================================
425// Model Parsing
426// ============================================================================
427
428/// Parse a model name like "openai:gpt-4o" into a model instance.
429///
430/// Supported formats:
431/// - `provider:model_name` (e.g., "openai:gpt-4o", "anthropic:claude-3-5-sonnet")
432/// - `model_name` (defaults to OpenAI)
433///
434/// Available providers (when their features are enabled):
435/// - `openai` / `gpt`: OpenAI models
436/// - `anthropic` / `claude`: Anthropic Claude models
437/// - `groq`: Groq fast inference
438/// - `mistral`: Mistral AI models
439/// - `ollama`: Local Ollama models
440/// - `bedrock` / `aws`: AWS Bedrock models
441/// - `openrouter` / `or`: OpenRouter multi-provider
442/// - `huggingface` / `hf`: HuggingFace Inference API
443/// - `cohere` / `co`: Cohere models
444fn parse_model_name(name: &str) -> Result<BoxedModel, DirectError> {
445    // Use the infer_model function from serdes-ai-models
446    #[cfg(feature = "openai")]
447    {
448        serdes_ai_models::infer_model(name).map_err(DirectError::ModelError)
449    }
450
451    #[cfg(not(feature = "openai"))]
452    {
453        // Without openai feature, we need manual parsing
454        let (provider, model_name) = if name.contains(':') {
455            let parts: Vec<&str> = name.splitn(2, ':').collect();
456            (parts[0], parts[1])
457        } else {
458            return Err(DirectError::InvalidModelName(format!(
459                "Model name '{}' requires a provider prefix (e.g., 'anthropic:{}') \
460                 when the 'openai' feature is not enabled.",
461                name, name
462            )));
463        };
464
465        match provider {
466            #[cfg(feature = "anthropic")]
467            "anthropic" | "claude" => {
468                let model = serdes_ai_models::AnthropicModel::from_env(model_name)
469                    .map_err(DirectError::ModelError)?;
470                Ok(Arc::new(model))
471            }
472            #[cfg(feature = "groq")]
473            "groq" => {
474                let model = serdes_ai_models::GroqModel::from_env(model_name)
475                    .map_err(DirectError::ModelError)?;
476                Ok(Arc::new(model))
477            }
478            #[cfg(feature = "mistral")]
479            "mistral" => {
480                let model = serdes_ai_models::MistralModel::from_env(model_name)
481                    .map_err(DirectError::ModelError)?;
482                Ok(Arc::new(model))
483            }
484            #[cfg(feature = "ollama")]
485            "ollama" => {
486                let model = serdes_ai_models::OllamaModel::from_env(model_name)
487                    .map_err(DirectError::ModelError)?;
488                Ok(Arc::new(model))
489            }
490            #[cfg(feature = "bedrock")]
491            "bedrock" | "aws" => {
492                let model = serdes_ai_models::BedrockModel::new(model_name)
493                    .map_err(DirectError::ModelError)?;
494                Ok(Arc::new(model))
495            }
496            _ => Err(DirectError::ProviderNotAvailable(provider.to_string())),
497        }
498    }
499}
500
501// ============================================================================
502// Tests
503// ============================================================================
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_model_spec_from_str() {
511        let spec: ModelSpec = "openai:gpt-4o".into();
512        assert!(matches!(spec, ModelSpec::Name(ref s) if s == "openai:gpt-4o"));
513    }
514
515    #[test]
516    fn test_model_spec_from_string() {
517        let spec: ModelSpec = String::from("anthropic:claude-3").into();
518        assert!(matches!(spec, ModelSpec::Name(ref s) if s == "anthropic:claude-3"));
519    }
520
521    #[test]
522    fn test_direct_error_display() {
523        let err = DirectError::InvalidModelName("bad-model".to_string());
524        assert!(err.to_string().contains("bad-model"));
525
526        let err = DirectError::ProviderNotAvailable("unknown".to_string());
527        assert!(err.to_string().contains("unknown"));
528
529        let err = DirectError::RuntimeError("something went wrong".to_string());
530        assert!(err.to_string().contains("something went wrong"));
531    }
532
533    #[test]
534    fn test_sync_runtime_detection() {
535        // In a normal sync context, this should not error due to runtime detection
536        // (but might fail due to missing API keys)
537        // We're just testing the runtime detection logic here
538
539        // Can't easily test the async context detection without actually being in one
540    }
541}