diff --git a/src/base-term/Base_types.ml b/src/base-term/Base_types.ml index d1803cd5..c139a3db 100644 --- a/src/base-term/Base_types.ml +++ b/src/base-term/Base_types.ml @@ -86,6 +86,7 @@ and data = { and cstor = { cstor_id: ID.t; cstor_is_a: ID.t; + mutable cstor_arity: int; cstor_args: select list lazy_t; cstor_ty_as_data: data; cstor_ty: ty lazy_t; @@ -946,6 +947,7 @@ module Cstor = struct type t = cstor = { cstor_id: ID.t; cstor_is_a: ID.t; + mutable cstor_arity: int; cstor_args: select list lazy_t; cstor_ty_as_data: data; cstor_ty: ty lazy_t; diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index b07da7a6..00899d0b 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -282,7 +282,12 @@ module Th_data = Sidekick_th_data.Make(struct | _ -> T_other t let mk_cstor tst c args : Term.t = Term.app_fun tst (Fun.cstor c) args - let mk_is_a tst c u : Term.t = Term.app_fun tst (Fun.is_a c) (IArray.singleton u) + let mk_is_a tst c u : Term.t = + if c.cstor_arity=0 then ( + Term.eq tst u (Term.const tst (Fun.cstor c)) + ) else ( + Term.app_fun tst (Fun.is_a c) (IArray.singleton u) + ) let ty_is_finite = Ty.finite let ty_set_is_finite = Ty.set_finite diff --git a/src/smtlib/Typecheck.ml b/src/smtlib/Typecheck.ml index 92887633..e5a565fc 100644 --- a/src/smtlib/Typecheck.ml +++ b/src/smtlib/Typecheck.ml @@ -392,6 +392,7 @@ and conv_statement_aux ctx (stmt:PA.statement) : Stmt.t list = cstor_id; cstor_is_a = ID.makef "(is _ %s)" cstor_name; (* every fun needs a name *) cstor_args=lazy (mk_selectors cstor); + cstor_arity=0; cstor_ty_as_data=data; cstor_ty=data.data_as_ty; } in @@ -428,7 +429,7 @@ and conv_statement_aux ctx (stmt:PA.statement) : Stmt.t list = (* now force definitions *) List.iter (fun {Data.data_cstors=lazy m;data_as_ty=lazy _;_} -> - ID.Map.iter (fun _ {Cstor.cstor_args=lazy _;_} -> ()) m; + ID.Map.iter (fun _ ({Cstor.cstor_args=lazy l;_} as r) -> r.cstor_arity <- List.length l) m; ()) l; [Stmt.Stmt_data l] diff --git a/src/th-data/Sidekick_th_data.ml b/src/th-data/Sidekick_th_data.ml index 1a5ce551..2e080b32 100644 --- a/src/th-data/Sidekick_th_data.ml +++ b/src/th-data/Sidekick_th_data.ml @@ -152,7 +152,8 @@ module Make(A : ARG) : S with module A = A = struct tst: T.state; cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *) cards: Card.t; (* remember finiteness *) - to_decide: bool ref N_tbl.t; (* set of terms to decide. true means already clausified *) + to_decide: unit N_tbl.t; (* set of terms to decide. *) + case_split_done: unit T.Tbl.t; (* set of terms for which case split is done *) (* TODO: also allocate a bit in CC to filter out quickly classes without cstors? *) (* TODO: bitfield for types with less than 62 cstors, to quickly detect conflict? *) } @@ -167,7 +168,7 @@ module Make(A : ARG) : S with module A = A = struct N_tbl.pop_levels self.to_decide n; () - (* TODO: select/is-a, with exhaustivity rule *) + (* TODO: select/is-a *) (* TODO: acyclicity *) (* attach data to constructor terms *) @@ -204,8 +205,8 @@ module Make(A : ARG) : S with module A = A = struct if Card.is_finite self.cards ty && not (N_tbl.mem self.to_decide n) then ( (* must decide this term *) Log.debugf 20 - (fun k->k "(@[%s.on-new-term.must-decide-finitey@ %a@])" name T.pp t); - N_tbl.add self.to_decide n (ref false); + (fun k->k "(@[%s.on-new-term.must-decide-finite-ty@ %a@])" name T.pp t); + N_tbl.add self.to_decide n (); ) | _ -> () @@ -249,37 +250,46 @@ module Make(A : ARG) : S with module A = A = struct | Ty_data {cstors} -> cstors | _ -> assert false + (* on final check, make sure we have done case split on all terms that + need it. *) let on_final_check (self:t) (solver:SI.t) (acts:SI.actions) _trail = let remaining_to_decide = N_tbl.to_iter self.to_decide - |> Iter.map (fun (n,r) -> SI.cc_find solver n, r) - |> Iter.filter (fun (n,r) -> not !r && not (N_tbl.mem self.cstors n)) + |> Iter.map (fun (n,_) -> SI.cc_find solver n) + |> Iter.filter + (fun n -> + not (N_tbl.mem self.cstors n) && + not (T.Tbl.mem self.case_split_done (N.term n))) |> Iter.to_rev_list in begin match remaining_to_decide with | [] -> () | l -> Log.debugf 10 - (fun k->k "(@[%s.must-decide@ %a@])" name - (Util.pp_list (Fmt.map fst N.pp)) l); + (fun k->k "(@[%s.final-check.must-decide@ %a@])" name (Util.pp_list N.pp) l); (* add clauses [∨_c is-c(t)] and [¬(is-a t) ∨ ¬(is-b t)] *) List.iter - (fun (n,r) -> - assert (not !r); + (fun n -> let t = N.term n in - let c = - cstors_of_ty (T.ty t) - |> Iter.map (fun c -> A.mk_is_a self.tst c t) - |> Iter.map (SI.mk_lit solver acts) - |> Iter.to_rev_list - in - r := true; - SI.add_clause_permanent solver acts c; - Iter.diagonal_l c - (fun (c1,c2) -> - SI.add_clause_permanent solver acts - [SI.Lit.neg c1; SI.Lit.neg c2]); - ()) + (* [t] might have been expanded already, in case of duplicates in [l] *) + if not @@ T.Tbl.mem self.case_split_done t then ( + T.Tbl.add self.case_split_done t (); + let c = + cstors_of_ty (T.ty t) + |> Iter.map (fun c -> A.mk_is_a self.tst c t) + |> Iter.map + (fun t -> + let lit = SI.mk_lit solver acts t in + (* TODO: set default polarity, depending on n° of args? *) + lit) + |> Iter.to_rev_list + in + SI.add_clause_permanent solver acts c; + Iter.diagonal_l c + (fun (c1,c2) -> + SI.add_clause_permanent solver acts + [SI.Lit.neg c1; SI.Lit.neg c2]); + )) l end; () @@ -289,6 +299,7 @@ module Make(A : ARG) : S with module A = A = struct tst=SI.tst solver; cstors=N_tbl.create ~size:32 (); to_decide=N_tbl.create ~size:16 (); + case_split_done=T.Tbl.create 16; cards=Card.create(); } in Log.debugf 1 (fun k->k "(setup :%s)" name); diff --git a/src/th-data/dune b/src/th-data/dune index 35e07fee..9dd86c15 100644 --- a/src/th-data/dune +++ b/src/th-data/dune @@ -4,5 +4,5 @@ (name Sidekick_th_data) (public_name sidekick.th-data) (libraries containers sidekick.core sidekick.util) - (flags :standard -open Sidekick_util)) + (flags :standard -open Sidekick_util -w -32)) ; TODO get warning back