1use crate::ChatMessage;
23use anyhow::{Context, Result, anyhow};
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use tokio::sync::mpsc::Sender;
27
28const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
29const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
30const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
31const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
32const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
33const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
34const X_TITLE: &str = "trusty-common";
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LocalModelConfig {
47 pub enabled: bool,
48 pub base_url: String,
49 pub model: String,
50}
51
52impl Default for LocalModelConfig {
53 fn default() -> Self {
54 Self {
55 enabled: true,
56 base_url: "http://localhost:11434".to_string(),
57 model: "llama3.2".to_string(),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDef {
76 pub name: String,
77 pub description: String,
78 pub parameters: serde_json::Value,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub struct ToolCall {
94 pub id: String,
95 pub name: String,
96 pub arguments: String,
97}
98
99#[derive(Debug, Clone)]
111pub enum ChatEvent {
112 Delta(String),
113 ToolCall(ToolCall),
114 Done,
115 Error(String),
116}
117
118#[async_trait]
132pub trait ChatProvider: Send + Sync {
133 fn name(&self) -> &str;
135 fn model(&self) -> &str;
137 async fn chat_stream(
139 &self,
140 messages: Vec<ChatMessage>,
141 tools: Vec<ToolDef>,
142 tx: Sender<ChatEvent>,
143 ) -> Result<()>;
144}
145
146#[derive(Debug, Serialize)]
149struct OpenAiToolWire<'a> {
150 #[serde(rename = "type")]
151 kind: &'static str,
152 function: OpenAiFunctionWire<'a>,
153}
154
155#[derive(Debug, Serialize)]
156struct OpenAiFunctionWire<'a> {
157 name: &'a str,
158 description: &'a str,
159 parameters: &'a serde_json::Value,
160}
161
162#[derive(Debug, Serialize)]
163struct ChatRequestWire<'a> {
164 model: &'a str,
165 messages: &'a [ChatMessage],
166 stream: bool,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 tools: Option<Vec<OpenAiToolWire<'a>>>,
169}
170
171fn tools_wire(tools: &[ToolDef]) -> Option<Vec<OpenAiToolWire<'_>>> {
172 if tools.is_empty() {
173 None
174 } else {
175 Some(
176 tools
177 .iter()
178 .map(|t| OpenAiToolWire {
179 kind: "function",
180 function: OpenAiFunctionWire {
181 name: &t.name,
182 description: &t.description,
183 parameters: &t.parameters,
184 },
185 })
186 .collect(),
187 )
188 }
189}
190
191#[derive(Debug, Default)]
203struct ToolCallAccumulator {
204 slots: Vec<Option<(String, String, String)>>,
206}
207
208impl ToolCallAccumulator {
209 fn apply_delta(&mut self, tool_calls: &serde_json::Value) {
210 let Some(arr) = tool_calls.as_array() else {
211 return;
212 };
213 for tc in arr {
214 let idx = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
215 while self.slots.len() <= idx {
216 self.slots.push(None);
217 }
218 let slot = self.slots[idx].get_or_insert_with(|| {
219 (String::new(), String::new(), String::new())
220 });
221 if let Some(id) = tc.get("id").and_then(|v| v.as_str())
222 && !id.is_empty()
223 {
224 slot.0 = id.to_string();
225 }
226 if let Some(func) = tc.get("function") {
227 if let Some(name) = func.get("name").and_then(|v| v.as_str())
228 && !name.is_empty()
229 {
230 slot.1 = name.to_string();
231 }
232 if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
233 slot.2.push_str(args);
234 }
235 }
236 }
237 }
238
239 fn finalize(self) -> Vec<ToolCall> {
240 self.slots
241 .into_iter()
242 .filter_map(|opt| {
243 opt.and_then(|(id, name, arguments)| {
244 if name.is_empty() {
245 None
246 } else {
247 Some(ToolCall {
248 id,
249 name,
250 arguments,
251 })
252 }
253 })
254 })
255 .collect()
256 }
257}
258
259async fn pump_openai_sse(resp: reqwest::Response, tx: Sender<ChatEvent>) -> Result<()> {
272 use futures_util::StreamExt;
273
274 let mut acc = ToolCallAccumulator::default();
275 let mut buf = String::new();
276 let mut stream = resp.bytes_stream();
277
278 while let Some(chunk) = stream.next().await {
279 let bytes = chunk.context("read chat stream chunk")?;
280 let text = match std::str::from_utf8(&bytes) {
281 Ok(s) => s,
282 Err(_) => continue,
283 };
284 buf.push_str(text);
285
286 while let Some(idx) = buf.find('\n') {
287 let line: String = buf.drain(..=idx).collect();
288 let line = line.trim();
289 let Some(payload) = line.strip_prefix("data:").map(str::trim) else {
290 continue;
291 };
292 if payload.is_empty() {
293 continue;
294 }
295 if payload == "[DONE]" {
296 for call in std::mem::take(&mut acc).finalize() {
298 if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
299 return Ok(());
300 }
301 }
302 let _ = tx.send(ChatEvent::Done).await;
303 return Ok(());
304 }
305 let v: serde_json::Value = match serde_json::from_str(payload) {
306 Ok(v) => v,
307 Err(_) => continue,
308 };
309 let delta = v
310 .get("choices")
311 .and_then(|c| c.get(0))
312 .and_then(|c| c.get("delta"));
313 if let Some(delta) = delta {
314 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
315 && !content.is_empty()
316 && tx
317 .send(ChatEvent::Delta(content.to_string()))
318 .await
319 .is_err()
320 {
321 return Ok(());
322 }
323 if let Some(tc) = delta.get("tool_calls") {
324 acc.apply_delta(tc);
325 }
326 }
327 }
328 }
329
330 for call in acc.finalize() {
332 if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
333 return Ok(());
334 }
335 }
336 let _ = tx.send(ChatEvent::Done).await;
337 Ok(())
338}
339
340pub struct OpenRouterProvider {
352 pub api_key: String,
353 pub model: String,
354}
355
356impl OpenRouterProvider {
357 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
364 Self {
365 api_key: api_key.into(),
366 model: model.into(),
367 }
368 }
369}
370
371#[async_trait]
372impl ChatProvider for OpenRouterProvider {
373 fn name(&self) -> &str {
374 "openrouter"
375 }
376
377 fn model(&self) -> &str {
378 &self.model
379 }
380
381 async fn chat_stream(
382 &self,
383 messages: Vec<ChatMessage>,
384 tools: Vec<ToolDef>,
385 tx: Sender<ChatEvent>,
386 ) -> Result<()> {
387 if self.api_key.is_empty() {
388 return Err(anyhow!("openrouter api key is empty"));
389 }
390 let client = reqwest::Client::builder()
391 .connect_timeout(std::time::Duration::from_secs(
392 OPENROUTER_CONNECT_TIMEOUT_SECS,
393 ))
394 .timeout(std::time::Duration::from_secs(
395 OPENROUTER_REQUEST_TIMEOUT_SECS,
396 ))
397 .build()
398 .context("build reqwest client for OpenRouterProvider::chat_stream")?;
399
400 let tools_wire = tools_wire(&tools);
401 let body = ChatRequestWire {
402 model: &self.model,
403 messages: &messages,
404 stream: true,
405 tools: tools_wire,
406 };
407 let resp = client
408 .post(OPENROUTER_URL)
409 .bearer_auth(&self.api_key)
410 .header("HTTP-Referer", HTTP_REFERER)
411 .header("X-Title", X_TITLE)
412 .json(&body)
413 .send()
414 .await
415 .context("POST openrouter chat completions (stream)")?;
416
417 let status = resp.status();
418 if !status.is_success() {
419 let text = resp.text().await.unwrap_or_default();
420 return Err(anyhow!("openrouter HTTP {status}: {text}"));
421 }
422
423 pump_openai_sse(resp, tx).await
424 }
425}
426
427pub struct OllamaProvider {
442 pub base_url: String,
443 pub model: String,
444}
445
446impl OllamaProvider {
447 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
455 Self {
456 base_url: base_url.into(),
457 model: model.into(),
458 }
459 }
460}
461
462#[async_trait]
463impl ChatProvider for OllamaProvider {
464 fn name(&self) -> &str {
465 "ollama"
466 }
467
468 fn model(&self) -> &str {
469 &self.model
470 }
471
472 async fn chat_stream(
473 &self,
474 messages: Vec<ChatMessage>,
475 tools: Vec<ToolDef>,
476 tx: Sender<ChatEvent>,
477 ) -> Result<()> {
478 let client = reqwest::Client::builder()
479 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
480 .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
481 .build()
482 .context("build reqwest client for OllamaProvider::chat_stream")?;
483
484 let url = format!(
485 "{}/v1/chat/completions",
486 self.base_url.trim_end_matches('/')
487 );
488 let tools_wire = tools_wire(&tools);
489 let body = ChatRequestWire {
490 model: &self.model,
491 messages: &messages,
492 stream: true,
493 tools: tools_wire,
494 };
495 let resp = client
496 .post(&url)
497 .json(&body)
498 .send()
499 .await
500 .with_context(|| format!("POST {url}"))?;
501
502 let status = resp.status();
503 if !status.is_success() {
504 let text = resp.text().await.unwrap_or_default();
505 return Err(anyhow!("local chat HTTP {status}: {text}"));
506 }
507
508 pump_openai_sse(resp, tx).await
509 }
510}
511
512pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
529 let client = reqwest::Client::builder()
530 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
531 .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
532 .build()
533 .ok()?;
534
535 let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
536 match client.get(&url).send().await {
537 Ok(resp) if resp.status().is_success() => {
538 Some(OllamaProvider::new(base_url.to_string(), String::new()))
539 }
540 _ => None,
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn local_model_config_defaults() {
550 let cfg = LocalModelConfig::default();
551 assert!(cfg.enabled);
552 assert_eq!(cfg.base_url, "http://localhost:11434");
553 assert_eq!(cfg.model, "llama3.2");
554 }
555
556 #[test]
557 fn openrouter_provider_reports_metadata() {
558 let p = OpenRouterProvider::new("sk-xxx", "anthropic/claude-3.5-sonnet");
559 assert_eq!(p.name(), "openrouter");
560 assert_eq!(p.model(), "anthropic/claude-3.5-sonnet");
561 }
562
563 #[test]
564 fn ollama_provider_reports_metadata() {
565 let p = OllamaProvider::new("http://localhost:11434", "llama3.2");
566 assert_eq!(p.name(), "ollama");
567 assert_eq!(p.model(), "llama3.2");
568 }
569
570 #[test]
571 fn tool_def_serializes_as_function() {
572 let tools = vec![ToolDef {
575 name: "search".into(),
576 description: "Search the web".into(),
577 parameters: serde_json::json!({
578 "type": "object",
579 "properties": { "query": { "type": "string" } },
580 "required": ["query"],
581 }),
582 }];
583 let wire = tools_wire(&tools).expect("expected Some");
584 let v = serde_json::to_value(&wire).unwrap();
585 assert_eq!(v[0]["type"], "function");
586 assert_eq!(v[0]["function"]["name"], "search");
587 assert_eq!(v[0]["function"]["parameters"]["type"], "object");
588 }
589
590 #[test]
591 fn empty_tools_serializes_to_none() {
592 assert!(tools_wire(&[]).is_none());
595 }
596
597 #[test]
598 fn accumulates_streamed_tool_call_fragments() {
599 let mut acc = ToolCallAccumulator::default();
603 acc.apply_delta(&serde_json::json!([{
604 "index": 0,
605 "id": "call_abc",
606 "function": { "name": "search", "arguments": "" }
607 }]));
608 acc.apply_delta(&serde_json::json!([{
609 "index": 0,
610 "function": { "arguments": "{\"query\":\"" }
611 }]));
612 acc.apply_delta(&serde_json::json!([{
613 "index": 0,
614 "function": { "arguments": "rust\"}" }
615 }]));
616 let calls = acc.finalize();
617 assert_eq!(calls.len(), 1);
618 assert_eq!(calls[0].id, "call_abc");
619 assert_eq!(calls[0].name, "search");
620 assert_eq!(calls[0].arguments, "{\"query\":\"rust\"}");
621 }
622
623 #[tokio::test]
624 async fn auto_detect_returns_none_on_unreachable() {
625 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
626 let port = listener.local_addr().unwrap().port();
627 drop(listener);
628
629 let base = format!("http://127.0.0.1:{port}");
630 let start = std::time::Instant::now();
631 let got = auto_detect_local_provider(&base).await;
632 let elapsed = start.elapsed();
633 assert!(got.is_none(), "expected None for unreachable server");
634 assert!(
635 elapsed < std::time::Duration::from_secs(2),
636 "auto-detect took too long: {elapsed:?}"
637 );
638 }
639
640 #[tokio::test]
641 async fn auto_detect_returns_some_on_200() {
642 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
643 let addr = listener.local_addr().unwrap();
644 let base = format!("http://{addr}");
645
646 tokio::spawn(async move {
647 if let Ok((mut sock, _)) = listener.accept().await {
648 use tokio::io::{AsyncReadExt, AsyncWriteExt};
649 let mut buf = [0u8; 1024];
650 let _ = sock.read(&mut buf).await;
651 let body = b"{\"data\":[]}";
652 let response = format!(
653 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
654 body.len()
655 );
656 let _ = sock.write_all(response.as_bytes()).await;
657 let _ = sock.write_all(body).await;
658 let _ = sock.shutdown().await;
659 }
660 });
661
662 let got = auto_detect_local_provider(&base).await;
663 assert!(got.is_some(), "expected Some for reachable 200 server");
664 let p = got.unwrap();
665 assert_eq!(p.name(), "ollama");
666 assert_eq!(p.base_url, base);
667 }
668
669 #[test]
670 fn local_model_config_deserializes_from_toml() {
671 let toml_src = r#"
672 enabled = true
673 base_url = "http://localhost:1234"
674 model = "qwen2.5-coder"
675 "#;
676 let cfg: LocalModelConfig = toml::from_str(toml_src).expect("parse TOML");
677 assert!(cfg.enabled);
678 assert_eq!(cfg.base_url, "http://localhost:1234");
679 assert_eq!(cfg.model, "qwen2.5-coder");
680 }
681
682 #[tokio::test]
683 async fn ollama_provider_streams_sse_deltas() {
684 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
687 let addr = listener.local_addr().unwrap();
688 let base = format!("http://{addr}");
689
690 tokio::spawn(async move {
691 if let Ok((mut sock, _)) = listener.accept().await {
692 use tokio::io::{AsyncReadExt, AsyncWriteExt};
693 let mut buf = [0u8; 4096];
694 let _ = sock.read(&mut buf).await;
695
696 let sse_body = concat!(
697 "data: {\"choices\":[{\"delta\":{\"content\":\"hello \"}}]}\n\n",
698 "data: {\"choices\":[{\"delta\":{\"content\":\"world\"}}]}\n\n",
699 "data: [DONE]\n\n",
700 );
701 let response = format!(
702 "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
703 sse_body.len(),
704 sse_body
705 );
706 let _ = sock.write_all(response.as_bytes()).await;
707 let _ = sock.shutdown().await;
708 }
709 });
710
711 let provider = OllamaProvider::new(base, "test-model");
712 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
713 let handle = tokio::spawn(async move {
714 provider
715 .chat_stream(
716 vec![ChatMessage {
717 role: "user".into(),
718 content: "hi".into(),
719 tool_call_id: None,
720 tool_calls: None,
721 }],
722 vec![],
723 tx,
724 )
725 .await
726 });
727
728 let mut deltas = Vec::new();
729 let mut saw_done = false;
730 while let Some(ev) = rx.recv().await {
731 match ev {
732 ChatEvent::Delta(s) => deltas.push(s),
733 ChatEvent::Done => saw_done = true,
734 ChatEvent::ToolCall(_) => panic!("unexpected tool call"),
735 ChatEvent::Error(e) => panic!("stream error: {e}"),
736 }
737 }
738 let result = handle.await.expect("task panicked");
739 assert!(result.is_ok(), "chat_stream errored: {result:?}");
740 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
741 assert!(saw_done, "expected ChatEvent::Done");
742 }
743
744 #[tokio::test]
745 async fn ollama_provider_emits_tool_call() {
746 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
748 let addr = listener.local_addr().unwrap();
749 let base = format!("http://{addr}");
750
751 tokio::spawn(async move {
752 if let Ok((mut sock, _)) = listener.accept().await {
753 use tokio::io::{AsyncReadExt, AsyncWriteExt};
754 let mut buf = [0u8; 4096];
755 let _ = sock.read(&mut buf).await;
756
757 let sse_body = concat!(
758 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\"}}]}}]}\n\n",
759 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"rust\\\"}\"}}]}}]}\n\n",
760 "data: [DONE]\n\n",
761 );
762 let response = format!(
763 "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
764 sse_body.len(),
765 sse_body
766 );
767 let _ = sock.write_all(response.as_bytes()).await;
768 let _ = sock.shutdown().await;
769 }
770 });
771
772 let provider = OllamaProvider::new(base, "test-model");
773 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
774 let handle = tokio::spawn(async move {
775 provider
776 .chat_stream(
777 vec![ChatMessage {
778 role: "user".into(),
779 content: "search rust".into(),
780 tool_call_id: None,
781 tool_calls: None,
782 }],
783 vec![ToolDef {
784 name: "search".into(),
785 description: "search the web".into(),
786 parameters: serde_json::json!({"type":"object"}),
787 }],
788 tx,
789 )
790 .await
791 });
792
793 let mut tool_calls = Vec::new();
794 let mut saw_done = false;
795 while let Some(ev) = rx.recv().await {
796 match ev {
797 ChatEvent::ToolCall(tc) => tool_calls.push(tc),
798 ChatEvent::Done => saw_done = true,
799 ChatEvent::Delta(_) => {}
800 ChatEvent::Error(e) => panic!("stream error: {e}"),
801 }
802 }
803 let result = handle.await.expect("task panicked");
804 assert!(result.is_ok(), "chat_stream errored: {result:?}");
805 assert_eq!(tool_calls.len(), 1);
806 assert_eq!(tool_calls[0].id, "call_1");
807 assert_eq!(tool_calls[0].name, "search");
808 assert_eq!(tool_calls[0].arguments, "{\"q\":\"rust\"}");
809 assert!(saw_done);
810 }
811}