Skip to main content

z3/
sort.rs

1use std::convert::TryInto;
2use std::ffi::CStr;
3use std::fmt;
4use std::ptr::NonNull;
5use z3_sys::*;
6
7use crate::{Context, FuncDecl, Sort, SortDiffers, Symbol};
8
9impl Sort {
10    pub(crate) unsafe fn wrap(ctx: &Context, z3_sort: Z3_sort) -> Sort {
11        unsafe {
12            Z3_inc_ref(ctx.z3_ctx.0, Z3_sort_to_ast(ctx.z3_ctx.0, z3_sort).unwrap());
13        }
14        Sort {
15            ctx: ctx.clone(),
16            z3_sort,
17        }
18    }
19
20    pub fn get_z3_sort(&self) -> Z3_sort {
21        self.z3_sort
22    }
23
24    pub fn uninterpreted(name: Symbol) -> Sort {
25        let ctx = &Context::thread_local();
26
27        unsafe {
28            Self::wrap(
29                ctx,
30                Z3_mk_uninterpreted_sort(ctx.z3_ctx.0, name.as_z3_symbol()).unwrap(),
31            )
32        }
33    }
34
35    pub fn bool() -> Sort {
36        unsafe {
37            let ctx = &Context::thread_local();
38            Self::wrap(ctx, Z3_mk_bool_sort(ctx.z3_ctx.0).unwrap())
39        }
40    }
41
42    pub fn int() -> Sort {
43        unsafe {
44            let ctx = &Context::thread_local();
45            Self::wrap(ctx, Z3_mk_int_sort(ctx.z3_ctx.0).unwrap())
46        }
47    }
48
49    pub fn real() -> Sort {
50        unsafe {
51            let ctx = &Context::thread_local();
52            Self::wrap(ctx, Z3_mk_real_sort(ctx.z3_ctx.0).unwrap())
53        }
54    }
55
56    pub fn float(ebits: u32, sbits: u32) -> Sort {
57        unsafe {
58            let ctx = &Context::thread_local();
59            Self::wrap(ctx, Z3_mk_fpa_sort(ctx.z3_ctx.0, ebits, sbits).unwrap())
60        }
61    }
62
63    pub fn float32() -> Sort {
64        unsafe {
65            let ctx = &Context::thread_local();
66            Self::wrap(ctx, Z3_mk_fpa_sort(ctx.z3_ctx.0, 8, 24).unwrap())
67        }
68    }
69
70    pub fn double() -> Sort {
71        unsafe {
72            let ctx = &Context::thread_local();
73            Self::wrap(ctx, Z3_mk_fpa_sort(ctx.z3_ctx.0, 11, 53).unwrap())
74        }
75    }
76
77    pub fn string() -> Sort {
78        unsafe {
79            let ctx = &Context::thread_local();
80            Self::wrap(ctx, Z3_mk_string_sort(ctx.z3_ctx.0).unwrap())
81        }
82    }
83
84    pub fn bitvector(sz: u32) -> Sort {
85        let ctx = &Context::thread_local();
86
87        unsafe {
88            Self::wrap(
89                ctx,
90                Z3_mk_bv_sort(ctx.z3_ctx.0, sz as ::std::os::raw::c_uint).unwrap(),
91            )
92        }
93    }
94
95    pub fn array(domain: &Sort, range: &Sort) -> Sort {
96        let ctx = &Context::thread_local();
97
98        unsafe {
99            Self::wrap(
100                ctx,
101                Z3_mk_array_sort(ctx.z3_ctx.0, domain.z3_sort, range.z3_sort).unwrap(),
102            )
103        }
104    }
105
106    pub fn set(elt: &Sort) -> Sort {
107        let ctx = &Context::thread_local();
108
109        unsafe { Self::wrap(ctx, Z3_mk_set_sort(ctx.z3_ctx.0, elt.z3_sort).unwrap()) }
110    }
111
112    pub fn seq(elt: &Sort) -> Sort {
113        let ctx = &Context::thread_local();
114
115        unsafe { Self::wrap(ctx, Z3_mk_seq_sort(ctx.z3_ctx.0, elt.z3_sort).unwrap()) }
116    }
117
118    /// Create an enumeration sort.
119    ///
120    /// Creates a Z3 enumeration sort with the given `name`.
121    /// The enum variants will have the names in `enum_names`.
122    /// Three things are returned:
123    /// - the created `Sort`,
124    /// - constants to create the variants,
125    /// - and testers to check if a value is equal to a variant.
126    ///
127    /// # Examples
128    /// ```
129    /// # use z3::{Config, Context, SatResult, Solver, Sort, Symbol};
130    /// # let cfg = Config::new();
131    /// # let solver = Solver::new();
132    /// let (colors, color_consts, color_testers) = Sort::enumeration(
133    ///     "Color".into(),
134    ///     &[
135    ///         "Red".into(),
136    ///         "Green".into(),
137    ///         "Blue".into(),
138    ///     ],
139    /// );
140    ///
141    /// let red_const = color_consts[0].apply(&[]);
142    /// let red_tester = &color_testers[0];
143    /// let eq = red_tester.apply(&[&red_const]);
144    ///
145    /// assert_eq!(solver.check(), SatResult::Sat);
146    /// let model = solver.get_model().unwrap();;
147    ///
148    /// assert!(model.eval(&eq, true).unwrap().as_bool().unwrap().as_bool().unwrap());
149    /// ```
150    pub fn enumeration(
151        name: Symbol,
152        enum_names: &[Symbol],
153    ) -> (Sort, Vec<FuncDecl>, Vec<FuncDecl>) {
154        let ctx = &Context::thread_local();
155        let enum_names: Vec<_> = enum_names.iter().map(|s| s.as_z3_symbol()).collect();
156        let mut enum_consts = vec![std::ptr::null_mut(); enum_names.len()];
157        let mut enum_testers = vec![std::ptr::null_mut(); enum_names.len()];
158
159        let sort = unsafe {
160            Self::wrap(
161                ctx,
162                Z3_mk_enumeration_sort(
163                    ctx.z3_ctx.0,
164                    name.as_z3_symbol(),
165                    enum_names.len().try_into().unwrap(),
166                    enum_names.as_ptr(),
167                    enum_consts.as_mut_ptr(),
168                    enum_testers.as_mut_ptr(),
169                )
170                .unwrap(),
171            )
172        };
173
174        // increase ref counts
175        for i in &enum_consts {
176            unsafe {
177                Z3_inc_ref(
178                    ctx.z3_ctx.0,
179                    Z3_func_decl_to_ast(ctx.z3_ctx.0, NonNull::new(*i).unwrap()).unwrap(),
180                );
181            }
182        }
183        for i in &enum_testers {
184            unsafe {
185                Z3_inc_ref(
186                    ctx.z3_ctx.0,
187                    Z3_func_decl_to_ast(ctx.z3_ctx.0, NonNull::new(*i).unwrap()).unwrap(),
188                );
189            }
190        }
191
192        // convert to Rust types
193        let enum_consts: Vec<_> = enum_consts
194            .into_iter()
195            .map(|z3_func_decl| unsafe { FuncDecl::wrap(ctx, NonNull::new(z3_func_decl).unwrap()) })
196            .collect();
197        let enum_testers: Vec<_> = enum_testers
198            .into_iter()
199            .map(|z3_func_decl| unsafe { FuncDecl::wrap(ctx, NonNull::new(z3_func_decl).unwrap()) })
200            .collect();
201
202        (sort, enum_consts, enum_testers)
203    }
204
205    pub fn kind(&self) -> SortKind {
206        unsafe { Z3_get_sort_kind(self.ctx.z3_ctx.0, self.z3_sort) }
207    }
208
209    /// Returns `Some(e)` where `e` is the number of exponent bits if the sort
210    /// is a `FloatingPoint` and `None` otherwise.
211    pub fn float_exponent_size(&self) -> Option<u32> {
212        if self.kind() == SortKind::FloatingPoint {
213            Some(unsafe { Z3_fpa_get_ebits(self.ctx.z3_ctx.0, self.z3_sort) })
214        } else {
215            None
216        }
217    }
218
219    /// Returns `Some(s)` where `s` is the number of significand bits if the sort
220    /// is a `FloatingPoint` and `None` otherwise.
221    pub fn float_significand_size(&self) -> Option<u32> {
222        if self.kind() == SortKind::FloatingPoint {
223            Some(unsafe { Z3_fpa_get_sbits(self.ctx.z3_ctx.0, self.z3_sort) })
224        } else {
225            None
226        }
227    }
228
229    /// Return if this Sort is for an `Array` or a `Set`.
230    ///
231    /// # Examples
232    /// ```
233    /// # use z3::{Config, Context, Sort, ast::Ast, ast::Int, ast::Bool};
234    /// let bool_sort = Sort::bool();
235    /// let int_sort = Sort::int();
236    /// let array_sort = Sort::array(&int_sort, &bool_sort);
237    /// let set_sort = Sort::set(&int_sort);
238    /// assert!(array_sort.is_array());
239    /// assert!(set_sort.is_array());
240    /// assert!(!int_sort.is_array());
241    /// assert!(!bool_sort.is_array());
242    /// ```
243    pub fn is_array(&self) -> bool {
244        self.kind() == SortKind::Array
245    }
246
247    /// Return the `Sort` of the domain for `Array`s of this `Sort`.
248    ///
249    /// If this `Sort` is an `Array` or `Set`, it has a domain sort, so return it.
250    /// If this is not an `Array` or `Set` `Sort`, return `None`.
251    /// # Examples
252    /// ```
253    /// # use z3::{Config, Context, Sort, ast::Ast, ast::Int, ast::Bool};
254    /// let bool_sort = Sort::bool();
255    /// let int_sort = Sort::int();
256    /// let array_sort = Sort::array(&int_sort, &bool_sort);
257    /// let set_sort = Sort::set(&int_sort);
258    /// assert_eq!(array_sort.array_domain().unwrap(), int_sort);
259    /// assert_eq!(set_sort.array_domain().unwrap(), int_sort);
260    /// assert!(int_sort.array_domain().is_none());
261    /// assert!(bool_sort.array_domain().is_none());
262    /// ```
263    pub fn array_domain(&self) -> Option<Sort> {
264        if self.is_array() {
265            unsafe {
266                let domain_sort = Z3_get_array_sort_domain(self.ctx.z3_ctx.0, self.z3_sort)?;
267                Some(Self::wrap(&self.ctx, domain_sort))
268            }
269        } else {
270            None
271        }
272    }
273
274    /// Return the `Sort` of the range for `Array`s of this `Sort`.
275    ///
276    /// If this `Sort` is an `Array` it has a range sort, so return it.
277    /// If this `Sort` is a `Set`, it has an implied range sort of `Bool`.
278    /// If this is not an `Array` or `Set` `Sort`, return `None`.
279    /// # Examples
280    /// ```
281    /// # use z3::{Config, Context, Sort, ast::Ast, ast::Int, ast::Bool};
282    /// let bool_sort = Sort::bool();
283    /// let int_sort = Sort::int();
284    /// let array_sort = Sort::array(&int_sort, &bool_sort);
285    /// let set_sort = Sort::set(&int_sort);
286    /// assert_eq!(array_sort.array_range().unwrap(), bool_sort);
287    /// assert_eq!(set_sort.array_range().unwrap(), bool_sort);
288    /// assert!(int_sort.array_range().is_none());
289    /// assert!(bool_sort.array_range().is_none());
290    /// ```
291    pub fn array_range(&self) -> Option<Sort> {
292        if self.is_array() {
293            unsafe {
294                let range_sort = Z3_get_array_sort_range(self.ctx.z3_ctx.0, self.z3_sort)?;
295                Some(Self::wrap(&self.ctx, range_sort))
296            }
297        } else {
298            None
299        }
300    }
301}
302
303impl Clone for Sort {
304    fn clone(&self) -> Self {
305        unsafe { Self::wrap(&self.ctx, self.z3_sort) }
306    }
307}
308
309impl fmt::Display for Sort {
310    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
311        let p = unsafe { Z3_sort_to_string(self.ctx.z3_ctx.0, self.z3_sort) };
312        if p.is_null() {
313            return Result::Err(fmt::Error);
314        }
315        match unsafe { CStr::from_ptr(p) }.to_str() {
316            Ok(s) => write!(f, "{s}"),
317            Err(_) => Result::Err(fmt::Error),
318        }
319    }
320}
321
322impl fmt::Debug for Sort {
323    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
324        <Self as fmt::Display>::fmt(self, f)
325    }
326}
327
328impl PartialEq<Sort> for Sort {
329    fn eq(&self, other: &Sort) -> bool {
330        unsafe { Z3_is_eq_sort(self.ctx.z3_ctx.0, self.z3_sort, other.z3_sort) }
331    }
332}
333
334impl Eq for Sort {}
335
336impl Drop for Sort {
337    fn drop(&mut self) {
338        unsafe {
339            Z3_dec_ref(
340                self.ctx.z3_ctx.0,
341                Z3_sort_to_ast(self.ctx.z3_ctx.0, self.z3_sort).unwrap(),
342            );
343        }
344    }
345}
346
347impl SortDiffers {
348    pub fn new(left: Sort, right: Sort) -> Self {
349        Self { left, right }
350    }
351
352    pub fn left(&self) -> &Sort {
353        &self.left
354    }
355
356    pub fn right(&self) -> &Sort {
357        &self.right
358    }
359}
360
361impl fmt::Display for SortDiffers {
362    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
363        write!(
364            f,
365            "Can not compare nodes, Sort does not match.  Nodes contain types {} and {}",
366            self.left, self.right
367        )
368    }
369}