rusty_commit/providers/
ollama.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{build_prompt, AIProvider};
7use crate::config::Config;
9use crate::utils::retry::retry_async;
10
11pub struct OllamaProvider {
12 client: Client,
13 api_url: String,
14 model: String,
15}
16
17#[derive(Serialize)]
18struct OllamaRequest {
19 model: String,
20 prompt: String,
21 stream: bool,
22 options: OllamaOptions,
23}
24
25#[derive(Serialize)]
26struct OllamaOptions {
27 temperature: f32,
28 num_predict: i32,
29}
30
31#[derive(Deserialize)]
32struct OllamaResponse {
33 response: String,
34}
35
36impl OllamaProvider {
37 pub fn new(config: &Config) -> Result<Self> {
38 let client = Client::new();
39 let api_url = config
40 .api_url
41 .as_deref()
42 .unwrap_or("http://localhost:11434")
43 .to_string();
44 let model = config.model.as_deref().unwrap_or("mistral").to_string();
45
46 Ok(Self {
47 client,
48 api_url,
49 model,
50 })
51 }
52
53 #[allow(dead_code)]
55 pub fn from_account(
56 account: &crate::config::accounts::AccountConfig,
57 _api_key: &str,
58 config: &Config,
59 ) -> Result<Self> {
60 let client = Client::new();
61 let api_url = account
62 .api_url
63 .as_deref()
64 .or(config.api_url.as_deref())
65 .unwrap_or("http://localhost:11434")
66 .to_string();
67 let model = account
68 .model
69 .as_deref()
70 .or(config.model.as_deref())
71 .unwrap_or("mistral")
72 .to_string();
73
74 Ok(Self {
75 client,
76 api_url,
77 model,
78 })
79 }
80}
81
82#[async_trait]
83impl AIProvider for OllamaProvider {
84 async fn generate_commit_message(
85 &self,
86 diff: &str,
87 context: Option<&str>,
88 full_gitmoji: bool,
89 config: &Config,
90 ) -> Result<String> {
91 let prompt = build_prompt(diff, context, config, full_gitmoji);
92
93 let request = OllamaRequest {
94 model: self.model.clone(),
95 prompt,
96 stream: false,
97 options: OllamaOptions {
98 temperature: 0.7,
99 num_predict: config.tokens_max_output.unwrap_or(500) as i32,
100 },
101 };
102
103 let ollama_response: OllamaResponse = retry_async(|| async {
104 let url = format!("{}/api/generate", self.api_url);
105 let response = self
106 .client
107 .post(&url)
108 .json(&request)
109 .send()
110 .await
111 .context("Failed to connect to Ollama")?;
112
113 if !response.status().is_success() {
114 let error_text = response.text().await?;
115 return Err(anyhow::anyhow!("Ollama API error: {}", error_text));
116 }
117
118 let ollama_response: OllamaResponse = response
119 .json()
120 .await
121 .context("Failed to parse Ollama response")?;
122
123 Ok(ollama_response)
124 })
125 .await
126 .context("Failed to generate commit message from Ollama after retries")?;
127
128 Ok(ollama_response.response.trim().to_string())
129 }
130}
131
132pub struct OllamaProviderBuilder;
134
135impl super::registry::ProviderBuilder for OllamaProviderBuilder {
136 fn name(&self) -> &'static str {
137 "ollama"
138 }
139
140 fn category(&self) -> super::registry::ProviderCategory {
141 super::registry::ProviderCategory::Local
142 }
143
144 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
145 Ok(Box::new(OllamaProvider::new(config)?))
146 }
147
148 fn requires_api_key(&self) -> bool {
149 false
150 }
151
152 fn default_model(&self) -> Option<&'static str> {
153 Some("llama3.1")
154 }
155}