Skip to main content

sql_cli/sql/functions/
bitwise_string.rs

1//! Bitwise operations on binary string representations
2
3use crate::data::datatable::DataValue;
4use crate::sql::functions::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5use anyhow::{anyhow, Result};
6
7/// BIT_AND_STR - Bitwise AND on binary strings
8pub struct BitAndStr;
9
10impl SqlFunction for BitAndStr {
11    fn signature(&self) -> FunctionSignature {
12        FunctionSignature {
13            name: "BIT_AND_STR",
14            category: FunctionCategory::Bitwise,
15            arg_count: ArgCount::Fixed(2),
16            description: "Performs bitwise AND on two binary strings",
17            returns: "Binary string result",
18            examples: vec![
19                "SELECT BIT_AND_STR('1101', '1011')",
20                "SELECT BIT_AND_STR(TO_BINARY(13), TO_BINARY(11))",
21            ],
22        }
23    }
24
25    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
26        self.validate_args(args)?;
27
28        let a = args[0].to_string();
29        let b = args[1].to_string();
30
31        // Pad to same length
32        let max_len = a.len().max(b.len());
33        let a_padded = format!("{:0>width$}", a, width = max_len);
34        let b_padded = format!("{:0>width$}", b, width = max_len);
35
36        let result: String = a_padded
37            .chars()
38            .zip(b_padded.chars())
39            .map(|(c1, c2)| match (c1, c2) {
40                ('1', '1') => '1',
41                _ => '0',
42            })
43            .collect();
44
45        Ok(DataValue::String(result))
46    }
47}
48
49/// BIT_OR_STR - Bitwise OR on binary strings
50pub struct BitOrStr;
51
52impl SqlFunction for BitOrStr {
53    fn signature(&self) -> FunctionSignature {
54        FunctionSignature {
55            name: "BIT_OR_STR",
56            category: FunctionCategory::Bitwise,
57            arg_count: ArgCount::Fixed(2),
58            description: "Performs bitwise OR on two binary strings",
59            returns: "Binary string result",
60            examples: vec![
61                "SELECT BIT_OR_STR('1100', '1010')",
62                "SELECT BIT_OR_STR(TO_BINARY(12), TO_BINARY(10))",
63            ],
64        }
65    }
66
67    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
68        self.validate_args(args)?;
69
70        let a = args[0].to_string();
71        let b = args[1].to_string();
72
73        let max_len = a.len().max(b.len());
74        let a_padded = format!("{:0>width$}", a, width = max_len);
75        let b_padded = format!("{:0>width$}", b, width = max_len);
76
77        let result: String = a_padded
78            .chars()
79            .zip(b_padded.chars())
80            .map(|(c1, c2)| match (c1, c2) {
81                ('0', '0') => '0',
82                _ => '1',
83            })
84            .collect();
85
86        Ok(DataValue::String(result))
87    }
88}
89
90/// BIT_XOR_STR - Bitwise XOR on binary strings
91pub struct BitXorStr;
92
93impl SqlFunction for BitXorStr {
94    fn signature(&self) -> FunctionSignature {
95        FunctionSignature {
96            name: "BIT_XOR_STR",
97            category: FunctionCategory::Bitwise,
98            arg_count: ArgCount::Fixed(2),
99            description: "Performs bitwise XOR on two binary strings",
100            returns: "Binary string result",
101            examples: vec![
102                "SELECT BIT_XOR_STR('1100', '1010')",
103                "SELECT BIT_XOR_STR(TO_BINARY(12), TO_BINARY(10))",
104            ],
105        }
106    }
107
108    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
109        self.validate_args(args)?;
110
111        let a = args[0].to_string();
112        let b = args[1].to_string();
113
114        let max_len = a.len().max(b.len());
115        let a_padded = format!("{:0>width$}", a, width = max_len);
116        let b_padded = format!("{:0>width$}", b, width = max_len);
117
118        let result: String = a_padded
119            .chars()
120            .zip(b_padded.chars())
121            .map(|(c1, c2)| if c1 == c2 { '0' } else { '1' })
122            .collect();
123
124        Ok(DataValue::String(result))
125    }
126}
127
128/// BIT_NOT_STR - Bitwise NOT on binary string
129pub struct BitNotStr;
130
131impl SqlFunction for BitNotStr {
132    fn signature(&self) -> FunctionSignature {
133        FunctionSignature {
134            name: "BIT_NOT_STR",
135            category: FunctionCategory::Bitwise,
136            arg_count: ArgCount::Fixed(1),
137            description: "Performs bitwise NOT on a binary string",
138            returns: "Binary string with bits flipped",
139            examples: vec![
140                "SELECT BIT_NOT_STR('1100')",
141                "SELECT BIT_NOT_STR(TO_BINARY(12))",
142            ],
143        }
144    }
145
146    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
147        self.validate_args(args)?;
148
149        let input = args[0].to_string();
150
151        let result: String = input
152            .chars()
153            .map(|c| if c == '0' { '1' } else { '0' })
154            .collect();
155
156        Ok(DataValue::String(result))
157    }
158}
159
160/// BIT_FLIP - Alias for BIT_NOT_STR
161pub struct BitFlip;
162
163impl SqlFunction for BitFlip {
164    fn signature(&self) -> FunctionSignature {
165        FunctionSignature {
166            name: "BIT_FLIP",
167            category: FunctionCategory::Bitwise,
168            arg_count: ArgCount::Fixed(1),
169            description: "Flips all bits in a binary string (alias for BIT_NOT_STR)",
170            returns: "Binary string with bits flipped",
171            examples: vec!["SELECT BIT_FLIP('11011010')"],
172        }
173    }
174
175    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
176        BitNotStr.evaluate(args)
177    }
178}
179
180/// BIT_COUNT - Count number of 1 bits (popcount).
181///
182/// Accepts either a binary string (e.g. `'11011010'`) or an integer.
183/// For integers the underlying 64-bit representation is counted via
184/// `u64::count_ones`, so negative values count their two's-complement bits.
185pub struct BitCount;
186
187impl SqlFunction for BitCount {
188    fn signature(&self) -> FunctionSignature {
189        FunctionSignature {
190            name: "BIT_COUNT",
191            category: FunctionCategory::Bitwise,
192            arg_count: ArgCount::Fixed(1),
193            description: "Counts the number of 1 bits in a binary string or integer (popcount)",
194            returns: "Integer count of 1 bits",
195            examples: vec![
196                "SELECT BIT_COUNT('11011010')  -- Returns 5",
197                "SELECT BIT_COUNT(218)         -- Returns 5 (same value, as integer)",
198                "SELECT BIT_COUNT(TO_BINARY(218))",
199            ],
200        }
201    }
202
203    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
204        self.validate_args(args)?;
205
206        let count = match &args[0] {
207            DataValue::Null => return Ok(DataValue::Null),
208            DataValue::Integer(n) => (*n as u64).count_ones() as i64,
209            DataValue::String(s) => s.chars().filter(|&c| c == '1').count() as i64,
210            // Fall back to the legacy string behaviour for anything else
211            // (e.g. booleans/floats stringify and are counted by '1' chars).
212            other => other.to_string().chars().filter(|&c| c == '1').count() as i64,
213        };
214        Ok(DataValue::Integer(count))
215    }
216}
217
218/// BIT_ROTATE_LEFT - Rotate binary string left by N positions
219pub struct BitRotateLeft;
220
221impl SqlFunction for BitRotateLeft {
222    fn signature(&self) -> FunctionSignature {
223        FunctionSignature {
224            name: "BIT_ROTATE_LEFT",
225            category: FunctionCategory::Bitwise,
226            arg_count: ArgCount::Fixed(2),
227            description: "Rotates a binary string left by N positions",
228            returns: "Rotated binary string",
229            examples: vec![
230                "SELECT BIT_ROTATE_LEFT('11011010', 2)",
231                "SELECT BIT_ROTATE_LEFT(TO_BINARY(218), 3)",
232            ],
233        }
234    }
235
236    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
237        self.validate_args(args)?;
238
239        let input = args[0].to_string();
240        let positions = match &args[1] {
241            DataValue::Integer(n) => *n as usize,
242            DataValue::Float(f) => *f as usize,
243            _ => return Err(anyhow!("Second argument must be a number")),
244        };
245
246        if input.is_empty() {
247            return Ok(DataValue::String(String::new()));
248        }
249
250        let effective_positions = positions % input.len();
251        let result = format!(
252            "{}{}",
253            &input[effective_positions..],
254            &input[..effective_positions]
255        );
256
257        Ok(DataValue::String(result))
258    }
259}
260
261/// BIT_ROTATE_RIGHT - Rotate binary string right by N positions
262pub struct BitRotateRight;
263
264impl SqlFunction for BitRotateRight {
265    fn signature(&self) -> FunctionSignature {
266        FunctionSignature {
267            name: "BIT_ROTATE_RIGHT",
268            category: FunctionCategory::Bitwise,
269            arg_count: ArgCount::Fixed(2),
270            description: "Rotates a binary string right by N positions",
271            returns: "Rotated binary string",
272            examples: vec![
273                "SELECT BIT_ROTATE_RIGHT('11011010', 2)",
274                "SELECT BIT_ROTATE_RIGHT(TO_BINARY(218), 3)",
275            ],
276        }
277    }
278
279    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
280        self.validate_args(args)?;
281
282        let input = args[0].to_string();
283        let positions = match &args[1] {
284            DataValue::Integer(n) => *n as usize,
285            DataValue::Float(f) => *f as usize,
286            _ => return Err(anyhow!("Second argument must be a number")),
287        };
288
289        if input.is_empty() {
290            return Ok(DataValue::String(String::new()));
291        }
292
293        let effective_positions = positions % input.len();
294        let split_point = input.len() - effective_positions;
295        let result = format!("{}{}", &input[split_point..], &input[..split_point]);
296
297        Ok(DataValue::String(result))
298    }
299}
300
301/// BIT_SHIFT_LEFT - Shift binary string left, filling with zeros
302pub struct BitShiftLeft;
303
304impl SqlFunction for BitShiftLeft {
305    fn signature(&self) -> FunctionSignature {
306        FunctionSignature {
307            name: "BIT_SHIFT_LEFT",
308            category: FunctionCategory::Bitwise,
309            arg_count: ArgCount::Fixed(2),
310            description: "Shifts a binary string left by N positions, filling with zeros",
311            returns: "Shifted binary string",
312            examples: vec![
313                "SELECT BIT_SHIFT_LEFT('11011010', 2)",
314                "SELECT BIT_SHIFT_LEFT(TO_BINARY(218), 3)",
315            ],
316        }
317    }
318
319    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
320        self.validate_args(args)?;
321
322        let input = args[0].to_string();
323        let positions = match &args[1] {
324            DataValue::Integer(n) => *n as usize,
325            DataValue::Float(f) => *f as usize,
326            _ => return Err(anyhow!("Second argument must be a number")),
327        };
328
329        if positions >= input.len() {
330            return Ok(DataValue::String("0".repeat(input.len())));
331        }
332
333        let result = format!("{}{}", &input[positions..], "0".repeat(positions));
334
335        Ok(DataValue::String(result))
336    }
337}
338
339/// BIT_SHIFT_RIGHT - Shift binary string right, filling with zeros
340pub struct BitShiftRight;
341
342impl SqlFunction for BitShiftRight {
343    fn signature(&self) -> FunctionSignature {
344        FunctionSignature {
345            name: "BIT_SHIFT_RIGHT",
346            category: FunctionCategory::Bitwise,
347            arg_count: ArgCount::Fixed(2),
348            description: "Shifts a binary string right by N positions, filling with zeros",
349            returns: "Shifted binary string",
350            examples: vec![
351                "SELECT BIT_SHIFT_RIGHT('11011010', 2)",
352                "SELECT BIT_SHIFT_RIGHT(TO_BINARY(218), 3)",
353            ],
354        }
355    }
356
357    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
358        self.validate_args(args)?;
359
360        let input = args[0].to_string();
361        let positions = match &args[1] {
362            DataValue::Integer(n) => *n as usize,
363            DataValue::Float(f) => *f as usize,
364            _ => return Err(anyhow!("Second argument must be a number")),
365        };
366
367        if positions >= input.len() {
368            return Ok(DataValue::String("0".repeat(input.len())));
369        }
370
371        let result = format!(
372            "{}{}",
373            "0".repeat(positions),
374            &input[..input.len() - positions]
375        );
376
377        Ok(DataValue::String(result))
378    }
379}
380
381/// HAMMING_DISTANCE - Count bit differences between two binary strings
382pub struct HammingDistance;
383
384impl SqlFunction for HammingDistance {
385    fn signature(&self) -> FunctionSignature {
386        FunctionSignature {
387            name: "HAMMING_DISTANCE",
388            category: FunctionCategory::Bitwise,
389            arg_count: ArgCount::Fixed(2),
390            description: "Counts the number of differing bits between two binary strings",
391            returns: "Integer count of different bits",
392            examples: vec![
393                "SELECT HAMMING_DISTANCE('1101', '1011')",
394                "SELECT HAMMING_DISTANCE(TO_BINARY(13), TO_BINARY(11))",
395            ],
396        }
397    }
398
399    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
400        self.validate_args(args)?;
401
402        let a = args[0].to_string();
403        let b = args[1].to_string();
404
405        let max_len = a.len().max(b.len());
406        let a_padded = format!("{:0>width$}", a, width = max_len);
407        let b_padded = format!("{:0>width$}", b, width = max_len);
408
409        let distance = a_padded
410            .chars()
411            .zip(b_padded.chars())
412            .filter(|(c1, c2)| c1 != c2)
413            .count() as i64;
414
415        Ok(DataValue::Integer(distance))
416    }
417}