rusty_commit/providers/
openai.rs1use anyhow::{Context, Result};
2use async_openai::{
3 config::OpenAIConfig,
4 types::chat::{
5 ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6 CreateChatCompletionRequestArgs,
7 },
8 Client,
9};
10use async_trait::async_trait;
11
12use super::{split_prompt, AIProvider};
13use crate::config::accounts::AccountConfig;
14use crate::config::Config;
15use crate::utils::retry::retry_async;
16
17pub struct OpenAIProvider {
18 client: Client<OpenAIConfig>,
19 model: String,
20}
21
22impl OpenAIProvider {
23 pub fn new(config: &Config) -> Result<Self> {
24 let api_key = config
25 .api_key
26 .as_ref()
27 .context("OpenAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://platform.openai.com/api-keys")?;
28
29 let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
30 config
31 .api_url
32 .as_deref()
33 .unwrap_or("https://api.openai.com/v1"),
34 );
35
36 let client = Client::with_config(openai_config);
37 let model = config.model.as_deref().unwrap_or("gpt-4o-mini").to_string();
38
39 Ok(Self { client, model })
40 }
41
42 #[allow(dead_code)]
44 pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
45 let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
46 account
47 .api_url
48 .as_deref()
49 .or(config.api_url.as_deref())
50 .unwrap_or("https://api.openai.com/v1"),
51 );
52
53 let client = Client::with_config(openai_config);
54 let model = account
55 .model
56 .as_deref()
57 .or(config.model.as_deref())
58 .unwrap_or("gpt-4o-mini")
59 .to_string();
60
61 Ok(Self { client, model })
62 }
63}
64
65#[async_trait]
66impl AIProvider for OpenAIProvider {
67 async fn generate_commit_message(
68 &self,
69 diff: &str,
70 context: Option<&str>,
71 full_gitmoji: bool,
72 config: &Config,
73 ) -> Result<String> {
74 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
75
76 let messages = vec![
77 ChatCompletionRequestSystemMessage::from(system_prompt).into(),
78 ChatCompletionRequestUserMessage::from(user_prompt).into(),
79 ];
80
81 let request = if self.model.contains("gpt-5-nano") {
83 CreateChatCompletionRequestArgs::default()
85 .model(&self.model)
86 .messages(messages)
87 .temperature(1.0)
88 .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
89 .build()?
90 } else {
91 CreateChatCompletionRequestArgs::default()
93 .model(&self.model)
94 .messages(messages)
95 .temperature(0.7)
96 .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
97 .build()?
98 };
99
100 let response = retry_async(|| async {
101 match self.client.chat().create(request.clone()).await {
102 Ok(resp) => Ok(resp),
103 Err(e) => {
104 let error_msg = e.to_string();
105 if error_msg.contains("401") || error_msg.contains("invalid_api_key") {
106 Err(anyhow::anyhow!("Invalid OpenAI API key. Please check your API key configuration."))
107 } else if error_msg.contains("insufficient_quota") {
108 Err(anyhow::anyhow!("OpenAI API quota exceeded. Please check your billing status."))
109 } else {
110 Err(anyhow::anyhow!(e).context("Failed to generate commit message from OpenAI"))
111 }
112 }
113 }
114 }).await.context("Failed to generate commit message from OpenAI after retries. Please check your internet connection and API configuration.")?;
115
116 let message = response
117 .choices
118 .first()
119 .and_then(|choice| choice.message.content.as_ref())
120 .context("OpenAI returned an empty response. The model may be overloaded - please try again.")?
121 .trim()
122 .to_string();
123
124 Ok(message)
125 }
126}
127
128#[allow(dead_code)]
131pub struct OpenAICompatibleProvider {
132 pub name: &'static str,
133 pub aliases: Vec<&'static str>,
134 pub default_api_url: &'static str,
135 pub default_model: Option<&'static str>,
136 pub compatible_providers: std::collections::HashMap<&'static str, &'static str>,
137}
138
139impl OpenAICompatibleProvider {
140 pub fn new() -> Self {
141 let mut compat = std::collections::HashMap::new();
142 compat.insert("deepseek", "https://api.deepseek.com/v1");
143 compat.insert("groq", "https://api.groq.com/openai/v1");
144 compat.insert("openrouter", "https://openrouter.ai/api/v1");
145 compat.insert("together", "https://api.together.ai/v1");
146 compat.insert("deepinfra", "https://api.deepinfra.com/v1/openai");
147 compat.insert("mistral", "https://api.mistral.ai/v1");
148 compat.insert("github-models", "https://models.inference.ai.azure.com");
149 compat.insert("fireworks", "https://api.fireworks.ai/v1");
150 compat.insert("fireworks-ai", "https://api.fireworks.ai/v1");
151 compat.insert("moonshot", "https://api.moonshot.cn/v1");
152 compat.insert("moonshot-ai", "https://api.moonshot.cn/v1");
153 compat.insert("dashscope", "https://dashscope.console.aliyuncs.com/api/v1");
154 compat.insert("alibaba", "https://dashscope.console.aliyuncs.com/api/v1");
155 compat.insert("qwen", "https://dashscope.console.aliyuncs.com/api/v1");
156 compat.insert(
157 "qwen-coder",
158 "https://dashscope.console.aliyuncs.com/api/v1",
159 );
160 compat.insert("codex", "https://api.openai.com/v1");
161
162 Self {
163 name: "openai",
164 aliases: vec!["openai"],
165 default_api_url: "https://api.openai.com/v1",
166 default_model: Some("gpt-4o-mini"),
167 compatible_providers: compat,
168 }
169 }
170}
171
172impl Default for OpenAICompatibleProvider {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178impl super::registry::ProviderBuilder for OpenAICompatibleProvider {
179 fn name(&self) -> &'static str {
180 self.name
181 }
182
183 fn aliases(&self) -> Vec<&'static str> {
184 self.aliases.clone()
185 }
186
187 fn category(&self) -> super::registry::ProviderCategory {
188 super::registry::ProviderCategory::OpenAICompatible
189 }
190
191 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
192 Ok(Box::new(OpenAIProvider::new(config)?))
193 }
194
195 fn requires_api_key(&self) -> bool {
196 true
197 }
198
199 fn default_model(&self) -> Option<&'static str> {
200 self.default_model
201 }
202}