sqlx_core/sqlite/connection/
explain.rs

1use crate::error::Error;
2use crate::from_row::FromRow;
3use crate::sqlite::connection::{execute, ConnectionState};
4use crate::sqlite::type_info::DataType;
5use crate::sqlite::SqliteTypeInfo;
6use crate::HashMap;
7use std::str::from_utf8;
8
9// affinity
10const SQLITE_AFF_NONE: u8 = 0x40; /* '@' */
11const SQLITE_AFF_BLOB: u8 = 0x41; /* 'A' */
12const SQLITE_AFF_TEXT: u8 = 0x42; /* 'B' */
13const SQLITE_AFF_NUMERIC: u8 = 0x43; /* 'C' */
14const SQLITE_AFF_INTEGER: u8 = 0x44; /* 'D' */
15const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */
16
17// opcodes
18const OP_INIT: &str = "Init";
19const OP_GOTO: &str = "Goto";
20const OP_DECR_JUMP_ZERO: &str = "DecrJumpZero";
21const OP_ELSE_EQ: &str = "ElseEq";
22const OP_EQ: &str = "Eq";
23const OP_END_COROUTINE: &str = "EndCoroutine";
24const OP_FILTER: &str = "Filter";
25const OP_FK_IF_ZERO: &str = "FkIfZero";
26const OP_FOUND: &str = "Found";
27const OP_GE: &str = "Ge";
28const OP_GO_SUB: &str = "Gosub";
29const OP_GT: &str = "Gt";
30const OP_IDX_GE: &str = "IdxGE";
31const OP_IDX_GT: &str = "IdxGT";
32const OP_IDX_LE: &str = "IdxLE";
33const OP_IDX_LT: &str = "IdxLT";
34const OP_IF: &str = "If";
35const OP_IF_NO_HOPE: &str = "IfNoHope";
36const OP_IF_NOT: &str = "IfNot";
37const OP_IF_NOT_OPEN: &str = "IfNotOpen";
38const OP_IF_NOT_ZERO: &str = "IfNotZero";
39const OP_IF_NULL_ROW: &str = "IfNullRow";
40const OP_IF_POS: &str = "IfPos";
41const OP_IF_SMALLER: &str = "IfSmaller";
42const OP_INCR_VACUUM: &str = "IncrVacuum";
43const OP_INIT_COROUTINE: &str = "InitCoroutine";
44const OP_IS_NULL: &str = "IsNull";
45const OP_IS_NULL_OR_TYPE: &str = "IsNullOrType";
46const OP_LAST: &str = "Last";
47const OP_LE: &str = "Le";
48const OP_LT: &str = "Lt";
49const OP_MUST_BE_INT: &str = "MustBeInt";
50const OP_NE: &str = "Ne";
51const OP_NEXT: &str = "Next";
52const OP_NO_CONFLICT: &str = "NoConflict";
53const OP_NOT_EXISTS: &str = "NotExists";
54const OP_NOT_NULL: &str = "NotNull";
55const OP_ONCE: &str = "Once";
56const OP_PREV: &str = "Prev";
57const OP_PROGRAM: &str = "Program";
58const OP_RETURN: &str = "Return";
59const OP_REWIND: &str = "Rewind";
60const OP_ROW_DATA: &str = "RowData";
61const OP_ROW_SET_READ: &str = "RowSetRead";
62const OP_ROW_SET_TEST: &str = "RowSetTest";
63const OP_SEEK_GE: &str = "SeekGE";
64const OP_SEEK_GT: &str = "SeekGT";
65const OP_SEEK_LE: &str = "SeekLE";
66const OP_SEEK_LT: &str = "SeekLT";
67const OP_SEEK_ROW_ID: &str = "SeekRowId";
68const OP_SEEK_SCAN: &str = "SeekScan";
69const OP_SEQUENCE_TEST: &str = "SequenceTest";
70const OP_SORTER_NEXT: &str = "SorterNext";
71const OP_SORTER_SORT: &str = "SorterSort";
72const OP_V_FILTER: &str = "VFilter";
73const OP_V_NEXT: &str = "VNext";
74const OP_YIELD: &str = "Yield";
75const OP_JUMP: &str = "Jump";
76const OP_COLUMN: &str = "Column";
77const OP_MAKE_RECORD: &str = "MakeRecord";
78const OP_INSERT: &str = "Insert";
79const OP_IDX_INSERT: &str = "IdxInsert";
80const OP_OPEN_PSEUDO: &str = "OpenPseudo";
81const OP_OPEN_READ: &str = "OpenRead";
82const OP_OPEN_WRITE: &str = "OpenWrite";
83const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral";
84const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex";
85const OP_AGG_FINAL: &str = "AggFinal";
86const OP_AGG_STEP: &str = "AggStep";
87const OP_FUNCTION: &str = "Function";
88const OP_MOVE: &str = "Move";
89const OP_COPY: &str = "Copy";
90const OP_SCOPY: &str = "SCopy";
91const OP_NULL: &str = "Null";
92const OP_NULL_ROW: &str = "NullRow";
93const OP_INT_COPY: &str = "IntCopy";
94const OP_CAST: &str = "Cast";
95const OP_STRING8: &str = "String8";
96const OP_INT64: &str = "Int64";
97const OP_INTEGER: &str = "Integer";
98const OP_REAL: &str = "Real";
99const OP_NOT: &str = "Not";
100const OP_BLOB: &str = "Blob";
101const OP_VARIABLE: &str = "Variable";
102const OP_COUNT: &str = "Count";
103const OP_ROWID: &str = "Rowid";
104const OP_NEWROWID: &str = "NewRowid";
105const OP_OR: &str = "Or";
106const OP_AND: &str = "And";
107const OP_BIT_AND: &str = "BitAnd";
108const OP_BIT_OR: &str = "BitOr";
109const OP_SHIFT_LEFT: &str = "ShiftLeft";
110const OP_SHIFT_RIGHT: &str = "ShiftRight";
111const OP_ADD: &str = "Add";
112const OP_SUBTRACT: &str = "Subtract";
113const OP_MULTIPLY: &str = "Multiply";
114const OP_DIVIDE: &str = "Divide";
115const OP_REMAINDER: &str = "Remainder";
116const OP_CONCAT: &str = "Concat";
117const OP_RESULT_ROW: &str = "ResultRow";
118const OP_HALT: &str = "Halt";
119
120#[derive(Debug, Copy, Clone, Eq, PartialEq)]
121struct ColumnType {
122    pub datatype: DataType,
123    pub nullable: Option<bool>,
124}
125
126impl Default for ColumnType {
127    fn default() -> Self {
128        Self {
129            datatype: DataType::Null,
130            nullable: None,
131        }
132    }
133}
134
135impl ColumnType {
136    fn null() -> Self {
137        Self {
138            datatype: DataType::Null,
139            nullable: Some(true),
140        }
141    }
142}
143
144#[derive(Debug, Clone, Eq, PartialEq)]
145enum RegDataType {
146    Single(ColumnType),
147    Record(Vec<ColumnType>),
148    Int(i64),
149}
150
151impl RegDataType {
152    fn map_to_datatype(&self) -> DataType {
153        match self {
154            RegDataType::Single(d) => d.datatype,
155            RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
156            RegDataType::Int(_) => DataType::Int,
157        }
158    }
159    fn map_to_nullable(&self) -> Option<bool> {
160        match self {
161            RegDataType::Single(d) => d.nullable,
162            RegDataType::Record(_) => None, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
163            RegDataType::Int(_) => Some(false),
164        }
165    }
166    fn map_to_columntype(&self) -> ColumnType {
167        match self {
168            RegDataType::Single(d) => *d,
169            RegDataType::Record(_) => ColumnType {
170                datatype: DataType::Null,
171                nullable: None,
172            }, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
173            RegDataType::Int(_) => ColumnType {
174                datatype: DataType::Int,
175                nullable: Some(false),
176            },
177        }
178    }
179}
180
181#[derive(Debug, Clone, Eq, PartialEq)]
182enum CursorDataType {
183    Normal(HashMap<i64, ColumnType>),
184    Pseudo(i64),
185}
186
187impl CursorDataType {
188    fn from_sparse_record(record: &HashMap<i64, ColumnType>) -> Self {
189        Self::Normal(
190            record
191                .iter()
192                .map(|(&colnum, &datatype)| (colnum, datatype))
193                .collect(),
194        )
195    }
196
197    fn from_dense_record(record: &Vec<ColumnType>) -> Self {
198        Self::Normal((0..).zip(record.iter().copied()).collect())
199    }
200
201    fn map_to_dense_record(&self, registers: &HashMap<i64, RegDataType>) -> Vec<ColumnType> {
202        match self {
203            Self::Normal(record) => {
204                let mut rowdata = vec![ColumnType::default(); record.len()];
205                for (idx, col) in record.iter() {
206                    rowdata[*idx as usize] = col.clone();
207                }
208                rowdata
209            }
210            Self::Pseudo(i) => match registers.get(i) {
211                Some(RegDataType::Record(r)) => r.clone(),
212                _ => Vec::new(),
213            },
214        }
215    }
216
217    fn map_to_sparse_record(
218        &self,
219        registers: &HashMap<i64, RegDataType>,
220    ) -> HashMap<i64, ColumnType> {
221        match self {
222            Self::Normal(c) => c.clone(),
223            Self::Pseudo(i) => match registers.get(i) {
224                Some(RegDataType::Record(r)) => (0..).zip(r.iter().copied()).collect(),
225                _ => HashMap::new(),
226            },
227        }
228    }
229}
230
231#[allow(clippy::wildcard_in_or_patterns)]
232fn affinity_to_type(affinity: u8) -> DataType {
233    match affinity {
234        SQLITE_AFF_BLOB => DataType::Blob,
235        SQLITE_AFF_INTEGER => DataType::Int64,
236        SQLITE_AFF_NUMERIC => DataType::Numeric,
237        SQLITE_AFF_REAL => DataType::Float,
238        SQLITE_AFF_TEXT => DataType::Text,
239
240        SQLITE_AFF_NONE | _ => DataType::Null,
241    }
242}
243
244#[allow(clippy::wildcard_in_or_patterns)]
245fn opcode_to_type(op: &str) -> DataType {
246    match op {
247        OP_REAL => DataType::Float,
248        OP_BLOB => DataType::Blob,
249        OP_AND | OP_OR => DataType::Bool,
250        OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Int64,
251        OP_STRING8 => DataType::Text,
252        OP_COLUMN | _ => DataType::Null,
253    }
254}
255
256fn root_block_columns(
257    conn: &mut ConnectionState,
258) -> Result<HashMap<i64, HashMap<i64, ColumnType>>, Error> {
259    let table_block_columns: Vec<(i64, i64, String, bool)> = execute::iter(
260        conn,
261        "SELECT s.rootpage, col.cid as colnum, col.type, col.\"notnull\"
262         FROM (select * from sqlite_temp_schema UNION select * from sqlite_schema) s
263         JOIN pragma_table_info(s.name) AS col
264         WHERE s.type = 'table'",
265        None,
266        false,
267    )?
268    .filter_map(|res| res.map(|either| either.right()).transpose())
269    .map(|row| FromRow::from_row(&row?))
270    .collect::<Result<Vec<_>, Error>>()?;
271
272    let index_block_columns: Vec<(i64, i64, String, bool)> = execute::iter(
273        conn,
274        "SELECT s.rootpage, idx.seqno as colnum, col.type, col.\"notnull\"
275         FROM (select * from sqlite_temp_schema UNION select * from sqlite_schema) s
276         JOIN pragma_index_info(s.name) AS idx
277         LEFT JOIN pragma_table_info(s.tbl_name) as col
278           ON col.cid = idx.cid
279           WHERE s.type = 'index'",
280        None,
281        false,
282    )?
283    .filter_map(|res| res.map(|either| either.right()).transpose())
284    .map(|row| FromRow::from_row(&row?))
285    .collect::<Result<Vec<_>, Error>>()?;
286
287    let mut row_info: HashMap<i64, HashMap<i64, ColumnType>> = HashMap::new();
288    for (block, colnum, datatype, notnull) in table_block_columns {
289        let row_info = row_info.entry(block).or_default();
290        row_info.insert(
291            colnum,
292            ColumnType {
293                datatype: datatype.parse().unwrap_or(DataType::Null),
294                nullable: Some(!notnull),
295            },
296        );
297    }
298    for (block, colnum, datatype, notnull) in index_block_columns {
299        let row_info = row_info.entry(block).or_default();
300        row_info.insert(
301            colnum,
302            ColumnType {
303                datatype: datatype.parse().unwrap_or(DataType::Null),
304                nullable: Some(!notnull),
305            },
306        );
307    }
308
309    return Ok(row_info);
310}
311
312#[derive(Debug, Clone, PartialEq)]
313struct QueryState {
314    pub visited: Vec<bool>,
315    pub history: Vec<usize>,
316    // Registers
317    pub r: HashMap<i64, RegDataType>,
318    // Rows that pointers point to
319    pub p: HashMap<i64, CursorDataType>,
320    // Next instruction to execute
321    pub program_i: usize,
322    // Results published by the execution
323    pub result: Option<Vec<(Option<SqliteTypeInfo>, Option<bool>)>>,
324}
325
326// Opcode Reference: https://sqlite.org/opcode.html
327pub(super) fn explain(
328    conn: &mut ConnectionState,
329    query: &str,
330) -> Result<(Vec<SqliteTypeInfo>, Vec<Option<bool>>), Error> {
331    let root_block_cols = root_block_columns(conn)?;
332    let program: Vec<(i64, String, i64, i64, i64, Vec<u8>)> =
333        execute::iter(conn, &format!("EXPLAIN {}", query), None, false)?
334            .filter_map(|res| res.map(|either| either.right()).transpose())
335            .map(|row| FromRow::from_row(&row?))
336            .collect::<Result<Vec<_>, Error>>()?;
337    let program_size = program.len();
338
339    let mut logger =
340        crate::logger::QueryPlanLogger::new(query, &program, conn.log_settings.clone());
341
342    let mut states = vec![QueryState {
343        visited: vec![false; program_size],
344        history: Vec::new(),
345        r: HashMap::with_capacity(6),
346        p: HashMap::with_capacity(6),
347        program_i: 0,
348        result: None,
349    }];
350
351    let mut result_states = Vec::new();
352
353    while let Some(mut state) = states.pop() {
354        while state.program_i < program_size {
355            if state.visited[state.program_i] {
356                state.program_i += 1;
357                //avoid (infinite) loops by breaking if we ever hit the same instruction twice
358                break;
359            }
360            let (_, ref opcode, p1, p2, p3, ref p4) = program[state.program_i];
361            state.history.push(state.program_i);
362
363            match &**opcode {
364                OP_INIT => {
365                    // start at <p2>
366                    state.visited[state.program_i] = true;
367                    state.program_i = p2 as usize;
368                    continue;
369                }
370
371                OP_GOTO => {
372                    // goto <p2>
373                    state.visited[state.program_i] = true;
374                    state.program_i = p2 as usize;
375                    continue;
376                }
377
378                OP_DECR_JUMP_ZERO | OP_ELSE_EQ | OP_EQ | OP_FILTER | OP_FK_IF_ZERO | OP_FOUND
379                | OP_GE | OP_GO_SUB | OP_GT | OP_IDX_GE | OP_IDX_GT | OP_IDX_LE | OP_IDX_LT
380                | OP_IF | OP_IF_NO_HOPE | OP_IF_NOT | OP_IF_NOT_OPEN | OP_IF_NOT_ZERO
381                | OP_IF_NULL_ROW | OP_IF_POS | OP_IF_SMALLER | OP_INCR_VACUUM | OP_IS_NULL
382                | OP_IS_NULL_OR_TYPE | OP_LE | OP_LAST | OP_LT | OP_MUST_BE_INT | OP_NE
383                | OP_NEXT | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_NOT_NULL | OP_ONCE | OP_PREV
384                | OP_PROGRAM | OP_ROW_SET_READ | OP_ROW_SET_TEST | OP_SEEK_GE | OP_SEEK_GT
385                | OP_SEEK_LE | OP_SEEK_LT | OP_SEEK_ROW_ID | OP_SEEK_SCAN | OP_SEQUENCE_TEST
386                | OP_SORTER_NEXT | OP_SORTER_SORT | OP_V_FILTER | OP_V_NEXT | OP_REWIND => {
387                    // goto <p2> or next instruction (depending on actual values)
388                    state.visited[state.program_i] = true;
389
390                    let mut branch_state = state.clone();
391                    branch_state.program_i = p2 as usize;
392                    states.push(branch_state);
393
394                    state.program_i += 1;
395                    continue;
396                }
397
398                OP_INIT_COROUTINE => {
399                    // goto <p2> or next instruction (depending on actual values)
400                    state.visited[state.program_i] = true;
401                    state.r.insert(p1, RegDataType::Int(p3));
402
403                    if p2 != 0 {
404                        state.program_i = p2 as usize;
405                    } else {
406                        state.program_i += 1;
407                    }
408                    continue;
409                }
410
411                OP_END_COROUTINE => {
412                    // jump to p2 of the yield instruction pointed at by register p1
413                    state.visited[state.program_i] = true;
414                    if let Some(RegDataType::Int(yield_i)) = state.r.get(&p1) {
415                        if let Some((_, yield_op, _, yield_p2, _, _)) =
416                            program.get(*yield_i as usize)
417                        {
418                            if OP_YIELD == yield_op.as_str() {
419                                state.program_i = (*yield_p2) as usize;
420                                state.r.remove(&p1);
421                                continue;
422                            } else {
423                                break;
424                            }
425                        } else {
426                            break;
427                        }
428                    } else {
429                        break;
430                    }
431                }
432
433                OP_RETURN => {
434                    // jump to the instruction after the instruction pointed at by register p1
435                    state.visited[state.program_i] = true;
436                    if let Some(RegDataType::Int(return_i)) = state.r.get(&p1) {
437                        state.program_i = (*return_i + 1) as usize;
438                        state.r.remove(&p1);
439                        continue;
440                    } else {
441                        break;
442                    }
443                }
444
445                OP_YIELD => {
446                    // jump to p2 of the yield instruction pointed at by register p1, store prior instruction in p1
447                    state.visited[state.program_i] = true;
448                    if let Some(RegDataType::Int(yield_i)) = state.r.get_mut(&p1) {
449                        let program_i: usize = state.program_i;
450
451                        //if yielding to a yield operation, go to the NEXT instruction after that instruction
452                        if program
453                            .get(*yield_i as usize)
454                            .map(|(_, yield_op, _, _, _, _)| yield_op.as_str())
455                            == Some(OP_YIELD)
456                        {
457                            state.program_i = (*yield_i + 1) as usize;
458                            *yield_i = program_i as i64;
459                            continue;
460                        } else {
461                            state.program_i = *yield_i as usize;
462                            *yield_i = program_i as i64;
463                            continue;
464                        }
465                    } else {
466                        break;
467                    }
468                }
469
470                OP_JUMP => {
471                    // goto one of <p1>, <p2>, or <p3> based on the result of a prior compare
472                    state.visited[state.program_i] = true;
473
474                    let mut branch_state = state.clone();
475                    branch_state.program_i = p1 as usize;
476                    states.push(branch_state);
477
478                    let mut branch_state = state.clone();
479                    branch_state.program_i = p2 as usize;
480                    states.push(branch_state);
481
482                    let mut branch_state = state.clone();
483                    branch_state.program_i = p3 as usize;
484                    states.push(branch_state);
485                }
486
487                OP_COLUMN => {
488                    //Get the row stored at p1, or NULL; get the column stored at p2, or NULL
489                    if let Some(record) = state.p.get(&p1).map(|c| c.map_to_sparse_record(&state.r))
490                    {
491                        if let Some(col) = record.get(&p2) {
492                            // insert into p3 the datatype of the col
493                            state.r.insert(p3, RegDataType::Single(*col));
494                        } else {
495                            state
496                                .r
497                                .insert(p3, RegDataType::Single(ColumnType::default()));
498                        }
499                    } else {
500                        state
501                            .r
502                            .insert(p3, RegDataType::Single(ColumnType::default()));
503                    }
504                }
505
506                OP_ROW_DATA => {
507                    //Get entire row from cursor p1, store it into register p2
508                    if let Some(record) = state.p.get(&p1) {
509                        let rowdata = record.map_to_dense_record(&state.r);
510                        state.r.insert(p2, RegDataType::Record(rowdata));
511                    } else {
512                        state.r.insert(p2, RegDataType::Record(Vec::new()));
513                    }
514                }
515
516                OP_MAKE_RECORD => {
517                    // p3 = Record([p1 .. p1 + p2])
518                    let mut record = Vec::with_capacity(p2 as usize);
519                    for reg in p1..p1 + p2 {
520                        record.push(
521                            state
522                                .r
523                                .get(&reg)
524                                .map(|d| d.clone().map_to_columntype())
525                                .unwrap_or(ColumnType::default()),
526                        );
527                    }
528                    state.r.insert(p3, RegDataType::Record(record));
529                }
530
531                OP_INSERT | OP_IDX_INSERT => {
532                    if let Some(RegDataType::Record(record)) = state.r.get(&p2) {
533                        if let Some(CursorDataType::Normal(row)) = state.p.get_mut(&p1) {
534                            // Insert the record into wherever pointer p1 is
535                            *row = (0..).zip(record.iter().copied()).collect();
536                        }
537                    }
538                    //Noop if the register p2 isn't a record, or if pointer p1 does not exist
539                }
540
541                OP_OPEN_PSEUDO => {
542                    // Create a cursor p1 aliasing the record from register p2
543                    state.p.insert(p1, CursorDataType::Pseudo(p2));
544                }
545                OP_OPEN_READ | OP_OPEN_WRITE => {
546                    //Create a new pointer which is referenced by p1, take column metadata from db schema if found
547                    if p3 == 0 {
548                        if let Some(columns) = root_block_cols.get(&p2) {
549                            state
550                                .p
551                                .insert(p1, CursorDataType::from_sparse_record(columns));
552                        } else {
553                            state
554                                .p
555                                .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6)));
556                        }
557                    } else {
558                        state
559                            .p
560                            .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6)));
561                    }
562                }
563
564                OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => {
565                    //Create a new pointer which is referenced by p1
566                    state.p.insert(
567                        p1,
568                        CursorDataType::from_dense_record(&vec![ColumnType::null(); p2 as usize]),
569                    );
570                }
571
572                OP_VARIABLE => {
573                    // r[p2] = <value of variable>
574                    state.r.insert(p2, RegDataType::Single(ColumnType::null()));
575                }
576
577                OP_FUNCTION => {
578                    // r[p1] = func( _ )
579                    match from_utf8(p4).map_err(Error::protocol)? {
580                        "last_insert_rowid(0)" => {
581                            // last_insert_rowid() -> INTEGER
582                            state.r.insert(
583                                p3,
584                                RegDataType::Single(ColumnType {
585                                    datatype: DataType::Int64,
586                                    nullable: Some(false),
587                                }),
588                            );
589                        }
590
591                        _ => logger.add_unknown_operation(&program[state.program_i]),
592                    }
593                }
594
595                OP_NULL_ROW => {
596                    // all columns in cursor X are potentially nullable
597                    if let Some(CursorDataType::Normal(ref mut cursor)) = state.p.get_mut(&p1) {
598                        for ref mut col in cursor.values_mut() {
599                            col.nullable = Some(true);
600                        }
601                    }
602                    //else we don't know about the cursor
603                }
604
605                OP_AGG_STEP => {
606                    //assume that AGG_FINAL will be called
607                    let p4 = from_utf8(p4).map_err(Error::protocol)?;
608
609                    if p4.starts_with("count(") {
610                        // count(_) -> INTEGER
611                        state.r.insert(
612                            p3,
613                            RegDataType::Single(ColumnType {
614                                datatype: DataType::Int64,
615                                nullable: Some(false),
616                            }),
617                        );
618                    } else if let Some(v) = state.r.get(&p2).cloned() {
619                        // r[p3] = AGG ( r[p2] )
620                        state.r.insert(p3, v);
621                    }
622                }
623
624                OP_AGG_FINAL => {
625                    let p4 = from_utf8(p4).map_err(Error::protocol)?;
626
627                    if p4.starts_with("count(") {
628                        // count(_) -> INTEGER
629                        state.r.insert(
630                            p1,
631                            RegDataType::Single(ColumnType {
632                                datatype: DataType::Int64,
633                                nullable: Some(false),
634                            }),
635                        );
636                    } else if let Some(v) = state.r.get(&p2).cloned() {
637                        // r[p3] = AGG ( r[p2] )
638                        state.r.insert(p3, v);
639                    }
640                }
641
642                OP_CAST => {
643                    // affinity(r[p1])
644                    if let Some(v) = state.r.get_mut(&p1) {
645                        *v = RegDataType::Single(ColumnType {
646                            datatype: affinity_to_type(p2 as u8),
647                            nullable: v.map_to_nullable(),
648                        });
649                    }
650                }
651
652                OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => {
653                    // r[p2] = r[p1]
654                    if let Some(v) = state.r.get(&p1).cloned() {
655                        state.r.insert(p2, v);
656                    }
657                }
658
659                OP_INTEGER => {
660                    // r[p2] = p1
661                    state.r.insert(p2, RegDataType::Int(p1));
662                }
663
664                OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_ROWID | OP_NEWROWID => {
665                    // r[p2] = <value of constant>
666                    state.r.insert(
667                        p2,
668                        RegDataType::Single(ColumnType {
669                            datatype: opcode_to_type(&opcode),
670                            nullable: Some(false),
671                        }),
672                    );
673                }
674
675                OP_NOT => {
676                    // r[p2] = NOT r[p1]
677                    if let Some(a) = state.r.get(&p1).cloned() {
678                        state.r.insert(p2, a);
679                    }
680                }
681
682                OP_NULL => {
683                    // r[p2..p3] = null
684                    let idx_range = if p2 < p3 { p2..=p3 } else { p2..=p2 };
685
686                    for idx in idx_range {
687                        state.r.insert(idx, RegDataType::Single(ColumnType::null()));
688                    }
689                }
690
691                OP_OR | OP_AND | OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT
692                | OP_ADD | OP_SUBTRACT | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => {
693                    // r[p3] = r[p1] + r[p2]
694                    match (state.r.get(&p1).cloned(), state.r.get(&p2).cloned()) {
695                        (Some(a), Some(b)) => {
696                            state.r.insert(
697                                p3,
698                                RegDataType::Single(ColumnType {
699                                    datatype: if matches!(a.map_to_datatype(), DataType::Null) {
700                                        b.map_to_datatype()
701                                    } else {
702                                        a.map_to_datatype()
703                                    },
704                                    nullable: match (a.map_to_nullable(), b.map_to_nullable()) {
705                                        (Some(a_n), Some(b_n)) => Some(a_n | b_n),
706                                        (Some(a_n), None) => Some(a_n),
707                                        (None, Some(b_n)) => Some(b_n),
708                                        (None, None) => None,
709                                    },
710                                }),
711                            );
712                        }
713
714                        (Some(v), None) => {
715                            state.r.insert(
716                                p3,
717                                RegDataType::Single(ColumnType {
718                                    datatype: v.map_to_datatype(),
719                                    nullable: None,
720                                }),
721                            );
722                        }
723
724                        (None, Some(v)) => {
725                            state.r.insert(
726                                p3,
727                                RegDataType::Single(ColumnType {
728                                    datatype: v.map_to_datatype(),
729                                    nullable: None,
730                                }),
731                            );
732                        }
733
734                        _ => {}
735                    }
736                }
737
738                OP_RESULT_ROW => {
739                    // output = r[p1 .. p1 + p2]
740                    state.visited[state.program_i] = true;
741                    state.result = Some(
742                        (p1..p1 + p2)
743                            .map(|i| {
744                                let coltype = state.r.get(&i);
745
746                                let sqltype =
747                                    coltype.map(|d| d.map_to_datatype()).map(SqliteTypeInfo);
748                                let nullable =
749                                    coltype.map(|d| d.map_to_nullable()).unwrap_or_default();
750
751                                (sqltype, nullable)
752                            })
753                            .collect(),
754                    );
755
756                    if logger.log_enabled() {
757                        let program_history: Vec<&(i64, String, i64, i64, i64, Vec<u8>)> =
758                            state.history.iter().map(|i| &program[*i]).collect();
759                        logger.add_result((program_history, state.result.clone()));
760                    }
761
762                    result_states.push(state.clone());
763                }
764
765                OP_HALT => {
766                    break;
767                }
768
769                _ => {
770                    // ignore unsupported operations
771                    // if we fail to find an r later, we just give up
772                    logger.add_unknown_operation(&program[state.program_i]);
773                }
774            }
775
776            state.visited[state.program_i] = true;
777            state.program_i += 1;
778        }
779    }
780
781    let mut output: Vec<Option<SqliteTypeInfo>> = Vec::new();
782    let mut nullable: Vec<Option<bool>> = Vec::new();
783
784    while let Some(state) = result_states.pop() {
785        // find the datatype info from each ResultRow execution
786        if let Some(result) = state.result {
787            let mut idx = 0;
788            for (this_type, this_nullable) in result {
789                if output.len() == idx {
790                    output.push(this_type);
791                } else if output[idx].is_none()
792                    || matches!(output[idx], Some(SqliteTypeInfo(DataType::Null)))
793                {
794                    output[idx] = this_type;
795                }
796
797                if nullable.len() == idx {
798                    nullable.push(this_nullable);
799                } else if let Some(ref mut null) = nullable[idx] {
800                    //if any ResultRow's column is nullable, the final result is nullable
801                    if let Some(this_null) = this_nullable {
802                        *null |= this_null;
803                    }
804                } else {
805                    nullable[idx] = this_nullable;
806                }
807                idx += 1;
808            }
809        }
810    }
811
812    let output = output
813        .into_iter()
814        .map(|o| o.unwrap_or(SqliteTypeInfo(DataType::Null)))
815        .collect();
816
817    Ok((output, nullable))
818}
819
820#[test]
821fn test_root_block_columns_has_types() {
822    use crate::sqlite::SqliteConnectOptions;
823    use std::str::FromStr;
824    let conn_options = SqliteConnectOptions::from_str("sqlite::memory:").unwrap();
825    let mut conn = super::EstablishParams::from_options(&conn_options)
826        .unwrap()
827        .establish()
828        .unwrap();
829
830    assert!(execute::iter(
831        &mut conn,
832        r"CREATE TABLE t(a INTEGER PRIMARY KEY, b_null TEXT NULL, b TEXT NOT NULL);",
833        None,
834        false
835    )
836    .unwrap()
837    .next()
838    .is_some());
839    assert!(
840        execute::iter(&mut conn, r"CREATE INDEX i1 on t (a,b_null);", None, false)
841            .unwrap()
842            .next()
843            .is_some()
844    );
845    assert!(execute::iter(
846        &mut conn,
847        r"CREATE UNIQUE INDEX i2 on t (a,b_null);",
848        None,
849        false
850    )
851    .unwrap()
852    .next()
853    .is_some());
854    assert!(execute::iter(
855        &mut conn,
856        r"CREATE TABLE t2(a INTEGER NOT NULL, b_null NUMERIC NULL, b NUMERIC NOT NULL);",
857        None,
858        false
859    )
860    .unwrap()
861    .next()
862    .is_some());
863    assert!(execute::iter(
864        &mut conn,
865        r"CREATE INDEX t2i1 on t2 (a,b_null);",
866        None,
867        false
868    )
869    .unwrap()
870    .next()
871    .is_some());
872    assert!(execute::iter(
873        &mut conn,
874        r"CREATE UNIQUE INDEX t2i2 on t2 (a,b);",
875        None,
876        false
877    )
878    .unwrap()
879    .next()
880    .is_some());
881
882    let table_block_nums: HashMap<String, i64> = execute::iter(
883        &mut conn,
884        r"select name, rootpage from sqlite_master",
885        None,
886        false,
887    )
888    .unwrap()
889    .filter_map(|res| res.map(|either| either.right()).transpose())
890    .map(|row| FromRow::from_row(row.as_ref().unwrap()))
891    .collect::<Result<HashMap<_, _>, Error>>()
892    .unwrap();
893
894    let root_block_cols = root_block_columns(&mut conn).unwrap();
895
896    assert_eq!(6, root_block_cols.len());
897
898    //prove that we have some information for each table & index
899    for blocknum in table_block_nums.values() {
900        assert!(root_block_cols.contains_key(blocknum));
901    }
902
903    //prove that each block has the correct information
904    {
905        let blocknum = table_block_nums["t"];
906        assert_eq!(
907            ColumnType {
908                datatype: DataType::Int64,
909                nullable: Some(true) //sqlite primary key columns are nullable unless declared not null
910            },
911            root_block_cols[&blocknum][&0]
912        );
913        assert_eq!(
914            ColumnType {
915                datatype: DataType::Text,
916                nullable: Some(true)
917            },
918            root_block_cols[&blocknum][&1]
919        );
920        assert_eq!(
921            ColumnType {
922                datatype: DataType::Text,
923                nullable: Some(false)
924            },
925            root_block_cols[&blocknum][&2]
926        );
927    }
928
929    {
930        let blocknum = table_block_nums["i1"];
931        assert_eq!(
932            ColumnType {
933                datatype: DataType::Int64,
934                nullable: Some(true) //sqlite primary key columns are nullable unless declared not null
935            },
936            root_block_cols[&blocknum][&0]
937        );
938        assert_eq!(
939            ColumnType {
940                datatype: DataType::Text,
941                nullable: Some(true)
942            },
943            root_block_cols[&blocknum][&1]
944        );
945    }
946
947    {
948        let blocknum = table_block_nums["i2"];
949        assert_eq!(
950            ColumnType {
951                datatype: DataType::Int64,
952                nullable: Some(true) //sqlite primary key columns are nullable unless declared not null
953            },
954            root_block_cols[&blocknum][&0]
955        );
956        assert_eq!(
957            ColumnType {
958                datatype: DataType::Text,
959                nullable: Some(true)
960            },
961            root_block_cols[&blocknum][&1]
962        );
963    }
964
965    {
966        let blocknum = table_block_nums["t2"];
967        assert_eq!(
968            ColumnType {
969                datatype: DataType::Int64,
970                nullable: Some(false)
971            },
972            root_block_cols[&blocknum][&0]
973        );
974        assert_eq!(
975            ColumnType {
976                datatype: DataType::Null,
977                nullable: Some(true)
978            },
979            root_block_cols[&blocknum][&1]
980        );
981        assert_eq!(
982            ColumnType {
983                datatype: DataType::Null,
984                nullable: Some(false)
985            },
986            root_block_cols[&blocknum][&2]
987        );
988    }
989
990    {
991        let blocknum = table_block_nums["t2i1"];
992        assert_eq!(
993            ColumnType {
994                datatype: DataType::Int64,
995                nullable: Some(false)
996            },
997            root_block_cols[&blocknum][&0]
998        );
999        assert_eq!(
1000            ColumnType {
1001                datatype: DataType::Null,
1002                nullable: Some(true)
1003            },
1004            root_block_cols[&blocknum][&1]
1005        );
1006    }
1007
1008    {
1009        let blocknum = table_block_nums["t2i2"];
1010        assert_eq!(
1011            ColumnType {
1012                datatype: DataType::Int64,
1013                nullable: Some(false)
1014            },
1015            root_block_cols[&blocknum][&0]
1016        );
1017        assert_eq!(
1018            ColumnType {
1019                datatype: DataType::Null,
1020                nullable: Some(false)
1021            },
1022            root_block_cols[&blocknum][&1]
1023        );
1024    }
1025}