Skip to main content

trusty_common/chat/
mod.rs

1//! Provider-agnostic streaming chat abstraction with tool-use support.
2//!
3//! Why: trusty-memory and trusty-search both want to support more than one
4//! upstream LLM (OpenRouter for cloud, Ollama / LM Studio for local). Rather
5//! than each crate re-implementing the dispatch, we expose a small
6//! [`ChatProvider`] trait plus concrete implementations and an auto-detector
7//! for a running local model server. The trait also surfaces OpenAI-style
8//! tool/function calling so downstream agents can let the model invoke tools
9//! (search, memory recall, shell, etc.).
10//!
11//! What: defines the [`ChatProvider`] trait, [`ToolDef`] / [`ToolCall`] /
12//! [`ChatEvent`] tool-use types, an [`OpenRouterProvider`] and an
13//! [`OllamaProvider`] that both speak OpenAI-compatible
14//! `/v1/chat/completions` with SSE streaming (including the streamed
15//! `tool_calls` shape), a [`BedrockProvider`] that uses the AWS Bedrock
16//! `Converse` API (behind the `bedrock` feature flag), and
17//! [`auto_detect_local_provider`] which probes `{base_url}/v1/models` with a
18//! 1-second timeout.
19//!
20//! Test: `cargo test -p trusty-common` covers default config values, the
21//! unreachable-server path of `auto_detect_local_provider`, SSE delta
22//! streaming, and accumulation of streamed tool-call fragments.
23
24mod openai_compat;
25
26#[cfg(feature = "bedrock")]
27mod bedrock_impl;
28#[cfg(not(feature = "bedrock"))]
29mod bedrock_stub;
30
31pub use openai_compat::{OllamaProvider, OpenRouterProvider, auto_detect_local_provider};
32
33#[cfg(feature = "bedrock")]
34pub use bedrock_impl::{
35    BedrockProvider, DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION, ENV_REGION_AWS,
36    ENV_REGION_TRUSTY,
37};
38
39// Re-expose the bedrock_impl module as `bedrock_provider` so downstream
40// crates can access constants (e.g. `DEFAULT_BEDROCK_MODEL`) without needing
41// to depend on the bedrock feature themselves.
42#[cfg(feature = "bedrock")]
43pub mod bedrock_provider {
44    pub use super::bedrock_impl::*;
45}
46
47#[cfg(not(feature = "bedrock"))]
48pub use bedrock_stub::BedrockProvider;
49
50// Stub constant so code that references DEFAULT_BEDROCK_MODEL compiles without
51// the bedrock feature. Must stay in sync with bedrock_impl::DEFAULT_BEDROCK_MODEL.
52// Claude Sonnet 4.6 drops the date stamp and -v1:0 suffix (verified vs AWS docs).
53#[cfg(not(feature = "bedrock"))]
54pub const DEFAULT_BEDROCK_MODEL: &str = "us.anthropic.claude-sonnet-4-6";
55
56use crate::ChatMessage;
57use anyhow::Result;
58use async_trait::async_trait;
59use serde::{Deserialize, Serialize};
60use tokio::sync::mpsc::Sender;
61
62// ── Public re-exports so callers get the full surface from `chat::*` ──────────
63
64/// Configuration for a local OpenAI-compatible model server (Ollama, LM
65/// Studio, llama.cpp's server, etc.).
66///
67/// Why: callers want a single struct they can deserialize from config files
68/// and pass to [`auto_detect_local_provider`] without juggling defaults.
69/// What: holds an enable flag, the server's base URL (no trailing slash),
70/// and the default model to request. Defaults target Ollama's standard
71/// localhost binding.
72/// Test: `local_model_config_defaults` asserts the default values.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LocalModelConfig {
75    pub enabled: bool,
76    pub base_url: String,
77    pub model: String,
78}
79
80impl Default for LocalModelConfig {
81    fn default() -> Self {
82        Self {
83            enabled: true,
84            base_url: "http://localhost:11434".to_string(),
85            model: "qwen3:30b".to_string(),
86        }
87    }
88}
89
90// ─── Tool-use types ───────────────────────────────────────────────────────────
91
92/// JSON-Schema description of a callable tool, in OpenAI function-calling
93/// shape.
94///
95/// Why: downstream agents (trusty-memory, trusty-search) expose tools like
96/// `memory_recall` or `web_search` to the LLM. The OpenAI tool format is the
97/// de-facto common denominator across OpenRouter, Ollama, LM Studio, and
98/// most cloud providers.
99/// What: `name` and `description` are passed verbatim; `parameters` is a
100/// JSON Schema object (typically `{"type":"object","properties":{...}}`).
101/// Test: `tool_def_serializes_as_function` checks the wire shape.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ToolDef {
104    pub name: String,
105    pub description: String,
106    pub parameters: serde_json::Value,
107}
108
109/// A tool invocation the model wants the host to perform.
110///
111/// Why: the streaming chat API emits `tool_calls` in fragments — first an
112/// `id` + `function.name`, then a string of `function.arguments` deltas.
113/// We accumulate fragments and surface one fully-formed [`ToolCall`] per
114/// invocation to the caller.
115/// What: `id` is the upstream's call id (echoed back in subsequent
116/// `role:"tool"` messages); `name` is the function name; `arguments` is a
117/// JSON string (NOT a parsed value — many models emit malformed JSON and
118/// callers want the raw text for error reporting / repair).
119/// Test: `accumulates_streamed_tool_call_fragments`.
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
121pub struct ToolCall {
122    pub id: String,
123    pub name: String,
124    pub arguments: String,
125}
126
127/// Streaming chat event.
128///
129/// Why: replaces the previous "string-only" channel so callers can
130/// distinguish text deltas from tool invocations and from terminal
131/// success/error without parsing magic markers out of the text stream.
132/// What: `Delta` is a content chunk; `ToolCall` is a fully-accumulated tool
133/// invocation; `Done` signals the upstream stream terminated normally;
134/// `Error` carries a human-readable message for stream-mid failures (the
135/// provider also returns `Err` from `chat_stream`, but `Error` lets the
136/// caller display partial-stream failures inline).
137/// Test: `ollama_provider_streams_sse_deltas`.
138#[derive(Debug, Clone)]
139pub enum ChatEvent {
140    Delta(String),
141    ToolCall(ToolCall),
142    Done,
143    Error(String),
144}
145
146/// Streaming chat provider abstraction.
147///
148/// Why: downstream crates (trusty-memory, trusty-search) want to support
149/// multiple LLM backends without hard-coding which one to call. Providers
150/// expose a uniform streaming interface so the caller can swap them at
151/// runtime based on configuration / availability.
152/// What: implementors stream [`ChatEvent`]s into `tx`. Pass an empty
153/// `tools` vec to disable tool use entirely (the provider MUST then omit
154/// the `tools` field from the upstream request — some models error on an
155/// empty array). Returning `Ok(())` means the stream completed normally;
156/// the caller should also expect a final [`ChatEvent::Done`].
157/// Test: implementations are covered by their own unit tests in this
158/// module plus integration tests in downstream crates.
159#[async_trait]
160pub trait ChatProvider: Send + Sync {
161    /// Human-readable provider name (e.g. `"openrouter"`, `"ollama"`).
162    fn name(&self) -> &str;
163    /// Model identifier sent on every request.
164    fn model(&self) -> &str;
165    /// Stream chat events into `tx`. `tools` empty disables tool use.
166    async fn chat_stream(
167        &self,
168        messages: Vec<ChatMessage>,
169        tools: Vec<ToolDef>,
170        tx: Sender<ChatEvent>,
171    ) -> Result<()>;
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn local_model_config_defaults() {
180        let cfg = LocalModelConfig::default();
181        assert!(cfg.enabled);
182        assert_eq!(cfg.base_url, "http://localhost:11434");
183        assert_eq!(cfg.model, "qwen3:30b");
184    }
185
186    #[test]
187    fn local_model_config_deserializes_from_toml() {
188        let toml_src = r#"
189            enabled = true
190            base_url = "http://localhost:1234"
191            model = "qwen2.5-coder"
192        "#;
193        let cfg: LocalModelConfig = toml::from_str(toml_src).expect("parse TOML");
194        assert!(cfg.enabled);
195        assert_eq!(cfg.base_url, "http://localhost:1234");
196        assert_eq!(cfg.model, "qwen2.5-coder");
197    }
198}