1use crate::error::{Error, Result};
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum AggregateFunction {
8 Sum,
9 Count,
10 Average,
11 Max,
12 Min,
13 Product,
14 CountNums,
15 StdDev,
16 StdDevP,
17 Var,
18 VarP,
19}
20
21impl AggregateFunction {
22 pub fn to_xml_str(&self) -> &str {
24 match self {
25 Self::Sum => "sum",
26 Self::Count => "count",
27 Self::Average => "average",
28 Self::Max => "max",
29 Self::Min => "min",
30 Self::Product => "product",
31 Self::CountNums => "countNums",
32 Self::StdDev => "stdDev",
33 Self::StdDevP => "stdDevp",
34 Self::Var => "var",
35 Self::VarP => "varp",
36 }
37 }
38
39 pub fn from_xml_str(s: &str) -> Option<Self> {
41 match s {
42 "sum" => Some(Self::Sum),
43 "count" => Some(Self::Count),
44 "average" => Some(Self::Average),
45 "max" => Some(Self::Max),
46 "min" => Some(Self::Min),
47 "product" => Some(Self::Product),
48 "countNums" => Some(Self::CountNums),
49 "stdDev" => Some(Self::StdDev),
50 "stdDevp" => Some(Self::StdDevP),
51 "var" => Some(Self::Var),
52 "varp" => Some(Self::VarP),
53 _ => None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct PivotTableConfig {
61 pub name: String,
63 pub source_sheet: String,
65 pub source_range: String,
67 pub target_sheet: String,
69 pub target_cell: String,
71 pub rows: Vec<PivotField>,
73 pub columns: Vec<PivotField>,
75 pub data: Vec<PivotDataField>,
77}
78
79#[derive(Debug, Clone)]
81pub struct PivotField {
82 pub name: String,
84}
85
86#[derive(Debug, Clone)]
88pub struct PivotDataField {
89 pub name: String,
91 pub function: AggregateFunction,
93 pub display_name: Option<String>,
95}
96
97#[derive(Debug, Clone)]
99pub struct PivotTableInfo {
100 pub name: String,
101 pub source_sheet: String,
102 pub source_range: String,
103 pub target_sheet: String,
104 pub location: String,
105}
106
107pub fn build_pivot_table_xml(
109 config: &PivotTableConfig,
110 cache_id: u32,
111 field_names: &[String],
112) -> Result<sheetkit_xml::pivot_table::PivotTableDefinition> {
113 use sheetkit_xml::pivot_table::*;
114
115 let ns = sheetkit_xml::namespaces::SPREADSHEET_ML;
116
117 let find_field_index = |name: &str| -> Result<usize> {
118 field_names.iter().position(|n| n == name).ok_or_else(|| {
119 Error::Internal(format!("pivot field '{}' not found in source data", name))
120 })
121 };
122
123 let mut pivot_field_defs = Vec::new();
124 for field_name in field_names {
125 let is_row = config.rows.iter().any(|r| r.name == *field_name);
126 let is_col = config.columns.iter().any(|c| c.name == *field_name);
127 let is_data = config.data.iter().any(|d| d.name == *field_name);
128
129 let axis = if is_row {
130 Some("axisRow".to_string())
131 } else if is_col {
132 Some("axisCol".to_string())
133 } else {
134 None
135 };
136
137 pivot_field_defs.push(PivotFieldDef {
138 axis,
139 data_field: if is_data { Some(true) } else { None },
140 show_all: Some(false),
141 items: None,
142 });
143 }
144
145 let row_fields = if config.rows.is_empty() {
146 None
147 } else {
148 let fields: Result<Vec<FieldRef>> = config
149 .rows
150 .iter()
151 .map(|r| find_field_index(&r.name).map(|i| FieldRef { index: i as i32 }))
152 .collect();
153 Some(FieldList {
154 count: Some(config.rows.len() as u32),
155 fields: fields?,
156 })
157 };
158
159 let col_fields = if config.columns.is_empty() {
160 None
161 } else {
162 let fields: Result<Vec<FieldRef>> = config
163 .columns
164 .iter()
165 .map(|c| find_field_index(&c.name).map(|i| FieldRef { index: i as i32 }))
166 .collect();
167 Some(FieldList {
168 count: Some(config.columns.len() as u32),
169 fields: fields?,
170 })
171 };
172
173 let data_fields = if config.data.is_empty() {
174 None
175 } else {
176 let fields: Result<Vec<DataFieldDef>> = config
177 .data
178 .iter()
179 .map(|d| {
180 let idx = find_field_index(&d.name)?;
181 Ok(DataFieldDef {
182 name: d.display_name.clone().or_else(|| {
183 Some(format!(
184 "{} of {}",
185 capitalize_first(d.function.to_xml_str()),
186 d.name
187 ))
188 }),
189 field_index: idx as u32,
190 subtotal: Some(d.function.to_xml_str().to_string()),
191 base_field: Some(0),
192 base_item: Some(0),
193 })
194 })
195 .collect();
196 Some(DataFields {
197 count: Some(config.data.len() as u32),
198 fields: fields?,
199 })
200 };
201
202 Ok(PivotTableDefinition {
203 xmlns: ns.to_string(),
204 name: config.name.clone(),
205 cache_id,
206 data_on_rows: Some(false),
207 apply_number_formats: Some(false),
208 apply_border_formats: Some(false),
209 apply_font_formats: Some(false),
210 apply_pattern_formats: Some(false),
211 apply_alignment_formats: Some(false),
212 apply_width_height_formats: Some(true),
213 location: PivotLocation {
214 reference: config.target_cell.clone(),
215 first_header_row: 1,
216 first_data_row: 1,
217 first_data_col: 1,
218 },
219 pivot_fields: PivotFields {
220 count: Some(field_names.len() as u32),
221 fields: pivot_field_defs,
222 },
223 row_fields,
224 col_fields,
225 data_fields,
226 })
227}
228
229pub fn build_pivot_cache_definition(
231 source_sheet: &str,
232 source_range: &str,
233 field_names: &[String],
234) -> sheetkit_xml::pivot_cache::PivotCacheDefinition {
235 use sheetkit_xml::pivot_cache::*;
236
237 let cache_fields = CacheFields {
238 count: Some(field_names.len() as u32),
239 fields: field_names
240 .iter()
241 .map(|name| CacheField {
242 name: name.clone(),
243 num_fmt_id: Some(0),
244 shared_items: Some(SharedItems {
245 contains_semi_mixed_types: None,
246 contains_string: None,
247 contains_number: None,
248 contains_blank: None,
249 count: Some(0),
250 string_items: vec![],
251 number_items: vec![],
252 }),
253 })
254 .collect(),
255 };
256
257 PivotCacheDefinition {
258 xmlns: sheetkit_xml::namespaces::SPREADSHEET_ML.to_string(),
259 xmlns_r: sheetkit_xml::namespaces::RELATIONSHIPS.to_string(),
260 r_id: None,
261 record_count: Some(0),
262 cache_source: CacheSource {
263 source_type: "worksheet".to_string(),
264 worksheet_source: Some(WorksheetSource {
265 reference: source_range.to_string(),
266 sheet: source_sheet.to_string(),
267 }),
268 },
269 cache_fields,
270 }
271}
272
273fn capitalize_first(s: &str) -> String {
274 let mut c = s.chars();
275 match c.next() {
276 None => String::new(),
277 Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_aggregate_function_to_xml_str() {
287 assert_eq!(AggregateFunction::Sum.to_xml_str(), "sum");
288 assert_eq!(AggregateFunction::Count.to_xml_str(), "count");
289 assert_eq!(AggregateFunction::Average.to_xml_str(), "average");
290 assert_eq!(AggregateFunction::Max.to_xml_str(), "max");
291 assert_eq!(AggregateFunction::Min.to_xml_str(), "min");
292 assert_eq!(AggregateFunction::Product.to_xml_str(), "product");
293 assert_eq!(AggregateFunction::CountNums.to_xml_str(), "countNums");
294 assert_eq!(AggregateFunction::StdDev.to_xml_str(), "stdDev");
295 assert_eq!(AggregateFunction::StdDevP.to_xml_str(), "stdDevp");
296 assert_eq!(AggregateFunction::Var.to_xml_str(), "var");
297 assert_eq!(AggregateFunction::VarP.to_xml_str(), "varp");
298 }
299
300 #[test]
301 fn test_aggregate_function_from_xml_str() {
302 assert_eq!(
303 AggregateFunction::from_xml_str("sum"),
304 Some(AggregateFunction::Sum)
305 );
306 assert_eq!(
307 AggregateFunction::from_xml_str("count"),
308 Some(AggregateFunction::Count)
309 );
310 assert_eq!(
311 AggregateFunction::from_xml_str("average"),
312 Some(AggregateFunction::Average)
313 );
314 assert_eq!(
315 AggregateFunction::from_xml_str("max"),
316 Some(AggregateFunction::Max)
317 );
318 assert_eq!(
319 AggregateFunction::from_xml_str("min"),
320 Some(AggregateFunction::Min)
321 );
322 assert_eq!(
323 AggregateFunction::from_xml_str("product"),
324 Some(AggregateFunction::Product)
325 );
326 assert_eq!(
327 AggregateFunction::from_xml_str("countNums"),
328 Some(AggregateFunction::CountNums)
329 );
330 assert_eq!(
331 AggregateFunction::from_xml_str("stdDev"),
332 Some(AggregateFunction::StdDev)
333 );
334 assert_eq!(
335 AggregateFunction::from_xml_str("stdDevp"),
336 Some(AggregateFunction::StdDevP)
337 );
338 assert_eq!(
339 AggregateFunction::from_xml_str("var"),
340 Some(AggregateFunction::Var)
341 );
342 assert_eq!(
343 AggregateFunction::from_xml_str("varp"),
344 Some(AggregateFunction::VarP)
345 );
346 }
347
348 #[test]
349 fn test_aggregate_function_from_xml_str_unknown() {
350 assert_eq!(AggregateFunction::from_xml_str("unknown"), None);
351 assert_eq!(AggregateFunction::from_xml_str(""), None);
352 assert_eq!(AggregateFunction::from_xml_str("SUM"), None);
353 }
354
355 #[test]
356 fn test_aggregate_function_roundtrip() {
357 let functions = vec![
358 AggregateFunction::Sum,
359 AggregateFunction::Count,
360 AggregateFunction::Average,
361 AggregateFunction::Max,
362 AggregateFunction::Min,
363 AggregateFunction::Product,
364 AggregateFunction::CountNums,
365 AggregateFunction::StdDev,
366 AggregateFunction::StdDevP,
367 AggregateFunction::Var,
368 AggregateFunction::VarP,
369 ];
370 for func in functions {
371 let xml_str = func.to_xml_str();
372 let parsed = AggregateFunction::from_xml_str(xml_str).unwrap();
373 assert_eq!(func, parsed);
374 }
375 }
376
377 #[test]
378 fn test_capitalize_first() {
379 assert_eq!(capitalize_first("sum"), "Sum");
380 assert_eq!(capitalize_first("count"), "Count");
381 assert_eq!(capitalize_first("average"), "Average");
382 assert_eq!(capitalize_first(""), "");
383 assert_eq!(capitalize_first("a"), "A");
384 }
385
386 #[test]
387 fn test_build_pivot_table_xml_basic() {
388 let config = PivotTableConfig {
389 name: "PivotTable1".to_string(),
390 source_sheet: "Data".to_string(),
391 source_range: "A1:C5".to_string(),
392 target_sheet: "Pivot".to_string(),
393 target_cell: "A1".to_string(),
394 rows: vec![PivotField {
395 name: "Region".to_string(),
396 }],
397 columns: vec![],
398 data: vec![PivotDataField {
399 name: "Sales".to_string(),
400 function: AggregateFunction::Sum,
401 display_name: None,
402 }],
403 };
404 let field_names = vec![
405 "Region".to_string(),
406 "Product".to_string(),
407 "Sales".to_string(),
408 ];
409
410 let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
411 assert_eq!(def.name, "PivotTable1");
412 assert_eq!(def.cache_id, 0);
413 assert_eq!(def.pivot_fields.count, Some(3));
414 assert_eq!(def.pivot_fields.fields.len(), 3);
415
416 assert_eq!(def.pivot_fields.fields[0].axis, Some("axisRow".to_string()));
418 assert_eq!(def.pivot_fields.fields[0].data_field, None);
419
420 assert_eq!(def.pivot_fields.fields[1].axis, None);
422
423 assert_eq!(def.pivot_fields.fields[2].axis, None);
425 assert_eq!(def.pivot_fields.fields[2].data_field, Some(true));
426
427 let row_fields = def.row_fields.unwrap();
429 assert_eq!(row_fields.count, Some(1));
430 assert_eq!(row_fields.fields[0].index, 0);
431
432 assert!(def.col_fields.is_none());
434
435 let data_fields = def.data_fields.unwrap();
437 assert_eq!(data_fields.count, Some(1));
438 assert_eq!(data_fields.fields[0].field_index, 2);
439 assert_eq!(data_fields.fields[0].subtotal, Some("sum".to_string()));
440 assert_eq!(data_fields.fields[0].name, Some("Sum of Sales".to_string()));
441 }
442
443 #[test]
444 fn test_build_pivot_table_xml_with_columns() {
445 let config = PivotTableConfig {
446 name: "SalesReport".to_string(),
447 source_sheet: "Data".to_string(),
448 source_range: "A1:D10".to_string(),
449 target_sheet: "Report".to_string(),
450 target_cell: "A1".to_string(),
451 rows: vec![PivotField {
452 name: "Region".to_string(),
453 }],
454 columns: vec![PivotField {
455 name: "Quarter".to_string(),
456 }],
457 data: vec![PivotDataField {
458 name: "Revenue".to_string(),
459 function: AggregateFunction::Average,
460 display_name: Some("Avg Revenue".to_string()),
461 }],
462 };
463 let field_names = vec![
464 "Region".to_string(),
465 "Quarter".to_string(),
466 "Revenue".to_string(),
467 ];
468
469 let def = build_pivot_table_xml(&config, 1, &field_names).unwrap();
470 assert_eq!(def.cache_id, 1);
471
472 assert_eq!(def.pivot_fields.fields[0].axis, Some("axisRow".to_string()));
474 assert_eq!(def.pivot_fields.fields[1].axis, Some("axisCol".to_string()));
475
476 let col_fields = def.col_fields.unwrap();
477 assert_eq!(col_fields.count, Some(1));
478 assert_eq!(col_fields.fields[0].index, 1);
479
480 let data_fields = def.data_fields.unwrap();
481 assert_eq!(data_fields.fields[0].name, Some("Avg Revenue".to_string()));
482 assert_eq!(data_fields.fields[0].subtotal, Some("average".to_string()));
483 }
484
485 #[test]
486 fn test_build_pivot_table_xml_unknown_field() {
487 let config = PivotTableConfig {
488 name: "Bad".to_string(),
489 source_sheet: "Data".to_string(),
490 source_range: "A1:B2".to_string(),
491 target_sheet: "Pivot".to_string(),
492 target_cell: "A1".to_string(),
493 rows: vec![PivotField {
494 name: "NonExistent".to_string(),
495 }],
496 columns: vec![],
497 data: vec![],
498 };
499 let field_names = vec!["Actual".to_string()];
500
501 let result = build_pivot_table_xml(&config, 0, &field_names);
502 assert!(result.is_err());
503 let err = result.unwrap_err().to_string();
504 assert!(err.contains("NonExistent"));
505 }
506
507 #[test]
508 fn test_build_pivot_table_xml_no_rows_or_cols() {
509 let config = PivotTableConfig {
510 name: "DataOnly".to_string(),
511 source_sheet: "Sheet1".to_string(),
512 source_range: "A1:B5".to_string(),
513 target_sheet: "Pivot".to_string(),
514 target_cell: "A1".to_string(),
515 rows: vec![],
516 columns: vec![],
517 data: vec![PivotDataField {
518 name: "Amount".to_string(),
519 function: AggregateFunction::Count,
520 display_name: None,
521 }],
522 };
523 let field_names = vec!["Amount".to_string()];
524
525 let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
526 assert!(def.row_fields.is_none());
527 assert!(def.col_fields.is_none());
528 assert!(def.data_fields.is_some());
529 }
530
531 #[test]
532 fn test_build_pivot_cache_definition() {
533 let field_names = vec![
534 "Name".to_string(),
535 "Region".to_string(),
536 "Sales".to_string(),
537 ];
538 let def = build_pivot_cache_definition("Sheet1", "A1:C10", &field_names);
539
540 assert_eq!(def.xmlns, sheetkit_xml::namespaces::SPREADSHEET_ML);
541 assert_eq!(def.cache_source.source_type, "worksheet");
542 let ws = def.cache_source.worksheet_source.unwrap();
543 assert_eq!(ws.sheet, "Sheet1");
544 assert_eq!(ws.reference, "A1:C10");
545
546 assert_eq!(def.cache_fields.count, Some(3));
547 assert_eq!(def.cache_fields.fields.len(), 3);
548 assert_eq!(def.cache_fields.fields[0].name, "Name");
549 assert_eq!(def.cache_fields.fields[1].name, "Region");
550 assert_eq!(def.cache_fields.fields[2].name, "Sales");
551
552 for field in &def.cache_fields.fields {
554 assert!(field.shared_items.is_some());
555 let items = field.shared_items.as_ref().unwrap();
556 assert_eq!(items.count, Some(0));
557 }
558
559 assert_eq!(def.record_count, Some(0));
560 assert!(def.r_id.is_none());
561 }
562
563 #[test]
564 fn test_build_pivot_cache_definition_empty_fields() {
565 let field_names: Vec<String> = vec![];
566 let def = build_pivot_cache_definition("Sheet1", "A1:A1", &field_names);
567 assert_eq!(def.cache_fields.count, Some(0));
568 assert!(def.cache_fields.fields.is_empty());
569 }
570
571 #[test]
572 fn test_pivot_table_info_struct() {
573 let info = PivotTableInfo {
574 name: "PT1".to_string(),
575 source_sheet: "Data".to_string(),
576 source_range: "A1:D10".to_string(),
577 target_sheet: "Report".to_string(),
578 location: "A3:E20".to_string(),
579 };
580 assert_eq!(info.name, "PT1");
581 assert_eq!(info.source_sheet, "Data");
582 assert_eq!(info.source_range, "A1:D10");
583 assert_eq!(info.target_sheet, "Report");
584 assert_eq!(info.location, "A3:E20");
585 }
586
587 #[test]
588 fn test_build_pivot_table_xml_generates_default_display_name() {
589 let config = PivotTableConfig {
590 name: "PT".to_string(),
591 source_sheet: "S".to_string(),
592 source_range: "A1:B2".to_string(),
593 target_sheet: "T".to_string(),
594 target_cell: "A1".to_string(),
595 rows: vec![],
596 columns: vec![],
597 data: vec![
598 PivotDataField {
599 name: "Amount".to_string(),
600 function: AggregateFunction::Sum,
601 display_name: None,
602 },
603 PivotDataField {
604 name: "Count".to_string(),
605 function: AggregateFunction::Count,
606 display_name: Some("Total Count".to_string()),
607 },
608 ],
609 };
610 let field_names = vec!["Amount".to_string(), "Count".to_string()];
611
612 let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
613 let data_fields = def.data_fields.unwrap();
614
615 assert_eq!(
617 data_fields.fields[0].name,
618 Some("Sum of Amount".to_string())
619 );
620 assert_eq!(data_fields.fields[1].name, Some("Total Count".to_string()));
622 }
623
624 #[test]
625 fn test_error_pivot_table_not_found() {
626 let err = Error::PivotTableNotFound {
627 name: "Missing".to_string(),
628 };
629 assert_eq!(err.to_string(), "pivot table 'Missing' not found");
630 }
631
632 #[test]
633 fn test_error_pivot_table_already_exists() {
634 let err = Error::PivotTableAlreadyExists {
635 name: "PT1".to_string(),
636 };
637 assert_eq!(err.to_string(), "pivot table 'PT1' already exists");
638 }
639
640 #[test]
641 fn test_error_invalid_source_range() {
642 let err = Error::InvalidSourceRange("bad range".to_string());
643 assert_eq!(err.to_string(), "invalid source range: bad range");
644 }
645}