Skip to main content

rusty_commit/providers/
xai.rs

1use anyhow::{Context, Result};
2use async_openai::{
3    config::OpenAIConfig,
4    types::chat::{
5        ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6        CreateChatCompletionRequestArgs,
7    },
8    Client,
9};
10use async_trait::async_trait;
11
12use super::prompt::split_prompt;
13use super::AIProvider;
14use crate::config::accounts::AccountConfig;
15use crate::config::Config;
16
17pub struct XAIProvider {
18    client: Client<OpenAIConfig>,
19    model: String,
20}
21
22impl XAIProvider {
23    pub fn new(config: &Config) -> Result<Self> {
24        let api_key = config
25            .api_key
26            .as_ref()
27            .context("xAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://x.ai/api")?;
28
29        let openai_config = OpenAIConfig::new()
30            .with_api_key(api_key)
31            .with_api_base(config.api_url.as_deref().unwrap_or("https://api.x.ai/v1"));
32
33        let client = Client::with_config(openai_config);
34        let model = config.model.as_deref().unwrap_or("grok-beta").to_string();
35
36        Ok(Self { client, model })
37    }
38
39    /// Create provider from account configuration
40    #[allow(dead_code)]
41    pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
42        let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
43            account
44                .api_url
45                .as_deref()
46                .or(config.api_url.as_deref())
47                .unwrap_or("https://api.x.ai/v1"),
48        );
49
50        let client = Client::with_config(openai_config);
51        let model = account
52            .model
53            .as_deref()
54            .or(config.model.as_deref())
55            .unwrap_or("grok-beta")
56            .to_string();
57
58        Ok(Self { client, model })
59    }
60}
61
62#[async_trait]
63impl AIProvider for XAIProvider {
64    async fn generate_commit_message(
65        &self,
66        diff: &str,
67        context: Option<&str>,
68        full_gitmoji: bool,
69        config: &Config,
70    ) -> Result<String> {
71        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
72
73        let messages = vec![
74            ChatCompletionRequestSystemMessage::from(system_prompt).into(),
75            ChatCompletionRequestUserMessage::from(user_prompt).into(),
76        ];
77
78        let request = CreateChatCompletionRequestArgs::default()
79            .model(&self.model)
80            .messages(messages)
81            .temperature(0.7)
82            .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
83            .build()?;
84
85        let response = self
86            .client
87            .chat()
88            .create(request)
89            .await
90            .context("Failed to generate commit message from xAI")?;
91
92        let message = response
93            .choices
94            .first()
95            .and_then(|choice| choice.message.content.as_ref())
96            .context("xAI returned an empty response")?
97            .trim()
98            .to_string();
99
100        Ok(message)
101    }
102}
103
104/// ProviderBuilder for XAI
105pub struct XAIProviderBuilder;
106
107impl super::registry::ProviderBuilder for XAIProviderBuilder {
108    fn name(&self) -> &'static str {
109        "xai"
110    }
111
112    fn aliases(&self) -> Vec<&'static str> {
113        vec!["grok", "x-ai"]
114    }
115
116    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
117        Ok(Box::new(XAIProvider::new(config)?))
118    }
119
120    fn requires_api_key(&self) -> bool {
121        true
122    }
123
124    fn default_model(&self) -> Option<&'static str> {
125        Some("grok-beta")
126    }
127}