strands_agents/models/
mod.rs1pub mod anthropic;
6pub mod bedrock;
7pub mod validation;
8pub mod gemini;
9pub mod litellm;
10pub mod llamaapi;
11pub mod llamacpp;
12pub mod mistral;
13pub mod ollama;
14pub mod openai;
15pub mod sagemaker;
16pub mod writer;
17
18use std::pin::Pin;
19
20use async_trait::async_trait;
21use futures::Stream;
22
23use crate::types::{
24 content::{Message, SystemContentBlock},
25 errors::StrandsError,
26 streaming::StreamEvent,
27 tools::{ToolChoice, ToolSpec},
28};
29
30#[derive(Debug, Clone, Default)]
32pub struct ModelConfig {
33 pub model_id: String,
34 pub max_tokens: Option<u32>,
35 pub temperature: Option<f32>,
36 pub top_p: Option<f32>,
37 pub stop_sequences: Option<Vec<String>>,
38 pub additional: std::collections::HashMap<String, serde_json::Value>,
39}
40
41impl ModelConfig {
42 pub fn new(model_id: impl Into<String>) -> Self {
43 Self { model_id: model_id.into(), ..Default::default() }
44 }
45
46 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
47 self.max_tokens = Some(max_tokens);
48 self
49 }
50
51 pub fn with_temperature(mut self, temperature: f32) -> Self {
52 self.temperature = Some(temperature);
53 self
54 }
55
56 pub fn with_top_p(mut self, top_p: f32) -> Self {
57 self.top_p = Some(top_p);
58 self
59 }
60}
61
62pub type StreamEventStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamEvent, StrandsError>> + Send + 'a>>;
64
65#[async_trait]
67pub trait Model: Send + Sync {
68 fn config(&self) -> &ModelConfig;
70
71 fn update_config(&mut self, config: ModelConfig);
73
74 fn stream<'a>(
76 &'a self,
77 messages: &'a [Message],
78 tool_specs: Option<&'a [ToolSpec]>,
79 system_prompt: Option<&'a str>,
80 tool_choice: Option<ToolChoice>,
81 system_prompt_content: Option<&'a [SystemContentBlock]>,
82 ) -> StreamEventStream<'a>;
83}
84
85#[async_trait]
87pub trait ModelExt: Model {
88 async fn structured_output<T>(
90 &self,
91 messages: &[Message],
92 system_prompt: Option<&str>,
93 ) -> Result<T, StrandsError>
94 where
95 T: serde::de::DeserializeOwned + schemars::JsonSchema + Send,
96 {
97 use futures::StreamExt;
98
99 let mut content = String::new();
100 let mut stream = self.stream(messages, None, system_prompt, None, None);
101
102 while let Some(event) = stream.next().await {
103 let event = event?;
104 if let Some(text) = event.as_text_delta() {
105 content.push_str(text);
106 }
107 }
108
109 serde_json::from_str(&content).map_err(|e| StrandsError::StructuredOutputError {
110 message: format!("Failed to parse structured output: {e}"),
111 })
112 }
113}
114
115impl<T: Model> ModelExt for T {}
116
117pub use anthropic::AnthropicModel;
118pub use bedrock::BedrockModel;
119pub use gemini::{GeminiConfig, GeminiModel};
120pub use litellm::{LiteLLMConfig, LiteLLMModel};
121pub use llamaapi::{LlamaAPIConfig, LlamaAPIModel};
122pub use llamacpp::{LlamaCppConfig, LlamaCppModel};
123pub use mistral::{MistralConfig, MistralModel};
124pub use ollama::OllamaModel;
125pub use openai::OpenAIModel;
126pub use sagemaker::{SageMakerEndpointConfig, SageMakerModel, SageMakerPayloadConfig};
127pub use writer::{WriterConfig, WriterModel};
128pub use validation::{
129 config_keys, validate_config_keys, warn_on_tool_choice_not_supported,
130};