Skip to main content

tuitbot_server/routes/vault/
mod.rs

1//! Vault API endpoints for searching notes, previewing fragments,
2//! and resolving selected references from the dashboard.
3//!
4//! All endpoints are account-scoped via `AccountContext` and return
5//! privacy-safe responses (no raw note bodies — only titles, paths,
6//! tags, heading paths, and truncated snippets).
7
8pub mod selections;
9
10use std::sync::Arc;
11
12use axum::extract::{Path, Query, State};
13use axum::Json;
14use serde::{Deserialize, Serialize};
15
16use tuitbot_core::context::retrieval::{self, VaultCitation};
17use tuitbot_core::storage::watchtower;
18
19use crate::account::AccountContext;
20use crate::error::ApiError;
21use crate::state::AppState;
22
23/// Maximum snippet length returned in API responses (characters).
24const SNIPPET_MAX_LEN: usize = 120;
25
26/// Default result limit for search endpoints.
27const DEFAULT_LIMIT: u32 = 20;
28
29/// Maximum result limit for search endpoints.
30const MAX_LIMIT: u32 = 100;
31
32fn clamp_limit(limit: Option<u32>) -> u32 {
33    limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT)
34}
35
36fn truncate_snippet(text: &str, max_len: usize) -> String {
37    if text.len() <= max_len {
38        text.to_string()
39    } else {
40        let mut end = max_len.saturating_sub(3);
41        while end > 0 && !text.is_char_boundary(end) {
42            end -= 1;
43        }
44        format!("{}...", &text[..end])
45    }
46}
47
48// ---------------------------------------------------------------------------
49// GET /api/vault/sources
50// ---------------------------------------------------------------------------
51
52#[derive(Serialize)]
53pub struct VaultSourcesResponse {
54    pub sources: Vec<VaultSourceStatusItem>,
55    pub deployment_mode: String,
56    pub privacy_envelope: String,
57}
58
59#[derive(Serialize)]
60pub struct VaultSourceStatusItem {
61    pub id: i64,
62    pub source_type: String,
63    pub status: String,
64    pub error_message: Option<String>,
65    pub node_count: i64,
66    pub updated_at: String,
67    /// For `local_fs` sources, the configured vault path.  Used by the
68    /// desktop frontend to construct `obsidian://` deep-link URIs.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub path: Option<String>,
71}
72
73pub async fn vault_sources(
74    State(state): State<Arc<AppState>>,
75    ctx: AccountContext,
76) -> Result<Json<VaultSourcesResponse>, ApiError> {
77    let sources = watchtower::get_all_source_contexts_for(&state.db, &ctx.account_id).await?;
78
79    let is_cloud = matches!(
80        state.deployment_mode,
81        tuitbot_core::config::DeploymentMode::Cloud
82    );
83
84    let mut items = Vec::with_capacity(sources.len());
85    for src in sources {
86        let count = watchtower::count_nodes_for_source(&state.db, &ctx.account_id, src.id)
87            .await
88            .unwrap_or(0);
89        // Only expose local path for non-Cloud modes (defense in depth).
90        let path = if src.source_type == "local_fs" && !is_cloud {
91            serde_json::from_str::<serde_json::Value>(&src.config_json)
92                .ok()
93                .and_then(|v| v.get("path").and_then(|p| p.as_str().map(String::from)))
94        } else {
95            None
96        };
97        items.push(VaultSourceStatusItem {
98            id: src.id,
99            source_type: src.source_type,
100            status: src.status,
101            error_message: src.error_message,
102            node_count: count,
103            updated_at: src.updated_at,
104            path,
105        });
106    }
107
108    Ok(Json(VaultSourcesResponse {
109        sources: items,
110        deployment_mode: state.deployment_mode.to_string(),
111        privacy_envelope: state.deployment_mode.privacy_envelope().to_string(),
112    }))
113}
114
115// ---------------------------------------------------------------------------
116// GET /api/vault/notes?q=&source_id=&limit=
117// ---------------------------------------------------------------------------
118
119#[derive(Deserialize)]
120pub struct SearchNotesQuery {
121    pub q: Option<String>,
122    pub source_id: Option<i64>,
123    pub limit: Option<u32>,
124}
125
126#[derive(Serialize)]
127pub struct SearchNotesResponse {
128    pub notes: Vec<VaultNoteItem>,
129}
130
131#[derive(Serialize)]
132pub struct VaultNoteItem {
133    pub node_id: i64,
134    pub source_id: i64,
135    pub title: Option<String>,
136    pub relative_path: String,
137    pub tags: Option<String>,
138    pub status: String,
139    pub chunk_count: i64,
140    pub updated_at: String,
141}
142
143pub async fn search_notes(
144    State(state): State<Arc<AppState>>,
145    ctx: AccountContext,
146    Query(params): Query<SearchNotesQuery>,
147) -> Result<Json<SearchNotesResponse>, ApiError> {
148    let limit = clamp_limit(params.limit);
149
150    let nodes = match (&params.q, params.source_id) {
151        (Some(q), _) if !q.is_empty() => {
152            watchtower::search_nodes_for(&state.db, &ctx.account_id, q, limit).await?
153        }
154        (_, Some(sid)) => {
155            watchtower::get_nodes_for_source_for(&state.db, &ctx.account_id, sid, limit).await?
156        }
157        _ => {
158            // No query and no source_id — return recent nodes.
159            watchtower::search_nodes_for(&state.db, &ctx.account_id, "", limit).await?
160        }
161    };
162
163    let mut notes = Vec::with_capacity(nodes.len());
164    for node in nodes {
165        let chunk_count =
166            watchtower::count_chunks_for_node(&state.db, &ctx.account_id, node.id).await?;
167        notes.push(VaultNoteItem {
168            node_id: node.id,
169            source_id: node.source_id,
170            title: node.title,
171            relative_path: node.relative_path,
172            tags: node.tags,
173            status: node.status,
174            chunk_count,
175            updated_at: node.updated_at,
176        });
177    }
178
179    Ok(Json(SearchNotesResponse { notes }))
180}
181
182// ---------------------------------------------------------------------------
183// GET /api/vault/notes/{id}
184// ---------------------------------------------------------------------------
185
186#[derive(Serialize)]
187pub struct VaultNoteDetail {
188    pub node_id: i64,
189    pub source_id: i64,
190    pub title: Option<String>,
191    pub relative_path: String,
192    pub tags: Option<String>,
193    pub status: String,
194    pub ingested_at: String,
195    pub updated_at: String,
196    pub chunks: Vec<VaultChunkSummary>,
197}
198
199#[derive(Serialize)]
200pub struct VaultChunkSummary {
201    pub chunk_id: i64,
202    pub heading_path: String,
203    pub snippet: String,
204    pub retrieval_boost: f64,
205}
206
207pub async fn note_detail(
208    State(state): State<Arc<AppState>>,
209    ctx: AccountContext,
210    Path(id): Path<i64>,
211) -> Result<Json<VaultNoteDetail>, ApiError> {
212    let node = watchtower::get_content_node_for(&state.db, &ctx.account_id, id)
213        .await?
214        .ok_or_else(|| ApiError::NotFound(format!("note {id} not found")))?;
215
216    let chunks = watchtower::get_chunks_for_node(&state.db, &ctx.account_id, id).await?;
217
218    let chunk_summaries: Vec<VaultChunkSummary> = chunks
219        .into_iter()
220        .map(|c| VaultChunkSummary {
221            chunk_id: c.id,
222            heading_path: c.heading_path,
223            snippet: truncate_snippet(&c.chunk_text, SNIPPET_MAX_LEN),
224            retrieval_boost: c.retrieval_boost,
225        })
226        .collect();
227
228    Ok(Json(VaultNoteDetail {
229        node_id: node.id,
230        source_id: node.source_id,
231        title: node.title,
232        relative_path: node.relative_path,
233        tags: node.tags,
234        status: node.status,
235        ingested_at: node.ingested_at,
236        updated_at: node.updated_at,
237        chunks: chunk_summaries,
238    }))
239}
240
241// ---------------------------------------------------------------------------
242// GET /api/vault/search?q=&limit=
243// ---------------------------------------------------------------------------
244
245#[derive(Deserialize)]
246pub struct SearchFragmentsQuery {
247    pub q: String,
248    pub limit: Option<u32>,
249}
250
251#[derive(Serialize)]
252pub struct SearchFragmentsResponse {
253    pub fragments: Vec<VaultCitation>,
254}
255
256pub async fn search_fragments(
257    State(state): State<Arc<AppState>>,
258    ctx: AccountContext,
259    Query(params): Query<SearchFragmentsQuery>,
260) -> Result<Json<SearchFragmentsResponse>, ApiError> {
261    let limit = clamp_limit(params.limit);
262
263    if params.q.is_empty() {
264        return Ok(Json(SearchFragmentsResponse { fragments: vec![] }));
265    }
266
267    let keywords: Vec<String> = params.q.split_whitespace().map(|s| s.to_string()).collect();
268
269    let fragments =
270        retrieval::retrieve_vault_fragments(&state.db, &ctx.account_id, &keywords, None, limit)
271            .await?;
272
273    let citations = retrieval::build_citations(&fragments);
274
275    Ok(Json(SearchFragmentsResponse {
276        fragments: citations,
277    }))
278}
279
280// ---------------------------------------------------------------------------
281// POST /api/vault/resolve-refs
282// ---------------------------------------------------------------------------
283
284#[derive(Deserialize)]
285pub struct ResolveRefsRequest {
286    pub node_ids: Vec<i64>,
287}
288
289#[derive(Serialize)]
290pub struct ResolveRefsResponse {
291    pub citations: Vec<VaultCitation>,
292}
293
294pub async fn resolve_refs(
295    State(state): State<Arc<AppState>>,
296    ctx: AccountContext,
297    Json(body): Json<ResolveRefsRequest>,
298) -> Result<Json<ResolveRefsResponse>, ApiError> {
299    if body.node_ids.is_empty() {
300        return Ok(Json(ResolveRefsResponse { citations: vec![] }));
301    }
302
303    let fragments = retrieval::retrieve_vault_fragments(
304        &state.db,
305        &ctx.account_id,
306        &[],
307        Some(&body.node_ids),
308        MAX_LIMIT,
309    )
310    .await?;
311
312    let citations = retrieval::build_citations(&fragments);
313
314    Ok(Json(ResolveRefsResponse { citations }))
315}
316
317// ---------------------------------------------------------------------------
318// Tests
319// ---------------------------------------------------------------------------
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    use std::collections::HashMap;
326    use std::path::PathBuf;
327
328    use axum::body::Body;
329    use axum::http::{Request, StatusCode};
330    use axum::routing::{get, post};
331    use axum::Router;
332    use tokio::sync::{broadcast, Mutex, RwLock};
333    use tower::ServiceExt;
334
335    use crate::ws::AccountWsEvent;
336
337    async fn test_state() -> Arc<AppState> {
338        let db = tuitbot_core::storage::init_test_db()
339            .await
340            .expect("init test db");
341        let (event_tx, _) = broadcast::channel::<AccountWsEvent>(16);
342        Arc::new(AppState {
343            db,
344            config_path: PathBuf::from("/tmp/test-config.toml"),
345            data_dir: PathBuf::from("/tmp"),
346            event_tx,
347            api_token: "test-token".to_string(),
348            passphrase_hash: RwLock::new(None),
349            passphrase_hash_mtime: RwLock::new(None),
350            bind_host: "127.0.0.1".to_string(),
351            bind_port: 3001,
352            login_attempts: Mutex::new(HashMap::new()),
353            runtimes: Mutex::new(HashMap::new()),
354            content_generators: Mutex::new(HashMap::new()),
355            circuit_breaker: None,
356            scraper_health: None,
357            watchtower_cancel: RwLock::new(None),
358            content_sources: RwLock::new(Default::default()),
359            connector_config: Default::default(),
360            deployment_mode: Default::default(),
361            pending_oauth: Mutex::new(HashMap::new()),
362            token_managers: Mutex::new(HashMap::new()),
363            x_client_id: String::new(),
364        })
365    }
366
367    fn test_router(state: Arc<AppState>) -> Router {
368        Router::new()
369            .route("/vault/sources", get(vault_sources))
370            .route("/vault/notes", get(search_notes))
371            .route("/vault/notes/{id}", get(note_detail))
372            .route("/vault/search", get(search_fragments))
373            .route("/vault/resolve-refs", post(resolve_refs))
374            .with_state(state)
375    }
376
377    #[tokio::test]
378    async fn vault_sources_returns_empty_when_no_sources() {
379        let state = test_state().await;
380        let app = test_router(state);
381
382        let resp = app
383            .oneshot(
384                Request::builder()
385                    .uri("/vault/sources")
386                    .body(Body::empty())
387                    .unwrap(),
388            )
389            .await
390            .unwrap();
391
392        assert_eq!(resp.status(), StatusCode::OK);
393        let body: serde_json::Value = serde_json::from_slice(
394            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
395                .await
396                .unwrap(),
397        )
398        .unwrap();
399        assert_eq!(body["sources"].as_array().unwrap().len(), 0);
400    }
401
402    #[tokio::test]
403    async fn search_notes_returns_empty_for_no_matches() {
404        let state = test_state().await;
405        let app = test_router(state);
406
407        let resp = app
408            .oneshot(
409                Request::builder()
410                    .uri("/vault/notes?q=nonexistent")
411                    .body(Body::empty())
412                    .unwrap(),
413            )
414            .await
415            .unwrap();
416
417        assert_eq!(resp.status(), StatusCode::OK);
418        let body: serde_json::Value = serde_json::from_slice(
419            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
420                .await
421                .unwrap(),
422        )
423        .unwrap();
424        assert_eq!(body["notes"].as_array().unwrap().len(), 0);
425    }
426
427    #[tokio::test]
428    async fn note_detail_returns_404_for_missing_node() {
429        let state = test_state().await;
430        let app = test_router(state);
431
432        let resp = app
433            .oneshot(
434                Request::builder()
435                    .uri("/vault/notes/999")
436                    .body(Body::empty())
437                    .unwrap(),
438            )
439            .await
440            .unwrap();
441
442        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
443    }
444
445    #[tokio::test]
446    async fn search_fragments_returns_empty_for_no_chunks() {
447        let state = test_state().await;
448        let app = test_router(state);
449
450        let resp = app
451            .oneshot(
452                Request::builder()
453                    .uri("/vault/search?q=nonexistent")
454                    .body(Body::empty())
455                    .unwrap(),
456            )
457            .await
458            .unwrap();
459
460        assert_eq!(resp.status(), StatusCode::OK);
461        let body: serde_json::Value = serde_json::from_slice(
462            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
463                .await
464                .unwrap(),
465        )
466        .unwrap();
467        assert_eq!(body["fragments"].as_array().unwrap().len(), 0);
468    }
469
470    // --- clamp_limit tests ---
471
472    #[test]
473    fn clamp_limit_default() {
474        assert_eq!(clamp_limit(None), DEFAULT_LIMIT);
475    }
476
477    #[test]
478    fn clamp_limit_under_max() {
479        assert_eq!(clamp_limit(Some(50)), 50);
480    }
481
482    #[test]
483    fn clamp_limit_at_max() {
484        assert_eq!(clamp_limit(Some(MAX_LIMIT)), MAX_LIMIT);
485    }
486
487    #[test]
488    fn clamp_limit_over_max() {
489        assert_eq!(clamp_limit(Some(500)), MAX_LIMIT);
490    }
491
492    // --- truncate_snippet tests ---
493
494    #[test]
495    fn truncate_snippet_short_text() {
496        assert_eq!(truncate_snippet("hello", 120), "hello");
497    }
498
499    #[test]
500    fn truncate_snippet_at_limit() {
501        let text = "a".repeat(120);
502        assert_eq!(truncate_snippet(&text, 120), text);
503    }
504
505    #[test]
506    fn truncate_snippet_over_limit() {
507        let text = "a".repeat(200);
508        let result = truncate_snippet(&text, 120);
509        assert!(result.ends_with("..."));
510        assert!(result.len() <= 120);
511    }
512
513    #[test]
514    fn truncate_snippet_unicode_safe() {
515        // Test with multi-byte chars
516        let text = "a".repeat(115) + "\u{1F600}\u{1F600}\u{1F600}";
517        let result = truncate_snippet(&text, 120);
518        assert!(result.ends_with("..."));
519        // Should not panic on char boundary
520    }
521
522    // --- deserialization tests ---
523
524    #[test]
525    fn search_notes_query_defaults() {
526        let json = "{}";
527        let q: SearchNotesQuery = serde_json::from_str(json).expect("deser");
528        assert!(q.q.is_none());
529        assert!(q.source_id.is_none());
530        assert!(q.limit.is_none());
531    }
532
533    #[test]
534    fn search_fragments_query_deserializes() {
535        let json = r#"{"q":"rust","limit":10}"#;
536        let q: SearchFragmentsQuery = serde_json::from_str(json).expect("deser");
537        assert_eq!(q.q, "rust");
538        assert_eq!(q.limit, Some(10));
539    }
540
541    #[test]
542    fn resolve_refs_request_deserializes() {
543        let json = r#"{"node_ids":[1,2,3]}"#;
544        let req: ResolveRefsRequest = serde_json::from_str(json).expect("deser");
545        assert_eq!(req.node_ids.len(), 3);
546    }
547
548    #[test]
549    fn resolve_refs_request_empty_ids() {
550        let json = r#"{"node_ids":[]}"#;
551        let req: ResolveRefsRequest = serde_json::from_str(json).expect("deser");
552        assert!(req.node_ids.is_empty());
553    }
554
555    #[tokio::test]
556    async fn resolve_refs_returns_empty_for_empty_ids() {
557        let state = test_state().await;
558        let app = test_router(state);
559
560        let resp = app
561            .oneshot(
562                Request::builder()
563                    .method("POST")
564                    .uri("/vault/resolve-refs")
565                    .header("content-type", "application/json")
566                    .body(Body::from(r#"{"node_ids":[]}"#))
567                    .unwrap(),
568            )
569            .await
570            .unwrap();
571
572        assert_eq!(resp.status(), StatusCode::OK);
573        let body: serde_json::Value = serde_json::from_slice(
574            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
575                .await
576                .unwrap(),
577        )
578        .unwrap();
579        assert_eq!(body["citations"].as_array().unwrap().len(), 0);
580    }
581
582    #[tokio::test]
583    async fn vault_sources_includes_privacy_envelope() {
584        let state = test_state().await;
585        let app = test_router(state);
586
587        let resp = app
588            .oneshot(
589                Request::builder()
590                    .uri("/vault/sources")
591                    .body(Body::empty())
592                    .unwrap(),
593            )
594            .await
595            .unwrap();
596
597        assert_eq!(resp.status(), StatusCode::OK);
598        let body: serde_json::Value = serde_json::from_slice(
599            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
600                .await
601                .unwrap(),
602        )
603        .unwrap();
604        // Default deployment_mode is Desktop
605        assert_eq!(body["deployment_mode"], "desktop");
606        assert_eq!(body["privacy_envelope"], "local_first");
607    }
608
609    async fn test_state_with_mode(mode: tuitbot_core::config::DeploymentMode) -> Arc<AppState> {
610        let db = tuitbot_core::storage::init_test_db()
611            .await
612            .expect("init test db");
613        let (event_tx, _) = broadcast::channel::<AccountWsEvent>(16);
614        Arc::new(AppState {
615            db,
616            config_path: PathBuf::from("/tmp/test-config.toml"),
617            data_dir: PathBuf::from("/tmp"),
618            event_tx,
619            api_token: "test-token".to_string(),
620            passphrase_hash: RwLock::new(None),
621            passphrase_hash_mtime: RwLock::new(None),
622            bind_host: "127.0.0.1".to_string(),
623            bind_port: 3001,
624            login_attempts: Mutex::new(HashMap::new()),
625            runtimes: Mutex::new(HashMap::new()),
626            content_generators: Mutex::new(HashMap::new()),
627            circuit_breaker: None,
628            scraper_health: None,
629            watchtower_cancel: RwLock::new(None),
630            content_sources: RwLock::new(Default::default()),
631            connector_config: Default::default(),
632            deployment_mode: mode,
633            pending_oauth: Mutex::new(HashMap::new()),
634            token_managers: Mutex::new(HashMap::new()),
635            x_client_id: String::new(),
636        })
637    }
638
639    #[tokio::test]
640    async fn vault_sources_cloud_mode_privacy_envelope() {
641        let state = test_state_with_mode(tuitbot_core::config::DeploymentMode::Cloud).await;
642        let app = test_router(state);
643
644        let resp = app
645            .oneshot(
646                Request::builder()
647                    .uri("/vault/sources")
648                    .body(Body::empty())
649                    .unwrap(),
650            )
651            .await
652            .unwrap();
653
654        assert_eq!(resp.status(), StatusCode::OK);
655        let body: serde_json::Value = serde_json::from_slice(
656            &axum::body::to_bytes(resp.into_body(), 1024 * 64)
657                .await
658                .unwrap(),
659        )
660        .unwrap();
661        assert_eq!(body["deployment_mode"], "cloud");
662        assert_eq!(body["privacy_envelope"], "provider_controlled");
663    }
664}