strands_agents/models/
sagemaker.rs1use std::collections::HashMap;
6
7use crate::types::content::{Message, SystemContentBlock};
8use crate::types::errors::StrandsError;
9use crate::types::tools::{ToolChoice, ToolSpec};
10
11use super::{Model, ModelConfig, StreamEventStream};
12
13#[derive(Debug, Clone, Default)]
15pub struct SageMakerEndpointConfig {
16 pub endpoint_name: String,
18 pub region_name: Option<String>,
20 pub inference_component_name: Option<String>,
22 pub target_model: Option<String>,
24 pub target_variant: Option<String>,
26 pub additional_args: Option<HashMap<String, serde_json::Value>>,
28}
29
30impl SageMakerEndpointConfig {
31 pub fn new(endpoint_name: impl Into<String>) -> Self {
33 Self {
34 endpoint_name: endpoint_name.into(),
35 ..Default::default()
36 }
37 }
38
39 pub fn with_region(mut self, region: impl Into<String>) -> Self {
41 self.region_name = Some(region.into());
42 self
43 }
44
45 pub fn with_inference_component(mut self, component: impl Into<String>) -> Self {
47 self.inference_component_name = Some(component.into());
48 self
49 }
50
51 pub fn with_target_model(mut self, model: impl Into<String>) -> Self {
53 self.target_model = Some(model.into());
54 self
55 }
56}
57
58#[derive(Debug, Clone, Default)]
60pub struct SageMakerPayloadConfig {
61 pub max_tokens: Option<u32>,
63 pub stream: bool,
65 pub temperature: Option<f64>,
67 pub top_p: Option<f64>,
69 pub top_k: Option<u32>,
71 pub stop: Option<Vec<String>>,
73 pub tool_results_as_user_messages: bool,
75 pub additional_args: Option<HashMap<String, serde_json::Value>>,
77}
78
79impl SageMakerPayloadConfig {
80 pub fn new() -> Self {
82 Self {
83 stream: true,
84 ..Default::default()
85 }
86 }
87
88 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
90 self.max_tokens = Some(max_tokens);
91 self
92 }
93
94 pub fn with_temperature(mut self, temperature: f64) -> Self {
96 self.temperature = Some(temperature);
97 self
98 }
99
100 pub fn with_stream(mut self, stream: bool) -> Self {
102 self.stream = stream;
103 self
104 }
105}
106
107pub struct SageMakerModel {
109 config: ModelConfig,
110 endpoint_config: SageMakerEndpointConfig,
111 payload_config: SageMakerPayloadConfig,
112}
113
114impl SageMakerModel {
115 pub fn new(endpoint_config: SageMakerEndpointConfig, payload_config: SageMakerPayloadConfig) -> Self {
117 Self {
118 config: ModelConfig::new(&endpoint_config.endpoint_name),
119 endpoint_config,
120 payload_config,
121 }
122 }
123
124 pub fn endpoint_config(&self) -> &SageMakerEndpointConfig {
126 &self.endpoint_config
127 }
128
129 pub fn payload_config(&self) -> &SageMakerPayloadConfig {
131 &self.payload_config
132 }
133
134 pub fn update_endpoint_config(&mut self, config: SageMakerEndpointConfig) {
136 self.config = ModelConfig::new(&config.endpoint_name);
137 self.endpoint_config = config;
138 }
139
140 pub fn update_payload_config(&mut self, config: SageMakerPayloadConfig) {
142 self.payload_config = config;
143 }
144}
145
146impl Model for SageMakerModel {
147 fn config(&self) -> &ModelConfig {
148 &self.config
149 }
150
151 fn update_config(&mut self, config: ModelConfig) {
152 self.config = config;
153 }
154
155 fn stream<'a>(
156 &'a self,
157 _messages: &'a [Message],
158 _tool_specs: Option<&'a [ToolSpec]>,
159 _system_prompt: Option<&'a str>,
160 _tool_choice: Option<ToolChoice>,
161 _system_prompt_content: Option<&'a [SystemContentBlock]>,
162 ) -> StreamEventStream<'a> {
163 Box::pin(futures::stream::once(async {
164 Err(StrandsError::ModelError {
165 message: "SageMaker integration requires aws-sdk-sagemakerruntime implementation".into(),
166 source: None,
167 })
168 }))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_sagemaker_endpoint_config() {
178 let config = SageMakerEndpointConfig::new("my-endpoint")
179 .with_region("us-west-2")
180 .with_target_model("my-model");
181
182 assert_eq!(config.endpoint_name, "my-endpoint");
183 assert_eq!(config.region_name, Some("us-west-2".to_string()));
184 assert_eq!(config.target_model, Some("my-model".to_string()));
185 }
186
187 #[test]
188 fn test_sagemaker_payload_config() {
189 let config = SageMakerPayloadConfig::new()
190 .with_max_tokens(1000)
191 .with_temperature(0.7);
192
193 assert_eq!(config.max_tokens, Some(1000));
194 assert_eq!(config.temperature, Some(0.7));
195 assert!(config.stream);
196 }
197
198 #[test]
199 fn test_sagemaker_model_creation() {
200 let endpoint_config = SageMakerEndpointConfig::new("test-endpoint");
201 let payload_config = SageMakerPayloadConfig::new();
202 let model = SageMakerModel::new(endpoint_config, payload_config);
203
204 assert_eq!(model.config().model_id, "test-endpoint");
205 }
206}