Skip to main content

wae_testing/
db_mock.rs

1//! 数据库 Mock 工具模块
2//!
3//! 提供 MockDatabase 结构体,用于模拟数据库操作,支持查询期望配置、调用记录和验证。
4
5use crate::mock::{Mock, MockCall};
6use parking_lot::RwLock;
7use std::{collections::HashMap, sync::Arc};
8use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
9
10/// 数据库查询记录
11#[derive(Debug, Clone)]
12pub struct DatabaseQuery {
13    /// 查询名称/类型
14    pub query_name: String,
15    /// 查询参数 (JSON 序列化)
16    pub params: Vec<String>,
17    /// 查询时间戳
18    pub timestamp: std::time::Instant,
19}
20
21/// 数据库查询期望
22#[derive(Debug, Clone)]
23pub enum DatabaseResult<T> {
24    /// 返回指定值
25    Return(T),
26    /// 返回错误
27    Error(String),
28    /// 根据查询名称匹配返回
29    NamedReturns(HashMap<String, T>),
30}
31
32impl<T: Clone> DatabaseResult<T> {
33    /// 创建返回值
34    pub fn return_value(value: T) -> Self {
35        DatabaseResult::Return(value)
36    }
37
38    /// 创建错误返回
39    pub fn error(msg: impl Into<String>) -> Self {
40        DatabaseResult::Error(msg.into())
41    }
42
43    /// 创建命名返回映射
44    pub fn named_returns(map: HashMap<String, T>) -> Self {
45        DatabaseResult::NamedReturns(map)
46    }
47}
48
49/// 数据库 Mock 数据库
50#[derive(Debug, Default)]
51pub struct DatabaseExpectation {
52    /// 期望的查询次数映射 (查询名称 -> 期望次数)
53    pub expected_queries: HashMap<String, usize>,
54    /// 描述信息
55    pub description: Option<String>,
56}
57
58impl DatabaseExpectation {
59    /// 创建新的数据库期望配置
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// 设置期望查询次数
65    pub fn expect_query(mut self, query_name: impl Into<String>, count: usize) -> Self {
66        self.expected_queries.insert(query_name.into(), count);
67        self
68    }
69
70    /// 设置描述
71    pub fn description(mut self, desc: impl Into<String>) -> Self {
72        self.description = Some(desc.into());
73        self
74    }
75}
76
77/// Mock 数据库构建器
78pub struct MockDatabaseBuilder<T> {
79    result: Option<DatabaseResult<T>>,
80    expectation: DatabaseExpectation,
81    queries: Arc<RwLock<Vec<DatabaseQuery>>>,
82}
83
84impl<T: Clone + Send + Sync + 'static> MockDatabaseBuilder<T> {
85    /// 创建新的 Mock 数据库构建器
86    pub fn new() -> Self {
87        Self { result: None, expectation: DatabaseExpectation::default(), queries: Arc::new(RwLock::new(Vec::new())) }
88    }
89
90    /// 设置返回值
91    pub fn return_value(mut self, value: T) -> Self {
92        self.result = Some(DatabaseResult::return_value(value));
93        self
94    }
95
96    /// 设置错误返回
97    pub fn error(mut self, msg: impl Into<String>) -> Self {
98        self.result = Some(DatabaseResult::error(msg));
99        self
100    }
101
102    /// 设置命名返回映射
103    pub fn named_returns(mut self, map: HashMap<String, T>) -> Self {
104        self.result = Some(DatabaseResult::named_returns(map));
105        self
106    }
107
108    /// 设置期望
109    pub fn expect(mut self, expectation: DatabaseExpectation) -> Self {
110        self.expectation = expectation;
111        self
112    }
113
114    /// 构建 Mock 数据库
115    pub fn build(self) -> MockDatabase<T> {
116        MockDatabase { result: self.result, expectation: self.expectation, queries: self.queries }
117    }
118}
119
120impl<T: Clone + Send + Sync + 'static> Default for MockDatabaseBuilder<T> {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126/// Mock 数据库
127///
128/// 用于模拟数据库操作,支持查询期望配置、调用记录和验证。
129pub struct MockDatabase<T> {
130    result: Option<DatabaseResult<T>>,
131    expectation: DatabaseExpectation,
132    queries: Arc<RwLock<Vec<DatabaseQuery>>>,
133}
134
135impl<T: Clone> MockDatabase<T> {
136    /// 执行数据库查询
137    pub fn query(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
138        let query_name = query_name.into();
139        {
140            let mut queries = self.queries.write();
141            queries.push(DatabaseQuery { query_name: query_name.clone(), params, timestamp: std::time::Instant::now() });
142        }
143
144        match &self.result {
145            Some(DatabaseResult::Return(v)) => Ok(v.clone()),
146            Some(DatabaseResult::Error(e)) => Err(WaeError::new(WaeErrorKind::MockError { reason: e.clone() })),
147            Some(DatabaseResult::NamedReturns(map)) => {
148                if let Some(value) = map.get(&query_name) {
149                    Ok(value.clone())
150                }
151                else {
152                    Err(WaeError::new(WaeErrorKind::MockError { reason: format!("No mock result for query: {}", query_name) }))
153                }
154            }
155            None => Err(WaeError::new(WaeErrorKind::MockError { reason: "No mock result configured".to_string() })),
156        }
157    }
158
159    /// 异步执行数据库查询
160    pub async fn query_async(&self, query_name: impl Into<String>, params: Vec<String>) -> TestingResult<T> {
161        self.query(query_name, params)
162    }
163
164    /// 获取查询记录
165    pub fn queries(&self) -> Vec<DatabaseQuery> {
166        self.queries.read().clone()
167    }
168
169    /// 获取查询次数
170    pub fn query_count(&self) -> usize {
171        self.queries.read().len()
172    }
173
174    /// 获取指定查询的次数
175    pub fn query_count_by_name(&self, query_name: &str) -> usize {
176        self.queries.read().iter().filter(|q| q.query_name == query_name).count()
177    }
178}
179
180impl<T: Clone + Send + Sync + 'static> Mock for MockDatabase<T> {
181    fn calls(&self) -> Vec<MockCall> {
182        self.queries
183            .read()
184            .iter()
185            .map(|q| MockCall {
186                args: vec![q.query_name.clone()].into_iter().chain(q.params.clone()).collect(),
187                timestamp: q.timestamp,
188            })
189            .collect()
190    }
191
192    fn call_count(&self) -> usize {
193        self.query_count()
194    }
195
196    fn verify(&self) -> TestingResult<()> {
197        for (query_name, expected) in &self.expectation.expected_queries {
198            let actual = self.query_count_by_name(query_name);
199            if actual != *expected {
200                return Err(WaeError::new(WaeErrorKind::AssertionFailed {
201                    message: format!("Expected {} calls for query '{}', but got {}", expected, query_name, actual),
202                }));
203            }
204        }
205
206        Ok(())
207    }
208
209    fn reset(&self) {
210        let mut queries = self.queries.write();
211        queries.clear();
212    }
213}