1use serde::{Deserialize, Serialize};
32use std::collections::{HashMap, HashSet};
33use std::fmt;
34
35use crate::{IrError, ParametricType, TLExpr, Term};
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
39pub struct Refinement {
40 pub var_name: String,
42 pub predicate: TLExpr,
44}
45
46impl Refinement {
47 pub fn new(var_name: impl Into<String>, predicate: TLExpr) -> Self {
48 Refinement {
49 var_name: var_name.into(),
50 predicate,
51 }
52 }
53
54 pub fn free_vars(&self) -> HashSet<String> {
56 let mut vars = self.predicate.free_vars();
57 vars.remove(&self.var_name);
58 vars
59 }
60
61 pub fn substitute(&self, subst: &HashMap<String, Term>) -> Refinement {
63 let mut filtered_subst = subst.clone();
65 filtered_subst.remove(&self.var_name);
66
67 Refinement {
68 var_name: self.var_name.clone(),
69 predicate: self.predicate.clone(), }
71 }
72
73 pub fn simplify(&self) -> Refinement {
75 use crate::optimize_expr;
76
77 Refinement {
78 var_name: self.var_name.clone(),
79 predicate: optimize_expr(&self.predicate),
80 }
81 }
82
83 pub fn implies(&self, other: &Refinement) -> bool {
85 if self.var_name != other.var_name {
87 return false;
88 }
89
90 self.predicate == other.predicate
92 }
93}
94
95impl fmt::Display for Refinement {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(f, "{{{}: | {}}}", self.var_name, self.predicate)
98 }
99}
100
101#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
103pub struct RefinementType {
104 pub var_name: String,
106 pub base_type: ParametricType,
108 pub refinement: TLExpr,
110}
111
112impl RefinementType {
113 pub fn new(
114 var_name: impl Into<String>,
115 base_type: impl Into<String>,
116 refinement: TLExpr,
117 ) -> Self {
118 RefinementType {
119 var_name: var_name.into(),
120 base_type: ParametricType::concrete(base_type),
121 refinement,
122 }
123 }
124
125 pub fn from_parametric(
127 var_name: impl Into<String>,
128 base_type: ParametricType,
129 refinement: TLExpr,
130 ) -> Self {
131 RefinementType {
132 var_name: var_name.into(),
133 base_type,
134 refinement,
135 }
136 }
137
138 pub fn positive_int(var_name: impl Into<String>) -> Self {
140 let var_name = var_name.into();
141 RefinementType::new(
142 var_name.clone(),
143 "Int",
144 TLExpr::gt(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
145 )
146 }
147
148 pub fn nat(var_name: impl Into<String>) -> Self {
150 let var_name = var_name.into();
151 RefinementType::new(
152 var_name.clone(),
153 "Int",
154 TLExpr::gte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
155 )
156 }
157
158 pub fn probability(var_name: impl Into<String>) -> Self {
160 let var_name = var_name.into();
161 RefinementType::new(
162 var_name.clone(),
163 "Float",
164 TLExpr::and(
165 TLExpr::gte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
166 TLExpr::lte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(1.0)),
167 ),
168 )
169 }
170
171 pub fn non_empty_vec(var_name: impl Into<String>, element_type: impl Into<String>) -> Self {
173 let var_name = var_name.into();
174 use crate::TypeConstructor;
175
176 let elem_type = ParametricType::concrete(element_type);
177 let vec_type = ParametricType::apply(TypeConstructor::List, vec![elem_type]);
178
179 RefinementType::from_parametric(
180 var_name.clone(),
181 vec_type,
182 TLExpr::gt(TLExpr::pred("length", vec![]), TLExpr::constant(0.0)),
183 )
184 }
185
186 pub fn free_vars(&self) -> HashSet<String> {
188 let mut vars = self.refinement.free_vars();
189 vars.remove(&self.var_name);
190 vars
191 }
192
193 pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
195 if self.base_type != other.base_type {
197 return false;
198 }
199
200 if self.var_name != other.var_name {
202 return false;
203 }
204
205 self.refinement == other.refinement
208 }
209
210 pub fn weaken(&self) -> RefinementType {
212 RefinementType {
214 var_name: self.var_name.clone(),
215 base_type: self.base_type.clone(),
216 refinement: TLExpr::constant(1.0), }
218 }
219
220 pub fn strengthen(&self, additional: TLExpr) -> RefinementType {
222 RefinementType {
223 var_name: self.var_name.clone(),
224 base_type: self.base_type.clone(),
225 refinement: TLExpr::and(self.refinement.clone(), additional),
226 }
227 }
228}
229
230impl fmt::Display for RefinementType {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 write!(
233 f,
234 "{{{}: {} | {}}}",
235 self.var_name, self.base_type, self.refinement
236 )
237 }
238}
239
240#[derive(Clone, Debug, Default)]
242pub struct RefinementContext {
243 bindings: HashMap<String, RefinementType>,
245 assumptions: Vec<TLExpr>,
247}
248
249impl RefinementContext {
250 pub fn new() -> Self {
251 Self::default()
252 }
253
254 pub fn bind(&mut self, name: impl Into<String>, typ: RefinementType) {
256 let name = name.into();
257
258 let assumption = typ.refinement.clone();
260 self.assumptions.push(assumption);
261
262 self.bindings.insert(name, typ);
263 }
264
265 pub fn get_type(&self, name: &str) -> Option<&RefinementType> {
267 self.bindings.get(name)
268 }
269
270 pub fn assume(&mut self, fact: TLExpr) {
272 self.assumptions.push(fact);
273 }
274
275 pub fn check_refinement(&self, refinement: &TLExpr) -> bool {
277 self.assumptions.contains(refinement)
280 }
281
282 pub fn verify(&self, _value: &Term, _typ: &RefinementType) -> Result<(), IrError> {
284 Ok(())
287 }
288}
289
290#[derive(Clone, Debug)]
292pub struct LiquidTypeInference {
293 context: RefinementContext,
294 unknowns: HashMap<String, Vec<TLExpr>>,
296}
297
298impl LiquidTypeInference {
299 pub fn new() -> Self {
300 LiquidTypeInference {
301 context: RefinementContext::new(),
302 unknowns: HashMap::new(),
303 }
304 }
305
306 pub fn add_unknown(&mut self, name: impl Into<String>, candidates: Vec<TLExpr>) {
308 self.unknowns.insert(name.into(), candidates);
309 }
310
311 pub fn infer(&mut self) -> HashMap<String, TLExpr> {
313 let mut inferred = HashMap::new();
315
316 for (name, candidates) in &self.unknowns {
317 if let Some(refinement) = candidates.first() {
319 inferred.insert(name.clone(), refinement.clone());
320 }
321 }
322
323 inferred
324 }
325
326 pub fn context(&self) -> &RefinementContext {
328 &self.context
329 }
330}
331
332impl Default for LiquidTypeInference {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_refinement_creation() {
344 let predicate = TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0));
345
346 let refinement = Refinement::new("x", predicate.clone());
347 assert_eq!(refinement.var_name, "x");
348 assert_eq!(refinement.predicate, predicate);
349 }
350
351 #[test]
352 fn test_refinement_type_positive_int() {
353 let pos_int = RefinementType::positive_int("x");
354 assert_eq!(pos_int.var_name, "x");
355 assert_eq!(pos_int.base_type, ParametricType::concrete("Int"));
356 assert!(pos_int.free_vars().is_empty());
357 }
358
359 #[test]
360 fn test_refinement_type_nat() {
361 let nat = RefinementType::nat("n");
362 assert_eq!(nat.to_string(), "{n: Int | (n() ≥ 0)}");
364 }
365
366 #[test]
367 fn test_refinement_type_probability() {
368 let prob = RefinementType::probability("p");
369 let s = prob.to_string();
370 assert!(s.contains("Float"));
371 assert!(s.contains("≥") || s.contains(">="));
373 assert!(s.contains("≤") || s.contains("<="));
374 }
375
376 #[test]
377 fn test_refinement_context() {
378 let mut ctx = RefinementContext::new();
379 let pos_int = RefinementType::positive_int("x");
380
381 ctx.bind("x", pos_int.clone());
382 assert!(ctx.get_type("x").is_some());
383 assert_eq!(ctx.get_type("x").unwrap(), &pos_int);
384 }
385
386 #[test]
387 fn test_refinement_type_weaken() {
388 let pos_int = RefinementType::positive_int("x");
389 let weakened = pos_int.weaken();
390
391 assert_eq!(weakened.base_type, pos_int.base_type);
393 assert_eq!(weakened.refinement, TLExpr::constant(1.0));
394 }
395
396 #[test]
397 fn test_refinement_type_strengthen() {
398 let pos_int = RefinementType::positive_int("x");
399 let additional = TLExpr::lt(TLExpr::pred("x", vec![]), TLExpr::constant(100.0));
400
401 let strengthened = pos_int.strengthen(additional.clone());
402
403 if let TLExpr::And(left, right) = &strengthened.refinement {
405 assert!(**left == pos_int.refinement || **right == pos_int.refinement);
406 } else {
407 panic!("Expected AND expression");
408 }
409 }
410
411 #[test]
412 fn test_liquid_type_inference() {
413 let mut inference = LiquidTypeInference::new();
414
415 let candidates = vec![
416 TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
417 TLExpr::gte(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
418 ];
419
420 inference.add_unknown("x_refinement", candidates);
421
422 let inferred = inference.infer();
423 assert!(inferred.contains_key("x_refinement"));
424 }
425
426 #[test]
427 fn test_refinement_free_vars() {
428 let predicate = TLExpr::and(
429 TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
430 TLExpr::lt(TLExpr::pred("x", vec![]), TLExpr::pred("y", vec![])),
431 );
432
433 let refinement = Refinement::new("x", predicate);
434 let free_vars = refinement.free_vars();
435
436 assert!(!free_vars.contains("x"));
439 assert!(free_vars.contains("y") || free_vars.is_empty()); }
442
443 #[test]
444 fn test_non_empty_vec() {
445 let non_empty = RefinementType::non_empty_vec("v", "Int");
446 assert!(non_empty.to_string().contains("List"));
447 assert!(non_empty.to_string().contains("length"));
448 }
449
450 #[test]
451 fn test_refinement_context_assumptions() {
452 let mut ctx = RefinementContext::new();
453 let fact = TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0));
454
455 ctx.assume(fact.clone());
456 assert!(ctx.check_refinement(&fact));
457 }
458
459 #[test]
460 fn test_refinement_type_subtyping() {
461 let pos_int = RefinementType::positive_int("x");
462 let nat = RefinementType::nat("x");
463
464 assert!(!pos_int.is_subtype_of(&nat)); }
468}