diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index af6d3761..94019621 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -114,14 +114,15 @@ let[@inline] same_class_t cc (t1:term)(t2:term): bool = let signature cc (t:term): node term_cell option = let find = find_tn cc in begin match Term.cell t with - | True | Builtin _ - -> None | App_cst (_, a) when IArray.is_empty a -> None | App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return - | If (a,b,c) -> If (find a, get_ cc b, get_ cc c) |> CCOpt.return - | Case (t, m) -> Case (find t, ID.Map.map (get_ cc) m) |> CCOpt.return | Custom {view;tc} -> Custom {tc; view=tc.tc_t_subst find view} |> CCOpt.return + | True + | Builtin _ + | If _ + | Case _ + -> None (* no congruence for these *) end (* find whether the given (parent) term corresponds to some signature @@ -139,7 +140,7 @@ let add_signature cc (t:term) (r:repr): unit = match signature cc t with | None -> () | Some s -> assert (CCOpt.map_or ~default:false (Signature.equal s) - (signature cc (r:>node).n_term)); + (signature cc r.n_term)); (* add, but only if not present already *) begin match Sig_tbl.get cc.signatures_tbl s with | None -> @@ -261,9 +262,37 @@ and update_combine cc = we try to ensure that [size ra <= size rb] in general, unless it clashes with the invariant that the representative must be a normal form if the class contains a normal form *) - let r_from, r_into = - if size_ ra > size_ rb then rb, ra else ra, rb + let must_solve, r_from, r_into = + match Term.is_semantic ra.n_term, Term.is_semantic rb.n_term with + | true, true -> + if size_ ra > size_ rb then true, rb, ra else true, ra, rb + | false, false -> + if size_ ra > size_ rb then false, rb, ra else false, ra, rb + | true, false -> false, rb, ra (* semantic ==> representative *) + | false, true -> false, ra, rb in + (* solve the equation, if both [ra] and [rb] are semantic. + The equation is between signatures, so as to canonize w.r.t the + current congruence before solving *) + if must_solve then ( + let t_a = ra.n_term and t_b = rb.n_term in + match signature cc t_a, signature cc t_b with + | Some (Custom t1), Some (Custom t2) -> + begin match t1.tc.tc_t_solve t1.view t2.view with + | Solve_ok {subst=l} -> + Log.debugf 5 + (fun k->k "(@[solve@ (@[= %a %a@])@ :yields (@[%a@])@])" + Term.pp t_a Term.pp t_b + (Util.pp_list @@ Util.pp_pair Equiv_class.pp Term.pp) l); + List.iter (fun (u1,u2) -> push_combine cc u1 (add cc u2) e_ab) l + | Solve_fail {expl} -> + Log.debugf 5 + (fun k->k "(@[solve-fail@ (@[= %a %a@])@ :expl %a@])" + Term.pp t_a Term.pp t_b Explanation.pp expl); + raise (Exn_unsat (Bag.return expl)) + end + | _ -> assert false + ); (* remove [ra.parents] from signature, put them into [st.pending] *) begin Bag.to_seq (r_from:>node).n_parents @@ -375,7 +404,8 @@ and add_new_term cc (t:term) : node = push_pending cc n; n -(* add [t=u] to the congruence closure, unconditionally (reduction relation) *) +(* TODO? *) +(* add [t=u] to the congruence closure, unconditionally (reduction relation) *) and[@inline] add_eqn (cc:t) (eqn:merge_op): unit = let t, u, expl = eqn in push_combine cc t u expl diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 99953868..8e4f2a60 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -61,7 +61,7 @@ and term_view_tc = { tc_t_equal : 'a. 'a CCEqual.t -> 'a term_view_custom CCEqual.t; tc_t_hash : 'a. 'a Hash.t -> 'a term_view_custom Hash.t; tc_t_ty : 'a. ('a -> ty) -> 'a term_view_custom -> ty; - tc_t_is_semantic : cc_node term_view_custom -> bool; (* is this a semantic term? semantic terms must be solvable *) + tc_t_is_semantic : 'a. 'a term_view_custom -> bool; (* is this a semantic term? semantic terms must be solvable *) tc_t_solve: cc_node term_view_custom -> cc_node term_view_custom -> solve_result; (* solve an equation between classes *) tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on immediate subterms *) tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on relevant immediate subterms *) diff --git a/src/smt/Term.ml b/src/smt/Term.ml index d5b776dd..465b2469 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -114,10 +114,18 @@ let fold_map_builtin let acc, b = f acc b in acc, B_imply (a, b) -let is_const t = match t.term_cell with +let[@inline] is_const t = match t.term_cell with | App_cst (_, a) -> IArray.is_empty a | _ -> false +let[@inline] is_custom t = match t.term_cell with + | Custom _ -> true + | _ -> false + +let[@inline] is_semantic t = match t.term_cell with + | Custom {view;tc} -> tc.tc_t_is_semantic view + | _ -> false + let map_builtin f b = let (), b = fold_map_builtin (fun () t -> (), f t) () b in b diff --git a/src/smt/Term.mli b/src/smt/Term.mli index de3d99eb..6ae651c3 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -51,6 +51,11 @@ val pp : t Fmt.printer val is_const : t -> bool +val is_custom : t -> bool + +val is_semantic : t -> bool +(** Custom term that is Shostak-ready (ie can be solved) *) + (* return [Some] iff the term is an undefined constant *) val as_cst_undef : t -> (cst * Ty.t) option diff --git a/src/smt/Util.ml b/src/smt/Util.ml index 576bebc8..83f814b2 100644 --- a/src/smt/Util.ml +++ b/src/smt/Util.ml @@ -12,6 +12,9 @@ let pp_sep sep out () = Format.fprintf out "%s@," sep let pp_list ?(sep=" ") pp out l = Fmt.list ~sep:(pp_sep sep) pp out l +let pp_pair ?(sep=" ") pp1 pp2 out t = + Fmt.pair ~sep:(pp_sep sep) pp1 pp2 out t + let pp_array ?(sep=" ") pp out l = Fmt.array ~sep:(pp_sep sep) pp out l diff --git a/src/smt/Util.mli b/src/smt/Util.mli index c6f77edf..f39267f0 100644 --- a/src/smt/Util.mli +++ b/src/smt/Util.mli @@ -9,6 +9,8 @@ val pp_list : ?sep:string -> 'a printer -> 'a list printer val pp_array : ?sep:string -> 'a printer -> 'a array printer +val pp_pair : ?sep:string -> 'a printer -> 'b printer -> ('a * 'b) printer + val pp_iarray : ?sep:string -> 'a CCFormat.printer -> 'a IArray.t CCFormat.printer exception Error of string diff --git a/src/smt/jbuild b/src/smt/jbuild index 137ba092..599e08d1 100644 --- a/src/smt/jbuild +++ b/src/smt/jbuild @@ -4,6 +4,6 @@ ((name CDCL_smt) (public_name cdcl.smt) (libraries (containers containers.data sequence cdcl)) - (flags (:standard -w +a-4-44-58-60@8 -color always -safe-string -short-paths)) + (flags (:standard -w +a-4-44-48-58-60@8 -color always -safe-string -short-paths)) (ocamlopt_flags (:standard -O3 -color always -unbox-closures -unbox-closures-factor 20))))