1use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24 tokio::spawn(async move {
25 let result = tokio::task::spawn(fut).await;
26 if let Err(e) = result {
27 tracing::error!("Background task '{}' panicked: {:?}", name, e);
28 }
29 });
30}
31
32pub fn get_tools() -> Vec<Tool> {
34 vec![
35 Tool {
36 name: "store".to_string(),
37 description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
38 input_schema: json!({
39 "type": "object",
40 "properties": {
41 "content": {
42 "type": "string",
43 "description": "The content to store"
44 },
45 "scope": {
46 "type": "string",
47 "enum": ["project", "global"],
48 "default": "project",
49 "description": "Where to store: 'project' (current project) or 'global' (all projects)"
50 }
51 },
52 "required": ["content"]
53 }),
54 },
55 Tool {
56 name: "recall".to_string(),
57 description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
58 input_schema: json!({
59 "type": "object",
60 "properties": {
61 "query": {
62 "type": "string",
63 "description": "What to search for (semantic search)"
64 },
65 "limit": {
66 "type": "number",
67 "default": 5,
68 "description": "Maximum number of results"
69 }
70 },
71 "required": ["query"]
72 }),
73 },
74 Tool {
75 name: "list".to_string(),
76 description: "List stored items.".to_string(),
77 input_schema: json!({
78 "type": "object",
79 "properties": {
80 "limit": {
81 "type": "number",
82 "default": 10,
83 "description": "Maximum number of results"
84 },
85 "scope": {
86 "type": "string",
87 "enum": ["project", "global", "all"],
88 "default": "project",
89 "description": "Which items to list: 'project', 'global', or 'all'"
90 }
91 }
92 }),
93 },
94 Tool {
95 name: "forget".to_string(),
96 description: "Delete a stored item by its ID.".to_string(),
97 input_schema: json!({
98 "type": "object",
99 "properties": {
100 "id": {
101 "type": "string",
102 "description": "The item ID to delete"
103 }
104 },
105 "required": ["id"]
106 }),
107 },
108 ]
109}
110
111#[derive(Debug, Deserialize)]
114pub struct StoreParams {
115 pub content: String,
116 #[serde(default)]
117 pub scope: Option<String>,
118}
119
120#[derive(Debug, Deserialize)]
121pub struct RecallParams {
122 pub query: String,
123 #[serde(default)]
124 pub limit: Option<usize>,
125}
126
127#[derive(Debug, Deserialize)]
128pub struct ListParams {
129 #[serde(default)]
130 pub limit: Option<usize>,
131 #[serde(default)]
132 pub scope: Option<String>,
133}
134
135#[derive(Debug, Deserialize)]
136pub struct ForgetParams {
137 pub id: String,
138}
139
140pub struct RecallConfig {
145 pub enable_graph_backfill: bool,
146 pub enable_graph_expansion: bool,
147 pub enable_co_access: bool,
148 pub enable_decay_scoring: bool,
149 pub enable_background_tasks: bool,
150}
151
152impl Default for RecallConfig {
153 fn default() -> Self {
154 Self {
155 enable_graph_backfill: true,
156 enable_graph_expansion: true,
157 enable_co_access: true,
158 enable_decay_scoring: true,
159 enable_background_tasks: true,
160 }
161 }
162}
163
164pub struct RecallResult {
166 pub results: Vec<crate::item::SearchResult>,
167 pub graph_expanded: Vec<Value>,
168 pub suggested: Vec<Value>,
169 pub raw_similarities: std::collections::HashMap<String, f32>,
171}
172
173pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
176 let config = RetryConfig::default();
177 let args_for_retry = args.clone();
178
179 let result = with_retry(&config, || {
180 let ctx_ref = ctx;
181 let name_ref = name;
182 let args_clone = args_for_retry.clone();
183
184 async move {
185 let mut db = Database::open_with_embedder(
187 &ctx_ref.db_path,
188 ctx_ref.project_id.clone(),
189 ctx_ref.embedder.clone(),
190 )
191 .await
192 .map_err(|e| sanitize_err("Failed to open database", e))?;
193
194 let tracker = AccessTracker::open(&ctx_ref.access_db_path)
196 .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
197
198 let graph = GraphStore::open(&ctx_ref.access_db_path)
200 .map_err(|e| sanitize_err("Failed to open graph store", e))?;
201
202 let result = match name_ref {
203 "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
204 "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
205 "list" => execute_list(&mut db, args_clone).await,
206 "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
207 _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
208 };
209
210 if result.is_error.unwrap_or(false)
211 && let Some(content) = result.content.first()
212 && is_retryable_error(&content.text)
213 {
214 return Err(content.text.clone());
215 }
216
217 Ok(result)
218 }
219 })
220 .await;
221
222 match result {
223 Ok(call_result) => call_result,
224 Err(e) => {
225 tracing::error!("Operation failed after retries: {}", e);
226 CallToolResult::error("Operation failed after retries")
227 }
228 }
229}
230
231fn is_retryable_error(error_msg: &str) -> bool {
232 let retryable_patterns = [
233 "connection",
234 "timeout",
235 "temporarily unavailable",
236 "resource busy",
237 "lock",
238 "I/O error",
239 "Failed to open",
240 "Failed to connect",
241 ];
242
243 let lower = error_msg.to_lowercase();
244 retryable_patterns
245 .iter()
246 .any(|p| lower.contains(&p.to_lowercase()))
247}
248
249async fn execute_store(
252 db: &mut Database,
253 _tracker: &AccessTracker,
254 graph: &GraphStore,
255 ctx: &ServerContext,
256 args: Option<Value>,
257) -> CallToolResult {
258 let params: StoreParams = match args {
259 Some(v) => match serde_json::from_value(v) {
260 Ok(p) => p,
261 Err(e) => {
262 tracing::debug!("Parameter validation failed: {}", e);
263 return CallToolResult::error("Invalid parameters");
264 }
265 },
266 None => return CallToolResult::error("Missing parameters"),
267 };
268
269 const MAX_CONTENT_BYTES: usize = 1_000_000;
273 if params.content.len() > MAX_CONTENT_BYTES {
274 return CallToolResult::error(format!(
275 "Content too large: {} bytes (max {} bytes)",
276 params.content.len(),
277 MAX_CONTENT_BYTES
278 ));
279 }
280
281 let scope = params
283 .scope
284 .as_deref()
285 .map(|s| s.parse::<StoreScope>())
286 .transpose();
287
288 let scope = match scope {
289 Ok(s) => s.unwrap_or(StoreScope::Project),
290 Err(e) => return CallToolResult::error(e),
291 };
292
293 let mut item = Item::new(¶ms.content);
295
296 if scope == StoreScope::Project
298 && let Some(project_id) = db.project_id()
299 {
300 item = item.with_project_id(project_id);
301 }
302
303 match db.store_item(item).await {
304 Ok(store_result) => {
305 let new_id = store_result.id.clone();
306
307 let now = chrono::Utc::now().timestamp();
309 let project_id = db.project_id().map(|s| s.to_string());
310 if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
311 tracing::warn!("graph add_node failed: {}", e);
312 }
313
314 if !store_result.potential_conflicts.is_empty()
316 && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
317 {
318 for conflict in &store_result.potential_conflicts {
319 if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
320 {
321 tracing::warn!("enqueue consolidation failed: {}", e);
322 }
323 }
324 }
325
326 let mut result = json!({
327 "success": true,
328 "id": new_id,
329 "message": format!("Stored in {} scope", scope)
330 });
331
332 if !store_result.potential_conflicts.is_empty() {
333 let conflicts: Vec<Value> = store_result
334 .potential_conflicts
335 .iter()
336 .map(|c| {
337 json!({
338 "id": c.id,
339 "content": c.content,
340 "similarity": format!("{:.2}", c.similarity)
341 })
342 })
343 .collect();
344 result["potential_conflicts"] = json!(conflicts);
345 }
346
347 CallToolResult::success(
348 serde_json::to_string_pretty(&result)
349 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
350 )
351 }
352 Err(e) => sanitized_error("Failed to store item", e),
353 }
354}
355
356pub async fn recall_pipeline(
361 db: &mut Database,
362 tracker: &AccessTracker,
363 graph: &GraphStore,
364 query: &str,
365 limit: usize,
366 filters: ItemFilters,
367 config: &RecallConfig,
368) -> std::result::Result<RecallResult, String> {
369 let mut results = db
370 .search_items(query, limit, filters)
371 .await
372 .map_err(|e| format!("Search failed: {}", e))?;
373
374 if results.is_empty() {
375 return Ok(RecallResult {
376 results: Vec::new(),
377 graph_expanded: Vec::new(),
378 suggested: Vec::new(),
379 raw_similarities: std::collections::HashMap::new(),
380 });
381 }
382
383 if config.enable_graph_backfill {
385 for result in &results {
386 if let Err(e) = graph.ensure_node_exists(
387 &result.id,
388 result.project_id.as_deref(),
389 result.created_at.timestamp(),
390 ) {
391 tracing::warn!("ensure_node_exists failed: {}", e);
392 }
393 }
394 }
395
396 let mut raw_similarities: std::collections::HashMap<String, f32> =
398 std::collections::HashMap::new();
399 if config.enable_decay_scoring {
400 let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
401 let access_records = tracker.get_accesses(&item_ids).unwrap_or_default();
402 let now = chrono::Utc::now().timestamp();
403
404 for result in &mut results {
405 raw_similarities.insert(result.id.clone(), result.similarity);
406
407 let created_at = result.created_at.timestamp();
408 let (access_count, last_accessed) = match access_records.get(&result.id) {
409 Some(rec) => (rec.access_count, Some(rec.last_accessed_at)),
410 None => (0, None),
411 };
412
413 let base_score = score_with_decay(
414 result.similarity,
415 now,
416 created_at,
417 access_count,
418 last_accessed,
419 );
420
421 let validation_count = tracker.get_validation_count(&result.id).unwrap_or(0);
422 let edge_count = graph.get_edge_count(&result.id).unwrap_or(0);
423 let trust_bonus =
424 1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
425
426 result.similarity = (base_score * trust_bonus).min(1.0);
427 }
428
429 results.sort_by(|a, b| {
430 b.similarity
431 .partial_cmp(&a.similarity)
432 .unwrap_or(std::cmp::Ordering::Equal)
433 });
434 }
435
436 for result in &results {
438 let created_at = result.created_at.timestamp();
439 if let Err(e) = tracker.record_access(&result.id, created_at) {
440 tracing::warn!("record_access failed: {}", e);
441 }
442 }
443
444 let existing_ids: std::collections::HashSet<String> =
446 results.iter().map(|r| r.id.clone()).collect();
447
448 let mut graph_expanded = Vec::new();
449 if config.enable_graph_expansion {
450 let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
451 if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
452 let neighbor_info: Vec<(String, String)> = neighbors
454 .into_iter()
455 .filter(|(id, _, _)| !existing_ids.contains(id))
456 .map(|(id, rel_type, _)| (id, rel_type))
457 .collect();
458
459 let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
460 if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
461 let item_map: std::collections::HashMap<&str, &Item> =
462 items.iter().map(|item| (item.id.as_str(), item)).collect();
463
464 for (neighbor_id, rel_type) in &neighbor_info {
465 if let Some(item) = item_map.get(neighbor_id.as_str()) {
466 let sr = crate::item::SearchResult::from_item(item, 0.05);
467 let mut entry = json!({
468 "id": sr.id,
469 "similarity": "graph",
470 "created": sr.created_at.to_rfc3339(),
471 "graph_expanded": true,
472 "rel_type": rel_type,
473 });
474 let same_project = match (db.project_id(), item.project_id.as_deref()) {
476 (Some(current), Some(item_pid)) => current == item_pid,
477 (_, None) => true,
478 _ => false,
479 };
480 if same_project {
481 entry["content"] = json!(sr.content);
482 } else {
483 entry["cross_project"] = json!(true);
484 }
485 graph_expanded.push(entry);
486 }
487 }
488 }
489 }
490 }
491
492 let mut suggested = Vec::new();
494 if config.enable_co_access {
495 let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
496 if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
497 let co_info: Vec<(String, i64)> = co_accessed
498 .into_iter()
499 .filter(|(id, _)| !existing_ids.contains(id))
500 .collect();
501
502 let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
503 if let Ok(items) = db.get_items_batch(&co_ids).await {
504 let item_map: std::collections::HashMap<&str, &Item> =
505 items.iter().map(|item| (item.id.as_str(), item)).collect();
506
507 for (co_id, co_count) in &co_info {
508 if let Some(item) = item_map.get(co_id.as_str()) {
509 let same_project = match (db.project_id(), item.project_id.as_deref()) {
510 (Some(current), Some(item_pid)) => current == item_pid,
511 (_, None) => true,
512 _ => false,
513 };
514 let mut entry = json!({
515 "id": item.id,
516 "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
517 });
518 if same_project {
519 entry["content"] = json!(truncate(&item.content, 100));
520 } else {
521 entry["cross_project"] = json!(true);
522 }
523 suggested.push(entry);
524 }
525 }
526 }
527 }
528 }
529
530 Ok(RecallResult {
531 results,
532 graph_expanded,
533 suggested,
534 raw_similarities,
535 })
536}
537
538async fn execute_recall(
539 db: &mut Database,
540 tracker: &AccessTracker,
541 graph: &GraphStore,
542 ctx: &ServerContext,
543 args: Option<Value>,
544) -> CallToolResult {
545 let params: RecallParams = match args {
546 Some(v) => match serde_json::from_value(v) {
547 Ok(p) => p,
548 Err(e) => {
549 tracing::debug!("Parameter validation failed: {}", e);
550 return CallToolResult::error("Invalid parameters");
551 }
552 },
553 None => return CallToolResult::error("Missing parameters"),
554 };
555
556 const MAX_QUERY_BYTES: usize = 10_000;
560 if params.query.len() > MAX_QUERY_BYTES {
561 return CallToolResult::error(format!(
562 "Query too large: {} bytes (max {} bytes)",
563 params.query.len(),
564 MAX_QUERY_BYTES
565 ));
566 }
567
568 let limit = params.limit.unwrap_or(5).min(100);
569
570 let filters = ItemFilters::new();
571
572 let config = RecallConfig::default();
573
574 let recall_result =
575 match recall_pipeline(db, tracker, graph, ¶ms.query, limit, filters, &config).await {
576 Ok(r) => r,
577 Err(e) => {
578 tracing::error!("Recall failed: {}", e);
579 return CallToolResult::error("Search failed");
580 }
581 };
582
583 if recall_result.results.is_empty() {
584 return CallToolResult::success("No items found matching your query.");
585 }
586
587 let results = &recall_result.results;
588
589 let formatted: Vec<Value> = results
590 .iter()
591 .map(|r| {
592 let mut obj = json!({
593 "id": r.id,
594 "content": r.content,
595 "similarity": format!("{:.2}", r.similarity),
596 "created": r.created_at.to_rfc3339(),
597 });
598
599 if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
601 && (raw_sim - r.similarity).abs() > 0.001
602 {
603 obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
604 }
605
606 if let Some(ref excerpt) = r.relevant_excerpt {
607 obj["relevant_excerpt"] = json!(excerpt);
608 }
609
610 if let Some(ref current_pid) = ctx.project_id
612 && let Some(ref item_pid) = r.project_id
613 && item_pid != current_pid
614 {
615 obj["cross_project"] = json!(true);
616 }
617
618 if let Ok(neighbors) = graph.get_neighbors(&[r.id.as_str()], 0.5) {
620 let related: Vec<String> = neighbors.iter().map(|(id, _, _)| id.clone()).collect();
621 if !related.is_empty() {
622 obj["related_ids"] = json!(related);
623 }
624 }
625
626 obj
627 })
628 .collect();
629
630 let mut result_json = json!({
631 "count": results.len(),
632 "results": formatted
633 });
634
635 if !recall_result.graph_expanded.is_empty() {
636 result_json["graph_expanded"] = json!(recall_result.graph_expanded);
637 }
638
639 if !recall_result.suggested.is_empty() {
640 result_json["suggested"] = json!(recall_result.suggested);
641 }
642
643 spawn_consolidation(
645 Arc::new(ctx.db_path.clone()),
646 Arc::new(ctx.access_db_path.clone()),
647 ctx.project_id.clone(),
648 ctx.embedder.clone(),
649 ctx.consolidation_semaphore.clone(),
650 );
651
652 let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
654 let access_db_path = ctx.access_db_path.clone();
655 spawn_logged("co_access", async move {
656 if let Ok(g) = GraphStore::open(&access_db_path) {
657 if let Err(e) = g.record_co_access(&result_ids) {
658 tracing::warn!("record_co_access failed: {}", e);
659 }
660 } else {
661 tracing::warn!("co_access: failed to open graph store");
662 }
663 });
664
665 let run_count = ctx
667 .recall_count
668 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
669 if run_count % 10 == 9 {
670 let access_db_path = ctx.access_db_path.clone();
672 spawn_logged("clustering", async move {
673 if let Ok(g) = GraphStore::open(&access_db_path)
674 && let Ok(clusters) = g.detect_clusters()
675 {
676 for (a, b, c) in &clusters {
677 let label = format!("cluster-{}", &a[..8.min(a.len())]);
678 if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
679 tracing::warn!("cluster add_related_edge failed: {}", e);
680 }
681 if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
682 tracing::warn!("cluster add_related_edge failed: {}", e);
683 }
684 if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
685 tracing::warn!("cluster add_related_edge failed: {}", e);
686 }
687 }
688 if !clusters.is_empty() {
689 tracing::info!("Detected {} clusters", clusters.len());
690 }
691 }
692 });
693
694 let access_db_path2 = ctx.access_db_path.clone();
696 spawn_logged("consolidation_cleanup", async move {
697 if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
698 && let Err(e) = q.cleanup_processed()
699 {
700 tracing::warn!("consolidation queue cleanup failed: {}", e);
701 }
702 });
703 }
704
705 CallToolResult::success(
706 serde_json::to_string_pretty(&result_json)
707 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
708 )
709}
710
711async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
712 let params: ListParams =
713 args.and_then(|v| serde_json::from_value(v).ok())
714 .unwrap_or(ListParams {
715 limit: None,
716 scope: None,
717 });
718
719 let limit = params.limit.unwrap_or(10).min(100);
720
721 let filters = ItemFilters::new();
722
723 let scope = params
724 .scope
725 .as_deref()
726 .map(|s| s.parse::<ListScope>())
727 .transpose();
728
729 let scope = match scope {
730 Ok(s) => s.unwrap_or(ListScope::Project),
731 Err(e) => return CallToolResult::error(e),
732 };
733
734 match db.list_items(filters, Some(limit), scope).await {
735 Ok(items) => {
736 if items.is_empty() {
737 CallToolResult::success("No items stored yet.")
738 } else {
739 let formatted: Vec<Value> = items
740 .iter()
741 .map(|item| {
742 let content_preview = truncate(&item.content, 100);
743 let mut obj = json!({
744 "id": item.id,
745 "content": content_preview,
746 "created": item.created_at.to_rfc3339(),
747 });
748
749 if item.is_chunked {
750 obj["chunked"] = json!(true);
751 }
752
753 obj
754 })
755 .collect();
756
757 let result = json!({
758 "count": items.len(),
759 "items": formatted
760 });
761
762 CallToolResult::success(
763 serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
764 format!("{{\"error\": \"serialization failed: {}\"}}", e)
765 }),
766 )
767 }
768 }
769 Err(e) => sanitized_error("Failed to list items", e),
770 }
771}
772
773async fn execute_forget(
774 db: &mut Database,
775 graph: &GraphStore,
776 ctx: &ServerContext,
777 args: Option<Value>,
778) -> CallToolResult {
779 let params: ForgetParams = match args {
780 Some(v) => match serde_json::from_value(v) {
781 Ok(p) => p,
782 Err(e) => {
783 tracing::debug!("Parameter validation failed: {}", e);
784 return CallToolResult::error("Invalid parameters");
785 }
786 },
787 None => return CallToolResult::error("Missing parameters"),
788 };
789
790 if let Some(ref current_pid) = ctx.project_id {
792 match db.get_item(¶ms.id).await {
793 Ok(Some(item)) => {
794 if let Some(ref item_pid) = item.project_id
795 && item_pid != current_pid
796 {
797 return CallToolResult::error(format!(
798 "Cannot delete item {} from a different project",
799 params.id
800 ));
801 }
802 }
803 Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
804 Err(e) => {
805 return sanitized_error("Failed to look up item", e);
806 }
807 }
808 }
809
810 match db.delete_item(¶ms.id).await {
811 Ok(true) => {
812 if let Err(e) = graph.remove_node(¶ms.id) {
814 tracing::warn!("remove_node failed: {}", e);
815 }
816
817 let result = json!({
818 "success": true,
819 "message": format!("Deleted item: {}", params.id)
820 });
821 CallToolResult::success(
822 serde_json::to_string_pretty(&result)
823 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
824 )
825 }
826 Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
827 Err(e) => sanitized_error("Failed to delete item", e),
828 }
829}
830
831fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
836 tracing::error!("{}: {}", context, err);
837 CallToolResult::error(context.to_string())
838}
839
840fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
842 tracing::error!("{}: {}", context, err);
843 context.to_string()
844}
845
846fn truncate(s: &str, max_len: usize) -> String {
847 if s.chars().count() <= max_len {
848 s.to_string()
849 } else if max_len <= 3 {
850 s.chars().take(max_len).collect()
852 } else {
853 let cut = s
854 .char_indices()
855 .nth(max_len - 3)
856 .map(|(i, _)| i)
857 .unwrap_or(s.len());
858 format!("{}...", &s[..cut])
859 }
860}
861
862#[cfg(test)]
863mod tests {
864 use super::*;
865
866 #[test]
867 fn test_truncate_small_max_len() {
868 assert_eq!(truncate("hello", 0), "");
870 assert_eq!(truncate("hello", 1), "h");
871 assert_eq!(truncate("hello", 2), "he");
872 assert_eq!(truncate("hello", 3), "hel");
873 assert_eq!(truncate("hi", 3), "hi"); assert_eq!(truncate("hello", 5), "hello");
875 assert_eq!(truncate("hello!", 5), "he...");
876 }
877
878 #[test]
879 fn test_truncate_unicode() {
880 assert_eq!(truncate("héllo wörld", 5), "hé...");
881 assert_eq!(truncate("日本語テスト", 4), "日...");
882 }
883}