strands_agents/models/
mod.rs

1//! Model traits and implementations.
2//!
3//! All model providers are always compiled (no feature flags required).
4
5pub mod anthropic;
6pub mod bedrock;
7pub mod validation;
8pub mod gemini;
9pub mod litellm;
10pub mod llamaapi;
11pub mod llamacpp;
12pub mod mistral;
13pub mod ollama;
14pub mod openai;
15pub mod sagemaker;
16pub mod writer;
17
18use std::pin::Pin;
19
20use async_trait::async_trait;
21use futures::Stream;
22
23use crate::types::{
24    content::{Message, SystemContentBlock},
25    errors::StrandsError,
26    streaming::StreamEvent,
27    tools::{ToolChoice, ToolSpec},
28};
29
30/// Configuration for a model.
31#[derive(Debug, Clone, Default)]
32pub struct ModelConfig {
33    pub model_id: String,
34    pub max_tokens: Option<u32>,
35    pub temperature: Option<f32>,
36    pub top_p: Option<f32>,
37    pub stop_sequences: Option<Vec<String>>,
38    pub additional: std::collections::HashMap<String, serde_json::Value>,
39}
40
41impl ModelConfig {
42    pub fn new(model_id: impl Into<String>) -> Self {
43        Self { model_id: model_id.into(), ..Default::default() }
44    }
45
46    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
47        self.max_tokens = Some(max_tokens);
48        self
49    }
50
51    pub fn with_temperature(mut self, temperature: f32) -> Self {
52        self.temperature = Some(temperature);
53        self
54    }
55
56    pub fn with_top_p(mut self, top_p: f32) -> Self {
57        self.top_p = Some(top_p);
58        self
59    }
60}
61
62/// A stream of model response events.
63pub type StreamEventStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamEvent, StrandsError>> + Send + 'a>>;
64
65/// Trait for model implementations.
66#[async_trait]
67pub trait Model: Send + Sync {
68    /// Returns the model configuration.
69    fn config(&self) -> &ModelConfig;
70
71    /// Updates the model configuration.
72    fn update_config(&mut self, config: ModelConfig);
73
74    /// Streams a response from the model.
75    fn stream<'a>(
76        &'a self,
77        messages: &'a [Message],
78        tool_specs: Option<&'a [ToolSpec]>,
79        system_prompt: Option<&'a str>,
80        tool_choice: Option<ToolChoice>,
81        system_prompt_content: Option<&'a [SystemContentBlock]>,
82    ) -> StreamEventStream<'a>;
83}
84
85/// Extension trait for models with additional functionality.
86#[async_trait]
87pub trait ModelExt: Model {
88    /// Generates a structured output from the model.
89    async fn structured_output<T>(
90        &self,
91        messages: &[Message],
92        system_prompt: Option<&str>,
93    ) -> Result<T, StrandsError>
94    where
95        T: serde::de::DeserializeOwned + schemars::JsonSchema + Send,
96    {
97        use futures::StreamExt;
98
99        let mut content = String::new();
100        let mut stream = self.stream(messages, None, system_prompt, None, None);
101
102        while let Some(event) = stream.next().await {
103            let event = event?;
104            if let Some(text) = event.as_text_delta() {
105                content.push_str(text);
106            }
107        }
108
109        serde_json::from_str(&content).map_err(|e| StrandsError::StructuredOutputError {
110            message: format!("Failed to parse structured output: {e}"),
111        })
112    }
113}
114
115impl<T: Model> ModelExt for T {}
116
117pub use anthropic::AnthropicModel;
118pub use bedrock::BedrockModel;
119pub use gemini::{GeminiConfig, GeminiModel};
120pub use litellm::{LiteLLMConfig, LiteLLMModel};
121pub use llamaapi::{LlamaAPIConfig, LlamaAPIModel};
122pub use llamacpp::{LlamaCppConfig, LlamaCppModel};
123pub use mistral::{MistralConfig, MistralModel};
124pub use ollama::OllamaModel;
125pub use openai::OpenAIModel;
126pub use sagemaker::{SageMakerEndpointConfig, SageMakerModel, SageMakerPayloadConfig};
127pub use writer::{WriterConfig, WriterModel};
128pub use validation::{
129    config_keys, validate_config_keys, warn_on_tool_choice_not_supported,
130};