mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-10 13:14:09 -05:00
refactor: change the functor stack
This commit is contained in:
parent
94ba04a49e
commit
7d8589accd
10 changed files with 213 additions and 179 deletions
|
|
@ -1,3 +1,4 @@
|
|||
open Sidekick_core
|
||||
module View = Sidekick_core.CC_view
|
||||
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t =
|
||||
|
|
@ -9,21 +10,27 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t =
|
|||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
module type ARG = Sidekick_core.CC_ARG
|
||||
module type S = Sidekick_core.CC_S
|
||||
|
||||
module Make(CC_A: ARG) = struct
|
||||
module A = CC_A
|
||||
type term = A.A.Term.t
|
||||
type term_state = A.A.Term.state
|
||||
type lit = A.Lit.t
|
||||
type fun_ = A.A.Fun.t
|
||||
type proof = A.A.Proof.t
|
||||
type actions = A.Actions.t
|
||||
module Make (A: CC_ARG)
|
||||
: S with module T = A.T
|
||||
and module Lit = A.Lit
|
||||
and module P = A.P
|
||||
and module Actions = A.Actions
|
||||
= struct
|
||||
module T = A.T
|
||||
module P = A.P
|
||||
module Lit = A.Lit
|
||||
module Actions = A.Actions
|
||||
type term = T.Term.t
|
||||
type term_state = T.Term.state
|
||||
type lit = Lit.t
|
||||
type fun_ = T.Fun.t
|
||||
type proof = P.t
|
||||
type actions = Actions.t
|
||||
|
||||
module T = A.A.Term
|
||||
module Fun = A.A.Fun
|
||||
module Lit = CC_A.Lit
|
||||
module Term = T.Term
|
||||
module Fun = T.Fun
|
||||
|
||||
module Bits : sig
|
||||
type t = private int
|
||||
|
|
@ -94,9 +101,9 @@ module Make(CC_A: ARG) = struct
|
|||
type t = node
|
||||
|
||||
let[@inline] equal (n1:t) n2 = n1 == n2
|
||||
let[@inline] hash n = T.hash n.n_term
|
||||
let[@inline] hash n = Term.hash n.n_term
|
||||
let[@inline] term n = n.n_term
|
||||
let[@inline] pp out n = T.pp out n.n_term
|
||||
let[@inline] pp out n = Term.pp out n.n_term
|
||||
let[@inline] as_lit n = n.n_as_lit
|
||||
|
||||
let make (t:term) : t =
|
||||
|
|
@ -158,7 +165,7 @@ module Make(CC_A: ARG) = struct
|
|||
| E_lit lit -> Lit.pp out lit
|
||||
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
|
||||
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
|
||||
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b
|
||||
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b
|
||||
| E_and (a,b) ->
|
||||
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
|
||||
|
||||
|
|
@ -167,7 +174,7 @@ module Make(CC_A: ARG) = struct
|
|||
let[@inline] mk_merge a b : t =
|
||||
assert (same_class a b);
|
||||
if N.equal a b then mk_reduction else E_merge (a,b)
|
||||
let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b)
|
||||
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b)
|
||||
let[@inline] mk_lit l : t = E_lit l
|
||||
|
||||
let rec mk_list l =
|
||||
|
|
@ -227,7 +234,7 @@ module Make(CC_A: ARG) = struct
|
|||
end
|
||||
|
||||
module Sig_tbl = CCHashtbl.Make(Signature)
|
||||
module T_tbl = CCHashtbl.Make(T)
|
||||
module T_tbl = CCHashtbl.Make(Term)
|
||||
|
||||
type combine_task =
|
||||
| CT_merge of node * node * explanation
|
||||
|
|
@ -299,7 +306,7 @@ module Make(CC_A: ARG) = struct
|
|||
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl
|
||||
in
|
||||
let pp_n out n =
|
||||
Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n
|
||||
Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n
|
||||
and pp_sig_e out (s,n) =
|
||||
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n
|
||||
in
|
||||
|
|
@ -357,7 +364,7 @@ module Make(CC_A: ARG) = struct
|
|||
Vec.clear cc.combine;
|
||||
List.iter (fun f -> f cc e) cc.on_conflict;
|
||||
Stat.incr cc.count_conflict;
|
||||
CC_A.Actions.raise_conflict acts e A.A.Proof.default
|
||||
Actions.raise_conflict acts e P.default
|
||||
|
||||
let[@inline] all_classes cc : repr Iter.t =
|
||||
T_tbl.values cc.tbl
|
||||
|
|
@ -440,7 +447,7 @@ module Make(CC_A: ARG) = struct
|
|||
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
|
||||
| a, b -> explain_pair cc acc a b
|
||||
| exception Not_found ->
|
||||
Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b
|
||||
Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b
|
||||
end
|
||||
| E_and (a,b) ->
|
||||
let acc = explain_decompose cc acc a in
|
||||
|
|
@ -477,7 +484,7 @@ module Make(CC_A: ARG) = struct
|
|||
(* add [t] to [cc] when not present already *)
|
||||
and add_new_term_ cc (t:term) : node =
|
||||
assert (not @@ mem cc t);
|
||||
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t);
|
||||
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
|
||||
let n = N.make t in
|
||||
(* register sub-terms, add [t] to their parent list, and return the
|
||||
corresponding initial signature *)
|
||||
|
|
@ -486,7 +493,7 @@ module Make(CC_A: ARG) = struct
|
|||
(* remove term when we backtrack *)
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t);
|
||||
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t);
|
||||
T_tbl.remove cc.tbl t);
|
||||
(* add term to the table *)
|
||||
T_tbl.add cc.tbl t n;
|
||||
|
|
@ -514,7 +521,7 @@ module Make(CC_A: ARG) = struct
|
|||
sub
|
||||
in
|
||||
let[@inline] return x = Some x in
|
||||
match CC_A.cc_view n.n_term with
|
||||
match A.cc_view n.n_term with
|
||||
| Bool _ | Opaque _ -> None
|
||||
| Eq (a,b) ->
|
||||
let a = deref_sub a in
|
||||
|
|
@ -733,7 +740,7 @@ module Make(CC_A: ARG) = struct
|
|||
in
|
||||
List.iter (fun f -> f cc lit reason) cc.on_propagate;
|
||||
Stat.incr cc.count_props;
|
||||
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default
|
||||
Actions.propagate acts lit ~reason P.default
|
||||
| _ -> ())
|
||||
|
||||
module Debug_ = struct
|
||||
|
|
@ -766,7 +773,7 @@ module Make(CC_A: ARG) = struct
|
|||
let t = Lit.term lit in
|
||||
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
|
||||
let sign = Lit.sign lit in
|
||||
begin match CC_A.cc_view t with
|
||||
begin match A.cc_view t with
|
||||
| Eq (a,b) when sign ->
|
||||
let a = add_term cc a in
|
||||
let b = add_term cc b in
|
||||
|
|
@ -837,9 +844,9 @@ module Make(CC_A: ARG) = struct
|
|||
count_props=Stat.mk_int stat "cc.propagations";
|
||||
count_merge=Stat.mk_int stat "cc.merges";
|
||||
} and true_ = lazy (
|
||||
add_term cc (T.bool tst true)
|
||||
add_term cc (Term.bool tst true)
|
||||
) and false_ = lazy (
|
||||
add_term cc (T.bool tst false)
|
||||
add_term cc (Term.bool tst false)
|
||||
)
|
||||
in
|
||||
ignore (Lazy.force true_ : node);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
(** {2 Congruence Closure} *)
|
||||
|
||||
module type ARG = Sidekick_core.CC_ARG
|
||||
open Sidekick_core
|
||||
module type S = Sidekick_core.CC_S
|
||||
|
||||
module Make(A: ARG) : S with module A = A
|
||||
module Make (A: CC_ARG)
|
||||
: S with module T = A.T
|
||||
and module Lit = A.Lit
|
||||
and module P = A.P
|
||||
and module Actions = A.Actions
|
||||
|
|
|
|||
|
|
@ -89,48 +89,58 @@ module type TERM = sig
|
|||
end
|
||||
end
|
||||
|
||||
module type TERM_PROOF = sig
|
||||
include TERM
|
||||
module type PROOF = sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
|
||||
module Proof : sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
val default : t
|
||||
end
|
||||
|
||||
val default : t
|
||||
end
|
||||
module type LIT = sig
|
||||
module T : TERM
|
||||
type t
|
||||
|
||||
val term : t -> T.Term.t
|
||||
val sign : t -> bool
|
||||
val neg : t -> t
|
||||
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module type CC_ACTIONS = sig
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
module Lit : LIT with module T = T
|
||||
type t
|
||||
|
||||
val raise_conflict : t -> Lit.t list -> P.t -> 'a
|
||||
|
||||
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list) -> P.t -> unit
|
||||
end
|
||||
|
||||
module type CC_ARG = sig
|
||||
module A : TERM_PROOF
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
module Lit : LIT with module T = T
|
||||
module Actions : CC_ACTIONS with module T=T and module P = P and module Lit = Lit
|
||||
|
||||
val cc_view : A.Term.t -> (A.Fun.t, A.Term.t, A.Term.t Iter.t) CC_view.t
|
||||
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
|
||||
(** View the term through the lens of the congruence closure *)
|
||||
|
||||
module Lit : sig
|
||||
type t
|
||||
val term : t -> A.Term.t
|
||||
val sign : t -> bool
|
||||
val neg : t -> t
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Actions : sig
|
||||
type t
|
||||
|
||||
val raise_conflict : t -> Lit.t list -> A.Proof.t -> 'a
|
||||
|
||||
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list) -> A.Proof.t -> unit
|
||||
end
|
||||
end
|
||||
|
||||
module type CC_S = sig
|
||||
module A : CC_ARG
|
||||
type term_state = A.A.Term.state
|
||||
type term = A.A.Term.t
|
||||
type fun_ = A.A.Fun.t
|
||||
type lit = A.Lit.t
|
||||
type proof = A.A.Proof.t
|
||||
type actions = A.Actions.t
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
module Lit : LIT with module T = T
|
||||
module Actions : CC_ACTIONS with module T = T and module Lit = Lit and module P = P
|
||||
type term_state = T.Term.state
|
||||
type term = T.Term.t
|
||||
type fun_ = T.Fun.t
|
||||
type lit = Lit.t
|
||||
type proof = P.t
|
||||
type actions = Actions.t
|
||||
|
||||
type t
|
||||
(** Global state of the congruence closure *)
|
||||
|
|
@ -305,14 +315,13 @@ end
|
|||
|
||||
(** A view of the solver from a theory's point of view *)
|
||||
module type SOLVER_INTERNAL = sig
|
||||
module A : TERM_PROOF
|
||||
module CC_A : CC_ARG with module A = A
|
||||
module CC : CC_S with module A = CC_A
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
|
||||
type ty = A.Ty.t
|
||||
type term = A.Term.t
|
||||
type term_state = A.Term.state
|
||||
type proof = A.Proof.t
|
||||
type ty = T.Ty.t
|
||||
type term = T.Term.t
|
||||
type term_state = T.Term.state
|
||||
type proof = P.t
|
||||
|
||||
(** {3 Main type for a solver} *)
|
||||
type t
|
||||
|
|
@ -320,26 +329,25 @@ module type SOLVER_INTERNAL = sig
|
|||
|
||||
val tst : t -> term_state
|
||||
|
||||
val cc : t -> CC.t
|
||||
(** Congruence closure for this solver *)
|
||||
|
||||
(** {3 Literals}
|
||||
|
||||
A literal is a (preprocessed) term along with its sign.
|
||||
It is directly manipulated by the SAT solver.
|
||||
*)
|
||||
module Lit : sig
|
||||
type t
|
||||
val term : t -> term
|
||||
val sign : t -> bool
|
||||
val neg : t -> t
|
||||
module Lit : LIT with module T = T
|
||||
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
type lit = Lit.t
|
||||
|
||||
(** {2 Congruence Closure} *)
|
||||
|
||||
module CC : CC_S
|
||||
with module T = T
|
||||
and module P = P
|
||||
and module Lit = Lit
|
||||
|
||||
val cc : t -> CC.t
|
||||
(** Congruence closure for this solver *)
|
||||
|
||||
(** {3 Simplifiers} *)
|
||||
|
||||
module Simplify : sig
|
||||
|
|
@ -368,11 +376,11 @@ module type SOLVER_INTERNAL = sig
|
|||
|
||||
(** {3 hooks for the theory} *)
|
||||
|
||||
type actions = CC_A.Actions.t
|
||||
type actions
|
||||
|
||||
val propagate : t -> actions -> lit -> reason:(unit -> lit list) -> A.Proof.t -> unit
|
||||
val propagate : t -> actions -> lit -> reason:(unit -> lit list) -> proof -> unit
|
||||
|
||||
val raise_conflict : t -> actions -> lit list -> A.Proof.t -> 'a
|
||||
val raise_conflict : t -> actions -> lit list -> proof -> 'a
|
||||
(** Give a conflict clause to the solver *)
|
||||
|
||||
val propagate: t -> actions -> lit -> (unit -> lit list) -> unit
|
||||
|
|
@ -472,17 +480,22 @@ end
|
|||
|
||||
(** Public view of the solver *)
|
||||
module type SOLVER = sig
|
||||
module A : TERM_PROOF
|
||||
module CC_A : CC_ARG with module A = A
|
||||
module Solver_internal : SOLVER_INTERNAL with module A = A and module CC_A = CC_A
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
module Lit : LIT with module T = T
|
||||
module Solver_internal
|
||||
: SOLVER_INTERNAL
|
||||
with module T = T
|
||||
and module P = P
|
||||
and module Lit = Lit
|
||||
(** Internal solver, available to theories. *)
|
||||
|
||||
type t
|
||||
type solver = t
|
||||
type term = A.Term.t
|
||||
type ty = A.Ty.t
|
||||
type lit = Solver_internal.Lit.t
|
||||
type lemma = A.Proof.t
|
||||
type term = T.Term.t
|
||||
type ty = T.Ty.t
|
||||
type lit = Lit.t
|
||||
type lemma = P.t
|
||||
|
||||
(** {3 A theory}
|
||||
|
||||
|
|
@ -590,7 +603,7 @@ module type SOLVER = sig
|
|||
(** {3 Main API} *)
|
||||
|
||||
val stats : t -> Stat.t
|
||||
val tst : t -> A.Term.state
|
||||
val tst : t -> T.Term.state
|
||||
|
||||
val create :
|
||||
?stat:Stat.t ->
|
||||
|
|
@ -598,7 +611,7 @@ module type SOLVER = sig
|
|||
(* TODO? ?config:Config.t -> *)
|
||||
?store_proof:bool ->
|
||||
theories:theory list ->
|
||||
A.Term.state ->
|
||||
T.Term.state ->
|
||||
unit ->
|
||||
t
|
||||
(** Create a new solver.
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
module CC_view = Sidekick_core.CC_view
|
||||
|
||||
module type ARG = sig
|
||||
include Sidekick_core.TERM
|
||||
module T : Sidekick_core.TERM
|
||||
|
||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
|
||||
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
|
|
@ -26,11 +26,11 @@ end
|
|||
module Make(A: ARG) = struct
|
||||
open CC_view
|
||||
|
||||
module Fun = A.Fun
|
||||
module T = A.Term
|
||||
type fun_ = A.Fun.t
|
||||
module Fun = A.T.Fun
|
||||
module T = A.T.Term
|
||||
type fun_ = A.T.Fun.t
|
||||
type term = T.t
|
||||
type term_state = A.Term.state
|
||||
type term_state = T.state
|
||||
|
||||
module T_tbl = CCHashtbl.Make(T)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@
|
|||
module CC_view = Sidekick_core.CC_view
|
||||
|
||||
module type ARG = sig
|
||||
include Sidekick_core.TERM
|
||||
module T : Sidekick_core.TERM
|
||||
|
||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
|
||||
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
|
|
@ -38,6 +38,6 @@ module type S = sig
|
|||
end
|
||||
|
||||
module Make(A: ARG)
|
||||
: S with type term = A.Term.t
|
||||
and type fun_ = A.Fun.t
|
||||
and type term_state = A.Term.state
|
||||
: S with type term = A.T.Term.t
|
||||
and type fun_ = A.T.Fun.t
|
||||
and type term_state = A.T.Term.state
|
||||
|
|
|
|||
|
|
@ -5,11 +5,13 @@ module Log = Msat.Log
|
|||
module IM = Util.Int_map
|
||||
|
||||
module type ARG = sig
|
||||
include Sidekick_core.TERM_PROOF
|
||||
open Sidekick_core
|
||||
module T : TERM
|
||||
module P : PROOF
|
||||
|
||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t
|
||||
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
|
||||
|
||||
val is_valid_literal : Term.t -> bool
|
||||
val is_valid_literal : T.Term.t -> bool
|
||||
(** Is this a valid boolean literal? (e.g. is it a closed term, not inside
|
||||
a quantifier) *)
|
||||
end
|
||||
|
|
@ -17,16 +19,20 @@ end
|
|||
module type S = Sidekick_core.SOLVER
|
||||
|
||||
module Make(A : ARG)
|
||||
: S with module A = A
|
||||
: S
|
||||
with module T = A.T
|
||||
and module P = A.P
|
||||
= struct
|
||||
module A = A
|
||||
module T = A.Term
|
||||
module Ty = A.Ty
|
||||
type term = T.t
|
||||
module T = A.T
|
||||
module P = A.P
|
||||
module Ty = T.Ty
|
||||
module Term = T.Term
|
||||
type term = Term.t
|
||||
type ty = Ty.t
|
||||
type lemma = A.Proof.t
|
||||
type lemma = P.t
|
||||
|
||||
module Lit_ = struct
|
||||
module T = T
|
||||
type t = {
|
||||
lit_term: term;
|
||||
lit_sign : bool
|
||||
|
|
@ -39,21 +45,21 @@ module Make(A : ARG)
|
|||
let make ~sign t = {lit_sign=sign; lit_term=t}
|
||||
|
||||
let atom tst ?(sign=true) (t:term) : t =
|
||||
let t, sign' = T.abs tst t in
|
||||
let t, sign' = Term.abs tst t in
|
||||
let sign = if not sign' then not sign else sign in
|
||||
make ~sign t
|
||||
|
||||
let equal a b =
|
||||
a.lit_sign = b.lit_sign &&
|
||||
T.equal a.lit_term b.lit_term
|
||||
Term.equal a.lit_term b.lit_term
|
||||
|
||||
let hash a =
|
||||
let sign = a.lit_sign in
|
||||
CCHash.combine3 2 (CCHash.bool sign) (T.hash a.lit_term)
|
||||
CCHash.combine3 2 (CCHash.bool sign) (Term.hash a.lit_term)
|
||||
|
||||
let pp out l =
|
||||
if l.lit_sign then T.pp out l.lit_term
|
||||
else Format.fprintf out "(@[@<1>¬@ %a@])" T.pp l.lit_term
|
||||
if l.lit_sign then Term.pp out l.lit_term
|
||||
else Format.fprintf out "(@[@<1>¬@ %a@])" Term.pp l.lit_term
|
||||
|
||||
let apply_sign t s = if s then t else neg t
|
||||
let norm_sign l = if l.lit_sign then l, true else neg l, false
|
||||
|
|
@ -63,15 +69,19 @@ module Make(A : ARG)
|
|||
type lit = Lit_.t
|
||||
|
||||
(* actions from msat *)
|
||||
type msat_acts = (Msat.void, lit, Msat.void, A.Proof.t) Msat.acts
|
||||
type msat_acts = (Msat.void, lit, Msat.void, P.t) Msat.acts
|
||||
|
||||
(* the full argument to the congruence closure *)
|
||||
module CC_A = struct
|
||||
module A = A
|
||||
module CC_actions = struct
|
||||
module T = T
|
||||
module P = P
|
||||
module Lit = Lit_
|
||||
let cc_view = A.cc_view
|
||||
|
||||
module Actions = struct
|
||||
module T = T
|
||||
module P = P
|
||||
module Lit = Lit
|
||||
type t = msat_acts
|
||||
let[@inline] raise_conflict a lits pr =
|
||||
a.Msat.acts_raise_conflict lits pr
|
||||
|
|
@ -81,21 +91,21 @@ module Make(A : ARG)
|
|||
end
|
||||
end
|
||||
|
||||
module CC = Sidekick_cc.Make(CC_A)
|
||||
module CC = Sidekick_cc.Make(CC_actions)
|
||||
module Expl = CC.Expl
|
||||
module N = CC.N
|
||||
|
||||
(** Internal solver, given to theories and to Msat *)
|
||||
module Solver_internal = struct
|
||||
module A = A
|
||||
module CC_A = CC_A
|
||||
module T = T
|
||||
module P = P
|
||||
module Lit = Lit_
|
||||
module CC = CC
|
||||
module N = CC.N
|
||||
type term = T.t
|
||||
type term = Term.t
|
||||
type ty = Ty.t
|
||||
type lit = Lit.t
|
||||
type term_state = T.state
|
||||
type term_state = Term.state
|
||||
|
||||
type th_states =
|
||||
| Ths_nil
|
||||
|
|
@ -112,33 +122,33 @@ module Make(A : ARG)
|
|||
type t = {
|
||||
tst: term_state;
|
||||
mutable hooks: hook list;
|
||||
cache: T.t T.Tbl.t;
|
||||
cache: Term.t Term.Tbl.t;
|
||||
}
|
||||
and hook = t -> term -> term option
|
||||
|
||||
let create tst : t = {tst; hooks=[]; cache=T.Tbl.create 32;}
|
||||
let create tst : t = {tst; hooks=[]; cache=Term.Tbl.create 32;}
|
||||
let[@inline] tst self = self.tst
|
||||
let add_hook self f = self.hooks <- f :: self.hooks
|
||||
let clear self = T.Tbl.clear self.cache
|
||||
let clear self = Term.Tbl.clear self.cache
|
||||
|
||||
let normalize (self:t) (t:T.t) : T.t =
|
||||
let normalize (self:t) (t:Term.t) : Term.t =
|
||||
(* compute and cache normal form of [t] *)
|
||||
let rec aux t =
|
||||
match T.Tbl.find self.cache t with
|
||||
match Term.Tbl.find self.cache t with
|
||||
| u -> u
|
||||
| exception Not_found ->
|
||||
let u = aux_rec t self.hooks in
|
||||
T.Tbl.add self.cache t u;
|
||||
Term.Tbl.add self.cache t u;
|
||||
u
|
||||
(* try each function in [hooks] successively, and rewrite subterms *)
|
||||
and aux_rec t hooks = match hooks with
|
||||
| [] ->
|
||||
let u = T.map_shallow self.tst aux t in
|
||||
if T.equal t u then t else aux u
|
||||
let u = Term.map_shallow self.tst aux t in
|
||||
if Term.equal t u then t else aux u
|
||||
| h :: hooks_tl ->
|
||||
match h self t with
|
||||
| None -> aux_rec t hooks_tl
|
||||
| Some u when T.equal t u -> aux_rec t hooks_tl
|
||||
| Some u when Term.equal t u -> aux_rec t hooks_tl
|
||||
| Some u -> aux u
|
||||
in
|
||||
aux t
|
||||
|
|
@ -146,7 +156,7 @@ module Make(A : ARG)
|
|||
type simplify_hook = Simplify.hook
|
||||
|
||||
type t = {
|
||||
tst: T.state; (** state for managing terms *)
|
||||
tst: Term.state; (** state for managing terms *)
|
||||
cc: CC.t lazy_t; (** congruence closure *)
|
||||
stat: Stat.t;
|
||||
count_axiom: int Stat.counter;
|
||||
|
|
@ -156,7 +166,7 @@ module Make(A : ARG)
|
|||
mutable on_progress: unit -> unit;
|
||||
simp: Simplify.t;
|
||||
mutable preprocess: preprocess_hook list;
|
||||
preprocess_cache: T.t T.Tbl.t;
|
||||
preprocess_cache: Term.t Term.Tbl.t;
|
||||
mutable th_states : th_states; (** Set of theories *)
|
||||
mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list;
|
||||
mutable on_final_check: (t -> actions -> lit Iter.t -> unit) list;
|
||||
|
|
@ -179,44 +189,44 @@ module Make(A : ARG)
|
|||
module Eq_class = CC.N
|
||||
module Expl = CC.Expl
|
||||
|
||||
type proof = A.Proof.t
|
||||
type proof = P.t
|
||||
|
||||
let[@inline] cc (t:t) = Lazy.force t.cc
|
||||
let[@inline] tst t = t.tst
|
||||
|
||||
let simplifier self = self.simp
|
||||
let simp_t self (t:T.t) : T.t = Simplify.normalize self.simp t
|
||||
let simp_t self (t:Term.t) : Term.t = Simplify.normalize self.simp t
|
||||
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
|
||||
|
||||
let add_preprocess self f = self.preprocess <- f :: self.preprocess
|
||||
|
||||
let[@inline] raise_conflict self acts c : 'a =
|
||||
Stat.incr self.count_conflict;
|
||||
acts.Msat.acts_raise_conflict c A.Proof.default
|
||||
acts.Msat.acts_raise_conflict c P.default
|
||||
|
||||
let[@inline] propagate self acts p cs : unit =
|
||||
Stat.incr self.count_propagate;
|
||||
acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), A.Proof.default))
|
||||
acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), P.default))
|
||||
|
||||
let[@inline] propagate_l self acts p cs : unit =
|
||||
propagate self acts p (fun()->cs)
|
||||
|
||||
let add_sat_clause_ self acts ~keep lits : unit =
|
||||
Stat.incr self.count_axiom;
|
||||
acts.Msat.acts_add_clause ~keep lits A.Proof.default
|
||||
acts.Msat.acts_add_clause ~keep lits P.default
|
||||
|
||||
let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit =
|
||||
let mk_lit t = Lit.atom self.tst t in
|
||||
(* compute and cache normal form of [t] *)
|
||||
let rec aux t =
|
||||
match T.Tbl.find self.preprocess_cache t with
|
||||
match Term.Tbl.find self.preprocess_cache t with
|
||||
| u -> u
|
||||
| exception Not_found ->
|
||||
(* first, map subterms *)
|
||||
let u = T.map_shallow self.tst aux t in
|
||||
let u = Term.map_shallow self.tst aux t in
|
||||
(* then rewrite *)
|
||||
let u = aux_rec u self.preprocess in
|
||||
T.Tbl.add self.preprocess_cache t u;
|
||||
Term.Tbl.add self.preprocess_cache t u;
|
||||
u
|
||||
(* try each function in [hooks] successively *)
|
||||
and aux_rec t hooks = match hooks with
|
||||
|
|
@ -227,7 +237,7 @@ module Make(A : ARG)
|
|||
| Some u ->
|
||||
Log.debugf 30
|
||||
(fun k->k "(@[msat-solver.preprocess.step@ :from %a@ :to %a@])"
|
||||
T.pp t T.pp u);
|
||||
Term.pp t Term.pp u);
|
||||
aux u
|
||||
in
|
||||
let t = Lit.term lit |> simp_t self |> aux in
|
||||
|
|
@ -339,7 +349,7 @@ module Make(A : ARG)
|
|||
CC.mk_model (cc self) m
|
||||
*)
|
||||
|
||||
let create ~stat (tst:A.Term.state) () : t =
|
||||
let create ~stat (tst:Term.state) () : t =
|
||||
let rec self = {
|
||||
tst;
|
||||
cc = lazy (
|
||||
|
|
@ -351,7 +361,7 @@ module Make(A : ARG)
|
|||
simp=Simplify.create tst;
|
||||
on_progress=(fun () -> ());
|
||||
preprocess=[];
|
||||
preprocess_cache=T.Tbl.create 32;
|
||||
preprocess_cache=Term.Tbl.create 32;
|
||||
count_axiom = Stat.mk_int stat "solver.th-axioms";
|
||||
count_preprocess_clause = Stat.mk_int stat "solver.preprocess-clause";
|
||||
count_propagate = Stat.mk_int stat "solver.th-propagations";
|
||||
|
|
@ -439,8 +449,8 @@ module Make(A : ARG)
|
|||
begin
|
||||
let tst = Solver_internal.tst self.si in
|
||||
Sat_solver.assume self.solver [
|
||||
[Lit.atom tst @@ T.bool tst true];
|
||||
] A.Proof.default;
|
||||
[Lit.atom tst @@ Term.bool tst true];
|
||||
] P.default;
|
||||
end;
|
||||
self
|
||||
|
||||
|
|
@ -456,9 +466,9 @@ module Make(A : ARG)
|
|||
mk_atom_lit_ self lit
|
||||
|
||||
(* map boolean subterms to literals *)
|
||||
let add_bool_subterms_ (self:t) (t:T.t) : unit =
|
||||
T.iter_dag t
|
||||
|> Iter.filter (fun t -> Ty.is_bool @@ T.ty t)
|
||||
let add_bool_subterms_ (self:t) (t:Term.t) : unit =
|
||||
Term.iter_dag t
|
||||
|> Iter.filter (fun t -> Ty.is_bool @@ Term.ty t)
|
||||
|> Iter.filter
|
||||
(fun t -> match A.cc_view t with
|
||||
| Sidekick_core.CC_view.Not _ -> false (* will process the subterm just later *)
|
||||
|
|
@ -466,7 +476,7 @@ module Make(A : ARG)
|
|||
|> Iter.filter (fun t -> A.is_valid_literal t)
|
||||
|> Iter.iter
|
||||
(fun sub ->
|
||||
Log.debugf 5 (fun k->k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" T.pp sub);
|
||||
Log.debugf 5 (fun k->k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" Term.pp sub);
|
||||
(* ensure that msat has a boolean atom for [sub] *)
|
||||
let atom = mk_atom_t_ self sub in
|
||||
(* also map [sub] to this atom in the congruence closure, for propagation *)
|
||||
|
|
@ -485,7 +495,7 @@ module Make(A : ARG)
|
|||
(* recursively add these sub-literals, so they're also properly processed *)
|
||||
Stat.incr self.si.count_preprocess_clause;
|
||||
let atoms = List.map (mk_atom_lit self) lits in
|
||||
Sat_solver.add_clause self.solver atoms A.Proof.default)
|
||||
Sat_solver.add_clause self.solver atoms P.default)
|
||||
self.si lit
|
||||
|
||||
let[@inline] mk_atom_t self ?sign t : Atom.t =
|
||||
|
|
@ -507,28 +517,28 @@ module Make(A : ARG)
|
|||
end [@@ocaml.warning "-37"]
|
||||
|
||||
(* just use terms as values *)
|
||||
module Value = A.Term
|
||||
module Value = Term
|
||||
|
||||
module Model = struct
|
||||
type t =
|
||||
| Empty
|
||||
| Map of Value.t A.Term.Tbl.t
|
||||
| Map of Value.t Term.Tbl.t
|
||||
let empty = Empty
|
||||
let mem = function
|
||||
| Empty -> fun _ -> false
|
||||
| Map tbl -> A.Term.Tbl.mem tbl
|
||||
| Map tbl -> Term.Tbl.mem tbl
|
||||
let find = function
|
||||
| Empty -> fun _ -> None
|
||||
| Map tbl -> A.Term.Tbl.get tbl
|
||||
| Map tbl -> Term.Tbl.get tbl
|
||||
let eval = find
|
||||
let pp out = function
|
||||
| Empty -> Fmt.string out "(model)"
|
||||
| Map tbl ->
|
||||
let pp_pair out (t,v) =
|
||||
Fmt.fprintf out "(@[<1>%a@ := %a@])" A.Term.pp t Value.pp v
|
||||
Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Value.pp v
|
||||
in
|
||||
Fmt.fprintf out "(@[<hv>model@ %a@])"
|
||||
(Util.pp_seq pp_pair) (A.Term.Tbl.to_seq tbl)
|
||||
(Util.pp_seq pp_pair) (Term.Tbl.to_seq tbl)
|
||||
end
|
||||
|
||||
type res =
|
||||
|
|
@ -551,7 +561,7 @@ module Make(A : ARG)
|
|||
|
||||
let add_clause (self:t) (c:Atom.t IArray.t) : unit =
|
||||
Stat.incr self.count_clause;
|
||||
Sat_solver.add_clause_a self.solver (c:> Atom.t array) A.Proof.default
|
||||
Sat_solver.add_clause_a self.solver (c:> Atom.t array) P.default
|
||||
|
||||
let add_clause_l self c = add_clause self (IArray.of_list c)
|
||||
|
||||
|
|
@ -562,13 +572,13 @@ module Make(A : ARG)
|
|||
|
||||
let mk_model (self:t) (lits:lit Iter.t) : Model.t =
|
||||
Log.debug 1 "(smt.solver.mk-model)";
|
||||
let module M = A.Term.Tbl in
|
||||
let module M = Term.Tbl in
|
||||
let m = M.create 128 in
|
||||
let tst = self.si.tst in
|
||||
(* first, add all boolean *)
|
||||
lits
|
||||
(fun {Lit.lit_term=t;lit_sign=sign} ->
|
||||
M.replace m t (A.Term.bool tst sign));
|
||||
M.replace m t (Term.bool tst sign));
|
||||
(* then add CC classes *)
|
||||
Solver_internal.CC.all_classes (Solver_internal.cc self.si)
|
||||
(fun repr ->
|
||||
|
|
|
|||
|
|
@ -399,11 +399,11 @@ let conv_term = Conv.conv_term
|
|||
|
||||
(* instantiate solver here *)
|
||||
module Solver_arg = struct
|
||||
include Sidekick_base_term
|
||||
module T = Sidekick_base_term
|
||||
|
||||
let cc_view = Term.cc_view
|
||||
let is_valid_literal _ = true
|
||||
module Proof = struct
|
||||
module P = struct
|
||||
type t = Default
|
||||
let default=Default
|
||||
let pp out _ = Fmt.string out "default"
|
||||
|
|
@ -643,7 +643,7 @@ let process_stmt
|
|||
|
||||
module Th_bool = Sidekick_th_bool_static.Make(struct
|
||||
module S = Solver
|
||||
type term = S.A.Term.t
|
||||
type term = S.T.Term.t
|
||||
include Form
|
||||
end)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@
|
|||
open Sidekick_base_term
|
||||
|
||||
module Solver
|
||||
: Sidekick_msat_solver.S with type A.Term.t = Term.t
|
||||
and type A.Term.state = Term.state
|
||||
and type A.Ty.t = Ty.t
|
||||
: Sidekick_msat_solver.S with type T.Term.t = Term.t
|
||||
and type T.Term.state = Term.state
|
||||
and type T.Ty.t = Ty.t
|
||||
|
||||
val th_bool : Solver.theory
|
||||
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ type 'a bool_view =
|
|||
module type ARG = sig
|
||||
module S : Sidekick_core.SOLVER
|
||||
|
||||
type term = S.A.Term.t
|
||||
type term = S.T.Term.t
|
||||
|
||||
val view_as_bool : term -> term bool_view
|
||||
(** Project the term into the boolean view *)
|
||||
|
||||
val mk_bool : S.A.Term.state -> term bool_view -> term
|
||||
val mk_bool : S.T.Term.state -> term bool_view -> term
|
||||
(** Make a term from the given boolean view *)
|
||||
|
||||
val check_congruence_classes : bool
|
||||
|
|
@ -32,9 +32,9 @@ module type ARG = sig
|
|||
module Gensym : sig
|
||||
type t
|
||||
|
||||
val create : S.A.Term.state -> t
|
||||
val create : S.T.Term.state -> t
|
||||
|
||||
val fresh_term : t -> pre:string -> S.A.Ty.t -> term
|
||||
val fresh_term : t -> pre:string -> S.T.Ty.t -> term
|
||||
(** Make a fresh term of the given type *)
|
||||
end
|
||||
end
|
||||
|
|
@ -44,7 +44,7 @@ module type S = sig
|
|||
|
||||
type state
|
||||
|
||||
val create : A.S.A.Term.state -> state
|
||||
val create : A.S.T.Term.state -> state
|
||||
|
||||
val simplify : state -> A.S.Solver_internal.simplify_hook
|
||||
(** Simplify given term *)
|
||||
|
|
@ -57,8 +57,8 @@ end
|
|||
|
||||
module Make(A : ARG) : S with module A = A = struct
|
||||
module A = A
|
||||
module Ty = A.S.A.Ty
|
||||
module T = A.S.A.Term
|
||||
module Ty = A.S.T.Ty
|
||||
module T = A.S.T.Term
|
||||
module Lit = A.S.Solver_internal.Lit
|
||||
module SI = A.S.Solver_internal
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ let name = "th-cstor"
|
|||
|
||||
module type ARG = sig
|
||||
module S : Sidekick_core.SOLVER
|
||||
val view_as_cstor : S.A.Term.t -> (S.A.Fun.t, S.A.Term.t) cstor_view
|
||||
val view_as_cstor : S.T.Term.t -> (S.T.Fun.t, S.T.Term.t) cstor_view
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
|
|
@ -19,9 +19,9 @@ end
|
|||
module Make(A : ARG) : S with module A = A = struct
|
||||
module A = A
|
||||
module SI = A.S.Solver_internal
|
||||
module T = A.S.A.Term
|
||||
module T = A.S.T.Term
|
||||
module N = SI.CC.N
|
||||
module Fun = A.S.A.Fun
|
||||
module Fun = A.S.T.Fun
|
||||
module Expl = SI.CC.Expl
|
||||
|
||||
type cstor_repr = {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue