1use crate::ChatMessage;
23use anyhow::{Context, Result, anyhow};
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use tokio::sync::mpsc::Sender;
27
28const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
29const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
30const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
31const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
32const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
33const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
34const X_TITLE: &str = "trusty-common";
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LocalModelConfig {
47 pub enabled: bool,
48 pub base_url: String,
49 pub model: String,
50}
51
52impl Default for LocalModelConfig {
53 fn default() -> Self {
54 Self {
55 enabled: true,
56 base_url: "http://localhost:11434".to_string(),
57 model: "qwen3:30b".to_string(),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDef {
76 pub name: String,
77 pub description: String,
78 pub parameters: serde_json::Value,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub struct ToolCall {
94 pub id: String,
95 pub name: String,
96 pub arguments: String,
97}
98
99#[derive(Debug, Clone)]
111pub enum ChatEvent {
112 Delta(String),
113 ToolCall(ToolCall),
114 Done,
115 Error(String),
116}
117
118#[async_trait]
132pub trait ChatProvider: Send + Sync {
133 fn name(&self) -> &str;
135 fn model(&self) -> &str;
137 async fn chat_stream(
139 &self,
140 messages: Vec<ChatMessage>,
141 tools: Vec<ToolDef>,
142 tx: Sender<ChatEvent>,
143 ) -> Result<()>;
144}
145
146#[derive(Debug, Serialize)]
149struct OpenAiToolWire<'a> {
150 #[serde(rename = "type")]
151 kind: &'static str,
152 function: OpenAiFunctionWire<'a>,
153}
154
155#[derive(Debug, Serialize)]
156struct OpenAiFunctionWire<'a> {
157 name: &'a str,
158 description: &'a str,
159 parameters: &'a serde_json::Value,
160}
161
162#[derive(Debug, Serialize)]
163struct ChatRequestWire<'a> {
164 model: &'a str,
165 messages: &'a [ChatMessage],
166 stream: bool,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 tools: Option<Vec<OpenAiToolWire<'a>>>,
169}
170
171fn tools_wire(tools: &[ToolDef]) -> Option<Vec<OpenAiToolWire<'_>>> {
172 if tools.is_empty() {
173 None
174 } else {
175 Some(
176 tools
177 .iter()
178 .map(|t| OpenAiToolWire {
179 kind: "function",
180 function: OpenAiFunctionWire {
181 name: &t.name,
182 description: &t.description,
183 parameters: &t.parameters,
184 },
185 })
186 .collect(),
187 )
188 }
189}
190
191#[derive(Debug, Default)]
203struct ToolCallAccumulator {
204 slots: Vec<Option<(String, String, String)>>,
206}
207
208impl ToolCallAccumulator {
209 fn apply_delta(&mut self, tool_calls: &serde_json::Value) {
210 let Some(arr) = tool_calls.as_array() else {
211 return;
212 };
213 for tc in arr {
214 let idx = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
215 while self.slots.len() <= idx {
216 self.slots.push(None);
217 }
218 let slot = self.slots[idx]
219 .get_or_insert_with(|| (String::new(), String::new(), String::new()));
220 if let Some(id) = tc.get("id").and_then(|v| v.as_str())
221 && !id.is_empty()
222 {
223 slot.0 = id.to_string();
224 }
225 if let Some(func) = tc.get("function") {
226 if let Some(name) = func.get("name").and_then(|v| v.as_str())
227 && !name.is_empty()
228 {
229 slot.1 = name.to_string();
230 }
231 if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
232 slot.2.push_str(args);
233 }
234 }
235 }
236 }
237
238 fn finalize(self) -> Vec<ToolCall> {
239 self.slots
240 .into_iter()
241 .filter_map(|opt| {
242 opt.and_then(|(id, name, arguments)| {
243 if name.is_empty() {
244 None
245 } else {
246 Some(ToolCall {
247 id,
248 name,
249 arguments,
250 })
251 }
252 })
253 })
254 .collect()
255 }
256}
257
258async fn pump_openai_sse(resp: reqwest::Response, tx: Sender<ChatEvent>) -> Result<()> {
271 use futures_util::StreamExt;
272
273 let mut acc = ToolCallAccumulator::default();
274 let mut buf = String::new();
275 let mut stream = resp.bytes_stream();
276
277 while let Some(chunk) = stream.next().await {
278 let bytes = chunk.context("read chat stream chunk")?;
279 let text = match std::str::from_utf8(&bytes) {
280 Ok(s) => s,
281 Err(_) => continue,
282 };
283 buf.push_str(text);
284
285 while let Some(idx) = buf.find('\n') {
286 let line: String = buf.drain(..=idx).collect();
287 let line = line.trim();
288 let Some(payload) = line.strip_prefix("data:").map(str::trim) else {
289 continue;
290 };
291 if payload.is_empty() {
292 continue;
293 }
294 if payload == "[DONE]" {
295 for call in std::mem::take(&mut acc).finalize() {
297 if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
298 return Ok(());
299 }
300 }
301 let _ = tx.send(ChatEvent::Done).await;
302 return Ok(());
303 }
304 let v: serde_json::Value = match serde_json::from_str(payload) {
305 Ok(v) => v,
306 Err(_) => continue,
307 };
308 let delta = v
309 .get("choices")
310 .and_then(|c| c.get(0))
311 .and_then(|c| c.get("delta"));
312 if let Some(delta) = delta {
313 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
314 && !content.is_empty()
315 && tx
316 .send(ChatEvent::Delta(content.to_string()))
317 .await
318 .is_err()
319 {
320 return Ok(());
321 }
322 if let Some(tc) = delta.get("tool_calls") {
323 acc.apply_delta(tc);
324 }
325 }
326 }
327 }
328
329 for call in acc.finalize() {
331 if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
332 return Ok(());
333 }
334 }
335 let _ = tx.send(ChatEvent::Done).await;
336 Ok(())
337}
338
339pub struct OpenRouterProvider {
351 pub api_key: String,
352 pub model: String,
353}
354
355impl OpenRouterProvider {
356 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
363 Self {
364 api_key: api_key.into(),
365 model: model.into(),
366 }
367 }
368}
369
370#[async_trait]
371impl ChatProvider for OpenRouterProvider {
372 fn name(&self) -> &str {
373 "openrouter"
374 }
375
376 fn model(&self) -> &str {
377 &self.model
378 }
379
380 async fn chat_stream(
381 &self,
382 messages: Vec<ChatMessage>,
383 tools: Vec<ToolDef>,
384 tx: Sender<ChatEvent>,
385 ) -> Result<()> {
386 if self.api_key.is_empty() {
387 return Err(anyhow!("openrouter api key is empty"));
388 }
389 let client = reqwest::Client::builder()
390 .connect_timeout(std::time::Duration::from_secs(
391 OPENROUTER_CONNECT_TIMEOUT_SECS,
392 ))
393 .timeout(std::time::Duration::from_secs(
394 OPENROUTER_REQUEST_TIMEOUT_SECS,
395 ))
396 .build()
397 .context("build reqwest client for OpenRouterProvider::chat_stream")?;
398
399 let tools_wire = tools_wire(&tools);
400 let body = ChatRequestWire {
401 model: &self.model,
402 messages: &messages,
403 stream: true,
404 tools: tools_wire,
405 };
406 let resp = client
407 .post(OPENROUTER_URL)
408 .bearer_auth(&self.api_key)
409 .header("HTTP-Referer", HTTP_REFERER)
410 .header("X-Title", X_TITLE)
411 .json(&body)
412 .send()
413 .await
414 .context("POST openrouter chat completions (stream)")?;
415
416 let status = resp.status();
417 if !status.is_success() {
418 let text = resp.text().await.unwrap_or_default();
419 return Err(anyhow!("openrouter HTTP {status}: {text}"));
420 }
421
422 pump_openai_sse(resp, tx).await
423 }
424}
425
426pub struct OllamaProvider {
441 pub base_url: String,
442 pub model: String,
443}
444
445impl OllamaProvider {
446 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
454 Self {
455 base_url: base_url.into(),
456 model: model.into(),
457 }
458 }
459}
460
461#[async_trait]
462impl ChatProvider for OllamaProvider {
463 fn name(&self) -> &str {
464 "ollama"
465 }
466
467 fn model(&self) -> &str {
468 &self.model
469 }
470
471 async fn chat_stream(
472 &self,
473 messages: Vec<ChatMessage>,
474 tools: Vec<ToolDef>,
475 tx: Sender<ChatEvent>,
476 ) -> Result<()> {
477 let client = reqwest::Client::builder()
478 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
479 .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
480 .build()
481 .context("build reqwest client for OllamaProvider::chat_stream")?;
482
483 let url = format!(
484 "{}/v1/chat/completions",
485 self.base_url.trim_end_matches('/')
486 );
487 let tools_wire = tools_wire(&tools);
488 let body = ChatRequestWire {
489 model: &self.model,
490 messages: &messages,
491 stream: true,
492 tools: tools_wire,
493 };
494 let resp = client
495 .post(&url)
496 .json(&body)
497 .send()
498 .await
499 .with_context(|| format!("POST {url}"))?;
500
501 let status = resp.status();
502 if !status.is_success() {
503 let text = resp.text().await.unwrap_or_default();
504 return Err(anyhow!("local chat HTTP {status}: {text}"));
505 }
506
507 pump_openai_sse(resp, tx).await
508 }
509}
510
511pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
528 let client = reqwest::Client::builder()
529 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
530 .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
531 .build()
532 .ok()?;
533
534 let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
535 match client.get(&url).send().await {
536 Ok(resp) if resp.status().is_success() => {
537 Some(OllamaProvider::new(base_url.to_string(), String::new()))
538 }
539 _ => None,
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn local_model_config_defaults() {
549 let cfg = LocalModelConfig::default();
550 assert!(cfg.enabled);
551 assert_eq!(cfg.base_url, "http://localhost:11434");
552 assert_eq!(cfg.model, "qwen3:30b");
553 }
554
555 #[test]
556 fn openrouter_provider_reports_metadata() {
557 let p = OpenRouterProvider::new("sk-xxx", "anthropic/claude-3.5-sonnet");
558 assert_eq!(p.name(), "openrouter");
559 assert_eq!(p.model(), "anthropic/claude-3.5-sonnet");
560 }
561
562 #[test]
563 fn ollama_provider_reports_metadata() {
564 let p = OllamaProvider::new("http://localhost:11434", "llama3.2");
565 assert_eq!(p.name(), "ollama");
566 assert_eq!(p.model(), "llama3.2");
567 }
568
569 #[test]
570 fn tool_def_serializes_as_function() {
571 let tools = vec![ToolDef {
574 name: "search".into(),
575 description: "Search the web".into(),
576 parameters: serde_json::json!({
577 "type": "object",
578 "properties": { "query": { "type": "string" } },
579 "required": ["query"],
580 }),
581 }];
582 let wire = tools_wire(&tools).expect("expected Some");
583 let v = serde_json::to_value(&wire).unwrap();
584 assert_eq!(v[0]["type"], "function");
585 assert_eq!(v[0]["function"]["name"], "search");
586 assert_eq!(v[0]["function"]["parameters"]["type"], "object");
587 }
588
589 #[test]
590 fn empty_tools_serializes_to_none() {
591 assert!(tools_wire(&[]).is_none());
594 }
595
596 #[test]
597 fn accumulates_streamed_tool_call_fragments() {
598 let mut acc = ToolCallAccumulator::default();
602 acc.apply_delta(&serde_json::json!([{
603 "index": 0,
604 "id": "call_abc",
605 "function": { "name": "search", "arguments": "" }
606 }]));
607 acc.apply_delta(&serde_json::json!([{
608 "index": 0,
609 "function": { "arguments": "{\"query\":\"" }
610 }]));
611 acc.apply_delta(&serde_json::json!([{
612 "index": 0,
613 "function": { "arguments": "rust\"}" }
614 }]));
615 let calls = acc.finalize();
616 assert_eq!(calls.len(), 1);
617 assert_eq!(calls[0].id, "call_abc");
618 assert_eq!(calls[0].name, "search");
619 assert_eq!(calls[0].arguments, "{\"query\":\"rust\"}");
620 }
621
622 #[tokio::test]
623 async fn auto_detect_returns_none_on_unreachable() {
624 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
625 let port = listener.local_addr().unwrap().port();
626 drop(listener);
627
628 let base = format!("http://127.0.0.1:{port}");
629 let start = std::time::Instant::now();
630 let got = auto_detect_local_provider(&base).await;
631 let elapsed = start.elapsed();
632 assert!(got.is_none(), "expected None for unreachable server");
633 assert!(
634 elapsed < std::time::Duration::from_secs(2),
635 "auto-detect took too long: {elapsed:?}"
636 );
637 }
638
639 #[tokio::test]
640 async fn auto_detect_returns_some_on_200() {
641 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
642 let addr = listener.local_addr().unwrap();
643 let base = format!("http://{addr}");
644
645 tokio::spawn(async move {
646 if let Ok((mut sock, _)) = listener.accept().await {
647 use tokio::io::{AsyncReadExt, AsyncWriteExt};
648 let mut buf = [0u8; 1024];
649 let _ = sock.read(&mut buf).await;
650 let body = b"{\"data\":[]}";
651 let response = format!(
652 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
653 body.len()
654 );
655 let _ = sock.write_all(response.as_bytes()).await;
656 let _ = sock.write_all(body).await;
657 let _ = sock.shutdown().await;
658 }
659 });
660
661 let got = auto_detect_local_provider(&base).await;
662 assert!(got.is_some(), "expected Some for reachable 200 server");
663 let p = got.unwrap();
664 assert_eq!(p.name(), "ollama");
665 assert_eq!(p.base_url, base);
666 }
667
668 #[test]
669 fn local_model_config_deserializes_from_toml() {
670 let toml_src = r#"
671 enabled = true
672 base_url = "http://localhost:1234"
673 model = "qwen2.5-coder"
674 "#;
675 let cfg: LocalModelConfig = toml::from_str(toml_src).expect("parse TOML");
676 assert!(cfg.enabled);
677 assert_eq!(cfg.base_url, "http://localhost:1234");
678 assert_eq!(cfg.model, "qwen2.5-coder");
679 }
680
681 #[tokio::test]
682 async fn ollama_provider_streams_sse_deltas() {
683 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
686 let addr = listener.local_addr().unwrap();
687 let base = format!("http://{addr}");
688
689 tokio::spawn(async move {
690 if let Ok((mut sock, _)) = listener.accept().await {
691 use tokio::io::{AsyncReadExt, AsyncWriteExt};
692 let mut buf = [0u8; 4096];
693 let _ = sock.read(&mut buf).await;
694
695 let sse_body = concat!(
696 "data: {\"choices\":[{\"delta\":{\"content\":\"hello \"}}]}\n\n",
697 "data: {\"choices\":[{\"delta\":{\"content\":\"world\"}}]}\n\n",
698 "data: [DONE]\n\n",
699 );
700 let response = format!(
701 "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
702 sse_body.len(),
703 sse_body
704 );
705 let _ = sock.write_all(response.as_bytes()).await;
706 let _ = sock.shutdown().await;
707 }
708 });
709
710 let provider = OllamaProvider::new(base, "test-model");
711 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
712 let handle = tokio::spawn(async move {
713 provider
714 .chat_stream(
715 vec![ChatMessage {
716 role: "user".into(),
717 content: "hi".into(),
718 tool_call_id: None,
719 tool_calls: None,
720 }],
721 vec![],
722 tx,
723 )
724 .await
725 });
726
727 let mut deltas = Vec::new();
728 let mut saw_done = false;
729 while let Some(ev) = rx.recv().await {
730 match ev {
731 ChatEvent::Delta(s) => deltas.push(s),
732 ChatEvent::Done => saw_done = true,
733 ChatEvent::ToolCall(_) => panic!("unexpected tool call"),
734 ChatEvent::Error(e) => panic!("stream error: {e}"),
735 }
736 }
737 let result = handle.await.expect("task panicked");
738 assert!(result.is_ok(), "chat_stream errored: {result:?}");
739 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
740 assert!(saw_done, "expected ChatEvent::Done");
741 }
742
743 #[tokio::test]
744 async fn ollama_provider_emits_tool_call() {
745 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
747 let addr = listener.local_addr().unwrap();
748 let base = format!("http://{addr}");
749
750 tokio::spawn(async move {
751 if let Ok((mut sock, _)) = listener.accept().await {
752 use tokio::io::{AsyncReadExt, AsyncWriteExt};
753 let mut buf = [0u8; 4096];
754 let _ = sock.read(&mut buf).await;
755
756 let sse_body = concat!(
757 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\"}}]}}]}\n\n",
758 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"rust\\\"}\"}}]}}]}\n\n",
759 "data: [DONE]\n\n",
760 );
761 let response = format!(
762 "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
763 sse_body.len(),
764 sse_body
765 );
766 let _ = sock.write_all(response.as_bytes()).await;
767 let _ = sock.shutdown().await;
768 }
769 });
770
771 let provider = OllamaProvider::new(base, "test-model");
772 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
773 let handle = tokio::spawn(async move {
774 provider
775 .chat_stream(
776 vec![ChatMessage {
777 role: "user".into(),
778 content: "search rust".into(),
779 tool_call_id: None,
780 tool_calls: None,
781 }],
782 vec![ToolDef {
783 name: "search".into(),
784 description: "search the web".into(),
785 parameters: serde_json::json!({"type":"object"}),
786 }],
787 tx,
788 )
789 .await
790 });
791
792 let mut tool_calls = Vec::new();
793 let mut saw_done = false;
794 while let Some(ev) = rx.recv().await {
795 match ev {
796 ChatEvent::ToolCall(tc) => tool_calls.push(tc),
797 ChatEvent::Done => saw_done = true,
798 ChatEvent::Delta(_) => {}
799 ChatEvent::Error(e) => panic!("stream error: {e}"),
800 }
801 }
802 let result = handle.await.expect("task panicked");
803 assert!(result.is_ok(), "chat_stream errored: {result:?}");
804 assert_eq!(tool_calls.len(), 1);
805 assert_eq!(tool_calls[0].id, "call_1");
806 assert_eq!(tool_calls[0].name, "search");
807 assert_eq!(tool_calls[0].arguments, "{\"q\":\"rust\"}");
808 assert!(saw_done);
809 }
810}