rust_agent/memory/
utils.rs

1// 记忆系统工具函数模块
2use std::path::{Path, PathBuf};
3use anyhow::{Error, Result};
4use serde::{Serialize, Deserialize};
5use serde_json::Value;
6use log::warn;
7
8/// 确保数据目录存在
9pub async fn ensure_data_dir_exists(data_dir: &Path) -> Result<()> {
10    if !data_dir.exists() {
11        tokio::fs::create_dir_all(data_dir).await?;
12    }
13    Ok(())
14}
15
16/// 估算文本的 token 数量
17/// 这是一个简化的实现,实际应用中可以使用更精确的 token 计算器
18pub fn estimate_token_count(text: &str) -> usize {
19    // 简化实现:假设平均每个 token 约 4 个字符
20    // 这对于英文来说是一个合理的近似值,但对于中文可能不准确
21    text.len() / 4
22}
23
24/// 估算文本的 token 数量(区分中文字符)
25/// 对于中文字符,1字符≈1token;对于非中文字符,4字符≈1token
26pub fn estimate_text_tokens(text: &str) -> usize {
27    // 简化实现:假设平均每个 token 约 4 个字符
28    // 对于英文,这个假设比较准确;对于中文,1个字符约等于1个token
29    // 这里我们采用一个混合策略
30    let chinese_chars = text.chars().filter(|c| {
31        let c = *c as u32;
32        // 中文字符的Unicode范围
33        (0x4E00..=0x9FFF).contains(&c) || 
34        (0x3400..=0x4DBF).contains(&c) || 
35        (0x20000..=0x2A6DF).contains(&c) ||
36        (0x2A700..=0x2B73F).contains(&c) ||
37        (0x2B740..=0x2B81F).contains(&c) ||
38        (0x2B820..=0x2CEAF).contains(&c) ||
39        (0xF900..=0xFAFF).contains(&c) ||
40        (0x2F800..=0x2FA1F).contains(&c)
41    }).count();
42    
43    let non_chinese_chars = text.chars().count() - chinese_chars;
44    
45    // 中文字符:1字符≈1token,非中文字符:4字符≈1token
46    chinese_chars + non_chinese_chars / 4
47}
48
49/// 估算 JSON 值的 token 数量
50pub fn estimate_json_token_count(value: &Value) -> usize {
51    match value {
52        Value::String(s) => estimate_token_count(s),
53        Value::Number(_) => 1, // 数字通常算作一个 token
54        Value::Bool(_) => 1,   // 布尔值通常算作一个 token
55        Value::Null => 1,      // null 通常算作一个 token
56        Value::Array(arr) => {
57            // 数组的 token 数量是所有元素之和加上方括号和逗号
58            let mut count = 2; // 方括号
59            for item in arr {
60                count += estimate_json_token_count(item) + 1; // 加上逗号
61            }
62            count
63        }
64        Value::Object(obj) => {
65            // 对象的 token 数量是所有键值对之和加上花括号和冒号
66            let mut count = 2; // 花括号
67            for (key, value) in obj {
68                count += estimate_token_count(key) + 1; // 键和冒号
69                count += estimate_json_token_count(value) + 1; // 值和逗号
70            }
71            count
72        }
73    }
74}
75
76/// 序列化 JSON 值到字符串,带错误处理
77pub fn serialize_to_string(value: &Value) -> Result<String> {
78    serde_json::to_string(value).map_err(|e| {
79        warn!("Failed to serialize JSON value: {}", e);
80        Error::from(e)
81    })
82}
83
84/// 反序列化 JSON 字符串到值,带错误处理
85pub fn deserialize_from_str(json_str: &str) -> Result<Value> {
86    serde_json::from_str(json_str).map_err(|e| {
87        warn!("Failed to deserialize JSON string: {}", e);
88        Error::from(e)
89    })
90}
91
92/// 生成当前时间戳 (ISO 8601 格式)
93pub fn current_timestamp() -> String {
94    chrono::Utc::now().to_rfc3339()
95}
96
97/// 解析时间戳字符串
98pub fn parse_timestamp(timestamp: &str) -> Result<chrono::DateTime<chrono::Utc>> {
99    timestamp.parse::<chrono::DateTime<chrono::Utc>>().map_err(|e| {
100        warn!("Failed to parse timestamp '{}': {}", timestamp, e);
101        Error::from(e)
102    })
103}
104
105/// 获取会话文件路径
106pub fn get_session_file_path(data_dir: &Path, session_id: &str, suffix: &str) -> PathBuf {
107    data_dir.join(format!("{}_{}", session_id, suffix))
108}
109
110/// 创建带时间戳的备份文件路径
111pub fn create_backup_path(file_path: &Path) -> PathBuf {
112    let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
113    let parent = file_path.parent().unwrap_or_else(|| Path::new("."));
114    let file_stem = file_path.file_stem().unwrap_or_else(|| std::ffi::OsStr::new("backup"));
115    let extension = file_path.extension().and_then(|s| s.to_str()).unwrap_or("");
116    
117    if extension.is_empty() {
118        parent.join(format!("{}_{}.backup", file_stem.to_string_lossy(), timestamp))
119    } else {
120        parent.join(format!("{}_{}_{}.backup", file_stem.to_string_lossy(), timestamp, extension))
121    }
122}
123
124/// 异步读取文件内容,带错误处理
125pub async fn read_file_content(file_path: &Path) -> Result<String> {
126    tokio::fs::read_to_string(file_path).await.map_err(|e| {
127        warn!("Failed to read file '{}': {}", file_path.display(), e);
128        Error::from(e)
129    })
130}
131
132/// 异步写入文件内容,带错误处理
133pub async fn write_file_content(file_path: &Path, content: &str) -> Result<()> {
134    tokio::fs::write(file_path, content).await.map_err(|e| {
135        warn!("Failed to write file '{}': {}", file_path.display(), e);
136        Error::from(e)
137    })
138}
139
140/// 原子写入文件内容(先写入临时文件,然后重命名)
141pub async fn atomic_write_file(file_path: &Path, content: &str) -> Result<()> {
142    // 创建临时文件路径
143    let temp_path = file_path.with_extension("tmp");
144    
145    // 确保父目录存在
146    if let Some(parent) = file_path.parent() {
147        ensure_data_dir_exists(parent).await?;
148    }
149    
150    // 写入临时文件
151    write_file_content(&temp_path, content).await?;
152    
153    // 原子重命名
154    tokio::fs::rename(&temp_path, file_path).await.map_err(|e| {
155        warn!("Failed to rename temporary file to '{}': {}", file_path.display(), e);
156        Error::from(e)
157    })?;
158    
159    Ok(())
160}
161
162/// 追加内容到文件,带错误处理
163pub async fn append_to_file(file_path: &Path, content: &str) -> Result<()> {
164    use tokio::io::AsyncWriteExt;
165    
166    // 确保父目录存在
167    if let Some(parent) = file_path.parent() {
168        ensure_data_dir_exists(parent).await?;
169    }
170    
171    // 打开文件并追加内容
172    let mut file = tokio::fs::OpenOptions::new()
173        .create(true)
174        .append(true)
175        .open(file_path)
176        .await?;
177    
178    file.write_all(content.as_bytes()).await?;
179    file.flush().await?;
180    
181    Ok(())
182}
183
184/// 检查文件是否存在
185pub async fn file_exists(file_path: &Path) -> bool {
186    tokio::fs::metadata(file_path).await.is_ok()
187}
188
189/// 删除文件,带错误处理
190pub async fn delete_file(file_path: &Path) -> Result<()> {
191    if file_exists(file_path).await {
192        tokio::fs::remove_file(file_path).await.map_err(|e| {
193            warn!("Failed to delete file '{}': {}", file_path.display(), e);
194            Error::from(e)
195        })?;
196    }
197    Ok(())
198}
199
200/// 创建目录(如果不存在)
201pub async fn ensure_dir_exists(dir_path: &Path) -> Result<()> {
202    if !dir_path.exists() {
203        tokio::fs::create_dir_all(dir_path).await.map_err(|e| {
204            warn!("Failed to create directory '{}': {}", dir_path.display(), e);
205            Error::from(e)
206        })?;
207    }
208    Ok(())
209}
210
211/// 获取环境变量值,如果不存在则返回默认值
212pub fn get_env_var(key: &str, default: &str) -> String {
213    std::env::var(key).unwrap_or_else(|_| default.to_string())
214}
215
216/// 从环境变量获取数据目录路径
217pub fn get_data_dir_from_env() -> PathBuf {
218    let default_dir = "./data/memory";
219    let dir_str = get_env_var("MEMORY_DATA_DIR", default_dir);
220    PathBuf::from(dir_str)
221}
222
223/// 从环境变量获取摘要阈值
224pub fn get_summary_threshold_from_env() -> usize {
225    let default_threshold = 3500;
226    let threshold_str = get_env_var("MEMORY_SUMMARY_THRESHOLD", &default_threshold.to_string());
227    threshold_str.parse().unwrap_or(default_threshold)
228}
229
230/// 从环境变量获取最近消息数量
231pub fn get_recent_messages_count_from_env() -> usize {
232    let default_count = 10;
233    let count_str = get_env_var("MEMORY_RECENT_MESSAGES_COUNT", &default_count.to_string());
234    count_str.parse().unwrap_or(default_count)
235}
236
237/// 生成随机会话 ID
238pub fn generate_session_id() -> String {
239    uuid::Uuid::new_v4().to_string()
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use tempfile::TempDir;
246    
247    #[tokio::test]
248    async fn test_ensure_data_dir_exists() {
249        let temp_dir = TempDir::new().unwrap();
250        let dir_path = temp_dir.path().join("test_dir");
251        
252        assert!(!dir_path.exists());
253        ensure_data_dir_exists(&dir_path).await.unwrap();
254        assert!(dir_path.exists());
255    }
256    
257    #[test]
258    fn test_estimate_token_count() {
259        assert_eq!(estimate_token_count(""), 0);
260        assert_eq!(estimate_token_count("hello world"), 2); // 11 chars / 4 = 2 (floor)
261        assert_eq!(estimate_token_count("a".repeat(20).as_str()), 5); // 20 chars / 4 = 5
262    }
263    
264    #[test]
265    fn test_current_timestamp() {
266        let timestamp = current_timestamp();
267        assert!(parse_timestamp(&timestamp).is_ok());
268    }
269    
270    #[test]
271    fn test_get_session_file_path() {
272        let data_dir = Path::new("/tmp");
273        let session_id = "test_session";
274        let suffix = "json";
275        
276        let path = get_session_file_path(data_dir, session_id, suffix);
277        assert_eq!(path, PathBuf::from("/tmp/test_session_json"));
278    }
279    
280    #[test]
281    fn test_get_env_var() {
282        let key = "MEMORY_TEST_VAR";
283        let default = "default_value";
284        
285        // 确保环境变量未设置
286        std::env::remove_var(key);
287        assert_eq!(get_env_var(key, default), default);
288        
289        // 设置环境变量
290        std::env::set_var(key, "test_value");
291        assert_eq!(get_env_var(key, default), "test_value");
292        
293        // 清理
294        std::env::remove_var(key);
295    }
296    
297    #[test]
298    fn test_generate_session_id() {
299        let id1 = generate_session_id();
300        let id2 = generate_session_id();
301        
302        assert_ne!(id1, id2);
303        assert_eq!(id1.len(), 36); // UUID length
304        assert_eq!(id2.len(), 36); // UUID length
305    }
306}