Skip to main content

trusty_memory/chat/
handler.rs

1//! The SSE-streaming chat handler (`chat_handler`).
2//!
3//! Why: the OpenRouter/Ollama tool-calling loop is by far the largest single
4//! concern in the chat HTTP surface; isolating it keeps the other handlers
5//! readable (split out of the former monolithic `chat.rs`, issue #607).
6//! What: the `chat_handler` axum handler, moved verbatim. Tool-loop building
7//! blocks (`all_tools`, `execute_tool`, `MAX_TOOL_ROUNDS`, `ChatBody`) are
8//! pulled in from the sibling `tools` submodule.
9//! Test: behaviour exercised end-to-end via the chat SSE integration paths.
10
11use crate::web::load_user_config;
12use crate::AppState;
13use axum::{
14    body::Body,
15    extract::State,
16    http::StatusCode,
17    response::{IntoResponse, Response},
18    Json,
19};
20use serde_json::{json, Value};
21use trusty_common::memory_core::palace::PalaceId;
22use trusty_common::memory_core::retrieval::recall_with_default_embedder;
23use trusty_common::memory_core::PalaceRegistry;
24use trusty_common::{ChatEvent, ChatMessage};
25
26// ---------------------------------------------------------------------------
27
28use super::tools::{all_tools, execute_get_dream_status, execute_tool, ChatBody, MAX_TOOL_ROUNDS};
29
30pub(crate) async fn chat_handler(
31    State(state): State<AppState>,
32    Json(body): Json<ChatBody>,
33) -> Response {
34    // Select the active provider (Ollama auto-detect, else OpenRouter).
35    let Some(provider) = state.chat_provider().await else {
36        return (
37            StatusCode::PRECONDITION_FAILED,
38            "No chat provider configured (no local Ollama detected and no OpenRouter key set)",
39        )
40            .into_response();
41    };
42
43    // Resolve palace id (explicit > default).
44    let palace_id = body
45        .palace_id
46        .clone()
47        .or_else(|| state.default_palace.clone())
48        .unwrap_or_default();
49
50    // Resolve / create chat session when a palace is bound.
51    let (session_id, mut history): (Option<String>, Vec<ChatMessage>) = if !palace_id.is_empty() {
52        let store = match state.session_store(&palace_id) {
53            Ok(s) => s,
54            Err(e) => {
55                tracing::warn!(palace = %palace_id, "session_store open failed: {e:#}");
56                return (
57                    StatusCode::INTERNAL_SERVER_ERROR,
58                    format!("session store: {e:#}"),
59                )
60                    .into_response();
61            }
62        };
63        match body.session_id.clone() {
64            Some(sid) => match store.get_session(&sid) {
65                Ok(Some(s)) => (
66                    Some(sid),
67                    s.history
68                        .into_iter()
69                        .map(|m| ChatMessage {
70                            role: m.role,
71                            content: m.content,
72                            tool_call_id: None,
73                            tool_calls: None,
74                        })
75                        .collect(),
76                ),
77                _ => (Some(sid), body.history.clone()),
78            },
79            None => {
80                let new_id = store.create_session(None).unwrap_or_else(|e| {
81                    tracing::warn!("create_session failed: {e:#}");
82                    String::new()
83                });
84                (
85                    if new_id.is_empty() {
86                        None
87                    } else {
88                        Some(new_id)
89                    },
90                    body.history.clone(),
91                )
92            }
93        }
94    } else {
95        (None, body.history.clone())
96    };
97
98    // Full palace roster for the identity block — names + ids, not just count,
99    // so the model can pick the right one when the user names a palace.
100    let all_palaces = PalaceRegistry::list_palaces(&state.data_root).unwrap_or_default();
101    let palace_count = all_palaces.len();
102    let palace_roster: String = all_palaces
103        .iter()
104        .map(|p| format!("- {} (id: {})", p.name, p.id.0))
105        .collect::<Vec<_>>()
106        .join("\n");
107
108    // Config + global dream snapshot — give the model an honest view of what's
109    // available so it doesn't invent tools or providers that aren't there.
110    let cfg = load_user_config().unwrap_or_default();
111    let active_provider_name = state
112        .chat_provider()
113        .await
114        .map(|p| p.name().to_string())
115        .unwrap_or_else(|| "none".to_string());
116    let dream_snapshot = execute_get_dream_status(&state).await;
117
118    // Look up the selected palace's metadata (name/description) and open its
119    // handle for live counts + recall context.
120    let selected_palace_meta = if palace_id.is_empty() {
121        None
122    } else {
123        all_palaces.iter().find(|p| p.id.0 == palace_id).cloned()
124    };
125
126    let mut palace_block = String::new();
127    let mut context = String::new();
128    let mut palace_display_name = palace_id.clone();
129
130    if !palace_id.is_empty() {
131        if let Ok(handle) = state
132            .registry
133            .open_palace(&state.data_root, &PalaceId::new(&palace_id))
134        {
135            // Live counts from the opened handle.
136            let drawer_count = handle.drawers.read().len();
137            let vector_count = handle.vector_store.index_size();
138            let kg_triple_count = handle.kg.count_active_triples();
139
140            // Prefer the on-disk palace.json name/description; fall back to id.
141            let (name, description) = match &selected_palace_meta {
142                Some(p) => (p.name.clone(), p.description.clone()),
143                None => (palace_id.clone(), None),
144            };
145            palace_display_name = name.clone();
146
147            palace_block.push_str(&format!(
148                "Currently selected palace:\n\
149                 - id: {id}\n\
150                 - name: {name}\n",
151                id = palace_id,
152                name = name,
153            ));
154            if let Some(desc) = description.as_deref().filter(|s| !s.is_empty()) {
155                palace_block.push_str(&format!("- description: {desc}\n"));
156            }
157            palace_block.push_str(&format!(
158                "- drawers: {drawer_count}\n\
159                 - vectors: {vector_count}\n\
160                 - kg_triples: {kg_triple_count}\n",
161            ));
162            let identity_trimmed = handle.identity.trim();
163            if !identity_trimmed.is_empty() {
164                palace_block.push_str(&format!("- identity:\n{identity_trimmed}\n",));
165            }
166
167            if let Ok(hits) = recall_with_default_embedder(&handle, &body.message, 5).await {
168                for r in hits.iter().take(5) {
169                    context.push_str(&format!("- (L{}) {}\n", r.layer, r.drawer.content));
170                }
171            }
172        }
173    }
174
175    // Build the grounded system prompt with identity, palace, RAG, config,
176    // dream-snapshot, and behavior blocks so the LLM never confuses
177    // trusty-memory palaces with real-world architectural palaces.
178    let mut system = String::new();
179    system.push_str(&format!(
180        "You are the assistant for trusty-memory, a machine-wide AI memory \
181         service running locally on this user's machine. trusty-memory stores \
182         knowledge in named \"palaces\" — isolated memory namespaces, each with \
183         its own vector index (usearch HNSW) and temporal knowledge graph \
184         (redb). Memories are organized as Palace -> Wing -> Room -> Closet \
185         -> Drawer, where a Drawer is an atomic memory unit.\n\
186         There are currently {palace_count} palace(s) on this machine.\n",
187    ));
188    if !palace_roster.is_empty() {
189        system.push_str(&format!("Palaces:\n{palace_roster}\n"));
190    }
191    system.push('\n');
192
193    // Config block — what providers/models are wired up right now.
194    system.push_str(&format!(
195        "System configuration:\n\
196         - active chat provider: {active_provider_name}\n\
197         - openrouter model: {or_model}\n\
198         - local model: {local_model} ({local_url}, enabled={local_enabled})\n\
199         - data root: {data_root}\n\n",
200        or_model = cfg.openrouter_model,
201        local_model = cfg.local_model.model,
202        local_url = cfg.local_model.base_url,
203        local_enabled = cfg.local_model.enabled,
204        data_root = state.data_root.display(),
205    ));
206
207    // Dream snapshot — give the model a sense of how stale memory state is.
208    system.push_str(&format!(
209        "Global dream status (background memory maintenance):\n{}\n\n",
210        dream_snapshot,
211    ));
212
213    if !palace_block.is_empty() {
214        system.push_str(&palace_block);
215        system.push('\n');
216    }
217
218    if !context.is_empty() {
219        system.push_str(&format!(
220            "Relevant memories from the '{palace_display_name}' palace \
221             (L0 = identity, L1 = essentials, L2 = topic-filtered, L3 = deep):\n\
222             {context}\n",
223        ));
224    }
225
226    system.push_str(
227        "You have a set of tools to introspect and modify this trusty-memory \
228         daemon. Prefer calling a tool over guessing — e.g. call \
229         `list_palaces` rather than relying on the roster above if you need \
230         live counts, and call `recall_memories` to search for facts you \
231         don't have in context. When the user asks about \"palaces\", they \
232         mean trusty-memory palaces (memory namespaces on this machine), not \
233         architectural palaces like Versailles. If a tool returns an error, \
234         report it honestly and don't fabricate results.",
235    );
236
237    // Append the new user message to the in-memory history we'll persist.
238    history.push(ChatMessage {
239        role: "user".to_string(),
240        content: body.message.clone(),
241        tool_call_id: None,
242        tool_calls: None,
243    });
244
245    let mut messages: Vec<ChatMessage> = Vec::with_capacity(history.len() + 1);
246    messages.push(ChatMessage {
247        role: "system".to_string(),
248        content: system,
249        tool_call_id: None,
250        tool_calls: None,
251    });
252    messages.extend(history.iter().cloned());
253
254    let tools = all_tools();
255    let (sse_tx, sse_rx) =
256        tokio::sync::mpsc::channel::<Result<axum::body::Bytes, std::io::Error>>(64);
257
258    // Capture session-persistence inputs.
259    let session_store = if !palace_id.is_empty() && session_id.is_some() {
260        state.session_store(&palace_id).ok()
261    } else {
262        None
263    };
264    let persist_session_id = session_id.clone();
265
266    // Drive the tool-execution loop in a background task so the response can
267    // start streaming immediately.
268    let loop_state = state.clone();
269    tokio::spawn(async move {
270        // Emit a leading session_id frame so the SPA can correlate this stream
271        // with a persisted session row.
272        if let Some(sid) = persist_session_id.as_deref() {
273            let frame = format!("data: {}\n\n", json!({ "session_id": sid }));
274            if sse_tx
275                .send(Ok(axum::body::Bytes::from(frame)))
276                .await
277                .is_err()
278            {
279                return;
280            }
281        }
282
283        let mut final_assistant_text = String::new();
284        let mut stream_err: Option<String> = None;
285
286        for round in 0..MAX_TOOL_ROUNDS {
287            let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<ChatEvent>(256);
288            let messages_clone = messages.clone();
289            let tools_clone = tools.clone();
290            let provider_clone = provider.clone();
291            let stream_handle = tokio::spawn(async move {
292                provider_clone
293                    .chat_stream(messages_clone, tools_clone, event_tx)
294                    .await
295            });
296
297            let mut tool_calls_this_round: Vec<trusty_common::ToolCall> = Vec::new();
298            let mut round_assistant_text = String::new();
299
300            while let Some(event) = event_rx.recv().await {
301                match event {
302                    ChatEvent::Delta(text) => {
303                        round_assistant_text.push_str(&text);
304                        let frame = format!("data: {}\n\n", json!({ "delta": text }));
305                        if sse_tx
306                            .send(Ok(axum::body::Bytes::from(frame)))
307                            .await
308                            .is_err()
309                        {
310                            return;
311                        }
312                    }
313                    ChatEvent::ToolCall(tc) => {
314                        let frame = format!(
315                            "data: {}\n\n",
316                            json!({ "tool_call": {
317                                "id": tc.id,
318                                "name": tc.name,
319                                "arguments": tc.arguments,
320                            }})
321                        );
322                        let _ = sse_tx.send(Ok(axum::body::Bytes::from(frame))).await;
323                        tool_calls_this_round.push(tc);
324                    }
325                    ChatEvent::Done => break,
326                    ChatEvent::Error(e) => {
327                        stream_err = Some(e);
328                        break;
329                    }
330                }
331            }
332
333            // Drain the spawned stream task; surface any error.
334            match stream_handle.await {
335                Ok(Ok(())) => {}
336                Ok(Err(e)) => stream_err = Some(e.to_string()),
337                Err(e) => stream_err = Some(format!("join: {e}")),
338            }
339
340            if stream_err.is_some() {
341                break;
342            }
343
344            final_assistant_text.push_str(&round_assistant_text);
345
346            if tool_calls_this_round.is_empty() {
347                // Model produced a plain answer — we're done.
348                break;
349            }
350
351            // Build the assistant message that requested these tool calls.
352            let assistant_tool_calls_json: Vec<Value> = tool_calls_this_round
353                .iter()
354                .map(|tc| {
355                    json!({
356                        "id": tc.id,
357                        "type": "function",
358                        "function": { "name": tc.name, "arguments": tc.arguments },
359                    })
360                })
361                .collect();
362            messages.push(ChatMessage {
363                role: "assistant".to_string(),
364                content: round_assistant_text,
365                tool_call_id: None,
366                tool_calls: Some(assistant_tool_calls_json),
367            });
368
369            // Execute each tool and append its result as a `role: "tool"`
370            // message. The next loop iteration feeds these back to the model.
371            for tc in &tool_calls_this_round {
372                let result = execute_tool(&tc.name, &tc.arguments, &loop_state).await;
373                let result_str = result.to_string();
374                let frame = format!(
375                    "data: {}\n\n",
376                    json!({ "tool_result": {
377                        "id": tc.id,
378                        "name": tc.name,
379                        "content": &result_str,
380                    }})
381                );
382                let _ = sse_tx.send(Ok(axum::body::Bytes::from(frame))).await;
383                messages.push(ChatMessage {
384                    role: "tool".to_string(),
385                    content: result_str,
386                    tool_call_id: Some(tc.id.clone()),
387                    tool_calls: None,
388                });
389            }
390
391            // Safety net: log when we walk off the round limit.
392            if round + 1 == MAX_TOOL_ROUNDS {
393                tracing::warn!(
394                    "chat: hit MAX_TOOL_ROUNDS={} — terminating tool loop",
395                    MAX_TOOL_ROUNDS
396                );
397            }
398        }
399
400        // Persist the completed conversation regardless of streaming error
401        // (partial assistant reply still better than nothing).
402        if let (Some(store), Some(sid)) = (session_store, persist_session_id.as_deref()) {
403            if !final_assistant_text.is_empty() {
404                history.push(ChatMessage {
405                    role: "assistant".into(),
406                    content: final_assistant_text,
407                    tool_call_id: None,
408                    tool_calls: None,
409                });
410            }
411            let core_history: Vec<trusty_common::memory_core::store::chat_sessions::ChatMessage> =
412                history
413                    .iter()
414                    .map(
415                        |m| trusty_common::memory_core::store::chat_sessions::ChatMessage {
416                            role: m.role.clone(),
417                            content: m.content.clone(),
418                        },
419                    )
420                    .collect();
421            if let Err(e) = store.upsert_session(sid, &core_history) {
422                tracing::warn!("upsert_session failed: {e:#}");
423            }
424        }
425
426        match stream_err {
427            None => {
428                let _ = sse_tx
429                    .send(Ok(axum::body::Bytes::from("data: [DONE]\n\n")))
430                    .await;
431            }
432            Some(e) => {
433                let out = format!("data: {}\n\n", json!({ "error": e }));
434                let _ = sse_tx.send(Ok(axum::body::Bytes::from(out))).await;
435            }
436        }
437    });
438
439    let stream = tokio_stream::wrappers::ReceiverStream::new(sse_rx);
440
441    Response::builder()
442        .header("Content-Type", "text/event-stream")
443        .header("Cache-Control", "no-cache")
444        .body(Body::from_stream(stream))
445        .expect("static SSE response builds")
446}