rusty_commit/providers/
xai.rs1use anyhow::{Context, Result};
2use async_openai::{
3 config::OpenAIConfig,
4 types::chat::{
5 ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6 CreateChatCompletionRequestArgs,
7 },
8 Client,
9};
10use async_trait::async_trait;
11
12use super::{split_prompt, AIProvider};
13use crate::config::accounts::AccountConfig;
14use crate::config::Config;
15
16pub struct XAIProvider {
17 client: Client<OpenAIConfig>,
18 model: String,
19}
20
21impl XAIProvider {
22 pub fn new(config: &Config) -> Result<Self> {
23 let api_key = config
24 .api_key
25 .as_ref()
26 .context("xAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://x.ai/api")?;
27
28 let openai_config = OpenAIConfig::new()
29 .with_api_key(api_key)
30 .with_api_base(config.api_url.as_deref().unwrap_or("https://api.x.ai/v1"));
31
32 let client = Client::with_config(openai_config);
33 let model = config.model.as_deref().unwrap_or("grok-beta").to_string();
34
35 Ok(Self { client, model })
36 }
37
38 #[allow(dead_code)]
40 pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
41 let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
42 account
43 .api_url
44 .as_deref()
45 .or(config.api_url.as_deref())
46 .unwrap_or("https://api.x.ai/v1"),
47 );
48
49 let client = Client::with_config(openai_config);
50 let model = account
51 .model
52 .as_deref()
53 .or(config.model.as_deref())
54 .unwrap_or("grok-beta")
55 .to_string();
56
57 Ok(Self { client, model })
58 }
59}
60
61#[async_trait]
62impl AIProvider for XAIProvider {
63 async fn generate_commit_message(
64 &self,
65 diff: &str,
66 context: Option<&str>,
67 full_gitmoji: bool,
68 config: &Config,
69 ) -> Result<String> {
70 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
71
72 let messages = vec![
73 ChatCompletionRequestSystemMessage::from(system_prompt).into(),
74 ChatCompletionRequestUserMessage::from(user_prompt).into(),
75 ];
76
77 let request = CreateChatCompletionRequestArgs::default()
78 .model(&self.model)
79 .messages(messages)
80 .temperature(0.7)
81 .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
82 .build()?;
83
84 let response = self
85 .client
86 .chat()
87 .create(request)
88 .await
89 .context("Failed to generate commit message from xAI")?;
90
91 let message = response
92 .choices
93 .first()
94 .and_then(|choice| choice.message.content.as_ref())
95 .context("xAI returned an empty response")?
96 .trim()
97 .to_string();
98
99 Ok(message)
100 }
101}
102
103pub struct XAIProviderBuilder;
105
106impl super::registry::ProviderBuilder for XAIProviderBuilder {
107 fn name(&self) -> &'static str {
108 "xai"
109 }
110
111 fn aliases(&self) -> Vec<&'static str> {
112 vec!["grok", "x-ai"]
113 }
114
115 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
116 Ok(Box::new(XAIProvider::new(config)?))
117 }
118
119 fn requires_api_key(&self) -> bool {
120 true
121 }
122
123 fn default_model(&self) -> Option<&'static str> {
124 Some("grok-beta")
125 }
126}