1use std::collections::HashSet;
4use std::fmt::Write as _;
5
6use crate::llm::{ChatClient, ChatCompletionOutput, ChatMessage, strip_code_fences};
7use crate::query::AskSource;
8use crate::text::nfd;
9
10use super::error::AskError;
11use super::types::AskPlanBody;
12
13const ASK_TEMPERATURE: f32 = 0.0;
14
15const PLAN_SYSTEM_PROMPT: &str = "Return only valid JSON of the form \
16 {\"queries\":[\"...\"]}. Generate 3 to 6 concise search queries for an \
17 Obsidian vault. Prefer concrete domain terms, likely note titles, aliases, \
18 and useful synonyms. Do not explain.";
19
20const ANSWER_SYSTEM_PROMPT: &str = "Answer the user's question using the vault \
21 snippets provided. Keep the answer compact and practical. Do not invent \
22 facts that are not supported by the snippets; if the snippets are thin or \
23 conflicting, say so briefly. Do not use a forced citation style.";
24
25#[derive(Debug, Clone)]
27pub struct AskClient {
28 planning_chat: ChatClient,
29 synthesis_chat: ChatClient,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct AskPlan {
35 pub queries: Vec<String>,
37 pub content: String,
39 pub reasoning_content: Option<String>,
41 pub raw_response: String,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct AskSynthesis {
48 pub answer: String,
50 pub content: String,
52 pub reasoning_content: Option<String>,
54 pub raw_response: String,
56}
57
58impl AskClient {
59 #[must_use]
61 pub fn new(chat: ChatClient) -> Self {
62 Self {
63 planning_chat: chat.clone(),
64 synthesis_chat: chat,
65 }
66 }
67
68 #[must_use]
70 pub const fn with_stage_clients(planning_chat: ChatClient, synthesis_chat: ChatClient) -> Self {
71 Self {
72 planning_chat,
73 synthesis_chat,
74 }
75 }
76
77 pub fn plan_queries(&self, question: &str, limit: u8) -> Result<Vec<String>, AskError> {
84 self.plan_queries_detailed(question, limit)
85 .map(|plan| plan.queries)
86 }
87
88 pub fn plan_queries_detailed(&self, question: &str, limit: u8) -> Result<AskPlan, AskError> {
95 let output = self.planning_chat.complete_raw(
96 vec![
97 ChatMessage::new("system", PLAN_SYSTEM_PROMPT),
98 ChatMessage::new("user", format!("Question: {question}")),
99 ],
100 ASK_TEMPERATURE,
101 )?;
102 let cleaned = strip_code_fences(&output.content);
103 let body: AskPlanBody = match serde_json::from_str(&cleaned) {
104 Ok(body) => body,
105 Err(_) => {
106 return Ok(AskPlan {
107 queries: Vec::new(),
108 content: output.content,
109 reasoning_content: output.reasoning_content,
110 raw_response: output.raw_response,
111 });
112 }
113 };
114 Ok(AskPlan {
115 queries: normalize_queries(question, body.queries, limit),
116 content: output.content,
117 reasoning_content: output.reasoning_content,
118 raw_response: output.raw_response,
119 })
120 }
121
122 pub fn synthesize(
128 &self,
129 question: &str,
130 queries: &[String],
131 sources: &[AskSource],
132 ) -> Result<String, AskError> {
133 self.synthesize_detailed(question, queries, sources)
134 .map(|synthesis| synthesis.answer)
135 }
136
137 pub fn synthesize_detailed(
143 &self,
144 question: &str,
145 queries: &[String],
146 sources: &[AskSource],
147 ) -> Result<AskSynthesis, AskError> {
148 let user_message = build_answer_user_message(question, queries, sources);
149 let output: ChatCompletionOutput = self.synthesis_chat.complete_raw(
150 vec![
151 ChatMessage::new("system", ANSWER_SYSTEM_PROMPT),
152 ChatMessage::new("user", user_message),
153 ],
154 ASK_TEMPERATURE,
155 )?;
156 Ok(AskSynthesis {
157 answer: output.content.trim().to_owned(),
158 content: output.content,
159 reasoning_content: output.reasoning_content,
160 raw_response: output.raw_response,
161 })
162 }
163
164 #[must_use]
166 pub fn model(&self) -> &str {
167 self.planning_chat.model()
168 }
169
170 #[must_use]
172 pub fn base_url(&self) -> &str {
173 self.planning_chat.base_url()
174 }
175}
176
177fn build_answer_user_message(question: &str, queries: &[String], sources: &[AskSource]) -> String {
178 let mut message = format!("Question:\n{question}\n\nSearch queries:\n");
179 for query in queries {
180 message.push_str("- ");
181 message.push_str(query);
182 message.push('\n');
183 }
184 message.push_str("\nVault snippets:\n");
185 for (index, source) in sources.iter().enumerate() {
186 let _ = writeln!(
187 message,
188 "[{}] {}\nTitle: {}\nScore: {:.3}\nSnippet: {}\n",
189 index + 1,
190 source.vault_path.as_str(),
191 source.title.as_str(),
192 source.score,
193 source.snippet.as_str()
194 );
195 }
196 message
197}
198
199fn normalize_queries(original: &str, queries: Vec<String>, limit: u8) -> Vec<String> {
200 let normalized_original = nfd::normalize(original.trim()).to_lowercase();
201 let limit = usize::from(limit);
202 let mut seen: HashSet<String> = HashSet::new();
203 let mut result = Vec::with_capacity(limit);
204 for candidate in queries {
205 let trimmed = candidate.trim().to_owned();
206 if trimmed.is_empty() {
207 continue;
208 }
209 let normalized = nfd::normalize(&trimmed).to_lowercase();
210 if normalized != normalized_original && seen.insert(normalized) {
211 result.push(trimmed);
212 if result.len() >= limit {
213 break;
214 }
215 }
216 }
217 result
218}
219
220#[cfg(test)]
221mod tests;