From 6e9e95c2339c459e9828f6d128a2dbcf797d13b4 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 26 May 2019 23:20:47 -0500 Subject: [PATCH] wip: functorize everything --- dune-project | 1 + src/{smt => base-term}/Ast.ml | 0 src/{smt => base-term}/Ast.mli | 0 src/base-term/Base_types.ml | 174 ++++ src/{smt => base-term}/Config.ml | 0 src/{smt => base-term}/Config.mli | 0 src/{smt => base-term}/Cst.ml | 2 +- src/{smt => base-term}/Cst.mli | 2 +- src/{smt => base-term}/Hash.ml | 0 src/{smt => base-term}/Hash.mli | 0 src/{smt => base-term}/Hashcons.ml | 0 src/{smt => base-term}/ID.ml | 0 src/{smt => base-term}/ID.mli | 0 src/{smt => base-term}/Lit.ml | 6 +- src/{smt => base-term}/Lit.mli | 5 +- src/{smt => base-term}/Model.ml | 44 +- src/{smt => base-term}/Model.mli | 0 src/{smt => base-term}/Term.ml | 6 +- src/{smt => base-term}/Term.mli | 4 +- src/{smt => base-term}/Term_cell.ml | 6 +- src/{smt => base-term}/Term_cell.mli | 4 +- src/{smt => base-term}/Ty.ml | 6 +- src/{smt => base-term}/Ty.mli | 8 +- src/{smt => base-term}/Ty_card.ml | 2 +- src/{smt => base-term}/Ty_card.mli | 2 +- src/{smt => base-term}/Value.ml | 10 +- src/{smt => base-term}/Value.mli | 2 +- src/base-term/dune | 7 + src/cc/Congruence_closure.ml | 922 -------------------- src/cc/Congruence_closure.mli | 13 - src/cc/Congruence_closure_intf.ml | 301 ------- src/cc/Sidekick_cc.ml | 889 ++++++++++++++++++- src/cc/Sidekick_cc.mli | 6 + src/cc/dune | 7 +- src/core/Sidekick_core.ml | 40 +- src/main/dune | 5 +- src/{cc => mini-cc}/Mini_cc.ml | 9 +- src/{cc => mini-cc}/Mini_cc.mli | 5 +- src/mini-cc/dune | 7 + src/{smt => msat-solver}/CC.ml | 0 src/{smt => msat-solver}/CC.mli | 0 src/{smt => msat-solver}/DESIGN.md | 0 src/{smt => msat-solver}/Sidekick_smt.ml | 0 src/{smt => msat-solver}/Solver.ml | 0 src/{smt => msat-solver}/Solver.mli | 0 src/{smt => msat-solver}/Solver_types.ml | 5 - src/{smt => msat-solver}/Theory.ml | 0 src/{smt => msat-solver}/Theory_combine.ml | 0 src/{smt => msat-solver}/Theory_combine.mli | 0 src/msat-solver/dune | 7 + src/msat-solver/th_key.ml.bak | 145 +++ src/smt/dune | 10 - src/smtlib/dune | 5 +- src/th-bool/dune | 2 +- src/th-cstor/dune | 2 +- src/th-distinct/Sidekick_th_distinct.ml | 33 +- src/th-distinct/dune | 2 +- src/util/dune | 6 +- 58 files changed, 1343 insertions(+), 1369 deletions(-) rename src/{smt => base-term}/Ast.ml (100%) rename src/{smt => base-term}/Ast.mli (100%) create mode 100644 src/base-term/Base_types.ml rename src/{smt => base-term}/Config.ml (100%) rename src/{smt => base-term}/Config.mli (100%) rename src/{smt => base-term}/Cst.ml (98%) rename src/{smt => base-term}/Cst.mli (96%) rename src/{smt => base-term}/Hash.ml (100%) rename src/{smt => base-term}/Hash.mli (100%) rename src/{smt => base-term}/Hashcons.ml (100%) rename src/{smt => base-term}/ID.ml (100%) rename src/{smt => base-term}/ID.mli (100%) rename src/{smt => base-term}/Lit.ml (86%) rename src/{smt => base-term}/Lit.mli (84%) rename src/{smt => base-term}/Model.ml (79%) rename src/{smt => base-term}/Model.mli (100%) rename src/{smt => base-term}/Term.ml (97%) rename src/{smt => base-term}/Term.mli (94%) rename src/{smt => base-term}/Term_cell.ml (95%) rename src/{smt => base-term}/Term_cell.mli (92%) rename src/{smt => base-term}/Ty.ml (96%) rename src/{smt => base-term}/Ty.mli (85%) rename src/{smt => base-term}/Ty_card.ml (96%) rename src/{smt => base-term}/Ty_card.mli (89%) rename src/{smt => base-term}/Value.ml (50%) rename src/{smt => base-term}/Value.mli (92%) create mode 100644 src/base-term/dune delete mode 100644 src/cc/Congruence_closure.ml delete mode 100644 src/cc/Congruence_closure.mli delete mode 100644 src/cc/Congruence_closure_intf.ml create mode 100644 src/cc/Sidekick_cc.mli rename src/{cc => mini-cc}/Mini_cc.ml (98%) rename src/{cc => mini-cc}/Mini_cc.mli (91%) create mode 100644 src/mini-cc/dune rename src/{smt => msat-solver}/CC.ml (100%) rename src/{smt => msat-solver}/CC.mli (100%) rename src/{smt => msat-solver}/DESIGN.md (100%) rename src/{smt => msat-solver}/Sidekick_smt.ml (100%) rename src/{smt => msat-solver}/Solver.ml (100%) rename src/{smt => msat-solver}/Solver.mli (100%) rename src/{smt => msat-solver}/Solver_types.ml (97%) rename src/{smt => msat-solver}/Theory.ml (100%) rename src/{smt => msat-solver}/Theory_combine.ml (100%) rename src/{smt => msat-solver}/Theory_combine.mli (100%) create mode 100644 src/msat-solver/dune create mode 100644 src/msat-solver/th_key.ml.bak delete mode 100644 src/smt/dune diff --git a/dune-project b/dune-project index 977e7d75..04fa6f89 100644 --- a/dune-project +++ b/dune-project @@ -1,2 +1,3 @@ (lang dune 1.1) (using menhir 1.0) +(using fmt 1.1) diff --git a/src/smt/Ast.ml b/src/base-term/Ast.ml similarity index 100% rename from src/smt/Ast.ml rename to src/base-term/Ast.ml diff --git a/src/smt/Ast.mli b/src/base-term/Ast.mli similarity index 100% rename from src/smt/Ast.mli rename to src/base-term/Ast.mli diff --git a/src/base-term/Base_types.ml b/src/base-term/Base_types.ml new file mode 100644 index 00000000..4b8a306b --- /dev/null +++ b/src/base-term/Base_types.ml @@ -0,0 +1,174 @@ + +module Vec = Msat.Vec +module Log = Msat.Log +module Fmt = CCFormat + +(* main term cell. *) +type term = { + mutable term_id: int; (* unique ID *) + mutable term_ty: ty; + term_view : term term_view; +} + +(* term shallow structure *) +and 'a term_view = + | Bool of bool + | App_cst of cst * 'a IArray.t (* full, first-order application *) + | Eq of 'a * 'a + | If of 'a * 'a * 'a + | Not of 'a + +(* boolean literal *) +and lit = { + lit_term: term; + lit_sign: bool; +} + +and cst = { + cst_id: ID.t; + cst_view: cst_view; +} + +and cst_view = + | Cst_undef of fun_ty (* simple undefined constant *) + | Cst_def of { + pp : 'a. ('a Fmt.printer -> 'a IArray.t Fmt.printer) option; + abs : self:term -> term IArray.t -> term * bool; (* remove the sign? *) + do_cc: bool; (* participate in congruence closure? *) + relevant : 'a. ID.t -> 'a IArray.t -> int -> bool; (* relevant argument? *) + ty : ID.t -> term IArray.t -> ty; (* compute type *) + eval: value IArray.t -> value; (* evaluate term *) + } +(** Methods on the custom term view whose arguments are ['a]. + Terms must be printable, and provide some additional theory handles. + + - [relevant] must return a subset of [args] (possibly the same set). + The terms it returns will be activated and evaluated whenever possible. + Terms in [args \ relevant args] are considered for + congruence but not for evaluation. +*) + +(** Function type *) +and fun_ty = { + fun_ty_args: ty list; + fun_ty_ret: ty; +} + +(** Hashconsed type *) +and ty = { + mutable ty_id: int; + ty_view: ty_view; +} + +and ty_view = + | Ty_prop + | Ty_atomic of { + def: ty_def; + args: ty list; + card: ty_card lazy_t; + } + +and ty_def = + | Ty_uninterpreted of ID.t + | Ty_def of { + id: ID.t; + pp: ty Fmt.printer -> ty list Fmt.printer; + default_val: value list -> value; (* default value of this type *) + card: ty list -> ty_card; + } + +and ty_card = + | Finite + | Infinite + +(** Semantic values, used for models (and possibly model-constructing calculi) *) +and value = + | V_bool of bool + | V_element of { + id: ID.t; + ty: ty; + } (** a named constant, distinct from any other constant *) + | V_custom of { + view: value_custom_view; + pp: value_custom_view Fmt.printer; + eq: value_custom_view -> value_custom_view -> bool; + hash: value_custom_view -> int; + } (** Custom value *) + +and value_custom_view = .. + +let[@inline] term_equal_ (a:term) b = a==b +let[@inline] term_hash_ a = a.term_id +let[@inline] term_cmp_ a b = CCInt.compare a.term_id b.term_id + +let cmp_lit a b = + let c = CCBool.compare a.lit_sign b.lit_sign in + if c<>0 then c + else term_cmp_ a.lit_term b.lit_term + +let cst_compare a b = ID.compare a.cst_id b.cst_id + +let hash_lit a = + let sign = a.lit_sign in + Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term) + +let pp_cst out a = ID.pp out a.cst_id +let id_of_cst a = a.cst_id + +let[@inline] eq_ty a b = a.ty_id = b.ty_id + +let eq_value a b = match a, b with + | V_bool a, V_bool b -> a=b + | V_element e1, V_element e2 -> + ID.equal e1.id e2.id && eq_ty e1.ty e2.ty + | V_custom x1, V_custom x2 -> + x1.eq x1.view x2.view + | V_bool _, _ | V_element _, _ | V_custom _, _ + -> false + +let hash_value a = match a with + | V_bool a -> Hash.bool a + | V_element e -> ID.hash e.id + | V_custom x -> x.hash x.view + +let pp_value out = function + | V_bool b -> Fmt.bool out b + | V_element e -> ID.pp out e.id + | V_custom c -> c.pp out c.view + +let pp_db out (i,_) = Format.fprintf out "%%%d" i + +let rec pp_ty out t = match t.ty_view with + | Ty_prop -> Fmt.string out "prop" + | Ty_atomic {def=Ty_uninterpreted id; args=[]; _} -> ID.pp out id + | Ty_atomic {def=Ty_uninterpreted id; args; _} -> + Fmt.fprintf out "(@[%a@ %a@])" ID.pp id (Util.pp_list pp_ty) args + | Ty_atomic {def=Ty_def def; args; _} -> def.pp pp_ty out args + +let pp_term_view_gen ~pp_id ~pp_t out = function + | Bool true -> Fmt.string out "true" + | Bool false -> Fmt.string out "false" + | App_cst ({cst_view=Cst_def {pp=Some pp_custom;_};_},l) -> pp_custom pp_t out l + | App_cst (c, a) when IArray.is_empty a -> + pp_id out (id_of_cst c) + | App_cst (f,l) -> + Fmt.fprintf out "(@[<1>%a@ %a@])" pp_id (id_of_cst f) (Util.pp_iarray pp_t) l + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" pp_t a pp_t b + | If (a, b, c) -> + Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c + | Not u -> Fmt.fprintf out "(@[not@ %a@])" pp_t u + +let pp_term_top ~ids out t = + let rec pp out t = + pp_rec out t; + (* FIXME if Config.pp_hashcons then Format.fprintf out "/%d" t.term_id; *) + and pp_rec out t = pp_term_view_gen ~pp_id ~pp_t:pp_rec out t.term_view + and pp_id = if ids then ID.pp else ID.pp_name in + pp out t + +let pp_term = pp_term_top ~ids:false +let pp_term_view = pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:pp_term + +let pp_lit out l = + if l.lit_sign then pp_term out l.lit_term + else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term diff --git a/src/smt/Config.ml b/src/base-term/Config.ml similarity index 100% rename from src/smt/Config.ml rename to src/base-term/Config.ml diff --git a/src/smt/Config.mli b/src/base-term/Config.mli similarity index 100% rename from src/smt/Config.mli rename to src/base-term/Config.mli diff --git a/src/smt/Cst.ml b/src/base-term/Cst.ml similarity index 98% rename from src/smt/Cst.ml rename to src/base-term/Cst.ml index d0cd406d..105b704e 100644 --- a/src/smt/Cst.ml +++ b/src/base-term/Cst.ml @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type view = cst_view type t = cst diff --git a/src/smt/Cst.mli b/src/base-term/Cst.mli similarity index 96% rename from src/smt/Cst.mli rename to src/base-term/Cst.mli index 3596e250..bbf2322a 100644 --- a/src/smt/Cst.mli +++ b/src/base-term/Cst.mli @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type view = cst_view type t = cst diff --git a/src/smt/Hash.ml b/src/base-term/Hash.ml similarity index 100% rename from src/smt/Hash.ml rename to src/base-term/Hash.ml diff --git a/src/smt/Hash.mli b/src/base-term/Hash.mli similarity index 100% rename from src/smt/Hash.mli rename to src/base-term/Hash.mli diff --git a/src/smt/Hashcons.ml b/src/base-term/Hashcons.ml similarity index 100% rename from src/smt/Hashcons.ml rename to src/base-term/Hashcons.ml diff --git a/src/smt/ID.ml b/src/base-term/ID.ml similarity index 100% rename from src/smt/ID.ml rename to src/base-term/ID.ml diff --git a/src/smt/ID.mli b/src/base-term/ID.mli similarity index 100% rename from src/smt/ID.mli rename to src/base-term/ID.mli diff --git a/src/smt/Lit.ml b/src/base-term/Lit.ml similarity index 86% rename from src/smt/Lit.ml rename to src/base-term/Lit.ml index 20e996ca..f14fe935 100644 --- a/src/smt/Lit.ml +++ b/src/base-term/Lit.ml @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type t = lit = { lit_term: term; @@ -27,8 +27,8 @@ let[@inline] equal a b = compare a b = 0 let pp = pp_lit let print = pp -let norm l = - if l.lit_sign then l, Msat.Solver_intf.Same_sign else neg l, Msat.Solver_intf.Negated +let apply_sign t s = if s then t else neg t +let norm_sign l = if l.lit_sign then l, true else neg l, false module Set = CCSet.Make(struct type t = lit let compare=compare end) module Tbl = CCHashtbl.Make(struct type t = lit let equal=equal let hash=hash end) diff --git a/src/smt/Lit.mli b/src/base-term/Lit.mli similarity index 84% rename from src/smt/Lit.mli rename to src/base-term/Lit.mli index a6b17f1e..af82e940 100644 --- a/src/smt/Lit.mli +++ b/src/base-term/Lit.mli @@ -1,6 +1,6 @@ (** {2 Literals} *) -open Solver_types +open Base_types type t = lit = { lit_term: term; @@ -18,7 +18,8 @@ val compare : t -> t -> int val equal : t -> t -> bool val print : t Fmt.printer val pp : t Fmt.printer -val norm : t -> t * Msat.Solver_intf.negated +val apply_sign : t -> bool -> t +val norm_sign : t -> t * bool module Set : CCSet.S with type elt = t module Tbl : CCHashtbl.S with type key = t diff --git a/src/smt/Model.ml b/src/base-term/Model.ml similarity index 79% rename from src/smt/Model.ml rename to src/base-term/Model.ml index 200e7c33..5e822c5a 100644 --- a/src/smt/Model.ml +++ b/src/base-term/Model.ml @@ -3,7 +3,7 @@ (** {1 Model} *) -open Solver_types +open Base_types module Val_map = struct module M = CCIntMap @@ -176,3 +176,45 @@ let eval (m:t) (t:Term.t) : Value.t option = in try Some (aux t) with No_value -> None + +(* TODO: get model from each theory, then complete it as follows based on types + let mk_model (cc:t) (m:A.Model.t) : A.Model.t = + let module Model = A.Model in + let module Value = A.Value in + Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc); + let t_tbl = N_tbl.create 32 in + (* populate [repr -> value] table *) + T_tbl.values cc.tbl + (fun r -> + if N.is_root r then ( + (* find a value in the class, if any *) + let v = + N.iter_class r + |> Iter.find_map (fun n -> Model.eval m n.n_term) + in + let v = match v with + | Some v -> v + | None -> + if same_class r (true_ cc) then Value.true_ + else if same_class r (false_ cc) then Value.false_ + else Value.fresh r.n_term + in + N_tbl.add t_tbl r v + )); + (* now map every term to its representative's value *) + let pairs = + T_tbl.values cc.tbl + |> Iter.map + (fun n -> + let r = find_ n in + let v = + try N_tbl.find t_tbl r + with Not_found -> + Error.errorf "didn't allocate a value for repr %a" N.pp r + in + n.n_term, v) + in + let m = Iter.fold (fun m (t,v) -> Model.add t v m) m pairs in + Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m); + m + *) diff --git a/src/smt/Model.mli b/src/base-term/Model.mli similarity index 100% rename from src/smt/Model.mli rename to src/base-term/Model.mli diff --git a/src/smt/Term.ml b/src/base-term/Term.ml similarity index 97% rename from src/smt/Term.ml rename to src/base-term/Term.ml index 0d9b0d40..a3c9ca63 100644 --- a/src/smt/Term.ml +++ b/src/base-term/Term.ml @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type t = term = { mutable term_id : int; @@ -89,7 +89,7 @@ let[@inline] is_const t = match view t with | _ -> false let cc_view (t:t) = - let module C = Sidekick_cc in + let module C = Sidekick_core.CC_view in match view t with | Bool b -> C.Bool b | App_cst (f,_) when not (Cst.do_cc f) -> C.Opaque t (* skip *) @@ -115,7 +115,7 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option = | App_cst (c, a) when IArray.is_empty a -> Cst.as_undefined c | _ -> None -let pp = Solver_types.pp_term +let pp = Base_types.pp_term module Iter_dag = struct type t = unit Tbl.t diff --git a/src/smt/Term.mli b/src/base-term/Term.mli similarity index 94% rename from src/smt/Term.mli rename to src/base-term/Term.mli index 5fefbaf3..4c39eabc 100644 --- a/src/smt/Term.mli +++ b/src/base-term/Term.mli @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type t = term = { mutable term_id : int; @@ -55,7 +55,7 @@ val is_true : t -> bool val is_false : t -> bool val is_const : t -> bool -val cc_view : t -> (cst,t,t Iter.t) Sidekick_cc.view +val cc_view : t -> (cst,t,t Iter.t) Sidekick_core.CC_view.t (* return [Some] iff the term is an undefined constant *) val as_cst_undef : t -> (cst * Ty.Fun.t) option diff --git a/src/smt/Term_cell.ml b/src/base-term/Term_cell.ml similarity index 95% rename from src/smt/Term_cell.ml rename to src/base-term/Term_cell.ml index e46b54b4..02903cbf 100644 --- a/src/smt/Term_cell.ml +++ b/src/base-term/Term_cell.ml @@ -1,9 +1,9 @@ -open Solver_types +open Base_types (* TODO: normalization of {!term_cell} for use in signatures? *) -type 'a view = 'a Solver_types.term_view = +type 'a view = 'a Base_types.term_view = | Bool of bool | App_cst of cst * 'a IArray.t | Eq of 'a * 'a @@ -43,7 +43,7 @@ module Make_eq(A : ARG) = struct | Bool _, _ | App_cst _, _ | If _, _ | Eq _, _ | Not _, _ -> false - let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp + let pp = Base_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp end[@@inline] include Make_eq(struct diff --git a/src/smt/Term_cell.mli b/src/base-term/Term_cell.mli similarity index 92% rename from src/smt/Term_cell.mli rename to src/base-term/Term_cell.mli index 47e2ad57..73bd5101 100644 --- a/src/smt/Term_cell.mli +++ b/src/base-term/Term_cell.mli @@ -1,7 +1,7 @@ -open Solver_types +open Base_types -type 'a view = 'a Solver_types.term_view = +type 'a view = 'a Base_types.term_view = | Bool of bool | App_cst of cst * 'a IArray.t | Eq of 'a * 'a diff --git a/src/smt/Ty.ml b/src/base-term/Ty.ml similarity index 96% rename from src/smt/Ty.ml rename to src/base-term/Ty.ml index cbef42b0..d6e7fd40 100644 --- a/src/smt/Ty.ml +++ b/src/base-term/Ty.ml @@ -1,9 +1,9 @@ -open Solver_types +open Base_types type t = ty -type view = Solver_types.ty_view -type def = Solver_types.ty_def +type view = Base_types.ty_view +type def = Base_types.ty_def let[@inline] id t = t.ty_id let[@inline] view t = t.ty_view diff --git a/src/smt/Ty.mli b/src/base-term/Ty.mli similarity index 85% rename from src/smt/Ty.mli rename to src/base-term/Ty.mli index 7976e1ed..bdebcfd7 100644 --- a/src/smt/Ty.mli +++ b/src/base-term/Ty.mli @@ -1,11 +1,11 @@ (** {1 Hashconsed Types} *) -open Solver_types +open Base_types -type t = Solver_types.ty -type view = Solver_types.ty_view -type def = Solver_types.ty_def +type t = Base_types.ty +type view = Base_types.ty_view +type def = Base_types.ty_def val id : t -> int val view : t -> view diff --git a/src/smt/Ty_card.ml b/src/base-term/Ty_card.ml similarity index 96% rename from src/smt/Ty_card.ml rename to src/base-term/Ty_card.ml index a6e91975..75004d42 100644 --- a/src/smt/Ty_card.ml +++ b/src/base-term/Ty_card.ml @@ -1,5 +1,5 @@ -open Solver_types +open Base_types type t = ty_card diff --git a/src/smt/Ty_card.mli b/src/base-term/Ty_card.mli similarity index 89% rename from src/smt/Ty_card.mli rename to src/base-term/Ty_card.mli index 478333b3..4a23e646 100644 --- a/src/smt/Ty_card.mli +++ b/src/base-term/Ty_card.mli @@ -1,7 +1,7 @@ (** {1 Type Cardinality} *) -type t = Solver_types.ty_card +type t = Base_types.ty_card val (+) : t -> t -> t val ( * ) : t -> t -> t diff --git a/src/smt/Value.ml b/src/base-term/Value.ml similarity index 50% rename from src/smt/Value.ml rename to src/base-term/Value.ml index 9057db36..740e7794 100644 --- a/src/smt/Value.ml +++ b/src/base-term/Value.ml @@ -1,19 +1,19 @@ (** {1 Value} *) -open Solver_types +open Base_types type t = value let true_ = V_bool true let false_ = V_bool false -let bool v = V_bool v +let[@inline] bool v = if v then true_ else false_ let mk_elt id ty : t = V_element {id; ty} -let is_bool = function V_bool _ -> true | _ -> false -let is_true = function V_bool true -> true | _ -> false -let is_false = function V_bool false -> true | _ -> false +let[@inline] is_bool = function V_bool _ -> true | _ -> false +let[@inline] is_true = function V_bool true -> true | _ -> false +let[@inline] is_false = function V_bool false -> true | _ -> false let equal = eq_value let hash = hash_value diff --git a/src/smt/Value.mli b/src/base-term/Value.mli similarity index 92% rename from src/smt/Value.mli rename to src/base-term/Value.mli index 5bfadde6..4cc1574a 100644 --- a/src/smt/Value.mli +++ b/src/base-term/Value.mli @@ -3,7 +3,7 @@ Semantic value *) -type t = Solver_types.value +type t = Base_types.value val true_ : t val false_ : t diff --git a/src/base-term/dune b/src/base-term/dune new file mode 100644 index 00000000..2c19bc8b --- /dev/null +++ b/src/base-term/dune @@ -0,0 +1,7 @@ +(library + (name sidekick_base_term) + (public_name sidekick.base-term) + (synopsis "Basic term definitions for the standalone SMT solver") + (libraries containers containers.data + sidekick.core sidekick.util zarith) + (flags :standard -open Sidekick_util)) diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml deleted file mode 100644 index 99b5f0ff..00000000 --- a/src/cc/Congruence_closure.ml +++ /dev/null @@ -1,922 +0,0 @@ - -open Congruence_closure_intf - -module type ARG = Congruence_closure_intf.ARG -module type S = Congruence_closure_intf.S - -module Bits = CCBitField.Make() - -let field_is_pending = Bits.mk_field() -(** true iff the node is in the [cc.pending] queue *) - -let field_usr1 = Bits.mk_field() -(** General purpose *) - -let field_usr2 = Bits.mk_field() -(** General purpose *) - -let () = Bits.freeze() - -module Make(A: ARG) = struct - type term = A.Term.t - type term_state = A.Term.state - type lit = A.Lit.t - type fun_ = A.Fun.t - type proof = A.Proof.t - type model = A.Model.t - type th_data = A.Data.t - - (** Actions available to the theory *) - type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts - - module T = A.Term - module Fun = A.Fun - module Key = Key - - (** A node of the congruence closure. - An equivalence class is represented by its "root" element, - the representative. *) - type node = { - n_term: term; - mutable n_sig0: signature option; (* initial signature *) - mutable n_bits: Bits.t; (* bitfield for various properties *) - mutable n_parents: node Bag.t; (* parent terms of this node *) - mutable n_root: node; (* representative of congruence class (itself if a representative) *) - mutable n_next: node; (* pointer to next element of congruence class *) - mutable n_size: int; (* size of the class *) - mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) - mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) - mutable n_th_data: th_data; (* theory data *) - } - - and signature = (fun_, node, node list) view - - and explanation_forest_link = - | FL_none - | FL_some of { - next: node; - expl: explanation; - } - - (* atomic explanation in the congruence closure *) - and explanation = - | E_reduction (* by pure reduction, tautologically equal *) - | E_lit of lit (* because of this literal *) - | E_merge of node * node - | E_merge_t of term * term - | E_congruence of node * node (* caused by normal congruence *) - | E_and of explanation * explanation - - type repr = node - type conflict = lit list - - module N = struct - type t = node - - let[@inline] equal (n1:t) n2 = n1 == n2 - let[@inline] hash n = T.hash n.n_term - let[@inline] term n = n.n_term - let[@inline] pp out n = T.pp out n.n_term - let[@inline] as_lit n = n.n_as_lit - let[@inline] th_data n = n.n_th_data - - let make (t:term) : t = - let rec n = { - n_term=t; - n_sig0= None; - n_bits=Bits.empty; - n_parents=Bag.empty; - n_as_lit=None; (* TODO: provide a method to do it *) - n_root=n; - n_expl=FL_none; - n_next=n; - n_size=1; - n_th_data=A.Data.empty; - } in - n - - let[@inline] is_root (n:node) : bool = n.n_root == n - - (* traverse the equivalence class of [n] *) - let iter_class_ (n:node) : node Iter.t = - fun yield -> - let rec aux u = - yield u; - if u.n_next != n then aux u.n_next - in - aux n - - let[@inline] iter_class n = - assert (is_root n); - iter_class_ n - - let[@inline] iter_parents (n:node) : node Iter.t = - assert (is_root n); - Bag.to_seq n.n_parents - - let[@inline] get_field f t = Bits.get f t.n_bits - let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits - - let[@inline] get_field_usr1 t = get_field field_usr1 t - let[@inline] set_field_usr1 t b = set_field field_usr1 b t - let[@inline] get_field_usr2 t = get_field field_usr2 t - let[@inline] set_field_usr2 t b = set_field field_usr2 b t - end - - module N_tbl = CCHashtbl.Make(N) - - module Expl = struct - type t = explanation - - let rec pp out (e:explanation) = match e with - | E_reduction -> Fmt.string out "reduction" - | E_lit lit -> A.Lit.pp out lit - | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 - | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b - | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b - | E_and (a,b) -> - Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b - - let mk_reduction : t = E_reduction - let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) - let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) - let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b) - let[@inline] mk_lit l : t = E_lit l - - let rec mk_list l = - match l with - | [] -> mk_reduction - | [x] -> x - | E_reduction :: tl -> mk_list tl - | x :: y -> - match mk_list y with - | E_reduction -> x - | y' -> E_and (x,y') - end - - (** A signature is a shallow term shape where immediate subterms - are representative *) - module Signature = struct - type t = signature - let equal (s1:t) s2 : bool = - match s1, s2 with - | Bool b1, Bool b2 -> b1=b2 - | App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2 - | App_fun (f1,l1), App_fun (f2,l2) -> - Fun.equal f1 f2 && CCList.equal N.equal l1 l2 - | App_ho (f1,l1), App_ho (f2,l2) -> - N.equal f1 f2 && CCList.equal N.equal l1 l2 - | Not a, Not b -> N.equal a b - | If (a1,b1,c1), If (a2,b2,c2) -> - N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2 - | Eq (a1,b1), Eq (a2,b2) -> - N.equal a1 a2 && N.equal b1 b2 - | Opaque u1, Opaque u2 -> N.equal u1 u2 - | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ - | Eq _, _ | Opaque _, _ | Not _, _ - -> false - - let hash (s:t) : int = - let module H = CCHash in - match s with - | Bool b -> H.combine2 10 (H.bool b) - | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l) - | App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l) - | Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b) - | Opaque u -> H.combine2 50 (N.hash u) - | If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c) - | Not u -> H.combine2 70 (N.hash u) - - let pp out = function - | Bool b -> Fmt.bool out b - | App_fun (f, []) -> Fun.pp out f - | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l - | App_ho (f, []) -> N.pp out f - | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l - | Opaque t -> N.pp out t - | Not u -> Fmt.fprintf out "(@[not@ %a@])" N.pp u - | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b - | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c - end - - module Sig_tbl = CCHashtbl.Make(Signature) - module T_tbl = CCHashtbl.Make(T) - - type combine_task = - | CT_merge of node * node * explanation - - type t = { - tst: term_state; - tbl: node T_tbl.t; - (* internalization [term -> node] *) - signatures_tbl : node Sig_tbl.t; - (* map a signature to the corresponding node in some equivalence class. - A signature is a [term_cell] in which every immediate subterm - that participates in the congruence/evaluation relation - is normalized (i.e. is its own representative). - The critical property is that all members of an equivalence class - that have the same "shape" (including head symbol) - have the same signature *) - pending: node Vec.t; - combine: combine_task Vec.t; - undo: (unit -> unit) Backtrack_stack.t; - mutable on_merge: ev_on_merge list; - mutable on_new_term: ev_on_new_term list; - mutable ps_lits: lit list; (* TODO: thread it around instead? *) - (* proof state *) - ps_queue: (node*node) Vec.t; - (* pairs to explain *) - true_ : node lazy_t; - false_ : node lazy_t; - stat: Stat.t; - count_conflict: int Stat.counter; - count_merge: int Stat.counter; - } - (* TODO: an additional union-find to keep track, for each term, - of the terms they are known to be equal to, according - to the current explanation. That allows not to prove some equality - several times. - See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) - - and ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit - and ev_on_new_term = t -> N.t -> term -> th_data -> th_data option - - let[@inline] size_ (r:repr) = r.n_size - let[@inline] true_ cc = Lazy.force cc.true_ - let[@inline] false_ cc = Lazy.force cc.false_ - let[@inline] term_state cc = cc.tst - - let[@inline] on_backtrack cc f : unit = - Backtrack_stack.push_if_nonzero_level cc.undo f - - (* check if [t] is in the congruence closure. - Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) - let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t - - (* find representative, recursively *) - let[@unroll 2] rec find_rec (n:node) : repr = - if n==n.n_root then ( - n - ) else ( - let root = find_rec n.n_root in - if root != n.n_root then ( - n.n_root <- root; (* path compression *) - ); - root - ) - - (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n:node) : repr = - if n == n.n_root then n else find_rec n.n_root - - let[@inline] same_class (n1:node)(n2:node): bool = - N.equal (find_ n1) (find_ n2) - - let[@inline] find _ n = find_ n - - (* print full state *) - let pp_full out (cc:t) : unit = - let pp_next out n = - Fmt.fprintf out "@ :next %a" N.pp n.n_next in - let pp_root out n = - if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in - let pp_expl out n = match n.n_expl with - | FL_none -> () - | FL_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl - in - let pp_n out n = - Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n - and pp_sig_e out (s,n) = - Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n - in - Fmt.fprintf out - "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ %a@])@])" - (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) - (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) - - (* compute up-to-date signature *) - let update_sig (s:signature) : Signature.t = - Congruence_closure_intf.map_view s - ~f_f:(fun x->x) - ~f_t:find_ - ~f_ts:(List.map find_) - - (* find whether the given (parent) term corresponds to some signature - in [signatures_] *) - let[@inline] find_signature cc (s:signature) : repr option = - Sig_tbl.get cc.signatures_tbl s - - let add_signature cc (s:signature) (n:node) : unit = - (* add, but only if not present already *) - match Sig_tbl.find cc.signatures_tbl s with - | exception Not_found -> - Log.debugf 15 - (fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n); - on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); - Sig_tbl.add cc.signatures_tbl s n; - | r' -> - assert (same_class n r'); - () - - let push_pending cc t : unit = - if not @@ N.get_field field_is_pending t then ( - Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); - N.set_field field_is_pending true t; - Vec.push cc.pending t - ) - - let merge_classes cc t u e : unit = - Log.debugf 5 - (fun k->k "(@[cc.push_combine@ %a ~@ %a@ :expl %a@])" - N.pp t N.pp u Expl.pp e); - Vec.push cc.combine @@ CT_merge (t,u,e) - - (* re-root the explanation tree of the equivalence class of [n] - so that it points to [n]. - postcondition: [n.n_expl = None] *) - let rec reroot_expl (cc:t) (n:node): unit = - let old_expl = n.n_expl in - begin match old_expl with - | FL_none -> () (* already root *) - | FL_some {next=u; expl=e_n_u} -> - reroot_expl cc u; - u.n_expl <- FL_some {next=n; expl=e_n_u}; - n.n_expl <- FL_none; - end - - let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ = - (* clear tasks queue *) - Vec.iter (N.set_field field_is_pending false) cc.pending; - Vec.clear cc.pending; - Vec.clear cc.combine; - let c = List.rev_map A.Lit.neg e in - Stat.incr cc.count_conflict; - acts.Msat.acts_raise_conflict c A.Proof.default - - let[@inline] all_classes cc : repr Iter.t = - T_tbl.values cc.tbl - |> Iter.filter N.is_root - - (* TODO: use markers and lockstep iteration instead *) - (* distance from [t] to its root in the proof forest *) - let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with - | FL_none -> 0 - | FL_some {next=t'; _} -> 1 + distance_to_root t' - - (* TODO: new bool flag on nodes + stepwise progress + cleanup *) - (* find the closest common ancestor of [a] and [b] in the proof forest *) - let find_common_ancestor (a:node) (b:node) : node = - let d_a = distance_to_root a in - let d_b = distance_to_root b in - (* drop [n] nodes in the path from [t] to its root *) - let rec drop_ n t = - if n=0 then t - else match t.n_expl with - | FL_none -> assert false - | FL_some {next=t'; _} -> drop_ (n-1) t' - in - (* reduce to the problem where [a] and [b] have the same distance to root *) - let a, b = - if d_a > d_b then drop_ (d_a-d_b) a, b - else if d_a < d_b then a, drop_ (d_b-d_a) b - else a, b - in - (* traverse stepwise until a==b *) - let rec aux_same_dist a b = - if a==b then a - else match a.n_expl, b.n_expl with - | FL_none, _ | _, FL_none -> assert false - | FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b' - in - aux_same_dist a b - - let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) - let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits - - let ps_clear (cc:t) = - cc.ps_lits <- []; - Vec.clear cc.ps_queue; - () - - (* decompose explanation [e] of why [n1 = n2] *) - let rec decompose_explain cc (e:explanation) : unit = - Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); - match e with - | E_reduction -> () - | E_congruence (n1, n2) -> - begin match n1.n_sig0, n2.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> - assert (Fun.equal f1 f2); - assert (List.length a1 = List.length a2); - List.iter2 (ps_add_obligation cc) a1 a2; - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> - assert (List.length a1 = List.length a2); - ps_add_obligation cc f1 f2; - List.iter2 (ps_add_obligation cc) a1 a2; - | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> - ps_add_obligation cc a1 a2; - ps_add_obligation cc b1 b2; - ps_add_obligation cc c1 c2; - | _ -> - assert false - end - | E_lit lit -> ps_add_lit cc lit - | E_merge (a,b) -> ps_add_obligation cc a b - | E_merge_t (a,b) -> - (* find nodes for [a] and [b] on the fly *) - begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with - | a, b -> ps_add_obligation cc a b - | exception Not_found -> - Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b - end - | E_and (a,b) -> decompose_explain cc a; decompose_explain cc b - - (* explain why [a = parent_a], where [a -> ... -> parent_a] in the - proof forest *) - let explain_along_path ps (a:node) (parent_a:node) : unit = - let rec aux n = - if n != parent_a then ( - match n.n_expl with - | FL_none -> assert false - | FL_some {next=next_n; expl=expl} -> - decompose_explain ps expl; - (* now prove [next_n = parent_a] *) - aux next_n - ) - in aux a - - (* find explanation *) - let explain_loop (cc : t) : lit list = - while not (Vec.is_empty cc.ps_queue) do - let a, b = Vec.pop cc.ps_queue in - Log.debugf 5 - (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); - assert (N.equal (find_ a) (find_ b)); - let c = find_common_ancestor a b in - explain_along_path cc a c; - explain_along_path cc b c; - done; - cc.ps_lits - - let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list = - ps_clear cc; - cc.ps_lits <- init; - ps_add_obligation cc n1 n2; - explain_loop cc - - let explain_unfold ?(init=[]) cc (e:explanation) : lit list = - ps_clear cc; - cc.ps_lits <- init; - decompose_explain cc e; - explain_loop cc - - (* add a term *) - let [@inline] rec add_term_rec_ cc t : node = - try T_tbl.find cc.tbl t - with Not_found -> add_new_term_ cc t - - (* add [t] to [cc] when not present already *) - and add_new_term_ cc (t:term) : node = - assert (not @@ mem cc t); - Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t); - let n = N.make t in - (* register sub-terms, add [t] to their parent list, and return the - corresponding initial signature *) - let sig0 = compute_sig0 cc n in - n.n_sig0 <- sig0; - (* remove term when we backtrack *) - on_backtrack cc - (fun () -> - Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t); - T_tbl.remove cc.tbl t); - (* add term to the table *) - T_tbl.add cc.tbl t n; - if CCOpt.is_some sig0 then ( - (* [n] might be merged with other equiv classes *) - push_pending cc n; - ); - (* initial theory data *) - let th_data = - List.fold_left - (fun data f -> - match f cc n t data with - | None -> data - | Some d -> d) - A.Data.empty cc.on_new_term - in - n.n_th_data <- th_data; - n - - (* compute the initial signature of the given node *) - and compute_sig0 (self:t) (n:node) : Signature.t option = - (* add sub-term to [cc], and register [n] to its parents *) - let deref_sub (u:term) : node = - let sub = add_term_rec_ self u in - (* add [n] to [sub.root]'s parent list *) - begin - let sub = find_ sub in - let old_parents = sub.n_parents in - on_backtrack self (fun () -> sub.n_parents <- old_parents); - sub.n_parents <- Bag.cons n sub.n_parents; - end; - sub - in - let[@inline] return x = Some x in - match T.cc_view n.n_term with - | Bool _ | Opaque _ -> None - | Eq (a,b) -> - let a = deref_sub a in - let b = deref_sub b in - return @@ Eq (a,b) - | Not u -> return @@ Not (deref_sub u) - | App_fun (f, args) -> - let args = args |> Iter.map deref_sub |> Iter.to_list in - if args<>[] then ( - return @@ App_fun (f, args) - ) else None - | App_ho (f, args) -> - let args = args |> Iter.map deref_sub |> Iter.to_list in - return @@ App_ho (deref_sub f, args) - | If (a,b,c) -> - return @@ If (deref_sub a, deref_sub b, deref_sub c) - - let[@inline] add_term cc t : node = add_term_rec_ cc t - - let set_as_lit cc (n:node) (lit:lit) : unit = - match n.n_as_lit with - | Some _ -> () - | None -> - Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit); - on_backtrack cc (fun () -> n.n_as_lit <- None); - n.n_as_lit <- Some lit - - let[@inline] n_is_bool (self:t) n : bool = - N.equal n (true_ self) || N.equal n (false_ self) - - (* main CC algo: add terms from [pending] to the signature table, - check for collisions *) - let rec update_tasks (cc:t) (acts:sat_actions) : unit = - while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do - while not @@ Vec.is_empty cc.pending do - task_pending_ cc (Vec.pop cc.pending); - done; - while not @@ Vec.is_empty cc.combine do - task_combine_ cc acts (Vec.pop cc.combine); - done; - done - - and task_pending_ cc (n:node) : unit = - N.set_field field_is_pending false n; - (* check if some parent collided *) - begin match n.n_sig0 with - | None -> () (* no-op *) - | Some (Eq (a,b)) -> - (* if [a=b] is now true, merge [(a=b)] and [true] *) - if same_class a b then ( - let expl = Expl.mk_merge a b in - merge_classes cc n (true_ cc) expl - ) - | Some (Not u) -> - (* [u = bool ==> not u = not bool] *) - let r_u = find_ u in - if N.equal r_u (true_ cc) then ( - let expl = Expl.mk_merge u (true_ cc) in - merge_classes cc n (false_ cc) expl - ) else if N.equal r_u (false_ cc) then ( - let expl = Expl.mk_merge u (false_ cc) in - merge_classes cc n (true_ cc) expl - ) - | Some s0 -> - (* update the signature by using [find] on each sub-node *) - let s = update_sig s0 in - match find_signature cc s with - | None -> - (* add to the signature table [sig(n) --> n] *) - add_signature cc s n - | Some u when n == u -> () - | Some u -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - assert (n != u); - let expl = Expl.mk_congruence n u in - merge_classes cc n u expl - end - - and[@inline] task_combine_ cc acts = function - | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab - - (* main CC algo: merge equivalence classes in [st.combine]. - @raise Exn_unsat if merge fails *) - and task_merge_ cc acts a b e_ab : unit = - let ra = find_ a in - let rb = find_ b in - if not @@ N.equal ra rb then ( - assert (N.is_root ra); - assert (N.is_root rb); - Stat.incr cc.count_merge; - (* check we're not merging [true] and [false] *) - if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) || - (N.equal rb (true_ cc) && N.equal ra (false_ cc)) then ( - Log.debugf 5 - (fun k->k "(@[cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])" - N.pp ra N.pp rb Expl.pp e_ab); - let lits = explain_unfold cc e_ab in - let lits = explain_eq_n ~init:lits cc a ra in - let lits = explain_eq_n ~init:lits cc b rb in - raise_conflict cc acts lits - ); - (* We will merge [r_from] into [r_into]. - we try to ensure that [size ra <= size rb] in general, but always - keep values as representative *) - let r_from, r_into = - if n_is_bool cc ra then rb, ra - else if n_is_bool cc rb then ra, rb - else if size_ ra > size_ rb then rb, ra - else ra, rb - in - (* when merging terms with [true] or [false], possibly propagate them to SAT *) - let merge_bool r1 t1 r2 t2 = - if N.equal r1 (true_ cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab true - ) else if N.equal r1 (false_ cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab false - ) - in - merge_bool ra a rb b; - merge_bool rb b ra a; - (* perform [union r_from r_into] *) - Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); - (* call [on_merge] functions, and merge theory data items *) - begin - let th_into = r_into.n_th_data in - let th_from = r_from.n_th_data in - let new_data = A.Data.merge th_into th_from in - (* restore old data, if it changed *) - if new_data != th_into then ( - on_backtrack cc (fun () -> r_into.n_th_data <- th_into); - ); - r_into.n_th_data <- new_data; - (* explanation is [a=ra & e_ab & b=rb] *) - let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in - List.iter (fun f -> f cc r_into th_into r_from th_from expl) cc.on_merge; - end; - begin - (* parents might have a different signature, check for collisions *) - N.iter_parents r_from - (fun parent -> push_pending cc parent); - (* for each node in [r_from]'s class, make it point to [r_into] *) - N.iter_class r_from - (fun u -> - assert (u.n_root == r_from); - u.n_root <- r_into); - (* now merge the classes *) - let r_into_old_next = r_into.n_next in - let r_from_old_next = r_from.n_next in - let r_into_old_parents = r_into.n_parents in - r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; - (* on backtrack, unmerge classes and restore the pointers to [r_from] *) - on_backtrack cc - (fun () -> - Log.debugf 15 - (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" - N.pp r_from N.pp r_into); - r_into.n_next <- r_into_old_next; - r_from.n_next <- r_from_old_next; - r_into.n_parents <- r_into_old_parents; - N.iter_class_ r_from (fun u -> u.n_root <- r_from); - ); - (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; - end; - (* update explanations (a -> b), arbitrarily. - Note that here we merge the classes by adding a bridge between [a] - and [b], not their roots. *) - begin - reroot_expl cc a; - assert (a.n_expl = FL_none); - (* on backtracking, link may be inverted, but we delete the one - that bridges between [a] and [b] *) - on_backtrack cc - (fun () -> match a.n_expl, b.n_expl with - | FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none - | _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none - | _ -> assert false); - a.n_expl <- FL_some {next=b; expl=e_ab}; - end; - ) - - (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] - in the equiv class of [r1] that is a known literal back to the SAT solver - and which is not the one initially merged. - We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) - and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = - (* explanation for [t1 =e= t2 = r2] *) - let half_expl = lazy ( - let expl = explain_unfold cc e_12 in - explain_eq_n ~init:expl cc r2 t2 - ) in - (* TODO: flag per class, `or`-ed on merge, to indicate if the class - contains at least one lit *) - N.iter_class r1 - (fun u1 -> - (* propagate if: - - [u1] is a proper literal - - [t2 != r2], because that can only happen - after an explicit merge (no way to obtain that by propagation) - *) - match N.as_lit u1 with - | Some lit when not (N.equal r2 t2) -> - let lit = if sign then lit else A.Lit.neg lit in (* apply sign *) - Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit); - (* complete explanation with the [u1=t1] chunk *) - let expl () = - let e = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in - e, A.Proof.default in - let reason = Msat.Consequence expl in - acts.Msat.acts_propagate lit reason - | _ -> ()) - - module Theory = struct - type cc = t - - (* raise a conflict *) - let raise_conflict cc expl = - Log.debugf 5 - (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - merge_classes cc (true_ cc) (false_ cc) expl - - let merge cc n1 n2 expl = - Log.debugf 5 - (fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl); - merge_classes cc n1 n2 expl - - let add_term = add_term - end - - let check_invariants_ (cc:t) = - Log.debug 5 "(cc.check-invariants)"; - Log.debugf 15 (fun k-> k "%a" pp_full cc); - assert (T.equal (T.bool cc.tst true) (true_ cc).n_term); - assert (T.equal (T.bool cc.tst false) (false_ cc).n_term); - assert (not @@ same_class (true_ cc) (false_ cc)); - assert (Vec.is_empty cc.combine); - assert (Vec.is_empty cc.pending); - (* check that subterms are internalized *) - T_tbl.iter - (fun t n -> - assert (T.equal t n.n_term); - assert (not @@ N.get_field field_is_pending n); - assert (N.equal n.n_root n.n_next.n_root); - (* check proper signature. - note that some signatures in the sig table can be obsolete (they - were not removed) but there must be a valid, up-to-date signature for - each term *) - begin match CCOpt.map update_sig n.n_sig0 with - | None -> () - | Some s -> - Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s); - (* add, but only if not present already *) - begin match Sig_tbl.find cc.signatures_tbl s with - | exception Not_found -> assert false - | repr_s -> assert (same_class n repr_s) - end - end; - ) - cc.tbl; - () - - let[@inline] check_invariants (cc:t) : unit = - if Util._CHECK_INVARIANTS then check_invariants_ cc - - let add_seq cc seq = - seq (fun t -> ignore @@ add_term_rec_ cc t); - () - - let[@inline] push_level (self:t) : unit = - Backtrack_stack.push_level self.undo - - let pop_levels (self:t) n : unit = - Vec.iter (N.set_field field_is_pending false) self.pending; - Vec.clear self.pending; - Vec.clear self.combine; - Log.debugf 15 - (fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo)); - Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); - () - - (* assert that this boolean literal holds. - if a lit is [= a b], merge [a] and [b]; - otherwise merge the atom with true/false *) - let assert_lit cc lit : unit = - let t = A.Lit.term lit in - Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit); - let sign = A.Lit.sign lit in - begin match T.cc_view t with - | Eq (a,b) when sign -> - let a = add_term cc a in - let b = add_term cc b in - (* merge [a] and [b] *) - merge_classes cc a b (Expl.mk_lit lit) - | _ -> - (* equate t and true/false *) - let rhs = if sign then true_ cc else false_ cc in - let n = add_term cc t in - (* TODO: ensure that this is O(1). - basically, just have [n] point to true/false and thus acquire - the corresponding value, so its superterms (like [ite]) can evaluate - properly *) - merge_classes cc n rhs (Expl.mk_lit lit) - end - - let[@inline] assert_lits cc lits : unit = - Iter.iter (assert_lit cc) lits - - let assert_eq cc t1 t2 (e:lit list) : unit = - let expl = Expl.mk_list @@ List.rev_map Expl.mk_lit e in - let n1 = add_term cc t1 in - let n2 = add_term cc t2 in - merge_classes cc n1 n2 expl - - let on_merge cc f = cc.on_merge <- f :: cc.on_merge - let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term - - let create ?(stat=Stat.global) - ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = - let size = match size with `Small -> 128 | `Big -> 2048 in - let rec cc = { - tst; - tbl = T_tbl.create size; - signatures_tbl = Sig_tbl.create size; - on_merge; - on_new_term; - pending=Vec.create(); - combine=Vec.create(); - ps_lits=[]; - undo=Backtrack_stack.create(); - ps_queue=Vec.create(); - true_; - false_; - stat; - count_conflict=Stat.mk_int stat "cc.conflicts"; - count_merge=Stat.mk_int stat "cc.merges"; - } and true_ = lazy ( - add_term cc (T.bool tst true) - ) and false_ = lazy ( - add_term cc (T.bool tst false) - ) - in - ignore (Lazy.force true_ : node); - ignore (Lazy.force false_ : node); - cc - - let[@inline] find_t cc t : repr = - let n = T_tbl.find cc.tbl t in - find_ n - - let[@inline] check cc acts : unit = - Log.debug 5 "(cc.check)"; - update_tasks cc acts - - (* model: map each uninterpreted equiv class to some ID *) - let mk_model (cc:t) (m:A.Model.t) : A.Model.t = - let module Model = A.Model in - let module Value = A.Value in - Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc); - let t_tbl = N_tbl.create 32 in - (* populate [repr -> value] table *) - T_tbl.values cc.tbl - (fun r -> - if N.is_root r then ( - (* find a value in the class, if any *) - let v = - N.iter_class r - |> Iter.find_map (fun n -> Model.eval m n.n_term) - in - let v = match v with - | Some v -> v - | None -> - if same_class r (true_ cc) then Value.true_ - else if same_class r (false_ cc) then Value.false_ - else Value.fresh r.n_term - in - N_tbl.add t_tbl r v - )); - (* now map every term to its representative's value *) - let pairs = - T_tbl.values cc.tbl - |> Iter.map - (fun n -> - let r = find_ n in - let v = - try N_tbl.find t_tbl r - with Not_found -> - Error.errorf "didn't allocate a value for repr %a" N.pp r - in - n.n_term, v) - in - let m = Iter.fold (fun m (t,v) -> Model.add t v m) m pairs in - Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m); - m -end diff --git a/src/cc/Congruence_closure.mli b/src/cc/Congruence_closure.mli deleted file mode 100644 index 717e327e..00000000 --- a/src/cc/Congruence_closure.mli +++ /dev/null @@ -1,13 +0,0 @@ -(** {2 Congruence Closure} *) - -module type ARG = Congruence_closure_intf.ARG -module type S = Congruence_closure_intf.S - -module Make(A: ARG) - : S with type term = A.Term.t - and type lit = A.Lit.t - and type fun_ = A.Fun.t - and type term_state = A.Term.state - and type proof = A.Proof.t - and type model = A.Model.t - and type th_data = A.Data.t diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml deleted file mode 100644 index 25e83f13..00000000 --- a/src/cc/Congruence_closure_intf.ml +++ /dev/null @@ -1,301 +0,0 @@ - -(** {1 Types used by the congruence closure} *) - -type ('f, 't, 'ts) view = - | Bool of bool - | App_fun of 'f * 'ts - | App_ho of 't * 'ts - | If of 't * 't * 't - | Eq of 't * 't - | Not of 't - | Opaque of 't (* do not enter *) - -let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = - match v with - | Bool b -> Bool b - | App_fun (f, args) -> App_fun (f_f f, f_ts args) - | App_ho (f, args) -> App_ho (f_t f, f_ts args) - | Not t -> Not (f_t t) - | If (a,b,c) -> If (f_t a, f_t b, f_t c) - | Eq (a,b) -> Eq (f_t a, f_t b) - | Opaque t -> Opaque (f_t t) - -let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit = - match v with - | Bool _ -> () - | App_fun (f, args) -> f_f f; f_ts args - | App_ho (f, args) -> f_t f; f_ts args - | Not t -> f_t t - | If (a,b,c) -> f_t a; f_t b; f_t c; - | Eq (a,b) -> f_t a; f_t b - | Opaque t -> f_t t - -module type TERM = sig - module Fun : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - end - - module Term : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - type state - - val bool : state -> bool -> t - - (** View the term through the lens of the congruence closure *) - val cc_view : t -> (Fun.t, t, t Iter.t) view - end -end - -module type TERM_LIT = sig - include TERM - - module Lit : sig - type t - val neg : t -> t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - val sign : t -> bool - val term : t -> Term.t - end -end - -module type ARG = sig - include TERM_LIT - - module Proof : sig - type t - val pp : t Fmt.printer - - val default : t - (* TODO: to give more details - val cc_lemma : unit -> t - *) - end - - module Ty : sig - type t - - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - end - - module Value : sig - type t - - val pp : t Fmt.printer - - val fresh : Term.t -> t - - val true_ : t - val false_ : t - end - - module Model : sig - type t - - val pp : t Fmt.printer - - val eval : t -> Term.t -> Value.t option - (** Evaluate the term in the current model *) - - val add : Term.t -> Value.t -> t -> t - end - - (** Monoid embedded in every node *) - module Data : sig - type t - - val empty : t - - val merge : t -> t -> t - end -end - -module type S = sig - type term_state - type term - type fun_ - type lit - type proof - type model - type th_data - - type t - (** Global state of the congruence closure *) - - (** An equivalence class is a set of terms that are currently equal - in the partial model built by the solver. - The class is represented by a collection of nodes, one of which is - distinguished and is called the "representative". - - All information pertaining to the whole equivalence class is stored - in this representative's node. - - When two classes become equal (are "merged"), one of the two - representatives is picked as the representative of the new class. - The new class contains the union of the two old classes' nodes. - - We also allow theories to store additional information in the - representative. This information can be used when two classes are - merged, to detect conflicts and solve equations à la Shostak. - *) - module N : sig - type t - - val term : t -> term - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - val is_root : t -> bool - (** Is the node a root (ie the representative of its class)? *) - - val iter_class : t -> t Iter.t - (** Traverse the congruence class. - Invariant: [is_root n] (see {!find} below) *) - - val iter_parents : t -> t Iter.t - (** Traverse the parents of the class. - Invariant: [is_root n] (see {!find} below) *) - - val th_data : t -> th_data - (** Access theory data for this node *) - - val get_field_usr1 : t -> bool - val set_field_usr1 : t -> bool -> unit - - val get_field_usr2 : t -> bool - val set_field_usr2 : t -> bool -> unit - end - - module Expl : sig - type t - val pp : t Fmt.printer - - val mk_merge : N.t -> N.t -> t - val mk_merge_t : term -> term -> t - val mk_lit : lit -> t - val mk_list : t list -> t - end - - type node = N.t - (** A node of the congruence closure *) - - type repr = N.t - (** Node that is currently a representative *) - - type explanation = Expl.t - - type conflict = lit list - - (** Accessors *) - - val term_state : t -> term_state - - val find : t -> node -> repr - (** Current representative *) - - val add_term : t -> term -> node - (** Add the term to the congruence closure, if not present already. - Will be backtracked. *) - - (** Actions available to the theory *) - type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts - - module Theory : sig - type cc = t - - val raise_conflict : cc -> Expl.t -> unit - (** Raise a conflict with the given explanation - it must be a theory tautology that [expl ==> absurd]. - To be used in theories. *) - - val merge : cc -> N.t -> N.t -> Expl.t -> unit - (** Merge these two nodes given this explanation. - It must be a theory tautology that [expl ==> n1 = n2]. - To be used in theories. *) - - val add_term : cc -> term -> N.t - (** Add/retrieve node for this term. - To be used in theories *) - end - - type ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit - type ev_on_new_term = t -> N.t -> term -> th_data -> th_data option - - val create : - ?stat:Stat.t -> - ?on_merge:ev_on_merge list -> - ?on_new_term:ev_on_new_term list -> - ?size:[`Small | `Big] -> - term_state -> - t - (** Create a new congruence closure. *) - - val on_merge : t -> ev_on_merge -> unit - (** Add a function to be called when two classes are merged *) - - val on_new_term : t -> ev_on_new_term -> unit - (** Add a function to be called when a new node is created *) - - val set_as_lit : t -> N.t -> lit -> unit - (** map the given node to a literal. *) - - val find_t : t -> term -> repr - (** Current representative of the term. - @raise Not_found if the term is not already {!add}-ed. *) - - val add_seq : t -> term Iter.t -> unit - (** Add a sequence of terms to the congruence closure *) - - val all_classes : t -> repr Iter.t - (** All current classes. This is costly, only use if there is no other solution *) - - val assert_lit : t -> lit -> unit - (** Given a literal, assume it in the congruence closure and propagate - its consequences. Will be backtracked. - - Useful for the theory combination or the SAT solver's functor *) - - val assert_lits : t -> lit Iter.t -> unit - (** Addition of many literals *) - - val assert_eq : t -> term -> term -> lit list -> unit - (** merge the given terms with some explanations *) - - (* TODO: remove and move into its own library as a micro theory - val assert_distinct : t -> term list -> neq:term -> lit -> unit - (** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct - because [lit] is true - precond: [u = distinct l] *) - *) - - val check : t -> sat_actions -> unit - (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. - Will use the [sat_actions] to propagate literals, declare conflicts, etc. *) - - val push_level : t -> unit - (** Push backtracking level *) - - val pop_levels : t -> int -> unit - (** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) - - val mk_model : t -> model -> model - (** Enrich a model by mapping terms to their representative's value, - if any. Otherwise map the representative to a fresh value *) - - (**/**) - val check_invariants : t -> unit - val pp_full : t Fmt.printer - (**/**) -end diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index f0bee810..25178ede 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -1,5 +1,7 @@ -type ('f, 't, 'ts) view = ('f, 't, 'ts) Congruence_closure_intf.view = +module View = Sidekick_core.CC_view + +type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t = | Bool of bool | App_fun of 'f * 'ts | App_ho of 't * 'ts @@ -8,12 +10,883 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) Congruence_closure_intf.view = | Not of 't | Opaque of 't (* do not enter *) -(** Parameter for the congruence closure *) -module type TERM_LIT = Congruence_closure_intf.TERM_LIT -module type ARG = Congruence_closure_intf.ARG -module type S = Congruence_closure.S +module type ARG = Sidekick_core.CC_ARG +module type S = Sidekick_core.CC_S -module Mini_cc = Mini_cc -module Congruence_closure = Congruence_closure -module Make = Congruence_closure.Make +module Bits = CCBitField.Make() +let field_is_pending = Bits.mk_field() +(** true iff the node is in the [cc.pending] queue *) + +let field_usr1 = Bits.mk_field() +(** General purpose *) + +let field_usr2 = Bits.mk_field() +(** General purpose *) + +let () = Bits.freeze() + +module Make(A: ARG) = struct + module A = A + type term = A.Term.t + type term_state = A.Term.state + type lit = A.Lit.t + type fun_ = A.Fun.t + type proof = A.Proof.t + type th_data = A.Data.t + type actions = A.Actions.t + + module T = A.Term + module Fun = A.Fun + + (** A node of the congruence closure. + An equivalence class is represented by its "root" element, + the representative. *) + type node = { + n_term: term; + mutable n_sig0: signature option; (* initial signature *) + mutable n_bits: Bits.t; (* bitfield for various properties *) + mutable n_parents: node Bag.t; (* parent terms of this node *) + mutable n_root: node; (* representative of congruence class (itself if a representative) *) + mutable n_next: node; (* pointer to next element of congruence class *) + mutable n_size: int; (* size of the class *) + mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) + mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) + mutable n_th_data: th_data; (* theory data *) + } + + and signature = (fun_, node, node list) view + + and explanation_forest_link = + | FL_none + | FL_some of { + next: node; + expl: explanation; + } + + (* atomic explanation in the congruence closure *) + and explanation = + | E_reduction (* by pure reduction, tautologically equal *) + | E_lit of lit (* because of this literal *) + | E_merge of node * node + | E_merge_t of term * term + | E_congruence of node * node (* caused by normal congruence *) + | E_and of explanation * explanation + + type repr = node + type conflict = lit list + + module N = struct + type t = node + + let[@inline] equal (n1:t) n2 = n1 == n2 + let[@inline] hash n = T.hash n.n_term + let[@inline] term n = n.n_term + let[@inline] pp out n = T.pp out n.n_term + let[@inline] as_lit n = n.n_as_lit + let[@inline] th_data n = n.n_th_data + + let make (t:term) : t = + let rec n = { + n_term=t; + n_sig0= None; + n_bits=Bits.empty; + n_parents=Bag.empty; + n_as_lit=None; (* TODO: provide a method to do it *) + n_root=n; + n_expl=FL_none; + n_next=n; + n_size=1; + n_th_data=A.Data.empty; + } in + n + + let[@inline] is_root (n:node) : bool = n.n_root == n + + (* traverse the equivalence class of [n] *) + let iter_class_ (n:node) : node Iter.t = + fun yield -> + let rec aux u = + yield u; + if u.n_next != n then aux u.n_next + in + aux n + + let[@inline] iter_class n = + assert (is_root n); + iter_class_ n + + let[@inline] iter_parents (n:node) : node Iter.t = + assert (is_root n); + Bag.to_seq n.n_parents + + let[@inline] get_field f t = Bits.get f t.n_bits + let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits + + let[@inline] get_field_usr1 t = get_field field_usr1 t + let[@inline] set_field_usr1 t b = set_field field_usr1 b t + let[@inline] get_field_usr2 t = get_field field_usr2 t + let[@inline] set_field_usr2 t b = set_field field_usr2 b t + end + + module N_tbl = CCHashtbl.Make(N) + + module Expl = struct + type t = explanation + + let rec pp out (e:explanation) = match e with + | E_reduction -> Fmt.string out "reduction" + | E_lit lit -> A.Lit.pp out lit + | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 + | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b + | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b + | E_and (a,b) -> + Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b + + let mk_reduction : t = E_reduction + let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) + let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) + let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b) + let[@inline] mk_lit l : t = E_lit l + + let rec mk_list l = + match l with + | [] -> mk_reduction + | [x] -> x + | E_reduction :: tl -> mk_list tl + | x :: y -> + match mk_list y with + | E_reduction -> x + | y' -> E_and (x,y') + end + + (** A signature is a shallow term shape where immediate subterms + are representative *) + module Signature = struct + type t = signature + let equal (s1:t) s2 : bool = + match s1, s2 with + | Bool b1, Bool b2 -> b1=b2 + | App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2 + | App_fun (f1,l1), App_fun (f2,l2) -> + Fun.equal f1 f2 && CCList.equal N.equal l1 l2 + | App_ho (f1,l1), App_ho (f2,l2) -> + N.equal f1 f2 && CCList.equal N.equal l1 l2 + | Not a, Not b -> N.equal a b + | If (a1,b1,c1), If (a2,b2,c2) -> + N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2 + | Eq (a1,b1), Eq (a2,b2) -> + N.equal a1 a2 && N.equal b1 b2 + | Opaque u1, Opaque u2 -> N.equal u1 u2 + | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ + | Eq _, _ | Opaque _, _ | Not _, _ + -> false + + let hash (s:t) : int = + let module H = CCHash in + match s with + | Bool b -> H.combine2 10 (H.bool b) + | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l) + | App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l) + | Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b) + | Opaque u -> H.combine2 50 (N.hash u) + | If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c) + | Not u -> H.combine2 70 (N.hash u) + + let pp out = function + | Bool b -> Fmt.bool out b + | App_fun (f, []) -> Fun.pp out f + | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l + | App_ho (f, []) -> N.pp out f + | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l + | Opaque t -> N.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" N.pp u + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b + | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c + end + + module Sig_tbl = CCHashtbl.Make(Signature) + module T_tbl = CCHashtbl.Make(T) + + type combine_task = + | CT_merge of node * node * explanation + + type t = { + tst: term_state; + tbl: node T_tbl.t; + (* internalization [term -> node] *) + signatures_tbl : node Sig_tbl.t; + (* map a signature to the corresponding node in some equivalence class. + A signature is a [term_cell] in which every immediate subterm + that participates in the congruence/evaluation relation + is normalized (i.e. is its own representative). + The critical property is that all members of an equivalence class + that have the same "shape" (including head symbol) + have the same signature *) + pending: node Vec.t; + combine: combine_task Vec.t; + undo: (unit -> unit) Backtrack_stack.t; + mutable on_merge: ev_on_merge list; + mutable on_new_term: ev_on_new_term list; + mutable ps_lits: lit list; (* TODO: thread it around instead? *) + (* proof state *) + ps_queue: (node*node) Vec.t; + (* pairs to explain *) + true_ : node lazy_t; + false_ : node lazy_t; + stat: Stat.t; + count_conflict: int Stat.counter; + count_merge: int Stat.counter; + } + (* TODO: an additional union-find to keep track, for each term, + of the terms they are known to be equal to, according + to the current explanation. That allows not to prove some equality + several times. + See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) + + and ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit + and ev_on_new_term = t -> N.t -> term -> th_data -> th_data option + + let[@inline] size_ (r:repr) = r.n_size + let[@inline] true_ cc = Lazy.force cc.true_ + let[@inline] false_ cc = Lazy.force cc.false_ + let[@inline] term_state cc = cc.tst + + let[@inline] on_backtrack cc f : unit = + Backtrack_stack.push_if_nonzero_level cc.undo f + + (* check if [t] is in the congruence closure. + Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) + let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t + + (* find representative, recursively *) + let[@unroll 2] rec find_rec (n:node) : repr = + if n==n.n_root then ( + n + ) else ( + let root = find_rec n.n_root in + if root != n.n_root then ( + n.n_root <- root; (* path compression *) + ); + root + ) + + (* non-recursive, inlinable function for [find] *) + let[@inline] find_ (n:node) : repr = + if n == n.n_root then n else find_rec n.n_root + + let[@inline] same_class (n1:node)(n2:node): bool = + N.equal (find_ n1) (find_ n2) + + let[@inline] find _ n = find_ n + + (* print full state *) + let pp_full out (cc:t) : unit = + let pp_next out n = + Fmt.fprintf out "@ :next %a" N.pp n.n_next in + let pp_root out n = + if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in + let pp_expl out n = match n.n_expl with + | FL_none -> () + | FL_some e -> + Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl + in + let pp_n out n = + Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n + and pp_sig_e out (s,n) = + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n + in + Fmt.fprintf out + "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ %a@])@])" + (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) + (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) + + (* compute up-to-date signature *) + let update_sig (s:signature) : Signature.t = + View.map_view s + ~f_f:(fun x->x) + ~f_t:find_ + ~f_ts:(List.map find_) + + (* find whether the given (parent) term corresponds to some signature + in [signatures_] *) + let[@inline] find_signature cc (s:signature) : repr option = + Sig_tbl.get cc.signatures_tbl s + + let add_signature cc (s:signature) (n:node) : unit = + (* add, but only if not present already *) + match Sig_tbl.find cc.signatures_tbl s with + | exception Not_found -> + Log.debugf 15 + (fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n); + on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); + Sig_tbl.add cc.signatures_tbl s n; + | r' -> + assert (same_class n r'); + () + + let push_pending cc t : unit = + if not @@ N.get_field field_is_pending t then ( + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); + N.set_field field_is_pending true t; + Vec.push cc.pending t + ) + + let merge_classes cc t u e : unit = + Log.debugf 5 + (fun k->k "(@[cc.push_combine@ %a ~@ %a@ :expl %a@])" + N.pp t N.pp u Expl.pp e); + Vec.push cc.combine @@ CT_merge (t,u,e) + + (* re-root the explanation tree of the equivalence class of [n] + so that it points to [n]. + postcondition: [n.n_expl = None] *) + let rec reroot_expl (cc:t) (n:node): unit = + let old_expl = n.n_expl in + begin match old_expl with + | FL_none -> () (* already root *) + | FL_some {next=u; expl=e_n_u} -> + reroot_expl cc u; + u.n_expl <- FL_some {next=n; expl=e_n_u}; + n.n_expl <- FL_none; + end + + let raise_conflict (cc:t) (acts:actions) (e:conflict) : _ = + (* clear tasks queue *) + Vec.iter (N.set_field field_is_pending false) cc.pending; + Vec.clear cc.pending; + Vec.clear cc.combine; + Stat.incr cc.count_conflict; + A.Actions.raise_conflict acts e A.Proof.default + + let[@inline] all_classes cc : repr Iter.t = + T_tbl.values cc.tbl + |> Iter.filter N.is_root + + (* TODO: use markers and lockstep iteration instead *) + (* distance from [t] to its root in the proof forest *) + let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with + | FL_none -> 0 + | FL_some {next=t'; _} -> 1 + distance_to_root t' + + (* TODO: new bool flag on nodes + stepwise progress + cleanup *) + (* find the closest common ancestor of [a] and [b] in the proof forest *) + let find_common_ancestor (a:node) (b:node) : node = + let d_a = distance_to_root a in + let d_b = distance_to_root b in + (* drop [n] nodes in the path from [t] to its root *) + let rec drop_ n t = + if n=0 then t + else match t.n_expl with + | FL_none -> assert false + | FL_some {next=t'; _} -> drop_ (n-1) t' + in + (* reduce to the problem where [a] and [b] have the same distance to root *) + let a, b = + if d_a > d_b then drop_ (d_a-d_b) a, b + else if d_a < d_b then a, drop_ (d_b-d_a) b + else a, b + in + (* traverse stepwise until a==b *) + let rec aux_same_dist a b = + if a==b then a + else match a.n_expl, b.n_expl with + | FL_none, _ | _, FL_none -> assert false + | FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b' + in + aux_same_dist a b + + let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) + let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits + + let ps_clear (cc:t) = + cc.ps_lits <- []; + Vec.clear cc.ps_queue; + () + + (* decompose explanation [e] of why [n1 = n2] *) + let rec decompose_explain cc (e:explanation) : unit = + Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); + match e with + | E_reduction -> () + | E_congruence (n1, n2) -> + begin match n1.n_sig0, n2.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Fun.equal f1 f2); + assert (List.length a1 = List.length a2); + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + assert (List.length a1 = List.length a2); + ps_add_obligation cc f1 f2; + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> + ps_add_obligation cc a1 a2; + ps_add_obligation cc b1 b2; + ps_add_obligation cc c1 c2; + | _ -> + assert false + end + | E_lit lit -> ps_add_lit cc lit + | E_merge (a,b) -> ps_add_obligation cc a b + | E_merge_t (a,b) -> + (* find nodes for [a] and [b] on the fly *) + begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with + | a, b -> ps_add_obligation cc a b + | exception Not_found -> + Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b + end + | E_and (a,b) -> decompose_explain cc a; decompose_explain cc b + + (* explain why [a = parent_a], where [a -> ... -> parent_a] in the + proof forest *) + let explain_along_path ps (a:node) (parent_a:node) : unit = + let rec aux n = + if n != parent_a then ( + match n.n_expl with + | FL_none -> assert false + | FL_some {next=next_n; expl=expl} -> + decompose_explain ps expl; + (* now prove [next_n = parent_a] *) + aux next_n + ) + in aux a + + (* find explanation *) + let explain_loop (cc : t) : lit list = + while not (Vec.is_empty cc.ps_queue) do + let a, b = Vec.pop cc.ps_queue in + Log.debugf 5 + (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); + assert (N.equal (find_ a) (find_ b)); + let c = find_common_ancestor a b in + explain_along_path cc a c; + explain_along_path cc b c; + done; + cc.ps_lits + + let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list = + ps_clear cc; + cc.ps_lits <- init; + ps_add_obligation cc n1 n2; + explain_loop cc + + let explain_unfold ?(init=[]) cc (e:explanation) : lit list = + ps_clear cc; + cc.ps_lits <- init; + decompose_explain cc e; + explain_loop cc + + (* add a term *) + let [@inline] rec add_term_rec_ cc t : node = + try T_tbl.find cc.tbl t + with Not_found -> add_new_term_ cc t + + (* add [t] to [cc] when not present already *) + and add_new_term_ cc (t:term) : node = + assert (not @@ mem cc t); + Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t); + let n = N.make t in + (* register sub-terms, add [t] to their parent list, and return the + corresponding initial signature *) + let sig0 = compute_sig0 cc n in + n.n_sig0 <- sig0; + (* remove term when we backtrack *) + on_backtrack cc + (fun () -> + Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t); + T_tbl.remove cc.tbl t); + (* add term to the table *) + T_tbl.add cc.tbl t n; + if CCOpt.is_some sig0 then ( + (* [n] might be merged with other equiv classes *) + push_pending cc n; + ); + (* initial theory data *) + let th_data = + List.fold_left + (fun data f -> + match f cc n t data with + | None -> data + | Some d -> d) + A.Data.empty cc.on_new_term + in + n.n_th_data <- th_data; + n + + (* compute the initial signature of the given node *) + and compute_sig0 (self:t) (n:node) : Signature.t option = + (* add sub-term to [cc], and register [n] to its parents *) + let deref_sub (u:term) : node = + let sub = find_ @@ add_term_rec_ self u in + (* add [n] to [sub.root]'s parent list *) + begin + let old_parents = sub.n_parents in + on_backtrack self (fun () -> sub.n_parents <- old_parents); + sub.n_parents <- Bag.cons n sub.n_parents; + end; + sub + in + let[@inline] return x = Some x in + match T.cc_view n.n_term with + | Bool _ | Opaque _ -> None + | Eq (a,b) -> + let a = deref_sub a in + let b = deref_sub b in + return @@ Eq (a,b) + | Not u -> return @@ Not (deref_sub u) + | App_fun (f, args) -> + let args = args |> Iter.map deref_sub |> Iter.to_list in + if args<>[] then ( + return @@ App_fun (f, args) + ) else None + | App_ho (f, args) -> + let args = args |> Iter.map deref_sub |> Iter.to_list in + return @@ App_ho (deref_sub f, args) + | If (a,b,c) -> + return @@ If (deref_sub a, deref_sub b, deref_sub c) + + let[@inline] add_term cc t : node = add_term_rec_ cc t + + let set_as_lit cc (n:node) (lit:lit) : unit = + match n.n_as_lit with + | Some _ -> () + | None -> + Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit); + on_backtrack cc (fun () -> n.n_as_lit <- None); + n.n_as_lit <- Some lit + + let n_is_bool (self:t) n : bool = + N.equal n (true_ self) || N.equal n (false_ self) + + (* main CC algo: add terms from [pending] to the signature table, + check for collisions *) + let rec update_tasks (cc:t) (acts:actions) : unit = + while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do + while not @@ Vec.is_empty cc.pending do + task_pending_ cc (Vec.pop cc.pending); + done; + while not @@ Vec.is_empty cc.combine do + task_combine_ cc acts (Vec.pop cc.combine); + done; + done + + and task_pending_ cc (n:node) : unit = + N.set_field field_is_pending false n; + (* check if some parent collided *) + begin match n.n_sig0 with + | None -> () (* no-op *) + | Some (Eq (a,b)) -> + (* if [a=b] is now true, merge [(a=b)] and [true] *) + if same_class a b then ( + let expl = Expl.mk_merge a b in + merge_classes cc n (true_ cc) expl + ) + | Some (Not u) -> + (* [u = bool ==> not u = not bool] *) + let r_u = find_ u in + if N.equal r_u (true_ cc) then ( + let expl = Expl.mk_merge u (true_ cc) in + merge_classes cc n (false_ cc) expl + ) else if N.equal r_u (false_ cc) then ( + let expl = Expl.mk_merge u (false_ cc) in + merge_classes cc n (true_ cc) expl + ) + | Some s0 -> + (* update the signature by using [find] on each sub-node *) + let s = update_sig s0 in + match find_signature cc s with + | None -> + (* add to the signature table [sig(n) --> n] *) + add_signature cc s n + | Some u when n == u -> () + | Some u -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = Expl.mk_congruence n u in + merge_classes cc n u expl + end + + and[@inline] task_combine_ cc acts = function + | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab + + (* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) + and task_merge_ cc acts a b e_ab : unit = + let ra = find_ a in + let rb = find_ b in + if not @@ N.equal ra rb then ( + assert (N.is_root ra); + assert (N.is_root rb); + Stat.incr cc.count_merge; + (* check we're not merging [true] and [false] *) + if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) || + (N.equal rb (true_ cc) && N.equal ra (false_ cc)) then ( + Log.debugf 5 + (fun k->k "(@[cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])" + N.pp ra N.pp rb Expl.pp e_ab); + let lits = explain_unfold cc e_ab in + let lits = explain_eq_n ~init:lits cc a ra in + let lits = explain_eq_n ~init:lits cc b rb in + raise_conflict cc acts lits + ); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, but always + keep values as representative *) + let r_from, r_into = + if n_is_bool cc ra then rb, ra + else if n_is_bool cc rb then ra, rb + else if size_ ra > size_ rb then rb, ra + else ra, rb + in + (* when merging terms with [true] or [false], possibly propagate them to SAT *) + let merge_bool r1 t1 r2 t2 = + if N.equal r1 (true_ cc) then ( + propagate_bools cc acts r2 t2 r1 t1 e_ab true + ) else if N.equal r1 (false_ cc) then ( + propagate_bools cc acts r2 t2 r1 t1 e_ab false + ) + in + merge_bool ra a rb b; + merge_bool rb b ra a; + (* perform [union r_from r_into] *) + Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + (* call [on_merge] functions, and merge theory data items *) + begin + let th_into = r_into.n_th_data in + let th_from = r_from.n_th_data in + let new_data = A.Data.merge th_into th_from in + (* restore old data, if it changed *) + if new_data != th_into then ( + on_backtrack cc (fun () -> r_into.n_th_data <- th_into); + ); + r_into.n_th_data <- new_data; + (* explanation is [a=ra & e_ab & b=rb] *) + let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in + List.iter (fun f -> f cc r_into th_into r_from th_from expl) cc.on_merge; + end; + begin + (* parents might have a different signature, check for collisions *) + N.iter_parents r_from + (fun parent -> push_pending cc parent); + (* for each node in [r_from]'s class, make it point to [r_into] *) + N.iter_class r_from + (fun u -> + assert (u.n_root == r_from); + u.n_root <- r_into); + (* now merge the classes *) + let r_into_old_next = r_into.n_next in + let r_from_old_next = r_from.n_next in + let r_into_old_parents = r_into.n_parents in + r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; + (* on backtrack, unmerge classes and restore the pointers to [r_from] *) + on_backtrack cc + (fun () -> + Log.debugf 15 + (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" + N.pp r_from N.pp r_into); + r_into.n_next <- r_into_old_next; + r_from.n_next <- r_from_old_next; + r_into.n_parents <- r_into_old_parents; + N.iter_class_ r_from (fun u -> u.n_root <- r_from); + ); + (* swap [into.next] and [from.next], merging the classes *) + r_into.n_next <- r_from_old_next; + r_from.n_next <- r_into_old_next; + end; + (* update explanations (a -> b), arbitrarily. + Note that here we merge the classes by adding a bridge between [a] + and [b], not their roots. *) + begin + reroot_expl cc a; + assert (a.n_expl = FL_none); + (* on backtracking, link may be inverted, but we delete the one + that bridges between [a] and [b] *) + on_backtrack cc + (fun () -> match a.n_expl, b.n_expl with + | FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none + | _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none + | _ -> assert false); + a.n_expl <- FL_some {next=b; expl=e_ab}; + end; + ) + + (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] + in the equiv class of [r1] that is a known literal back to the SAT solver + and which is not the one initially merged. + We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) + and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = + (* explanation for [t1 =e= t2 = r2] *) + let half_expl = lazy ( + let expl = explain_unfold cc e_12 in + explain_eq_n ~init:expl cc r2 t2 + ) in + (* TODO: flag per class, `or`-ed on merge, to indicate if the class + contains at least one lit *) + N.iter_class r1 + (fun u1 -> + (* propagate if: + - [u1] is a proper literal + - [t2 != r2], because that can only happen + after an explicit merge (no way to obtain that by propagation) + *) + match N.as_lit u1 with + | Some lit when not (N.equal r2 t2) -> + let lit = if sign then lit else A.Lit.neg lit in (* apply sign *) + Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit); + (* complete explanation with the [u1=t1] chunk *) + let reason yield = + let e = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in + List.iter yield e + in + A.Actions.propagate acts lit ~reason A.Proof.default + | _ -> ()) + + module Theory = struct + type cc = t + + (* raise a conflict *) + let raise_conflict cc expl = + Log.debugf 5 + (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); + merge_classes cc (true_ cc) (false_ cc) expl + + let merge cc n1 n2 expl = + Log.debugf 5 + (fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl); + merge_classes cc n1 n2 expl + + let add_term = add_term + end + + let check_invariants_ (cc:t) = + Log.debug 5 "(cc.check-invariants)"; + Log.debugf 15 (fun k-> k "%a" pp_full cc); + assert (T.equal (T.bool cc.tst true) (true_ cc).n_term); + assert (T.equal (T.bool cc.tst false) (false_ cc).n_term); + assert (not @@ same_class (true_ cc) (false_ cc)); + assert (Vec.is_empty cc.combine); + assert (Vec.is_empty cc.pending); + (* check that subterms are internalized *) + T_tbl.iter + (fun t n -> + assert (T.equal t n.n_term); + assert (not @@ N.get_field field_is_pending n); + assert (N.equal n.n_root n.n_next.n_root); + (* check proper signature. + note that some signatures in the sig table can be obsolete (they + were not removed) but there must be a valid, up-to-date signature for + each term *) + begin match CCOpt.map update_sig n.n_sig0 with + | None -> () + | Some s -> + Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s); + (* add, but only if not present already *) + begin match Sig_tbl.find cc.signatures_tbl s with + | exception Not_found -> assert false + | repr_s -> assert (same_class n repr_s) + end + end; + ) + cc.tbl; + () + + module Debug_ = struct + let[@inline] check_invariants (cc:t) : unit = + if Util._CHECK_INVARIANTS then check_invariants_ cc + let pp out _ = Fmt.string out "cc" + end + + let add_seq cc seq = + seq (fun t -> ignore @@ add_term_rec_ cc t); + () + + let[@inline] push_level (self:t) : unit = + Backtrack_stack.push_level self.undo + + let pop_levels (self:t) n : unit = + Vec.iter (N.set_field field_is_pending false) self.pending; + Vec.clear self.pending; + Vec.clear self.combine; + Log.debugf 15 + (fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo)); + Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); + () + + (* assert that this boolean literal holds. + if a lit is [= a b], merge [a] and [b]; + otherwise merge the atom with true/false *) + let assert_lit cc lit : unit = + let t = A.Lit.term lit in + Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit); + let sign = A.Lit.sign lit in + begin match T.cc_view t with + | Eq (a,b) when sign -> + let a = add_term cc a in + let b = add_term cc b in + (* merge [a] and [b] *) + merge_classes cc a b (Expl.mk_lit lit) + | _ -> + (* equate t and true/false *) + let rhs = if sign then true_ cc else false_ cc in + let n = add_term cc t in + (* TODO: ensure that this is O(1). + basically, just have [n] point to true/false and thus acquire + the corresponding value, so its superterms (like [ite]) can evaluate + properly *) + merge_classes cc n rhs (Expl.mk_lit lit) + end + + let[@inline] assert_lits cc lits : unit = + Iter.iter (assert_lit cc) lits + + let assert_eq cc t1 t2 (e:lit list) : unit = + let expl = Expl.mk_list @@ List.rev_map Expl.mk_lit e in + let n1 = add_term cc t1 in + let n2 = add_term cc t2 in + merge_classes cc n1 n2 expl + + let on_merge cc f = cc.on_merge <- f :: cc.on_merge + let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term + + let create ?(stat=Stat.global) + ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = + let size = match size with `Small -> 128 | `Big -> 2048 in + let rec cc = { + tst; + tbl = T_tbl.create size; + signatures_tbl = Sig_tbl.create size; + on_merge; + on_new_term; + pending=Vec.create(); + combine=Vec.create(); + ps_lits=[]; + undo=Backtrack_stack.create(); + ps_queue=Vec.create(); + true_; + false_; + stat; + count_conflict=Stat.mk_int stat "cc.conflicts"; + count_merge=Stat.mk_int stat "cc.merges"; + } and true_ = lazy ( + add_term cc (T.bool tst true) + ) and false_ = lazy ( + add_term cc (T.bool tst false) + ) + in + ignore (Lazy.force true_ : node); + ignore (Lazy.force false_ : node); + cc + + let[@inline] find_t cc t : repr = + let n = T_tbl.find cc.tbl t in + find_ n + + let[@inline] check cc acts : unit = + Log.debug 5 "(cc.check)"; + update_tasks cc acts + + (* model: return all the classes *) + let get_model (cc:t) : repr Iter.t Iter.t = + all_classes cc |> Iter.map N.iter_class +end diff --git a/src/cc/Sidekick_cc.mli b/src/cc/Sidekick_cc.mli new file mode 100644 index 00000000..07259b23 --- /dev/null +++ b/src/cc/Sidekick_cc.mli @@ -0,0 +1,6 @@ +(** {2 Congruence Closure} *) + +module type ARG = Sidekick_core.CC_ARG +module type S = Sidekick_core.CC_S + +module Make(A: ARG) : S with module A = A diff --git a/src/cc/dune b/src/cc/dune index 0fa44a8f..c8bf6e36 100644 --- a/src/cc/dune +++ b/src/cc/dune @@ -3,8 +3,5 @@ (library (name Sidekick_cc) (public_name sidekick.cc) - (libraries containers containers.data msat iter sidekick.util) - (flags :standard -warn-error -a+8 - -color always -safe-string -short-paths -open Sidekick_util) - (ocamlopt_flags :standard -O3 -color always - -unbox-closures -unbox-closures-factor 20)) + (libraries containers containers.data iter sidekick.core sidekick.util) + (flags :standard -open Sidekick_util)) diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index fa69b106..53fdc782 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -75,8 +75,12 @@ module type TERM_LIT = sig val hash : t -> int val pp : t Fmt.printer - val sign : t -> bool val term : t -> Term.t + val sign : t -> bool + val abs : t -> t + val apply_sign : t -> bool -> t + val norm_sign : t -> t * bool + (** Invariant: if [u, sign = norm_sign t] then [apply_sign u sign = t] *) end end @@ -88,7 +92,8 @@ module type CC_ARG = sig val pp : t Fmt.printer val default : t - (* TODO: to give more details + (* TODO: to give more details? or make this extensible? + or have a generative function for new proof cstors? val cc_lemma : unit -> t *) end @@ -104,44 +109,17 @@ module type CC_ARG = sig (** Monoid embedded in every node *) module Data : sig type t - val empty : t - val merge : t -> t -> t end module Actions : sig type t - val raise_conflict : t -> Lit.t list -> 'a + val raise_conflict : t -> Lit.t list -> Proof.t -> 'a - val propagate : t -> Lit.t -> reason:Lit.t Iter.t -> unit + val propagate : t -> Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit end - - (* TODO: instead, provide model as a `equiv_class Iter.t`, for the - benefit of $whatever_theory_combination_method? - module Value : sig - type t - - val pp : t Fmt.printer - - val fresh : Term.t -> t - - val true_ : t - val false_ : t - end - - module Model : sig - type t - - val pp : t Fmt.printer - - val eval : t -> Term.t -> Value.t option - (** Evaluate the term in the current model *) - - val add : Term.t -> Value.t -> t -> t - end - *) end module type CC_S = sig diff --git a/src/main/dune b/src/main/dune index 8d9cf100..6d74cb43 100644 --- a/src/main/dune +++ b/src/main/dune @@ -4,8 +4,9 @@ (name main) (public_name sidekick) (package sidekick) - (libraries containers iter result msat sidekick.smt sidekick.smtlib - sidekick.smt.th-ite sidekick.dimacs) + (libraries containers iter result msat sidekick.core + sidekick.base-term sidekick.msat-solver sidekick.smtlib + sidekick.smt.th-ite sidekick.dimacs) (flags :standard -w +a-4-42-44-48-50-58-32-60@8 -safe-string -color always -open Sidekick_util) (ocamlopt_flags :standard -O3 -color always diff --git a/src/cc/Mini_cc.ml b/src/mini-cc/Mini_cc.ml similarity index 98% rename from src/cc/Mini_cc.ml rename to src/mini-cc/Mini_cc.ml index 9b5a37dd..66a000a5 100644 --- a/src/cc/Mini_cc.ml +++ b/src/mini-cc/Mini_cc.ml @@ -1,11 +1,10 @@ -open Congruence_closure_intf - type res = | Sat | Unsat -module type TERM = Congruence_closure_intf.TERM +module CC_view = Sidekick_core.CC_view +module type TERM = Sidekick_core.TERM module type S = sig type term @@ -26,6 +25,8 @@ end module Make(A: TERM) = struct + open CC_view + module Fun = A.Fun module T = A.Term type fun_ = A.Fun.t @@ -42,7 +43,7 @@ module Make(A: TERM) = struct mutable n_root: node; } - type signature = (fun_, node, node list) view + type signature = (fun_, node, node list) CC_view.t module Node = struct type t = node diff --git a/src/cc/Mini_cc.mli b/src/mini-cc/Mini_cc.mli similarity index 91% rename from src/cc/Mini_cc.mli rename to src/mini-cc/Mini_cc.mli index 6f96c723..10afc715 100644 --- a/src/cc/Mini_cc.mli +++ b/src/mini-cc/Mini_cc.mli @@ -6,13 +6,12 @@ It just decides the satisfiability of a set of (dis)equations. *) -open Congruence_closure_intf - type res = | Sat | Unsat -module type TERM = Congruence_closure_intf.TERM +module CC_view = Sidekick_core.CC_view +module type TERM = Sidekick_core.TERM module type S = sig type term diff --git a/src/mini-cc/dune b/src/mini-cc/dune new file mode 100644 index 00000000..e32584a1 --- /dev/null +++ b/src/mini-cc/dune @@ -0,0 +1,7 @@ + + +(library + (name Sidekick_mini_cc) + (public_name sidekick.mini-cc) + (libraries containers iter sidekick.core sidekick.util) + (flags :standard -open Sidekick_util)) diff --git a/src/smt/CC.ml b/src/msat-solver/CC.ml similarity index 100% rename from src/smt/CC.ml rename to src/msat-solver/CC.ml diff --git a/src/smt/CC.mli b/src/msat-solver/CC.mli similarity index 100% rename from src/smt/CC.mli rename to src/msat-solver/CC.mli diff --git a/src/smt/DESIGN.md b/src/msat-solver/DESIGN.md similarity index 100% rename from src/smt/DESIGN.md rename to src/msat-solver/DESIGN.md diff --git a/src/smt/Sidekick_smt.ml b/src/msat-solver/Sidekick_smt.ml similarity index 100% rename from src/smt/Sidekick_smt.ml rename to src/msat-solver/Sidekick_smt.ml diff --git a/src/smt/Solver.ml b/src/msat-solver/Solver.ml similarity index 100% rename from src/smt/Solver.ml rename to src/msat-solver/Solver.ml diff --git a/src/smt/Solver.mli b/src/msat-solver/Solver.mli similarity index 100% rename from src/smt/Solver.mli rename to src/msat-solver/Solver.mli diff --git a/src/smt/Solver_types.ml b/src/msat-solver/Solver_types.ml similarity index 97% rename from src/smt/Solver_types.ml rename to src/msat-solver/Solver_types.ml index db4f9e9c..3c95c320 100644 --- a/src/smt/Solver_types.ml +++ b/src/msat-solver/Solver_types.ml @@ -4,11 +4,6 @@ module Log = Msat.Log module Fmt = CCFormat -(* for objects that are expanded on demand only *) -type 'a lazily_expanded = - | Lazy_some of 'a - | Lazy_none - (* main term cell. *) type term = { mutable term_id: int; (* unique ID *) diff --git a/src/smt/Theory.ml b/src/msat-solver/Theory.ml similarity index 100% rename from src/smt/Theory.ml rename to src/msat-solver/Theory.ml diff --git a/src/smt/Theory_combine.ml b/src/msat-solver/Theory_combine.ml similarity index 100% rename from src/smt/Theory_combine.ml rename to src/msat-solver/Theory_combine.ml diff --git a/src/smt/Theory_combine.mli b/src/msat-solver/Theory_combine.mli similarity index 100% rename from src/smt/Theory_combine.mli rename to src/msat-solver/Theory_combine.mli diff --git a/src/msat-solver/dune b/src/msat-solver/dune new file mode 100644 index 00000000..d774e2eb --- /dev/null +++ b/src/msat-solver/dune @@ -0,0 +1,7 @@ + +(library + (name Sidekick_msat_solver) + (public_name sidekick.msat-solver) + (libraries containers containers.data iter + sidekick.core sidekick.util sidekick.cc msat zarith) + (flags :standard -open Sidekick_util)) diff --git a/src/msat-solver/th_key.ml.bak b/src/msat-solver/th_key.ml.bak new file mode 100644 index 00000000..cd8c7194 --- /dev/null +++ b/src/msat-solver/th_key.ml.bak @@ -0,0 +1,145 @@ + + +module type S = sig + type ('term,'lit,'a) t + (** An access key for theories which have per-class data ['a] *) + + val create : + ?pp:'a Fmt.printer -> + name:string -> + eq:('a -> 'a -> bool) -> + merge:('a -> 'a -> 'a) -> + unit -> + ('term,'lit,'a) t + (** Generative creation of keys for the given theory data. + + @param eq : Equality. This is used to optimize backtracking info. + + @param merge : + [merge d1 d2] is called when merging classes with data [d1] and [d2] + respectively. The theory should already have checked that the merge + is compatible, and this produces the combined data for terms in the + merged class. + @param name name of the theory which owns this data + @param pp a printer for the data + *) + + val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool + (** Checks if two keys are equal (generatively) *) + + val pp : _ t Fmt.printer + (** Prints the name of the key. *) +end + + +(** Custom keys for theory data. + This imitates the classic tricks for heterogeneous maps + https://blog.janestreet.com/a-universal-type/ + + It needs to form a commutative monoid where values are persistent so + they can be restored during backtracking. +*) +module Key = struct + module type KEY_IMPL = sig + type term + type lit + type t + val id : int + val name : string + val pp : t Fmt.printer + val equal : t -> t -> bool + val merge : t -> t -> t + exception Store of t + end + + type ('term,'lit,'a) t = + (module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a) + + let n_ = ref 0 + + let create (type term)(type lit)(type d) + ?(pp=fun out _ -> Fmt.string out "") + ~name ~eq ~merge () : (term,lit,d) t = + let module K = struct + type nonrec term = term + type nonrec lit = lit + type t = d + let id = !n_ + let name = name + let pp = pp + let merge = merge + let equal = eq + exception Store of d + end in + incr n_; + (module K) + + let[@inline] id + : type term lit a. (term,lit,a) t -> int + = fun (module K) -> K.id + + let[@inline] equal + : type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool + = fun (module K1) (module K2) -> K1.id = K2.id + + let pp + : type term lit a. (term,lit,a) t Fmt.printer + = fun out (module K) -> Fmt.string out K.name +end + + + +(* + (** Map for theory data associated with representatives *) + module K_map = struct + type 'a key = (term,lit,'a) Key.t + type pair = Pair : 'a key * exn -> pair + + type t = pair IM.t + + let empty = IM.empty + + let[@inline] mem k t = IM.mem (Key.id k) t + + let find (type a) (k : a key) (self:t) : a option = + let (module K) = k in + match IM.find K.id self with + | Pair (_, K.Store v) -> Some v + | _ -> None + | exception Not_found -> None + + let add (type a) (k : a key) (v:a) (self:t) : t = + let (module K) = k in + IM.add K.id (Pair (k, K.Store v)) self + + let remove (type a) (k: a key) self : t = + let (module K) = k in + IM.remove K.id self + + let equal (m1:t) (m2:t) : bool = + IM.equal + (fun p1 p2 -> + let Pair ((module K1), v1) = p1 in + let Pair ((module K2), v2) = p2 in + assert (K1.id = K2.id); + match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false) + m1 m2 + + let merge ~f_both (m1:t) (m2:t) : t = + IM.merge + (fun _ p1 p2 -> + match p1, p2 with + | None, None -> None + | Some v, None + | None, Some v -> Some v + | Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) -> + match pair1, pair2 with + | K1.Store v1, K1.Store v2 -> + f_both K1.id pair1 pair2; (* callback for checking compat *) + let v12 = K1.merge v1 v2 in (* merge content *) + Some (Pair (key1, K1.Store v12)) + | _ -> assert false + ) + m1 m2 + end + *) diff --git a/src/smt/dune b/src/smt/dune deleted file mode 100644 index 0d2c2890..00000000 --- a/src/smt/dune +++ /dev/null @@ -1,10 +0,0 @@ - -(library - (name Sidekick_smt) - (public_name sidekick.smt) - (libraries containers containers.data iter - sidekick.util sidekick.cc msat zarith) - (flags :standard -warn-error -a+8 - -color always -safe-string -short-paths -open Sidekick_util) - (ocamlopt_flags :standard -O3 -color always - -unbox-closures -unbox-closures-factor 20)) diff --git a/src/smtlib/dune b/src/smtlib/dune index 7f1b23d5..d8c89385 100644 --- a/src/smtlib/dune +++ b/src/smtlib/dune @@ -2,8 +2,9 @@ (library (name sidekick_smtlib) (public_name sidekick.smtlib) - (libraries containers zarith msat sidekick.smt sidekick.util - sidekick.smt.th-bool sidekick.smt.th-distinct msat.backend) + (libraries containers zarith msat sidekick.core sidekick.util + sidekick.msat-solver sidekick.base-term + sidekick.smt.th-bool sidekick.smt.th-distinct msat.backend) (flags :standard -open Sidekick_util)) (menhir (modules Parser)) diff --git a/src/th-bool/dune b/src/th-bool/dune index 248a759c..8d8b3005 100644 --- a/src/th-bool/dune +++ b/src/th-bool/dune @@ -1,6 +1,6 @@ (library (name Sidekick_th_bool) (public_name sidekick.smt.th-bool) - (libraries containers sidekick.smt) + (libraries containers sidekick.core sidekick.util) (flags :standard -open Sidekick_util)) diff --git a/src/th-cstor/dune b/src/th-cstor/dune index 667e77f2..b4e89620 100644 --- a/src/th-cstor/dune +++ b/src/th-cstor/dune @@ -2,6 +2,6 @@ (library (name Sidekick_th_cstor) (public_name sidekick.smt.th-cstor) - (libraries containers sidekick.smt) + (libraries containers sidekick.core sidekick.util) (flags :standard -open Sidekick_util)) diff --git a/src/th-distinct/Sidekick_th_distinct.ml b/src/th-distinct/Sidekick_th_distinct.ml index 94aa6343..ca3588d2 100644 --- a/src/th-distinct/Sidekick_th_distinct.ml +++ b/src/th-distinct/Sidekick_th_distinct.ml @@ -1,25 +1,10 @@ -module Term = Sidekick_smt.Term -module Theory = Sidekick_smt.Theory - module type ARG = sig - module T : sig - type t - type state - val pp : t Fmt.printer - val equal : t -> t -> bool - val hash : t -> int - val as_distinct : t -> t Iter.t option - val mk_eq : state -> t -> t -> t - end - module Lit : sig - type t - val term : t -> T.t - val neg : t -> t - val sign : t -> bool - val compare : t -> t -> int - val atom : T.state -> ?sign:bool -> T.t -> t - val pp : t Fmt.printer + include Sidekick_core.TERM_LIT + + module Arg_distinct : sig + val as_distinct : Term.t -> Term.t Iter.t option + val mk_eq : Term.state -> Term.t -> Term.t -> Term.t end end @@ -28,8 +13,12 @@ module type S = sig type term_state type lit - type data - val key : (term, lit, data) Sidekick_cc.Key.t + module Data : sig + type t + val empty : t + val merge : t -> t -> t + end + val th : Sidekick_smt.Theory.t end diff --git a/src/th-distinct/dune b/src/th-distinct/dune index 57fcb854..e8237c6c 100644 --- a/src/th-distinct/dune +++ b/src/th-distinct/dune @@ -2,6 +2,6 @@ (library (name Sidekick_th_distinct) (public_name sidekick.smt.th-distinct) - (libraries containers sidekick.smt) + (libraries containers sidekick.core sidekick.util) (flags :standard -open Sidekick_util)) diff --git a/src/util/dune b/src/util/dune index 34504bb6..a8802b38 100644 --- a/src/util/dune +++ b/src/util/dune @@ -1,8 +1,4 @@ (library (name sidekick_util) (public_name sidekick.util) - (libraries containers iter msat) - (flags :standard -w +a-4-42-44-48-50-58-32-60@8 -color always -safe-string) - (ocamlopt_flags :standard -O3 -bin-annot - -unbox-closures -unbox-closures-factor 20) - ) + (libraries containers iter msat))