diff --git a/src/base-term/Base_types.ml b/src/base-term/Base_types.ml index ae67a54f..7d45cf3b 100644 --- a/src/base-term/Base_types.ml +++ b/src/base-term/Base_types.ml @@ -18,12 +18,6 @@ and 'a term_view = | Eq of 'a * 'a | Not of 'a -(* boolean literal *) -and lit = { - lit_term: term; - lit_sign: bool; -} - and fun_ = { fun_id: ID.t; fun_view: fun_view; @@ -101,17 +95,8 @@ let[@inline] term_equal_ (a:term) b = a==b let[@inline] term_hash_ a = a.term_id let[@inline] term_cmp_ a b = CCInt.compare a.term_id b.term_id -let cmp_lit a b = - let c = CCBool.compare a.lit_sign b.lit_sign in - if c<>0 then c - else term_cmp_ a.lit_term b.lit_term - let fun_compare a b = ID.compare a.fun_id b.fun_id -let hash_lit a = - let sign = a.lit_sign in - Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term) - let pp_fun out a = ID.pp out a.fun_id let id_of_fun a = a.fun_id @@ -167,10 +152,6 @@ let pp_term_top ~ids out t = let pp_term = pp_term_top ~ids:false let pp_term_view = pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:pp_term -let pp_lit out l = - if l.lit_sign then pp_term out l.lit_term - else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term - module Ty_card : sig type t = ty_card = Finite | Infinite @@ -756,63 +737,6 @@ end = struct | Eq (a,b) -> eq tst (f a) (f b) end -module Lit : sig - type t = lit = { - lit_term: term; - lit_sign : bool - } - - val neg : t -> t - val abs : t -> t - val sign : t -> bool - val term : t -> term - val as_atom : t -> term * bool - val atom : Term.state -> ?sign:bool -> term -> t - val hash : t -> int - val compare : t -> t -> int - val equal : t -> t -> bool - val print : t Fmt.printer - val pp : t Fmt.printer - val apply_sign : t -> bool -> t - val norm_sign : t -> t * bool - val norm : t -> t * Msat.negated - module Set : CCSet.S with type elt = t - module Tbl : CCHashtbl.S with type key = t -end = struct - type t = lit = { - lit_term: term; - lit_sign : bool - } - - let[@inline] neg l = {l with lit_sign=not l.lit_sign} - let[@inline] sign t = t.lit_sign - let[@inline] term (t:t): term = t.lit_term - - let[@inline] abs t: t = {t with lit_sign=true} - - let make ~sign t = {lit_sign=sign; lit_term=t} - - let atom tst ?(sign=true) (t:term) : t = - let t, sign' = Term.abs tst t in - let sign = if not sign' then not sign else sign in - make ~sign t - - let[@inline] as_atom (lit:t) = lit.lit_term, lit.lit_sign - - let hash = hash_lit - let compare = cmp_lit - let[@inline] equal a b = compare a b = 0 - let pp = pp_lit - let print = pp - - let apply_sign t s = if s then t else neg t - let norm_sign l = if l.lit_sign then l, true else neg l, false - let norm l = let l, sign = norm_sign l in l, if sign then Msat.Same_sign else Msat.Negated - - module Set = CCSet.Make(struct type t = lit let compare=compare end) - module Tbl = CCHashtbl.Make(struct type t = lit let equal=equal let hash=hash end) -end - module Value : sig type t = value = | V_bool of bool @@ -872,10 +796,3 @@ module Proof = struct type t = Default let default = Default end - -module type CC_ACTIONS = sig - val raise_conflict : Lit.t list -> Proof.t -> 'a - val propagate : Lit.t -> reason:(unit -> Lit.t list) -> Proof.t -> unit -end - -type cc_actions = (module CC_ACTIONS) diff --git a/src/base-term/Sidekick_base_term.ml b/src/base-term/Sidekick_base_term.ml index 3bd3882f..535ce45e 100644 --- a/src/base-term/Sidekick_base_term.ml +++ b/src/base-term/Sidekick_base_term.ml @@ -8,17 +8,14 @@ module Term = Base_types.Term module Value = Base_types.Value module Term_cell = Base_types.Term_cell module Ty = Base_types.Ty -module Lit = Base_types.Lit module Arg - : Sidekick_core.TERM_LIT + : Sidekick_core.TERM with type Term.t = Term.t - and type Lit.t = Lit.t and type Fun.t = Fun.t and type Ty.t = Ty.t = struct module Term = Term - module Lit = Lit module Fun = Fun module Ty = Ty end diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index f66d5c94..2c5b5027 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -17,14 +17,14 @@ module Make(CC_A: ARG) = struct module A = CC_A.A type term = A.Term.t type term_state = A.Term.state - type lit = A.Lit.t + type lit = CC_A.Lit.t type fun_ = A.Fun.t type proof = A.Proof.t type actions = CC_A.Actions.t module T = A.Term module Fun = A.Fun - module Lit = A.Lit + module Lit = CC_A.Lit module Bits = CCBitField.Make() (* TODO: give theories the possibility to allocate new bits in nodes *) diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 044b03eb..03867f7a 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -78,6 +78,8 @@ module type TERM = sig val bool : state -> bool -> t (* build true/false *) val as_bool : t -> bool option + val abs : state -> t -> t * bool + val map_shallow : state -> (t -> t) -> t -> t (** Map function on immediate subterms *) @@ -87,56 +89,37 @@ module type TERM = sig end end -module type TERM_LIT = sig +module type TERM_PROOF = sig include TERM - module Lit : sig - type t - val neg : t -> t - val equal : t -> t -> bool - val compare : t -> t -> int - val hash : t -> int - val pp : t Fmt.printer - - val term : t -> Term.t - val sign : t -> bool - val abs : t -> t - val apply_sign : t -> bool -> t - val norm_sign : t -> t * bool - (** Invariant: if [u, sign = norm_sign t] then [apply_sign u sign = t] *) - - - val atom : Term.state -> ?sign:bool -> Term.t -> t - end -end - -module type TERM_LIT_PROOF = sig - include TERM_LIT - module Proof : sig type t val pp : t Fmt.printer val default : t - (* TODO: to give more details? or make this extensible? - or have a generative function for new proof cstors? - val cc_lemma : unit -> t - *) end end module type CC_ARG = sig - module A : TERM_LIT_PROOF + module A : TERM_PROOF val cc_view : A.Term.t -> (A.Fun.t, A.Term.t, A.Term.t Iter.t) CC_view.t (** View the term through the lens of the congruence closure *) + module Lit : sig + type t + val term : t -> A.Term.t + val sign : t -> bool + val neg : t -> t + val pp : t Fmt.printer + end + module Actions : sig type t - val raise_conflict : t -> A.Lit.t list -> A.Proof.t -> 'a + val raise_conflict : t -> Lit.t list -> A.Proof.t -> 'a - val propagate : t -> A.Lit.t -> reason:(unit -> A.Lit.t list) -> A.Proof.t -> unit + val propagate : t -> Lit.t -> reason:(unit -> Lit.t list) -> A.Proof.t -> unit end end @@ -146,7 +129,7 @@ module type CC_S = sig type term_state = A.Term.state type term = A.Term.t type fun_ = A.Fun.t - type lit = A.Lit.t + type lit = CC_A.Lit.t type proof = A.Proof.t type actions = CC_A.Actions.t @@ -302,12 +285,11 @@ end (** A view of the solver from a theory's point of view *) module type SOLVER_INTERNAL = sig - module A : TERM_LIT_PROOF + module A : TERM_PROOF module CC_A : CC_ARG with module A = A module CC : CC_S with module CC_A = CC_A type ty = A.Ty.t - type lit = A.Lit.t type term = A.Term.t type term_state = A.Term.state type proof = A.Proof.t @@ -324,6 +306,23 @@ module type SOLVER_INTERNAL = sig val cc : t -> CC.t (** Congruence closure for this solver *) + (** {3 Literals} + + A literal is a (preprocessed) term along with its sign. + It is directly manipulated by the SAT solver. + *) + module Lit : sig + type t + val term : t -> term + val sign : t -> bool + val neg : t -> t + + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + end + type lit = Lit.t + (** {3 Simplifiers} *) module Simplify : sig @@ -440,7 +439,11 @@ module type SOLVER_INTERNAL = sig literals suitable for reasoning. Typically some clauses are also added to the solver. *) - type preprocess_hook = t -> add_clause:(lit list -> unit) -> term -> term option + type preprocess_hook = + t -> + mk_lit:(term -> lit) -> + add_clause:(lit list -> unit) -> + term -> term option (** Given a term, try to preprocess it. Return [None] if it didn't change. Can also add clauses to define new terms. *) @@ -449,16 +452,18 @@ end (** Public view of the solver *) module type SOLVER = sig - module A : TERM_LIT_PROOF + module A : TERM_PROOF module CC_A : CC_ARG with module A = A module Solver_internal : SOLVER_INTERNAL with module A = A and module CC_A = CC_A (** Internal solver, available to theories. *) + module Lit = Solver_internal.Lit + type t type solver = t type term = A.Term.t type ty = A.Ty.t - type lit = A.Lit.t + type lit = Solver_internal.Lit.t type lemma = A.Proof.t (** {3 A theory} @@ -583,10 +588,6 @@ module type SOLVER = sig val mk_atom_t : t -> ?sign:bool -> term -> Atom.t - val add_clause_lits : t -> lit IArray.t -> unit - - val add_clause_lits_l : t -> lit list -> unit - val add_clause : t -> Atom.t IArray.t -> unit val add_clause_l : t -> Atom.t list -> unit diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index 1f85faac..a29916f3 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -5,30 +5,71 @@ module Log = Msat.Log module IM = Util.Int_map module type ARG = sig - include Sidekick_core.TERM_LIT_PROOF + include Sidekick_core.TERM_PROOF val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t end module type S = Sidekick_core.SOLVER module Make(A : ARG) - : S with module A = A +(* : S with module A = A *) = struct module A = A module T = A.Term - module Lit = A.Lit module Ty = A.Ty - type lit = Lit.t type term = T.t type ty = Ty.t type lemma = A.Proof.t + module Lit = struct + type t = { + lit_term: term; + lit_sign : bool + } + + let[@inline] neg l = {l with lit_sign=not l.lit_sign} + let[@inline] sign t = t.lit_sign + let[@inline] term (t:t): term = t.lit_term + + let[@inline] abs t: t = {t with lit_sign=true} + + let make ~sign t = {lit_sign=sign; lit_term=t} + + let atom tst ?(sign=true) (t:term) : t = + let t, sign' = T.abs tst t in + let sign = if not sign' then not sign else sign in + make ~sign t + + let[@inline] as_atom (lit:t) = lit.lit_term, lit.lit_sign + + let equal a b = + a.lit_sign = b.lit_sign && + T.equal a.lit_term b.lit_term + + let hash a = + let sign = a.lit_sign in + CCHash.combine3 2 (CCHash.bool sign) (T.hash a.lit_term) + + let pp out l = + if l.lit_sign then T.pp out l.lit_term + else Format.fprintf out "(@[@<1>¬@ %a@])" T.pp l.lit_term + + let print = pp + + let apply_sign t s = if s then t else neg t + let norm_sign l = if l.lit_sign then l, true else neg l, false + let norm l = let l, sign = norm_sign l in l, if sign then Msat.Same_sign else Msat.Negated + end + + type lit = Lit.t + (* actions from msat *) type msat_acts = (Msat.void, Lit.t, Msat.void, A.Proof.t) Msat.acts (* the full argument to the congruence closure *) module CC_A = struct module A = A + module Lit = Lit let cc_view = A.cc_view module Actions = struct @@ -49,6 +90,7 @@ module Make(A : ARG) module Solver_internal = struct module A = A module CC_A = CC_A + module Lit = Lit module CC = CC module N = CC.N type term = T.t @@ -119,7 +161,12 @@ module Make(A : ARG) mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> actions -> lit Iter.t -> unit) list; } - and preprocess_hook = t -> add_clause:(lit list -> unit) -> term -> term option + + and preprocess_hook = + t -> + mk_lit:(term -> lit) -> + add_clause:(lit list -> unit) -> + term -> term option type solver = t @@ -159,6 +206,7 @@ module Make(A : ARG) acts.Msat.acts_add_clause ~keep lits A.Proof.default let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit = + let mk_lit t = Lit.atom self.tst t in (* compute and cache normal form of [t] *) let rec aux t = match T.Tbl.find self.preprocess_cache t with @@ -174,7 +222,7 @@ module Make(A : ARG) and aux_rec t hooks = match hooks with | [] -> t | h :: hooks_tl -> - match h self ~add_clause t with + match h self ~mk_lit ~add_clause t with | None -> aux_rec t hooks_tl | Some u -> Log.debugf 30 @@ -188,7 +236,7 @@ module Make(A : ARG) (fun k->k "(@[msat-solver.preprocess@ :lit %a@ :into %a@])" Lit.pp lit Lit.pp lit'); lit' - let[@inline] mk_lit self acts ?sign t = + let mk_lit self acts ?sign t = let add_clause lits = Stat.incr self.count_preprocess_clause; add_sat_clause_ self acts ~keep:true lits @@ -423,16 +471,18 @@ module Make(A : ARG) ignore (mk_atom_t_ self sub : Sat_solver.atom)) let rec mk_atom_lit self lit : Atom.t = - let lit = + let lit = preprocess_lit_ self lit in + add_bool_subterms_ self (Lit.term lit); + Sat_solver.make_atom self.solver lit + + and preprocess_lit_ self lit : Lit.t = Solver_internal.preprocess_lit_ ~add_clause:(fun lits -> (* recursively add these sub-literals, so they're also properly processed *) Stat.incr self.si.count_preprocess_clause; let atoms = List.map (mk_atom_lit self) lits in Sat_solver.add_clause self.solver atoms A.Proof.default) - self.si lit in - add_bool_subterms_ self (Lit.term lit); - Sat_solver.make_atom self.solver lit + self.si lit let[@inline] mk_atom_t self ?sign t : Atom.t = let lit = Lit.atom (tst self) ?sign t in @@ -500,13 +550,6 @@ module Make(A : ARG) let add_clause_l self c = add_clause self (IArray.of_list c) - let add_clause_lits (self:t) (c:Lit.t IArray.t) : unit = - let c = IArray.map (mk_atom_lit self) c in - add_clause self c - - let add_clause_lits_l (self:t) (c:Lit.t list) : unit = - add_clause self (IArray.of_list_map (mk_atom_lit self) c) - (* TODO: remove? use a special constant + micro theory instead? let[@inline] assume_distinct self l ~neq lit : unit = CC.assert_distinct (cc self) l lit ~neq diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 71f30ca8..1287d5a0 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -406,6 +406,7 @@ end module Solver = Sidekick_msat_solver.Make(Solver_arg) module Check_cc = struct + module Lit = Solver.Lit module SI = Solver.Solver_internal module CC = Solver.Solver_internal.CC module MCC = Sidekick_mini_cc.Make(Solver_arg) @@ -604,9 +605,9 @@ let process_stmt if pp_cnf then ( Format.printf "(@[assert@ %a@])@." Term.pp t ); - let atom = Lit.atom tst t in + let atom = Solver.mk_atom_t solver t in CCOpt.iter (fun h -> Vec.push h [atom]) hyps; - Solver.add_clause_lits solver (IArray.singleton atom); + Solver.add_clause solver (IArray.singleton atom); E.return() | A.Goal (_, _) -> Error.errorf "cannot deal with goals yet" diff --git a/src/smtlib/Process.mli b/src/smtlib/Process.mli index 6d845728..9eb818f3 100644 --- a/src/smtlib/Process.mli +++ b/src/smtlib/Process.mli @@ -5,7 +5,6 @@ open Sidekick_base_term module Solver : Sidekick_msat_solver.S with type A.Term.t = Term.t and type A.Term.state = Term.state - and type A.Lit.t = Lit.t and type A.Ty.t = Ty.t val th_bool : Solver.theory @@ -24,7 +23,7 @@ module Check_cc : sig end val process_stmt : - ?hyps:Lit.t list Vec.t -> + ?hyps:Solver.Atom.t list Vec.t -> ?gc:bool -> ?restarts:bool -> ?pp_cnf:bool -> diff --git a/src/th-bool-static/Sidekick_th_bool_static.ml b/src/th-bool-static/Sidekick_th_bool_static.ml index 602ac914..e5ce0711 100644 --- a/src/th-bool-static/Sidekick_th_bool_static.ml +++ b/src/th-bool-static/Sidekick_th_bool_static.ml @@ -54,7 +54,7 @@ module Make(A : ARG) : S with module A = A = struct module A = A module Ty = A.S.A.Ty module T = A.S.A.Term - module Lit = A.S.A.Lit + module Lit = A.S.Lit module SI = A.S.Solver_internal type state = { @@ -122,28 +122,28 @@ module Make(A : ARG) : S with module A = A = struct | B_atom _ -> None let fresh_term self ~pre ty = A.Gensym.fresh_term self.gensym ~pre ty - let fresh_lit (self:state) ~pre : Lit.t = + let fresh_lit (self:state) ~mk_lit ~pre : Lit.t = let t = fresh_term ~pre self Ty.bool in - Lit.atom self.tst t + mk_lit t (* TODO: polarity? *) - let cnf (self:state) (_solver:SI.t) ~add_clause (t:T.t) : T.t option = + let cnf (self:state) (_si:SI.t) ~mk_lit ~add_clause (t:T.t) : T.t option = let rec get_lit (t:T.t) : Lit.t = match A.view_as_bool t with - | B_bool b -> Lit.atom self.tst ~sign:b (T.bool self.tst true) + | B_bool b -> mk_lit (T.bool self.tst b) | B_not u -> let lit = get_lit u in Lit.neg lit | B_and l -> let subs = IArray.to_list_map get_lit l in - let proxy = fresh_lit ~pre:"and_" self in + let proxy = fresh_lit ~mk_lit ~pre:"and_" self in (* add clauses *) List.iter (fun u -> add_clause [Lit.neg proxy; u]) subs; add_clause (proxy :: List.map Lit.neg subs); proxy | B_or l -> let subs = IArray.to_list_map get_lit l in - let proxy = fresh_lit ~pre:"or_" self in + let proxy = fresh_lit ~mk_lit ~pre:"or_" self in (* add clauses *) List.iter (fun u -> add_clause [Lit.neg u; proxy]) subs; add_clause (Lit.neg proxy :: subs); @@ -154,11 +154,11 @@ module Make(A : ARG) : S with module A = A = struct IArray.append (IArray.map (not_ self.tst) args) (IArray.singleton u) in get_lit t' | B_ite _ | B_eq _ -> - Lit.atom self.tst t + mk_lit t | B_equiv (a,b) -> let a = get_lit a in let b = get_lit b in - let proxy = fresh_lit ~pre:"equiv_" self in + let proxy = fresh_lit ~mk_lit ~pre:"equiv_" self in (* proxy => a<=> b, ¬proxy => a xor b *) add_clause [Lit.neg proxy; Lit.neg a; b]; @@ -166,7 +166,7 @@ module Make(A : ARG) : S with module A = A = struct add_clause [proxy; a; b]; add_clause [proxy; Lit.neg a; Lit.neg b]; proxy - | B_atom u -> Lit.atom self.tst u + | B_atom u -> mk_lit u in let lit = get_lit t in let u = Lit.term lit in