rusty_commit/providers/
azure.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::{Deserialize, Serialize};
5
6use super::prompt::split_prompt;
7use super::AIProvider;
8use crate::config::Config;
9
10pub struct AzureProvider {
11 client: Client,
12 api_key: String,
13 endpoint: String,
14 deployment: String,
15}
16
17#[derive(Serialize)]
18struct AzureRequest {
19 messages: Vec<Message>,
20 max_tokens: u32,
21 temperature: f32,
22}
23
24#[derive(Serialize)]
25struct Message {
26 role: String,
27 content: String,
28}
29
30#[derive(Deserialize)]
31struct AzureResponse {
32 choices: Vec<Choice>,
33}
34
35#[derive(Deserialize)]
36struct Choice {
37 message: ResponseMessage,
38}
39
40#[derive(Deserialize)]
41struct ResponseMessage {
42 content: String,
43}
44
45impl AzureProvider {
46 pub fn new(config: &Config) -> Result<Self> {
47 let api_key = config
48 .api_key
49 .as_ref()
50 .context("Azure API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
51 .clone();
52
53 let endpoint = config
54 .api_url
55 .as_ref()
56 .context(
57 "Azure endpoint not configured. Run: rco config set RCO_API_URL=<your_endpoint>",
58 )?
59 .clone();
60
61 let deployment = config
62 .model
63 .as_deref()
64 .unwrap_or("gpt-35-turbo")
65 .to_string();
66
67 let client = Client::new();
68
69 Ok(Self {
70 client,
71 api_key,
72 endpoint,
73 deployment,
74 })
75 }
76
77 #[allow(dead_code)]
79 pub fn from_account(
80 account: &crate::config::accounts::AccountConfig,
81 api_key: &str,
82 config: &Config,
83 ) -> Result<Self> {
84 let endpoint = account
85 .api_url
86 .as_ref()
87 .context(
88 "Azure endpoint required. Set with: rco config set RCO_API_URL=<your_endpoint>",
89 )?
90 .clone();
91
92 let deployment = account
93 .model
94 .as_deref()
95 .or(config.model.as_deref())
96 .unwrap_or("gpt-35-turbo")
97 .to_string();
98
99 let client = Client::new();
100
101 Ok(Self {
102 client,
103 api_key: api_key.to_string(),
104 endpoint,
105 deployment,
106 })
107 }
108}
109
110#[async_trait]
111impl AIProvider for AzureProvider {
112 async fn generate_commit_message(
113 &self,
114 diff: &str,
115 context: Option<&str>,
116 full_gitmoji: bool,
117 config: &Config,
118 ) -> Result<String> {
119 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
120
121 let request = AzureRequest {
122 messages: vec![
123 Message {
124 role: "system".to_string(),
125 content: system_prompt,
126 },
127 Message {
128 role: "user".to_string(),
129 content: user_prompt,
130 },
131 ],
132 max_tokens: config.tokens_max_output.unwrap_or(500),
133 temperature: 0.7,
134 };
135
136 let url = format!(
137 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
138 self.endpoint, self.deployment
139 );
140
141 let response = self
142 .client
143 .post(&url)
144 .header("api-key", &self.api_key)
145 .header(header::CONTENT_TYPE, "application/json")
146 .json(&request)
147 .send()
148 .await
149 .context("Failed to connect to Azure OpenAI")?;
150
151 if !response.status().is_success() {
152 let error_text = response.text().await?;
153 anyhow::bail!("Azure OpenAI API error: {}", error_text);
154 }
155
156 let azure_response: AzureResponse = response
157 .json()
158 .await
159 .context("Failed to parse Azure OpenAI response")?;
160
161 let message = azure_response
162 .choices
163 .first()
164 .map(|c| c.message.content.trim().to_string())
165 .context("No response from Azure OpenAI")?;
166
167 Ok(message)
168 }
169}
170
171pub struct AzureProviderBuilder;
173
174impl super::registry::ProviderBuilder for AzureProviderBuilder {
175 fn name(&self) -> &'static str {
176 "azure"
177 }
178
179 fn aliases(&self) -> Vec<&'static str> {
180 vec!["azure-openai"]
181 }
182
183 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
184 Ok(Box::new(AzureProvider::new(config)?))
185 }
186
187 fn requires_api_key(&self) -> bool {
188 true
189 }
190
191 fn default_model(&self) -> Option<&'static str> {
192 Some("gpt-4o")
193 }
194}