1use std::collections::HashMap;
4
5use crate::{
6 AggregateOp, EinsumGraph, EinsumNode, FuzzyImplicationKind, FuzzyNegationKind, Metadata,
7 OpType, TCoNormKind, TLExpr, TNormKind, Term, TypeAnnotation,
8};
9
10use super::ExprSerializeError;
11use super::{
12 AGG_ALL, AGG_ANY, AGG_AVERAGE, AGG_COUNT, AGG_MAX, AGG_MIN, AGG_PRODUCT, AGG_SUM, FIMP_GODEL,
13 FIMP_GOGUEN, FIMP_KLEENE_DIENES, FIMP_LUKASIEWICZ, FIMP_REICHENBACH, FIMP_RESCHER,
14 FNEG_STANDARD, FNEG_SUGENO, FNEG_YAGER, FORMAT_VER, OP_EINSUM, OP_ELEM_BINARY, OP_ELEM_UNARY,
15 OP_REDUCE, TAG_ABDUCIBLE, TAG_ABS, TAG_ADD, TAG_AGGREGATE, TAG_ALL_DIFFERENT, TAG_ALWAYS,
16 TAG_AND, TAG_APPLY, TAG_AT, TAG_BOX, TAG_CEIL, TAG_CONSTANT, TAG_COS, TAG_COUNTING_EXISTS,
17 TAG_COUNTING_FORALL, TAG_DIAMOND, TAG_DIV, TAG_EMPTY_SET, TAG_EQ, TAG_EVENTUALLY,
18 TAG_EVERYWHERE, TAG_EXACT_COUNT, TAG_EXISTS, TAG_EXP, TAG_EXPLAIN, TAG_FLOOR, TAG_FORALL,
19 TAG_FUZZY_IMPLICATION, TAG_FUZZY_NOT, TAG_GLOBAL_CARDINALITY, TAG_GREATEST_FIXPOINT, TAG_GT,
20 TAG_GTE, TAG_IF_THEN_ELSE, TAG_IMPLY, TAG_LAMBDA, TAG_LEAST_FIXPOINT, TAG_LET, TAG_LOG, TAG_LT,
21 TAG_LTE, TAG_MAJORITY, TAG_MATCH, TAG_MAX, TAG_MIN, TAG_MOD, TAG_MUL, TAG_NEXT, TAG_NOMINAL,
22 TAG_NOT, TAG_OR, TAG_PATTERN_CONST_NUMBER, TAG_PATTERN_CONST_SYMBOL, TAG_PATTERN_WILDCARD,
23 TAG_POW, TAG_PRED, TAG_PROBABILISTIC_CHOICE, TAG_RELEASE, TAG_ROUND, TAG_SCORE,
24 TAG_SET_CARDINALITY, TAG_SET_COMPREHENSION, TAG_SET_DIFFERENCE, TAG_SET_INTERSECTION,
25 TAG_SET_MEMBERSHIP, TAG_SET_UNION, TAG_SIN, TAG_SOFT_EXISTS, TAG_SOFT_FORALL, TAG_SOMEWHERE,
26 TAG_SQRT, TAG_STRONG_RELEASE, TAG_SUB, TAG_SYMBOL_LITERAL, TAG_TAN, TAG_TCONORM, TAG_TNORM,
27 TAG_UNTIL, TAG_WEAK_UNTIL, TAG_WEIGHTED_RULE, TCONORM_BOUNDED_SUM, TCONORM_DRASTIC,
28 TCONORM_HAMACHER, TCONORM_MAXIMUM, TCONORM_NILPOTENT_MAXIMUM, TCONORM_PROBABILISTIC_SUM,
29 TERM_TAG_CONST, TERM_TAG_TYPED, TERM_TAG_VAR, TLEX_MAGIC, TLGR_MAGIC, TNORM_DRASTIC,
30 TNORM_HAMACHER, TNORM_LUKASIEWICZ, TNORM_MINIMUM, TNORM_NILPOTENT_MINIMUM, TNORM_PRODUCT,
31};
32
33pub(super) struct BinWriter {
35 buf: Vec<u8>,
36}
37
38impl BinWriter {
39 pub(super) fn new() -> Self {
40 Self { buf: Vec::new() }
41 }
42
43 pub(super) fn write_u8(&mut self, v: u8) {
44 self.buf.push(v);
45 }
46
47 pub(super) fn write_u32(&mut self, v: u32) {
48 self.buf.extend_from_slice(&v.to_le_bytes());
49 }
50
51 fn write_u64(&mut self, v: u64) {
52 self.buf.extend_from_slice(&v.to_le_bytes());
53 }
54
55 fn write_i32(&mut self, v: i32) {
56 self.buf.extend_from_slice(&v.to_le_bytes());
57 }
58
59 fn write_f64(&mut self, v: f64) {
60 self.buf.extend_from_slice(&v.to_le_bytes());
61 }
62
63 fn write_string(&mut self, s: &str) {
64 self.write_u32(s.len() as u32);
65 self.buf.extend_from_slice(s.as_bytes());
66 }
67
68 pub(super) fn write_magic(&mut self, magic: &[u8; 4]) {
69 self.buf.extend_from_slice(magic);
70 }
71
72 pub(super) fn into_bytes(self) -> Vec<u8> {
73 self.buf
74 }
75}
76
77pub(super) struct BinReader<'a> {
79 data: &'a [u8],
80 pos: usize,
81}
82
83impl<'a> BinReader<'a> {
84 pub(super) fn new(data: &'a [u8]) -> Self {
85 Self { data, pos: 0 }
86 }
87
88 fn remaining(&self) -> usize {
89 self.data.len().saturating_sub(self.pos)
90 }
91
92 pub(super) fn read_u8(&mut self) -> Result<u8, ExprSerializeError> {
93 if self.remaining() < 1 {
94 return Err(ExprSerializeError::TruncatedInput);
95 }
96 let v = self.data[self.pos];
97 self.pos += 1;
98 Ok(v)
99 }
100
101 pub(super) fn read_u32(&mut self) -> Result<u32, ExprSerializeError> {
102 if self.remaining() < 4 {
103 return Err(ExprSerializeError::TruncatedInput);
104 }
105 let bytes: [u8; 4] = self.data[self.pos..self.pos + 4]
106 .try_into()
107 .map_err(|_| ExprSerializeError::TruncatedInput)?;
108 self.pos += 4;
109 Ok(u32::from_le_bytes(bytes))
110 }
111
112 fn read_u64(&mut self) -> Result<u64, ExprSerializeError> {
113 if self.remaining() < 8 {
114 return Err(ExprSerializeError::TruncatedInput);
115 }
116 let bytes: [u8; 8] = self.data[self.pos..self.pos + 8]
117 .try_into()
118 .map_err(|_| ExprSerializeError::TruncatedInput)?;
119 self.pos += 8;
120 Ok(u64::from_le_bytes(bytes))
121 }
122
123 fn read_i32(&mut self) -> Result<i32, ExprSerializeError> {
124 if self.remaining() < 4 {
125 return Err(ExprSerializeError::TruncatedInput);
126 }
127 let bytes: [u8; 4] = self.data[self.pos..self.pos + 4]
128 .try_into()
129 .map_err(|_| ExprSerializeError::TruncatedInput)?;
130 self.pos += 4;
131 Ok(i32::from_le_bytes(bytes))
132 }
133
134 fn read_f64(&mut self) -> Result<f64, ExprSerializeError> {
135 if self.remaining() < 8 {
136 return Err(ExprSerializeError::TruncatedInput);
137 }
138 let bytes: [u8; 8] = self.data[self.pos..self.pos + 8]
139 .try_into()
140 .map_err(|_| ExprSerializeError::TruncatedInput)?;
141 self.pos += 8;
142 Ok(f64::from_le_bytes(bytes))
143 }
144
145 fn read_str(&mut self) -> Result<String, ExprSerializeError> {
146 let len = self.read_u32()? as usize;
147 if self.remaining() < len {
148 return Err(ExprSerializeError::TruncatedInput);
149 }
150 let s = std::str::from_utf8(&self.data[self.pos..self.pos + len])
151 .map_err(|e| ExprSerializeError::Utf8Error(e.to_string()))?
152 .to_string();
153 self.pos += len;
154 Ok(s)
155 }
156
157 pub(super) fn read_magic(&mut self) -> Result<[u8; 4], ExprSerializeError> {
158 if self.remaining() < 4 {
159 return Err(ExprSerializeError::TruncatedInput);
160 }
161 let magic: [u8; 4] = self.data[self.pos..self.pos + 4]
162 .try_into()
163 .map_err(|_| ExprSerializeError::TruncatedInput)?;
164 self.pos += 4;
165 Ok(magic)
166 }
167}
168
169pub fn to_binary(expr: &TLExpr) -> Vec<u8> {
171 let mut w = BinWriter::new();
172 w.write_magic(&TLEX_MAGIC);
173 w.write_u32(FORMAT_VER);
174 write_expr_bin(expr, &mut w);
175 w.into_bytes()
176}
177
178pub fn from_binary(bytes: &[u8]) -> Result<TLExpr, ExprSerializeError> {
180 let mut r = BinReader::new(bytes);
181 let magic = r.read_magic()?;
182 if magic != TLEX_MAGIC {
183 return Err(ExprSerializeError::InvalidMagic);
184 }
185 let version = r.read_u32()?;
186 if version != FORMAT_VER {
187 return Err(ExprSerializeError::VersionMismatch {
188 expected: FORMAT_VER,
189 got: version,
190 });
191 }
192 read_expr_bin(&mut r)
193}
194
195fn write_term_bin(term: &Term, w: &mut BinWriter) {
196 match term {
197 Term::Var(name) => {
198 w.write_u8(TERM_TAG_VAR);
199 w.write_string(name);
200 }
201 Term::Const(name) => {
202 w.write_u8(TERM_TAG_CONST);
203 w.write_string(name);
204 }
205 Term::Typed {
206 value,
207 type_annotation,
208 } => {
209 w.write_u8(TERM_TAG_TYPED);
210 write_term_bin(value, w);
211 w.write_string(&type_annotation.type_name);
212 }
213 }
214}
215
216fn read_term_bin(r: &mut BinReader<'_>) -> Result<Term, ExprSerializeError> {
217 let tag = r.read_u8()?;
218 match tag {
219 TERM_TAG_VAR => Ok(Term::Var(r.read_str()?)),
220 TERM_TAG_CONST => Ok(Term::Const(r.read_str()?)),
221 TERM_TAG_TYPED => {
222 let value = read_term_bin(r)?;
223 let type_name = r.read_str()?;
224 Ok(Term::Typed {
225 value: Box::new(value),
226 type_annotation: TypeAnnotation::new(type_name),
227 })
228 }
229 _ => Err(ExprSerializeError::UnknownVariant(format!(
230 "Term tag {tag}"
231 ))),
232 }
233}
234
235fn write_optional_string(s: &Option<String>, w: &mut BinWriter) {
236 match s {
237 Some(val) => {
238 w.write_u8(1);
239 w.write_string(val);
240 }
241 None => w.write_u8(0),
242 }
243}
244
245fn read_optional_string(r: &mut BinReader<'_>) -> Result<Option<String>, ExprSerializeError> {
246 let has = r.read_u8()?;
247 if has == 0 {
248 Ok(None)
249 } else {
250 Ok(Some(r.read_str()?))
251 }
252}
253
254fn write_string_vec(v: &[String], w: &mut BinWriter) {
255 w.write_u32(v.len() as u32);
256 for s in v {
257 w.write_string(s);
258 }
259}
260
261fn read_string_vec(r: &mut BinReader<'_>) -> Result<Vec<String>, ExprSerializeError> {
262 let count = r.read_u32()? as usize;
263 let mut result = Vec::with_capacity(count);
264 for _ in 0..count {
265 result.push(r.read_str()?);
266 }
267 Ok(result)
268}
269
270fn write_usize_vec(v: &[usize], w: &mut BinWriter) {
271 w.write_u32(v.len() as u32);
272 for &val in v {
273 w.write_u64(val as u64);
274 }
275}
276
277fn read_usize_vec(r: &mut BinReader<'_>) -> Result<Vec<usize>, ExprSerializeError> {
278 let count = r.read_u32()? as usize;
279 let mut result = Vec::with_capacity(count);
280 for _ in 0..count {
281 result.push(r.read_u64()? as usize);
282 }
283 Ok(result)
284}
285
286pub(super) fn write_expr_bin(expr: &TLExpr, w: &mut BinWriter) {
287 match expr {
288 TLExpr::Pred { name, args } => {
289 w.write_u8(TAG_PRED);
290 w.write_string(name);
291 w.write_u32(args.len() as u32);
292 for arg in args {
293 write_term_bin(arg, w);
294 }
295 }
296 TLExpr::And(a, b) => {
297 w.write_u8(TAG_AND);
298 write_expr_bin(a, w);
299 write_expr_bin(b, w);
300 }
301 TLExpr::Or(a, b) => {
302 w.write_u8(TAG_OR);
303 write_expr_bin(a, w);
304 write_expr_bin(b, w);
305 }
306 TLExpr::Not(e) => {
307 w.write_u8(TAG_NOT);
308 write_expr_bin(e, w);
309 }
310 TLExpr::Exists { var, domain, body } => {
311 w.write_u8(TAG_EXISTS);
312 w.write_string(var);
313 w.write_string(domain);
314 write_expr_bin(body, w);
315 }
316 TLExpr::ForAll { var, domain, body } => {
317 w.write_u8(TAG_FORALL);
318 w.write_string(var);
319 w.write_string(domain);
320 write_expr_bin(body, w);
321 }
322 TLExpr::Imply(a, b) => {
323 w.write_u8(TAG_IMPLY);
324 write_expr_bin(a, w);
325 write_expr_bin(b, w);
326 }
327 TLExpr::Score(e) => {
328 w.write_u8(TAG_SCORE);
329 write_expr_bin(e, w);
330 }
331 TLExpr::Add(a, b) => {
332 w.write_u8(TAG_ADD);
333 write_expr_bin(a, w);
334 write_expr_bin(b, w);
335 }
336 TLExpr::Sub(a, b) => {
337 w.write_u8(TAG_SUB);
338 write_expr_bin(a, w);
339 write_expr_bin(b, w);
340 }
341 TLExpr::Mul(a, b) => {
342 w.write_u8(TAG_MUL);
343 write_expr_bin(a, w);
344 write_expr_bin(b, w);
345 }
346 TLExpr::Div(a, b) => {
347 w.write_u8(TAG_DIV);
348 write_expr_bin(a, w);
349 write_expr_bin(b, w);
350 }
351 TLExpr::Pow(a, b) => {
352 w.write_u8(TAG_POW);
353 write_expr_bin(a, w);
354 write_expr_bin(b, w);
355 }
356 TLExpr::Mod(a, b) => {
357 w.write_u8(TAG_MOD);
358 write_expr_bin(a, w);
359 write_expr_bin(b, w);
360 }
361 TLExpr::Min(a, b) => {
362 w.write_u8(TAG_MIN);
363 write_expr_bin(a, w);
364 write_expr_bin(b, w);
365 }
366 TLExpr::Max(a, b) => {
367 w.write_u8(TAG_MAX);
368 write_expr_bin(a, w);
369 write_expr_bin(b, w);
370 }
371 TLExpr::Abs(e) => {
372 w.write_u8(TAG_ABS);
373 write_expr_bin(e, w);
374 }
375 TLExpr::Floor(e) => {
376 w.write_u8(TAG_FLOOR);
377 write_expr_bin(e, w);
378 }
379 TLExpr::Ceil(e) => {
380 w.write_u8(TAG_CEIL);
381 write_expr_bin(e, w);
382 }
383 TLExpr::Round(e) => {
384 w.write_u8(TAG_ROUND);
385 write_expr_bin(e, w);
386 }
387 TLExpr::Sqrt(e) => {
388 w.write_u8(TAG_SQRT);
389 write_expr_bin(e, w);
390 }
391 TLExpr::Exp(e) => {
392 w.write_u8(TAG_EXP);
393 write_expr_bin(e, w);
394 }
395 TLExpr::Log(e) => {
396 w.write_u8(TAG_LOG);
397 write_expr_bin(e, w);
398 }
399 TLExpr::Sin(e) => {
400 w.write_u8(TAG_SIN);
401 write_expr_bin(e, w);
402 }
403 TLExpr::Cos(e) => {
404 w.write_u8(TAG_COS);
405 write_expr_bin(e, w);
406 }
407 TLExpr::Tan(e) => {
408 w.write_u8(TAG_TAN);
409 write_expr_bin(e, w);
410 }
411 TLExpr::Eq(a, b) => {
412 w.write_u8(TAG_EQ);
413 write_expr_bin(a, w);
414 write_expr_bin(b, w);
415 }
416 TLExpr::Lt(a, b) => {
417 w.write_u8(TAG_LT);
418 write_expr_bin(a, w);
419 write_expr_bin(b, w);
420 }
421 TLExpr::Gt(a, b) => {
422 w.write_u8(TAG_GT);
423 write_expr_bin(a, w);
424 write_expr_bin(b, w);
425 }
426 TLExpr::Lte(a, b) => {
427 w.write_u8(TAG_LTE);
428 write_expr_bin(a, w);
429 write_expr_bin(b, w);
430 }
431 TLExpr::Gte(a, b) => {
432 w.write_u8(TAG_GTE);
433 write_expr_bin(a, w);
434 write_expr_bin(b, w);
435 }
436 TLExpr::IfThenElse {
437 condition,
438 then_branch,
439 else_branch,
440 } => {
441 w.write_u8(TAG_IF_THEN_ELSE);
442 write_expr_bin(condition, w);
443 write_expr_bin(then_branch, w);
444 write_expr_bin(else_branch, w);
445 }
446 TLExpr::Constant(v) => {
447 w.write_u8(TAG_CONSTANT);
448 w.write_f64(*v);
449 }
450 TLExpr::Aggregate {
451 op,
452 var,
453 domain,
454 body,
455 group_by,
456 } => {
457 w.write_u8(TAG_AGGREGATE);
458 w.write_u8(aggregate_op_tag(op));
459 w.write_string(var);
460 w.write_string(domain);
461 write_expr_bin(body, w);
462 match group_by {
463 Some(gb) => {
464 w.write_u8(1);
465 write_string_vec(gb, w);
466 }
467 None => w.write_u8(0),
468 }
469 }
470 TLExpr::Let { var, value, body } => {
471 w.write_u8(TAG_LET);
472 w.write_string(var);
473 write_expr_bin(value, w);
474 write_expr_bin(body, w);
475 }
476 TLExpr::Box(e) => {
477 w.write_u8(TAG_BOX);
478 write_expr_bin(e, w);
479 }
480 TLExpr::Diamond(e) => {
481 w.write_u8(TAG_DIAMOND);
482 write_expr_bin(e, w);
483 }
484 TLExpr::Next(e) => {
485 w.write_u8(TAG_NEXT);
486 write_expr_bin(e, w);
487 }
488 TLExpr::Eventually(e) => {
489 w.write_u8(TAG_EVENTUALLY);
490 write_expr_bin(e, w);
491 }
492 TLExpr::Always(e) => {
493 w.write_u8(TAG_ALWAYS);
494 write_expr_bin(e, w);
495 }
496 TLExpr::Until { before, after } => {
497 w.write_u8(TAG_UNTIL);
498 write_expr_bin(before, w);
499 write_expr_bin(after, w);
500 }
501 TLExpr::TNorm { kind, left, right } => {
502 w.write_u8(TAG_TNORM);
503 w.write_u8(tnorm_kind_tag(kind));
504 write_expr_bin(left, w);
505 write_expr_bin(right, w);
506 }
507 TLExpr::TCoNorm { kind, left, right } => {
508 w.write_u8(TAG_TCONORM);
509 w.write_u8(tconorm_kind_tag(kind));
510 write_expr_bin(left, w);
511 write_expr_bin(right, w);
512 }
513 TLExpr::FuzzyNot { kind, expr: e } => {
514 w.write_u8(TAG_FUZZY_NOT);
515 write_fuzzy_neg_kind_bin(kind, w);
516 write_expr_bin(e, w);
517 }
518 TLExpr::FuzzyImplication {
519 kind,
520 premise,
521 conclusion,
522 } => {
523 w.write_u8(TAG_FUZZY_IMPLICATION);
524 w.write_u8(fuzzy_imp_kind_tag(kind));
525 write_expr_bin(premise, w);
526 write_expr_bin(conclusion, w);
527 }
528 TLExpr::SoftExists {
529 var,
530 domain,
531 body,
532 temperature,
533 } => {
534 w.write_u8(TAG_SOFT_EXISTS);
535 w.write_string(var);
536 w.write_string(domain);
537 w.write_f64(*temperature);
538 write_expr_bin(body, w);
539 }
540 TLExpr::SoftForAll {
541 var,
542 domain,
543 body,
544 temperature,
545 } => {
546 w.write_u8(TAG_SOFT_FORALL);
547 w.write_string(var);
548 w.write_string(domain);
549 w.write_f64(*temperature);
550 write_expr_bin(body, w);
551 }
552 TLExpr::WeightedRule { weight, rule } => {
553 w.write_u8(TAG_WEIGHTED_RULE);
554 w.write_f64(*weight);
555 write_expr_bin(rule, w);
556 }
557 TLExpr::ProbabilisticChoice { alternatives } => {
558 w.write_u8(TAG_PROBABILISTIC_CHOICE);
559 w.write_u32(alternatives.len() as u32);
560 for (prob, alt_expr) in alternatives {
561 w.write_f64(*prob);
562 write_expr_bin(alt_expr, w);
563 }
564 }
565 TLExpr::Release { released, releaser } => {
566 w.write_u8(TAG_RELEASE);
567 write_expr_bin(released, w);
568 write_expr_bin(releaser, w);
569 }
570 TLExpr::WeakUntil { before, after } => {
571 w.write_u8(TAG_WEAK_UNTIL);
572 write_expr_bin(before, w);
573 write_expr_bin(after, w);
574 }
575 TLExpr::StrongRelease { released, releaser } => {
576 w.write_u8(TAG_STRONG_RELEASE);
577 write_expr_bin(released, w);
578 write_expr_bin(releaser, w);
579 }
580 TLExpr::Lambda {
581 var,
582 var_type,
583 body,
584 } => {
585 w.write_u8(TAG_LAMBDA);
586 w.write_string(var);
587 write_optional_string(var_type, w);
588 write_expr_bin(body, w);
589 }
590 TLExpr::Apply { function, argument } => {
591 w.write_u8(TAG_APPLY);
592 write_expr_bin(function, w);
593 write_expr_bin(argument, w);
594 }
595 TLExpr::SetMembership { element, set } => {
596 w.write_u8(TAG_SET_MEMBERSHIP);
597 write_expr_bin(element, w);
598 write_expr_bin(set, w);
599 }
600 TLExpr::SetUnion { left, right } => {
601 w.write_u8(TAG_SET_UNION);
602 write_expr_bin(left, w);
603 write_expr_bin(right, w);
604 }
605 TLExpr::SetIntersection { left, right } => {
606 w.write_u8(TAG_SET_INTERSECTION);
607 write_expr_bin(left, w);
608 write_expr_bin(right, w);
609 }
610 TLExpr::SetDifference { left, right } => {
611 w.write_u8(TAG_SET_DIFFERENCE);
612 write_expr_bin(left, w);
613 write_expr_bin(right, w);
614 }
615 TLExpr::SetCardinality { set } => {
616 w.write_u8(TAG_SET_CARDINALITY);
617 write_expr_bin(set, w);
618 }
619 TLExpr::EmptySet => {
620 w.write_u8(TAG_EMPTY_SET);
621 }
622 TLExpr::SetComprehension {
623 var,
624 domain,
625 condition,
626 } => {
627 w.write_u8(TAG_SET_COMPREHENSION);
628 w.write_string(var);
629 w.write_string(domain);
630 write_expr_bin(condition, w);
631 }
632 TLExpr::CountingExists {
633 var,
634 domain,
635 body,
636 min_count,
637 } => {
638 w.write_u8(TAG_COUNTING_EXISTS);
639 w.write_string(var);
640 w.write_string(domain);
641 w.write_u64(*min_count as u64);
642 write_expr_bin(body, w);
643 }
644 TLExpr::CountingForAll {
645 var,
646 domain,
647 body,
648 min_count,
649 } => {
650 w.write_u8(TAG_COUNTING_FORALL);
651 w.write_string(var);
652 w.write_string(domain);
653 w.write_u64(*min_count as u64);
654 write_expr_bin(body, w);
655 }
656 TLExpr::ExactCount {
657 var,
658 domain,
659 body,
660 count,
661 } => {
662 w.write_u8(TAG_EXACT_COUNT);
663 w.write_string(var);
664 w.write_string(domain);
665 w.write_u64(*count as u64);
666 write_expr_bin(body, w);
667 }
668 TLExpr::Majority { var, domain, body } => {
669 w.write_u8(TAG_MAJORITY);
670 w.write_string(var);
671 w.write_string(domain);
672 write_expr_bin(body, w);
673 }
674 TLExpr::LeastFixpoint { var, body } => {
675 w.write_u8(TAG_LEAST_FIXPOINT);
676 w.write_string(var);
677 write_expr_bin(body, w);
678 }
679 TLExpr::GreatestFixpoint { var, body } => {
680 w.write_u8(TAG_GREATEST_FIXPOINT);
681 w.write_string(var);
682 write_expr_bin(body, w);
683 }
684 TLExpr::Nominal { name } => {
685 w.write_u8(TAG_NOMINAL);
686 w.write_string(name);
687 }
688 TLExpr::At { nominal, formula } => {
689 w.write_u8(TAG_AT);
690 w.write_string(nominal);
691 write_expr_bin(formula, w);
692 }
693 TLExpr::Somewhere { formula } => {
694 w.write_u8(TAG_SOMEWHERE);
695 write_expr_bin(formula, w);
696 }
697 TLExpr::Everywhere { formula } => {
698 w.write_u8(TAG_EVERYWHERE);
699 write_expr_bin(formula, w);
700 }
701 TLExpr::AllDifferent { variables } => {
702 w.write_u8(TAG_ALL_DIFFERENT);
703 write_string_vec(variables, w);
704 }
705 TLExpr::GlobalCardinality {
706 variables,
707 values,
708 min_occurrences,
709 max_occurrences,
710 } => {
711 w.write_u8(TAG_GLOBAL_CARDINALITY);
712 write_string_vec(variables, w);
713 w.write_u32(values.len() as u32);
714 for val in values {
715 write_expr_bin(val, w);
716 }
717 write_usize_vec(min_occurrences, w);
718 write_usize_vec(max_occurrences, w);
719 }
720 TLExpr::Abducible { name, cost } => {
721 w.write_u8(TAG_ABDUCIBLE);
722 w.write_string(name);
723 w.write_f64(*cost);
724 }
725 TLExpr::Explain { formula } => {
726 w.write_u8(TAG_EXPLAIN);
727 write_expr_bin(formula, w);
728 }
729 TLExpr::SymbolLiteral(s) => {
730 w.write_u8(TAG_SYMBOL_LITERAL);
731 w.write_string(s);
732 }
733 TLExpr::Match { scrutinee, arms } => {
734 w.write_u8(TAG_MATCH);
735 write_expr_bin(scrutinee, w);
736 w.write_u32(arms.len() as u32);
737 for (pat, body) in arms {
738 match pat {
739 crate::pattern::MatchPattern::ConstSymbol(s) => {
740 w.write_u8(TAG_PATTERN_CONST_SYMBOL);
741 w.write_string(s);
742 }
743 crate::pattern::MatchPattern::ConstNumber(n) => {
744 w.write_u8(TAG_PATTERN_CONST_NUMBER);
745 w.write_f64(*n);
746 }
747 crate::pattern::MatchPattern::Wildcard => {
748 w.write_u8(TAG_PATTERN_WILDCARD);
749 }
750 }
751 write_expr_bin(body, w);
752 }
753 }
754 }
755}
756
757pub(super) fn read_expr_bin(r: &mut BinReader<'_>) -> Result<TLExpr, ExprSerializeError> {
758 let tag = r.read_u8()?;
759 match tag {
760 TAG_PRED => {
761 let name = r.read_str()?;
762 let count = r.read_u32()? as usize;
763 let mut args = Vec::with_capacity(count);
764 for _ in 0..count {
765 args.push(read_term_bin(r)?);
766 }
767 Ok(TLExpr::Pred { name, args })
768 }
769 TAG_AND => {
770 let a = read_expr_bin(r)?;
771 let b = read_expr_bin(r)?;
772 Ok(TLExpr::And(Box::new(a), Box::new(b)))
773 }
774 TAG_OR => {
775 let a = read_expr_bin(r)?;
776 let b = read_expr_bin(r)?;
777 Ok(TLExpr::Or(Box::new(a), Box::new(b)))
778 }
779 TAG_NOT => {
780 let e = read_expr_bin(r)?;
781 Ok(TLExpr::Not(Box::new(e)))
782 }
783 TAG_EXISTS => {
784 let var = r.read_str()?;
785 let domain = r.read_str()?;
786 let body = read_expr_bin(r)?;
787 Ok(TLExpr::Exists {
788 var,
789 domain,
790 body: Box::new(body),
791 })
792 }
793 TAG_FORALL => {
794 let var = r.read_str()?;
795 let domain = r.read_str()?;
796 let body = read_expr_bin(r)?;
797 Ok(TLExpr::ForAll {
798 var,
799 domain,
800 body: Box::new(body),
801 })
802 }
803 TAG_IMPLY => {
804 let a = read_expr_bin(r)?;
805 let b = read_expr_bin(r)?;
806 Ok(TLExpr::Imply(Box::new(a), Box::new(b)))
807 }
808 TAG_SCORE => {
809 let e = read_expr_bin(r)?;
810 Ok(TLExpr::Score(Box::new(e)))
811 }
812 TAG_ADD => read_binary_expr(r, TLExpr::Add),
813 TAG_SUB => read_binary_expr(r, TLExpr::Sub),
814 TAG_MUL => read_binary_expr(r, TLExpr::Mul),
815 TAG_DIV => read_binary_expr(r, TLExpr::Div),
816 TAG_POW => read_binary_expr(r, TLExpr::Pow),
817 TAG_MOD => read_binary_expr(r, TLExpr::Mod),
818 TAG_MIN => read_binary_expr(r, TLExpr::Min),
819 TAG_MAX => read_binary_expr(r, TLExpr::Max),
820 TAG_ABS => read_unary_expr(r, TLExpr::Abs),
821 TAG_FLOOR => read_unary_expr(r, TLExpr::Floor),
822 TAG_CEIL => read_unary_expr(r, TLExpr::Ceil),
823 TAG_ROUND => read_unary_expr(r, TLExpr::Round),
824 TAG_SQRT => read_unary_expr(r, TLExpr::Sqrt),
825 TAG_EXP => read_unary_expr(r, TLExpr::Exp),
826 TAG_LOG => read_unary_expr(r, TLExpr::Log),
827 TAG_SIN => read_unary_expr(r, TLExpr::Sin),
828 TAG_COS => read_unary_expr(r, TLExpr::Cos),
829 TAG_TAN => read_unary_expr(r, TLExpr::Tan),
830 TAG_EQ => read_binary_expr(r, TLExpr::Eq),
831 TAG_LT => read_binary_expr(r, TLExpr::Lt),
832 TAG_GT => read_binary_expr(r, TLExpr::Gt),
833 TAG_LTE => read_binary_expr(r, TLExpr::Lte),
834 TAG_GTE => read_binary_expr(r, TLExpr::Gte),
835 TAG_IF_THEN_ELSE => {
836 let cond = read_expr_bin(r)?;
837 let then_b = read_expr_bin(r)?;
838 let else_b = read_expr_bin(r)?;
839 Ok(TLExpr::IfThenElse {
840 condition: Box::new(cond),
841 then_branch: Box::new(then_b),
842 else_branch: Box::new(else_b),
843 })
844 }
845 TAG_CONSTANT => {
846 let v = r.read_f64()?;
847 Ok(TLExpr::Constant(v))
848 }
849 TAG_AGGREGATE => {
850 let op_tag = r.read_u8()?;
851 let op = read_aggregate_op_tag(op_tag)?;
852 let var = r.read_str()?;
853 let domain = r.read_str()?;
854 let body = read_expr_bin(r)?;
855 let has_gb = r.read_u8()?;
856 let group_by = if has_gb == 0 {
857 None
858 } else {
859 Some(read_string_vec(r)?)
860 };
861 Ok(TLExpr::Aggregate {
862 op,
863 var,
864 domain,
865 body: Box::new(body),
866 group_by,
867 })
868 }
869 TAG_LET => {
870 let var = r.read_str()?;
871 let value = read_expr_bin(r)?;
872 let body = read_expr_bin(r)?;
873 Ok(TLExpr::Let {
874 var,
875 value: Box::new(value),
876 body: Box::new(body),
877 })
878 }
879 TAG_BOX => read_unary_expr(r, TLExpr::Box),
880 TAG_DIAMOND => read_unary_expr(r, TLExpr::Diamond),
881 TAG_NEXT => read_unary_expr(r, TLExpr::Next),
882 TAG_EVENTUALLY => read_unary_expr(r, TLExpr::Eventually),
883 TAG_ALWAYS => read_unary_expr(r, TLExpr::Always),
884 TAG_UNTIL => {
885 let before = read_expr_bin(r)?;
886 let after = read_expr_bin(r)?;
887 Ok(TLExpr::Until {
888 before: Box::new(before),
889 after: Box::new(after),
890 })
891 }
892 TAG_TNORM => {
893 let kind_tag = r.read_u8()?;
894 let kind = read_tnorm_kind_tag(kind_tag)?;
895 let left = read_expr_bin(r)?;
896 let right = read_expr_bin(r)?;
897 Ok(TLExpr::TNorm {
898 kind,
899 left: Box::new(left),
900 right: Box::new(right),
901 })
902 }
903 TAG_TCONORM => {
904 let kind_tag = r.read_u8()?;
905 let kind = read_tconorm_kind_tag(kind_tag)?;
906 let left = read_expr_bin(r)?;
907 let right = read_expr_bin(r)?;
908 Ok(TLExpr::TCoNorm {
909 kind,
910 left: Box::new(left),
911 right: Box::new(right),
912 })
913 }
914 TAG_FUZZY_NOT => {
915 let kind = read_fuzzy_neg_kind_bin(r)?;
916 let e = read_expr_bin(r)?;
917 Ok(TLExpr::FuzzyNot {
918 kind,
919 expr: Box::new(e),
920 })
921 }
922 TAG_FUZZY_IMPLICATION => {
923 let kind_tag = r.read_u8()?;
924 let kind = read_fuzzy_imp_kind_tag(kind_tag)?;
925 let premise = read_expr_bin(r)?;
926 let conclusion = read_expr_bin(r)?;
927 Ok(TLExpr::FuzzyImplication {
928 kind,
929 premise: Box::new(premise),
930 conclusion: Box::new(conclusion),
931 })
932 }
933 TAG_SOFT_EXISTS => {
934 let var = r.read_str()?;
935 let domain = r.read_str()?;
936 let temperature = r.read_f64()?;
937 let body = read_expr_bin(r)?;
938 Ok(TLExpr::SoftExists {
939 var,
940 domain,
941 body: Box::new(body),
942 temperature,
943 })
944 }
945 TAG_SOFT_FORALL => {
946 let var = r.read_str()?;
947 let domain = r.read_str()?;
948 let temperature = r.read_f64()?;
949 let body = read_expr_bin(r)?;
950 Ok(TLExpr::SoftForAll {
951 var,
952 domain,
953 body: Box::new(body),
954 temperature,
955 })
956 }
957 TAG_WEIGHTED_RULE => {
958 let weight = r.read_f64()?;
959 let rule = read_expr_bin(r)?;
960 Ok(TLExpr::WeightedRule {
961 weight,
962 rule: Box::new(rule),
963 })
964 }
965 TAG_PROBABILISTIC_CHOICE => {
966 let count = r.read_u32()? as usize;
967 let mut alternatives = Vec::with_capacity(count);
968 for _ in 0..count {
969 let prob = r.read_f64()?;
970 let alt_expr = read_expr_bin(r)?;
971 alternatives.push((prob, alt_expr));
972 }
973 Ok(TLExpr::ProbabilisticChoice { alternatives })
974 }
975 TAG_RELEASE => {
976 let released = read_expr_bin(r)?;
977 let releaser = read_expr_bin(r)?;
978 Ok(TLExpr::Release {
979 released: Box::new(released),
980 releaser: Box::new(releaser),
981 })
982 }
983 TAG_WEAK_UNTIL => {
984 let before = read_expr_bin(r)?;
985 let after = read_expr_bin(r)?;
986 Ok(TLExpr::WeakUntil {
987 before: Box::new(before),
988 after: Box::new(after),
989 })
990 }
991 TAG_STRONG_RELEASE => {
992 let released = read_expr_bin(r)?;
993 let releaser = read_expr_bin(r)?;
994 Ok(TLExpr::StrongRelease {
995 released: Box::new(released),
996 releaser: Box::new(releaser),
997 })
998 }
999 TAG_LAMBDA => {
1000 let var = r.read_str()?;
1001 let var_type = read_optional_string(r)?;
1002 let body = read_expr_bin(r)?;
1003 Ok(TLExpr::Lambda {
1004 var,
1005 var_type,
1006 body: Box::new(body),
1007 })
1008 }
1009 TAG_APPLY => {
1010 let function = read_expr_bin(r)?;
1011 let argument = read_expr_bin(r)?;
1012 Ok(TLExpr::Apply {
1013 function: Box::new(function),
1014 argument: Box::new(argument),
1015 })
1016 }
1017 TAG_SET_MEMBERSHIP => {
1018 let element = read_expr_bin(r)?;
1019 let set = read_expr_bin(r)?;
1020 Ok(TLExpr::SetMembership {
1021 element: Box::new(element),
1022 set: Box::new(set),
1023 })
1024 }
1025 TAG_SET_UNION => read_binary_expr(r, |a, b| TLExpr::SetUnion { left: a, right: b }),
1026 TAG_SET_INTERSECTION => {
1027 read_binary_expr(r, |a, b| TLExpr::SetIntersection { left: a, right: b })
1028 }
1029 TAG_SET_DIFFERENCE => {
1030 read_binary_expr(r, |a, b| TLExpr::SetDifference { left: a, right: b })
1031 }
1032 TAG_SET_CARDINALITY => read_unary_expr(r, |e| TLExpr::SetCardinality { set: e }),
1033 TAG_EMPTY_SET => Ok(TLExpr::EmptySet),
1034 TAG_SET_COMPREHENSION => {
1035 let var = r.read_str()?;
1036 let domain = r.read_str()?;
1037 let condition = read_expr_bin(r)?;
1038 Ok(TLExpr::SetComprehension {
1039 var,
1040 domain,
1041 condition: Box::new(condition),
1042 })
1043 }
1044 TAG_COUNTING_EXISTS => {
1045 let var = r.read_str()?;
1046 let domain = r.read_str()?;
1047 let min_count = r.read_u64()? as usize;
1048 let body = read_expr_bin(r)?;
1049 Ok(TLExpr::CountingExists {
1050 var,
1051 domain,
1052 body: Box::new(body),
1053 min_count,
1054 })
1055 }
1056 TAG_COUNTING_FORALL => {
1057 let var = r.read_str()?;
1058 let domain = r.read_str()?;
1059 let min_count = r.read_u64()? as usize;
1060 let body = read_expr_bin(r)?;
1061 Ok(TLExpr::CountingForAll {
1062 var,
1063 domain,
1064 body: Box::new(body),
1065 min_count,
1066 })
1067 }
1068 TAG_EXACT_COUNT => {
1069 let var = r.read_str()?;
1070 let domain = r.read_str()?;
1071 let count = r.read_u64()? as usize;
1072 let body = read_expr_bin(r)?;
1073 Ok(TLExpr::ExactCount {
1074 var,
1075 domain,
1076 body: Box::new(body),
1077 count,
1078 })
1079 }
1080 TAG_MAJORITY => {
1081 let var = r.read_str()?;
1082 let domain = r.read_str()?;
1083 let body = read_expr_bin(r)?;
1084 Ok(TLExpr::Majority {
1085 var,
1086 domain,
1087 body: Box::new(body),
1088 })
1089 }
1090 TAG_LEAST_FIXPOINT => {
1091 let var = r.read_str()?;
1092 let body = read_expr_bin(r)?;
1093 Ok(TLExpr::LeastFixpoint {
1094 var,
1095 body: Box::new(body),
1096 })
1097 }
1098 TAG_GREATEST_FIXPOINT => {
1099 let var = r.read_str()?;
1100 let body = read_expr_bin(r)?;
1101 Ok(TLExpr::GreatestFixpoint {
1102 var,
1103 body: Box::new(body),
1104 })
1105 }
1106 TAG_NOMINAL => {
1107 let name = r.read_str()?;
1108 Ok(TLExpr::Nominal { name })
1109 }
1110 TAG_AT => {
1111 let nominal = r.read_str()?;
1112 let formula = read_expr_bin(r)?;
1113 Ok(TLExpr::At {
1114 nominal,
1115 formula: Box::new(formula),
1116 })
1117 }
1118 TAG_SOMEWHERE => read_unary_expr(r, |e| TLExpr::Somewhere { formula: e }),
1119 TAG_EVERYWHERE => read_unary_expr(r, |e| TLExpr::Everywhere { formula: e }),
1120 TAG_ALL_DIFFERENT => {
1121 let variables = read_string_vec(r)?;
1122 Ok(TLExpr::AllDifferent { variables })
1123 }
1124 TAG_GLOBAL_CARDINALITY => {
1125 let variables = read_string_vec(r)?;
1126 let val_count = r.read_u32()? as usize;
1127 let mut values = Vec::with_capacity(val_count);
1128 for _ in 0..val_count {
1129 values.push(read_expr_bin(r)?);
1130 }
1131 let min_occurrences = read_usize_vec(r)?;
1132 let max_occurrences = read_usize_vec(r)?;
1133 Ok(TLExpr::GlobalCardinality {
1134 variables,
1135 values,
1136 min_occurrences,
1137 max_occurrences,
1138 })
1139 }
1140 TAG_ABDUCIBLE => {
1141 let name = r.read_str()?;
1142 let cost = r.read_f64()?;
1143 Ok(TLExpr::Abducible { name, cost })
1144 }
1145 TAG_EXPLAIN => read_unary_expr(r, |e| TLExpr::Explain { formula: e }),
1146 TAG_SYMBOL_LITERAL => {
1147 let s = r.read_str()?;
1148 Ok(TLExpr::SymbolLiteral(s))
1149 }
1150 TAG_MATCH => {
1151 let scrutinee = read_expr_bin(r)?;
1152 let arm_count = r.read_u32()? as usize;
1153 let mut arms = Vec::with_capacity(arm_count);
1154 for _ in 0..arm_count {
1155 let pat_tag = r.read_u8()?;
1156 let pat = match pat_tag {
1157 TAG_PATTERN_CONST_SYMBOL => {
1158 let s = r.read_str()?;
1159 crate::pattern::MatchPattern::ConstSymbol(s)
1160 }
1161 TAG_PATTERN_CONST_NUMBER => {
1162 let n = r.read_f64()?;
1163 crate::pattern::MatchPattern::ConstNumber(n)
1164 }
1165 TAG_PATTERN_WILDCARD => crate::pattern::MatchPattern::Wildcard,
1166 other => {
1167 return Err(ExprSerializeError::UnknownVariant(format!(
1168 "pattern tag {other}"
1169 )));
1170 }
1171 };
1172 let body = read_expr_bin(r)?;
1173 arms.push((pat, Box::new(body)));
1174 }
1175 Ok(TLExpr::Match {
1176 scrutinee: Box::new(scrutinee),
1177 arms,
1178 })
1179 }
1180 _ => Err(ExprSerializeError::UnknownVariant(format!(
1181 "binary tag {tag}"
1182 ))),
1183 }
1184}
1185
1186fn read_unary_expr(
1187 r: &mut BinReader<'_>,
1188 ctor: impl FnOnce(Box<TLExpr>) -> TLExpr,
1189) -> Result<TLExpr, ExprSerializeError> {
1190 let e = read_expr_bin(r)?;
1191 Ok(ctor(Box::new(e)))
1192}
1193
1194fn read_binary_expr(
1195 r: &mut BinReader<'_>,
1196 ctor: impl FnOnce(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
1197) -> Result<TLExpr, ExprSerializeError> {
1198 let a = read_expr_bin(r)?;
1199 let b = read_expr_bin(r)?;
1200 Ok(ctor(Box::new(a), Box::new(b)))
1201}
1202
1203fn aggregate_op_tag(op: &AggregateOp) -> u8 {
1204 match op {
1205 AggregateOp::Count => AGG_COUNT,
1206 AggregateOp::Sum => AGG_SUM,
1207 AggregateOp::Average => AGG_AVERAGE,
1208 AggregateOp::Max => AGG_MAX,
1209 AggregateOp::Min => AGG_MIN,
1210 AggregateOp::Product => AGG_PRODUCT,
1211 AggregateOp::Any => AGG_ANY,
1212 AggregateOp::All => AGG_ALL,
1213 }
1214}
1215
1216fn read_aggregate_op_tag(tag: u8) -> Result<AggregateOp, ExprSerializeError> {
1217 match tag {
1218 AGG_COUNT => Ok(AggregateOp::Count),
1219 AGG_SUM => Ok(AggregateOp::Sum),
1220 AGG_AVERAGE => Ok(AggregateOp::Average),
1221 AGG_MAX => Ok(AggregateOp::Max),
1222 AGG_MIN => Ok(AggregateOp::Min),
1223 AGG_PRODUCT => Ok(AggregateOp::Product),
1224 AGG_ANY => Ok(AggregateOp::Any),
1225 AGG_ALL => Ok(AggregateOp::All),
1226 _ => Err(ExprSerializeError::UnknownVariant(format!(
1227 "AggregateOp tag {tag}"
1228 ))),
1229 }
1230}
1231
1232fn tnorm_kind_tag(kind: &TNormKind) -> u8 {
1233 match kind {
1234 TNormKind::Minimum => TNORM_MINIMUM,
1235 TNormKind::Product => TNORM_PRODUCT,
1236 TNormKind::Lukasiewicz => TNORM_LUKASIEWICZ,
1237 TNormKind::Drastic => TNORM_DRASTIC,
1238 TNormKind::NilpotentMinimum => TNORM_NILPOTENT_MINIMUM,
1239 TNormKind::Hamacher => TNORM_HAMACHER,
1240 }
1241}
1242
1243fn read_tnorm_kind_tag(tag: u8) -> Result<TNormKind, ExprSerializeError> {
1244 match tag {
1245 TNORM_MINIMUM => Ok(TNormKind::Minimum),
1246 TNORM_PRODUCT => Ok(TNormKind::Product),
1247 TNORM_LUKASIEWICZ => Ok(TNormKind::Lukasiewicz),
1248 TNORM_DRASTIC => Ok(TNormKind::Drastic),
1249 TNORM_NILPOTENT_MINIMUM => Ok(TNormKind::NilpotentMinimum),
1250 TNORM_HAMACHER => Ok(TNormKind::Hamacher),
1251 _ => Err(ExprSerializeError::UnknownVariant(format!(
1252 "TNormKind tag {tag}"
1253 ))),
1254 }
1255}
1256
1257fn tconorm_kind_tag(kind: &TCoNormKind) -> u8 {
1258 match kind {
1259 TCoNormKind::Maximum => TCONORM_MAXIMUM,
1260 TCoNormKind::ProbabilisticSum => TCONORM_PROBABILISTIC_SUM,
1261 TCoNormKind::BoundedSum => TCONORM_BOUNDED_SUM,
1262 TCoNormKind::Drastic => TCONORM_DRASTIC,
1263 TCoNormKind::NilpotentMaximum => TCONORM_NILPOTENT_MAXIMUM,
1264 TCoNormKind::Hamacher => TCONORM_HAMACHER,
1265 }
1266}
1267
1268fn read_tconorm_kind_tag(tag: u8) -> Result<TCoNormKind, ExprSerializeError> {
1269 match tag {
1270 TCONORM_MAXIMUM => Ok(TCoNormKind::Maximum),
1271 TCONORM_PROBABILISTIC_SUM => Ok(TCoNormKind::ProbabilisticSum),
1272 TCONORM_BOUNDED_SUM => Ok(TCoNormKind::BoundedSum),
1273 TCONORM_DRASTIC => Ok(TCoNormKind::Drastic),
1274 TCONORM_NILPOTENT_MAXIMUM => Ok(TCoNormKind::NilpotentMaximum),
1275 TCONORM_HAMACHER => Ok(TCoNormKind::Hamacher),
1276 _ => Err(ExprSerializeError::UnknownVariant(format!(
1277 "TCoNormKind tag {tag}"
1278 ))),
1279 }
1280}
1281
1282fn write_fuzzy_neg_kind_bin(kind: &FuzzyNegationKind, w: &mut BinWriter) {
1283 match kind {
1284 FuzzyNegationKind::Standard => w.write_u8(FNEG_STANDARD),
1285 FuzzyNegationKind::Sugeno { lambda } => {
1286 w.write_u8(FNEG_SUGENO);
1287 w.write_i32(*lambda);
1288 }
1289 FuzzyNegationKind::Yager { w: wval } => {
1290 w.write_u8(FNEG_YAGER);
1291 w.write_u32(*wval);
1292 }
1293 }
1294}
1295
1296fn read_fuzzy_neg_kind_bin(r: &mut BinReader<'_>) -> Result<FuzzyNegationKind, ExprSerializeError> {
1297 let tag = r.read_u8()?;
1298 match tag {
1299 FNEG_STANDARD => Ok(FuzzyNegationKind::Standard),
1300 FNEG_SUGENO => {
1301 let lambda = r.read_i32()?;
1302 Ok(FuzzyNegationKind::Sugeno { lambda })
1303 }
1304 FNEG_YAGER => {
1305 let w = r.read_u32()?;
1306 Ok(FuzzyNegationKind::Yager { w })
1307 }
1308 _ => Err(ExprSerializeError::UnknownVariant(format!(
1309 "FuzzyNegationKind tag {tag}"
1310 ))),
1311 }
1312}
1313
1314fn fuzzy_imp_kind_tag(kind: &FuzzyImplicationKind) -> u8 {
1315 match kind {
1316 FuzzyImplicationKind::Godel => FIMP_GODEL,
1317 FuzzyImplicationKind::Lukasiewicz => FIMP_LUKASIEWICZ,
1318 FuzzyImplicationKind::Reichenbach => FIMP_REICHENBACH,
1319 FuzzyImplicationKind::KleeneDienes => FIMP_KLEENE_DIENES,
1320 FuzzyImplicationKind::Rescher => FIMP_RESCHER,
1321 FuzzyImplicationKind::Goguen => FIMP_GOGUEN,
1322 }
1323}
1324
1325fn read_fuzzy_imp_kind_tag(tag: u8) -> Result<FuzzyImplicationKind, ExprSerializeError> {
1326 match tag {
1327 FIMP_GODEL => Ok(FuzzyImplicationKind::Godel),
1328 FIMP_LUKASIEWICZ => Ok(FuzzyImplicationKind::Lukasiewicz),
1329 FIMP_REICHENBACH => Ok(FuzzyImplicationKind::Reichenbach),
1330 FIMP_KLEENE_DIENES => Ok(FuzzyImplicationKind::KleeneDienes),
1331 FIMP_RESCHER => Ok(FuzzyImplicationKind::Rescher),
1332 FIMP_GOGUEN => Ok(FuzzyImplicationKind::Goguen),
1333 _ => Err(ExprSerializeError::UnknownVariant(format!(
1334 "FuzzyImplicationKind tag {tag}"
1335 ))),
1336 }
1337}
1338
1339pub fn graph_to_binary(graph: &EinsumGraph) -> Vec<u8> {
1345 let mut w = BinWriter::new();
1346 w.write_magic(&TLGR_MAGIC);
1347 w.write_u32(FORMAT_VER);
1348
1349 write_string_vec(&graph.tensors, &mut w);
1351
1352 w.write_u32(graph.nodes.len() as u32);
1354 for node in &graph.nodes {
1355 write_optype_bin(&node.op, &mut w);
1356 write_usize_vec(&node.inputs, &mut w);
1357 write_usize_vec(&node.outputs, &mut w);
1358 match &node.metadata {
1360 Some(meta) => {
1361 w.write_u8(1);
1362 write_optional_string(&meta.name, &mut w);
1363 }
1364 None => w.write_u8(0),
1365 }
1366 }
1367
1368 write_usize_vec(&graph.inputs, &mut w);
1370 write_usize_vec(&graph.outputs, &mut w);
1372
1373 w.write_u32(graph.tensor_metadata.len() as u32);
1375 let mut keys: Vec<&usize> = graph.tensor_metadata.keys().collect();
1377 keys.sort();
1378 for &key in &keys {
1379 if let Some(meta) = graph.tensor_metadata.get(key) {
1380 w.write_u64(*key as u64);
1381 write_optional_string(&meta.name, &mut w);
1382 }
1383 }
1384
1385 w.into_bytes()
1386}
1387
1388pub fn graph_from_binary(bytes: &[u8]) -> Result<EinsumGraph, ExprSerializeError> {
1390 let mut r = BinReader::new(bytes);
1391 let magic = r.read_magic()?;
1392 if magic != TLGR_MAGIC {
1393 return Err(ExprSerializeError::InvalidMagic);
1394 }
1395 let version = r.read_u32()?;
1396 if version != FORMAT_VER {
1397 return Err(ExprSerializeError::VersionMismatch {
1398 expected: FORMAT_VER,
1399 got: version,
1400 });
1401 }
1402
1403 let tensors = read_string_vec(&mut r)?;
1404 let node_count = r.read_u32()? as usize;
1405 let mut nodes = Vec::with_capacity(node_count);
1406 for _ in 0..node_count {
1407 let op = read_optype_bin(&mut r)?;
1408 let inputs = read_usize_vec(&mut r)?;
1409 let outputs = read_usize_vec(&mut r)?;
1410 let has_meta = r.read_u8()?;
1411 let metadata = if has_meta != 0 {
1412 let name = read_optional_string(&mut r)?;
1413 let mut meta = Metadata::new();
1414 if let Some(n) = name {
1415 meta = meta.with_name(n);
1416 }
1417 Some(meta)
1418 } else {
1419 None
1420 };
1421 nodes.push(EinsumNode {
1422 op,
1423 inputs,
1424 outputs,
1425 metadata,
1426 });
1427 }
1428
1429 let inputs = read_usize_vec(&mut r)?;
1430 let outputs = read_usize_vec(&mut r)?;
1431
1432 let meta_count = r.read_u32()? as usize;
1433 let mut tensor_metadata = HashMap::new();
1434 for _ in 0..meta_count {
1435 let key = r.read_u64()? as usize;
1436 let name_opt = read_optional_string(&mut r)?;
1437 let mut meta = Metadata::new();
1438 if let Some(n) = name_opt {
1439 meta = meta.with_name(n);
1440 }
1441 tensor_metadata.insert(key, meta);
1442 }
1443
1444 Ok(EinsumGraph {
1445 tensors,
1446 nodes,
1447 inputs,
1448 outputs,
1449 tensor_metadata,
1450 })
1451}
1452
1453fn write_optype_bin(op: &OpType, w: &mut BinWriter) {
1454 match op {
1455 OpType::Einsum { spec } => {
1456 w.write_u8(OP_EINSUM);
1457 w.write_string(spec);
1458 }
1459 OpType::ElemUnary { op: op_name } => {
1460 w.write_u8(OP_ELEM_UNARY);
1461 w.write_string(op_name);
1462 }
1463 OpType::ElemBinary { op: op_name } => {
1464 w.write_u8(OP_ELEM_BINARY);
1465 w.write_string(op_name);
1466 }
1467 OpType::Reduce { op: op_name, axes } => {
1468 w.write_u8(OP_REDUCE);
1469 w.write_string(op_name);
1470 write_usize_vec(axes, w);
1471 }
1472 }
1473}
1474
1475fn read_optype_bin(r: &mut BinReader<'_>) -> Result<OpType, ExprSerializeError> {
1476 let tag = r.read_u8()?;
1477 match tag {
1478 OP_EINSUM => {
1479 let spec = r.read_str()?;
1480 Ok(OpType::Einsum { spec })
1481 }
1482 OP_ELEM_UNARY => {
1483 let op_name = r.read_str()?;
1484 Ok(OpType::ElemUnary { op: op_name })
1485 }
1486 OP_ELEM_BINARY => {
1487 let op_name = r.read_str()?;
1488 Ok(OpType::ElemBinary { op: op_name })
1489 }
1490 OP_REDUCE => {
1491 let op_name = r.read_str()?;
1492 let axes = read_usize_vec(r)?;
1493 Ok(OpType::Reduce { op: op_name, axes })
1494 }
1495 _ => Err(ExprSerializeError::UnknownVariant(format!(
1496 "OpType tag {tag}"
1497 ))),
1498 }
1499}