refactor(cc): use explicit actions in CC, not effectful functions

This commit is contained in:
Simon Cruanes 2022-07-22 21:26:21 -04:00
parent e37f66c394
commit 6da6284711
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 413 additions and 374 deletions

View file

@ -104,7 +104,7 @@ module Make (A : ARG) : S with module A = A = struct
module T = A.S.T.Term module T = A.S.T.Term
module Lit = A.S.Solver_internal.Lit module Lit = A.S.Solver_internal.Lit
module SI = A.S.Solver_internal module SI = A.S.Solver_internal
module N = SI.CC.Class module N = SI.CC.E_node
open struct open struct
module Pr = SI.Proof_trace module Pr = SI.Proof_trace
@ -121,7 +121,9 @@ module Make (A : ARG) : S with module A = A = struct
| Lit l -> [ l ] | Lit l -> [ l ]
| CC_eq (n1, n2) -> | CC_eq (n1, n2) ->
let r = SI.CC.explain_eq (SI.cc si) n1 n2 in let r = SI.CC.explain_eq (SI.cc si) n1 n2 in
assert (not (SI.CC.Resolved_expl.is_semantic r)); (* FIXME
assert (not (SI.CC.Resolved_expl.is_semantic r));
*)
r.lits r.lits
end end
@ -214,8 +216,8 @@ module Make (A : ARG) : S with module A = A = struct
in in
raise (Confl expl) raise (Confl expl)
)); ));
Ok (List.rev_append l1 l2) Ok (List.rev_append l1 l2, [])
with Confl expl -> Error expl with Confl expl -> Error (SI.CC.Handler_action.Conflict expl)
end end
module ST_exprs = Sidekick_cc_plugin.Make (Monoid_exprs) module ST_exprs = Sidekick_cc_plugin.Make (Monoid_exprs)
@ -798,15 +800,17 @@ module Make (A : ARG) : S with module A = A = struct
SI.on_final_check si (final_check_ st); SI.on_final_check si (final_check_ st);
SI.on_partial_check si (partial_check_ st); SI.on_partial_check si (partial_check_ st);
SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st); SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (fun (_, _, t) -> on_subterm st t); SI.on_cc_is_subterm si (fun (_, _, t) ->
SI.on_cc_pre_merge si (fun (cc, acts, n1, n2, expl) -> on_subterm st t;
[]);
SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) ->
match as_const_ (N.term n1), as_const_ (N.term n2) with match as_const_ (N.term n1), as_const_ (N.term n2) with
| Some q1, Some q2 when A.Q.(q1 <> q2) -> | Some q1, Some q2 when A.Q.(q1 <> q2) ->
(* classes with incompatible constants *) (* classes with incompatible constants *)
Log.debugf 30 (fun k -> Log.debugf 30 (fun k ->
k "(@[lra.merge-incompatible-consts@ %a@ %a@])" N.pp n1 N.pp n2); k "(@[lra.merge-incompatible-consts@ %a@ %a@])" N.pp n1 N.pp n2);
SI.CC.raise_conflict_from_expl cc acts expl Error (SI.CC.Handler_action.Conflict expl)
| _ -> ()); | _ -> Ok []);
SI.on_th_combination si (do_th_combination st); SI.on_th_combination si (do_th_combination st);
st st

View file

