refactor: change the functor stack

This commit is contained in:
Simon Cruanes 2019-10-29 15:06:19 -05:00
parent 94ba04a49e
commit 7d8589accd
10 changed files with 213 additions and 179 deletions

View file

@ -1,3 +1,4 @@
open Sidekick_core
module View = Sidekick_core.CC_view module View = Sidekick_core.CC_view
type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t = type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t =
@ -9,21 +10,27 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t =
| Not of 't | Not of 't
| Opaque of 't (* do not enter *) | Opaque of 't (* do not enter *)
module type ARG = Sidekick_core.CC_ARG
module type S = Sidekick_core.CC_S module type S = Sidekick_core.CC_S
module Make(CC_A: ARG) = struct module Make (A: CC_ARG)
module A = CC_A : S with module T = A.T
type term = A.A.Term.t and module Lit = A.Lit
type term_state = A.A.Term.state and module P = A.P
type lit = A.Lit.t and module Actions = A.Actions
type fun_ = A.A.Fun.t = struct
type proof = A.A.Proof.t module T = A.T
type actions = A.Actions.t module P = A.P
module Lit = A.Lit
module Actions = A.Actions
type term = T.Term.t
type term_state = T.Term.state
type lit = Lit.t
type fun_ = T.Fun.t
type proof = P.t
type actions = Actions.t
module T = A.A.Term module Term = T.Term
module Fun = A.A.Fun module Fun = T.Fun
module Lit = CC_A.Lit
module Bits : sig module Bits : sig
type t = private int type t = private int
@ -94,9 +101,9 @@ module Make(CC_A: ARG) = struct
type t = node type t = node
let[@inline] equal (n1:t) n2 = n1 == n2 let[@inline] equal (n1:t) n2 = n1 == n2
let[@inline] hash n = T.hash n.n_term let[@inline] hash n = Term.hash n.n_term
let[@inline] term n = n.n_term let[@inline] term n = n.n_term
let[@inline] pp out n = T.pp out n.n_term let[@inline] pp out n = Term.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit let[@inline] as_lit n = n.n_as_lit
let make (t:term) : t = let make (t:term) : t =
@ -158,7 +165,7 @@ module Make(CC_A: ARG) = struct
| E_lit lit -> Lit.pp out lit | E_lit lit -> Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b
| E_and (a,b) -> | E_and (a,b) ->
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
@ -167,7 +174,7 @@ module Make(CC_A: ARG) = struct
let[@inline] mk_merge a b : t = let[@inline] mk_merge a b : t =
assert (same_class a b); assert (same_class a b);
if N.equal a b then mk_reduction else E_merge (a,b) if N.equal a b then mk_reduction else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b)
let[@inline] mk_lit l : t = E_lit l let[@inline] mk_lit l : t = E_lit l
let rec mk_list l = let rec mk_list l =
@ -227,7 +234,7 @@ module Make(CC_A: ARG) = struct
end end
module Sig_tbl = CCHashtbl.Make(Signature) module Sig_tbl = CCHashtbl.Make(Signature)
module T_tbl = CCHashtbl.Make(T) module T_tbl = CCHashtbl.Make(Term)
type combine_task = type combine_task =
| CT_merge of node * node * explanation | CT_merge of node * node * explanation
@ -299,7 +306,7 @@ module Make(CC_A: ARG) = struct
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl
in in
let pp_n out n = let pp_n out n =
Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n
and pp_sig_e out (s,n) = and pp_sig_e out (s,n) =
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n
in in
@ -357,7 +364,7 @@ module Make(CC_A: ARG) = struct
Vec.clear cc.combine; Vec.clear cc.combine;
List.iter (fun f -> f cc e) cc.on_conflict; List.iter (fun f -> f cc e) cc.on_conflict;
Stat.incr cc.count_conflict; Stat.incr cc.count_conflict;
CC_A.Actions.raise_conflict acts e A.A.Proof.default Actions.raise_conflict acts e P.default
let[@inline] all_classes cc : repr Iter.t = let[@inline] all_classes cc : repr Iter.t =
T_tbl.values cc.tbl T_tbl.values cc.tbl
@ -440,7 +447,7 @@ module Make(CC_A: ARG) = struct
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
| a, b -> explain_pair cc acc a b | a, b -> explain_pair cc acc a b
| exception Not_found -> | exception Not_found ->
Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b
end end
| E_and (a,b) -> | E_and (a,b) ->
let acc = explain_decompose cc acc a in let acc = explain_decompose cc acc a in
@ -477,7 +484,7 @@ module Make(CC_A: ARG) = struct
(* add [t] to [cc] when not present already *) (* add [t] to [cc] when not present already *)
and add_new_term_ cc (t:term) : node = and add_new_term_ cc (t:term) : node =
assert (not @@ mem cc t); assert (not @@ mem cc t);
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t); Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
let n = N.make t in let n = N.make t in
(* register sub-terms, add [t] to their parent list, and return the (* register sub-terms, add [t] to their parent list, and return the
corresponding initial signature *) corresponding initial signature *)
@ -486,7 +493,7 @@ module Make(CC_A: ARG) = struct
(* remove term when we backtrack *) (* remove term when we backtrack *)
on_backtrack cc on_backtrack cc
(fun () -> (fun () ->
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t); Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t);
T_tbl.remove cc.tbl t); T_tbl.remove cc.tbl t);
(* add term to the table *) (* add term to the table *)
T_tbl.add cc.tbl t n; T_tbl.add cc.tbl t n;
@ -514,7 +521,7 @@ module Make(CC_A: ARG) = struct
sub sub
in in
let[@inline] return x = Some x in let[@inline] return x = Some x in
match CC_A.cc_view n.n_term with match A.cc_view n.n_term with
| Bool _ | Opaque _ -> None | Bool _ | Opaque _ -> None
| Eq (a,b) -> | Eq (a,b) ->
let a = deref_sub a in let a = deref_sub a in
@ -733,7 +740,7 @@ module Make(CC_A: ARG) = struct
in in
List.iter (fun f -> f cc lit reason) cc.on_propagate; List.iter (fun f -> f cc lit reason) cc.on_propagate;
Stat.incr cc.count_props; Stat.incr cc.count_props;
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default Actions.propagate acts lit ~reason P.default
| _ -> ()) | _ -> ())
module Debug_ = struct module Debug_ = struct
@ -766,7 +773,7 @@ module Make(CC_A: ARG) = struct
let t = Lit.term lit in let t = Lit.term lit in
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit); Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
let sign = Lit.sign lit in let sign = Lit.sign lit in
begin match CC_A.cc_view t with begin match A.cc_view t with
| Eq (a,b) when sign -> | Eq (a,b) when sign ->
let a = add_term cc a in let a = add_term cc a in
let b = add_term cc b in let b = add_term cc b in
@ -837,9 +844,9 @@ module Make(CC_A: ARG) = struct
count_props=Stat.mk_int stat "cc.propagations"; count_props=Stat.mk_int stat "cc.propagations";
count_merge=Stat.mk_int stat "cc.merges"; count_merge=Stat.mk_int stat "cc.merges";
} and true_ = lazy ( } and true_ = lazy (
add_term cc (T.bool tst true) add_term cc (Term.bool tst true)
) and false_ = lazy ( ) and false_ = lazy (
add_term cc (T.bool tst false) add_term cc (Term.bool tst false)
) )
in in
ignore (Lazy.force true_ : node); ignore (Lazy.force true_ : node);

