1use std::io::{Cursor, Read};
14use torg_core::{BoolOp, Graph, Node, Source};
15
16use crate::error::SerdeError;
17
18const MAGIC: [u8; 2] = [0x54, 0x47]; const VERSION: u8 = 1;
23
24const SOURCE_ID: u8 = 0;
26const SOURCE_TRUE: u8 = 1;
28const SOURCE_FALSE: u8 = 2;
30
31const OP_OR: u8 = 0;
33const OP_NOR: u8 = 1;
35const OP_XOR: u8 = 2;
37
38pub fn to_bytes(graph: &Graph) -> Vec<u8> {
43 let mut buf = Vec::with_capacity(estimate_size(graph));
44
45 buf.extend_from_slice(&MAGIC);
47 buf.push(VERSION);
48 buf.push(0); write_u16(&mut buf, graph.inputs.len() as u16);
52 for &id in &graph.inputs {
53 write_u16(&mut buf, id);
54 }
55
56 write_u16(&mut buf, graph.nodes.len() as u16);
58 for node in &graph.nodes {
59 write_u16(&mut buf, node.id);
60 buf.push(encode_op(node.op));
61 write_source(&mut buf, &node.left);
62 write_source(&mut buf, &node.right);
63 }
64
65 write_u16(&mut buf, graph.outputs.len() as u16);
67 for &id in &graph.outputs {
68 write_u16(&mut buf, id);
69 }
70
71 buf
72}
73
74pub fn from_bytes(bytes: &[u8]) -> Result<Graph, SerdeError> {
78 let mut cursor = Cursor::new(bytes);
79 let mut pos = 0usize;
80
81 let magic = read_bytes(&mut cursor, 2, &mut pos)?;
83 if magic[0] != MAGIC[0] || magic[1] != MAGIC[1] {
84 return Err(SerdeError::InvalidMagic(magic[0], magic[1]));
85 }
86
87 let version = read_u8(&mut cursor, &mut pos)?;
88 if version != VERSION {
89 return Err(SerdeError::UnsupportedVersion(version));
90 }
91
92 let _flags = read_u8(&mut cursor, &mut pos)?; let input_count = read_u16(&mut cursor, &mut pos)?;
96 let mut inputs = Vec::with_capacity(input_count as usize);
97 for _ in 0..input_count {
98 inputs.push(read_u16(&mut cursor, &mut pos)?);
99 }
100
101 let node_count = read_u16(&mut cursor, &mut pos)?;
103 let mut nodes = Vec::with_capacity(node_count as usize);
104 for _ in 0..node_count {
105 let id = read_u16(&mut cursor, &mut pos)?;
106 let op = decode_op(read_u8(&mut cursor, &mut pos)?, pos)?;
107 let left = read_source(&mut cursor, &mut pos)?;
108 let right = read_source(&mut cursor, &mut pos)?;
109 nodes.push(Node::new(id, op, left, right));
110 }
111
112 let output_count = read_u16(&mut cursor, &mut pos)?;
114 let mut outputs = Vec::with_capacity(output_count as usize);
115 for _ in 0..output_count {
116 outputs.push(read_u16(&mut cursor, &mut pos)?);
117 }
118
119 Ok(Graph {
120 inputs,
121 nodes,
122 outputs,
123 })
124}
125
126fn estimate_size(graph: &Graph) -> usize {
128 4 + 2 + graph.inputs.len() * 2 + 2 + graph.nodes.len() * 7 + 2 + graph.outputs.len() * 2 }
133
134#[inline]
136fn write_u16(buf: &mut Vec<u8>, val: u16) {
137 buf.extend_from_slice(&val.to_le_bytes());
138}
139
140#[inline]
142fn encode_op(op: BoolOp) -> u8 {
143 match op {
144 BoolOp::Or => OP_OR,
145 BoolOp::Nor => OP_NOR,
146 BoolOp::Xor => OP_XOR,
147 }
148}
149
150fn write_source(buf: &mut Vec<u8>, source: &Source) {
152 match source {
153 Source::Id(id) => {
154 buf.push(SOURCE_ID);
155 write_u16(buf, *id);
156 }
157 Source::True => buf.push(SOURCE_TRUE),
158 Source::False => buf.push(SOURCE_FALSE),
159 }
160}
161
162fn read_bytes(
164 cursor: &mut Cursor<&[u8]>,
165 n: usize,
166 pos: &mut usize,
167) -> Result<Vec<u8>, SerdeError> {
168 let mut buf = vec![0u8; n];
169 cursor
170 .read_exact(&mut buf)
171 .map_err(|_| SerdeError::UnexpectedEof(*pos))?;
172 *pos += n;
173 Ok(buf)
174}
175
176fn read_u8(cursor: &mut Cursor<&[u8]>, pos: &mut usize) -> Result<u8, SerdeError> {
178 let mut buf = [0u8; 1];
179 cursor
180 .read_exact(&mut buf)
181 .map_err(|_| SerdeError::UnexpectedEof(*pos))?;
182 *pos += 1;
183 Ok(buf[0])
184}
185
186fn read_u16(cursor: &mut Cursor<&[u8]>, pos: &mut usize) -> Result<u16, SerdeError> {
188 let mut buf = [0u8; 2];
189 cursor
190 .read_exact(&mut buf)
191 .map_err(|_| SerdeError::UnexpectedEof(*pos))?;
192 *pos += 2;
193 Ok(u16::from_le_bytes(buf))
194}
195
196fn decode_op(byte: u8, _pos: usize) -> Result<BoolOp, SerdeError> {
198 match byte {
199 OP_OR => Ok(BoolOp::Or),
200 OP_NOR => Ok(BoolOp::Nor),
201 OP_XOR => Ok(BoolOp::Xor),
202 _ => Err(SerdeError::InvalidOpCode(byte)),
203 }
204}
205
206fn read_source(cursor: &mut Cursor<&[u8]>, pos: &mut usize) -> Result<Source, SerdeError> {
208 let tag = read_u8(cursor, pos)?;
209 match tag {
210 SOURCE_ID => {
211 let id = read_u16(cursor, pos)?;
212 Ok(Source::Id(id))
213 }
214 SOURCE_TRUE => Ok(Source::True),
215 SOURCE_FALSE => Ok(Source::False),
216 _ => Err(SerdeError::InvalidSourceTag(tag)),
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 fn sample_graph() -> Graph {
225 Graph {
226 inputs: vec![0, 1],
227 nodes: vec![
228 Node::new(2, BoolOp::Or, Source::Id(0), Source::Id(1)),
229 Node::new(3, BoolOp::Nor, Source::Id(2), Source::True),
230 ],
231 outputs: vec![3],
232 }
233 }
234
235 #[test]
236 fn test_roundtrip() {
237 let graph = sample_graph();
238 let bytes = to_bytes(&graph);
239 let restored = from_bytes(&bytes).unwrap();
240 assert_eq!(graph, restored);
241 }
242
243 #[test]
244 fn test_empty_graph() {
245 let graph = Graph::default();
246 let bytes = to_bytes(&graph);
247 let restored = from_bytes(&bytes).unwrap();
248 assert_eq!(graph, restored);
249 }
250
251 #[test]
252 fn test_all_source_types() {
253 let graph = Graph {
254 inputs: vec![0],
255 nodes: vec![
256 Node::new(1, BoolOp::Or, Source::Id(0), Source::True),
257 Node::new(2, BoolOp::Xor, Source::False, Source::Id(1)),
258 ],
259 outputs: vec![2],
260 };
261 let bytes = to_bytes(&graph);
262 let restored = from_bytes(&bytes).unwrap();
263 assert_eq!(graph, restored);
264 }
265
266 #[test]
267 fn test_compact_size() {
268 let graph = sample_graph();
269 let bytes = to_bytes(&graph);
270 assert!(
274 bytes.len() < 50,
275 "Expected compact encoding, got {} bytes",
276 bytes.len()
277 );
278 }
279
280 #[test]
281 fn test_invalid_magic() {
282 let bytes = [0x00, 0x00, VERSION, 0, 0, 0, 0, 0, 0, 0];
283 let result = from_bytes(&bytes);
284 assert!(matches!(result, Err(SerdeError::InvalidMagic(0x00, 0x00))));
285 }
286
287 #[test]
288 fn test_unsupported_version() {
289 let bytes = [MAGIC[0], MAGIC[1], 99, 0, 0, 0, 0, 0, 0, 0];
290 let result = from_bytes(&bytes);
291 assert!(matches!(result, Err(SerdeError::UnsupportedVersion(99))));
292 }
293
294 #[test]
295 fn test_unexpected_eof() {
296 let bytes = [MAGIC[0], MAGIC[1], VERSION, 0];
297 let result = from_bytes(&bytes);
298 assert!(matches!(result, Err(SerdeError::UnexpectedEof(_))));
299 }
300}