Skip to main content

rusty_commit/providers/
bedrock.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use aws_config::Region;
4use aws_sdk_bedrockruntime as bedrock;
5use aws_sdk_bedrockruntime::types::{ContentBlock, SystemContentBlock};
6
7use super::{split_prompt, AIProvider};
8use crate::config::Config;
9
10pub struct BedrockProvider {
11    client: bedrock::Client,
12    model: String,
13    #[allow(dead_code)]
14    region: String,
15}
16
17impl BedrockProvider {
18    pub fn new(config: &Config) -> Result<Self> {
19        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
20        rt.block_on(async { Self::new_async(config).await })
21    }
22
23    async fn new_async(config: &Config) -> Result<Self> {
24        let region = config
25            .api_url
26            .as_ref()
27            .and_then(|url| {
28                url.split("bedrock.")
29                    .nth(1)
30                    .and_then(|s| s.split('.').next())
31                    .map(|s| s.to_string())
32            })
33            .unwrap_or_else(|| {
34                std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string())
35            });
36
37        let region_provider = Region::new(region.clone());
38        let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
39            .region(region_provider)
40            .load()
41            .await;
42
43        let client = bedrock::Client::new(&shared_config);
44
45        let model = config
46            .model
47            .as_deref()
48            .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
49            .to_string();
50
51        Ok(Self {
52            client,
53            model,
54            region,
55        })
56    }
57
58    #[allow(dead_code)]
59    pub fn from_account(
60        account: &crate::config::accounts::AccountConfig,
61        _api_key: &str,
62        config: &Config,
63    ) -> Result<Self> {
64        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
65        rt.block_on(async { Self::from_account_async(account, config).await })
66    }
67
68    async fn from_account_async(
69        account: &crate::config::accounts::AccountConfig,
70        config: &Config,
71    ) -> Result<Self> {
72        let region = account
73            .api_url
74            .as_ref()
75            .and_then(|url| {
76                url.split("bedrock.")
77                    .nth(1)
78                    .and_then(|s| s.split('.').next())
79                    .map(|s| s.to_string())
80            })
81            .unwrap_or_else(|| "us-east-1".to_string());
82
83        let region_provider = Region::new(region.clone());
84        let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
85            .region(region_provider)
86            .load()
87            .await;
88
89        let client = bedrock::Client::new(&shared_config);
90
91        let model = account
92            .model
93            .as_deref()
94            .or(config.model.as_deref())
95            .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
96            .to_string();
97
98        Ok(Self {
99            client,
100            model,
101            region,
102        })
103    }
104}
105
106#[async_trait]
107impl AIProvider for BedrockProvider {
108    async fn generate_commit_message(
109        &self,
110        diff: &str,
111        context: Option<&str>,
112        full_gitmoji: bool,
113        config: &Config,
114    ) -> Result<String> {
115        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
116
117        // Build system message
118        let system_block = SystemContentBlock::Text(system_prompt);
119
120        // Build user message
121        let user_content = ContentBlock::Text(user_prompt);
122        let user_message = bedrock::types::Message::builder()
123            .role(bedrock::types::ConversationRole::User)
124            .content(user_content)
125            .build()
126            .context("Failed to build user message")?;
127
128        let inference_config = bedrock::types::InferenceConfiguration::builder()
129            .max_tokens(config.tokens_max_output.unwrap_or(500) as i32)
130            .temperature(0.7)
131            .build();
132
133        let converse_output = self
134            .client
135            .converse()
136            .model_id(&self.model)
137            .messages(user_message)
138            .system(system_block)
139            .inference_config(inference_config)
140            .send()
141            .await
142            .context("Failed to communicate with Bedrock")?;
143
144        let message = converse_output
145            .output()
146            .and_then(|o| o.as_message().ok())
147            .context("No response from Bedrock")?;
148
149        let content = message
150            .content()
151            .first()
152            .and_then(|c| c.as_text().ok())
153            .context("Empty response from Bedrock")?;
154
155        Ok(content.trim().to_string())
156    }
157}
158
159/// ProviderBuilder for Bedrock
160pub struct BedrockProviderBuilder;
161
162impl super::registry::ProviderBuilder for BedrockProviderBuilder {
163    fn name(&self) -> &'static str {
164        "bedrock"
165    }
166
167    fn aliases(&self) -> Vec<&'static str> {
168        vec!["aws-bedrock", "amazon-bedrock"]
169    }
170
171    fn category(&self) -> super::registry::ProviderCategory {
172        super::registry::ProviderCategory::Cloud
173    }
174
175    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
176        Ok(Box::new(BedrockProvider::new(config)?))
177    }
178
179    fn requires_api_key(&self) -> bool {
180        false
181    }
182
183    fn default_model(&self) -> Option<&'static str> {
184        Some("anthropic.claude-3-5-sonnet-20241022-v2:0")
185    }
186}