rusty_commit/providers/
xai.rs1use anyhow::{Context, Result};
2use async_openai::{
3 config::OpenAIConfig,
4 types::chat::{
5 ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6 CreateChatCompletionRequestArgs,
7 },
8 Client,
9};
10use async_trait::async_trait;
11
12use super::prompt::split_prompt;
13use super::AIProvider;
14use crate::config::accounts::AccountConfig;
15use crate::config::Config;
16
17pub struct XAIProvider {
18 client: Client<OpenAIConfig>,
19 model: String,
20}
21
22impl XAIProvider {
23 pub fn new(config: &Config) -> Result<Self> {
24 let api_key = config
25 .api_key
26 .as_ref()
27 .context("xAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://x.ai/api")?;
28
29 let openai_config = OpenAIConfig::new()
30 .with_api_key(api_key)
31 .with_api_base(config.api_url.as_deref().unwrap_or("https://api.x.ai/v1"));
32
33 let client = Client::with_config(openai_config);
34 let model = config.model.as_deref().unwrap_or("grok-beta").to_string();
35
36 Ok(Self { client, model })
37 }
38
39 #[allow(dead_code)]
41 pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
42 let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
43 account
44 .api_url
45 .as_deref()
46 .or(config.api_url.as_deref())
47 .unwrap_or("https://api.x.ai/v1"),
48 );
49
50 let client = Client::with_config(openai_config);
51 let model = account
52 .model
53 .as_deref()
54 .or(config.model.as_deref())
55 .unwrap_or("grok-beta")
56 .to_string();
57
58 Ok(Self { client, model })
59 }
60}
61
62#[async_trait]
63impl AIProvider for XAIProvider {
64 async fn generate_commit_message(
65 &self,
66 diff: &str,
67 context: Option<&str>,
68 full_gitmoji: bool,
69 config: &Config,
70 ) -> Result<String> {
71 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
72
73 let messages = vec![
74 ChatCompletionRequestSystemMessage::from(system_prompt).into(),
75 ChatCompletionRequestUserMessage::from(user_prompt).into(),
76 ];
77
78 let request = CreateChatCompletionRequestArgs::default()
79 .model(&self.model)
80 .messages(messages)
81 .temperature(0.7)
82 .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
83 .build()?;
84
85 let response = self
86 .client
87 .chat()
88 .create(request)
89 .await
90 .context("Failed to generate commit message from xAI")?;
91
92 let message = response
93 .choices
94 .first()
95 .and_then(|choice| choice.message.content.as_ref())
96 .context("xAI returned an empty response")?
97 .trim()
98 .to_string();
99
100 Ok(message)
101 }
102}
103
104pub struct XAIProviderBuilder;
106
107impl super::registry::ProviderBuilder for XAIProviderBuilder {
108 fn name(&self) -> &'static str {
109 "xai"
110 }
111
112 fn aliases(&self) -> Vec<&'static str> {
113 vec!["grok", "x-ai"]
114 }
115
116 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
117 Ok(Box::new(XAIProvider::new(config)?))
118 }
119
120 fn requires_api_key(&self) -> bool {
121 true
122 }
123
124 fn default_model(&self) -> Option<&'static str> {
125 Some("grok-beta")
126 }
127}