1use crate::client::LlmClient;
17use crate::retry::RetryClient;
18use crate::schema::response_schema_for;
19use crate::tool::ToolDef;
20use crate::types::{LlmConfig, Message, SgrError, ToolCall};
21use schemars::JsonSchema;
22use serde::de::DeserializeOwned;
23use serde_json::Value;
24
25enum Backend {
28 Oxide(RetryClient<crate::oxide_client::OxideClient>),
29 OxideChat(RetryClient<crate::oxide_chat_client::OxideChatClient>),
30 #[cfg(feature = "genai")]
31 Genai(crate::genai_client::GenaiClient),
32 Cli(crate::cli_client::CliClient),
34}
35
36pub struct Llm {
38 inner: Backend,
39}
40
41impl Llm {
42 pub fn new(config: &LlmConfig) -> Self {
47 if config.use_cli {
49 let backend = crate::cli_client::CliBackend::from_model(&config.model)
50 .unwrap_or(crate::cli_client::CliBackend::Claude);
51 let client = crate::cli_client::CliClient::new(backend).with_model(&config.model);
52 tracing::debug!(model = %config.model, backend = "cli", "Llm backend selected");
53 return Self {
54 inner: Backend::Cli(client),
55 };
56 }
57
58 #[cfg(feature = "genai")]
60 if config.use_genai {
61 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (explicit)");
62 return Self {
63 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
64 };
65 }
66
67 #[cfg(feature = "genai")]
69 if config.project_id.is_some() {
70 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected");
71 return Self {
72 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
73 };
74 }
75
76 if config.use_chat_api
78 && let Ok(client) = crate::oxide_chat_client::OxideChatClient::from_config(config)
79 {
80 tracing::debug!(model = %config.model, backend = "oxide-chat", "Llm backend selected (Chat Completions)");
81 return Self {
82 inner: Backend::OxideChat(RetryClient::new(client)),
83 };
84 }
85
86 if let Ok(client) = crate::oxide_client::OxideClient::from_config(config) {
87 tracing::debug!(model = %config.model, backend = "oxide", "Llm backend selected");
88 Self {
89 inner: Backend::Oxide(RetryClient::new(client)),
90 }
91 } else {
92 #[cfg(feature = "genai")]
93 {
94 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (oxide fallback)");
95 return Self {
96 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
97 };
98 }
99 #[cfg(not(feature = "genai"))]
100 panic!("OxideClient::from_config failed and genai feature not enabled");
101 }
102 }
103
104 fn client(&self) -> &dyn LlmClient {
106 match &self.inner {
107 Backend::Oxide(c) => c,
108 Backend::OxideChat(c) => c,
109 #[cfg(feature = "genai")]
110 Backend::Genai(c) => c,
111 Backend::Cli(c) => c,
112 }
113 }
114
115 pub async fn new_async(config: &LlmConfig) -> Self {
118 let llm = Self::new(config);
119 if config.websocket
120 && let Err(e) = llm.connect_ws().await
121 {
122 tracing::warn!("WebSocket upgrade skipped: {}", e);
123 }
124 llm
125 }
126
127 pub async fn connect_ws(&self) -> Result<(), SgrError> {
130 #[cfg(feature = "oxide-ws")]
131 if let Backend::Oxide(c) = &self.inner {
132 return c.inner().connect_ws().await;
133 }
134 Ok(())
135 }
136
137 pub async fn stream_complete<F>(
139 &self,
140 messages: &[Message],
141 mut on_token: F,
142 ) -> Result<String, SgrError>
143 where
144 F: FnMut(&str),
145 {
146 match &self.inner {
147 #[cfg(feature = "genai")]
148 Backend::Genai(c) => c.stream_complete(messages, on_token).await,
149 Backend::Oxide(_) | Backend::OxideChat(_) | Backend::Cli(_) => {
150 let text = self.generate(messages).await?;
153 on_token(&text);
154 Ok(text)
155 }
156 }
157 }
158
159 pub async fn generate(&self, messages: &[Message]) -> Result<String, SgrError> {
161 self.client().complete(messages).await
162 }
163
164 pub async fn tools_call_stateful(
167 &self,
168 messages: &[Message],
169 tools: &[ToolDef],
170 previous_response_id: Option<&str>,
171 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
172 self.client()
173 .tools_call_stateful(messages, tools, previous_response_id)
174 .await
175 }
176
177 pub async fn tools_call_with_text(
180 &self,
181 messages: &[Message],
182 tools: &[ToolDef],
183 ) -> Result<(Vec<ToolCall>, String), SgrError> {
184 self.client().tools_call_with_text(messages, tools).await
185 }
186
187 pub async fn structured<T: JsonSchema + DeserializeOwned>(
189 &self,
190 messages: &[Message],
191 ) -> Result<T, SgrError> {
192 let schema = response_schema_for::<T>();
193 let (parsed, _tool_calls, raw_text) =
194 self.client().structured_call(messages, &schema).await?;
195 match parsed {
196 Some(value) => serde_json::from_value::<T>(value)
197 .map_err(|e| SgrError::Schema(format!("Parse error: {e}\nRaw: {raw_text}"))),
198 None => Err(SgrError::EmptyResponse),
199 }
200 }
201
202 pub fn backend_name(&self) -> &'static str {
204 match &self.inner {
205 Backend::Oxide(_) => "oxide",
206 Backend::OxideChat(_) => "oxide-chat",
207 #[cfg(feature = "genai")]
208 Backend::Genai(_) => "genai",
209 Backend::Cli(_) => "cli",
210 }
211 }
212}
213
214#[async_trait::async_trait]
215impl LlmClient for Llm {
216 async fn structured_call(
217 &self,
218 messages: &[Message],
219 schema: &Value,
220 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
221 self.client().structured_call(messages, schema).await
222 }
223
224 async fn tools_call(
225 &self,
226 messages: &[Message],
227 tools: &[ToolDef],
228 ) -> Result<Vec<ToolCall>, SgrError> {
229 self.client().tools_call(messages, tools).await
230 }
231
232 async fn tools_call_stateful(
233 &self,
234 messages: &[Message],
235 tools: &[ToolDef],
236 previous_response_id: Option<&str>,
237 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
238 self.client()
239 .tools_call_stateful(messages, tools, previous_response_id)
240 .await
241 }
242
243 async fn tools_call_with_text(
244 &self,
245 messages: &[Message],
246 tools: &[ToolDef],
247 ) -> Result<(Vec<ToolCall>, String), SgrError> {
248 self.client().tools_call_with_text(messages, tools).await
249 }
250
251 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
252 self.client().complete(messages).await
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn llm_from_auto_config() {
262 let config = LlmConfig::endpoint("sk-test-dummy", "https://api.openai.com/v1", "gpt-5.4");
264 let llm = Llm::new(&config);
265 assert_eq!(llm.backend_name(), "oxide");
266 }
267
268 #[test]
269 fn llm_custom_endpoint_uses_oxide() {
270 let config = LlmConfig::endpoint("sk-test", "https://openrouter.ai/api/v1", "gpt-5.4");
271 let llm = Llm::new(&config);
272 assert_eq!(llm.backend_name(), "oxide");
273 }
274
275 #[test]
276 fn llm_config_serde_roundtrip() {
277 let config = LlmConfig::endpoint("key", "https://example.com/v1", "model")
278 .temperature(0.9)
279 .max_tokens(1000);
280 let json = serde_json::to_string(&config).unwrap();
281 let back: LlmConfig = serde_json::from_str(&json).unwrap();
282 assert_eq!(back.model, "model");
283 assert_eq!(back.api_key.as_deref(), Some("key"));
284 assert_eq!(back.base_url.as_deref(), Some("https://example.com/v1"));
285 assert_eq!(back.temp, 0.9);
286 assert_eq!(back.max_tokens, Some(1000));
287 }
288
289 #[test]
290 fn llm_config_auto_minimal_json() {
291 let json = r#"{"model": "gpt-4o"}"#;
292 let config: LlmConfig = serde_json::from_str(json).unwrap();
293 assert_eq!(config.model, "gpt-4o");
294 assert!(config.api_key.is_none());
295 assert_eq!(config.temp, 0.7);
296 }
297}