strands_agents/models/
sagemaker.rs

1//! Amazon SageMaker model provider.
2//!
3//! This provider integrates with Amazon SageMaker endpoints.
4
5use std::collections::HashMap;
6
7use crate::types::content::{Message, SystemContentBlock};
8use crate::types::errors::StrandsError;
9use crate::types::tools::{ToolChoice, ToolSpec};
10
11use super::{Model, ModelConfig, StreamEventStream};
12
13/// Endpoint configuration for SageMaker.
14#[derive(Debug, Clone, Default)]
15pub struct SageMakerEndpointConfig {
16    /// The name of the SageMaker endpoint to invoke.
17    pub endpoint_name: String,
18    /// AWS region name.
19    pub region_name: Option<String>,
20    /// The name of the inference component to use.
21    pub inference_component_name: Option<String>,
22    /// Target model for multi-model endpoints.
23    pub target_model: Option<String>,
24    /// Target variant.
25    pub target_variant: Option<String>,
26    /// Additional arguments for the request.
27    pub additional_args: Option<HashMap<String, serde_json::Value>>,
28}
29
30impl SageMakerEndpointConfig {
31    /// Create a new endpoint config.
32    pub fn new(endpoint_name: impl Into<String>) -> Self {
33        Self {
34            endpoint_name: endpoint_name.into(),
35            ..Default::default()
36        }
37    }
38
39    /// Set region name.
40    pub fn with_region(mut self, region: impl Into<String>) -> Self {
41        self.region_name = Some(region.into());
42        self
43    }
44
45    /// Set inference component name.
46    pub fn with_inference_component(mut self, component: impl Into<String>) -> Self {
47        self.inference_component_name = Some(component.into());
48        self
49    }
50
51    /// Set target model.
52    pub fn with_target_model(mut self, model: impl Into<String>) -> Self {
53        self.target_model = Some(model.into());
54        self
55    }
56}
57
58/// Payload configuration for SageMaker.
59#[derive(Debug, Clone, Default)]
60pub struct SageMakerPayloadConfig {
61    /// Maximum number of tokens to generate.
62    pub max_tokens: Option<u32>,
63    /// Whether to stream the response.
64    pub stream: bool,
65    /// Sampling temperature.
66    pub temperature: Option<f64>,
67    /// Top-p (nucleus sampling).
68    pub top_p: Option<f64>,
69    /// Top-k sampling.
70    pub top_k: Option<u32>,
71    /// Stop sequences.
72    pub stop: Option<Vec<String>>,
73    /// Convert tool results to user messages.
74    pub tool_results_as_user_messages: bool,
75    /// Additional arguments.
76    pub additional_args: Option<HashMap<String, serde_json::Value>>,
77}
78
79impl SageMakerPayloadConfig {
80    /// Create a new payload config.
81    pub fn new() -> Self {
82        Self {
83            stream: true,
84            ..Default::default()
85        }
86    }
87
88    /// Set max tokens.
89    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
90        self.max_tokens = Some(max_tokens);
91        self
92    }
93
94    /// Set temperature.
95    pub fn with_temperature(mut self, temperature: f64) -> Self {
96        self.temperature = Some(temperature);
97        self
98    }
99
100    /// Set streaming mode.
101    pub fn with_stream(mut self, stream: bool) -> Self {
102        self.stream = stream;
103        self
104    }
105}
106
107/// Amazon SageMaker model provider implementation.
108pub struct SageMakerModel {
109    config: ModelConfig,
110    endpoint_config: SageMakerEndpointConfig,
111    payload_config: SageMakerPayloadConfig,
112}
113
114impl SageMakerModel {
115    /// Create a new SageMaker model.
116    pub fn new(endpoint_config: SageMakerEndpointConfig, payload_config: SageMakerPayloadConfig) -> Self {
117        Self {
118            config: ModelConfig::new(&endpoint_config.endpoint_name),
119            endpoint_config,
120            payload_config,
121        }
122    }
123
124    /// Get the endpoint configuration.
125    pub fn endpoint_config(&self) -> &SageMakerEndpointConfig {
126        &self.endpoint_config
127    }
128
129    /// Get the payload configuration.
130    pub fn payload_config(&self) -> &SageMakerPayloadConfig {
131        &self.payload_config
132    }
133
134    /// Update the endpoint configuration.
135    pub fn update_endpoint_config(&mut self, config: SageMakerEndpointConfig) {
136        self.config = ModelConfig::new(&config.endpoint_name);
137        self.endpoint_config = config;
138    }
139
140    /// Update the payload configuration.
141    pub fn update_payload_config(&mut self, config: SageMakerPayloadConfig) {
142        self.payload_config = config;
143    }
144}
145
146impl Model for SageMakerModel {
147    fn config(&self) -> &ModelConfig {
148        &self.config
149    }
150
151    fn update_config(&mut self, config: ModelConfig) {
152        self.config = config;
153    }
154
155    fn stream<'a>(
156        &'a self,
157        _messages: &'a [Message],
158        _tool_specs: Option<&'a [ToolSpec]>,
159        _system_prompt: Option<&'a str>,
160        _tool_choice: Option<ToolChoice>,
161        _system_prompt_content: Option<&'a [SystemContentBlock]>,
162    ) -> StreamEventStream<'a> {
163        Box::pin(futures::stream::once(async {
164            Err(StrandsError::ModelError {
165                message: "SageMaker integration requires aws-sdk-sagemakerruntime implementation".into(),
166                source: None,
167            })
168        }))
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_sagemaker_endpoint_config() {
178        let config = SageMakerEndpointConfig::new("my-endpoint")
179            .with_region("us-west-2")
180            .with_target_model("my-model");
181        
182        assert_eq!(config.endpoint_name, "my-endpoint");
183        assert_eq!(config.region_name, Some("us-west-2".to_string()));
184        assert_eq!(config.target_model, Some("my-model".to_string()));
185    }
186
187    #[test]
188    fn test_sagemaker_payload_config() {
189        let config = SageMakerPayloadConfig::new()
190            .with_max_tokens(1000)
191            .with_temperature(0.7);
192        
193        assert_eq!(config.max_tokens, Some(1000));
194        assert_eq!(config.temperature, Some(0.7));
195        assert!(config.stream);
196    }
197
198    #[test]
199    fn test_sagemaker_model_creation() {
200        let endpoint_config = SageMakerEndpointConfig::new("test-endpoint");
201        let payload_config = SageMakerPayloadConfig::new();
202        let model = SageMakerModel::new(endpoint_config, payload_config);
203        
204        assert_eq!(model.config().model_id, "test-endpoint");
205    }
206}