Skip to main content

sql_cli/sql/functions/
bitwise.rs

1//! Bitwise operations and binary visualization functions
2//!
3//! This module provides additional bitwise functions not covered by the bigint module.
4//! Basic operations like BITAND, BITOR, BITXOR, TO_BINARY, FROM_BINARY are in bigint.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9use crate::sql::functions::{
10    ArgCount, FunctionCategory, FunctionRegistry, FunctionSignature, SqlFunction,
11};
12
13/// BITNOT(a) - Bitwise NOT operation
14pub struct BitNotFunction;
15
16impl SqlFunction for BitNotFunction {
17    fn signature(&self) -> FunctionSignature {
18        FunctionSignature {
19            name: "BITNOT",
20            category: FunctionCategory::Bitwise,
21            arg_count: ArgCount::Fixed(1),
22            description: "Performs bitwise NOT operation (ones' complement)",
23            returns: "Integer result of ~a",
24            examples: vec![
25                "SELECT BITNOT(0)      -- Returns -1 (all bits set)",
26                "SELECT BITNOT(255)    -- Returns -256",
27                "SELECT BITNOT(-1)     -- Returns 0",
28            ],
29        }
30    }
31
32    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
33        if args.len() != 1 {
34            return Err(anyhow!("BITNOT requires exactly 1 argument"));
35        }
36
37        match &args[0] {
38            DataValue::Integer(a) => Ok(DataValue::Integer(!a)),
39            DataValue::Null => Ok(DataValue::Null),
40            _ => Err(anyhow!("BITNOT requires an integer argument")),
41        }
42    }
43}
44
45/// IS_POWER_OF_TWO(n) - Check if number is exact power of two
46pub struct IsPowerOfTwoFunction;
47
48impl SqlFunction for IsPowerOfTwoFunction {
49    fn signature(&self) -> FunctionSignature {
50        FunctionSignature {
51            name: "IS_POWER_OF_TWO",
52            category: FunctionCategory::Bitwise,
53            arg_count: ArgCount::Fixed(1),
54            description: "Checks if a number is an exact power of two using n & (n-1) == 0",
55            returns: "Boolean true if power of two, false otherwise",
56            examples: vec![
57                "SELECT IS_POWER_OF_TWO(16)  -- Returns true (2^4)",
58                "SELECT IS_POWER_OF_TWO(15)  -- Returns false",
59                "SELECT IS_POWER_OF_TWO(1)   -- Returns true (2^0)",
60                "SELECT IS_POWER_OF_TWO(0)   -- Returns false",
61            ],
62        }
63    }
64
65    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
66        if args.len() != 1 {
67            return Err(anyhow!("IS_POWER_OF_TWO requires exactly 1 argument"));
68        }
69
70        match &args[0] {
71            DataValue::Integer(n) => {
72                // A number is a power of two if:
73                // 1. It's positive (greater than 0)
74                // 2. n & (n-1) == 0
75                let is_power = *n > 0 && (n & (n - 1)) == 0;
76                Ok(DataValue::Boolean(is_power))
77            }
78            DataValue::Null => Ok(DataValue::Null),
79            _ => Err(anyhow!("IS_POWER_OF_TWO requires an integer argument")),
80        }
81    }
82}
83
84/// COUNT_BITS(n) - Count number of set bits (popcount)
85pub struct CountBitsFunction;
86
87impl SqlFunction for CountBitsFunction {
88    fn signature(&self) -> FunctionSignature {
89        FunctionSignature {
90            name: "COUNT_BITS",
91            category: FunctionCategory::Bitwise,
92            arg_count: ArgCount::Fixed(1),
93            description: "Counts the number of set bits (1s) in the binary representation",
94            returns: "Integer count of set bits",
95            examples: vec![
96                "SELECT COUNT_BITS(7)    -- Returns 3 (111 has three 1s)",
97                "SELECT COUNT_BITS(255)  -- Returns 8 (11111111 has eight 1s)",
98                "SELECT COUNT_BITS(16)   -- Returns 1 (10000 has one 1)",
99            ],
100        }
101    }
102
103    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
104        if args.len() != 1 {
105            return Err(anyhow!("COUNT_BITS requires exactly 1 argument"));
106        }
107
108        match &args[0] {
109            DataValue::Integer(n) => {
110                // Use built-in popcount
111                let count = (*n as u64).count_ones() as i64;
112                Ok(DataValue::Integer(count))
113            }
114            DataValue::Null => Ok(DataValue::Null),
115            _ => Err(anyhow!("COUNT_BITS requires an integer argument")),
116        }
117    }
118}
119
120/// BINARY_FORMAT(n, separator, group_size) - Format binary with separator
121pub struct BinaryFormatFunction;
122
123impl SqlFunction for BinaryFormatFunction {
124    fn signature(&self) -> FunctionSignature {
125        FunctionSignature {
126            name: "BINARY_FORMAT",
127            category: FunctionCategory::Bitwise,
128            arg_count: ArgCount::Range(1, 3),
129            description: "Formats binary string with separators for readability",
130            returns: "Formatted binary string",
131            examples: vec![
132                "SELECT BINARY_FORMAT(255)           -- Returns '11111111'",
133                "SELECT BINARY_FORMAT(255, '_')      -- Returns '1111_1111' (groups of 4)",
134                "SELECT BINARY_FORMAT(65535, '_', 8) -- Returns '11111111_11111111' (groups of 8)",
135            ],
136        }
137    }
138
139    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
140        if args.is_empty() || args.len() > 3 {
141            return Err(anyhow!("BINARY_FORMAT requires 1-3 arguments"));
142        }
143
144        let value = match &args[0] {
145            DataValue::Integer(n) => *n,
146            DataValue::Null => return Ok(DataValue::Null),
147            _ => {
148                return Err(anyhow!(
149                    "BINARY_FORMAT requires an integer as first argument"
150                ))
151            }
152        };
153
154        let separator = if args.len() >= 2 {
155            match &args[1] {
156                DataValue::String(s) => s.clone(),
157                DataValue::Null => String::new(),
158                _ => return Err(anyhow!("Separator must be a string")),
159            }
160        } else {
161            String::new()
162        };
163
164        let group_size = if args.len() == 3 {
165            match &args[2] {
166                DataValue::Integer(g) => {
167                    if *g <= 0 {
168                        return Err(anyhow!("Group size must be positive"));
169                    }
170                    *g as usize
171                }
172                DataValue::Null => 4, // Default to groups of 4
173                _ => return Err(anyhow!("Group size must be an integer")),
174            }
175        } else {
176            4 // Default to groups of 4
177        };
178
179        // Convert to binary string
180        let binary = if value >= 0 {
181            format!("{:b}", value)
182        } else {
183            format!("{:b}", value as u64)
184        };
185
186        // Add separators if requested
187        let result = if !separator.is_empty() && group_size > 0 {
188            let mut formatted = String::new();
189            let mut chars: Vec<char> = binary.chars().collect();
190
191            // Process from right to left for consistent grouping
192            while !chars.is_empty() {
193                let group_start = chars.len().saturating_sub(group_size);
194                let group: String = chars.drain(group_start..).collect();
195
196                if !formatted.is_empty() {
197                    formatted = format!("{}{}{}", group, separator, formatted);
198                } else {
199                    formatted = group;
200                }
201            }
202            formatted
203        } else {
204            binary
205        };
206
207        Ok(DataValue::String(result))
208    }
209}
210
211/// NEXT_POWER_OF_TWO(n) - Returns the next power of two greater than or equal to n
212pub struct NextPowerOfTwoFunction;
213
214impl SqlFunction for NextPowerOfTwoFunction {
215    fn signature(&self) -> FunctionSignature {
216        FunctionSignature {
217            name: "NEXT_POWER_OF_TWO",
218            category: FunctionCategory::Bitwise,
219            arg_count: ArgCount::Fixed(1),
220            description: "Returns the next power of two greater than or equal to n",
221            returns: "Integer that is the next power of two",
222            examples: vec![
223                "SELECT NEXT_POWER_OF_TWO(5)   -- Returns 8",
224                "SELECT NEXT_POWER_OF_TWO(16)  -- Returns 16 (already power of 2)",
225                "SELECT NEXT_POWER_OF_TWO(17)  -- Returns 32",
226            ],
227        }
228    }
229
230    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
231        if args.len() != 1 {
232            return Err(anyhow!("NEXT_POWER_OF_TWO requires exactly 1 argument"));
233        }
234
235        match &args[0] {
236            DataValue::Integer(n) => {
237                if *n <= 0 {
238                    return Ok(DataValue::Integer(1));
239                }
240
241                // Find the next power of two
242                let mut power = 1i64;
243                while power < *n && power < i64::MAX / 2 {
244                    power <<= 1;
245                }
246
247                Ok(DataValue::Integer(power))
248            }
249            DataValue::Null => Ok(DataValue::Null),
250            _ => Err(anyhow!("NEXT_POWER_OF_TWO requires an integer argument")),
251        }
252    }
253}
254
255/// HIGHEST_BIT(n) - Returns the position of the highest set bit
256pub struct HighestBitFunction;
257
258impl SqlFunction for HighestBitFunction {
259    fn signature(&self) -> FunctionSignature {
260        FunctionSignature {
261            name: "HIGHEST_BIT",
262            category: FunctionCategory::Bitwise,
263            arg_count: ArgCount::Fixed(1),
264            description: "Returns the position of the highest set bit (0-indexed)",
265            returns: "Integer position of highest bit, or -1 if no bits set",
266            examples: vec![
267                "SELECT HIGHEST_BIT(8)    -- Returns 3 (bit 3 is set in 1000)",
268                "SELECT HIGHEST_BIT(255)  -- Returns 7 (bit 7 is highest in 11111111)",
269                "SELECT HIGHEST_BIT(0)    -- Returns -1 (no bits set)",
270            ],
271        }
272    }
273
274    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
275        if args.len() != 1 {
276            return Err(anyhow!("HIGHEST_BIT requires exactly 1 argument"));
277        }
278
279        match &args[0] {
280            DataValue::Integer(n) => {
281                if *n <= 0 {
282                    return Ok(DataValue::Integer(-1));
283                }
284
285                // Find the position of the highest bit
286                let position = 63 - (*n as u64).leading_zeros() as i64;
287                Ok(DataValue::Integer(position))
288            }
289            DataValue::Null => Ok(DataValue::Null),
290            _ => Err(anyhow!("HIGHEST_BIT requires an integer argument")),
291        }
292    }
293}
294
295/// LOWEST_BIT(n) - Returns the position of the lowest set bit
296pub struct LowestBitFunction;
297
298impl SqlFunction for LowestBitFunction {
299    fn signature(&self) -> FunctionSignature {
300        FunctionSignature {
301            name: "LOWEST_BIT",
302            category: FunctionCategory::Bitwise,
303            arg_count: ArgCount::Fixed(1),
304            description: "Returns the position of the lowest set bit (0-indexed)",
305            returns: "Integer position of lowest bit, or -1 if no bits set",
306            examples: vec![
307                "SELECT LOWEST_BIT(8)    -- Returns 3 (bit 3 is the only bit in 1000)",
308                "SELECT LOWEST_BIT(12)   -- Returns 2 (bit 2 is lowest in 1100)",
309                "SELECT LOWEST_BIT(0)    -- Returns -1 (no bits set)",
310            ],
311        }
312    }
313
314    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
315        if args.len() != 1 {
316            return Err(anyhow!("LOWEST_BIT requires exactly 1 argument"));
317        }
318
319        match &args[0] {
320            DataValue::Integer(n) => {
321                if *n == 0 {
322                    return Ok(DataValue::Integer(-1));
323                }
324
325                // Find the position of the lowest bit using trailing zeros
326                let position = (*n as u64).trailing_zeros() as i64;
327                Ok(DataValue::Integer(position))
328            }
329            DataValue::Null => Ok(DataValue::Null),
330            _ => Err(anyhow!("LOWEST_BIT requires an integer argument")),
331        }
332    }
333}
334
335/// POPCOUNT(n) - Population count: number of set bits in an integer.
336///
337/// This is the canonical CPU/intrinsic name for what `COUNT_BITS` also does,
338/// and is polymorphic with `BIT_COUNT` (which also accepts binary strings).
339pub struct PopcountFunction;
340
341impl SqlFunction for PopcountFunction {
342    fn signature(&self) -> FunctionSignature {
343        FunctionSignature {
344            name: "POPCOUNT",
345            category: FunctionCategory::Bitwise,
346            arg_count: ArgCount::Fixed(1),
347            description: "Population count: number of set bits (1s) in an integer",
348            returns: "Integer count of set bits",
349            examples: vec![
350                "SELECT POPCOUNT(7)    -- Returns 3 (0b111)",
351                "SELECT POPCOUNT(255)  -- Returns 8",
352                "SELECT POPCOUNT(0)    -- Returns 0",
353            ],
354        }
355    }
356
357    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
358        if args.len() != 1 {
359            return Err(anyhow!("POPCOUNT requires exactly 1 argument"));
360        }
361
362        match &args[0] {
363            DataValue::Integer(n) => Ok(DataValue::Integer((*n as u64).count_ones() as i64)),
364            DataValue::Null => Ok(DataValue::Null),
365            _ => Err(anyhow!("POPCOUNT requires an integer argument")),
366        }
367    }
368}
369
370/// Shared helper: parse the optional `width` argument (8/16/32/64) for
371/// LEADING_ZEROS / TRAILING_ONES style functions. Returns the width in bits.
372fn parse_bit_width(arg: &DataValue, func_name: &str) -> Result<u32> {
373    match arg {
374        DataValue::Integer(w) => match *w {
375            8 | 16 | 32 | 64 => Ok(*w as u32),
376            other => Err(anyhow!(
377                "{func_name}: width must be 8, 16, 32, or 64 (got {other})"
378            )),
379        },
380        _ => Err(anyhow!("{func_name}: width must be an integer")),
381    }
382}
383
384/// LEADING_ZEROS(n [, width]) - Number of leading zero bits in the binary
385/// representation of `n`, viewed as an unsigned integer of the given width.
386///
387/// `width` defaults to 64 so the result passes through `u64::leading_zeros`
388/// directly. Use an explicit width (8/16/32/64) to get the intuitive answer
389/// within a smaller container (e.g. `LEADING_ZEROS(8, 8)` -> 4).
390pub struct LeadingZerosFunction;
391
392impl SqlFunction for LeadingZerosFunction {
393    fn signature(&self) -> FunctionSignature {
394        FunctionSignature {
395            name: "LEADING_ZEROS",
396            category: FunctionCategory::Bitwise,
397            arg_count: ArgCount::Range(1, 2),
398            description: "Number of leading zero bits (optionally within a given bit width: 8/16/32/64, default 64)",
399            returns: "Integer count of leading zeros",
400            examples: vec![
401                "SELECT LEADING_ZEROS(1)          -- Returns 63 (64-bit view)",
402                "SELECT LEADING_ZEROS(8, 8)       -- Returns 4 (00001000)",
403                "SELECT LEADING_ZEROS(1, 16)      -- Returns 15",
404                "SELECT LEADING_ZEROS(0, 32)      -- Returns 32",
405            ],
406        }
407    }
408
409    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
410        if args.is_empty() || args.len() > 2 {
411            return Err(anyhow!("LEADING_ZEROS requires 1 or 2 arguments"));
412        }
413
414        let n = match &args[0] {
415            DataValue::Integer(n) => *n,
416            DataValue::Null => return Ok(DataValue::Null),
417            _ => return Err(anyhow!("LEADING_ZEROS requires an integer argument")),
418        };
419
420        let width = if args.len() == 2 {
421            parse_bit_width(&args[1], "LEADING_ZEROS")?
422        } else {
423            64
424        };
425
426        // Mask down to `width` bits so the count is relative to that container.
427        // For width=64 the mask is !0 which keeps every bit; smaller widths
428        // truncate so e.g. LEADING_ZEROS(-1, 8) == 0 (all eight bits are set).
429        let mask: u64 = if width == 64 {
430            !0u64
431        } else {
432            (1u64 << width) - 1
433        };
434        let masked = (n as u64) & mask;
435
436        let count = if masked == 0 {
437            width as i64
438        } else {
439            // `u64::leading_zeros` counts against the 64-bit container, so we
440            // subtract the overhead above `width` to get the result within the
441            // requested window.
442            (masked.leading_zeros() as i64) - (64 - width as i64)
443        };
444
445        Ok(DataValue::Integer(count))
446    }
447}
448
449/// TRAILING_ZEROS(n) - Number of trailing zero bits in `n`.
450///
451/// Width-invariant (trailing zeros from the low end are the same regardless of
452/// the container width), so no width argument is needed. For `n == 0` this
453/// returns -1 to match the convention used by `LOWEST_BIT` — "no bits set".
454pub struct TrailingZerosFunction;
455
456impl SqlFunction for TrailingZerosFunction {
457    fn signature(&self) -> FunctionSignature {
458        FunctionSignature {
459            name: "TRAILING_ZEROS",
460            category: FunctionCategory::Bitwise,
461            arg_count: ArgCount::Fixed(1),
462            description: "Number of trailing zero bits; returns -1 if n is zero",
463            returns: "Integer count of trailing zeros, or -1",
464            examples: vec![
465                "SELECT TRAILING_ZEROS(8)   -- Returns 3 (0b1000)",
466                "SELECT TRAILING_ZEROS(12)  -- Returns 2 (0b1100)",
467                "SELECT TRAILING_ZEROS(1)   -- Returns 0",
468                "SELECT TRAILING_ZEROS(0)   -- Returns -1",
469            ],
470        }
471    }
472
473    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
474        if args.len() != 1 {
475            return Err(anyhow!("TRAILING_ZEROS requires exactly 1 argument"));
476        }
477
478        match &args[0] {
479            DataValue::Integer(n) => {
480                if *n == 0 {
481                    Ok(DataValue::Integer(-1))
482                } else {
483                    Ok(DataValue::Integer((*n as u64).trailing_zeros() as i64))
484                }
485            }
486            DataValue::Null => Ok(DataValue::Null),
487            _ => Err(anyhow!("TRAILING_ZEROS requires an integer argument")),
488        }
489    }
490}
491
492/// Register all bitwise functions
493pub fn register_bitwise_functions(registry: &mut FunctionRegistry) {
494    registry.register(Box::new(BitNotFunction));
495    registry.register(Box::new(IsPowerOfTwoFunction));
496    registry.register(Box::new(CountBitsFunction));
497    registry.register(Box::new(PopcountFunction));
498    registry.register(Box::new(BinaryFormatFunction));
499    registry.register(Box::new(NextPowerOfTwoFunction));
500    registry.register(Box::new(HighestBitFunction));
501    registry.register(Box::new(LowestBitFunction));
502    registry.register(Box::new(LeadingZerosFunction));
503    registry.register(Box::new(TrailingZerosFunction));
504}