@ -248,21 +248,6 @@ module Make (A : ARG) :
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits
end end
type propagation_reason = unit -> lit list * step_id
type action =
| Act_merge of E_node.t * E_node.t * Expl.t
| Act_propagate of { lit: lit; reason: propagation_reason }
type conflict =
| Conflict of lit list * step_id
(** [raise_conflict (c,pr)] declares that [c] is a tautology of
the theory of congruence.
@param pr the proof of [c] being a tautology *)
| Conflict_expl of Expl.t
type actions_or_confl = (action list, conflict) result
(** A signature is a shallow term shape where immediate subterms (** A signature is a shallow term shape where immediate subterms
are representative *) are representative *)
module Signature = struct module Signature = struct
@ -319,9 +304,26 @@ module Make (A : ARG) :
module Sig_tbl = CCHashtbl.Make (Signature) module Sig_tbl = CCHashtbl.Make (Signature)
module T_tbl = CCHashtbl.Make (Term) module T_tbl = CCHashtbl.Make (Term)
type propagation_reason = unit -> lit list * step_id
module Handler_action = struct
type t =
| Act_merge of E_node.t * E_node.t * Expl.t
| Act_propagate of lit * propagation_reason
type conflict = Conflict of Expl.t [@@unboxed]
type or_conflict = (t list, conflict) result
end
module Result_action = struct
type t = Act_propagate of { lit: lit; reason: propagation_reason }
type conflict = Conflict of lit list * step_id
type or_conflict = (t list, conflict) result
end
type combine_task = type combine_task =
| CT_merge of e_node * e_node * explanation | CT_merge of e_node * e_node * explanation
| CT_act of action | CT_act of Handler_action.t
type t = { type t = {
tst: term_store; tst: term_store;
@ -344,14 +346,18 @@ module Make (A : ARG) :
true_: e_node lazy_t; true_: e_node lazy_t;
false_: e_node lazy_t; false_: e_node lazy_t;
mutable in_loop: bool; (* currently being modified? *) mutable in_loop: bool; (* currently being modified? *)
res_acts: action Vec.t; (* to return *) res_acts: Result_action.t Vec.t; (* to return *)
on_pre_merge: on_pre_merge:
(t * E_node.t * E_node.t * Expl.t, actions_or_confl) Event.Emitter.t; ( t * E_node.t * E_node.t * Expl.t,
on_post_merge: (t * E_node.t * E_node.t, action list) Event.Emitter.t; Handler_action.or_conflict )
on_new_term: (t * E_node.t * term, action list) Event.Emitter.t; Event.Emitter.t;
on_post_merge:
(t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t;
on_new_term: (t * E_node.t * term, Handler_action.t list) Event.Emitter.t;
on_conflict: (ev_on_conflict, unit) Event.Emitter.t; on_conflict: (ev_on_conflict, unit) Event.Emitter.t;
on_propagate: (t * lit * propagation_reason, action list) Event.Emitter.t; on_propagate:
on_is_subterm: (t * E_node.t * term, action list) Event.Emitter.t; (t * lit * propagation_reason, Handler_action.t list) Event.Emitter.t;
on_is_subterm: (t * E_node.t * term, Handler_action.t list) Event.Emitter.t;
count_conflict: int Stat.counter; count_conflict: int Stat.counter;
count_props: int Stat.counter; count_props: int Stat.counter;
count_merge: int Stat.counter; count_merge: int Stat.counter;
@ -451,10 +457,10 @@ module Make (A : ARG) :
Log.debugf 50 (fun k -> k "(@[<hv1>cc.push-pending@ %a@])" E_node.pp t); Log.debugf 50 (fun k -> k "(@[<hv1>cc.push-pending@ %a@])" E_node.pp t);
Vec.push self.pending t Vec.push self.pending t
let push_action self (a : action) : unit = Vec.push self.combine (CT_act a) let push_action self (a : Handler_action.t) : unit =
Vec.push self.combine (CT_act a)
let push_action_l self (l : action list) : unit = let push_action_l self (l : _ list) : unit = List.iter (push_action self) l
List.iter (push_action self) l
let merge_classes self t u e : unit = let merge_classes self t u e : unit =
if t != u && not (same_class t u) then ( if t != u && not (same_class t u) then (
@ -476,7 +482,7 @@ module Make (A : ARG) :
u.n_expl <- FL_some { next = n; expl = e_n_u }; u.n_expl <- FL_some { next = n; expl = e_n_u };
n.n_expl <- FL_none n.n_expl <- FL_none
exception E_confl of conflict exception E_confl of Result_action.conflict
let raise_conflict_ (cc : t) ~th (e : lit list) (p : step_id) : _ = let raise_conflict_ (cc : t) ~th (e : lit list) (p : step_id) : _ =
Profile.instant "cc.conflict"; Profile.instant "cc.conflict";
@ -824,10 +830,10 @@ module Make (A : ARG) :
and task_combine_ self = function and task_combine_ self = function
| CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab
| CT_act (Act_merge (t, u, e)) -> task_merge_ self t u e | CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e
| CT_act (Act_propagate _ as a) -> | CT_act (Handler_action.Act_propagate (lit, reason)) ->
(* will return this propagation to the caller *) (* will return this propagation to the caller *)
Vec.push self.res_acts a Vec.push self.res_acts (Result_action.Act_propagate { lit; reason })
(* main CC algo: merge equivalence classes in [st.combine]. (* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *) @raise Exn_unsat if merge fails *)
@ -900,7 +906,8 @@ module Make (A : ARG) :
Event.emit_iter self.on_pre_merge (self, r_into, r_from, expl) Event.emit_iter self.on_pre_merge (self, r_into, r_from, expl)
~f:(function ~f:(function
| Ok l -> push_action_l self l | Ok l -> push_action_l self l
| Error c -> raise (E_confl c)); | Error (Handler_action.Conflict expl) ->
raise_conflict_from_expl self expl);
(* TODO: merge plugin data here, _after_ the pre-merge hooks are called, (* TODO: merge plugin data here, _after_ the pre-merge hooks are called,
so they have a chance of observing pre-merge plugin data *) so they have a chance of observing pre-merge plugin data *)
@ -999,12 +1006,24 @@ module Make (A : ARG) :
let _, pr = lits_and_proof_of_expl self st in let _, pr = lits_and_proof_of_expl self st in
guard, pr guard, pr
in in
push_action self (Act_propagate { lit; reason }); Vec.push self.res_acts (Result_action.Act_propagate { lit; reason });
Event.emit_iter self.on_propagate (self, lit, reason) Event.emit_iter self.on_propagate (self, lit, reason)
~f:(push_action_l self); ~f:(push_action_l self);
Stat.incr self.count_props Stat.incr self.count_props
| _ -> ()) | _ -> ())
(* raise a conflict from an explanation, typically from an event handler.
Raises E_confl with a result conflict. *)
and raise_conflict_from_expl self (expl : Expl.t) : 'a =
Log.debugf 5 (fun k ->
k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
let st = Expl_state.create () in
explain_decompose_expl self st expl;
let lits, pr = lits_and_proof_of_expl self st in
let c = List.rev_map Lit.neg lits in
let th = st.th_lemmas <> [] in
raise_conflict_ self ~th c pr
let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t)
let push_level (self : t) : unit = let push_level (self : t) : unit =
@ -1053,19 +1072,6 @@ module Make (A : ARG) :
assert (not self.in_loop); assert (not self.in_loop);
Iter.iter (assert_lit self) lits Iter.iter (assert_lit self) lits
(* FIXME: remove?
(* raise a conflict *)
let raise_conflict_from_expl self (acts : actions_or_confl) expl =
Log.debugf 5 (fun k ->
k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
let st = Expl_state.create () in
explain_decompose_expl self st expl;
let lits, pr = lits_and_proof_of_expl self st in
let c = List.rev_map Lit.neg lits in
let th = st.th_lemmas <> [] in
raise_conflict_ self ~th c pr
*)
let merge self n1 n2 expl = let merge self n1 n2 expl =
assert (not self.in_loop); assert (not self.in_loop);
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
@ -1083,6 +1089,11 @@ module Make (A : ARG) :
(* FIXME: also need to return the proof? *) (* FIXME: also need to return the proof? *)
Expl_state.to_resolved_expl st Expl_state.to_resolved_expl st
let explain_expl (self : t) expl : Resolved_expl.t =
let expl_st = Expl_state.create () in
explain_decompose_expl self expl_st expl;
Expl_state.to_resolved_expl expl_st
let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge
let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge
let[@inline] on_new_term self = Event.of_emitter self.on_new_term let[@inline] on_new_term self = Event.of_emitter self.on_new_term
@ -1142,7 +1153,7 @@ module Make (A : ARG) :
in in
loop [] loop []
let check self : actions_or_confl = let check self : Result_action.or_conflict =
Log.debug 5 "(cc.check)"; Log.debug 5 "(cc.check)";
self.in_loop <- true; self.in_loop <- true;
let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in

View file

@ -63,8 +63,9 @@ module Make (M : MONOID_PLUGIN_ARG) :
else else
None None
let on_new_term cc n (t : term) : unit = let on_new_term cc n (t : term) : CC.Handler_action.t list =
(*Log.debugf 50 (fun k->k "(@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);*) (*Log.debugf 50 (fun k->k "(@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);*)
let acts = ref [] in
let maybe_m, l = M.of_term cc n t in let maybe_m, l = M.of_term cc n t in
(match maybe_m with (match maybe_m with
| Some v -> | Some v ->
@ -86,12 +87,14 @@ module Make (M : MONOID_PLUGIN_ARG) :
with Not_found -> with Not_found ->
Error.errorf "node %a has bitfield but no value" E_node.pp n_u Error.errorf "node %a has bitfield but no value" E_node.pp n_u
in in
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
| Error expl -> | Error (CC.Handler_action.Conflict expl) ->
Error.errorf Error.errorf
"when merging@ @[for node %a@],@ values %a and %a:@ conflict %a" "when merging@ @[for node %a@],@ values %a and %a:@ conflict %a"
E_node.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl E_node.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
| Ok m_u_merged -> | Ok (m_u_merged, merge_acts) ->
acts := List.rev_append merge_acts !acts;
Log.debugf 20 (fun k -> Log.debugf 20 (fun k ->
k k
"(@[monoid[%s].on-new-term.sub.merged@ :n %a@ :sub-t %a@ \ "(@[monoid[%s].on-new-term.sub.merged@ :n %a@ :sub-t %a@ \
@ -104,14 +107,15 @@ module Make (M : MONOID_PLUGIN_ARG) :
Cls_tbl.add values n_u m_u Cls_tbl.add values n_u m_u
)) ))
l; l;
() !acts
let iter_all : _ Iter.t = Cls_tbl.to_iter values let iter_all : _ Iter.t = Cls_tbl.to_iter values
let on_pre_merge cc n1 n2 e_n1_n2 : CC.actions = let on_pre_merge cc n1 n2 e_n1_n2 : CC.Handler_action.or_conflict =
let exception E of M.CC.conflict in let exception E of M.CC.Handler_action.conflict in
let acts = ref [] in
try try
match get n1, get n2 with (match get n1, get n2 with
| Some v1, Some v2 -> | Some v1, Some v2 ->
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
k k
@ -119,17 +123,19 @@ module Make (M : MONOID_PLUGIN_ARG) :
%a@ :val2 %a@])@])" %a@ :val2 %a@])@])"
M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2); M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2);
(match M.merge cc n1 v1 n2 v2 e_n1_n2 with (match M.merge cc n1 v1 n2 v2 e_n1_n2 with
| Ok v' -> | Ok (v', merge_acts) ->
acts := merge_acts;
Cls_tbl.remove values n2; Cls_tbl.remove values n2;
(* only keep repr *) (* only keep repr *)
Cls_tbl.add values n1 v' Cls_tbl.add values n1 v'
| Error expl -> raise (E (CC.Conflict_expl expl))) | Error c -> raise (E c))
| None, Some cr -> | None, Some cr ->
CC.set_bitfield cc field_has_value true n1; CC.set_bitfield cc field_has_value true n1;
Cls_tbl.add values n1 cr; Cls_tbl.add values n1 cr;
Cls_tbl.remove values n2 (* only keep reprs *) Cls_tbl.remove values n2 (* only keep reprs *)
| Some _, None -> () (* already there on the left *) | Some _, None -> () (* already there on the left *)
| None, None -> () | None, None -> ());
Ok !acts
with E c -> Error c with E c -> Error c
let pp out () : unit = let pp out () : unit =
@ -141,8 +147,8 @@ module Make (M : MONOID_PLUGIN_ARG) :
(* setup *) (* setup *)
let () = let () =
Event.on (CC.on_new_term cc) ~f:(fun (_, r, t) -> on_new_term cc r t); Event.on (CC.on_new_term cc) ~f:(fun (_, r, t) -> on_new_term cc r t);
Event.on (CC.on_pre_merge cc) ~f:(fun (_, acts, ra, rb, expl) -> Event.on (CC.on_pre_merge cc) ~f:(fun (_, ra, rb, expl) ->
on_pre_merge cc acts ra rb expl); on_pre_merge cc ra rb expl);
() ()
end end

View file

