cc: use backtrackable table

This commit is contained in:
Simon Cruanes 2022-02-17 22:03:18 -05:00
parent d153c80ca5
commit cdc5d160a7
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -265,6 +265,7 @@ module Make (A: CC_ARG)
module Sig_tbl = CCHashtbl.Make(Signature) module Sig_tbl = CCHashtbl.Make(Signature)
module T_tbl = CCHashtbl.Make(Term) module T_tbl = CCHashtbl.Make(Term)
module T_b_tbl = Backtrackable_tbl.Make(Term)
type combine_task = type combine_task =
| CT_merge of node * node * explanation | CT_merge of node * node * explanation
@ -289,10 +290,10 @@ module Make (A: CC_ARG)
pending: node Vec.t; pending: node Vec.t;
combine: combine_task Vec.t; combine: combine_task Vec.t;
t_to_val: (node*value) T_tbl.t; t_to_val: (node*value) T_b_tbl.t;
(* [repr -> (t,val)] where [repr = t] (* [repr -> (t,val)] where [repr = t] and [t := val] in the model *)
and [t := val] in the model *)
val_to_t: node T_tbl.t; (* [val -> t] where [t := val] in the model *) val_to_t: node T_b_tbl.t; (* [val -> t] where [t := val] in the model *)
undo: (unit -> unit) Backtrack_stack.t; undo: (unit -> unit) Backtrack_stack.t;
bitgen: Bits.bitfield_gen; bitgen: Bits.bitfield_gen;
@ -778,7 +779,7 @@ module Make (A: CC_ARG)
(* - if repr(n) has value [v], do nothing (* - if repr(n) has value [v], do nothing
- else if repr(n) has value [v'], semantic conflict - else if repr(n) has value [v'], semantic conflict
- else add [repr(n) -> (n,v)] to cc.t_to_val *) - else add [repr(n) -> (n,v)] to cc.t_to_val *)
begin match T_tbl.find_opt cc.t_to_val repr_n.n_term with begin match T_b_tbl.get cc.t_to_val repr_n.n_term with
| Some (n', v') when not (Term.equal v v') -> | Some (n', v') when not (Term.equal v v') ->
(* semantic conflict *) (* semantic conflict *)
let expl = [Expl.mk_merge n n'] in let expl = [Expl.mk_merge n n'] in
@ -798,16 +799,14 @@ module Make (A: CC_ARG)
| Some _ -> () | Some _ -> ()
| None -> | None ->
T_tbl.add cc.t_to_val repr_n.n_term (n, v); T_b_tbl.add cc.t_to_val repr_n.n_term (n, v);
on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val repr_n.n_term);
end; end;
(* now for the reverse map, look in self.val_to_t for [v]. (* now for the reverse map, look in self.val_to_t for [v].
- if present, push a merge command with Expl.mk_same_value - if present, push a merge command with Expl.mk_same_value
- if not, add [v -> n] *) - if not, add [v -> n] *)
begin match T_tbl.find_opt cc.val_to_t v with begin match T_b_tbl.get cc.val_to_t v with
| None -> | None ->
T_tbl.add cc.val_to_t v n; T_b_tbl.add cc.val_to_t v n;
on_backtrack cc (fun () -> T_tbl.remove cc.val_to_t v);
| Some n' when not (same_class n n') -> | Some n' when not (same_class n n') ->
merge_classes cc n n' (Expl.mk_same_value n n') merge_classes cc n n' (Expl.mk_same_value n n')
@ -930,13 +929,12 @@ module Make (A: CC_ARG)
(* check for semantic values, update the one of [r_into] (* check for semantic values, update the one of [r_into]
if [r_from] has a value *) if [r_from] has a value *)
begin match T_tbl.find_opt cc.t_to_val r_from.n_term with begin match T_b_tbl.get cc.t_to_val r_from.n_term with
| None -> () | None -> ()
| Some (n_from, v_from) -> | Some (n_from, v_from) ->
begin match T_tbl.find_opt cc.t_to_val r_into.n_term with begin match T_b_tbl.get cc.t_to_val r_into.n_term with
| None -> | None ->
T_tbl.add cc.t_to_val r_into.n_term (n_from,v_from); T_b_tbl.add cc.t_to_val r_into.n_term (n_from,v_from);
on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val r_into.n_term);
| Some (n_into,v_into) when not (Term.equal v_from v_into) -> | Some (n_into,v_into) when not (Term.equal v_from v_into) ->
(* semantic conflict, including [n_from != n_into] in model *) (* semantic conflict, including [n_from != n_into] in model *)
@ -1039,7 +1037,9 @@ module Make (A: CC_ARG)
() ()
let[@inline] push_level (self:t) : unit = let[@inline] push_level (self:t) : unit =
Backtrack_stack.push_level self.undo Backtrack_stack.push_level self.undo;
T_b_tbl.push_level self.t_to_val;
T_b_tbl.push_level self.val_to_t
let pop_levels (self:t) n : unit = let pop_levels (self:t) n : unit =
Vec.clear self.pending; Vec.clear self.pending;
@ -1047,6 +1047,8 @@ module Make (A: CC_ARG)
Log.debugf 15 Log.debugf 15
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo)); (fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
T_b_tbl.pop_levels self.t_to_val n;
T_b_tbl.pop_levels self.val_to_t n;
() ()
@ -1143,8 +1145,8 @@ module Make (A: CC_ARG)
tbl = T_tbl.create size; tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size; signatures_tbl = Sig_tbl.create size;
bitgen; bitgen;
t_to_val=T_tbl.create 32; t_to_val=T_b_tbl.create ~size:32 ();
val_to_t=T_tbl.create 32; val_to_t=T_b_tbl.create ~size:32 ();
model_mode=false; model_mode=false;
on_pre_merge; on_pre_merge;
on_post_merge; on_post_merge;