diff --git a/src/arith/base-term/Base_types.ml b/src/arith/base-term/Base_types.ml index dd60dc5d..ed7df0eb 100644 --- a/src/arith/base-term/Base_types.ml +++ b/src/arith/base-term/Base_types.ml @@ -209,7 +209,7 @@ let string_of_lra_pred = function | Leq -> "<=" | Neq -> "!=" | Eq -> "=" - | Gt-> ">" + | Gt -> ">" | Geq -> ">=" let pp_pred out p = Fmt.string out (string_of_lra_pred p) diff --git a/src/arith/lra/Sidekick_arith_lra.ml b/src/arith/lra/Sidekick_arith_lra.ml index 0f9911eb..543e9383 100644 --- a/src/arith/lra/Sidekick_arith_lra.ml +++ b/src/arith/lra/Sidekick_arith_lra.ml @@ -142,22 +142,22 @@ module Make(A : ARG) : S with module A = A = struct let t = fresh_term ~pre self Ty.bool in mk_lit t - (* turn the term into a linear expression *) - let rec as_linexp (t:T.t) : LE.t = + (* turn the term into a linear expression. Apply [f] on leaves. *) + let rec as_linexp ~f (t:T.t) : LE.t = let open LE.Infix in match A.view_as_lra t with - | LRA_other _ -> LE.var t + | LRA_other _ -> LE.var (f t) | LRA_pred _ -> Error.errorf "type error: in linexp, LRA predicate %a" T.pp t | LRA_op (op, t1, t2) -> - let t1 = as_linexp t1 in - let t2 = as_linexp t2 in + let t1 = as_linexp ~f t1 in + let t2 = as_linexp ~f t2 in begin match op with | Plus -> t1 + t2 | Minus -> t1 - t2 end | LRA_mult (n, x) -> - let t = as_linexp x in + let t = as_linexp ~f x in LE.( n * t ) | LRA_const q -> LE.const q @@ -166,21 +166,19 @@ module Make(A : ARG) : S with module A = A = struct *) (* preprocess linear expressions away *) - let preproc_lra self si ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option = + let preproc_lra self si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option = Log.debugf 50 (fun k->k "lra.preprocess %a" T.pp t); let _tst = SI.tst si in match A.view_as_lra t with | LRA_pred (pred, t1, t2) -> - (* TODO: map preproc on [l1] and [l2] *) - let l1 = as_linexp t1 in - let l2 = as_linexp t2 in + let l1 = as_linexp ~f:recurse t1 in + let l2 = as_linexp ~f:recurse t2 in let proxy = fresh_term self ~pre:"_pred_lra_" Ty.bool in T.Tbl.add self.pred_defs proxy (pred, l1, l2); Log.debugf 5 (fun k->k"lra.preprocess.step %a :into %a" T.pp t T.pp proxy); Some proxy | LRA_op _ | LRA_mult _ -> - let le = as_linexp t in - (* TODO: map preproc on [le] *) + let le = as_linexp ~f:recurse t in let proxy = fresh_term self ~pre:"_e_lra_" (T.ty t) in self.t_defs <- (proxy, le) :: self.t_defs; Log.debugf 5 (fun k->k"lra.preprocess.step %a :into %a" T.pp t T.pp proxy); diff --git a/src/arith/lra/fourier_motzkin.ml b/src/arith/lra/fourier_motzkin.ml index 91d894f5..3efb5ba7 100644 --- a/src/arith/lra/fourier_motzkin.ml +++ b/src/arith/lra/fourier_motzkin.ml @@ -59,6 +59,8 @@ module type S = sig val find_exn : term -> t -> Q.t val find : term -> t -> Q.t option +(* val map : (term -> term) -> t -> t *) + module Infix : sig val (+) : t -> t -> t val (-) : t -> t -> t @@ -112,10 +114,10 @@ module Make(A : ARG) let zero = const Q.zero let var x : t = {const=Q.zero; le=M.singleton x Q.one} - let find_exn v le = M.find v le.le - let find v le = M.get v le.le + let[@inline] find_exn v le = M.find v le.le + let[@inline] find v le = M.get v le.le - let remove v le : t = {le with le=M.remove v le.le} + let[@inline] remove v le : t = {le with le=M.remove v le.le} let neg a : t = {const=Q.neg a.const; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index bc6ba4ac..127aa65d 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -481,11 +481,16 @@ module type SOLVER_INTERNAL = sig type preprocess_hook = t -> + recurse:(term -> term) -> mk_lit:(term -> lit) -> add_clause:(lit list -> unit) -> term -> term option (** Given a term, try to preprocess it. Return [None] if it didn't change. - Can also add clauses to define new terms. *) + Can also add clauses to define new terms. + @param recurse call preprocessor on subterms. + @param mk_lit creates a new literal for a boolean term. + @param add_clause pushes a new clause into the SAT solver. + *) val add_preprocess : t -> preprocess_hook -> unit end diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index 617837bd..b5036682 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -174,6 +174,7 @@ module Make(A : ARG) and preprocess_hook = t -> + recurse:(term -> term) -> mk_lit:(term -> lit) -> add_clause:(lit list -> unit) -> term -> term option @@ -223,23 +224,27 @@ module Make(A : ARG) match Term.Tbl.find self.preprocess_cache t with | u -> u | exception Not_found -> - (* first, map subterms *) - let u = Term.map_shallow self.tst aux t in - (* then rewrite *) - let u = aux_rec u self.preprocess in + (* try rewrite here *) + let u = + match aux_rec t self.preprocess with + | None -> + Term.map_shallow self.tst aux t (* just map subterms *) + | Some u -> u + in Term.Tbl.add self.preprocess_cache t u; u (* try each function in [hooks] successively *) and aux_rec t hooks = match hooks with - | [] -> t + | [] -> None | h :: hooks_tl -> - match h self ~mk_lit ~add_clause t with + match h self ~recurse:aux ~mk_lit ~add_clause t with | None -> aux_rec t hooks_tl | Some u -> - Log.debugf 30 + Log.debugf 30 (fun k->k "(@[msat-solver.preprocess.step@ :from %a@ :to %a@])" Term.pp t Term.pp u); - aux u + let u' = aux u in + Some u' in let t = Lit.term lit |> simp_t self |> aux in let lit' = Lit.atom self.tst ~sign:(Lit.sign lit) t in diff --git a/src/smtlib/Typecheck.ml b/src/smtlib/Typecheck.ml index 221fad4e..ce89c1d4 100644 --- a/src/smtlib/Typecheck.ml +++ b/src/smtlib/Typecheck.ml @@ -294,9 +294,8 @@ let rec conv_term (ctx:Ctx.t) (t:PA.term) : T.t = | PA.Div, [a;b] -> begin match t_as_q a, t_as_q b with | Some a, Some b -> T.lra ctx.tst (LRA_const (Q.div a b)) - | Some a, _ -> T.lra ctx.tst (LRA_mult (Q.inv a, b)) | _, Some b -> T.lra ctx.tst (LRA_mult (Q.inv b, a)) - | None, None -> + | _, None -> errorf_ctx ctx "cannot handle non-linear div %a" PA.pp_term t end | _ -> diff --git a/src/th-bool-static/Sidekick_th_bool_static.ml b/src/th-bool-static/Sidekick_th_bool_static.ml index 7f729e11..bedf0033 100644 --- a/src/th-bool-static/Sidekick_th_bool_static.ml +++ b/src/th-bool-static/Sidekick_th_bool_static.ml @@ -84,7 +84,7 @@ module Make(A : ARG) : S with module A = A = struct let is_true t = match T.as_bool t with Some true -> true | _ -> false let is_false t = match T.as_bool t with Some false -> true | _ -> false - + let simplify (self:state) (simp:SI.Simplify.t) (t:T.t) : T.t option = let tst = self.tst in match A.view_as_bool t with @@ -133,18 +133,26 @@ module Make(A : ARG) : S with module A = A = struct mk_lit t (* preprocess "ite" away *) - let preproc_ite self _si ~mk_lit ~add_clause (t:T.t) : T.t option = + let preproc_ite self _si ~recurse ~mk_lit ~add_clause (t:T.t) : T.t option = match A.view_as_bool t with | B_ite (a,b,c) -> - let t_a = fresh_term self ~pre:"ite" (T.ty b) in - let lit_a = mk_lit a in - add_clause [Lit.neg lit_a; mk_lit (eq self.tst t_a b)]; - add_clause [lit_a; mk_lit (eq self.tst t_a c)]; - Some t_a + let a = recurse a in + begin match A.view_as_bool a with + | B_bool true -> Some (recurse b) + | B_bool false -> Some (recurse c) + | _ -> + let t_a = fresh_term self ~pre:"ite" (T.ty b) in + let lit_a = mk_lit a in + let b = recurse b in + let c = recurse c in + add_clause [Lit.neg lit_a; mk_lit (eq self.tst t_a b)]; + add_clause [lit_a; mk_lit (eq self.tst t_a c)]; + Some t_a + end | _ -> None (* TODO: polarity? *) - let cnf (self:state) (_si:SI.t) ~mk_lit ~add_clause (t:T.t) : T.t option = + let cnf (self:state) (_si:SI.t) ~recurse:_ ~mk_lit ~add_clause (t:T.t) : T.t option = let rec get_lit (t:T.t) : Lit.t = let t_abs, t_sign = T.abs self.tst t in let lit = @@ -217,6 +225,7 @@ module Make(A : ARG) : S with module A = A = struct in let cnf_of t = cnf self si t + ~recurse:(fun t -> t) ~mk_lit:(SI.mk_lit si acts) ~add_clause:(SI.add_clause_permanent si acts) in begin