Skip to main content

rusty_commit/providers/
azure.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::{Deserialize, Serialize};
5
6use super::prompt::split_prompt;
7use super::AIProvider;
8use crate::config::Config;
9
10pub struct AzureProvider {
11    client: Client,
12    api_key: String,
13    endpoint: String,
14    deployment: String,
15}
16
17#[derive(Serialize)]
18struct AzureRequest {
19    messages: Vec<Message>,
20    max_tokens: u32,
21    temperature: f32,
22}
23
24#[derive(Serialize)]
25struct Message {
26    role: String,
27    content: String,
28}
29
30#[derive(Deserialize)]
31struct AzureResponse {
32    choices: Vec<Choice>,
33}
34
35#[derive(Deserialize)]
36struct Choice {
37    message: ResponseMessage,
38}
39
40#[derive(Deserialize)]
41struct ResponseMessage {
42    content: String,
43}
44
45impl AzureProvider {
46    pub fn new(config: &Config) -> Result<Self> {
47        let api_key = config
48            .api_key
49            .as_ref()
50            .context("Azure API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
51            .clone();
52
53        let endpoint = config
54            .api_url
55            .as_ref()
56            .context(
57                "Azure endpoint not configured. Run: rco config set RCO_API_URL=<your_endpoint>",
58            )?
59            .clone();
60
61        let deployment = config
62            .model
63            .as_deref()
64            .unwrap_or("gpt-35-turbo")
65            .to_string();
66
67        let client = Client::new();
68
69        Ok(Self {
70            client,
71            api_key,
72            endpoint,
73            deployment,
74        })
75    }
76
77    /// Create provider from account configuration
78    #[allow(dead_code)]
79    pub fn from_account(
80        account: &crate::config::accounts::AccountConfig,
81        api_key: &str,
82        config: &Config,
83    ) -> Result<Self> {
84        let endpoint = account
85            .api_url
86            .as_ref()
87            .context(
88                "Azure endpoint required. Set with: rco config set RCO_API_URL=<your_endpoint>",
89            )?
90            .clone();
91
92        let deployment = account
93            .model
94            .as_deref()
95            .or(config.model.as_deref())
96            .unwrap_or("gpt-35-turbo")
97            .to_string();
98
99        let client = Client::new();
100
101        Ok(Self {
102            client,
103            api_key: api_key.to_string(),
104            endpoint,
105            deployment,
106        })
107    }
108}
109
110#[async_trait]
111impl AIProvider for AzureProvider {
112    async fn generate_commit_message(
113        &self,
114        diff: &str,
115        context: Option<&str>,
116        full_gitmoji: bool,
117        config: &Config,
118    ) -> Result<String> {
119        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
120
121        let request = AzureRequest {
122            messages: vec![
123                Message {
124                    role: "system".to_string(),
125                    content: system_prompt,
126                },
127                Message {
128                    role: "user".to_string(),
129                    content: user_prompt,
130                },
131            ],
132            max_tokens: config.tokens_max_output.unwrap_or(500),
133            temperature: 0.7,
134        };
135
136        let url = format!(
137            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
138            self.endpoint, self.deployment
139        );
140
141        let response = self
142            .client
143            .post(&url)
144            .header("api-key", &self.api_key)
145            .header(header::CONTENT_TYPE, "application/json")
146            .json(&request)
147            .send()
148            .await
149            .context("Failed to connect to Azure OpenAI")?;
150
151        if !response.status().is_success() {
152            let error_text = response.text().await?;
153            anyhow::bail!("Azure OpenAI API error: {}", error_text);
154        }
155
156        let azure_response: AzureResponse = response
157            .json()
158            .await
159            .context("Failed to parse Azure OpenAI response")?;
160
161        let message = azure_response
162            .choices
163            .first()
164            .map(|c| c.message.content.trim().to_string())
165            .context("No response from Azure OpenAI")?;
166
167        Ok(message)
168    }
169}
170
171/// ProviderBuilder for Azure
172pub struct AzureProviderBuilder;
173
174impl super::registry::ProviderBuilder for AzureProviderBuilder {
175    fn name(&self) -> &'static str {
176        "azure"
177    }
178
179    fn aliases(&self) -> Vec<&'static str> {
180        vec!["azure-openai"]
181    }
182
183    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
184        Ok(Box::new(AzureProvider::new(config)?))
185    }
186
187    fn requires_api_key(&self) -> bool {
188        true
189    }
190
191    fn default_model(&self) -> Option<&'static str> {
192        Some("gpt-4o")
193    }
194}