mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
fix: integrate negation into CC; map boolean subterms to literals
This commit is contained in:
parent
866249deb1
commit
d58759aa8c
12 changed files with 104 additions and 34 deletions
|
|
@ -7,6 +7,7 @@ type ('f, 't, 'ts) view =
|
|||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
|
||||
|
|
@ -14,6 +15,7 @@ let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
|
|||
| Bool b -> Bool b
|
||||
| App_fun (f, args) -> App_fun (f_f f, f_ts args)
|
||||
| App_ho (f, args) -> App_ho (f_t f, f_ts args)
|
||||
| Not t -> Not (f_t t)
|
||||
| If (a,b,c) -> If (f_t a, f_t b, f_t c)
|
||||
| Eq (a,b) -> Eq (f_t a, f_t b)
|
||||
| Opaque t -> Opaque (f_t t)
|
||||
|
|
@ -23,6 +25,7 @@ let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit =
|
|||
| Bool _ -> ()
|
||||
| App_fun (f, args) -> f_f f; f_ts args
|
||||
| App_ho (f, args) -> f_t f; f_ts args
|
||||
| Not t -> f_t t
|
||||
| If (a,b,c) -> f_t a; f_t b; f_t c;
|
||||
| Eq (a,b) -> f_t a; f_t b
|
||||
| Opaque t -> f_t t
|
||||
|
|
|
|||
|
|
@ -263,13 +263,14 @@ module Make(A: ARG) = struct
|
|||
Fun.equal f1 f2 && CCList.equal N.equal l1 l2
|
||||
| App_ho (f1,l1), App_ho (f2,l2) ->
|
||||
N.equal f1 f2 && CCList.equal N.equal l1 l2
|
||||
| Not a, Not b -> N.equal a b
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2
|
||||
| Eq (a1,b1), Eq (a2,b2) ->
|
||||
N.equal a1 a2 && N.equal b1 b2
|
||||
| Opaque u1, Opaque u2 -> N.equal u1 u2
|
||||
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
|
||||
| Eq _, _ | Opaque _, _
|
||||
| Eq _, _ | Opaque _, _ | Not _, _
|
||||
-> false
|
||||
|
||||
let hash (s:t) : int =
|
||||
|
|
@ -281,6 +282,7 @@ module Make(A: ARG) = struct
|
|||
| Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b)
|
||||
| Opaque u -> H.combine2 50 (N.hash u)
|
||||
| If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c)
|
||||
| Not u -> H.combine2 70 (N.hash u)
|
||||
|
||||
let pp out = function
|
||||
| Bool b -> Fmt.bool out b
|
||||
|
|
@ -289,6 +291,7 @@ module Make(A: ARG) = struct
|
|||
| App_ho (f, []) -> N.pp out f
|
||||
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l
|
||||
| Opaque t -> N.pp out t
|
||||
| Not u -> Fmt.fprintf out "(@[not@ %a@])" N.pp u
|
||||
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b
|
||||
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c
|
||||
end
|
||||
|
|
@ -629,6 +632,7 @@ module Make(A: ARG) = struct
|
|||
let a = deref_sub a in
|
||||
let b = deref_sub b in
|
||||
return @@ Eq (a,b)
|
||||
| Not u -> return @@ Not (deref_sub u)
|
||||
| App_fun (f, args) ->
|
||||
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
|
||||
if args<>[] then (
|
||||
|
|
@ -676,6 +680,16 @@ module Make(A: ARG) = struct
|
|||
let expl = Expl.mk_merge a b in
|
||||
merge_classes cc n (true_ cc) expl
|
||||
)
|
||||
| Some (Not u) ->
|
||||
(* [u = bool ==> not u = not bool] *)
|
||||
let r_u = find_ u in
|
||||
if N.equal r_u (true_ cc) then (
|
||||
let expl = Expl.mk_merge u (true_ cc) in
|
||||
merge_classes cc n (false_ cc) expl
|
||||
) else if N.equal r_u (false_ cc) then (
|
||||
let expl = Expl.mk_merge u (false_ cc) in
|
||||
merge_classes cc n (true_ cc) expl
|
||||
)
|
||||
| Some s0 ->
|
||||
(* update the signature by using [find] on each sub-node *)
|
||||
let s = update_sig s0 in
|
||||
|
|
|
|||
|
|
@ -81,13 +81,14 @@ module Make(A: TERM) = struct
|
|||
Fun.equal f1 f2 && CCList.equal Node.equal l1 l2
|
||||
| App_ho (f1,l1), App_ho (f2,l2) ->
|
||||
Node.equal f1 f2 && CCList.equal Node.equal l1 l2
|
||||
| Not n1, Not n2 -> Node.equal n1 n2
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2
|
||||
| Eq (a1,b1), Eq (a2,b2) ->
|
||||
Node.equal a1 a2 && Node.equal b1 b2
|
||||
| Opaque u1, Opaque u2 -> Node.equal u1 u2
|
||||
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
|
||||
| Eq _, _ | Opaque _, _
|
||||
| Eq _, _ | Opaque _, _ | Not _, _
|
||||
-> false
|
||||
|
||||
let hash (s:t) : int =
|
||||
|
|
@ -99,6 +100,7 @@ module Make(A: TERM) = struct
|
|||
| Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b)
|
||||
| Opaque u -> H.combine2 50 (Node.hash u)
|
||||
| If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c)
|
||||
| Not u -> H.combine2 70 (Node.hash u)
|
||||
|
||||
let pp out = function
|
||||
| Bool b -> Fmt.bool out b
|
||||
|
|
@ -107,6 +109,7 @@ module Make(A: TERM) = struct
|
|||
| App_ho (f, []) -> Node.pp out f
|
||||
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l
|
||||
| Opaque t -> Node.pp out t
|
||||
| Not u -> Fmt.fprintf out "(@[not@ %a@])" Node.pp u
|
||||
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b
|
||||
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c
|
||||
end
|
||||
|
|
@ -147,6 +150,7 @@ module Make(A: TERM) = struct
|
|||
| App_fun (_, args) -> args k
|
||||
| App_ho (f, args) -> k f; args k
|
||||
| Eq (a,b) -> k a; k b
|
||||
| Not u -> k u
|
||||
| If(a,b,c) -> k a; k b; k c
|
||||
|
||||
let rec add_t (self:t) (t:term) : node =
|
||||
|
|
@ -199,6 +203,7 @@ module Make(A: TERM) = struct
|
|||
let a = find_t_ self a in
|
||||
let b = find_t_ self b in
|
||||
return @@ Eq (a,b)
|
||||
| Not u -> return @@ Not (find_t_ self u)
|
||||
| App_fun (f, args) ->
|
||||
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
|
||||
if args<>[] then (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view =
|
|||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
module CC_types = CC_types
|
||||
|
|
|
|||
|
|
@ -145,6 +145,11 @@ let eval (m:t) (t:Term.t) : Value.t option =
|
|||
| V_bool false -> aux c
|
||||
| v -> Error.errorf "@[Model: wrong value@ for boolean %a@ %a@]" Term.pp a Value.pp v
|
||||
end
|
||||
| Not a ->
|
||||
begin match aux a with
|
||||
| V_bool b -> V_bool (not b)
|
||||
| v -> Error.errorf "@[Model: wrong value@ for boolean %a@ :val %a@]" Term.pp a Value.pp v
|
||||
end
|
||||
| Eq(a,b) ->
|
||||
let a = aux a in
|
||||
let b = aux b in
|
||||
|
|
|
|||
|
|
@ -170,8 +170,22 @@ let do_on_exit ~on_exit =
|
|||
List.iter (fun f->f()) on_exit;
|
||||
()
|
||||
|
||||
(* map boolean subterms to literals *)
|
||||
let add_bool_subterms_ (self:t) (t:Term.t) : unit =
|
||||
Term.iter_dag t
|
||||
|> Sequence.filter (fun t -> Ty.is_prop @@ Term.ty t)
|
||||
|> Sequence.filter
|
||||
(fun t -> match Term.view t with
|
||||
| Term.Not _ -> false (* will process the subterm just later *)
|
||||
| _ -> true)
|
||||
|> Sequence.iter
|
||||
(fun sub ->
|
||||
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub);
|
||||
ignore (mk_atom_t self sub : Sat_solver.atom))
|
||||
|
||||
let assume (self:t) (c:Lit.t IArray.t) : unit =
|
||||
let sat = solver self in
|
||||
IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c;
|
||||
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
|
||||
Sat_solver.add_clause_a sat c Proof_default
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ and 'a term_view =
|
|||
| App_cst of cst * 'a IArray.t (* full, first-order application *)
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
| Not of 'a
|
||||
|
||||
(* boolean literal *)
|
||||
and lit = {
|
||||
|
|
@ -165,6 +166,7 @@ let pp_term_view_gen ~pp_id ~pp_t out = function
|
|||
| Eq (a,b) -> Fmt.fprintf out "(@[<hv>=@ %a@ %a@])" pp_t a pp_t b
|
||||
| If (a, b, c) ->
|
||||
Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c
|
||||
| Not u -> Fmt.fprintf out "(@[not@ %a@])" pp_t u
|
||||
|
||||
let pp_term_top ~ids out t =
|
||||
let rec pp out t =
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ type 'a view = 'a term_view =
|
|||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
| Not of 'a
|
||||
|
||||
let[@inline] id t = t.term_id
|
||||
let[@inline] ty t = t.term_ty
|
||||
|
|
@ -67,6 +68,7 @@ let app_cst st f a =
|
|||
let[@inline] const st c = app_cst st c IArray.empty
|
||||
let[@inline] if_ st a b c = make st (Term_cell.if_ a b c)
|
||||
let[@inline] eq st a b = make st (Term_cell.eq a b)
|
||||
let[@inline] not_ st a = make st (Not a)
|
||||
|
||||
(* "eager" and, evaluating [a] first *)
|
||||
let and_eager st a b = if_ st a b (false_ st)
|
||||
|
|
@ -74,6 +76,7 @@ let and_eager st a b = if_ st a b (false_ st)
|
|||
(* might need to tranfer the negation from [t] to [sign] *)
|
||||
let abs tst t : t * bool = match view t with
|
||||
| Bool false -> true_ tst, false
|
||||
| Not u -> u, false
|
||||
| App_cst ({cst_view=Cst_def def; _}, args) ->
|
||||
def.abs ~self:t args (* TODO: pass state *)
|
||||
| _ -> t, true
|
||||
|
|
@ -93,6 +96,7 @@ let cc_view (t:t) =
|
|||
| App_cst (f,args) -> C.App_fun (f, IArray.to_seq args)
|
||||
| Eq (a,b) -> C.Eq (a, b)
|
||||
| If (a,b,c) -> C.If (a,b,c)
|
||||
| Not u -> C.Not u
|
||||
|
||||
module As_key = struct
|
||||
type t = term
|
||||
|
|
@ -105,17 +109,6 @@ module Map = CCMap.Make(As_key)
|
|||
module Set = CCSet.Make(As_key)
|
||||
module Tbl = CCHashtbl.Make(As_key)
|
||||
|
||||
let to_seq t yield =
|
||||
let rec aux t =
|
||||
yield t;
|
||||
match view t with
|
||||
| Bool _ -> ()
|
||||
| App_cst (_,a) -> IArray.iter aux a
|
||||
| Eq (a,b) -> aux a; aux b
|
||||
| If (a,b,c) -> aux a; aux b; aux c
|
||||
in
|
||||
aux t
|
||||
|
||||
(* return [Some] iff the term is an undefined constant *)
|
||||
let as_cst_undef (t:term): (cst * Ty.Fun.t) option =
|
||||
match view t with
|
||||
|
|
@ -124,6 +117,23 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option =
|
|||
|
||||
let pp = Solver_types.pp_term
|
||||
|
||||
module Iter_dag = struct
|
||||
type t = unit Tbl.t
|
||||
let create () : t = Tbl.create 16
|
||||
let iter_dag (self:t) t yield =
|
||||
let rec aux t =
|
||||
if not @@ Tbl.mem self t then (
|
||||
Tbl.add self t ();
|
||||
yield t;
|
||||
Term_cell.iter aux (view t)
|
||||
)
|
||||
in
|
||||
aux t
|
||||
end
|
||||
|
||||
let iter_dag t yield =
|
||||
let st = Iter_dag.create() in
|
||||
Iter_dag.iter_dag st t yield
|
||||
|
||||
(* TODO
|
||||
module T_arg = struct
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ type 'a view = 'a term_view =
|
|||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
| Not of 'a
|
||||
|
||||
val id : t -> int
|
||||
val view : t -> term view
|
||||
|
|
@ -33,11 +34,18 @@ val app_cst : state -> cst -> t IArray.t -> t
|
|||
val eq : state -> t -> t -> t
|
||||
val if_: state -> t -> t -> t -> t
|
||||
val and_eager : state -> t -> t -> t (* evaluate left argument first *)
|
||||
val not_ : state -> t -> t
|
||||
|
||||
(** Obtain unsigned version of [t], + the sign as a boolean *)
|
||||
val abs : state -> t -> t * bool
|
||||
|
||||
val to_seq : t -> t Sequence.t
|
||||
module Iter_dag : sig
|
||||
type t
|
||||
val create : unit -> t
|
||||
val iter_dag : t -> term -> term Sequence.t
|
||||
end
|
||||
|
||||
val iter_dag : t -> t Sequence.t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ type 'a view = 'a Solver_types.term_view =
|
|||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
| Not of 'a
|
||||
|
||||
type t = term view
|
||||
|
||||
|
|
@ -28,6 +29,7 @@ module Make_eq(A : ARG) = struct
|
|||
Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l)
|
||||
| Eq (a,b) -> Hash.combine3 12 (sub_hash a) (sub_hash b)
|
||||
| If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c)
|
||||
| Not u -> Hash.combine2 70 (sub_hash u)
|
||||
|
||||
(* equality that relies on physical equality of subterms *)
|
||||
let equal (a:A.t view) b : bool = match a, b with
|
||||
|
|
@ -37,7 +39,8 @@ module Make_eq(A : ARG) = struct
|
|||
| Eq(a1,b1), Eq(a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2
|
||||
| Bool _, _ | App_cst _, _ | If _, _ | Eq _, _
|
||||
| Not a, Not b -> sub_eq a b
|
||||
| Bool _, _ | App_cst _, _ | If _, _ | Eq _, _ | Not _, _
|
||||
-> false
|
||||
|
||||
let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp
|
||||
|
|
@ -64,12 +67,18 @@ let eq a b =
|
|||
Eq (a,b)
|
||||
)
|
||||
|
||||
let not_ t =
|
||||
match t.term_view with
|
||||
| Bool b -> Bool (not b)
|
||||
| Not u -> u.term_view
|
||||
| _ -> Not t
|
||||
|
||||
let if_ a b c =
|
||||
assert (Ty.equal b.term_ty c.term_ty);
|
||||
If (a,b,c)
|
||||
|
||||
let ty (t:t): Ty.t = match t with
|
||||
| Bool _ | Eq _ -> Ty.prop
|
||||
| Bool _ | Eq _ | Not _ -> Ty.prop
|
||||
| App_cst (f, args) ->
|
||||
begin match Cst.view f with
|
||||
| Cst_undef fty ->
|
||||
|
|
@ -95,6 +104,14 @@ let ty (t:t): Ty.t = match t with
|
|||
end
|
||||
| If (_,b,_) -> b.term_ty
|
||||
|
||||
let iter f view =
|
||||
match view with
|
||||
| Bool _ -> ()
|
||||
| App_cst (_,a) -> IArray.iter f a
|
||||
| Not u -> f u
|
||||
| Eq (a,b) -> f a; f b
|
||||
| If (a,b,c) -> f a; f b; f c
|
||||
|
||||
module Tbl = CCHashtbl.Make(struct
|
||||
type t = term view
|
||||
let equal = equal
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ type 'a view = 'a Solver_types.term_view =
|
|||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
| Not of 'a
|
||||
|
||||
type t = term view
|
||||
|
||||
|
|
@ -18,12 +19,15 @@ val const : cst -> t
|
|||
val app_cst : cst -> term IArray.t -> t
|
||||
val eq : term -> term -> t
|
||||
val if_ : term -> term -> term -> t
|
||||
val not_ : term -> t
|
||||
|
||||
val ty : t -> Ty.t
|
||||
(** Compute the type of this term cell. Not totally free *)
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val iter : ('a -> unit) -> 'a view -> unit
|
||||
|
||||
module type ARG = sig
|
||||
type t
|
||||
val hash : t -> int
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ type 'a view = 'a Bool_intf.view
|
|||
|
||||
exception Not_a_th_term
|
||||
|
||||
let id_not = ID.make "not"
|
||||
let id_and = ID.make "and"
|
||||
let id_or = ID.make "or"
|
||||
let id_imply = ID.make "=>"
|
||||
|
|
@ -23,9 +22,7 @@ let equal = T.equal
|
|||
let hash = T.hash
|
||||
|
||||
let view_id cst_id args =
|
||||
if ID.equal cst_id id_not && IArray.length args=1 then (
|
||||
B_not (IArray.get args 0)
|
||||
) else if ID.equal cst_id id_and then (
|
||||
if ID.equal cst_id id_and then (
|
||||
B_and args
|
||||
) else if ID.equal cst_id id_or then (
|
||||
B_or args
|
||||
|
|
@ -39,6 +36,7 @@ let view_id cst_id args =
|
|||
|
||||
let view_as_bool (t:T.t) : T.t view =
|
||||
match T.view t with
|
||||
| Not u -> B_not u
|
||||
| App_cst ({cst_id; _}, args) ->
|
||||
(try view_id cst_id args with Not_a_th_term -> B_atom t)
|
||||
| _ -> B_atom t
|
||||
|
|
@ -49,9 +47,7 @@ module C = struct
|
|||
|
||||
let abs ~self _a =
|
||||
match T.view self with
|
||||
| App_cst ({cst_id;_}, args) when ID.equal cst_id id_not && IArray.length args=1 ->
|
||||
(* [not a] --> [a, false] *)
|
||||
IArray.get args 0, false
|
||||
| Not u -> u, false
|
||||
| _ -> self, true
|
||||
|
||||
let eval id args =
|
||||
|
|
@ -77,7 +73,7 @@ module C = struct
|
|||
cst_view=Cst_def {
|
||||
pp=None; abs; ty=get_ty; relevant; do_cc; eval=eval id; }; }
|
||||
|
||||
let not = mk_cst id_not
|
||||
let not = T.not_
|
||||
let and_ = mk_cst id_and
|
||||
let or_ = mk_cst id_or
|
||||
let imply = mk_cst id_imply
|
||||
|
|
@ -116,17 +112,8 @@ let and_ st a b = and_l st [a;b]
|
|||
let or_ st a b = or_l st [a;b]
|
||||
let and_a st a = and_l st (IArray.to_list a)
|
||||
let or_a st a = or_l st (IArray.to_list a)
|
||||
|
||||
let eq = T.eq
|
||||
|
||||
let not_ st a =
|
||||
match as_id id_not a, T.view a with
|
||||
| _, Bool false -> T.true_ st
|
||||
| _, Bool true -> T.false_ st
|
||||
| Some args, _ ->
|
||||
assert (IArray.length args = 1);
|
||||
IArray.get args 0
|
||||
| None, _ -> T.app_cst st C.not (IArray.singleton a)
|
||||
let not_ = T.not_
|
||||
|
||||
let neq st a b = not_ st @@ eq st a b
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue