Skip to main content

talon_core/ask/
client.rs

1//! Ask-mode LLM client built on the shared chat-completions client.
2
3use std::collections::HashSet;
4use std::fmt::Write as _;
5
6use crate::llm::{ChatClient, ChatCompletionOutput, ChatMessage, strip_code_fences};
7use crate::query::AskSource;
8use crate::text::nfd;
9
10use super::error::AskError;
11use super::types::AskPlanBody;
12
13const ASK_TEMPERATURE: f32 = 0.0;
14
15const PLAN_SYSTEM_PROMPT: &str = "Return only valid JSON of the form \
16    {\"queries\":[\"...\"]}. Generate 3 to 6 concise search queries for an \
17    Obsidian vault. Prefer concrete domain terms, likely note titles, aliases, \
18    and useful synonyms. Do not explain.";
19
20const ANSWER_SYSTEM_PROMPT: &str = "Answer the user's question using the vault \
21    snippets provided. Keep the answer compact and practical. Do not invent \
22    facts that are not supported by the snippets; if the snippets are thin or \
23    conflicting, say so briefly. Do not use a forced citation style.";
24
25/// High-level client for `talon ask` planning and synthesis.
26#[derive(Debug, Clone)]
27pub struct AskClient {
28    planning_chat: ChatClient,
29    synthesis_chat: ChatClient,
30}
31
32/// Query-planning result with raw LLM output for diagnostics.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct AskPlan {
35    /// Normalized search queries.
36    pub queries: Vec<String>,
37    /// Raw visible content returned by the model.
38    pub content: String,
39    /// Optional hidden/thinking trace returned by the model server.
40    pub reasoning_content: Option<String>,
41    /// Raw JSON response body returned by the model server.
42    pub raw_response: String,
43}
44
45/// Answer synthesis result with raw LLM output for diagnostics.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct AskSynthesis {
48    /// Trimmed answer text.
49    pub answer: String,
50    /// Raw visible content returned by the model.
51    pub content: String,
52    /// Optional hidden/thinking trace returned by the model server.
53    pub reasoning_content: Option<String>,
54    /// Raw JSON response body returned by the model server.
55    pub raw_response: String,
56}
57
58impl AskClient {
59    /// Builds an ask client from an existing chat client.
60    #[must_use]
61    pub fn new(chat: ChatClient) -> Self {
62        Self {
63            planning_chat: chat.clone(),
64            synthesis_chat: chat,
65        }
66    }
67
68    /// Builds an ask client with distinct planner and synthesis clients.
69    #[must_use]
70    pub const fn with_stage_clients(planning_chat: ChatClient, synthesis_chat: ChatClient) -> Self {
71        Self {
72            planning_chat,
73            synthesis_chat,
74        }
75    }
76
77    /// Plans search queries for a broad natural-language question.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`AskError`] for transport failures. Malformed planner JSON
82    /// gracefully falls back to the original question.
83    pub fn plan_queries(&self, question: &str, limit: u8) -> Result<Vec<String>, AskError> {
84        self.plan_queries_detailed(question, limit)
85            .map(|plan| plan.queries)
86    }
87
88    /// Plans search queries and returns raw model output for diagnostics.
89    ///
90    /// # Errors
91    ///
92    /// Returns [`AskError`] for transport failures. Malformed planner JSON
93    /// gracefully falls back to the original question.
94    pub fn plan_queries_detailed(&self, question: &str, limit: u8) -> Result<AskPlan, AskError> {
95        let output = self.planning_chat.complete_raw(
96            vec![
97                ChatMessage::new("system", PLAN_SYSTEM_PROMPT),
98                ChatMessage::new("user", format!("Question: {question}")),
99            ],
100            ASK_TEMPERATURE,
101        )?;
102        let cleaned = strip_code_fences(&output.content);
103        let body: AskPlanBody = match serde_json::from_str(&cleaned) {
104            Ok(body) => body,
105            Err(_) => {
106                return Ok(AskPlan {
107                    queries: Vec::new(),
108                    content: output.content,
109                    reasoning_content: output.reasoning_content,
110                    raw_response: output.raw_response,
111                });
112            }
113        };
114        Ok(AskPlan {
115            queries: normalize_queries(question, body.queries, limit),
116            content: output.content,
117            reasoning_content: output.reasoning_content,
118            raw_response: output.raw_response,
119        })
120    }
121
122    /// Synthesizes an answer from ranked vault snippets.
123    ///
124    /// # Errors
125    ///
126    /// Returns [`AskError`] for chat transport or response-shape failures.
127    pub fn synthesize(
128        &self,
129        question: &str,
130        queries: &[String],
131        sources: &[AskSource],
132    ) -> Result<String, AskError> {
133        self.synthesize_detailed(question, queries, sources)
134            .map(|synthesis| synthesis.answer)
135    }
136
137    /// Synthesizes an answer and returns raw model output for diagnostics.
138    ///
139    /// # Errors
140    ///
141    /// Returns [`AskError`] for chat transport or response-shape failures.
142    pub fn synthesize_detailed(
143        &self,
144        question: &str,
145        queries: &[String],
146        sources: &[AskSource],
147    ) -> Result<AskSynthesis, AskError> {
148        let user_message = build_answer_user_message(question, queries, sources);
149        let output: ChatCompletionOutput = self.synthesis_chat.complete_raw(
150            vec![
151                ChatMessage::new("system", ANSWER_SYSTEM_PROMPT),
152                ChatMessage::new("user", user_message),
153            ],
154            ASK_TEMPERATURE,
155        )?;
156        Ok(AskSynthesis {
157            answer: output.content.trim().to_owned(),
158            content: output.content,
159            reasoning_content: output.reasoning_content,
160            raw_response: output.raw_response,
161        })
162    }
163
164    /// Returns the configured ask model.
165    #[must_use]
166    pub fn model(&self) -> &str {
167        self.planning_chat.model()
168    }
169
170    /// Returns the configured ask endpoint.
171    #[must_use]
172    pub fn base_url(&self) -> &str {
173        self.planning_chat.base_url()
174    }
175}
176
177fn build_answer_user_message(question: &str, queries: &[String], sources: &[AskSource]) -> String {
178    let mut message = format!("Question:\n{question}\n\nSearch queries:\n");
179    for query in queries {
180        message.push_str("- ");
181        message.push_str(query);
182        message.push('\n');
183    }
184    message.push_str("\nVault snippets:\n");
185    for (index, source) in sources.iter().enumerate() {
186        let _ = writeln!(
187            message,
188            "[{}] {}\nTitle: {}\nScore: {:.3}\nSnippet: {}\n",
189            index + 1,
190            source.vault_path.as_str(),
191            source.title.as_str(),
192            source.score,
193            source.snippet.as_str()
194        );
195    }
196    message
197}
198
199fn normalize_queries(original: &str, queries: Vec<String>, limit: u8) -> Vec<String> {
200    let normalized_original = nfd::normalize(original.trim()).to_lowercase();
201    let limit = usize::from(limit);
202    let mut seen: HashSet<String> = HashSet::new();
203    let mut result = Vec::with_capacity(limit);
204    for candidate in queries {
205        let trimmed = candidate.trim().to_owned();
206        if trimmed.is_empty() {
207            continue;
208        }
209        let normalized = nfd::normalize(&trimmed).to_lowercase();
210        if normalized != normalized_original && seen.insert(normalized) {
211            result.push(trimmed);
212            if result.len() >= limit {
213                break;
214            }
215        }
216    }
217    result
218}
219
220#[cfg(test)]
221mod tests;