spreadsheet_mcp/formula/
pattern.rs

1use crate::utils::column_number_to_name;
2use anyhow::{Result, anyhow, bail};
3use formualizer_parse::parser::ReferenceType;
4use formualizer_parse::{ASTNode, ASTNodeType, LiteralValue};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RelativeMode {
8    Excel,
9    AbsCols,
10    AbsRows,
11}
12
13impl RelativeMode {
14    pub fn parse(mode: Option<&str>) -> Result<Self> {
15        match mode.unwrap_or("excel").to_ascii_lowercase().as_str() {
16            "excel" => Ok(Self::Excel),
17            "abs_cols" | "abscols" | "columns_absolute" => Ok(Self::AbsCols),
18            "abs_rows" | "absrows" | "rows_absolute" => Ok(Self::AbsRows),
19            other => bail!("invalid relative_mode: {}", other),
20        }
21    }
22}
23
24#[derive(Debug, Clone, Copy)]
25struct CoordFlags {
26    abs_col: bool,
27    abs_row: bool,
28}
29
30pub fn parse_base_formula(formula: &str) -> Result<ASTNode> {
31    let trimmed = formula.trim();
32    let with_equals = if trimmed.starts_with('=') {
33        trimmed.to_string()
34    } else {
35        format!("={}", trimmed)
36    };
37    formualizer_parse::parse(&with_equals)
38        .map_err(|e| anyhow!("failed to parse base_formula: {}", e.message))
39}
40
41pub fn shift_formula_ast(
42    ast: &ASTNode,
43    delta_col: i32,
44    delta_row: i32,
45    mode: RelativeMode,
46) -> Result<String> {
47    Ok(format!("={}", shift_node(ast, delta_col, delta_row, mode)?))
48}
49
50fn shift_node(
51    node: &ASTNode,
52    delta_col: i32,
53    delta_row: i32,
54    mode: RelativeMode,
55) -> Result<String> {
56    Ok(match &node.node_type {
57        ASTNodeType::Literal(value) => match value {
58            LiteralValue::Text(s) => {
59                let escaped = s.replace('"', "\"\"");
60                format!("\"{escaped}\"")
61            }
62            _ => format!("{value}"),
63        },
64        ASTNodeType::Reference {
65            original,
66            reference,
67        } => shift_reference(original, reference, delta_col, delta_row, mode)?,
68        ASTNodeType::UnaryOp { op, expr } => {
69            format!("{}{}", op, shift_node(expr, delta_col, delta_row, mode)?)
70        }
71        ASTNodeType::BinaryOp { op, left, right } => {
72            if op == ":" {
73                format!(
74                    "{}:{}",
75                    shift_node(left, delta_col, delta_row, mode)?,
76                    shift_node(right, delta_col, delta_row, mode)?
77                )
78            } else {
79                format!(
80                    "{} {} {}",
81                    shift_node(left, delta_col, delta_row, mode)?,
82                    op,
83                    shift_node(right, delta_col, delta_row, mode)?
84                )
85            }
86        }
87        ASTNodeType::Function { name, args } => {
88            let args_str = args
89                .iter()
90                .map(|a| shift_node(a, delta_col, delta_row, mode))
91                .collect::<Result<Vec<_>>>()?
92                .join(", ");
93            format!("{}({})", name.to_uppercase(), args_str)
94        }
95        ASTNodeType::Array(rows) => {
96            let rows_str = rows
97                .iter()
98                .map(|row| {
99                    row.iter()
100                        .map(|a| shift_node(a, delta_col, delta_row, mode))
101                        .collect::<Result<Vec<_>>>()
102                        .map(|parts| parts.join(", "))
103                })
104                .collect::<Result<Vec<_>>>()?
105                .join("; ");
106            format!("{{{rows_str}}}")
107        }
108    })
109}
110
111fn shift_reference(
112    original: &str,
113    reference: &ReferenceType,
114    delta_col: i32,
115    delta_row: i32,
116    mode: RelativeMode,
117) -> Result<String> {
118    match reference {
119        ReferenceType::Cell { sheet, row, col } => {
120            let coord_part = strip_sheet_prefix(original);
121            let mut flags = coord_abs_flags(coord_part);
122            match mode {
123                RelativeMode::AbsCols => flags.abs_col = true,
124                RelativeMode::AbsRows => flags.abs_row = true,
125                RelativeMode::Excel => {}
126            }
127            let new_col = shift_u32(*col, flags.abs_col, delta_col)?;
128            let new_row = shift_u32(*row, flags.abs_row, delta_row)?;
129            let coord = format_cell_coord(new_col, new_row, flags);
130            Ok(format!("{}{}", format_sheet_prefix(sheet), coord))
131        }
132        ReferenceType::Range {
133            sheet,
134            start_row,
135            start_col,
136            end_row,
137            end_col,
138        } => {
139            let ref_part = strip_sheet_prefix(original);
140            let (start_str, end_str) = ref_part.split_once(':').unwrap_or((ref_part, ref_part));
141            let mut start_flags = coord_abs_flags(start_str);
142            let mut end_flags = coord_abs_flags(end_str);
143
144            match mode {
145                RelativeMode::AbsCols => {
146                    if start_col.is_some() {
147                        start_flags.abs_col = true;
148                    }
149                    if end_col.is_some() {
150                        end_flags.abs_col = true;
151                    }
152                }
153                RelativeMode::AbsRows => {
154                    if start_row.is_some() {
155                        start_flags.abs_row = true;
156                    }
157                    if end_row.is_some() {
158                        end_flags.abs_row = true;
159                    }
160                }
161                RelativeMode::Excel => {}
162            }
163
164            let new_start_col = shift_opt_u32(*start_col, start_flags.abs_col, delta_col)?;
165            let new_end_col = shift_opt_u32(*end_col, end_flags.abs_col, delta_col)?;
166            let new_start_row = shift_opt_u32(*start_row, start_flags.abs_row, delta_row)?;
167            let new_end_row = shift_opt_u32(*end_row, end_flags.abs_row, delta_row)?;
168
169            let start_coord = format_range_coord(new_start_col, new_start_row, start_flags);
170            let end_coord = format_range_coord(new_end_col, new_end_row, end_flags);
171            if start_coord.is_empty() || end_coord.is_empty() {
172                bail!("invalid range reference after shift: {}", original);
173            }
174            let coord = format!("{start_coord}:{end_coord}");
175            Ok(format!("{}{}", format_sheet_prefix(sheet), coord))
176        }
177        ReferenceType::Table(_) | ReferenceType::NamedRange(_) => Ok(reference.to_string()),
178    }
179}
180
181fn shift_u32(value: u32, abs: bool, delta: i32) -> Result<u32> {
182    if abs || delta == 0 {
183        return Ok(value);
184    }
185    let shifted = value as i64 + delta as i64;
186    if shifted < 1 {
187        bail!("shift would move reference before A1");
188    }
189    Ok(shifted as u32)
190}
191
192fn shift_opt_u32(value: Option<u32>, abs: bool, delta: i32) -> Result<Option<u32>> {
193    match value {
194        Some(v) => Ok(Some(shift_u32(v, abs, delta)?)),
195        None => Ok(None),
196    }
197}
198
199fn strip_sheet_prefix(original: &str) -> &str {
200    original
201        .rsplit_once('!')
202        .map(|(_, tail)| tail)
203        .unwrap_or(original)
204        .trim()
205}
206
207fn coord_abs_flags(coord: &str) -> CoordFlags {
208    let bytes = coord.as_bytes();
209    let len = bytes.len();
210    let mut i = 0;
211    let leading_dollar = i < len && bytes[i] == b'$';
212    if leading_dollar {
213        i += 1;
214    }
215
216    let letters_start = i;
217    while i < len && bytes[i].is_ascii_alphabetic() {
218        i += 1;
219    }
220    let has_letters = i > letters_start;
221
222    let second_dollar = i < len && bytes[i] == b'$';
223    let digits_start = if second_dollar { i + 1 } else { i };
224    let mut j = digits_start;
225    while j < len && bytes[j].is_ascii_digit() {
226        j += 1;
227    }
228    let has_digits = j > digits_start;
229
230    let abs_col = leading_dollar && has_letters;
231    let abs_row = if has_letters {
232        second_dollar && has_digits
233    } else {
234        leading_dollar && has_digits
235    };
236
237    CoordFlags { abs_col, abs_row }
238}
239
240fn format_cell_coord(col: u32, row: u32, flags: CoordFlags) -> String {
241    let col_str = column_number_to_name(col);
242    let mut out = String::new();
243    if flags.abs_col {
244        out.push('$');
245    }
246    out.push_str(&col_str);
247    if flags.abs_row {
248        out.push('$');
249    }
250    out.push_str(&row.to_string());
251    out
252}
253
254fn format_range_coord(col: Option<u32>, row: Option<u32>, flags: CoordFlags) -> String {
255    match (col, row) {
256        (Some(c), Some(r)) => format_cell_coord(c, r, flags),
257        (Some(c), None) => {
258            let col_str = column_number_to_name(c);
259            if flags.abs_col {
260                format!("${col_str}")
261            } else {
262                col_str
263            }
264        }
265        (None, Some(r)) => {
266            if flags.abs_row {
267                format!("${r}")
268            } else {
269                r.to_string()
270            }
271        }
272        (None, None) => String::new(),
273    }
274}
275
276fn format_sheet_prefix(sheet: &Option<String>) -> String {
277    if let Some(name) = sheet {
278        if sheet_name_needs_quoting(name) {
279            let escaped = name.replace('\'', "''");
280            format!("'{escaped}'!")
281        } else {
282            format!("{name}!")
283        }
284    } else {
285        String::new()
286    }
287}
288
289fn sheet_name_needs_quoting(name: &str) -> bool {
290    if name.is_empty() {
291        return false;
292    }
293    let bytes = name.as_bytes();
294    if bytes[0].is_ascii_digit() {
295        return true;
296    }
297    for &byte in bytes {
298        match byte {
299            b' ' | b'!' | b'"' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'(' | b')' | b'*' | b'+'
300            | b',' | b'-' | b'.' | b'/' | b':' | b';' | b'<' | b'=' | b'>' | b'?' | b'@' | b'['
301            | b'\\' | b']' | b'^' | b'`' | b'{' | b'|' | b'}' | b'~' => return true,
302            _ => {}
303        }
304    }
305    let upper = name.to_uppercase();
306    matches!(
307        upper.as_str(),
308        "TRUE" | "FALSE" | "NULL" | "REF" | "DIV" | "NAME" | "NUM" | "VALUE" | "N/A"
309    )
310}