View file

@ -1,6 +1,10 @@
(** {2 Congruence Closure} *) (** {2 Congruence Closure} *)
module type ARG = Sidekick_core.CC_ARG open Sidekick_core
module type S = Sidekick_core.CC_S module type S = Sidekick_core.CC_S
module Make(A: ARG) : S with module A = A module Make (A: CC_ARG)
: S with module T = A.T
and module Lit = A.Lit
and module P = A.P
and module Actions = A.Actions

View file

@ -89,48 +89,58 @@ module type TERM = sig
end end
end end
module type TERM_PROOF = sig module type PROOF = sig
include TERM
module Proof : sig
type t type t
val pp : t Fmt.printer val pp : t Fmt.printer
val default : t val default : t
end end
module type LIT = sig
module T : TERM
type t
val term : t -> T.Term.t
val sign : t -> bool
val neg : t -> t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
end
module type CC_ACTIONS = sig
module T : TERM
module P : PROOF
module Lit : LIT with module T = T
type t
val raise_conflict : t -> Lit.t list -> P.t -> 'a
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list) -> P.t -> unit
end end
module type CC_ARG = sig module type CC_ARG = sig
module A : TERM_PROOF module T : TERM
module P : PROOF
module Lit : LIT with module T = T
module Actions : CC_ACTIONS with module T=T and module P = P and module Lit = Lit
val cc_view : A.Term.t -> (A.Fun.t, A.Term.t, A.Term.t Iter.t) CC_view.t val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
(** View the term through the lens of the congruence closure *) (** View the term through the lens of the congruence closure *)
module Lit : sig
type t
val term : t -> A.Term.t
val sign : t -> bool
val neg : t -> t
val pp : t Fmt.printer
end
module Actions : sig
type t
val raise_conflict : t -> Lit.t list -> A.Proof.t -> 'a
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list) -> A.Proof.t -> unit
end
end end
module type CC_S = sig module type CC_S = sig
module A : CC_ARG module T : TERM
type term_state = A.A.Term.state module P : PROOF
type term = A.A.Term.t module Lit : LIT with module T = T
type fun_ = A.A.Fun.t module Actions : CC_ACTIONS with module T = T and module Lit = Lit and module P = P
type lit = A.Lit.t type term_state = T.Term.state
type proof = A.A.Proof.t type term = T.Term.t
type actions = A.Actions.t type fun_ = T.Fun.t
type lit = Lit.t
type proof = P.t
type actions = Actions.t
type t type t
(** Global state of the congruence closure *) (** Global state of the congruence closure *)
@ -305,14 +315,13 @@ end
(** A view of the solver from a theory's point of view *) (** A view of the solver from a theory's point of view *)
module type SOLVER_INTERNAL = sig module type SOLVER_INTERNAL = sig
module A : TERM_PROOF module T : TERM
module CC_A : CC_ARG with module A = A module P : PROOF
module CC : CC_S with module A = CC_A
type ty = A.Ty.t type ty = T.Ty.t
type term = A.Term.t type term = T.Term.t
type term_state = A.Term.state type term_state = T.Term.state
type proof = A.Proof.t type proof = P.t
(** {3 Main type for a solver} *) (** {3 Main type for a solver} *)
type t type t
@ -320,26 +329,25 @@ module type SOLVER_INTERNAL = sig
val tst : t -> term_state val tst : t -> term_state
val cc : t -> CC.t
(** Congruence closure for this solver *)
(** {3 Literals} (** {3 Literals}
A literal is a (preprocessed) term along with its sign. A literal is a (preprocessed) term along with its sign.
It is directly manipulated by the SAT solver. It is directly manipulated by the SAT solver.
*) *)
module Lit : sig module Lit : LIT with module T = T
type t
val term : t -> term
val sign : t -> bool
val neg : t -> t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
end
type lit = Lit.t type lit = Lit.t
(** {2 Congruence Closure} *)
module CC : CC_S
with module T = T
and module P = P
and module Lit = Lit
val cc : t -> CC.t
(** Congruence closure for this solver *)
(** {3 Simplifiers} *) (** {3 Simplifiers} *)
module Simplify : sig module Simplify : sig
@ -368,11 +376,11 @@ module type SOLVER_INTERNAL = sig
(** {3 hooks for the theory} *) (** {3 hooks for the theory} *)
type actions = CC_A.Actions.t type actions
val propagate : t -> actions -> lit -> reason:(unit -> lit list) -> A.Proof.t -> unit val propagate : t -> actions -> lit -> reason:(unit -> lit list) -> proof -> unit
val raise_conflict : t -> actions -> lit list -> A.Proof.t -> 'a val raise_conflict : t -> actions -> lit list -> proof -> 'a
(** Give a conflict clause to the solver *) (** Give a conflict clause to the solver *)
val propagate: t -> actions -> lit -> (unit -> lit list) -> unit val propagate: t -> actions -> lit -> (unit -> lit list) -> unit
@ -472,17 +480,22 @@ end
(** Public view of the solver *) (** Public view of the solver *)
module type SOLVER = sig module type SOLVER = sig
module A : TERM_PROOF module T : TERM
module CC_A : CC_ARG with module A = A module P : PROOF
module Solver_internal : SOLVER_INTERNAL with module A = A and module CC_A = CC_A module Lit : LIT with module T = T
module Solver_internal
: SOLVER_INTERNAL
with module T = T
and module P = P
and module Lit = Lit
(** Internal solver, available to theories. *) (** Internal solver, available to theories. *)
type t type t
type solver = t type solver = t
type term = A.Term.t type term = T.Term.t
type ty = A.Ty.t type ty = T.Ty.t
type lit = Solver_internal.Lit.t type lit = Lit.t
type lemma = A.Proof.t type lemma = P.t
(** {3 A theory} (** {3 A theory}
@ -590,7 +603,7 @@ module type SOLVER = sig
(** {3 Main API} *) (** {3 Main API} *)
val stats : t -> Stat.t val stats : t -> Stat.t
val tst : t -> A.Term.state val tst : t -> T.Term.state
val create : val create :
?stat:Stat.t -> ?stat:Stat.t ->
@ -598,7 +611,7 @@ module type SOLVER = sig
(* TODO? ?config:Config.t -> *) (* TODO? ?config:Config.t -> *)
?store_proof:bool -> ?store_proof:bool ->
theories:theory list -> theories:theory list ->
A.Term.state -> T.Term.state ->
unit -> unit ->
t t
(** Create a new solver. (** Create a new solver.

View file

@ -1,9 +1,9 @@
module CC_view = Sidekick_core.CC_view module CC_view = Sidekick_core.CC_view
module type ARG = sig module type ARG = sig
include Sidekick_core.TERM module T : Sidekick_core.TERM
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
end end
module type S = sig module type S = sig
@ -26,11 +26,11 @@ end
module Make(A: ARG) = struct module Make(A: ARG) = struct
open CC_view open CC_view
module Fun = A.Fun module Fun = A.T.Fun
module T = A.Term module T = A.T.Term
type fun_ = A.Fun.t type fun_ = A.T.Fun.t
type term = T.t type term = T.t
type term_state = A.Term.state type term_state = T.state
module T_tbl = CCHashtbl.Make(T) module T_tbl = CCHashtbl.Make(T)

View file

@ -8,9 +8,9 @@
module CC_view = Sidekick_core.CC_view module CC_view = Sidekick_core.CC_view
module type ARG = sig module type ARG = sig
include Sidekick_core.TERM module T : Sidekick_core.TERM
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
end end
module type S = sig module type S = sig
@ -38,6 +38,6 @@ module type S = sig
end end
module Make(A: ARG) module Make(A: ARG)
: S with type term = A.Term.t : S with type term = A.T.Term.t
and type fun_ = A.Fun.t and type fun_ = A.T.Fun.t
and type term_state = A.Term.state and type term_state = A.T.Term.state

View file

@ -5,11 +5,13 @@ module Log = Msat.Log
module IM = Util.Int_map module IM = Util.Int_map
module type ARG = sig module type ARG = sig
include Sidekick_core.TERM_PROOF open Sidekick_core
module T : TERM
module P : PROOF
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
val is_valid_literal : Term.t -> bool val is_valid_literal : T.Term.t -> bool
(** Is this a valid boolean literal? (e.g. is it a closed term, not inside (** Is this a valid boolean literal? (e.g. is it a closed term, not inside
a quantifier) *) a quantifier) *)
end end
@ -17,16 +19,20 @@ end
module type S = Sidekick_core.SOLVER module type S = Sidekick_core.SOLVER
module Make(A : ARG) module Make(A : ARG)
: S with module A = A : S
with module T = A.T
and module P = A.P
= struct = struct
module A = A module T = A.T
module T = A.Term module P = A.P
module Ty = A.Ty module Ty = T.Ty
type term = T.t module Term = T.Term
type term = Term.t
type ty = Ty.t type ty = Ty.t
type lemma = A.Proof.t type lemma = P.t
module Lit_ = struct module Lit_ = struct
module T = T
type t = { type t = {
lit_term: term; lit_term: term;
lit_sign : bool lit_sign : bool
@ -39,21 +45,21 @@ module Make(A : ARG)
let make ~sign t = {lit_sign=sign; lit_term=t} let make ~sign t = {lit_sign=sign; lit_term=t}
let atom tst ?(sign=true) (t:term) : t = let atom tst ?(sign=true) (t:term) : t =
let t, sign' = T.abs tst t in let t, sign' = Term.abs tst t in
let sign = if not sign' then not sign else sign in let sign = if not sign' then not sign else sign in
make ~sign t make ~sign t
let equal a b = let equal a b =
a.lit_sign = b.lit_sign && a.lit_sign = b.lit_sign &&
T.equal a.lit_term b.lit_term Term.equal a.lit_term b.lit_term
let hash a = let hash a =
let sign = a.lit_sign in let sign = a.lit_sign in
CCHash.combine3 2 (CCHash.bool sign) (T.hash a.lit_term) CCHash.combine3 2 (CCHash.bool sign) (Term.hash a.lit_term)
let pp out l = let pp out l =
if l.lit_sign then T.pp out l.lit_term if l.lit_sign then Term.pp out l.lit_term
else Format.fprintf out "(@[@<1>¬@ %a@])" T.pp l.lit_term else Format.fprintf out "(@[@<1>¬@ %a@])" Term.pp l.lit_term
let apply_sign t s = if s then t else neg t let apply_sign t s = if s then t else neg t
let norm_sign l = if l.lit_sign then l, true else neg l, false let norm_sign l = if l.lit_sign then l, true else neg l, false
@ -63,15 +69,19 @@ module Make(A : ARG)
type lit = Lit_.t type lit = Lit_.t
(* actions from msat *) (* actions from msat *)
type msat_acts = (Msat.void, lit, Msat.void, A.Proof.t) Msat.acts type msat_acts = (Msat.void, lit, Msat.void, P.t) Msat.acts
(* the full argument to the congruence closure *) (* the full argument to the congruence closure *)
module CC_A = struct module CC_actions = struct
module A = A module T = T
module P = P
module Lit = Lit_ module Lit = Lit_
let cc_view = A.cc_view let cc_view = A.cc_view
module Actions = struct module Actions = struct
module T = T
module P = P
module Lit = Lit
type t = msat_acts type t = msat_acts
let[@inline] raise_conflict a lits pr = let[@inline] raise_conflict a lits pr =
a.Msat.acts_raise_conflict lits pr a.Msat.acts_raise_conflict lits pr
@ -81,21 +91,21 @@ module Make(A : ARG)
end end
end end
module CC = Sidekick_cc.Make(CC_A) module CC = Sidekick_cc.Make(CC_actions)
module Expl = CC.Expl module Expl = CC.Expl
module N = CC.N module N = CC.N
(** Internal solver, given to theories and to Msat *) (** Internal solver, given to theories and to Msat *)
module Solver_internal = struct module Solver_internal = struct
module A = A module T = T
module CC_A = CC_A module P = P
module Lit = Lit_ module Lit = Lit_
module CC = CC module CC = CC
module N = CC.N module N = CC.N
type term = T.t type term = Term.t
type ty = Ty.t type ty = Ty.t
type lit = Lit.t type lit = Lit.t
type term_state = T.state type term_state = Term.state
type th_states = type th_states =
| Ths_nil | Ths_nil
@ -112,33 +122,33 @@ module Make(A : ARG)
type t = { type t = {
tst: term_state; tst: term_state;
mutable hooks: hook list; mutable hooks: hook list;
cache: T.t T.Tbl.t; cache: Term.t Term.Tbl.t;
} }
and hook = t -> term -> term option and hook = t -> term -> term option
let create tst : t = {tst; hooks=[]; cache=T.Tbl.create 32;} let create tst : t = {tst; hooks=[]; cache=Term.Tbl.create 32;}
let[@inline] tst self = self.tst let[@inline] tst self = self.tst
let add_hook self f = self.hooks <- f :: self.hooks let add_hook self f = self.hooks <- f :: self.hooks
let clear self = T.Tbl.clear self.cache let clear self = Term.Tbl.clear self.cache
let normalize (self:t) (t:T.t) : T.t = let normalize (self:t) (t:Term.t) : Term.t =
(* compute and cache normal form of [t] *) (* compute and cache normal form of [t] *)
let rec aux t = let rec aux t =
match T.Tbl.find self.cache t with match Term.Tbl.find self.cache t with
| u -> u | u -> u
| exception Not_found -> | exception Not_found ->
let u = aux_rec t self.hooks in let u = aux_rec t self.hooks in
T.Tbl.add self.cache t u; Term.Tbl.add self.cache t u;
u u
(* try each function in [hooks] successively, and rewrite subterms *) (* try each function in [hooks] successively, and rewrite subterms *)
and aux_rec t hooks = match hooks with and aux_rec t hooks = match hooks with
| [] -> | [] ->
let u = T.map_shallow self.tst aux t in let u = Term.map_shallow self.tst aux t in
if T.equal t u then t else aux u if Term.equal t u then t else aux u
| h :: hooks_tl -> | h :: hooks_tl ->
match h self t with match h self t with
| None -> aux_rec t hooks_tl | None -> aux_rec t hooks_tl
| Some u when T.equal t u -> aux_rec t hooks_tl | Some u when Term.equal t u -> aux_rec t hooks_tl
| Some u -> aux u | Some u -> aux u
in in
aux t aux t
@ -146,7 +156,7 @@ module Make(A : ARG)
type simplify_hook = Simplify.hook type simplify_hook = Simplify.hook
type t = { type t = {
tst: T.state; (** state for managing terms *) tst: Term.state; (** state for managing terms *)
cc: CC.t lazy_t; (** congruence closure *) cc: CC.t lazy_t; (** congruence closure *)
stat: Stat.t; stat: Stat.t;
count_axiom: int Stat.counter; count_axiom: int Stat.counter;
@ -156,7 +166,7 @@ module Make(A : ARG)
mutable on_progress: unit -> unit; mutable on_progress: unit -> unit;
simp: Simplify.t; simp: Simplify.t;
mutable preprocess: preprocess_hook list; mutable preprocess: preprocess_hook list;
preprocess_cache: T.t T.Tbl.t; preprocess_cache: Term.t Term.Tbl.t;
mutable th_states : th_states; (** Set of theories *) mutable th_states : th_states; (** Set of theories *)
mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list; mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> actions -> lit Iter.t -> unit) list;
@ -179,44 +189,44 @@ module Make(A : ARG)
module Eq_class = CC.N module Eq_class = CC.N
module Expl = CC.Expl module Expl = CC.Expl
type proof = A.Proof.t type proof = P.t
let[@inline] cc (t:t) = Lazy.force t.cc let[@inline] cc (t:t) = Lazy.force t.cc
let[@inline] tst t = t.tst let[@inline] tst t = t.tst
let simplifier self = self.simp let simplifier self = self.simp
let simp_t self (t:T.t) : T.t = Simplify.normalize self.simp t let simp_t self (t:Term.t) : Term.t = Simplify.normalize self.simp t
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
let add_preprocess self f = self.preprocess <- f :: self.preprocess let add_preprocess self f = self.preprocess <- f :: self.preprocess
let[@inline] raise_conflict self acts c : 'a = let[@inline] raise_conflict self acts c : 'a =
Stat.incr self.count_conflict; Stat.incr self.count_conflict;
acts.Msat.acts_raise_conflict c A.Proof.default acts.Msat.acts_raise_conflict c P.default
let[@inline] propagate self acts p cs : unit = let[@inline] propagate self acts p cs : unit =
Stat.incr self.count_propagate; Stat.incr self.count_propagate;
acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), A.Proof.default)) acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), P.default))
let[@inline] propagate_l self acts p cs : unit = let[@inline] propagate_l self acts p cs : unit =
propagate self acts p (fun()->cs) propagate self acts p (fun()->cs)
let add_sat_clause_ self acts ~keep lits : unit = let add_sat_clause_ self acts ~keep lits : unit =
Stat.incr self.count_axiom; Stat.incr self.count_axiom;
acts.Msat.acts_add_clause ~keep lits A.Proof.default acts.Msat.acts_add_clause ~keep lits P.default
let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit = let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit =
let mk_lit t = Lit.atom self.tst t in let mk_lit t = Lit.atom self.tst t in
(* compute and cache normal form of [t] *) (* compute and cache normal form of [t] *)
let rec aux t = let rec aux t =
match T.Tbl.find self.preprocess_cache t with match Term.Tbl.find self.preprocess_cache t with
| u -> u | u -> u
| exception Not_found -> | exception Not_found ->
(* first, map subterms *) (* first, map subterms *)
let u = T.map_shallow self.tst aux t in let u = Term.map_shallow self.tst aux t in
(* then rewrite *) (* then rewrite *)
let u = aux_rec u self.preprocess in let u = aux_rec u self.preprocess in
T.Tbl.add self.preprocess_cache t u; Term.Tbl.add self.preprocess_cache t u;
u u
(* try each function in [hooks] successively *) (* try each function in [hooks] successively *)
and aux_rec t hooks = match hooks with and aux_rec t hooks = match hooks with
@ -227,7 +237,7 @@ module Make(A : ARG)
| Some u -> | Some u ->
Log.debugf 30 Log.debugf 30
(fun k->k "(@[msat-solver.preprocess.step@ :from %a@ :to %a@])" (fun k->k "(@[msat-solver.preprocess.step@ :from %a@ :to %a@])"
T.pp t T.pp u); Term.pp t Term.pp u);
aux u aux u
in in
let t = Lit.term lit |> simp_t self |> aux in let t = Lit.term lit |> simp_t self |> aux in
@ -339,7 +349,7 @@ module Make(A : ARG)
CC.mk_model (cc self) m CC.mk_model (cc self) m
*) *)
let create ~stat (tst:A.Term.state) () : t = let create ~stat (tst:Term.state) () : t =
let rec self = { let rec self = {
tst; tst;
cc = lazy ( cc = lazy (
@ -351,7 +361,7 @@ module Make(A : ARG)
simp=Simplify.create tst; simp=Simplify.create tst;
on_progress=(fun () -> ()); on_progress=(fun () -> ());
preprocess=[]; preprocess=[];
preprocess_cache=T.Tbl.create 32; preprocess_cache=Term.Tbl.create 32;
count_axiom = Stat.mk_int stat "solver.th-axioms"; count_axiom = Stat.mk_int stat "solver.th-axioms";
count_preprocess_clause = Stat.mk_int stat "solver.preprocess-clause"; count_preprocess_clause = Stat.mk_int stat "solver.preprocess-clause";
count_propagate = Stat.mk_int stat "solver.th-propagations"; count_propagate = Stat.mk_int stat "solver.th-propagations";
@ -439,8 +449,8 @@ module Make(A : ARG)
begin begin
let tst = Solver_internal.tst self.si in let tst = Solver_internal.tst self.si in
Sat_solver.assume self.solver [ Sat_solver.assume self.solver [
[Lit.atom tst @@ T.bool tst true]; [Lit.atom tst @@ Term.bool tst true];
] A.Proof.default; ] P.default;
end; end;
self self
@ -456,9 +466,9 @@ module Make(A : ARG)
mk_atom_lit_ self lit mk_atom_lit_ self lit
(* map boolean subterms to literals *) (* map boolean subterms to literals *)
let add_bool_subterms_ (self:t) (t:T.t) : unit = let add_bool_subterms_ (self:t) (t:Term.t) : unit =
T.iter_dag t Term.iter_dag t
|> Iter.filter (fun t -> Ty.is_bool @@ T.ty t) |> Iter.filter (fun t -> Ty.is_bool @@ Term.ty t)
|> Iter.filter |> Iter.filter
(fun t -> match A.cc_view t with (fun t -> match A.cc_view t with
| Sidekick_core.CC_view.Not _ -> false (* will process the subterm just later *) | Sidekick_core.CC_view.Not _ -> false (* will process the subterm just later *)
@ -466,7 +476,7 @@ module Make(A : ARG)
|> Iter.filter (fun t -> A.is_valid_literal t) |> Iter.filter (fun t -> A.is_valid_literal t)
|> Iter.iter |> Iter.iter
(fun sub -> (fun sub ->
Log.debugf 5 (fun k->k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" T.pp sub); Log.debugf 5 (fun k->k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" Term.pp sub);
(* ensure that msat has a boolean atom for [sub] *) (* ensure that msat has a boolean atom for [sub] *)
let atom = mk_atom_t_ self sub in let atom = mk_atom_t_ self sub in
(* also map [sub] to this atom in the congruence closure, for propagation *) (* also map [sub] to this atom in the congruence closure, for propagation *)
@ -485,7 +495,7 @@ module Make(A : ARG)
(* recursively add these sub-literals, so they're also properly processed *) (* recursively add these sub-literals, so they're also properly processed *)
Stat.incr self.si.count_preprocess_clause; Stat.incr self.si.count_preprocess_clause;
let atoms = List.map (mk_atom_lit self) lits in let atoms = List.map (mk_atom_lit self) lits in
Sat_solver.add_clause self.solver atoms A.Proof.default) Sat_solver.add_clause self.solver atoms P.default)
self.si lit self.si lit
let[@inline] mk_atom_t self ?sign t : Atom.t = let[@inline] mk_atom_t self ?sign t : Atom.t =
@ -507,28 +517,28 @@ module Make(A : ARG)
end [@@ocaml.warning "-37"] end [@@ocaml.warning "-37"]
(* just use terms as values *) (* just use terms as values *)
module Value = A.Term module Value = Term
module Model = struct module Model = struct
type t = type t =
| Empty | Empty
| Map of Value.t A.Term.Tbl.t | Map of Value.t Term.Tbl.t
let empty = Empty let empty = Empty
let mem = function let mem = function
| Empty -> fun _ -> false | Empty -> fun _ -> false
| Map tbl -> A.Term.Tbl.mem tbl | Map tbl -> Term.Tbl.mem tbl
let find = function let find = function
| Empty -> fun _ -> None | Empty -> fun _ -> None
| Map tbl -> A.Term.Tbl.get tbl | Map tbl -> Term.Tbl.get tbl
let eval = find let eval = find
let pp out = function let pp out = function
| Empty -> Fmt.string out "(model)" | Empty -> Fmt.string out "(model)"
| Map tbl -> | Map tbl ->
let pp_pair out (t,v) = let pp_pair out (t,v) =
Fmt.fprintf out "(@[<1>%a@ := %a@])" A.Term.pp t Value.pp v Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Value.pp v
in in
Fmt.fprintf out "(@[<hv>model@ %a@])" Fmt.fprintf out "(@[<hv>model@ %a@])"
(Util.pp_seq pp_pair) (A.Term.Tbl.to_seq tbl) (Util.pp_seq pp_pair) (Term.Tbl.to_seq tbl)
end end
type res = type res =
@ -551,7 +561,7 @@ module Make(A : ARG)
let add_clause (self:t) (c:Atom.t IArray.t) : unit = let add_clause (self:t) (c:Atom.t IArray.t) : unit =
Stat.incr self.count_clause; Stat.incr self.count_clause;
Sat_solver.add_clause_a self.solver (c:> Atom.t array) A.Proof.default Sat_solver.add_clause_a self.solver (c:> Atom.t array) P.default
let add_clause_l self c = add_clause self (IArray.of_list c) let add_clause_l self c = add_clause self (IArray.of_list c)
@ -562,13 +572,13 @@ module Make(A : ARG)
let mk_model (self:t) (lits:lit Iter.t) : Model.t = let mk_model (self:t) (lits:lit Iter.t) : Model.t =
Log.debug 1 "(smt.solver.mk-model)"; Log.debug 1 "(smt.solver.mk-model)";
let module M = A.Term.Tbl in let module M = Term.Tbl in
let m = M.create 128 in let m = M.create 128 in
let tst = self.si.tst in let tst = self.si.tst in
(* first, add all boolean *) (* first, add all boolean *)
lits lits
(fun {Lit.lit_term=t;lit_sign=sign} -> (fun {Lit.lit_term=t;lit_sign=sign} ->
M.replace m t (A.Term.bool tst sign)); M.replace m t (Term.bool tst sign));
(* then add CC classes *) (* then add CC classes *)
Solver_internal.CC.all_classes (Solver_internal.cc self.si) Solver_internal.CC.all_classes (Solver_internal.cc self.si)
(fun repr -> (fun repr ->

View file

@ -399,11 +399,11 @@ let conv_term = Conv.conv_term
(* instantiate solver here *) (* instantiate solver here *)
module Solver_arg = struct module Solver_arg = struct
include Sidekick_base_term module T = Sidekick_base_term
let cc_view = Term.cc_view let cc_view = Term.cc_view
let is_valid_literal _ = true let is_valid_literal _ = true
module Proof = struct module P = struct
type t = Default type t = Default
let default=Default let default=Default
let pp out _ = Fmt.string out "default" let pp out _ = Fmt.string out "default"
@ -643,7 +643,7 @@ let process_stmt
module Th_bool = Sidekick_th_bool_static.Make(struct module Th_bool = Sidekick_th_bool_static.Make(struct
module S = Solver module S = Solver
type term = S.A.Term.t type term = S.T.Term.t
include Form include Form
end) end)

View file

@ -3,9 +3,9 @@
open Sidekick_base_term open Sidekick_base_term
module Solver module Solver
: Sidekick_msat_solver.S with type A.Term.t = Term.t : Sidekick_msat_solver.S with type T.Term.t = Term.t
and type A.Term.state = Term.state and type T.Term.state = Term.state
and type A.Ty.t = Ty.t and type T.Ty.t = Ty.t
val th_bool : Solver.theory val th_bool : Solver.theory

View file

@ -15,12 +15,12 @@ type 'a bool_view =
module type ARG = sig module type ARG = sig
module S : Sidekick_core.SOLVER module S : Sidekick_core.SOLVER
type term = S.A.Term.t type term = S.T.Term.t
val view_as_bool : term -> term bool_view val view_as_bool : term -> term bool_view
(** Project the term into the boolean view *) (** Project the term into the boolean view *)
val mk_bool : S.A.Term.state -> term bool_view -> term val mk_bool : S.T.Term.state -> term bool_view -> term
(** Make a term from the given boolean view *) (** Make a term from the given boolean view *)
val check_congruence_classes : bool val check_congruence_classes : bool
@ -32,9 +32,9 @@ module type ARG = sig
module Gensym : sig module Gensym : sig
type t type t
val create : S.A.Term.state -> t val create : S.T.Term.state -> t
val fresh_term : t -> pre:string -> S.A.Ty.t -> term val fresh_term : t -> pre:string -> S.T.Ty.t -> term
(** Make a fresh term of the given type *) (** Make a fresh term of the given type *)
end end
end end
@ -44,7 +44,7 @@ module type S = sig
type state type state
val create : A.S.A.Term.state -> state val create : A.S.T.Term.state -> state
val simplify : state -> A.S.Solver_internal.simplify_hook val simplify : state -> A.S.Solver_internal.simplify_hook
(** Simplify given term *) (** Simplify given term *)
@ -57,8 +57,8 @@ end
module Make(A : ARG) : S with module A = A = struct module Make(A : ARG) : S with module A = A = struct
module A = A module A = A
module Ty = A.S.A.Ty module Ty = A.S.T.Ty
module T = A.S.A.Term module T = A.S.T.Term
module Lit = A.S.Solver_internal.Lit module Lit = A.S.Solver_internal.Lit
module SI = A.S.Solver_internal module SI = A.S.Solver_internal

View file

@ -8,7 +8,7 @@ let name = "th-cstor"
module type ARG = sig module type ARG = sig
module S : Sidekick_core.SOLVER module S : Sidekick_core.SOLVER
val view_as_cstor : S.A.Term.t -> (S.A.Fun.t, S.A.Term.t) cstor_view val view_as_cstor : S.T.Term.t -> (S.T.Fun.t, S.T.Term.t) cstor_view
end end
module type S = sig module type S = sig
@ -19,9 +19,9 @@ end
module Make(A : ARG) : S with module A = A = struct module Make(A : ARG) : S with module A = A = struct
module A = A module A = A
module SI = A.S.Solver_internal module SI = A.S.Solver_internal
module T = A.S.A.Term module T = A.S.T.Term
module N = SI.CC.N module N = SI.CC.N
module Fun = A.S.A.Fun module Fun = A.S.T.Fun
module Expl = SI.CC.Expl module Expl = SI.CC.Expl
type cstor_repr = { type cstor_repr = {