@ -5,11 +5,11 @@ open Sidekick_sigs_cc
module type EXTENDED_PLUGIN_BUILDER = sig module type EXTENDED_PLUGIN_BUILDER = sig
include MONOID_PLUGIN_BUILDER include MONOID_PLUGIN_BUILDER
val mem : t -> M.CC.Class.t -> bool val mem : t -> M.CC.E_node.t -> bool
(** Does the CC Class.t have a monoid value? *) (** Does the CC.E_node.t have a monoid value? *)
val get : t -> M.CC.Class.t -> M.t option val get : t -> M.CC.E_node.t -> M.t option
(** Get monoid value for this CC Class.t, if any *) (** Get monoid value for this CC.E_node.t, if any *)
val iter_all : t -> (M.CC.repr * M.t) Iter.t val iter_all : t -> (M.CC.repr * M.t) Iter.t

View file

@ -6,48 +6,6 @@ module type TERM = Sidekick_sigs_term.S
module type LIT = Sidekick_sigs_lit.S module type LIT = Sidekick_sigs_lit.S
module type PROOF_TRACE = Sidekick_sigs_proof_trace.S module type PROOF_TRACE = Sidekick_sigs_proof_trace.S
(** Actions provided to the congruence closure.
The congruence closure must be able to propagate literals when
it detects that they are true or false; it must also
be able to create conflicts when the set of (dis)equalities
is inconsistent *)
module type DYN_ACTIONS = sig
type term
type lit
type proof_trace
type step_id
val proof_trace : unit -> proof_trace
val raise_conflict : lit list -> step_id -> 'a
(** [raise_conflict c pr] declares that [c] is a tautology of
the theory of congruence. This does not return (it should raise an
exception).
@param pr the proof of [c] being a tautology *)
val raise_semantic_conflict : lit list -> (bool * term * term) list -> 'a
(** [raise_semantic_conflict lits same_val] declares that
the conjunction of all [lits] (literals true in current trail) and tuples
[{=,}, t_i, u_i] implies false.
The [{=,}, t_i, u_i] are pairs of terms with the same value (if [=] / true)
or distinct value (if [] / false)) in the current model.
This does not return. It should raise an exception.
*)
val propagate : lit -> reason:(unit -> lit list * step_id) -> unit
(** [propagate lit ~reason pr] declares that [reason() => lit]
is a tautology.
- [reason()] should return a list of literals that are currently true.
- [lit] should be a literal of interest (see {!CC_S.set_as_lit}).
This function might never be called, a congruence closure has the right
to not propagate and only trigger conflicts. *)
end
(** Arguments to a congruence closure's implementation *) (** Arguments to a congruence closure's implementation *)
module type ARG = sig module type ARG = sig
module T : TERM module T : TERM
@ -83,23 +41,17 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
type proof_trace = Proof_trace.t type proof_trace = Proof_trace.t
type step_id = Proof_trace.A.step_id type step_id = Proof_trace.A.step_id
type actions = (** E-node.
(module DYN_ACTIONS
with type term = T.Term.t
and type lit = Lit.t
and type proof_trace = proof_trace
and type step_id = step_id)
(** Actions available to the congruence closure *)
(** Equivalence classes.
An e-node is a node in the congruence closure that is contained
in some equivalence classe).
An equivalence class is a set of terms that are currently equal An equivalence class is a set of terms that are currently equal
in the partial model built by the solver. in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is The class is represented by a collection of nodes, one of which is
distinguished and is called the "representative". distinguished and is called the "representative".
All information pertaining to the whole equivalence class is stored All information pertaining to the whole equivalence class is stored
in this representative's Class.t. in its representative's {!E_node.t}.
When two classes become equal (are "merged"), one of the two When two classes become equal (are "merged"), one of the two
representatives is picked as the representative of the new class. representatives is picked as the representative of the new class.
@ -109,10 +61,9 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
representative. This information can be used when two classes are representative. This information can be used when two classes are
merged, to detect conflicts and solve equations à la Shostak. merged, to detect conflicts and solve equations à la Shostak.
*) *)
module Class : sig module E_node : sig
type t type t
(** An equivalent class, containing terms that are proved (** An E-node.
to be equal.
A value of type [t] points to a particular term, but see A value of type [t] points to a particular term, but see
{!find} to get the representative of the class. *) {!find} to get the representative of the class. *)
@ -125,14 +76,14 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
val equal : t -> t -> bool val equal : t -> t -> bool
(** Are two classes {b physically} equal? To check for (** Are two classes {b physically} equal? To check for
logical equality, use [CC.Class.equal (CC.find cc n1) (CC.find cc n2)] logical equality, use [CC.E_node.equal (CC.find cc n1) (CC.find cc n2)]
which checks for equality of representatives. *) which checks for equality of representatives. *)
val hash : t -> int val hash : t -> int
(** An opaque hash of this Class.t. *) (** An opaque hash of this E_node.t. *)
val is_root : t -> bool val is_root : t -> bool
(** Is the Class.t a root (ie the representative of its class)? (** Is the E_node.t a root (ie the representative of its class)?
See {!find} to get the root. *) See {!find} to get the root. *)
val iter_class : t -> t Iter.t val iter_class : t -> t Iter.t
@ -167,7 +118,7 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
include Sidekick_sigs.PRINT with type t := t include Sidekick_sigs.PRINT with type t := t
val mk_merge : Class.t -> Class.t -> t val mk_merge : E_node.t -> E_node.t -> t
(** Explanation: the nodes were explicitly merged *) (** Explanation: the nodes were explicitly merged *)
val mk_merge_t : term -> term -> t val mk_merge_t : term -> term -> t
@ -178,9 +129,6 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
or we merged [t] and [true] because of literal [t], or we merged [t] and [true] because of literal [t],
or [t] and [false] because of literal [¬t] *) or [t] and [false] because of literal [¬t] *)
val mk_same_value : Class.t -> Class.t -> t
(** The two classes have the same value during model construction *)
val mk_list : t list -> t val mk_list : t list -> t
(** Conjunction of explanations *) (** Conjunction of explanations *)
@ -217,73 +165,22 @@ module type ARGS_CLASSES_EXPL_EVENT = sig
However, we can also have merged classes because they have the same value However, we can also have merged classes because they have the same value
in the current model. *) in the current model. *)
module Resolved_expl : sig module Resolved_expl : sig
type t = { type t = { lits: lit list; pr: proof_trace -> step_id }
lits: lit list;
same_value: (Class.t * Class.t) list;
pr: proof_trace -> step_id;
}
include Sidekick_sigs.PRINT with type t := t include Sidekick_sigs.PRINT with type t := t
val is_semantic : t -> bool
(** [is_semantic expl] is [true] if there's at least one
pair in [expl.same_value]. *)
end end
type node = Class.t (** Per-node data *)
type e_node = E_node.t
(** A node of the congruence closure *) (** A node of the congruence closure *)
type repr = Class.t type repr = E_node.t
(** Node that is currently a representative. *) (** Node that is currently a representative. *)
type explanation = Expl.t type explanation = Expl.t
end end
(* TODO: can we have that meaningfully? the type of Class.t would depend on
the implementation, so it can't be pre-defined, but nor can it be accessed from
shortcuts from the inside. That means one cannot point to classes from outside
the opened module.
Potential solution:
- make Expl polymorphic and lift it to toplevel, like View
- do not expose Class, only Term-based API
(** The type for a congruence closure, as a first-class module *)
module type DYN = sig
include ARGS_CLASSES_EXPL_EVENT
include Sidekick_sigs.DYN_BACKTRACKABLE
val term_store : unit -> term_store
val proof : unit -> proof_trace
val find : node -> repr
val add_term : term -> node
val mem_term : term -> bool
val allocate_bitfield : descr:string -> Class.bitfield
val get_bitfield : Class.bitfield -> Class.t -> bool
val set_bitfield : Class.bitfield -> bool -> Class.t -> unit
val on_event : unit -> event Event.t
val set_as_lit : Class.t -> lit -> unit
val find_t : term -> repr
val add_iter : term Iter.t -> unit
val all_classes : repr Iter.t
val assert_lit : lit -> unit
val assert_lits : lit Iter.t -> unit
val explain_eq : Class.t -> Class.t -> Resolved_expl.t
val raise_conflict_from_expl : actions -> Expl.t -> 'a
val n_true : unit -> Class.t
val n_false : unit -> Class.t
val n_bool : bool -> Class.t
val merge : Class.t -> Class.t -> Expl.t -> unit
val merge_t : term -> term -> Expl.t -> unit
val set_model_value : term -> value -> unit
val with_model_mode : (unit -> 'a) -> 'a
val get_model_for_each_class : (repr * Class.t Iter.t * value) Iter.t
val check : actions -> unit
val push_level : unit -> unit
val pop_levels : int -> unit
val get_model : Class.t Iter.t Iter.t
end
*)
(** Main congruence closure signature. (** Main congruence closure signature.
The congruence closure handles the theory QF_UF (uninterpreted The congruence closure handles the theory QF_UF (uninterpreted
@ -312,18 +209,18 @@ module type S = sig
val term_store : t -> term_store val term_store : t -> term_store
val proof : t -> proof_trace val proof : t -> proof_trace
val find : t -> node -> repr val find : t -> e_node -> repr
(** Current representative *) (** Current representative *)
val add_term : t -> term -> node val add_term : t -> term -> e_node
(** Add the term to the congruence closure, if not present already. (** Add the term to the congruence closure, if not present already.
Will be backtracked. *) Will be backtracked. *)
val mem_term : t -> term -> bool val mem_term : t -> term -> bool
(** Returns [true] if the term is explicitly present in the congruence closure *) (** Returns [true] if the term is explicitly present in the congruence closure *)
val allocate_bitfield : t -> descr:string -> Class.bitfield val allocate_bitfield : t -> descr:string -> E_node.bitfield
(** Allocate a new node field (see {!Class.bitfield}). (** Allocate a new e_node field (see {!E_node.bitfield}).
This field descriptor is henceforth reserved for all nodes This field descriptor is henceforth reserved for all nodes
in this congruence closure, and can be set using {!set_bitfield} in this congruence closure, and can be set using {!set_bitfield}
@ -336,12 +233,62 @@ module type S = sig
for a given congruence closure (e.g. at most {!Sys.int_size} fields). for a given congruence closure (e.g. at most {!Sys.int_size} fields).
*) *)
val get_bitfield : t -> Class.bitfield -> Class.t -> bool val get_bitfield : t -> E_node.bitfield -> E_node.t -> bool
(** Access the bit field of the given node *) (** Access the bit field of the given e_node *)
val set_bitfield : t -> Class.bitfield -> bool -> Class.t -> unit val set_bitfield : t -> E_node.bitfield -> bool -> E_node.t -> unit
(** Set the bitfield for the node. This will be backtracked. (** Set the bitfield for the e_node. This will be backtracked.
See {!Class.bitfield}. *) See {!E_node.bitfield}. *)
type propagation_reason = unit -> lit list * step_id
(** Handler Actions
Actions that can be scheduled by event handlers. *)
module Handler_action : sig
type t =
| Act_merge of E_node.t * E_node.t * Expl.t
| Act_propagate of lit * propagation_reason
(* TODO:
- an action to modify data associated with a class
*)
type conflict = Conflict of Expl.t [@@unboxed]
type or_conflict = (t list, conflict) result
(** Actions or conflict scheduled by an event handler.
- [Ok acts] is a list of merges and propagations
- [Error confl] is a conflict to resolve.
*)
end
(** Result Actions.
Actions returned by the congruence closure after calling {!check}. *)
module Result_action : sig
type t =
| Act_propagate of { lit: lit; reason: propagation_reason }
(** [propagate (lit, reason)] declares that [reason() => lit]
is a tautology.
- [reason()] should return a list of literals that are currently true,
as well as a proof.
- [lit] should be a literal of interest (see {!S.set_as_lit}).
This function might never be called, a congruence closure has the right
to not propagate and only trigger conflicts. *)
type conflict =
| Conflict of lit list * step_id
(** [raise_conflict (c,pr)] declares that [c] is a tautology of
the theory of congruence.
@param pr the proof of [c] being a tautology *)
type or_conflict = (t list, conflict) result
end
(** {3 Events} (** {3 Events}
@ -349,18 +296,20 @@ module type S = sig
other plugins can subscribe. *) other plugins can subscribe. *)
(** Events emitted by the congruence closure when something changes. *) (** Events emitted by the congruence closure when something changes. *)
val on_pre_merge : t -> (t * actions * Class.t * Class.t * Expl.t) Event.t val on_pre_merge :
t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t
(** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1] (** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1]
and [n2] are merged with explanation [expl]. *) and [n2] are merged with explanation [expl]. *)
val on_post_merge : t -> (t * actions * Class.t * Class.t) Event.t val on_post_merge :
t -> (t * E_node.t * E_node.t, Handler_action.t list) Event.t
(** [ev_on_post_merge acts n1 n2] is emitted right after [n1] (** [ev_on_post_merge acts n1 n2] is emitted right after [n1]
and [n2] were merged. [find cc n1] and [find cc n2] will return and [n2] were merged. [find cc n1] and [find cc n2] will return
the same Class.t. *) the same E_node.t. *)
val on_new_term : t -> (t * Class.t * term) Event.t val on_new_term : t -> (t * E_node.t * term, Handler_action.t list) Event.t
(** [ev_on_new_term n t] is emitted whenever a new term [t] (** [ev_on_new_term n t] is emitted whenever a new term [t]
is added to the congruence closure. Its Class.t is [n]. *) is added to the congruence closure. Its E_node.t is [n]. *)
type ev_on_conflict = { cc: t; th: bool; c: lit list } type ev_on_conflict = { cc: t; th: bool; c: lit list }
(** Event emitted when a conflict occurs in the CC. (** Event emitted when a conflict occurs in the CC.
@ -370,27 +319,37 @@ module type S = sig
participating in the conflict are purely syntactic theories participating in the conflict are purely syntactic theories
like injectivity of constructors. *) like injectivity of constructors. *)
val on_conflict : t -> ev_on_conflict Event.t val on_conflict : t -> (ev_on_conflict, unit) Event.t
(** [ev_on_conflict {th; c}] is emitted when the congruence (** [ev_on_conflict {th; c}] is emitted when the congruence
closure triggers a conflict by asserting the tautology [c]. *) closure triggers a conflict by asserting the tautology [c]. *)
val on_propagate : t -> (t * lit * (unit -> lit list * step_id)) Event.t val on_propagate :
t -> (t * lit * (unit -> lit list * step_id), Handler_action.t list) Event.t
(** [ev_on_propagate lit reason] is emitted whenever [reason() => lit] (** [ev_on_propagate lit reason] is emitted whenever [reason() => lit]
is a propagated lemma. See {!CC_ACTIONS.propagate}. *) is a propagated lemma. See {!CC_ACTIONS.propagate}. *)
val on_is_subterm : t -> (t * Class.t * term) Event.t val on_is_subterm : t -> (t * E_node.t * term, Handler_action.t list) Event.t
(** [ev_on_is_subterm n t] is emitted when [n] is a subterm of (** [ev_on_is_subterm n t] is emitted when [n] is a subterm of
another Class.t for the first time. [t] is the term corresponding to another E_node.t for the first time. [t] is the term corresponding to
the Class.t [n]. This can be useful for theory combination. *) the E_node.t [n]. This can be useful for theory combination. *)
(** {3 Misc} *) (** {3 Misc} *)
val set_as_lit : t -> Class.t -> lit -> unit val n_true : t -> E_node.t
(** map the given node to a literal. *) (** Node for [true] *)
val n_false : t -> E_node.t
(** Node for [false] *)
val n_bool : t -> bool -> E_node.t
(** Node for either true or false *)
val set_as_lit : t -> E_node.t -> lit -> unit
(** map the given e_node to a literal. *)
val find_t : t -> term -> repr val find_t : t -> term -> repr
(** Current representative of the term. (** Current representative of the term.
@raise Class.t_found if the term is not already {!add}-ed. *) @raise E_node.t_found if the term is not already {!add}-ed. *)
val add_iter : t -> term Iter.t -> unit val add_iter : t -> term Iter.t -> unit
(** Add a sequence of terms to the congruence closure *) (** Add a sequence of terms to the congruence closure *)
@ -398,6 +357,36 @@ module type S = sig
val all_classes : t -> repr Iter.t val all_classes : t -> repr Iter.t
(** All current classes. This is costly, only use if there is no other solution *) (** All current classes. This is costly, only use if there is no other solution *)
val explain_eq : t -> E_node.t -> E_node.t -> Resolved_expl.t
(** Explain why the two nodes are equal.
Fails if they are not, in an unspecified way. *)
val explain_expl : t -> Expl.t -> Resolved_expl.t
(** Transform explanation into an actionable conflict clause *)
(* FIXME: remove
val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a
(** Raise a conflict with the given explanation.
It must be a theory tautology that [expl ==> absurd].
To be used in theories.
This fails in an unspecified way if the explanation, once resolved,
satisfies {!Resolved_expl.is_semantic}. *)
*)
val merge : t -> E_node.t -> E_node.t -> Expl.t -> unit
(** Merge these two nodes given this explanation.
It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *)
val merge_t : t -> term -> term -> Expl.t -> unit
(** Shortcut for adding + merging *)
(** {3 Main API *)
val assert_eq : t -> term -> term -> Expl.t -> unit
(** Assert that two terms are equal, using the given explanation. *)
val assert_lit : t -> lit -> unit val assert_lit : t -> lit -> unit
(** Given a literal, assume it in the congruence closure and propagate (** Given a literal, assume it in the congruence closure and propagate
its consequences. Will be backtracked. its consequences. Will be backtracked.
@ -407,45 +396,7 @@ module type S = sig
val assert_lits : t -> lit Iter.t -> unit val assert_lits : t -> lit Iter.t -> unit
(** Addition of many literals *) (** Addition of many literals *)
val explain_eq : t -> Class.t -> Class.t -> Resolved_expl.t val check : t -> Result_action.or_conflict
(** Explain why the two nodes are equal.
Fails if they are not, in an unspecified way. *)
val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a
(** Raise a conflict with the given explanation.
It must be a theory tautology that [expl ==> absurd].
To be used in theories.
This fails in an unspecified way if the explanation, once resolved,
satisfies {!Resolved_expl.is_semantic}. *)
val n_true : t -> Class.t
(** Node for [true] *)
val n_false : t -> Class.t
(** Node for [false] *)
val n_bool : t -> bool -> Class.t
(** Node for either true or false *)
val merge : t -> Class.t -> Class.t -> Expl.t -> unit
(** Merge these two nodes given this explanation.
It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *)
val merge_t : t -> term -> term -> Expl.t -> unit
(** Shortcut for adding + merging *)
val set_model_value : t -> term -> value -> unit
(** Set the value of a term in the model. *)
val with_model_mode : t -> (unit -> 'a) -> 'a
(** Enter model combination mode. *)
val get_model_for_each_class : t -> (repr * Class.t Iter.t * value) Iter.t
(** In model combination mode, obtain classes with their values. *)
val check : t -> actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the {!actions} to propagate literals, declare conflicts, etc. *) Will use the {!actions} to propagate literals, declare conflicts, etc. *)
@ -455,7 +406,7 @@ module type S = sig
val pop_levels : t -> int -> unit val pop_levels : t -> int -> unit
(** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) (** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *)
val get_model : t -> Class.t Iter.t Iter.t val get_model : t -> E_node.t Iter.t Iter.t
(** get all the equivalence classes so they can be merged in the model *) (** get all the equivalence classes so they can be merged in the model *)
end end
@ -485,8 +436,10 @@ module type MONOID_PLUGIN_ARG = sig
val name : string val name : string
(** name of the monoid structure (short) *) (** name of the monoid structure (short) *)
(* FIXME: for subs, return list of e_nodes, and assume of_term already
returned data for them. *)
val of_term : val of_term :
CC.t -> CC.Class.t -> CC.term -> t option * (CC.Class.t * t) list CC.t -> CC.E_node.t -> CC.term -> t option * (CC.E_node.t * t) list
(** [of_term n t], where [t] is the term annotating node [n], (** [of_term n t], where [t] is the term annotating node [n],
must return [maybe_m, l], where: must return [maybe_m, l], where:
@ -500,12 +453,12 @@ module type MONOID_PLUGIN_ARG = sig
val merge : val merge :
CC.t -> CC.t ->
CC.Class.t -> CC.E_node.t ->
t -> t ->
CC.Class.t -> CC.E_node.t ->
t -> t ->
CC.Expl.t -> CC.Expl.t ->
(t, CC.Expl.t) result (t * CC.Handler_action.t list, CC.Handler_action.conflict) result
(** Monoidal combination of two values. (** Monoidal combination of two values.
[merge cc n1 mon1 n2 mon2 expl] returns the result of merging [merge cc n1 mon1 n2 mon2 expl] returns the result of merging
@ -531,11 +484,11 @@ module type DYN_MONOID_PLUGIN = sig
val pp : unit Fmt.printer val pp : unit Fmt.printer
val mem : M.CC.Class.t -> bool val mem : M.CC.E_node.t -> bool
(** Does the CC Class.t have a monoid value? *) (** Does the CC E_node.t have a monoid value? *)
val get : M.CC.Class.t -> M.t option val get : M.CC.E_node.t -> M.t option
(** Get monoid value for this CC Class.t, if any *) (** Get monoid value for this CC E_node.t, if any *)
val iter_all : (M.CC.repr * M.t) Iter.t val iter_all : (M.CC.repr * M.t) Iter.t
end end

View file

@ -230,19 +230,22 @@ module type SOLVER_INTERNAL = sig
(** Add the given (signed) bool term to the SAT solver, so it gets assigned (** Add the given (signed) bool term to the SAT solver, so it gets assigned
a boolean value *) a boolean value *)
val cc_raise_conflict_expl : t -> theory_actions -> CC.Expl.t -> 'a val cc_find : t -> CC.E_node.t -> CC.E_node.t
(** Raise a conflict with the given congruence closure explanation.
it must be a theory tautology that [expl ==> absurd].
To be used in theories. *)
val cc_find : t -> CC.Class.t -> CC.Class.t
(** Find representative of the node *) (** Find representative of the node *)
val cc_are_equal : t -> term -> term -> bool val cc_are_equal : t -> term -> term -> bool
(** Are these two terms equal in the congruence closure? *) (** Are these two terms equal in the congruence closure? *)
val cc_resolve_expl : t -> CC.Expl.t -> lit list * step_id
(*
val cc_raise_conflict_expl : t -> theory_actions -> CC.Expl.t -> 'a
(** Raise a conflict with the given congruence closure explanation.
it must be a theory tautology that [expl ==> absurd].
To be used in theories. *)
val cc_merge : val cc_merge :
t -> theory_actions -> CC.Class.t -> CC.Class.t -> CC.Expl.t -> unit t -> theory_actions -> CC.E_node.t -> CC.E_node.t -> CC.Expl.t -> unit
(** Merge these two nodes in the congruence closure, given this explanation. (** Merge these two nodes in the congruence closure, given this explanation.
It must be a theory tautology that [expl ==> n1 = n2]. It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *) To be used in theories. *)
@ -250,8 +253,9 @@ module type SOLVER_INTERNAL = sig
val cc_merge_t : t -> theory_actions -> term -> term -> CC.Expl.t -> unit val cc_merge_t : t -> theory_actions -> term -> term -> CC.Expl.t -> unit
(** Merge these two terms in the congruence closure, given this explanation. (** Merge these two terms in the congruence closure, given this explanation.
See {!cc_merge} *) See {!cc_merge} *)
*)
val cc_add_term : t -> term -> CC.Class.t val cc_add_term : t -> term -> CC.E_node.t
(** Add/retrieve congruence closure node for this term. (** Add/retrieve congruence closure node for this term.
To be used in theories *) To be used in theories *)
@ -261,19 +265,22 @@ module type SOLVER_INTERNAL = sig
val on_cc_pre_merge : val on_cc_pre_merge :
t -> t ->
(CC.t * CC.actions * CC.Class.t * CC.Class.t * CC.Expl.t -> unit) -> (CC.t * CC.E_node.t * CC.E_node.t * CC.Expl.t ->
CC.Handler_action.or_conflict) ->
unit unit
(** Callback for when two classes containing data for this key are merged (called before) *) (** Callback for when two classes containing data for this key are merged (called before) *)
val on_cc_post_merge : val on_cc_post_merge :
t -> (CC.t * CC.actions * CC.Class.t * CC.Class.t -> unit) -> unit t -> (CC.t * CC.E_node.t * CC.E_node.t -> CC.Handler_action.t list) -> unit
(** Callback for when two classes containing data for this key are merged (called after)*) (** Callback for when two classes containing data for this key are merged (called after)*)
val on_cc_new_term : t -> (CC.t * CC.Class.t * term -> unit) -> unit val on_cc_new_term :
t -> (CC.t * CC.E_node.t * term -> CC.Handler_action.t list) -> unit
(** Callback to add data on terms when they are added to the congruence (** Callback to add data on terms when they are added to the congruence
closure *) closure *)
val on_cc_is_subterm : t -> (CC.t * CC.Class.t * term -> unit) -> unit val on_cc_is_subterm :
t -> (CC.t * CC.E_node.t * term -> CC.Handler_action.t list) -> unit
(** Callback for when a term is a subterm of another term in the (** Callback for when a term is a subterm of another term in the
congruence closure *) congruence closure *)
@ -281,7 +288,9 @@ module type SOLVER_INTERNAL = sig
(** Callback called on every CC conflict *) (** Callback called on every CC conflict *)
val on_cc_propagate : val on_cc_propagate :
t -> (CC.t * lit * (unit -> lit list * step_id) -> unit) -> unit t ->
(CC.t * lit * (unit -> lit list * step_id) -> CC.Handler_action.t list) ->
unit
(** Callback called on every CC propagation *) (** Callback called on every CC propagation *)
val on_partial_check : val on_partial_check :
@ -319,7 +328,7 @@ module type SOLVER_INTERNAL = sig
(** {3 Model production} *) (** {3 Model production} *)
type model_ask_hook = type model_ask_hook =
recurse:(t -> CC.Class.t -> term) -> t -> CC.Class.t -> term option recurse:(t -> CC.E_node.t -> term) -> t -> CC.E_node.t -> term option
(** A model-production hook to query values from a theory. (** A model-production hook to query values from a theory.
It takes the solver, a class, and returns It takes the solver, a class, and returns

View file

@ -134,7 +134,7 @@ module Make (A : ARG) :
end end
module CC = Sidekick_cc.Make (CC_arg) module CC = Sidekick_cc.Make (CC_arg)
module N = CC.Class module N = CC.E_node
module Model = struct module Model = struct
type t = Empty | Map of term Term.Tbl.t type t = Empty | Map of term Term.Tbl.t
@ -167,28 +167,30 @@ module Make (A : ARG) :
| DA_add_clause of { c: lit list; pr: step_id; keep: bool } | DA_add_clause of { c: lit list; pr: step_id; keep: bool }
| DA_add_lit of { default_pol: bool option; lit: lit } | DA_add_lit of { default_pol: bool option; lit: lit }
let mk_cc_acts_ (pr : P.t) (a : sat_acts) : CC.actions = (* TODO
let (module A) = a in let mk_cc_acts_ (pr : P.t) (a : sat_acts) : CC.actions =
let (module A) = a in
(module struct (module struct
module T = T module T = T
module Lit = Lit module Lit = Lit
type nonrec lit = lit type nonrec lit = lit
type nonrec term = term type nonrec term = term
type nonrec proof_trace = Proof_trace.t type nonrec proof_trace = Proof_trace.t
type nonrec step_id = step_id type nonrec step_id = step_id
let proof_trace () = pr let proof_trace () = pr
let[@inline] raise_conflict lits (pr : step_id) = A.raise_conflict lits pr let[@inline] raise_conflict lits (pr : step_id) = A.raise_conflict lits pr
let[@inline] raise_semantic_conflict lits semantic = let[@inline] raise_semantic_conflict lits semantic =
raise (Semantic_conflict { lits; semantic }) raise (Semantic_conflict { lits; semantic })
let[@inline] propagate lit ~reason = let[@inline] propagate lit ~reason =
let reason = Sidekick_sat.Consequence reason in let reason = Sidekick_sat.Consequence reason in
A.propagate lit reason A.propagate lit reason
end) end)
*)
(** Internal solver, given to theories and to Msat *) (** Internal solver, given to theories and to Msat *)
module Solver_internal = struct module Solver_internal = struct
@ -198,7 +200,7 @@ module Make (A : ARG) :
module P_core_rules = A.Rule_core module P_core_rules = A.Rule_core
module Lit = Lit module Lit = Lit
module CC = CC module CC = CC
module N = CC.Class module N = CC.E_node
type nonrec proof_trace = Proof_trace.t type nonrec proof_trace = Proof_trace.t
type nonrec step_id = step_id type nonrec step_id = step_id
@ -584,6 +586,11 @@ module Make (A : ARG) :
let n2 = cc_add_term self t2 in let n2 = cc_add_term self t2 in
N.equal (cc_find self n1) (cc_find self n2) N.equal (cc_find self n1) (cc_find self n2)
let cc_resolve_expl self e : lit list * _ =
let r = CC.explain_expl (cc self) e in
r.lits, r.pr self.proof
(*
let cc_merge self _acts n1 n2 e = CC.merge (cc self) n1 n2 e let cc_merge self _acts n1 n2 e = CC.merge (cc self) n1 n2 e
let cc_merge_t self acts t1 t2 e = let cc_merge_t self acts t1 t2 e =
@ -593,6 +600,7 @@ module Make (A : ARG) :
let cc_raise_conflict_expl self acts e = let cc_raise_conflict_expl self acts e =
let cc_acts = mk_cc_acts_ self.proof acts in let cc_acts = mk_cc_acts_ self.proof acts in
CC.raise_conflict_from_expl (cc self) cc_acts e CC.raise_conflict_from_expl (cc self) cc_acts e
*)
(** {2 Interface with the SAT solver} *) (** {2 Interface with the SAT solver} *)
@ -631,13 +639,16 @@ module Make (A : ARG) :
in in
let model = M.create 128 in let model = M.create 128 in
(* populate with information from the CC *) (* populate with information from the CC *)
CC.get_model_for_each_class cc (fun (_, ts, v) -> (* FIXME
Iter.iter CC.get_model_for_each_class cc (fun (_, ts, v) ->
(fun n -> Iter.iter
let t = N.term n in (fun n ->
M.replace model t v) let t = N.term n in
ts); M.replace model t v)
ts);
*)
(* complete model with theory specific values *) (* complete model with theory specific values *)
let complete_with f = let complete_with f =
@ -702,30 +713,45 @@ module Make (A : ARG) :
can merge classes, *) can merge classes, *)
let check_th_combination_ (self : t) (acts : theory_actions) : let check_th_combination_ (self : t) (acts : theory_actions) :
(Model.t, th_combination_conflict) result = (Model.t, th_combination_conflict) result =
(* FIXME
(* enter model mode, disabling most of congruence closure *)
CC.with_model_mode cc @@ fun () ->
let set_val (t, v) : unit =
Log.debugf 50 (fun k ->
k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" Term.pp t
Term.pp v);
CC.set_model_value cc t v
in
(* obtain assignments from the hook, and communicate them to the CC *)
let add_th_values f : unit =
let vals = f self acts in
Iter.iter set_val vals
in
try
List.iter add_th_values self.on_th_combination;
CC.check cc;
let m = mk_model_ self in
Ok m
with Semantic_conflict c -> Error c
*)
let m = mk_model_ self in
Ok m
(* call congruence closure, perform the actions it scheduled *)
let check_cc_with_acts_ (self : t) (acts : theory_actions) =
let (module A) = acts in
let cc = cc self in let cc = cc self in
let cc_acts = mk_cc_acts_ self.proof acts in match CC.check cc with
| Ok acts ->
(* entier model mode, disabling most of congruence closure *) List.iter
CC.with_model_mode cc @@ fun () -> (function
let set_val (t, v) : unit = | CC.Result_action.Act_propagate { lit; reason } ->
Log.debugf 50 (fun k -> let reason = Sidekick_sat.Consequence reason in
k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" Term.pp t A.propagate lit reason)
Term.pp v); acts
CC.set_model_value cc t v | Error (CC.Result_action.Conflict (lits, pr)) -> A.raise_conflict lits pr
in
(* obtain assignments from the hook, and communicate them to the CC *)
let add_th_values f : unit =
let vals = f self acts in
Iter.iter set_val vals
in
try
List.iter add_th_values self.on_th_combination;
CC.check cc cc_acts;
let m = mk_model_ self in
Ok m
with Semantic_conflict c -> Error c
(* handle a literal assumed by the SAT solver *) (* handle a literal assumed by the SAT solver *)
let assert_lits_ ~final (self : t) (acts : theory_actions) let assert_lits_ ~final (self : t) (acts : theory_actions)
@ -741,14 +767,13 @@ module Make (A : ARG) :
lits); lits);
(* transmit to CC *) (* transmit to CC *)
let cc = cc self in let cc = cc self in
let cc_acts = mk_cc_acts_ self.proof acts in
if not final then CC.assert_lits cc lits; if not final then CC.assert_lits cc lits;
(* transmit to theories. *) (* transmit to theories. *)
CC.check cc cc_acts; check_cc_with_acts_ self acts;
if final then ( if final then (
List.iter (fun f -> f self acts lits) self.on_final_check; List.iter (fun f -> f self acts lits) self.on_final_check;
CC.check cc cc_acts; check_cc_with_acts_ self acts;
(match check_th_combination_ self acts with (match check_th_combination_ self acts with
| Ok m -> self.last_model <- Some m | Ok m -> self.last_model <- Some m

View file

@ -23,7 +23,7 @@ module Make (A : ARG) : S with module A = A = struct
module A = A module A = A
module SI = A.S.Solver_internal module SI = A.S.Solver_internal
module T = A.S.T.Term module T = A.S.T.Term
module N = SI.CC.Class module N = SI.CC.E_node
module Fun = A.S.T.Fun module Fun = A.S.T.Fun
module Expl = SI.CC.Expl module Expl = SI.CC.Expl
@ -46,7 +46,7 @@ module Make (A : ARG) : S with module A = A = struct
Some { n; t; cstor; args }, [] Some { n; t; cstor; args }, []
| _ -> None, [] | _ -> None, []
let merge cc n1 v1 n2 v2 e_n1_n2 : _ result = let merge _cc n1 v1 n2 v2 e_n1_n2 : _ result =
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" name N.pp n1 k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" name N.pp n1
T.pp v1.t N.pp n2 T.pp v2.t); T.pp v1.t N.pp n2 T.pp v2.t);
@ -60,11 +60,16 @@ module Make (A : ARG) : S with module A = A = struct
if Fun.equal v1.cstor v2.cstor then ( if Fun.equal v1.cstor v2.cstor then (
(* same function: injectivity *) (* same function: injectivity *)
assert (CCArray.length v1.args = CCArray.length v2.args); assert (CCArray.length v1.args = CCArray.length v2.args);
CCArray.iter2 (fun u1 u2 -> SI.CC.merge cc u1 u2 expl) v1.args v2.args; let acts =
Ok v1 CCArray.map2
(fun u1 u2 -> SI.CC.Handler_action.Act_merge (u1, u2, expl))
v1.args v2.args
|> Array.to_list
in
Ok (v1, acts)
) else ) else
(* different function: disjointness *) (* different function: disjointness *)
Error expl Error (SI.CC.Handler_action.Conflict expl)
end end
module ST = Sidekick_cc_plugin.Make (Monoid) module ST = Sidekick_cc_plugin.Make (Monoid)

View file

@ -160,7 +160,7 @@ module Make (A : ARG) : S with module A = A = struct
module A = A module A = A
module SI = A.S.Solver_internal module SI = A.S.Solver_internal
module T = A.S.T.Term module T = A.S.T.Term
module N = SI.CC.Class module N = SI.CC.E_node
module Ty = A.S.T.Ty module Ty = A.S.T.Ty
module Expl = SI.CC.Expl module Expl = SI.CC.Expl
module Card = Compute_card (A) module Card = Compute_card (A)
@ -216,9 +216,11 @@ module Make (A : ARG) : S with module A = A = struct
in in
assert (CCArray.length c1.c_args = CCArray.length c2.c_args); assert (CCArray.length c1.c_args = CCArray.length c2.c_args);
let acts = ref [] in
Util.array_iteri2 c1.c_args c2.c_args ~f:(fun i u1 u2 -> Util.array_iteri2 c1.c_args c2.c_args ~f:(fun i u1 u2 ->
SI.CC.merge cc u1 u2 (expl_merge i)); acts :=
Ok c1 SI.CC.Handler_action.Act_merge (u1, u2, expl_merge i) :: !acts);
Ok (c1, !acts)
) else ( ) else (
(* different function: disjointness *) (* different function: disjointness *)
let expl = let expl =
@ -226,7 +228,7 @@ module Make (A : ARG) : S with module A = A = struct
mk_expl t1 t2 @@ Pr.add_step proof @@ A.P.lemma_cstor_distinct t1 t2 mk_expl t1 t2 @@ Pr.add_step proof @@ A.P.lemma_cstor_distinct t1 t2
in in
Error expl Error (SI.CC.Handler_action.Conflict expl)
) )
end end
@ -294,7 +296,7 @@ module Make (A : ARG) : S with module A = A = struct
pp v1 N.pp n2 pp v2); pp v1 N.pp n2 pp v2);
let parent_is_a = v1.parent_is_a @ v2.parent_is_a in let parent_is_a = v1.parent_is_a @ v2.parent_is_a in
let parent_select = v1.parent_select @ v2.parent_select in let parent_select = v1.parent_select @ v2.parent_select in
Ok { parent_is_a; parent_select } Ok ({ parent_is_a; parent_select }, [])
end end
module ST_cstors = Sidekick_cc_plugin.Make (Monoid_cstor) module ST_cstors = Sidekick_cc_plugin.Make (Monoid_cstor)
@ -394,7 +396,7 @@ module Make (A : ARG) : S with module A = A = struct
N_tbl.add self.to_decide_for_complete_model n () N_tbl.add self.to_decide_for_complete_model n ()
| _ -> () | _ -> ()
let on_new_term (self : t) ((cc, n, t) : _ * N.t * T.t) : unit = let on_new_term (self : t) ((cc, n, t) : _ * N.t * T.t) : _ list =
on_new_term_look_at_ty self n t; on_new_term_look_at_ty self n t;
(* might have to decide [t] *) (* might have to decide [t] *)
match A.view_as_data t with match A.view_as_data t with
@ -402,8 +404,10 @@ module Make (A : ARG) : S with module A = A = struct
let n_u = SI.CC.add_term cc u in let n_u = SI.CC.add_term cc u in
let repr_u = SI.CC.find cc n_u in let repr_u = SI.CC.find cc n_u in
(match ST_cstors.get self.cstors repr_u with (match ST_cstors.get self.cstors repr_u with
| None -> N_tbl.add self.to_decide repr_u () | None ->
(* needs to be decided *) (* needs to be decided *)
N_tbl.add self.to_decide repr_u ();
[]
| Some cstor -> | Some cstor ->
let is_true = A.Cstor.equal cstor.c_cstor c_t in let is_true = A.Cstor.equal cstor.c_cstor c_t in
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
@ -416,11 +420,14 @@ module Make (A : ARG) : S with module A = A = struct
@@ A.P.lemma_isa_cstor ~cstor_t:(N.term cstor.c_n) t @@ A.P.lemma_isa_cstor ~cstor_t:(N.term cstor.c_n) t
in in
let n_bool = SI.CC.n_bool cc is_true in let n_bool = SI.CC.n_bool cc is_true in
SI.CC.merge cc n n_bool let expl =
Expl.( Expl.(
mk_theory (N.term n) (N.term n_bool) mk_theory (N.term n) (N.term n_bool)
[ N.term n_u, N.term cstor.c_n, [ mk_merge n_u cstor.c_n ] ] [ N.term n_u, N.term cstor.c_n, [ mk_merge n_u cstor.c_n ] ]
pr)) pr)
in
let a = SI.CC.Handler_action.Act_merge (n, n_bool, expl) in
[ a ])
| T_select (c_t, i, u) -> | T_select (c_t, i, u) ->
let n_u = SI.CC.add_term cc u in let n_u = SI.CC.add_term cc u in
let repr_u = SI.CC.find cc n_u in let repr_u = SI.CC.find cc n_u in
@ -435,21 +442,28 @@ module Make (A : ARG) : S with module A = A = struct
Pr.add_step self.proof Pr.add_step self.proof
@@ A.P.lemma_select_cstor ~cstor_t:(N.term cstor.c_n) t @@ A.P.lemma_select_cstor ~cstor_t:(N.term cstor.c_n) t
in in
SI.CC.merge cc n u_i let expl =
Expl.( Expl.(
mk_theory (N.term n) (N.term u_i) mk_theory (N.term n) (N.term u_i)
[ N.term n_u, N.term cstor.c_n, [ mk_merge n_u cstor.c_n ] ] [ N.term n_u, N.term cstor.c_n, [ mk_merge n_u cstor.c_n ] ]
pr) pr)
| Some _ -> () in
| None -> N_tbl.add self.to_decide repr_u () (* needs to be decided *)) [ SI.CC.Handler_action.Act_merge (n, u_i, expl) ]
| T_cstor _ | T_other _ -> () | Some _ -> []
| None ->
(* needs to be decided *)
N_tbl.add self.to_decide repr_u ();
[])
| T_cstor _ | T_other _ -> []
let cstors_of_ty (ty : Ty.t) : A.Cstor.t Iter.t = let cstors_of_ty (ty : Ty.t) : A.Cstor.t Iter.t =
match A.as_datatype ty with match A.as_datatype ty with
| Ty_data { cstors } -> cstors | Ty_data { cstors } -> cstors
| _ -> assert false | _ -> assert false
let on_pre_merge (self : t) (cc, acts, n1, n2, expl) : unit = let on_pre_merge (self : t) (cc, n1, n2, expl) : _ result =
let exception E_confl of SI.CC.Expl.t in
let acts = ref [] in
let merge_is_a n1 (c1 : Monoid_cstor.t) n2 (is_a2 : Monoid_parents.is_a) = let merge_is_a n1 (c1 : Monoid_cstor.t) n2 (is_a2 : Monoid_parents.is_a) =
let is_true = A.Cstor.equal c1.c_cstor is_a2.is_a_cstor in let is_true = A.Cstor.equal c1.c_cstor is_a2.is_a_cstor in
Log.debugf 50 (fun k -> Log.debugf 50 (fun k ->
@ -463,18 +477,21 @@ module Make (A : ARG) : S with module A = A = struct
@@ A.P.lemma_isa_cstor ~cstor_t:(N.term c1.c_n) (N.term is_a2.is_a_n) @@ A.P.lemma_isa_cstor ~cstor_t:(N.term c1.c_n) (N.term is_a2.is_a_n)
in in
let n_bool = SI.CC.n_bool cc is_true in let n_bool = SI.CC.n_bool cc is_true in
SI.CC.merge cc is_a2.is_a_n n_bool let expl =
(Expl.mk_theory (N.term is_a2.is_a_n) (N.term n_bool) Expl.mk_theory (N.term is_a2.is_a_n) (N.term n_bool)
[ [
( N.term n1, ( N.term n1,
N.term n2, N.term n2,
[ [
Expl.mk_merge n1 c1.c_n; Expl.mk_merge n1 c1.c_n;
Expl.mk_merge n1 n2; Expl.mk_merge n1 n2;
Expl.mk_merge n2 is_a2.is_a_arg; Expl.mk_merge n2 is_a2.is_a_arg;
] ); ] );
] ]
pr) pr
in
let act = SI.CC.Handler_action.Act_merge (is_a2.is_a_n, n_bool, expl) in
acts := act :: !acts
in in
let merge_select n1 (c1 : Monoid_cstor.t) n2 (sel2 : Monoid_parents.select) let merge_select n1 (c1 : Monoid_cstor.t) n2 (sel2 : Monoid_parents.select)
= =
@ -488,18 +505,21 @@ module Make (A : ARG) : S with module A = A = struct
@@ A.P.lemma_select_cstor ~cstor_t:(N.term c1.c_n) (N.term sel2.sel_n) @@ A.P.lemma_select_cstor ~cstor_t:(N.term c1.c_n) (N.term sel2.sel_n)
in in
let u_i = CCArray.get c1.c_args sel2.sel_idx in let u_i = CCArray.get c1.c_args sel2.sel_idx in
SI.CC.merge cc sel2.sel_n u_i let expl =
(Expl.mk_theory (N.term sel2.sel_n) (N.term u_i) Expl.mk_theory (N.term sel2.sel_n) (N.term u_i)
[ [
( N.term n1, ( N.term n1,
N.term n2, N.term n2,
[ [
Expl.mk_merge n1 c1.c_n; Expl.mk_merge n1 c1.c_n;
Expl.mk_merge n1 n2; Expl.mk_merge n1 n2;
Expl.mk_merge n2 sel2.sel_arg; Expl.mk_merge n2 sel2.sel_arg;
] ); ] );
] ]
pr) pr
in
let act = SI.CC.Handler_action.Act_merge (sel2.sel_n, u_i, expl) in
acts := act :: !acts
) )
in in
let merge_c_p n1 n2 = let merge_c_p n1 n2 =
@ -514,9 +534,11 @@ module Make (A : ARG) : S with module A = A = struct
List.iter (fun is_a2 -> merge_is_a n1 c1 n2 is_a2) p2.parent_is_a; List.iter (fun is_a2 -> merge_is_a n1 c1 n2 is_a2) p2.parent_is_a;
List.iter (fun s2 -> merge_select n1 c1 n2 s2) p2.parent_select List.iter (fun s2 -> merge_select n1 c1 n2 s2) p2.parent_select
in in
merge_c_p n1 n2; try
merge_c_p n2 n1; merge_c_p n1 n2;
() merge_c_p n2 n1;
Ok !acts
with E_confl e -> Error (SI.CC.Handler_action.Conflict e)
module Acyclicity_ = struct module Acyclicity_ = struct
type repr = N.t type repr = N.t
@ -611,7 +633,8 @@ module Make (A : ARG) : S with module A = A = struct
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
k "(@[%s.acyclicity.raise_confl@ %a@ @[:path %a@]@])" name Expl.pp k "(@[%s.acyclicity.raise_confl@ %a@ @[:path %a@]@])" name Expl.pp
expl pp_path path); expl pp_path path);
SI.cc_raise_conflict_expl solver acts expl let lits, pr = SI.cc_resolve_expl solver expl in
SI.raise_conflict solver acts lits pr
| { flag = New; _ } as node_r -> | { flag = New; _ } as node_r ->
node_r.flag <- Open; node_r.flag <- Open;
let path = (n, node_r) :: path in let path = (n, node_r) :: path in
@ -642,7 +665,8 @@ module Make (A : ARG) : S with module A = A = struct
k "(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])" name T.pp u T.pp k "(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])" name T.pp u T.pp
rhs SI.Lit.pp lit); rhs SI.Lit.pp lit);
let pr = Pr.add_step self.proof @@ A.P.lemma_isa_sel t in let pr = Pr.add_step self.proof @@ A.P.lemma_isa_sel t in
SI.cc_merge_t solver acts u rhs (* merge [u] and [rhs] *)
SI.CC.merge_t (SI.cc solver) u rhs
(Expl.mk_theory u rhs (Expl.mk_theory u rhs
[ t, N.term (SI.CC.n_true @@ SI.cc solver), [ Expl.mk_lit lit ] ] [ t, N.term (SI.CC.n_true @@ SI.cc solver), [ Expl.mk_lit lit ] ]
pr) pr)

View file

@ -6,15 +6,15 @@ let nop_handler_ _ = assert false
module Emitter = struct module Emitter = struct
type nonrec ('a, 'b) t = ('a, 'b) t type nonrec ('a, 'b) t = ('a, 'b) t
let emit (self : (_, unit) t) x = Vec.iter self.h ~f:(fun h -> h x) let emit (self : (_, unit) t) x = Vec.rev_iter self.h ~f:(fun h -> h x)
let emit_collect (self : _ t) x : _ list = let emit_collect (self : _ t) x : _ list =
let l = ref [] in let l = ref [] in
Vec.iter self.h ~f:(fun h -> l := h x :: !l); Vec.rev_iter self.h ~f:(fun h -> l := h x :: !l);
!l !l
let emit_iter self x ~f = let emit_iter self x ~f =
Vec.iter self.h ~f:(fun h -> Vec.rev_iter self.h ~f:(fun h ->
let y = h x in let y = h x in
f y) f y)

View file

@ -24,3 +24,5 @@ module Stat = Stat
module Hash = Hash module Hash = Hash
module Profile = Profile module Profile = Profile
module Chunk_stack = Chunk_stack module Chunk_stack = Chunk_stack
let[@inline] ( let@ ) f x = f x