diff --git a/src/cc/CC_types.ml b/src/cc/CC_types.ml index d5e098f0..f0ff2363 100644 --- a/src/cc/CC_types.ml +++ b/src/cc/CC_types.ml @@ -7,6 +7,7 @@ type ('f, 't, 'ts) view = | App_ho of 't * 'ts | If of 't * 't * 't | Eq of 't * 't + | Not of 't | Opaque of 't (* do not enter *) let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = @@ -14,6 +15,7 @@ let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = | Bool b -> Bool b | App_fun (f, args) -> App_fun (f_f f, f_ts args) | App_ho (f, args) -> App_ho (f_t f, f_ts args) + | Not t -> Not (f_t t) | If (a,b,c) -> If (f_t a, f_t b, f_t c) | Eq (a,b) -> Eq (f_t a, f_t b) | Opaque t -> Opaque (f_t t) @@ -23,6 +25,7 @@ let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit = | Bool _ -> () | App_fun (f, args) -> f_f f; f_ts args | App_ho (f, args) -> f_t f; f_ts args + | Not t -> f_t t | If (a,b,c) -> f_t a; f_t b; f_t c; | Eq (a,b) -> f_t a; f_t b | Opaque t -> f_t t diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index 53f5be62..5a12f6fb 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -263,13 +263,14 @@ module Make(A: ARG) = struct Fun.equal f1 f2 && CCList.equal N.equal l1 l2 | App_ho (f1,l1), App_ho (f2,l2) -> N.equal f1 f2 && CCList.equal N.equal l1 l2 + | Not a, Not b -> N.equal a b | If (a1,b1,c1), If (a2,b2,c2) -> N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2 | Eq (a1,b1), Eq (a2,b2) -> N.equal a1 a2 && N.equal b1 b2 | Opaque u1, Opaque u2 -> N.equal u1 u2 | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ - | Eq _, _ | Opaque _, _ + | Eq _, _ | Opaque _, _ | Not _, _ -> false let hash (s:t) : int = @@ -281,6 +282,7 @@ module Make(A: ARG) = struct | Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b) | Opaque u -> H.combine2 50 (N.hash u) | If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c) + | Not u -> H.combine2 70 (N.hash u) let pp out = function | Bool b -> Fmt.bool out b @@ -289,6 +291,7 @@ module Make(A: ARG) = struct | App_ho (f, []) -> N.pp out f | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l | Opaque t -> N.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" N.pp u | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c end @@ -629,6 +632,7 @@ module Make(A: ARG) = struct let a = deref_sub a in let b = deref_sub b in return @@ Eq (a,b) + | Not u -> return @@ Not (deref_sub u) | App_fun (f, args) -> let args = args |> Sequence.map deref_sub |> Sequence.to_list in if args<>[] then ( @@ -676,6 +680,16 @@ module Make(A: ARG) = struct let expl = Expl.mk_merge a b in merge_classes cc n (true_ cc) expl ) + | Some (Not u) -> + (* [u = bool ==> not u = not bool] *) + let r_u = find_ u in + if N.equal r_u (true_ cc) then ( + let expl = Expl.mk_merge u (true_ cc) in + merge_classes cc n (false_ cc) expl + ) else if N.equal r_u (false_ cc) then ( + let expl = Expl.mk_merge u (false_ cc) in + merge_classes cc n (true_ cc) expl + ) | Some s0 -> (* update the signature by using [find] on each sub-node *) let s = update_sig s0 in diff --git a/src/cc/Mini_cc.ml b/src/cc/Mini_cc.ml index ce0d7f56..c7ab15b5 100644 --- a/src/cc/Mini_cc.ml +++ b/src/cc/Mini_cc.ml @@ -81,13 +81,14 @@ module Make(A: TERM) = struct Fun.equal f1 f2 && CCList.equal Node.equal l1 l2 | App_ho (f1,l1), App_ho (f2,l2) -> Node.equal f1 f2 && CCList.equal Node.equal l1 l2 + | Not n1, Not n2 -> Node.equal n1 n2 | If (a1,b1,c1), If (a2,b2,c2) -> Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2 | Eq (a1,b1), Eq (a2,b2) -> Node.equal a1 a2 && Node.equal b1 b2 | Opaque u1, Opaque u2 -> Node.equal u1 u2 | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ - | Eq _, _ | Opaque _, _ + | Eq _, _ | Opaque _, _ | Not _, _ -> false let hash (s:t) : int = @@ -99,6 +100,7 @@ module Make(A: TERM) = struct | Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b) | Opaque u -> H.combine2 50 (Node.hash u) | If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c) + | Not u -> H.combine2 70 (Node.hash u) let pp out = function | Bool b -> Fmt.bool out b @@ -107,6 +109,7 @@ module Make(A: TERM) = struct | App_ho (f, []) -> Node.pp out f | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l | Opaque t -> Node.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" Node.pp u | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c end @@ -147,6 +150,7 @@ module Make(A: TERM) = struct | App_fun (_, args) -> args k | App_ho (f, args) -> k f; args k | Eq (a,b) -> k a; k b + | Not u -> k u | If(a,b,c) -> k a; k b; k c let rec add_t (self:t) (t:term) : node = @@ -199,6 +203,7 @@ module Make(A: TERM) = struct let a = find_t_ self a in let b = find_t_ self b in return @@ Eq (a,b) + | Not u -> return @@ Not (find_t_ self u) | App_fun (f, args) -> let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in if args<>[] then ( diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 6389618f..574fe5bd 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -5,6 +5,7 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view = | App_ho of 't * 'ts | If of 't * 't * 't | Eq of 't * 't + | Not of 't | Opaque of 't (* do not enter *) module CC_types = CC_types diff --git a/src/smt/Model.ml b/src/smt/Model.ml index fa6281e7..1c97ab05 100644 --- a/src/smt/Model.ml +++ b/src/smt/Model.ml @@ -145,6 +145,11 @@ let eval (m:t) (t:Term.t) : Value.t option = | V_bool false -> aux c | v -> Error.errorf "@[Model: wrong value@ for boolean %a@ %a@]" Term.pp a Value.pp v end + | Not a -> + begin match aux a with + | V_bool b -> V_bool (not b) + | v -> Error.errorf "@[Model: wrong value@ for boolean %a@ :val %a@]" Term.pp a Value.pp v + end | Eq(a,b) -> let a = aux a in let b = aux b in diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index 4ce0bbe7..f190107c 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -170,8 +170,22 @@ let do_on_exit ~on_exit = List.iter (fun f->f()) on_exit; () +(* map boolean subterms to literals *) +let add_bool_subterms_ (self:t) (t:Term.t) : unit = + Term.iter_dag t + |> Sequence.filter (fun t -> Ty.is_prop @@ Term.ty t) + |> Sequence.filter + (fun t -> match Term.view t with + | Term.Not _ -> false (* will process the subterm just later *) + | _ -> true) + |> Sequence.iter + (fun sub -> + Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub); + ignore (mk_atom_t self sub : Sat_solver.atom)) + let assume (self:t) (c:Lit.t IArray.t) : unit = let sat = solver self in + IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c; let c = IArray.to_array_map (Sat_solver.make_atom sat) c in Sat_solver.add_clause_a sat c Proof_default diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index de1ea2fe..db4f9e9c 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -22,6 +22,7 @@ and 'a term_view = | App_cst of cst * 'a IArray.t (* full, first-order application *) | Eq of 'a * 'a | If of 'a * 'a * 'a + | Not of 'a (* boolean literal *) and lit = { @@ -165,6 +166,7 @@ let pp_term_view_gen ~pp_id ~pp_t out = function | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" pp_t a pp_t b | If (a, b, c) -> Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c + | Not u -> Fmt.fprintf out "(@[not@ %a@])" pp_t u let pp_term_top ~ids out t = let rec pp out t = diff --git a/src/smt/Term.ml b/src/smt/Term.ml index 10f3adcd..0d9b0d40 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -12,6 +12,7 @@ type 'a view = 'a term_view = | App_cst of cst * 'a IArray.t | Eq of 'a * 'a | If of 'a * 'a * 'a + | Not of 'a let[@inline] id t = t.term_id let[@inline] ty t = t.term_ty @@ -67,6 +68,7 @@ let app_cst st f a = let[@inline] const st c = app_cst st c IArray.empty let[@inline] if_ st a b c = make st (Term_cell.if_ a b c) let[@inline] eq st a b = make st (Term_cell.eq a b) +let[@inline] not_ st a = make st (Not a) (* "eager" and, evaluating [a] first *) let and_eager st a b = if_ st a b (false_ st) @@ -74,6 +76,7 @@ let and_eager st a b = if_ st a b (false_ st) (* might need to tranfer the negation from [t] to [sign] *) let abs tst t : t * bool = match view t with | Bool false -> true_ tst, false + | Not u -> u, false | App_cst ({cst_view=Cst_def def; _}, args) -> def.abs ~self:t args (* TODO: pass state *) | _ -> t, true @@ -93,6 +96,7 @@ let cc_view (t:t) = | App_cst (f,args) -> C.App_fun (f, IArray.to_seq args) | Eq (a,b) -> C.Eq (a, b) | If (a,b,c) -> C.If (a,b,c) + | Not u -> C.Not u module As_key = struct type t = term @@ -105,17 +109,6 @@ module Map = CCMap.Make(As_key) module Set = CCSet.Make(As_key) module Tbl = CCHashtbl.Make(As_key) -let to_seq t yield = - let rec aux t = - yield t; - match view t with - | Bool _ -> () - | App_cst (_,a) -> IArray.iter aux a - | Eq (a,b) -> aux a; aux b - | If (a,b,c) -> aux a; aux b; aux c - in - aux t - (* return [Some] iff the term is an undefined constant *) let as_cst_undef (t:term): (cst * Ty.Fun.t) option = match view t with @@ -124,6 +117,23 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option = let pp = Solver_types.pp_term +module Iter_dag = struct + type t = unit Tbl.t + let create () : t = Tbl.create 16 + let iter_dag (self:t) t yield = + let rec aux t = + if not @@ Tbl.mem self t then ( + Tbl.add self t (); + yield t; + Term_cell.iter aux (view t) + ) + in + aux t +end + +let iter_dag t yield = + let st = Iter_dag.create() in + Iter_dag.iter_dag st t yield (* TODO module T_arg = struct diff --git a/src/smt/Term.mli b/src/smt/Term.mli index b7e751be..bd079892 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -12,6 +12,7 @@ type 'a view = 'a term_view = | App_cst of cst * 'a IArray.t | Eq of 'a * 'a | If of 'a * 'a * 'a + | Not of 'a val id : t -> int val view : t -> term view @@ -33,11 +34,18 @@ val app_cst : state -> cst -> t IArray.t -> t val eq : state -> t -> t -> t val if_: state -> t -> t -> t -> t val and_eager : state -> t -> t -> t (* evaluate left argument first *) +val not_ : state -> t -> t (** Obtain unsigned version of [t], + the sign as a boolean *) val abs : state -> t -> t * bool -val to_seq : t -> t Sequence.t +module Iter_dag : sig + type t + val create : unit -> t + val iter_dag : t -> term -> term Sequence.t +end + +val iter_dag : t -> t Sequence.t val pp : t Fmt.printer diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml index 85cd04b8..e46b54b4 100644 --- a/src/smt/Term_cell.ml +++ b/src/smt/Term_cell.ml @@ -8,6 +8,7 @@ type 'a view = 'a Solver_types.term_view = | App_cst of cst * 'a IArray.t | Eq of 'a * 'a | If of 'a * 'a * 'a + | Not of 'a type t = term view @@ -28,6 +29,7 @@ module Make_eq(A : ARG) = struct Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l) | Eq (a,b) -> Hash.combine3 12 (sub_hash a) (sub_hash b) | If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c) + | Not u -> Hash.combine2 70 (sub_hash u) (* equality that relies on physical equality of subterms *) let equal (a:A.t view) b : bool = match a, b with @@ -37,7 +39,8 @@ module Make_eq(A : ARG) = struct | Eq(a1,b1), Eq(a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 | If (a1,b1,c1), If (a2,b2,c2) -> sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2 - | Bool _, _ | App_cst _, _ | If _, _ | Eq _, _ + | Not a, Not b -> sub_eq a b + | Bool _, _ | App_cst _, _ | If _, _ | Eq _, _ | Not _, _ -> false let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp @@ -64,12 +67,18 @@ let eq a b = Eq (a,b) ) +let not_ t = + match t.term_view with + | Bool b -> Bool (not b) + | Not u -> u.term_view + | _ -> Not t + let if_ a b c = assert (Ty.equal b.term_ty c.term_ty); If (a,b,c) let ty (t:t): Ty.t = match t with - | Bool _ | Eq _ -> Ty.prop + | Bool _ | Eq _ | Not _ -> Ty.prop | App_cst (f, args) -> begin match Cst.view f with | Cst_undef fty -> @@ -95,6 +104,14 @@ let ty (t:t): Ty.t = match t with end | If (_,b,_) -> b.term_ty +let iter f view = + match view with + | Bool _ -> () + | App_cst (_,a) -> IArray.iter f a + | Not u -> f u + | Eq (a,b) -> f a; f b + | If (a,b,c) -> f a; f b; f c + module Tbl = CCHashtbl.Make(struct type t = term view let equal = equal diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli index c31e4c9e..47e2ad57 100644 --- a/src/smt/Term_cell.mli +++ b/src/smt/Term_cell.mli @@ -6,6 +6,7 @@ type 'a view = 'a Solver_types.term_view = | App_cst of cst * 'a IArray.t | Eq of 'a * 'a | If of 'a * 'a * 'a + | Not of 'a type t = term view @@ -18,12 +19,15 @@ val const : cst -> t val app_cst : cst -> term IArray.t -> t val eq : term -> term -> t val if_ : term -> term -> term -> t +val not_ : term -> t val ty : t -> Ty.t (** Compute the type of this term cell. Not totally free *) val pp : t Fmt.printer +val iter : ('a -> unit) -> 'a view -> unit + module type ARG = sig type t val hash : t -> int diff --git a/src/th-bool/Bool_term.ml b/src/th-bool/Bool_term.ml index e9a617e2..22dd0c99 100644 --- a/src/th-bool/Bool_term.ml +++ b/src/th-bool/Bool_term.ml @@ -14,7 +14,6 @@ type 'a view = 'a Bool_intf.view exception Not_a_th_term -let id_not = ID.make "not" let id_and = ID.make "and" let id_or = ID.make "or" let id_imply = ID.make "=>" @@ -23,9 +22,7 @@ let equal = T.equal let hash = T.hash let view_id cst_id args = - if ID.equal cst_id id_not && IArray.length args=1 then ( - B_not (IArray.get args 0) - ) else if ID.equal cst_id id_and then ( + if ID.equal cst_id id_and then ( B_and args ) else if ID.equal cst_id id_or then ( B_or args @@ -39,6 +36,7 @@ let view_id cst_id args = let view_as_bool (t:T.t) : T.t view = match T.view t with + | Not u -> B_not u | App_cst ({cst_id; _}, args) -> (try view_id cst_id args with Not_a_th_term -> B_atom t) | _ -> B_atom t @@ -49,9 +47,7 @@ module C = struct let abs ~self _a = match T.view self with - | App_cst ({cst_id;_}, args) when ID.equal cst_id id_not && IArray.length args=1 -> - (* [not a] --> [a, false] *) - IArray.get args 0, false + | Not u -> u, false | _ -> self, true let eval id args = @@ -77,7 +73,7 @@ module C = struct cst_view=Cst_def { pp=None; abs; ty=get_ty; relevant; do_cc; eval=eval id; }; } - let not = mk_cst id_not + let not = T.not_ let and_ = mk_cst id_and let or_ = mk_cst id_or let imply = mk_cst id_imply @@ -116,17 +112,8 @@ let and_ st a b = and_l st [a;b] let or_ st a b = or_l st [a;b] let and_a st a = and_l st (IArray.to_list a) let or_a st a = or_l st (IArray.to_list a) - let eq = T.eq - -let not_ st a = - match as_id id_not a, T.view a with - | _, Bool false -> T.true_ st - | _, Bool true -> T.false_ st - | Some args, _ -> - assert (IArray.length args = 1); - IArray.get args 0 - | None, _ -> T.app_cst st C.not (IArray.singleton a) +let not_ = T.not_ let neq st a b = not_ st @@ eq st a b