1use crate::{Error, FieldType, Record, RecordSet, Result, Schema, Value};
4use std::collections::HashMap;
5use std::io::{Seek, SeekFrom, Write};
6
7#[derive(Debug)]
9pub struct DbcWriter<W: Write + Seek> {
10 writer: W,
12 schema: Option<Schema>,
14}
15
16impl<W: Write + Seek> DbcWriter<W> {
17 pub fn new(writer: W) -> Self {
19 Self {
20 writer,
21 schema: None,
22 }
23 }
24
25 pub fn with_schema(mut self, schema: Schema) -> Self {
27 self.schema = Some(schema);
28 self
29 }
30
31 pub fn write_records(&mut self, record_set: &RecordSet) -> Result<()> {
33 let schema = if let Some(schema) = self.schema.as_ref() {
35 schema.clone()
36 } else if let Some(schema) = record_set.schema() {
37 schema.clone()
38 } else {
39 return Err(Error::InvalidRecord(
40 "No schema provided for writing records".to_string(),
41 ));
42 };
43
44 let (string_block, string_offsets) = self.build_string_block(record_set)?;
46
47 let record_count = record_set.len() as u32;
49 let field_count = schema.fields.len() as u32;
50 let record_size = schema.record_size() as u32;
51 let string_block_size = string_block.len() as u32;
52
53 self.writer.seek(SeekFrom::Start(0))?;
55 self.writer.write_all(&crate::header::DBC_MAGIC)?;
56 self.writer.write_all(&record_count.to_le_bytes())?;
57 self.writer.write_all(&field_count.to_le_bytes())?;
58 self.writer.write_all(&record_size.to_le_bytes())?;
59 self.writer.write_all(&string_block_size.to_le_bytes())?;
60
61 for record in record_set.records() {
63 self.write_record(record, &schema, record_set, &string_offsets)?;
64 }
65
66 self.writer.write_all(&string_block)?;
68
69 Ok(())
70 }
71
72 fn build_string_block(
74 &self,
75 record_set: &RecordSet,
76 ) -> Result<(Vec<u8>, HashMap<String, u32>)> {
77 let mut string_block = Vec::new();
78 let mut string_offsets = HashMap::new();
79
80 string_block.push(0);
82 string_offsets.insert(String::new(), 0);
83
84 for record in record_set.records() {
86 for value in record.values() {
87 if let Value::StringRef(string_ref) = value {
88 let string = record_set.get_string(*string_ref)?;
89
90 if !string_offsets.contains_key(string) {
91 let offset = string_block.len() as u32;
92 string_offsets.insert(string.to_string(), offset);
93
94 string_block.extend_from_slice(string.as_bytes());
96 string_block.push(0); }
98 }
99 }
100 }
101
102 Ok((string_block, string_offsets))
103 }
104
105 fn write_record(
107 &mut self,
108 record: &Record,
109 schema: &Schema,
110 record_set: &RecordSet,
111 string_offsets: &HashMap<String, u32>,
112 ) -> Result<()> {
113 for (i, field) in schema.fields.iter().enumerate() {
114 if let Some(value) = record.get_value(i) {
115 self.write_value(value, field.field_type, record_set, string_offsets)?;
116 } else {
117 match field.field_type {
119 FieldType::Int32 => self.writer.write_all(&0i32.to_le_bytes())?,
120 FieldType::UInt32 => self.writer.write_all(&0u32.to_le_bytes())?,
121 FieldType::Float32 => self.writer.write_all(&0.0f32.to_le_bytes())?,
122 FieldType::String => self.writer.write_all(&0u32.to_le_bytes())?,
123 FieldType::Bool => self.writer.write_all(&0u32.to_le_bytes())?,
124 FieldType::UInt8 => self.writer.write_all(&[0u8])?,
125 FieldType::Int8 => self.writer.write_all(&[0u8])?,
126 FieldType::UInt16 => self.writer.write_all(&0u16.to_le_bytes())?,
127 FieldType::Int16 => self.writer.write_all(&0i16.to_le_bytes())?,
128 }
129 }
130 }
131
132 Ok(())
133 }
134
135 fn write_value(
137 &mut self,
138 value: &Value,
139 field_type: FieldType,
140 record_set: &RecordSet,
141 string_offsets: &HashMap<String, u32>,
142 ) -> Result<()> {
143 match (value, field_type) {
144 (Value::Int32(v), FieldType::Int32) => self.writer.write_all(&v.to_le_bytes())?,
145 (Value::UInt32(v), FieldType::UInt32) => self.writer.write_all(&v.to_le_bytes())?,
146 (Value::Float32(v), FieldType::Float32) => self.writer.write_all(&v.to_le_bytes())?,
147 (Value::StringRef(v), FieldType::String) => {
148 let string = record_set.get_string(*v)?;
149 let offset = string_offsets.get(string).unwrap_or(&0);
150 self.writer.write_all(&offset.to_le_bytes())?;
151 }
152 (Value::Bool(v), FieldType::Bool) => self
153 .writer
154 .write_all(&(if *v { 1u32 } else { 0u32 }).to_le_bytes())?,
155 (Value::UInt8(v), FieldType::UInt8) => self.writer.write_all(&[*v])?,
156 (Value::Int8(v), FieldType::Int8) => self.writer.write_all(&[*v as u8])?,
157 (Value::UInt16(v), FieldType::UInt16) => self.writer.write_all(&v.to_le_bytes())?,
158 (Value::Int16(v), FieldType::Int16) => self.writer.write_all(&v.to_le_bytes())?,
159 (Value::Array(vals), _) => {
160 for val in vals {
162 self.write_value(val, field_type, record_set, string_offsets)?;
163 }
164 }
165 _ => {
166 return Err(Error::TypeConversion(format!(
167 "Type mismatch: {value:?} is not compatible with {field_type:?}"
168 )));
169 }
170 }
171
172 Ok(())
173 }
174}