1use std::collections::HashMap;
2
3use serde_json::Value;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct InsertAllConfig {
8 pub batch_size: usize,
10 pub id_field: String,
12}
13
14impl Default for InsertAllConfig {
15 fn default() -> Self {
16 Self {
17 batch_size: 1000,
18 id_field: "id".to_owned(),
19 }
20 }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct InsertResult {
26 pub inserted_count: usize,
28 pub inserted_ids: Vec<i64>,
30 pub batches: usize,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct UpsertResult {
37 pub inserted_count: usize,
39 pub updated_count: usize,
41 pub inserted_ids: Vec<i64>,
43 pub updated_ids: Vec<i64>,
45 pub batches: usize,
47}
48
49#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
51pub enum InsertAllError {
52 #[error("batch size must be greater than zero")]
54 InvalidBatchSize,
55 #[error("unique_by must not be empty")]
57 EmptyUniqueBy,
58}
59
60pub fn insert_all(records: &[HashMap<String, Value>]) -> Result<InsertResult, InsertAllError> {
62 insert_all_with_config(records, &InsertAllConfig::default())
63}
64
65pub fn insert_all_with_config(
67 records: &[HashMap<String, Value>],
68 config: &InsertAllConfig,
69) -> Result<InsertResult, InsertAllError> {
70 if config.batch_size == 0 {
71 return Err(InsertAllError::InvalidBatchSize);
72 }
73
74 let inserted_ids = collect_ids(records, &config.id_field);
75 Ok(InsertResult {
76 inserted_count: records.len(),
77 batches: batch_count(records.len(), config.batch_size),
78 inserted_ids,
79 })
80}
81
82pub fn upsert_all(
84 records: &[HashMap<String, Value>],
85 unique_by: &str,
86) -> Result<UpsertResult, InsertAllError> {
87 upsert_all_with_config(records, unique_by, &InsertAllConfig::default())
88}
89
90pub fn upsert_all_with_config(
92 records: &[HashMap<String, Value>],
93 unique_by: &str,
94 config: &InsertAllConfig,
95) -> Result<UpsertResult, InsertAllError> {
96 if config.batch_size == 0 {
97 return Err(InsertAllError::InvalidBatchSize);
98 }
99 if unique_by.is_empty() {
100 return Err(InsertAllError::EmptyUniqueBy);
101 }
102
103 let mut ids_by_unique_value = HashMap::<String, i64>::new();
104 let mut inserted_ids = Vec::new();
105 let mut updated_ids = Vec::new();
106 let mut next_id = 1_i64;
107
108 for record in records {
109 let unique_key = record
110 .get(unique_by)
111 .map(unique_value_key)
112 .unwrap_or_else(|| format!("__missing__:{next_id}"));
113 let explicit_id = record.get(&config.id_field).and_then(Value::as_i64);
114
115 if let Some(existing_id) = ids_by_unique_value.get(&unique_key).copied() {
116 updated_ids.push(existing_id);
117 continue;
118 }
119
120 let id = explicit_id.unwrap_or_else(|| {
121 let assigned = next_id;
122 next_id += 1;
123 assigned
124 });
125 ids_by_unique_value.insert(unique_key, id);
126 inserted_ids.push(id);
127 }
128
129 Ok(UpsertResult {
130 inserted_count: inserted_ids.len(),
131 updated_count: updated_ids.len(),
132 inserted_ids,
133 updated_ids,
134 batches: batch_count(records.len(), config.batch_size),
135 })
136}
137
138fn collect_ids(records: &[HashMap<String, Value>], id_field: &str) -> Vec<i64> {
139 let mut next_id = 1_i64;
140 records
141 .iter()
142 .map(|record| {
143 let id = record
144 .get(id_field)
145 .and_then(Value::as_i64)
146 .unwrap_or(next_id);
147 next_id = next_id.max(id.saturating_add(1));
148 id
149 })
150 .collect()
151}
152
153fn batch_count(total: usize, batch_size: usize) -> usize {
154 if total == 0 {
155 0
156 } else {
157 ((total - 1) / batch_size) + 1
158 }
159}
160
161fn unique_value_key(value: &Value) -> String {
162 match serde_json::to_string(value) {
163 Ok(serialized) => serialized,
164 Err(_) => "null".to_owned(),
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use serde_json::json;
171
172 use super::{InsertAllConfig, InsertAllError, insert_all, insert_all_with_config, upsert_all};
173
174 fn record(
175 id: Option<i64>,
176 email: &str,
177 ) -> std::collections::HashMap<String, serde_json::Value> {
178 let mut record = std::collections::HashMap::from([("email".to_owned(), json!(email))]);
179 if let Some(id) = id {
180 record.insert("id".to_owned(), json!(id));
181 }
182 record
183 }
184
185 #[test]
186 fn insert_all_returns_inserted_count_and_ids() {
187 let result = insert_all(&[
188 record(Some(10), "a@example.com"),
189 record(None, "b@example.com"),
190 ])
191 .expect("insert should succeed");
192
193 assert_eq!(result.inserted_count, 2);
194 assert_eq!(result.inserted_ids, vec![10, 11]);
195 assert_eq!(result.batches, 1);
196 }
197
198 #[test]
199 fn insert_all_respects_batch_size() {
200 let config = InsertAllConfig {
201 batch_size: 2,
202 ..InsertAllConfig::default()
203 };
204 let result = insert_all_with_config(
205 &[record(None, "a"), record(None, "b"), record(None, "c")],
206 &config,
207 )
208 .expect("insert should succeed");
209
210 assert_eq!(result.batches, 2);
211 }
212
213 #[test]
214 fn insert_all_rejects_zero_batch_size() {
215 let config = InsertAllConfig {
216 batch_size: 0,
217 ..InsertAllConfig::default()
218 };
219 assert_eq!(
220 insert_all_with_config(&[record(None, "a")], &config),
221 Err(InsertAllError::InvalidBatchSize)
222 );
223 }
224
225 #[test]
226 fn upsert_all_inserts_unique_rows_and_updates_duplicates() {
227 let result = upsert_all(
228 &[
229 record(Some(1), "a@example.com"),
230 record(Some(2), "b@example.com"),
231 record(Some(3), "a@example.com"),
232 ],
233 "email",
234 )
235 .expect("upsert should succeed");
236
237 assert_eq!(result.inserted_count, 2);
238 assert_eq!(result.updated_count, 1);
239 assert_eq!(result.inserted_ids, vec![1, 2]);
240 assert_eq!(result.updated_ids, vec![1]);
241 }
242
243 #[test]
244 fn upsert_all_requires_unique_by() {
245 assert_eq!(
246 upsert_all(&[record(None, "a")], ""),
247 Err(InsertAllError::EmptyUniqueBy)
248 );
249 }
250
251 #[test]
252 fn upsert_all_handles_missing_unique_field_as_distinct_rows() {
253 let records = [
254 std::collections::HashMap::new(),
255 std::collections::HashMap::new(),
256 ];
257 let result = upsert_all(&records, "email").expect("upsert should succeed");
258
259 assert_eq!(result.inserted_count, 2);
260 assert_eq!(result.updated_count, 0);
261 }
262}