Merge branch 'wip-defunctorize-terms'

This commit is contained in:
Simon Cruanes 2022-08-29 20:27:27 -04:00
commit 737a11504d
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
255 changed files with 13750 additions and 11880 deletions

2
.gitignore vendored
View file

@ -16,3 +16,5 @@ snapshots/
perf.*
.mypy_cache
*.gz
.git-blame-ignore-revs
*.json

View file

@ -22,6 +22,7 @@ build-dev:
clean:
@dune clean
@rm sidekick || true
test:
@dune runtest $(OPTS) --force --no-buffer
@ -32,31 +33,34 @@ DATE=$(shell date +%FT%H:%M)
snapshots:
@mkdir -p snapshots
$(TESTTOOL)-quick: snapshots
sidekick:
@ln -f -s _build/default/src/main/main.exe ./sidekick
$(TESTTOOL)-quick: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/quick-$(DATE).csv --task sidekick-smt-quick
$(TESTTOOL)-quick-proofs: snapshots
$(TESTTOOL)-quick-proofs: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/quick-$(DATE).csv --task sidekick-smt-quick-proofs --proof-dir out-proofs-$(DATE)/
$(TESTTOOL)-local: snapshots
$(TESTTOOL)-local: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/quick-$(DATE).csv --task sidekick-smt-local
$(TESTTOOL)-smt-QF_UF: snapshots
$(TESTTOOL)-smt-QF_UF: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_UF-$(DATE).csv --task sidekick-smt-nodir tests/QF_UF
$(TESTTOOL)-smt-QF_DT: snapshots
$(TESTTOOL)-smt-QF_DT: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_DT-$(DATE).csv --task sidekick-smt-nodir tests/QF_DT
$(TESTTOOL)-smt-QF_LRA: snapshots
$(TESTTOOL)-smt-QF_LRA: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_LRA-$(DATE).csv --task sidekick-smt-nodir tests/QF_LRA
$(TESTTOOL)-smt-QF_UFLRA: snapshots
$(TESTTOOL)-smt-QF_UFLRA: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_UFLRA-$(DATE).csv --task sidekick-smt-nodir tests/QF_UFLRA
$(TESTTOOL)-smt-QF_LIA: snapshots
$(TESTTOOL)-smt-QF_LIA: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_LRA-$(DATE).csv --task sidekick-smt-nodir tests/QF_LIA
$(TESTTOOL)-smt-QF_UFLIA: snapshots
$(TESTTOOL)-smt-QF_UFLIA: sidekick snapshots
$(TESTTOOL) run $(TESTOPTS) \
--csv snapshots/smt-QF_LRA-$(DATE).csv --task sidekick-smt-nodir tests/QF_UFLIA
@ -77,7 +81,7 @@ reindent:
@find src '(' -name '*.ml' -or -name '*.mli' ')' -print0 | xargs -0 echo "reindenting: "
@find src '(' -name '*.ml' -or -name '*.mli' ')' -print0 | xargs -0 ocp-indent -i
WATCH=@all
WATCH?=@all
watch:
dune build $(WATCH) -w $(OPTS)
#@dune build @all -w # TODO: once tests pass

View file

@ -38,7 +38,7 @@ OCaml prompt):
# #show Sidekick_base;;
module Sidekick_base :
sig
module Base_types = Sidekick_base__.Base_types
module Types_ = Sidekick_base__.Types_
...
end
```
@ -75,34 +75,28 @@ We're going to use these libraries:
main Solver, along with a few theories. Let us peek into it now:
```ocaml
# #require "sidekick-base.solver";;
# #show Sidekick_base_solver;;
module Sidekick_base_solver :
# #require "sidekick-base";;
# #show Sidekick_base.Solver;;
module Solver = Sidekick_base__.Solver
module Solver = Sidekick_base.Solver
module Solver :
sig
module Solver_arg : sig ... end
module Solver : sig ... end
module Th_data : sig ... end
module Th_bool : sig ... end
module Gensym : sig ... end
module Th_lra : sig ... end
val th_bool : Solver.theory
val th_data : Solver.theory
val th_lra : Solver.theory
end
type t = Solver.t
...
```
Let's bring more all these things into scope, and install some printers
for legibility:
```ocaml
# open Sidekick_core;;
# open Sidekick_base;;
# open Sidekick_base_solver;;
# open Sidekick_smt_solver;;
# #install_printer Term.pp;;
# #install_printer Lit.pp;;
# #install_printer Ty.pp;;
# #install_printer Fun.pp;;
# #install_printer Const.pp;;
# #install_printer Model.pp;;
# #install_printer Solver.Model.pp;;
```
## First steps in solving
@ -117,30 +111,24 @@ All terms in sidekick live in a store, which is necessary for _hashconsing_
in alternative implementations.)
```ocaml
# let tstore = Term.create ();;
# let tstore = Term.Store.create ();;
val tstore : Term.store = <abstr>
# Term.store_size tstore;;
- : int = 2
# Term.Store.size tstore;;
- : int = 0
```
Interesting, there are already two terms that are predefined.
Let's peek at them:
Let's look at some basic terms we can build immediately.
```ocaml
# let all_terms_init =
Term.store_iter tstore |> Iter.to_list |> List.sort Term.compare;;
val all_terms_init : Term.t list = [true; false]
# Term.true_ tstore;;
- : Term.t = true
- : Term.term = true
# (* check it's the same term *)
Term.(equal (true_ tstore) (List.hd all_terms_init));;
- : bool = true
# Term.false_ tstore;;
- : Term.term = false
# Term.(equal (false_ tstore) (List.hd all_terms_init));;
- : bool = false
# Term.eq tstore (Term.true_ tstore) (Term.false_ tstore);;
- : Term.term = (= Bool true false)
```
Cool. Similarly, we need to manipulate types.
@ -151,57 +139,60 @@ In general we'd need to carry around a type store as well.
The only predefined type is _bool_, the type of booleans:
```ocaml
# Ty.bool ();;
- : Ty.t = Bool
# Ty.bool tstore;;
- : Term.term = Bool
```
Now we can define new terms and constants. Let's try to define
a few boolean constants named "p", "q", "r":
```ocaml
# let p = Term.const_undefined tstore (ID.make "p") @@ Ty.bool();;
val p : Term.t = p
# let q = Term.const_undefined tstore (ID.make "q") @@ Ty.bool();;
val q : Term.t = q
# let r = Term.const_undefined tstore (ID.make "r") @@ Ty.bool();;
val r : Term.t = r
# let p = Uconst.uconst_of_str tstore "p" [] @@ Ty.bool tstore;;
val p : Term.term = p
# let q = Uconst.uconst_of_str tstore "q" [] @@ Ty.bool tstore;;
val q : Term.term = q
# let r = Uconst.uconst_of_str tstore "r" [] @@ Ty.bool tstore;;
val r : Term.term = r
# Term.ty p;;
- : Ty.t = Bool
- : Term.term = Bool
# Term.equal p q;;
- : bool = false
# Term.view p;;
- : Term.t Term.view = Sidekick_base.Term.App_fun (p/3, [||])
- : Term.view = Sidekick_base.Term.E_const p
# Term.store_iter tstore |> Iter.to_list |> List.sort Term.compare;;
- : Term.t list = [true; false; p; q; r]
# Term.equal p p;;
- : bool = true
```
We can now build formulas from these.
```ocaml
# let p_eq_q = Term.eq tstore p q;;
val p_eq_q : Term.t = (= p q)
val p_eq_q : Term.term = (= Bool p q)
# let p_imp_r = Form.imply tstore p r;;
val p_imp_r : Term.t = (=> p r)
val p_imp_r : Term.term = (=> p r)
```
### Using a solver.
We can create a solver by passing `Solver.create` a term store
and a type store (which in our case is simply `() : unit`).
and a proof trace (here, `Proof_trace.dummy` because we don't care about
proofs).
A list of theories can be added initially, or later using
`Solver.add_theory`.
```ocaml
# let solver = Solver.create ~theories:[th_bool] ~proof:(Proof.empty) tstore () ();;
val solver : Solver.t = <abstr>
# let proof = Proof_trace.dummy;;
val proof : Proof_trace.t = <abstr>
# let solver = Solver.create_default ~theories:[th_bool_static] ~proof tstore ();;
val solver : solver = <abstr>
# Solver.add_theory;;
- : Solver.t -> Solver.theory -> unit = <fun>
- : solver -> theory -> unit = <fun>
```
Alright, let's do some solving now ⚙️. We're going to assert
@ -211,18 +202,18 @@ We start with `p = q`.
```ocaml
# p_eq_q;;
- : Term.t = (= p q)
- : Term.term = (= Bool p q)
# Solver.assert_term solver p_eq_q;;
- : unit = ()
# Solver.solve ~assumptions:[] solver;;
- : Solver.res =
Sidekick_base_solver.Solver.Sat
Sidekick_smt_solver.Solver.Sat
(model
(true := true)
(false := false)
(p := true)
(false := $@c[0])
(q := true)
((= p q) := true))
((= Bool p q) := true)
(true := true)
(p := true))
```
It is satisfiable, and we got a model where "p" and "q" are both false.
@ -238,8 +229,8 @@ whether the assertions and hypotheses are satisfiable together.
~assumptions:[Solver.mk_lit_t solver p;
Solver.mk_lit_t solver q ~sign:false];;
- : Solver.res =
Sidekick_base_solver.Solver.Unsat
{Sidekick_base_solver.Solver.unsat_core = <fun>; unsat_proof_step = <fun>}
Sidekick_smt_solver.Solver.Unsat
{Sidekick_smt_solver.Solver.unsat_core = <fun>; unsat_step_id = <fun>}
```
Here it's unsat, because we asserted "p = q", and then assumed "p"
@ -253,40 +244,40 @@ Note that this doesn't affect satisfiability without assumptions:
```ocaml
# Solver.solve ~assumptions:[] solver;;
- : Solver.res =
Sidekick_base_solver.Solver.Sat
Sidekick_smt_solver.Solver.Sat
(model
(false := $@c[0])
(q := false)
((= Bool p q) := true)
(true := true)
(false := false)
(p := true)
(q := true)
((= p q) := true))
(p := false))
```
We can therefore add more formulas and see where it leads us.
```ocaml
# p_imp_r;;
- : Term.t = (=> p r)
- : Term.term = (=> p r)
# Solver.assert_term solver p_imp_r;;
- : unit = ()
# Solver.solve ~assumptions:[] solver;;
- : Solver.res =
Sidekick_base_solver.Solver.Sat
Sidekick_smt_solver.Solver.Sat
(model
(true := true)
(false := false)
(p := true)
(q := true)
(false := $@c[0])
(q := false)
(r := true)
((= p q) := true)
((=> p r) := true))
((= Bool p q) := true)
((or r (not p) false) := true)
(true := true)
(p := false))
```
Still satisfiable, but now we see `r` in the model, too. And now:
```ocaml
# let q_imp_not_r = Form.imply tstore q (Form.not_ tstore r);;
val q_imp_not_r : Term.t = (=> q (not r))
val q_imp_not_r : Term.term = (=> q (not r))
# Solver.assert_term solver q_imp_not_r;;
- : unit = ()
@ -295,8 +286,8 @@ val q_imp_not_r : Term.t = (=> q (not r))
# Solver.solve ~assumptions:[] solver;;
- : Solver.res =
Sidekick_base_solver.Solver.Unsat
{Sidekick_base_solver.Solver.unsat_core = <fun>; unsat_proof_step = <fun>}
Sidekick_smt_solver.Solver.Unsat
{Sidekick_smt_solver.Solver.unsat_core = <fun>; unsat_step_id = <fun>}
```
This time we got _unsat_ and there is no way of undoing it.
@ -310,25 +301,25 @@ We can solve linear real arithmetic problems as well.
Let's create a new solver and add the theory of reals to it.
```ocaml
# let solver = Solver.create ~theories:[th_bool; th_lra] ~proof:(Proof.empty) tstore () ();;
val solver : Solver.t = <abstr>
# let solver = Solver.create_default ~theories:[th_bool_static; th_lra] ~proof tstore ();;
val solver : solver = <abstr>
```
Create a few arithmetic constants.
```ocaml
# let real = Ty.real ();;
val real : Ty.t = Real
# let a = Term.const_undefined tstore (ID.make "a") real;;
val a : Term.t = a
# let b = Term.const_undefined tstore (ID.make "b") real;;
val b : Term.t = b
# let real = Ty.real tstore;;
val real : Term.term = Real
# let a = Uconst.uconst_of_str tstore "a" [] real;;
val a : Term.term = a
# let b = Uconst.uconst_of_str tstore "b" [] real;;
val b : Term.term = b
# Term.ty a;;
- : Ty.t = Real
- : Term.term = Real
# let a_leq_b = Term.LRA.(leq tstore a b);;
val a_leq_b : Term.t = (<= a b)
# let a_leq_b = LRA_term.leq tstore a b;;
val a_leq_b : Term.term = (<= a b)
```
We can play with assertions now:
@ -338,31 +329,39 @@ We can play with assertions now:
- : unit = ()
# Solver.solve ~assumptions:[] solver;;
- : Solver.res =
Sidekick_base_solver.Solver.Sat
Sidekick_smt_solver.Solver.Sat
(model
(true := true)
(false := false)
(a := 0)
((+ a) := $@c[0])
(0 := 0)
(false := $@c[5])
(b := 0)
((<= (+ a (* -1 b)) 0) := true)
(_sk_lra__le_comb0 := 0))
((+ a ((* -1) b)) := $@c[7])
((<= (+ a ((* -1) b))) := $@c[3])
((* -1) := $@c[6])
((<= (+ a ((* -1) b)) 0) := true)
(((* -1) b) := $@c[1])
(<= := $@c[2])
($_le_comb[0] := 0)
(+ := $@c[4])
(true := true))
# let a_geq_1 = Term.LRA.(geq tstore a (const tstore (Q.of_int 1)));;
val a_geq_1 : Term.t = (>= a 1)
# let b_leq_half = Term.LRA.(leq tstore b (const tstore (Q.of_string "1/2")));;
val b_leq_half : Term.t = (<= b 1/2)
# let a_geq_1 = LRA_term.geq tstore a (LRA_term.const tstore (Q.of_int 1));;
val a_geq_1 : Term.term = (>= a 1)
# let b_leq_half = LRA_term.(leq tstore b (LRA_term.const tstore (Q.of_string "1/2")));;
val b_leq_half : Term.term = (<= b 1/2)
# let res = Solver.solve solver
~assumptions:[Solver.mk_lit_t solver p;
Solver.mk_lit_t solver a_geq_1;
Solver.mk_lit_t solver b_leq_half];;
val res : Solver.res =
Sidekick_base_solver.Solver.Unsat
{Sidekick_base_solver.Solver.unsat_core = <fun>; unsat_proof_step = <fun>}
Sidekick_smt_solver.Solver.Unsat
{Sidekick_smt_solver.Solver.unsat_core = <fun>; unsat_step_id = <fun>}
# match res with Solver.Unsat {unsat_core=us; _} -> us() |> Iter.to_list | _ -> assert false;;
- : Proof.lit list = [(>= a 1); (<= b 1/2)]
- : Proof_trace.lit list = [(>= a 1); (<= b 1/2)]
```
This just showed that `a=1, b=1/2, a>=b` is unsatisfiable.
@ -378,41 +377,39 @@ We can define function symbols, not just constants. Let's also define `u`,
an uninterpreted type.
```ocaml
# let u = Ty.atomic_uninterpreted (ID.make "u");;
val u : Ty.t = u/9
# let u = Ty.uninterpreted_str tstore "u";;
val u : Term.term = u
# let u1 = Term.const_undefined tstore (ID.make "u1") u;;
val u1 : Term.t = u1
# let u2 = Term.const_undefined tstore (ID.make "u2") u;;
val u2 : Term.t = u2
# let u3 = Term.const_undefined tstore (ID.make "u3") u;;
val u3 : Term.t = u3
# let u1 = Uconst.uconst_of_str tstore "u1" [] u;;
val u1 : Term.term = u1
# let u2 = Uconst.uconst_of_str tstore "u2" [] u;;
val u2 : Term.term = u2
# let u3 = Uconst.uconst_of_str tstore "u3" [] u;;
val u3 : Term.term = u3
# let f1 = Fun.mk_undef' (ID.make "f1") [u] u;;
val f1 : Fun.t = f1/13
# Fun.view f1;;
- : Fun.view =
Sidekick_base.Fun.Fun_undef
{Sidekick_base.Base_types.fun_ty_args = [u/9]; fun_ty_ret = u/9}
# let f1 = Uconst.uconst_of_str tstore "f1" [u] u;;
val f1 : Term.term = f1
# Term.view f1;;
- : Term.view = Sidekick_base.Term.E_const f1
# let f1_u1 = Term.app_fun_l tstore f1 [u1];;
val f1_u1 : Term.t = (f1 u1)
# let f1_u1 = Term.app_l tstore f1 [u1];;
val f1_u1 : Term.term = (f1 u1)
# Term.ty f1_u1;;
- : Ty.t = u/9
- : Term.term = u
# Term.view f1_u1;;
- : Term.t Term.view = Sidekick_base.Term.App_fun (f1/13, [|u1|])
- : Term.view = Sidekick_base.Term.E_app (f1, u1)
```
Anyway, Sidekick knows how to reason about functions.
```ocaml
# let solver = Solver.create ~theories:[] ~proof:(Proof.empty) tstore () ();;
val solver : Solver.t = <abstr>
# let solver = Solver.create_default ~theories:[] ~proof tstore ();;
val solver : solver = <abstr>
# (* helper *)
let appf1 x = Term.app_fun_l tstore f1 x;;
val appf1 : Term.t list -> Term.t = <fun>
let appf1 x = Term.app_l tstore f1 x;;
val appf1 : Term.term list -> Term.term = <fun>
# Solver.assert_term solver (Term.eq tstore u2 (appf1 [u1]));;
- : unit = ()
@ -427,14 +424,14 @@ val appf1 : Term.t list -> Term.t = <fun>
# Solver.solve solver
~assumptions:[Solver.mk_lit_t solver ~sign:false (Term.eq tstore u1 (appf1[u1]))];;
- : Solver.res =
Sidekick_base_solver.Solver.Unsat
{Sidekick_base_solver.Solver.unsat_core = <fun>; unsat_proof_step = <fun>}
Sidekick_smt_solver.Solver.Unsat
{Sidekick_smt_solver.Solver.unsat_core = <fun>; unsat_step_id = <fun>}
# Solver.solve solver
~assumptions:[Solver.mk_lit_t solver ~sign:false (Term.eq tstore u2 u3)];;
- : Solver.res =
Sidekick_base_solver.Solver.Unsat
{Sidekick_base_solver.Solver.unsat_core = <fun>; unsat_proof_step = <fun>}
Sidekick_smt_solver.Solver.Unsat
{Sidekick_smt_solver.Solver.unsat_core = <fun>; unsat_step_id = <fun>}
```
Assuming: `f1(u1)=u2, f1(u2)=u3, f1^2(u1)=u1, f1^3(u1)=u1`,

2
dune
View file

@ -4,5 +4,5 @@
(_
(flags :standard -warn-error -a+8+9 -w +a-4-32-40-41-42-44-48-70 -color
always -strict-sequence -safe-string -short-paths)
(ocamlopt_flags :standard -O3 -color always -unbox-closures
(ocamlopt_flags :standard -O3 -color always -inline 30 -unbox-closures
-unbox-closures-factor 20)))

View file

@ -1,9 +1,6 @@
(** {1 simple sudoku solver} *)
(** simple sudoku solver *)
module Fmt = CCFormat
module Vec = Sidekick_util.Vec
module Log = Sidekick_util.Log
module Profile = Sidekick_util.Profile
open Sidekick_util
let errorf msg = Fmt.kasprintf failwith msg
@ -144,82 +141,84 @@ module B_ref = Sidekick_util.Backtrackable_ref
module Solver : sig
type t
val create : Grid.t -> t
val create : stat:Stat.t -> Grid.t -> t
val solve : t -> Grid.t option
end = struct
open Sidekick_sat.Solver_intf
open Sidekick_core
(* formulas *)
module F = struct
type t = bool * int * int * Cell.t
type Const.view += Cell_is of { x: int; y: int; value: Cell.t }
let equal (sign1, x1, y1, c1) (sign2, x2, y2, c2) =
sign1 = sign2 && x1 = x2 && y1 = y2 && Cell.equal c1 c2
let ops =
(module struct
let pp out = function
| Cell_is { x; y; value } ->
Fmt.fprintf out "(%d:%d=%a)" x y Cell.pp value
| _ -> ()
let hash (sign, x, y, c) =
CCHash.(combine4 (bool sign) (int x) (int y) (Cell.hash c))
let hash = function
| Cell_is { x; y; value } ->
Hash.(combine3 (int x) (int y) (Cell.hash value))
| _ -> assert false
let pp out (sign, x, y, c) =
Fmt.fprintf out "[@[(%d,%d) %s %a@]]" x y
(if sign then
"="
else
"!=")
Cell.pp c
let equal a b =
match a, b with
| Cell_is a, Cell_is b ->
a.x = b.x && a.y = b.y && Cell.equal a.value b.value
| _ -> false
end : Const.DYN_OPS)
let neg (sign, x, y, c) = not sign, x, y, c
module Sat = Sidekick_sat
let norm_sign ((sign, _, _, _) as f) =
if sign then
f, true
else
neg f, false
let mk_cell tst x y value : Term.t =
Term.const tst
@@ Const.make (Cell_is { x; y; value }) ops ~ty:(Term.bool tst)
let make sign x y (c : Cell.t) : t = sign, x, y, c
end
let mk_cell_lit ?sign tst x y value : Lit.t =
Lit.atom ?sign tst @@ mk_cell tst x y value
module Theory = struct
type proof = unit
type proof_step = unit
module Theory : sig
type t
module Lit = F
val grid : t -> Grid.t
val create : stat:Stat.t -> Term.store -> Grid.t -> t
val to_plugin : t -> Sat.plugin
end = struct
type t = {
tst: Term.store;
grid: Grid.t B_ref.t;
stat_check_full: int Stat.counter;
stat_conflict: int Stat.counter;
}
type lit = Lit.t
module Proof = Sidekick_sat.Proof_dummy.Make (Lit)
type t = { grid: Grid.t B_ref.t }
let create g : t = { grid = B_ref.create g }
let[@inline] grid self : Grid.t = B_ref.get self.grid
let[@inline] set_grid self g : unit = B_ref.set self.grid g
let push_level self = B_ref.push_level self.grid
let pop_levels self n = B_ref.pop_levels self.grid n
let pp_c_ = Fmt.(list ~sep:(return "@ ")) F.pp
let pp_c_ = Fmt.(list ~sep:(return "@ ")) Lit.pp
let[@inline] logs_conflict kind c : unit =
Log.debugf 4 (fun k -> k "(@[conflict.%s@ %a@])" kind pp_c_ c)
(* check that all cells are full *)
let check_full_ (self : t) (acts : (Lit.t, proof, proof_step) acts) : unit =
(*Profile.with_ "check-full" @@ fun () ->*)
let check_full_ (self : t) (acts : Sat.acts) : unit =
(*let@ () = Profile.with_ "check-full" in*)
let (module A) = acts in
Grid.all_cells (grid self) (fun (x, y, c) ->
if Cell.is_empty c then (
Stat.incr self.stat_check_full;
let c =
CCList.init 9 (fun c -> F.make true x y (Cell.make (c + 1)))
CCList.init 9 (fun c ->
mk_cell_lit self.tst x y (Cell.make (c + 1)))
in
Log.debugf 4 (fun k -> k "(@[add-clause@ %a@])" pp_c_ c);
A.add_clause ~keep:true c ()
A.add_clause ~keep:true c Proof_trace.dummy_step_id
))
(* check constraints *)
let check_ (self : t) (acts : (Lit.t, proof, proof_step) acts) : unit =
(*Profile.with_ "check-constraints" @@ fun () ->*)
let check_ (self : t) (acts : Sat.acts) : unit =
(*let@ () = Profile.with_ "check-constraints" in*)
Log.debugf 4 (fun k ->
k "(@[sudoku.check@ @[:g %a@]@])" Grid.pp (B_ref.get self.grid));
let (module A) = acts in
let[@inline] all_diff kind f =
let[@inline] all_diff c_kind f =
let pairs =
f (grid self)
|> Iter.flat_map (fun set ->
@ -230,9 +229,15 @@ end = struct
pairs (fun ((x1, y1, c1), (x2, y2, c2)) ->
if Cell.equal c1 c2 then (
assert (x1 <> x2 || y1 <> y2);
let c = [ F.make false x1 y1 c1; F.make false x2 y2 c2 ] in
logs_conflict ("all-diff." ^ kind) c;
A.raise_conflict c ()
let c =
[
mk_cell_lit self.tst ~sign:false x1 y1 c1;
mk_cell_lit self.tst ~sign:false x2 y2 c2;
]
in
Stat.incr self.stat_conflict;
logs_conflict c_kind c;
A.raise_conflict c Proof_trace.dummy_step_id
))
in
all_diff "rows" Grid.rows;
@ -240,69 +245,98 @@ end = struct
all_diff "squares" Grid.squares;
()
let trail_ (acts : (Lit.t, proof, proof_step) acts) =
let trail_ (acts : Sat.acts) =
let (module A) = acts in
A.iter_assumptions
(* update current grid with the given slice *)
let add_slice (self : t) (acts : (Lit.t, proof, proof_step) acts) : unit =
let add_slice (self : t) (acts : Sat.acts) : unit =
let (module A) = acts in
trail_ acts (function
| false, _, _, _ -> ()
| true, x, y, c ->
assert (Cell.is_full c);
let grid = grid self in
let c' = Grid.get grid x y in
if Cell.is_empty c' then
set_grid self (Grid.set grid x y c)
else if Cell.neq c c' then (
(* conflict: at most one value *)
let c = [ F.make false x y c; F.make false x y c' ] in
logs_conflict "at-most-one" c;
A.raise_conflict c ()
))
trail_ acts (fun lit ->
match Lit.sign lit, Term.view (Lit.term lit) with
| true, E_const { Const.c_view = Cell_is { x; y; value = c }; _ } ->
assert (Cell.is_full c);
let grid = grid self in
let c' = Grid.get grid x y in
if Cell.is_empty c' then
set_grid self (Grid.set grid x y c)
else if Cell.neq c c' then (
(* conflict: at most one value *)
let c =
[
mk_cell_lit self.tst ~sign:false x y c;
mk_cell_lit self.tst ~sign:false x y c';
]
in
logs_conflict "at-most-one" c;
A.raise_conflict c Proof_trace.dummy_step_id
)
| _ -> ())
let partial_check (self : t) acts : unit =
(* let@ () = Profile.with_ "partial-check" in*)
Log.debugf 4 (fun k ->
k "(@[sudoku.partial-check@ :trail [@[%a@]]@])" (Fmt.list F.pp)
(trail_ acts |> Iter.to_list));
k "(@[sudoku.partial-check@ :trail [@[%a@]]@])" (Fmt.iter Lit.pp)
(trail_ acts));
add_slice self acts;
check_ self acts
let final_check (self : t) acts : unit =
(*let@ () = Profile.with_ "final-check" in*)
Log.debugf 4 (fun k -> k "(@[sudoku.final-check@])");
check_full_ self acts;
check_ self acts
let create ~stat tst g : t =
{
tst;
grid = B_ref.create g;
stat_check_full = Stat.mk_int stat "sudoku.check-cell-full";
stat_conflict = Stat.mk_int stat "sudoku.conflict";
}
let to_plugin (self : t) : Sat.plugin =
Sat.mk_plugin_cdcl_t
~push_level:(fun () -> B_ref.push_level self.grid)
~pop_levels:(fun n -> B_ref.pop_levels self.grid n)
~partial_check:(partial_check self) ~final_check:(final_check self) ()
end
module S = Sidekick_sat.Make_cdcl_t (Theory)
type t = { grid0: Grid.t; solver: S.t }
type t = { grid0: Grid.t; tst: Term.store; theory: Theory.t; solver: Sat.t }
let solve (self : t) : _ option =
Profile.with_ "sudoku.solve" @@ fun () ->
let@ () = Profile.with_ "sudoku.solve" in
let assumptions =
Grid.all_cells self.grid0
|> Iter.filter (fun (_, _, c) -> Cell.is_full c)
|> Iter.map (fun (x, y, c) -> F.make true x y c)
|> Iter.map (fun (x, y, c) -> mk_cell_lit self.tst x y c)
|> Iter.to_rev_list
in
Log.debugf 2 (fun k ->
k "(@[sudoku.solve@ :assumptions %a@])" (Fmt.Dump.list F.pp) assumptions);
k "(@[sudoku.solve@ :assumptions %a@])" (Fmt.Dump.list Lit.pp)
assumptions);
let r =
match S.solve self.solver ~assumptions with
| S.Sat _ -> Some (Theory.grid (S.theory self.solver))
| S.Unsat _ -> None
match Sat.solve self.solver ~assumptions with
| Sat.Sat _ -> Some (Theory.grid self.theory)
| Sat.Unsat _ -> None
in
(* TODO: print some stats *)
r
let create g : t =
{ solver = S.create ~proof:() (Theory.create g); grid0 = g }
let create ~stat g : t =
let tst = Term.Store.create () in
let theory = Theory.create ~stat tst g in
let plugin : Sat.plugin = Theory.to_plugin theory in
{
tst;
solver = Sat.create ~stat ~proof:Proof_trace.dummy plugin;
theory;
grid0 = g;
}
end
let solve_grid (g : Grid.t) : Grid.t option =
let s = Solver.create g in
let solve_grid ~stat (g : Grid.t) : Grid.t option =
let s = Solver.create ~stat g in
Solver.solve s
module type CHRONO = sig
@ -318,8 +352,8 @@ let chrono ~pp_time : (module CHRONO) =
end in
(module M)
let solve_file ~pp_time file =
Profile.with_ "solve-file" @@ fun () ->
let solve_file ~use_stats ~pp_time file =
let@ () = Profile.with_ "solve-file" in
let open (val chrono ~pp_time) in
Format.printf "solve grids in file %S@." file;
@ -342,7 +376,8 @@ let solve_file ~pp_time file =
Format.printf
"@[<v>@,#########################@,@[<2>solve grid:@ %a@]@]@." Grid.pp g;
let open (val chrono ~pp_time) in
match solve_grid g with
let stat = Stat.create () in
(match solve_grid ~stat g with
| None -> Format.printf "no solution%t@." pp_elapsed
| Some g' when not @@ Grid.is_full g' ->
errorf "grid %a@ is not full" Grid.pp g'
@ -353,28 +388,34 @@ let solve_file ~pp_time file =
g
| Some g' ->
Format.printf "@[<v>@[<2>solution%t:@ %a@]@,###################@]@."
pp_elapsed Grid.pp g')
pp_elapsed Grid.pp g');
if use_stats then Fmt.printf "stats: %a@." Stat.pp stat)
grids;
Format.printf "@.solved %d grids%t@." (List.length grids) pp_elapsed;
()
let () =
Sidekick_tef.with_setup @@ fun () ->
let@ () = Sidekick_tef.with_setup in
Fmt.set_color_default true;
let files = ref [] in
let debug = ref 0 in
let pp_time = ref true in
let use_stats = ref false in
let opts =
[
"--debug", Arg.Set_int debug, " debug";
"-d", Arg.Set_int debug, " debug";
"--no-time", Arg.Clear pp_time, " do not print solve time";
"--stat", Arg.Set use_stats, " print statistics";
]
|> Arg.align
in
Arg.parse opts (fun f -> files := f :: !files) "sudoku_solve [options] <file>";
Log.set_debug !debug;
try List.iter (fun f -> solve_file ~pp_time:!pp_time f) !files
try
List.iter
(fun f -> solve_file ~pp_time:!pp_time ~use_stats:!use_stats f)
!files
with Failure msg | Invalid_argument msg ->
Format.printf "@{<Red>Error@}:@.%s@." msg;
exit 1

3
sidekick.sh Executable file
View file

@ -0,0 +1,3 @@
#!/bin/sh
OPTS="--profile=release --display=quiet"
exec dune exec $OPTS ./src/main/main.exe -- $@

View file

@ -1,9 +0,0 @@
(library
(name sidekick_base_solver)
(public_name sidekick-base.solver)
(synopsis "Instantiation of solver and theories for Sidekick_base")
(libraries sidekick-base sidekick.core sidekick.smt-solver
sidekick.th-bool-static sidekick.mini-cc sidekick.th-data
sidekick.arith-lra sidekick.zarith)
(flags :standard -warn-error -a+8 -safe-string -color always -open
Sidekick_util))

View file

@ -1,146 +0,0 @@
(** SMT Solver and Theories for [Sidekick_base].
This contains instances of the SMT solver, and theories,
from {!Sidekick_core}, using data structures from
{!Sidekick_base}. *)
open! Sidekick_base
(** Argument to the SMT solver *)
module Solver_arg = struct
module T = Sidekick_base.Solver_arg
module Lit = Sidekick_base.Lit
let cc_view = Term.cc_view
let mk_eq = Term.eq
let is_valid_literal _ = true
module P = Sidekick_base.Proof
type proof = P.t
type proof_step = P.proof_step
end
module Solver = Sidekick_smt_solver.Make (Solver_arg)
(** SMT solver, obtained from {!Sidekick_smt_solver} *)
(** Theory of datatypes *)
module Th_data = Sidekick_th_data.Make (struct
module S = Solver
open! Base_types
open! Sidekick_th_data
module Proof = Proof
module Cstor = Cstor
let as_datatype ty =
match Ty.view ty with
| Ty_atomic { def = Ty_data data; _ } ->
Ty_data { cstors = Lazy.force data.data.data_cstors |> ID.Map.values }
| Ty_atomic { def = _; args; finite = _ } ->
Ty_app { args = Iter.of_list args }
| Ty_bool | Ty_real | Ty_int -> Ty_app { args = Iter.empty }
let view_as_data t =
match Term.view t with
| Term.App_fun ({ fun_view = Fun.Fun_cstor c; _ }, args) -> T_cstor (c, args)
| Term.App_fun ({ fun_view = Fun.Fun_select sel; _ }, args) ->
assert (CCArray.length args = 1);
T_select (sel.select_cstor, sel.select_i, CCArray.get args 0)
| Term.App_fun ({ fun_view = Fun.Fun_is_a c; _ }, args) ->
assert (CCArray.length args = 1);
T_is_a (c, CCArray.get args 0)
| _ -> T_other t
let mk_eq = Term.eq
let mk_cstor tst c args : Term.t = Term.app_fun tst (Fun.cstor c) args
let mk_sel tst c i u = Term.app_fun tst (Fun.select_idx c i) [| u |]
let mk_is_a tst c u : Term.t =
if c.cstor_arity = 0 then
Term.eq tst u (Term.const tst (Fun.cstor c))
else
Term.app_fun tst (Fun.is_a c) [| u |]
let ty_is_finite = Ty.finite
let ty_set_is_finite = Ty.set_finite
module P = Proof
end)
(** Reducing boolean formulas to clauses *)
module Th_bool = Sidekick_th_bool_static.Make (struct
module S = Solver
type term = S.T.Term.t
include Form
let lemma_bool_tauto = Proof.lemma_bool_tauto
let lemma_bool_c = Proof.lemma_bool_c
let lemma_bool_equiv = Proof.lemma_bool_equiv
let lemma_ite_true = Proof.lemma_ite_true
let lemma_ite_false = Proof.lemma_ite_false
end)
module Gensym = struct
type t = { tst: Term.store; mutable fresh: int }
let create tst : t = { tst; fresh = 0 }
let tst self = self.tst
let copy s = { s with tst = s.tst }
let fresh_term (self : t) ~pre (ty : Ty.t) : Term.t =
let name = Printf.sprintf "_sk_lra_%s%d" pre self.fresh in
self.fresh <- 1 + self.fresh;
let id = ID.make name in
Term.const self.tst @@ Fun.mk_undef_const id ty
end
(** Theory of Linear Rational Arithmetic *)
module Th_lra = Sidekick_arith_lra.Make (struct
module S = Solver
module T = Term
module Z = Sidekick_zarith.Int
module Q = Sidekick_zarith.Rational
type term = S.T.Term.t
type ty = S.T.Ty.t
module LRA = Sidekick_arith_lra
let mk_eq = Form.eq
let mk_lra store l =
match l with
| LRA.LRA_other x -> x
| LRA.LRA_pred (p, x, y) -> T.lra store (Pred (p, x, y))
| LRA.LRA_op (op, x, y) -> T.lra store (Op (op, x, y))
| LRA.LRA_const c -> T.lra store (Const c)
| LRA.LRA_mult (c, x) -> T.lra store (Mult (c, x))
let mk_bool = T.bool
let rec view_as_lra t =
match T.view t with
| T.LRA l ->
let module LRA = Sidekick_arith_lra in
(match l with
| Const c -> LRA.LRA_const c
| Pred (p, a, b) -> LRA.LRA_pred (p, a, b)
| Op (op, a, b) -> LRA.LRA_op (op, a, b)
| Mult (c, x) -> LRA.LRA_mult (c, x)
| To_real x -> view_as_lra x
| Var x -> LRA.LRA_other x)
| T.Eq (a, b) when Ty.equal (T.ty a) (Ty.real ()) -> LRA.LRA_pred (Eq, a, b)
| _ -> LRA.LRA_other t
let ty_lra _st = Ty.real ()
let has_ty_real t = Ty.equal (T.ty t) (Ty.real ())
let lemma_lra = Proof.lemma_lra
module Gensym = Gensym
end)
let th_bool : Solver.theory = Th_bool.theory
let th_data : Solver.theory = Th_data.theory
let th_lra : Solver.theory = Th_lra.theory

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,7 @@
(** {1 Configuration} *)
(** Configuration *)
type 'a sequence = ('a -> unit) -> unit
module Key = Het.Key
module Key = CCHet.Key
type pair = Het.pair = Pair : 'a Key.t * 'a -> pair
type pair = CCHet.pair = Pair : 'a Key.t * 'a -> pair
include CCHet.Map
include Het.Map

View file

@ -1,6 +1,4 @@
(** {1 Configuration} *)
type 'a sequence = ('a -> unit) -> unit
(** Configuration *)
module Key : sig
type 'a t
@ -26,9 +24,9 @@ val find_exn : 'a Key.t -> t -> 'a
type pair = Pair : 'a Key.t * 'a -> pair
val iter : (pair -> unit) -> t -> unit
val to_iter : t -> pair sequence
val of_iter : pair sequence -> t
val add_iter : t -> pair sequence -> t
val to_iter : t -> pair Iter.t
val of_iter : pair Iter.t -> t
val add_iter : t -> pair Iter.t -> t
val add_list : t -> pair list -> t
val of_list : pair list -> t
val to_list : t -> pair list

148
src/base/Data_ty.ml Normal file
View file

@ -0,0 +1,148 @@
open Types_
type select = Types_.select = {
select_id: ID.t;
select_cstor: cstor;
select_ty: ty lazy_t;
select_i: int;
}
type cstor = Types_.cstor = {
cstor_id: ID.t;
cstor_is_a: ID.t;
mutable cstor_arity: int;
cstor_args: select list lazy_t;
cstor_ty_as_data: data;
cstor_ty: ty lazy_t;
}
type t = data = {
data_id: ID.t;
data_cstors: cstor ID.Map.t lazy_t;
data_as_ty: ty lazy_t;
}
let pp out d = ID.pp out d.data_id
let equal a b = ID.equal a.data_id b.data_id
let hash a = ID.hash a.data_id
(** Datatype selectors.
A selector is a kind of function that allows to obtain an argument
of a given constructor. *)
module Select = struct
type t = Types_.select = {
select_id: ID.t;
select_cstor: cstor;
select_ty: ty lazy_t;
select_i: int;
}
let ty sel = Lazy.force sel.select_ty
let equal a b =
ID.equal a.select_id b.select_id
&& ID.equal a.select_cstor.cstor_id b.select_cstor.cstor_id
&& a.select_i = b.select_i
let hash a =
Hash.combine4 1952 (ID.hash a.select_id)
(ID.hash a.select_cstor.cstor_id)
(Hash.int a.select_i)
let pp out self =
Fmt.fprintf out "select.%a[%d]" ID.pp self.select_cstor.cstor_id
self.select_i
end
(** Datatype constructors.
A datatype has one or more constructors, each of which is a special
kind of function symbol. Constructors are injective and pairwise distinct. *)
module Cstor = struct
type t = cstor
let hash c = ID.hash c.cstor_id
let ty_args c = Lazy.force c.cstor_args |> List.map Select.ty
let select_idx c i =
let (lazy sels) = c.cstor_args in
if i >= List.length sels then invalid_arg "cstor.select_idx: out of bound";
List.nth sels i
let equal a b = ID.equal a.cstor_id b.cstor_id
let pp out c = ID.pp out c.cstor_id
end
type Const.view +=
| Data of data
| Cstor of cstor
| Select of select
| Is_a of cstor
let ops =
(module struct
let pp out = function
| Data d -> pp out d
| Cstor c -> Cstor.pp out c
| Select s -> Select.pp out s
| Is_a c -> Fmt.fprintf out "(_ is %a)" Cstor.pp c
| _ -> assert false
let equal a b =
match a, b with
| Data a, Data b -> equal a b
| Cstor a, Cstor b -> Cstor.equal a b
| Select a, Select b -> Select.equal a b
| Is_a a, Is_a b -> Cstor.equal a b
| _ -> false
let hash = function
| Data d -> Hash.combine2 592 (hash d)
| Cstor c -> Hash.combine2 593 (Cstor.hash c)
| Select s -> Hash.combine2 594 (Select.hash s)
| Is_a c -> Hash.combine2 595 (Cstor.hash c)
| _ -> assert false
end : Const.DYN_OPS)
let data tst d : Term.t =
Term.const tst @@ Const.make (Data d) ops ~ty:(Term.type_ tst)
let cstor tst c : Term.t =
let ty_ret = Lazy.force c.cstor_ty in
let ty_args =
List.map (fun s -> Lazy.force s.select_ty) (Lazy.force c.cstor_args)
in
let ty = Term.arrow_l tst ty_args ty_ret in
Term.const tst @@ Const.make (Cstor c) ops ~ty
let select tst s : Term.t =
let ty_ret = Lazy.force s.select_ty in
let ty_arg = data tst s.select_cstor.cstor_ty_as_data in
let ty = Term.arrow tst ty_arg ty_ret in
Term.const tst @@ Const.make (Select s) ops ~ty
let is_a tst c : Term.t =
let ty_arg = Lazy.force c.cstor_ty in
let ty = Term.arrow tst ty_arg (Term.bool tst) in
Term.const tst @@ Const.make (Is_a c) ops ~ty
let as_data t =
match Term.view t with
| E_const { Const.c_view = Data d; _ } -> Some d
| _ -> None
let as_cstor t =
match Term.view t with
| E_const { Const.c_view = Cstor c; _ } -> Some c
| _ -> None
let as_select t =
match Term.view t with
| E_const { Const.c_view = Select s; _ } -> Some s
| _ -> None
let as_is_a t =
match Term.view t with
| E_const { Const.c_view = Is_a c; _ } -> Some c
| _ -> None

59
src/base/Data_ty.mli Normal file
View file

@ -0,0 +1,59 @@
open Types_
type select = Types_.select = {
select_id: ID.t;
select_cstor: cstor;
select_ty: ty lazy_t;
select_i: int;
}
type cstor = Types_.cstor = {
cstor_id: ID.t;
cstor_is_a: ID.t;
mutable cstor_arity: int;
cstor_args: select list lazy_t;
cstor_ty_as_data: data;
cstor_ty: ty lazy_t;
}
type t = data = {
data_id: ID.t;
data_cstors: cstor ID.Map.t lazy_t;
data_as_ty: ty lazy_t;
}
type Const.view +=
private
| Data of data
| Cstor of cstor
| Select of select
| Is_a of cstor
include Sidekick_sigs.EQ_HASH_PRINT with type t := t
module Select : sig
type t = select
include Sidekick_sigs.EQ_HASH_PRINT with type t := t
end
module Cstor : sig
type t = cstor
val ty_args : t -> ty list
val select_idx : t -> int -> select
include Sidekick_sigs.EQ_HASH_PRINT with type t := t
end
val data : Term.store -> t -> Term.t
val cstor : Term.store -> cstor -> Term.t
val select : Term.store -> select -> Term.t
val is_a : Term.store -> cstor -> Term.t
(* TODO: select_ : store -> cstor -> int -> term *)
val as_data : ty -> data option
val as_select : term -> select option
val as_cstor : term -> cstor option
val as_is_a : term -> cstor option

View file

@ -1,58 +1,128 @@
(** Formulas (boolean terms).
open Sidekick_core
module T = Term
This module defines function symbols, constants, and views
to manipulate boolean formulas in {!Sidekick_base}.
This is useful to have the ability to use boolean connectives instead
of being limited to clauses; by using {!Sidekick_th_bool_static},
the formulas are turned into clauses automatically for you.
*)
type term = Term.t
module T = Base_types.Term
module Ty = Base_types.Ty
module Fun = Base_types.Fun
module Value = Base_types.Value
open Sidekick_th_bool_static
type 'a view = 'a Sidekick_core.Bool_view.t =
| B_bool of bool
| B_not of 'a
| B_and of 'a list
| B_or of 'a list
| B_imply of 'a * 'a
| B_equiv of 'a * 'a
| B_xor of 'a * 'a
| B_eq of 'a * 'a
| B_neq of 'a * 'a
| B_ite of 'a * 'a * 'a
| B_atom of 'a
exception Not_a_th_term
type Const.view += C_and | C_or | C_imply
let id_and = ID.make "and"
let id_or = ID.make "or"
let id_imply = ID.make "=>"
let ops : Const.ops =
(module struct
let pp out = function
| C_and -> Fmt.string out "and"
| C_or -> Fmt.string out "or"
| C_imply -> Fmt.string out "=>"
| _ -> assert false
let view_id fid args =
if ID.equal fid id_and then
B_and (CCArray.to_iter args)
else if ID.equal fid id_or then
B_or (CCArray.to_iter args)
else if ID.equal fid id_imply && CCArray.length args >= 2 then (
(* conclusion is stored last *)
let len = CCArray.length args in
B_imply
(Iter.of_array args |> Iter.take (len - 1), CCArray.get args (len - 1))
) else
raise_notrace Not_a_th_term
let equal a b =
match a, b with
| C_and, C_and | C_or, C_or | C_imply, C_imply -> true
| _ -> false
let view_as_bool (t : T.t) : (T.t, _) bool_view =
match T.view t with
| Bool b -> B_bool b
| Not u -> B_not u
| Eq (a, b) when Ty.is_bool (T.ty a) -> B_equiv (a, b)
| Ite (a, b, c) -> B_ite (a, b, c)
| App_fun ({ fun_id; _ }, args) ->
(try view_id fun_id args with Not_a_th_term -> B_atom t)
let hash = function
| C_and -> Hash.int 425
| C_or -> Hash.int 426
| C_imply -> Hash.int 427
| _ -> assert false
end)
(* ### view *)
let view (t : T.t) : T.t view =
let hd, args = T.unfold_app t in
match T.view hd, args with
| E_const { Const.c_view = T.C_true; _ }, [] -> B_bool true
| E_const { Const.c_view = T.C_false; _ }, [] -> B_bool false
| E_const { Const.c_view = T.C_not; _ }, [ a ] -> B_not a
| E_const { Const.c_view = T.C_eq; _ }, [ _ty; a; b ] ->
if Ty.is_bool a then
B_equiv (a, b)
else
B_eq (a, b)
| E_const { Const.c_view = T.C_ite; _ }, [ _ty; a; b; c ] -> B_ite (a, b, c)
| E_const { Const.c_view = C_imply; _ }, [ a; b ] -> B_imply (a, b)
| E_app_fold { f; args; acc0 }, [] ->
(match T.view f, T.view acc0 with
| ( E_const { Const.c_view = C_and; _ },
E_const { Const.c_view = T.C_true; _ } ) ->
B_and args
| ( E_const { Const.c_view = C_or; _ },
E_const { Const.c_view = T.C_false; _ } ) ->
B_or args
| _ -> B_atom t)
| _ -> B_atom t
module Funs = struct
let get_ty _ _ = Ty.bool ()
let ty2b_ tst =
let bool = Term.bool tst in
Term.arrow_l tst [ bool; bool ] bool
let abs ~self _a =
match T.view self with
| Not u -> u, false
| _ -> self, true
let c_and tst : Const.t = Const.make C_and ops ~ty:(ty2b_ tst)
let c_or tst : Const.t = Const.make C_or ops ~ty:(ty2b_ tst)
let c_imply tst : Const.t = Const.make C_imply ops ~ty:(ty2b_ tst)
(* no congruence closure for boolean terms *)
let relevant _id _ _ = false
let and_l tst = function
| [] -> T.true_ tst
| [ x ] -> x
| l ->
Term.app_fold tst l ~f:(Term.const tst @@ c_and tst) ~acc0:(T.true_ tst)
let or_l tst = function
| [] -> T.false_ tst
| [ x ] -> x
| l ->
Term.app_fold tst l ~f:(Term.const tst @@ c_or tst) ~acc0:(T.false_ tst)
let bool = Term.bool_val
let and_ tst a b = and_l tst [ a; b ]
let or_ tst a b = or_l tst [ a; b ]
let imply tst a b : Term.t = T.app_l tst (T.const tst @@ c_imply tst) [ a; b ]
let eq = T.eq
let not_ = T.not
let ite = T.ite
let neq st a b = not_ st @@ eq st a b
let imply_l tst xs y = List.fold_right (imply tst) xs y
let equiv tst a b =
if (not (T.is_bool (T.ty a))) || not (T.is_bool (T.ty b)) then
failwith "Form.equiv: takes boolean arguments";
T.eq tst a b
let xor tst a b = not_ tst (equiv tst a b)
let distinct_l tst l =
match l with
| [] | [ _ ] -> T.true_ tst
| l ->
(* turn into [and_{i<j} t_i != t_j] *)
let cs = CCList.diagonal l |> List.map (fun (a, b) -> neq tst a b) in
and_l tst cs
let mk_of_view tst = function
| B_bool b -> T.bool_val tst b
| B_atom t -> t
| B_and l -> and_l tst l
| B_or l -> or_l tst l
| B_imply (a, b) -> imply tst a b
| B_ite (a, b, c) -> ite tst a b c
| B_equiv (a, b) -> equiv tst a b
| B_xor (a, b) -> not_ tst (equiv tst a b)
| B_eq (a, b) -> T.eq tst a b
| B_neq (a, b) -> not_ tst (T.eq tst a b)
| B_not t -> not_ tst t
(*
let eval id args =
let open Value in
match view_id id args with
@ -79,126 +149,4 @@ module Funs = struct
| B_opaque_bool t -> Error.errorf "cannot evaluate opaque bool %a" pp t
| B_not _ | B_and _ | B_or _ | B_imply _ ->
Error.errorf "non boolean value in boolean connective"
let mk_fun ?(do_cc = false) id : Fun.t =
{
fun_id = id;
fun_view =
Fun_def { pp = None; abs; ty = get_ty; relevant; do_cc; eval = eval id };
}
let and_ = mk_fun id_and
let or_ = mk_fun id_or
let imply = mk_fun id_imply
let ite = T.ite
end
let as_id id (t : T.t) : T.t array option =
match T.view t with
| App_fun ({ fun_id; _ }, args) when ID.equal id fun_id -> Some args
| _ -> None
(* flatten terms of the given ID *)
let flatten_id op sign (l : T.t list) : T.t list =
CCList.flat_map
(fun t ->
match as_id op t with
| Some args -> CCArray.to_list args
| None when (sign && T.is_true t) || ((not sign) && T.is_false t) ->
[] (* idempotent *)
| None -> [ t ])
l
let and_l st l =
match flatten_id id_and true l with
| [] -> T.true_ st
| l when List.exists T.is_false l -> T.false_ st
| [ x ] -> x
| args -> T.app_fun st Funs.and_ (CCArray.of_list args)
let or_l st l =
match flatten_id id_or false l with
| [] -> T.false_ st
| l when List.exists T.is_true l -> T.true_ st
| [ x ] -> x
| args -> T.app_fun st Funs.or_ (CCArray.of_list args)
let and_ st a b = and_l st [ a; b ]
let or_ st a b = or_l st [ a; b ]
let and_a st a = and_l st (CCArray.to_list a)
let or_a st a = or_l st (CCArray.to_list a)
let eq = T.eq
let not_ = T.not_
let ite st a b c =
match T.view a with
| T.Bool ba ->
if ba then
b
else
c
| _ -> T.ite st a b c
let equiv st a b =
if T.equal a b then
T.true_ st
else if T.is_true a then
b
else if T.is_true b then
a
else if T.is_false a then
not_ st b
else if T.is_false b then
not_ st a
else
T.eq st a b
let neq st a b = not_ st @@ eq st a b
let imply_a st xs y =
if Array.length xs = 0 then
y
else
T.app_fun st Funs.imply (CCArray.append xs [| y |])
let imply_l st xs y =
match xs with
| [] -> y
| _ -> imply_a st (CCArray.of_list xs) y
let imply st a b = imply_a st [| a |] b
let xor st a b = not_ st (equiv st a b)
let distinct_l tst l =
match l with
| [] | [ _ ] -> T.true_ tst
| l ->
(* turn into [and_{i<j} t_i != t_j] *)
let cs = CCList.diagonal l |> List.map (fun (a, b) -> neq tst a b) in
and_l tst cs
let mk_bool st = function
| B_bool b -> T.bool st b
| B_atom t -> t
| B_and l -> and_a st l
| B_or l -> or_a st l
| B_imply (a, b) -> imply_a st a b
| B_ite (a, b, c) -> ite st a b c
| B_equiv (a, b) -> equiv st a b
| B_xor (a, b) -> not_ st (equiv st a b)
| B_eq (a, b) -> T.eq st a b
| B_neq (a, b) -> not_ st (T.eq st a b)
| B_not t -> not_ st t
| B_opaque_bool t -> t
module Gensym = struct
type t = { tst: T.store; mutable fresh: int }
let create tst : t = { tst; fresh = 0 }
let fresh_term (self : t) ~pre (ty : Ty.t) : T.t =
let name = Printf.sprintf "_tseitin_%s%d" pre self.fresh in
self.fresh <- 1 + self.fresh;
let id = ID.make name in
T.const self.tst @@ Fun.mk_undef_const id ty
end
*)

49
src/base/Form.mli Normal file
View file

@ -0,0 +1,49 @@
(** Formulas (boolean terms).
This module defines function symbols, constants, and views
to manipulate boolean formulas in {!Sidekick_base}.
This is useful to have the ability to use boolean connectives instead
of being limited to clauses; by using {!Sidekick_th_bool_static},
the formulas are turned into clauses automatically for you.
*)
open Types_
type term = Term.t
type 'a view = 'a Sidekick_core.Bool_view.t =
| B_bool of bool
| B_not of 'a
| B_and of 'a list
| B_or of 'a list
| B_imply of 'a * 'a
| B_equiv of 'a * 'a
| B_xor of 'a * 'a
| B_eq of 'a * 'a
| B_neq of 'a * 'a
| B_ite of 'a * 'a * 'a
| B_atom of 'a
val view : term -> term view
val bool : Term.store -> bool -> term
val not_ : Term.store -> term -> term
val and_ : Term.store -> term -> term -> term
val or_ : Term.store -> term -> term -> term
val eq : Term.store -> term -> term -> term
val neq : Term.store -> term -> term -> term
val imply : Term.store -> term -> term -> term
val equiv : Term.store -> term -> term -> term
val xor : Term.store -> term -> term -> term
val ite : Term.store -> term -> term -> term -> term
val distinct_l : Term.store -> term list -> term
(* *)
val and_l : Term.store -> term list -> term
val or_l : Term.store -> term list -> term
val imply_l : Term.store -> term list -> term -> term
val mk_of_view : Term.store -> term view -> term
(* TODO?
val make : Term.store -> (term, term list) view -> term
*)

View file

@ -74,58 +74,6 @@ let pair_of_e_pair (E_pair (k, e)) =
| K.Store v -> Pair (k, v)
| _ -> assert false
module Tbl = struct
module M = Hashtbl.Make (struct
type t = int
let equal (i : int) j = i = j
let hash (i : int) = Hashtbl.hash i
end)
type t = exn_pair M.t
let create ?(size = 16) () = M.create size
let mem t k = M.mem t (Key.id k)
let find_exn (type a) t (k : a Key.t) : a =
let module K = (val k) in
let (E_pair (_, v)) = M.find t K.id in
match v with
| K.Store v -> v
| _ -> assert false
let find t k = try Some (find_exn t k) with Not_found -> None
let add_pair_ t p =
let (Pair (k, v)) = p in
let module K = (val k) in
let p = E_pair (k, K.Store v) in
M.replace t K.id p
let add t k v = add_pair_ t (Pair (k, v))
let remove (type a) t (k : a Key.t) =
let module K = (val k) in
M.remove t K.id
let length t = M.length t
let iter f t = M.iter (fun _ pair -> f (pair_of_e_pair pair)) t
let to_iter t yield = iter yield t
let to_list t = M.fold (fun _ p l -> pair_of_e_pair p :: l) t []
let add_list t l = List.iter (add_pair_ t) l
let add_iter t seq = seq (add_pair_ t)
let of_list l =
let t = create () in
add_list t l;
t
let of_iter seq =
let t = create () in
add_iter t seq;
t
end
module Map = struct
module M = Map.Make (struct
type t = int

View file

@ -1,5 +1,3 @@
(* This file is free software, part of containers. See file "license" for more details. *)
(** {1 Associative containers with Heterogeneous Values}
This is similar to {!CCMixtbl}, but the injection is directly used as
@ -21,29 +19,6 @@ end
type pair = Pair : 'a Key.t * 'a -> pair
(** {2 Imperative table indexed by [Key]} *)
module Tbl : sig
type t
val create : ?size:int -> unit -> t
val mem : t -> _ Key.t -> bool
val add : t -> 'a Key.t -> 'a -> unit
val remove : t -> _ Key.t -> unit
val length : t -> int
val find : t -> 'a Key.t -> 'a option
val find_exn : t -> 'a Key.t -> 'a
(** @raise Not_found if the key is not in the table. *)
val iter : (pair -> unit) -> t -> unit
val to_iter : t -> pair iter
val of_iter : pair iter -> t
val add_iter : t -> pair iter -> unit
val add_list : t -> pair list -> unit
val of_list : pair list -> t
val to_list : t -> pair list
end
(** {2 Immutable map} *)
module Map : sig
type t

View file

@ -16,13 +16,13 @@ let to_string id = id.name
let equal a b = a.id = b.id
let compare a b = CCInt.compare a.id b.id
let hash a = CCHash.int a.id
let pp out a = Format.fprintf out "%s/%d" a.name a.id
let pp_full out a = Format.fprintf out "%s/%d" a.name a.id
let pp_name out a = CCFormat.string out a.name
let pp = pp_name
let to_string_full a = Printf.sprintf "%s/%d" a.name a.id
module AsKey = struct
type t_ = t
type t = t_
type nonrec t = t
let equal = equal
let compare = compare

View file

@ -37,12 +37,10 @@ val to_string : t -> string
val to_string_full : t -> string
(** Printer name and unique counter for this ID. *)
include Intf.EQ with type t := t
include Intf.ORD with type t := t
include Intf.HASH with type t := t
include Intf.PRINT with type t := t
include Sidekick_sigs.EQ_ORD_HASH_PRINT with type t := t
val pp_name : t CCFormat.printer
val pp_full : t CCFormat.printer
module Map : CCMap.S with type key = t
module Set : CCSet.S with type elt = t

70
src/base/LIA_term.ml Normal file
View file

@ -0,0 +1,70 @@
open struct
let hash_z = Z.hash
end
module LIA_pred = LRA_term.Pred
module LIA_op = LRA_term.Op
module LIA_view = struct
type 'a t =
| LRA_pred of LIA_pred.t * 'a * 'a
| LRA_op of LIA_op.t * 'a * 'a
| LRA_mult of Z.t * 'a
| LRA_const of Z.t
| LRA_other of 'a
let map ~f_c f (l : _ t) : _ t =
match l with
| LRA_pred (p, a, b) -> LRA_pred (p, f a, f b)
| LRA_op (p, a, b) -> LRA_op (p, f a, f b)
| LRA_mult (n, a) -> LRA_mult (f_c n, f a)
| LRA_const c -> LRA_const (f_c c)
| LRA_other x -> LRA_other (f x)
let iter f l : unit =
match l with
| LRA_pred (_, a, b) | LRA_op (_, a, b) ->
f a;
f b
| LRA_mult (_, x) | LRA_other x -> f x
| LRA_const _ -> ()
let pp ~pp_t out = function
| LRA_pred (p, a, b) ->
Fmt.fprintf out "(@[%a@ %a@ %a@])" LRA_term.Pred.pp p pp_t a pp_t b
| LRA_op (p, a, b) ->
Fmt.fprintf out "(@[%a@ %a@ %a@])" LRA_term.Op.pp p pp_t a pp_t b
| LRA_mult (n, x) -> Fmt.fprintf out "(@[*@ %a@ %a@])" Z.pp_print n pp_t x
| LRA_const n -> Z.pp_print out n
| LRA_other x -> pp_t out x
let hash ~sub_hash = function
| LRA_pred (p, a, b) ->
Hash.combine4 81 (Hash.poly p) (sub_hash a) (sub_hash b)
| LRA_op (p, a, b) ->
Hash.combine4 82 (Hash.poly p) (sub_hash a) (sub_hash b)
| LRA_mult (n, x) -> Hash.combine3 83 (hash_z n) (sub_hash x)
| LRA_const n -> Hash.combine2 84 (hash_z n)
| LRA_other x -> sub_hash x
let equal ~sub_eq l1 l2 =
match l1, l2 with
| LRA_pred (p1, a1, b1), LRA_pred (p2, a2, b2) ->
p1 = p2 && sub_eq a1 a2 && sub_eq b1 b2
| LRA_op (p1, a1, b1), LRA_op (p2, a2, b2) ->
p1 = p2 && sub_eq a1 a2 && sub_eq b1 b2
| LRA_const a1, LRA_const a2 -> Z.equal a1 a2
| LRA_mult (n1, x1), LRA_mult (n2, x2) -> Z.equal n1 n2 && sub_eq x1 x2
| LRA_other x1, LRA_other x2 -> sub_eq x1 x2
| (LRA_pred _ | LRA_op _ | LRA_const _ | LRA_mult _ | LRA_other _), _ ->
false
(* convert the whole structure to reals *)
let to_lra f l : _ LRA_term.View.t =
match l with
| LRA_pred (p, a, b) -> LRA_term.View.LRA_pred (p, f a, f b)
| LRA_op (op, a, b) -> LRA_term.View.LRA_op (op, f a, f b)
| LRA_mult (c, x) -> LRA_term.View.LRA_mult (Q.of_bigint c, f x)
| LRA_const x -> LRA_term.View.LRA_const (Q.of_bigint x)
| LRA_other v -> LRA_term.View.LRA_other (f v)
end

176
src/base/LRA_term.ml Normal file
View file

@ -0,0 +1,176 @@
open Sidekick_core
module T = Term
open struct
let hash_z = Z.hash
let[@inline] hash_q q = CCHash.combine2 (hash_z (Q.num q)) (hash_z (Q.den q))
end
module Pred = struct
type t = Sidekick_th_lra.Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
let to_string = function
| Lt -> "<"
| Leq -> "<="
| Neq -> "!=_LRA"
| Eq -> "=_LRA"
| Gt -> ">"
| Geq -> ">="
let equal : t -> t -> bool = ( = )
let hash : t -> int = Hashtbl.hash
let pp out p = Fmt.string out (to_string p)
end
module Op = struct
type t = Sidekick_th_lra.op = Plus | Minus
let to_string = function
| Plus -> "+"
| Minus -> "-"
let equal : t -> t -> bool = ( = )
let hash : t -> int = Hashtbl.hash
let pp out p = Fmt.string out (to_string p)
end
module View = struct
include Sidekick_th_lra
type 'a t = (Q.t, 'a) lra_view
let map ~f_c f (l : _ t) : _ t =
match l with
| LRA_pred (p, a, b) -> LRA_pred (p, f a, f b)
| LRA_op (p, a, b) -> LRA_op (p, f a, f b)
| LRA_mult (n, a) -> LRA_mult (f_c n, f a)
| LRA_const c -> LRA_const (f_c c)
| LRA_other x -> LRA_other (f x)
let iter f l : unit =
match l with
| LRA_pred (_, a, b) | LRA_op (_, a, b) ->
f a;
f b
| LRA_mult (_, x) | LRA_other x -> f x
| LRA_const _ -> ()
let pp ~pp_t out = function
| LRA_pred (p, a, b) ->
Fmt.fprintf out "(@[%s@ %a@ %a@])" (Pred.to_string p) pp_t a pp_t b
| LRA_op (p, a, b) ->
Fmt.fprintf out "(@[%s@ %a@ %a@])" (Op.to_string p) pp_t a pp_t b
| LRA_mult (n, x) -> Fmt.fprintf out "(@[*@ %a@ %a@])" Q.pp_print n pp_t x
| LRA_const q -> Q.pp_print out q
| LRA_other x -> pp_t out x
let hash ~sub_hash = function
| LRA_pred (p, a, b) ->
Hash.combine4 81 (Hash.poly p) (sub_hash a) (sub_hash b)
| LRA_op (p, a, b) ->
Hash.combine4 82 (Hash.poly p) (sub_hash a) (sub_hash b)
| LRA_mult (n, x) -> Hash.combine3 83 (hash_q n) (sub_hash x)
| LRA_const q -> Hash.combine2 84 (hash_q q)
| LRA_other x -> sub_hash x
let equal ~sub_eq l1 l2 =
match l1, l2 with
| LRA_pred (p1, a1, b1), LRA_pred (p2, a2, b2) ->
p1 = p2 && sub_eq a1 a2 && sub_eq b1 b2
| LRA_op (p1, a1, b1), LRA_op (p2, a2, b2) ->
p1 = p2 && sub_eq a1 a2 && sub_eq b1 b2
| LRA_const a1, LRA_const a2 -> Q.equal a1 a2
| LRA_mult (n1, x1), LRA_mult (n2, x2) -> Q.equal n1 n2 && sub_eq x1 x2
| LRA_other x1, LRA_other x2 -> sub_eq x1 x2
| (LRA_pred _ | LRA_op _ | LRA_const _ | LRA_mult _ | LRA_other _), _ ->
false
end
type term = Term.t
type ty = Term.t
type Const.view += Const of Q.t | Pred of Pred.t | Op of Op.t | Mult_by of Q.t
let ops : Const.ops =
(module struct
let pp out = function
| Const q -> Q.pp_print out q
| Pred p -> Pred.pp out p
| Op o -> Op.pp out o
| Mult_by q -> Fmt.fprintf out "(* %a)" Q.pp_print q
| _ -> assert false
let equal a b =
match a, b with
| Const a, Const b -> Q.equal a b
| Pred a, Pred b -> Pred.equal a b
| Op a, Op b -> Op.equal a b
| Mult_by a, Mult_by b -> Q.equal a b
| _ -> false
let hash = function
| Const q -> Sidekick_zarith.Rational.hash q
| Pred p -> Pred.hash p
| Op o -> Op.hash o
| Mult_by q -> Hash.(combine2 135 (Sidekick_zarith.Rational.hash q))
| _ -> assert false
end)
let real tst = Ty.real tst
let has_ty_real t = Ty.is_real (T.ty t)
let const tst q : term =
Term.const tst (Const.make (Const q) ops ~ty:(real tst))
let mult_by tst q t : term =
let ty_c = Term.arrow tst (real tst) (real tst) in
let c = Term.const tst (Const.make (Mult_by q) ops ~ty:ty_c) in
Term.app tst c t
let pred tst p t1 t2 : term =
match p with
| Pred.Eq -> T.eq tst t1 t2
| Pred.Neq -> T.not tst (T.eq tst t1 t2)
| _ ->
let ty = Term.(arrow_l tst [ real tst; real tst ] (Term.bool tst)) in
let p = Term.const tst (Const.make (Pred p) ops ~ty) in
Term.app_l tst p [ t1; t2 ]
let leq tst a b = pred tst Pred.Leq a b
let lt tst a b = pred tst Pred.Lt a b
let geq tst a b = pred tst Pred.Geq a b
let gt tst a b = pred tst Pred.Gt a b
let eq tst a b = pred tst Pred.Eq a b
let neq tst a b = pred tst Pred.Neq a b
let op tst op t1 t2 : term =
let ty = Term.(arrow_l tst [ real tst; real tst ] (real tst)) in
let p = Term.const tst (Const.make (Op op) ops ~ty) in
Term.app_l tst p [ t1; t2 ]
let plus tst a b = op tst Op.Plus a b
let minus tst a b = op tst Op.Minus a b
let view (t : term) : _ View.t =
let f, args = Term.unfold_app t in
match T.view f, args with
| T.E_const { Const.c_view = T.C_eq; _ }, [ _; a; b ] when has_ty_real a ->
View.LRA_pred (Pred.Eq, a, b)
| T.E_const { Const.c_view = T.C_not; _ }, [ u ] ->
(* might be not-eq *)
let f, args = Term.unfold_app u in
(match T.view f, args with
| T.E_const { Const.c_view = T.C_eq; _ }, [ _; a; b ] when has_ty_real a ->
View.LRA_pred (Pred.Neq, a, b)
| _ -> View.LRA_other t)
| T.E_const { Const.c_view = Const q; _ }, [] -> View.LRA_const q
| T.E_const { Const.c_view = Pred p; _ }, [ a; b ] -> View.LRA_pred (p, a, b)
| T.E_const { Const.c_view = Op op; _ }, [ a; b ] -> View.LRA_op (op, a, b)
| T.E_const { Const.c_view = Mult_by q; _ }, [ a ] -> View.LRA_mult (q, a)
| _ -> View.LRA_other t
let term_of_view store = function
| View.LRA_const q -> const store q
| View.LRA_mult (n, t) -> mult_by store n t
| View.LRA_pred (p, a, b) -> pred store p a b
| View.LRA_op (o, a, b) -> op store o a b
| View.LRA_other x -> x

57
src/base/LRA_term.mli Normal file
View file

@ -0,0 +1,57 @@
open Sidekick_core
module Pred : sig
type t = Sidekick_th_lra.Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
include Sidekick_sigs.EQ_HASH_PRINT with type t := t
end
module Op : sig
type t = Sidekick_th_lra.op = Plus | Minus
include Sidekick_sigs.EQ_HASH_PRINT with type t := t
end
module View : sig
type ('num, 'a) lra_view = ('num, 'a) Sidekick_th_lra.lra_view =
| LRA_pred of Pred.t * 'a * 'a
| LRA_op of Op.t * 'a * 'a
| LRA_mult of 'num * 'a
| LRA_const of 'num
| LRA_other of 'a
type 'a t = (Q.t, 'a) Sidekick_th_lra.lra_view
val map : f_c:(Q.t -> Q.t) -> ('a -> 'b) -> 'a t -> 'b t
val iter : ('a -> unit) -> 'a t -> unit
val pp : pp_t:'a Fmt.printer -> 'a t Fmt.printer
val hash : sub_hash:('a -> int) -> 'a t -> int
val equal : sub_eq:('a -> 'b -> bool) -> 'a t -> 'b t -> bool
end
type term = Term.t
type ty = Term.t
val term_of_view : Term.store -> term View.t -> term
val real : Term.store -> ty
val has_ty_real : term -> bool
val pred : Term.store -> Pred.t -> term -> term -> term
val mult_by : Term.store -> Q.t -> term -> term
val op : Term.store -> Op.t -> term -> term -> term
val const : Term.store -> Q.t -> term
(** {2 Helpers} *)
val leq : Term.store -> term -> term -> term
val lt : Term.store -> term -> term -> term
val geq : Term.store -> term -> term -> term
val gt : Term.store -> term -> term -> term
val eq : Term.store -> term -> term -> term
val neq : Term.store -> term -> term -> term
val plus : Term.store -> term -> term -> term
val minus : Term.store -> term -> term -> term
(** {2 View} *)
val view : term -> term View.t
(** View as LRA *)

View file

@ -1 +0,0 @@
include Sidekick_lit.Make (Solver_arg)

View file

@ -1 +0,0 @@
include Sidekick_core.LIT with module T = Solver_arg

View file

@ -1,246 +0,0 @@
(* This file is free software. See file "license" for more details. *)
open! Base_types
module Val_map = struct
module M = CCMap.Make (CCInt)
module Key = struct
type t = Value.t list
let equal = CCList.equal Value.equal
let hash = Hash.list Value.hash
end
type key = Key.t
type 'a t = (key * 'a) list M.t
let empty = M.empty
let is_empty m = M.cardinal m = 0
let cardinal = M.cardinal
let find k m =
try Some (CCList.assoc ~eq:Key.equal k @@ M.find (Key.hash k) m)
with Not_found -> None
let add k v m =
let h = Key.hash k in
let l = M.get_or ~default:[] h m in
let l = CCList.Assoc.set ~eq:Key.equal k v l in
M.add h l m
let to_iter m yield = M.iter (fun _ l -> List.iter yield l) m
end
module Fun_interpretation = struct
type t = { cases: Value.t Val_map.t; default: Value.t }
let default fi = fi.default
let cases_list fi = Val_map.to_iter fi.cases |> Iter.to_rev_list
let make ~default l : t =
let m =
List.fold_left (fun m (k, v) -> Val_map.add k v m) Val_map.empty l
in
{ cases = m; default }
end
type t = { values: Value.t Term.Map.t; funs: Fun_interpretation.t Fun.Map.t }
let empty : t = { values = Term.Map.empty; funs = Fun.Map.empty }
(* FIXME: ues this to allocate a default value for each sort
(* get or make a default value for this type *)
let rec get_ty_default (ty:Ty.t) : Value.t =
match Ty.view ty with
| Ty_prop -> Value.true_
| Ty_atomic { def = Ty_uninterpreted _;_} ->
(* domain element *)
Ty_tbl.get_or_add ty_tbl ~k:ty
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
| Ty_atomic { def = Ty_def d; args; _} ->
(* ask the theory for a default value *)
Ty_tbl.get_or_add ty_tbl ~k:ty
~f:(fun _ty ->
let vals = List.map get_ty_default args in
d.default_val vals)
in
*)
let[@inline] mem t m = Term.Map.mem t m.values
let[@inline] find t m = Term.Map.get t m.values
let add t v m : t =
match Term.Map.find t m.values with
| v' ->
if not @@ Value.equal v v' then
Error.errorf
"@[Model: incompatible values for term %a@ :previous %a@ :new %a@]"
Term.pp t Value.pp v Value.pp v';
m
| exception Not_found -> { m with values = Term.Map.add t v m.values }
let add_fun c v m : t =
match Fun.Map.find c m.funs with
| _ ->
Error.errorf "@[Model: function %a already has an interpretation@]" Fun.pp c
| exception Not_found -> { m with funs = Fun.Map.add c v m.funs }
(* merge two models *)
let merge m1 m2 : t =
let values =
Term.Map.merge_safe m1.values m2.values ~f:(fun t o ->
match o with
| `Left v | `Right v -> Some v
| `Both (v1, v2) ->
if Value.equal v1 v2 then
Some v1
else
Error.errorf
"@[Model: incompatible values for term %a@ :previous %a@ :new \
%a@]"
Term.pp t Value.pp v1 Value.pp v2)
and funs =
Fun.Map.merge_safe m1.funs m2.funs ~f:(fun c o ->
match o with
| `Left v | `Right v -> Some v
| `Both _ ->
Error.errorf "cannot merge the two interpretations of function %a"
Fun.pp c)
in
{ values; funs }
let add_funs fs m : t = merge { values = Term.Map.empty; funs = fs } m
let pp out { values; funs } =
let module FI = Fun_interpretation in
let pp_tv out (t, v) =
Fmt.fprintf out "(@[%a@ := %a@])" Term.pp t Value.pp v
in
let pp_fun_entry out (vals, ret) =
Format.fprintf out "(@[%a@ := %a@])" (Fmt.Dump.list Value.pp) vals Value.pp
ret
in
let pp_fun out ((c, fi) : Fun.t * FI.t) =
Format.fprintf out "(@[<hov>%a :default %a@ %a@])" Fun.pp c Value.pp
fi.FI.default
(Fmt.list ~sep:(Fmt.return "@ ") pp_fun_entry)
(FI.cases_list fi)
in
Fmt.fprintf out "(@[model@ @[:terms (@[<hv>%a@])@]@ @[:funs (@[<hv>%a@])@]@])"
(Fmt.iter ~sep:Fmt.(return "@ ") pp_tv)
(Term.Map.to_iter values)
(Fmt.iter ~sep:Fmt.(return "@ ") pp_fun)
(Fun.Map.to_iter funs)
exception No_value
let eval (m : t) (t : Term.t) : Value.t option =
let module FI = Fun_interpretation in
let rec aux t =
match Term.view t with
| Bool b -> Value.bool b
| Not a ->
(match aux a with
| V_bool b -> V_bool (not b)
| v ->
Error.errorf "@[Model: wrong value@ for boolean %a@ :val %a@]" Term.pp a
Value.pp v)
| Ite (a, b, c) ->
(match aux a with
| V_bool true -> aux b
| V_bool false -> aux c
| v ->
Error.errorf "@[Model: wrong value@ for boolean %a@ :val %a@]" Term.pp a
Value.pp v)
| Eq (a, b) ->
let a = aux a in
let b = aux b in
if Value.equal a b then
Value.true_
else
Value.false_
| LRA _l ->
assert false
(* TODO: evaluation
begin match l with
| LRA_pred (p, a, b) ->
| LRA_op (_, _, _)|LRA_const _|LRA_other _ -> assert false
end
*)
| LIA _l -> assert false (* TODO *)
| App_fun (c, args) ->
(match Fun.view c, (args : _ array :> _ array) with
| Fun_def udef, _ ->
(* use builtin interpretation function *)
let args = CCArray.map aux args in
udef.eval args
| Fun_cstor c, _ -> Value.cstor_app c (Util.array_to_list_map aux args)
| Fun_select s, [| u |] ->
(match aux u with
| V_cstor { c; args } when Cstor.equal c s.select_cstor ->
List.nth args s.select_i
| v_u ->
Error.errorf "cannot eval selector %a@ on %a" Term.pp t Value.pp v_u)
| Fun_is_a c1, [| u |] ->
(match aux u with
| V_cstor { c = c2; args = _ } -> Value.bool (Cstor.equal c1 c2)
| v_u ->
Error.errorf "cannot eval is-a %a@ on %a" Term.pp t Value.pp v_u)
| Fun_select _, _ -> Error.errorf "bad selector term %a" Term.pp t
| Fun_is_a _, _ -> Error.errorf "bad is-a term %a" Term.pp t
| Fun_undef _, _ ->
(try Term.Map.find t m.values
with Not_found ->
(match Fun.Map.find c m.funs with
| fi ->
let args = CCArray.map aux args |> CCArray.to_list in
(match Val_map.find args fi.FI.cases with
| None -> fi.FI.default
| Some v -> v)
| exception Not_found ->
raise No_value (* no particular interpretation *))))
in
try Some (aux t) with No_value -> None
(* TODO: get model from each theory, then complete it as follows based on types
let mk_model (cc:t) (m:A.Model.t) : A.Model.t =
let module Model = A.Model in
let module Value = A.Value in
Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc);
let t_tbl = N_tbl.create 32 in
(* populate [repr -> value] table *)
T_tbl.values cc.tbl
(fun r ->
if N.is_root r then (
(* find a value in the class, if any *)
let v =
N.iter_class r
|> Iter.find_map (fun n -> Model.eval m n.n_term)
in
let v = match v with
| Some v -> v
| None ->
if same_class r (true_ cc) then Value.true_
else if same_class r (false_ cc) then Value.false_
else Value.fresh r.n_term
in
N_tbl.add t_tbl r v
));
(* now map every term to its representative's value *)
let pairs =
T_tbl.values cc.tbl
|> Iter.map
(fun n ->
let r = find_ n in
let v =
try N_tbl.find t_tbl r
with Not_found ->
Error.errorf "didn't allocate a value for repr %a" N.pp r
in
n.n_term, v)
in
let m = Iter.fold (fun m (t,v) -> Model.add t v m) m pairs in
Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m);
m
*)

View file

@ -1,56 +0,0 @@
(* This file is free software. See file "license" for more details. *)
(** Models
A model is a solution to the satisfiability question, created by the
SMT solver when it proves the formula to be {b satisfiable}.
A model gives a value to each term of the original formula(s), in
such a way that the formula(s) is true when the term is replaced by its
value.
*)
open Base_types
module Val_map : sig
type key = Value.t list
type 'a t
val empty : 'a t
val is_empty : _ t -> bool
val cardinal : _ t -> int
val find : key -> 'a t -> 'a option
val add : key -> 'a -> 'a t -> 'a t
end
(** Model for function symbols.
Function models are a finite map from argument tuples to values,
accompanied with a default value that every other argument tuples
map to. In other words, it's of the form:
[lambda x y. if (x=vx1,y=vy1) then v1 else if then else vdefault]
*)
module Fun_interpretation : sig
type t = { cases: Value.t Val_map.t; default: Value.t }
val default : t -> Value.t
val cases_list : t -> (Value.t list * Value.t) list
val make : default:Value.t -> (Value.t list * Value.t) list -> t
end
type t = { values: Value.t Term.Map.t; funs: Fun_interpretation.t Fun.Map.t }
(** Model *)
val empty : t
(** Empty model *)
val add : Term.t -> Value.t -> t -> t
val mem : Term.t -> t -> bool
val find : Term.t -> t -> Value.t option
val merge : t -> t -> t
val pp : t CCFormat.printer
val eval : t -> Term.t -> Value.t option
(** [eval m t] tries to evaluate term [t] in the model.
If it succeeds, the value is returned, otherwise [None] is. *)

View file

@ -1,41 +0,0 @@
open Base_types
type lit = Lit.t
type term = Term.t
type t = unit
type proof_step = unit
type proof_rule = t -> proof_step
module Step_vec = Vec_unit
let create () : t = ()
let with_proof _ _ = ()
let enabled (_pr : t) = false
let del_clause _ _ (_pr : t) = ()
let emit_redundant_clause _ ~hyps:_ _ = ()
let emit_input_clause _ _ = ()
let define_term _ _ _ = ()
let emit_unsat _ _ = ()
let proof_p1 _ _ (_pr : t) = ()
let proof_r1 _ _ (_pr : t) = ()
let proof_res ~pivot:_ _ _ (_pr : t) = ()
let emit_unsat_core _ (_pr : t) = ()
let lemma_preprocess _ _ ~using:_ (_pr : t) = ()
let lemma_true _ _ = ()
let lemma_cc _ _ = ()
let lemma_rw_clause _ ~res:_ ~using:_ (_pr : t) = ()
let with_defs _ _ (_pr : t) = ()
let lemma_lra _ _ = ()
let lemma_bool_tauto _ _ = ()
let lemma_bool_c _ _ _ = ()
let lemma_bool_equiv _ _ _ = ()
let lemma_ite_true ~ite:_ _ = ()
let lemma_ite_false ~ite:_ _ = ()
let lemma_isa_cstor ~cstor_t:_ _ (_pr : t) = ()
let lemma_select_cstor ~cstor_t:_ _ (_pr : t) = ()
let lemma_isa_split _ _ (_pr : t) = ()
let lemma_isa_sel _ (_pr : t) = ()
let lemma_isa_disj _ _ (_pr : t) = ()
let lemma_cstor_inj _ _ _ (_pr : t) = ()
let lemma_cstor_distinct _ _ (_pr : t) = ()
let lemma_acyclicity _ (_pr : t) = ()

View file

@ -1,29 +0,0 @@
(** Dummy proof module that does nothing. *)
open Base_types
include
Sidekick_core.PROOF
with type t = private unit
and type proof_step = private unit
and type lit = Lit.t
and type term = Term.t
type proof_rule = t -> proof_step
val create : unit -> t
val lemma_lra : Lit.t Iter.t -> proof_rule
include
Sidekick_th_data.PROOF
with type proof := t
and type proof_step := proof_step
and type lit := Lit.t
and type term := Term.t
include
Sidekick_th_bool_static.PROOF
with type proof := t
and type proof_step := proof_step
and type lit := Lit.t
and type term := Term.t

View file

@ -8,7 +8,7 @@ type t = P.t
module type CONV_ARG = sig
val proof : Proof.t
val unsat : Proof.proof_step
val unsat : Proof.step_id
end
module Make_lazy_tbl (T : sig
@ -318,7 +318,7 @@ end = struct
P.composite_a steps
end
let of_proof (self : Proof.t) ~(unsat : Proof.proof_step) : P.t =
let of_proof (self : Proof.t) ~(unsat : Proof.step_id) : P.t =
let module C = Conv (struct
let proof = self
let unsat = unsat

View file

@ -4,7 +4,7 @@
type t
val of_proof : Proof.t -> unsat:Proof.proof_step -> t
val of_proof : Proof.t -> unsat:Proof.step_id -> t
type out_format = Sidekick_quip.out_format = Sexp | CSexp

View file

@ -30,27 +30,13 @@ end
(* a step is just a unique integer ID.
The actual step is stored in the chunk_stack. *)
type proof_step = Proof_ser.ID.t
type step_id = Proof_ser.ID.t
type term_id = Proof_ser.ID.t
type lit = Lit.t
type term = Term.t
type t = {
mutable enabled: bool;
buf: Buffer.t;
out: Proof_ser.Bare.Encode.t;
mutable storage: Storage.t;
dispose: unit -> unit;
mutable steps_writer: CS.Writer.t;
mutable next_id: int;
map_term: term_id Term.Tbl.t; (* term -> proof ID *)
map_fun: term_id Fun.Tbl.t;
}
type proof_rule = t -> proof_step
module Step_vec = struct
type elt = proof_step
type elt = step_id
type t = elt Vec.t
let get = Vec.get
@ -71,6 +57,18 @@ module Step_vec = struct
let to_iter = Vec.to_iter
end
type t = {
mutable enabled: bool;
buf: Buffer.t;
out: Proof_ser.Bare.Encode.t;
mutable storage: Storage.t;
dispose: unit -> unit;
mutable steps_writer: CS.Writer.t;
mutable next_id: int;
map_term: term_id Term.Tbl.t; (* term -> proof ID *)
map_fun: term_id Fun.Tbl.t;
}
let disable (self : t) : unit =
self.enabled <- false;
self.storage <- Storage.No_store;
@ -114,7 +112,7 @@ let create ?(config = Config.default) () : t =
let empty = create ~config:Config.empty ()
let iter_steps_backward (self : t) = Storage.iter_steps_backward self.storage
let dummy_step : proof_step = Int32.min_int
let dummy_step : step_id = Int32.min_int
let[@inline] enabled (self : t) = self.enabled
(* allocate a unique ID to refer to an event in the trace *)
@ -178,119 +176,178 @@ let emit_lit_ (self : t) (lit : Lit.t) : term_id =
else
Int32.neg t
let emit_ (self : t) f : proof_step =
if enabled self then (
let view = f () in
let id = alloc_id self in
emit_step_ self { PS.Step.id; view };
id
) else
dummy_step
let emit_no_return_ (self : t) f : unit =
if enabled self then (
let view = f () in
emit_step_ self { PS.Step.id = -1l; view }
)
let[@inline] emit_redundant_clause lits ~hyps (self : t) =
emit_ self @@ fun () ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
let clause = Proof_ser.{ Clause.lits } in
let hyps = Iter.to_array hyps in
PS.Step_view.Step_rup { res = clause; hyps }
let emit_unsat c (self : t) : unit =
emit_no_return_ self @@ fun () -> PS.(Step_view.Step_unsat { Step_unsat.c })
let emit_input_clause (lits : Lit.t Iter.t) (self : t) =
emit_ self @@ fun () ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
PS.(Step_view.Step_input { Step_input.c = { Clause.lits } })
(** What a rule can return. It can return an existing step, or ask to create
a new one. *)
type rule_res = R_new of PS.Step_view.t | R_old of step_id
let define_term t u (self : t) =
emit_ self @@ fun () ->
let t = emit_term_ self t and u = emit_term_ self u in
PS.(Step_view.Expr_def { Expr_def.c = t; rhs = u })
type rule = t -> rule_res
let proof_p1 rw_with c (self : t) =
emit_ self @@ fun () ->
PS.(Step_view.Step_proof_p1 { Step_proof_p1.c; rw_with })
let proof_r1 unit c (self : t) =
emit_ self @@ fun () -> PS.(Step_view.Step_proof_r1 { Step_proof_r1.c; unit })
let proof_res ~pivot c1 c2 (self : t) =
emit_ self @@ fun () ->
let pivot = emit_term_ self pivot in
PS.(Step_view.Step_proof_res { Step_proof_res.c1; c2; pivot })
let lemma_preprocess t u ~using (self : t) =
emit_ self @@ fun () ->
let t = emit_term_ self t and u = emit_term_ self u in
let using = using |> Iter.to_array in
PS.(Step_view.Step_preprocess { Step_preprocess.t; u; using })
let lemma_true t (self : t) =
emit_ self @@ fun () ->
let t = emit_term_ self t in
PS.(Step_view.Step_true { Step_true.true_ = t })
let lemma_cc lits (self : t) =
emit_ self @@ fun () ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
PS.(Step_view.Step_cc { Step_cc.eqns = lits })
let lemma_rw_clause c ~res ~using (self : t) =
let emit_rule_ (self : t) (f : rule) : step_id =
if enabled self then (
let using = Iter.to_array using in
if Array.length using = 0 then
c
(* useless step *)
else
emit_ self @@ fun () ->
let lits = Iter.map (emit_lit_ self) res |> Iter.to_array in
let res = Proof_ser.{ Clause.lits } in
PS.(Step_view.Step_clause_rw { Step_clause_rw.c; res; using })
match f self with
| R_old id -> id
| R_new view ->
let id = alloc_id self in
emit_step_ self { PS.Step.id; view };
id
) else
dummy_step
(* TODO *)
let with_defs _ _ (_pr : t) = dummy_step
module Proof_trace = struct
module A = struct
type nonrec step_id = step_id
type nonrec rule = rule
module Step_vec = Step_vec
end
type nonrec t = t
let enabled = enabled
let add_step = emit_rule_
let[@inline] add_unsat self id = emit_unsat id self
let delete _ _ = ()
end
let r_new v = R_new v
let r_old id = R_old id
module Rule_sat = struct
type nonrec lit = lit
type nonrec step_id = step_id
type nonrec rule = rule
let sat_redundant_clause lits ~hyps : rule =
fun self ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
let clause = Proof_ser.{ Clause.lits } in
let hyps = Iter.to_array hyps in
r_new @@ PS.Step_view.Step_rup { res = clause; hyps }
let sat_input_clause (lits : Lit.t Iter.t) : rule =
fun self ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
r_new @@ PS.(Step_view.Step_input { Step_input.c = { Clause.lits } })
(* TODO *)
let sat_unsat_core _ (_pr : t) = r_old dummy_step
end
module Rule_core = struct
type nonrec term = term
type nonrec step_id = step_id
type nonrec rule = rule
type nonrec lit = lit
let sat_redundant_clause lits ~hyps : rule =
fun self ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
let clause = Proof_ser.{ Clause.lits } in
let hyps = Iter.to_array hyps in
r_new @@ PS.Step_view.Step_rup { res = clause; hyps }
let define_term t u : rule =
fun self ->
let t = emit_term_ self t and u = emit_term_ self u in
r_new @@ PS.(Step_view.Expr_def { Expr_def.c = t; rhs = u })
let proof_p1 rw_with c : rule =
fun _self ->
r_new @@ PS.(Step_view.Step_proof_p1 { Step_proof_p1.c; rw_with })
let proof_r1 unit c : rule =
fun _self -> r_new @@ PS.(Step_view.Step_proof_r1 { Step_proof_r1.c; unit })
let proof_res ~pivot c1 c2 : rule =
fun self ->
let pivot = emit_term_ self pivot in
r_new @@ PS.(Step_view.Step_proof_res { Step_proof_res.c1; c2; pivot })
let lemma_preprocess t u ~using : rule =
fun self ->
let t = emit_term_ self t and u = emit_term_ self u in
let using = using |> Iter.to_array in
r_new @@ PS.(Step_view.Step_preprocess { Step_preprocess.t; u; using })
let lemma_true t : rule =
fun self ->
let t = emit_term_ self t in
r_new @@ PS.(Step_view.Step_true { Step_true.true_ = t })
let lemma_cc lits : rule =
fun self ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
r_new @@ PS.(Step_view.Step_cc { Step_cc.eqns = lits })
let lemma_rw_clause c ~res ~using : rule =
fun self ->
let using = Iter.to_array using in
if Array.length using = 0 then
r_old c
(* useless step *)
else (
let lits = Iter.map (emit_lit_ self) res |> Iter.to_array in
let res = Proof_ser.{ Clause.lits } in
r_new @@ PS.(Step_view.Step_clause_rw { Step_clause_rw.c; res; using })
)
(* TODO *)
let with_defs _ _ (_pr : t) = r_old dummy_step
end
(* not useful *)
let del_clause _ _ (_pr : t) = ()
(* TODO *)
let emit_unsat_core _ (_pr : t) = dummy_step
module Rule_bool = struct
type nonrec term = term
type nonrec lit = lit
type nonrec rule = rule
let emit_unsat c (self : t) : unit =
emit_no_return_ self @@ fun () -> PS.(Step_view.Step_unsat { Step_unsat.c })
let lemma_bool_tauto lits : rule =
fun self ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
r_new @@ PS.(Step_view.Step_bool_tauto { Step_bool_tauto.lits })
let lemma_bool_tauto lits (self : t) =
emit_ self @@ fun () ->
let lits = Iter.map (emit_lit_ self) lits |> Iter.to_array in
PS.(Step_view.Step_bool_tauto { Step_bool_tauto.lits })
let lemma_bool_c rule (ts : Term.t list) : rule =
fun self ->
let exprs = Util.array_of_list_map (emit_term_ self) ts in
r_new @@ PS.(Step_view.Step_bool_c { Step_bool_c.exprs; rule })
let lemma_bool_c rule (ts : Term.t list) (self : t) =
emit_ self @@ fun () ->
let exprs = ts |> Util.array_of_list_map (emit_term_ self) in
PS.(Step_view.Step_bool_c { Step_bool_c.exprs; rule })
let lemma_bool_equiv _ _ _ = r_old dummy_step
let lemma_ite_true ~ite:_ _ = r_old dummy_step
let lemma_ite_false ~ite:_ _ = r_old dummy_step
end
(* TODO *)
let lemma_lra _ _ = dummy_step
let lemma_relax_to_lra _ _ = dummy_step
let lemma_lia _ _ = dummy_step
let lemma_bool_equiv _ _ _ = dummy_step
let lemma_ite_true ~ite:_ _ = dummy_step
let lemma_ite_false ~ite:_ _ = dummy_step
let lemma_isa_cstor ~cstor_t:_ _ (_pr : t) = dummy_step
let lemma_select_cstor ~cstor_t:_ _ (_pr : t) = dummy_step
let lemma_isa_split _ _ (_pr : t) = dummy_step
let lemma_isa_sel _ (_pr : t) = dummy_step
let lemma_isa_disj _ _ (_pr : t) = dummy_step
let lemma_cstor_inj _ _ _ (_pr : t) = dummy_step
let lemma_cstor_distinct _ _ (_pr : t) = dummy_step
let lemma_acyclicity _ (_pr : t) = dummy_step
let lemma_lra _ _ = r_old dummy_step
let lemma_relax_to_lra _ _ = r_old dummy_step
let lemma_lia _ _ = r_old dummy_step
module Rule_data = struct
type nonrec lit = lit
type nonrec rule = rule
type nonrec term = term
let lemma_isa_cstor ~cstor_t:_ _ (_pr : t) = r_old dummy_step
let lemma_select_cstor ~cstor_t:_ _ (_pr : t) = r_old dummy_step
let lemma_isa_split _ _ (_pr : t) = r_old dummy_step
let lemma_isa_sel _ (_pr : t) = r_old dummy_step
let lemma_isa_disj _ _ (_pr : t) = r_old dummy_step
let lemma_cstor_inj _ _ _ (_pr : t) = r_old dummy_step
let lemma_cstor_distinct _ _ (_pr : t) = r_old dummy_step
let lemma_acyclicity _ (_pr : t) = r_old dummy_step
end
module Unsafe_ = struct
let[@inline] id_of_proof_step_ (p : proof_step) : proof_step = p
let[@inline] id_of_proof_step_ (p : step_id) : step_id = p
end

View file

@ -28,39 +28,42 @@ end
(** {2 Main Proof API} *)
type t
module Proof_trace : Sidekick_core.PROOF_TRACE
type t = Proof_trace.t
(** A container for the whole proof *)
type proof_step
(** A proof step in the trace.
type step_id = Proof_trace.A.step_id
type rule = Proof_trace.A.rule
The proof will store all steps, and at the end when we find the empty clause
we can filter them to keep only the relevant ones. *)
module Rule_sat :
Sidekick_core.SAT_PROOF_RULES
with type rule = rule
and type lit = Lit.t
and type step_id = step_id
include
Sidekick_core.PROOF
with type t := t
and type proof_step := proof_step
module Rule_core :
Sidekick_core.PROOF_CORE
with type rule = rule
and type lit = Lit.t
and type term = Term.t
and type step_id = step_id
val lemma_lra : Lit.t Iter.t -> rule
val lemma_relax_to_lra : Lit.t Iter.t -> rule
val lemma_lia : Lit.t Iter.t -> rule
module Rule_data :
Sidekick_th_data.PROOF_RULES
with type rule = rule
and type lit = Lit.t
and type term = Term.t
val lemma_lra : Lit.t Iter.t -> proof_rule
val lemma_relax_to_lra : Lit.t Iter.t -> proof_rule
val lemma_lia : Lit.t Iter.t -> proof_rule
include
Sidekick_th_data.PROOF
with type proof := t
and type proof_step := proof_step
and type lit := Lit.t
and type term := Term.t
include
Sidekick_th_bool_static.PROOF
with type proof := t
and type proof_step := proof_step
and type lit := Lit.t
and type term := Term.t
module Rule_bool :
Sidekick_th_bool_static.PROOF_RULES
with type rule = rule
and type lit = Lit.t
and type term = Term.t
(** {2 Creation} *)
@ -83,5 +86,5 @@ val iter_steps_backward : t -> Proof_ser.Step.t Iter.t
a dummy backend. *)
module Unsafe_ : sig
val id_of_proof_step_ : proof_step -> Proof_ser.ID.t
val id_of_proof_step_ : step_id -> Proof_ser.ID.t
end

View file

@ -1,4 +1,4 @@
(** {1 Sidekick base}
(** Sidekick base
This library is a starting point for writing concrete implementations
of SMT solvers with Sidekick.
@ -6,7 +6,7 @@
It provides a representation of terms, boolean formulas,
linear arithmetic expressions, datatypes for the functors in Sidekick.
In addition, it has a notion of {{!Base_types.Statement} Statement}.
In addition, it has a notion of {{!Statement.t} Statement}.
Statements are instructions
for the SMT solver to do something, such as: define a new constant,
declare a new constant, assert a formula as being true,
@ -14,32 +14,31 @@
etc. Logic formats such as SMT-LIB 2.6 are in fact based on a similar
notion of statements, and a [.smt2] files contains a list of statements.
*)
*)
module Base_types = Base_types
module Types_ = Types_
module Term = Term
module Const = Sidekick_core.Const
module Ty = Ty
module ID = ID
module Fun = Base_types.Fun
module Stat = Stat
module Model = Model
module Term = Base_types.Term
module Value = Base_types.Value
module Term_cell = Base_types.Term_cell
module Ty = Base_types.Ty
module Statement = Base_types.Statement
module Data = Base_types.Data
module Select = Base_types.Select
module Form = Form
module LRA_view = Base_types.LRA_view
module LRA_pred = Base_types.LRA_pred
module LRA_op = Base_types.LRA_op
module LIA_view = Base_types.LIA_view
module LIA_pred = Base_types.LIA_pred
module LIA_op = Base_types.LIA_op
module Solver_arg = Solver_arg
module Lit = Lit
module Proof_dummy = Proof_dummy
module Proof = Proof
module Proof_quip = Proof_quip
module Data_ty = Data_ty
module Cstor = Data_ty.Cstor
module Select = Data_ty.Select
module Statement = Statement
module Solver = Solver
module Uconst = Uconst
module Config = Config
module LRA_term = LRA_term
module Th_data = Th_data
module Th_bool = Th_bool
module Th_lra = Th_lra
module Th_uf = Th_uf
(* re-export *)
module IArray = IArray
let k_th_bool_config = Th_bool.k_config
let th_bool = Th_bool.theory
let th_bool_dyn : Solver.theory = Th_bool.theory_dyn
let th_bool_static : Solver.theory = Th_bool.theory_static
let th_data : Solver.theory = Th_data.theory
let th_lra : Solver.theory = Th_lra.theory
let th_uf : Solver.theory = Th_uf.theory

10
src/base/Solver.ml Normal file
View file

@ -0,0 +1,10 @@
include Sidekick_smt_solver.Solver
let default_arg =
(module struct
let view_as_cc = Term.view_as_cc
let is_valid_literal _ = true
end : Sidekick_smt_solver.Sigs.ARG)
let create_default ?stat ?size ~proof ~theories tst : t =
create default_arg ?stat ?size ~proof ~theories tst ()

View file

@ -1,4 +0,0 @@
open! Base_types
module Term = Term
module Fun = Fun
module Ty = Ty

View file

@ -1,15 +0,0 @@
(** Concrete implementation of {!Sidekick_core.TERM}
this module gathers most definitions above in a form
that is compatible with what Sidekick expects for terms, functions, etc.
*)
open Base_types
include
Sidekick_core.TERM
with type Term.t = Term.t
and type Fun.t = Fun.t
and type Ty.t = Ty.t
and type Term.store = Term.store
and type Ty.store = Ty.store

47
src/base/Statement.ml Normal file
View file

@ -0,0 +1,47 @@
open Types_
type t = statement =
| Stmt_set_logic of string
| Stmt_set_option of string list
| Stmt_set_info of string * string
| Stmt_data of data list
| Stmt_ty_decl of ID.t * int (* new atomic cstor *)
| Stmt_decl of ID.t * ty list * ty
| Stmt_define of definition list
| Stmt_assert of term
| Stmt_assert_clause of term list
| Stmt_check_sat of (bool * term) list
| Stmt_get_model
| Stmt_get_value of term list
| Stmt_exit
(** Pretty print a statement *)
let pp out = function
| Stmt_set_logic s -> Fmt.fprintf out "(set-logic %s)" s
| Stmt_set_option l ->
Fmt.fprintf out "(@[set-logic@ %a@])" (Util.pp_list Fmt.string) l
| Stmt_set_info (a, b) -> Fmt.fprintf out "(@[set-info@ %s@ %s@])" a b
| Stmt_check_sat [] -> Fmt.string out "(check-sat)"
| Stmt_check_sat l ->
let pp_pair out (b, t) =
if b then
Term.pp_debug out t
else
Fmt.fprintf out "(@[not %a@])" Term.pp_debug t
in
Fmt.fprintf out "(@[check-sat-assuming@ (@[%a@])@])" (Fmt.list pp_pair) l
| Stmt_ty_decl (s, n) -> Fmt.fprintf out "(@[declare-sort@ %a %d@])" ID.pp s n
| Stmt_decl (id, args, ret) ->
Fmt.fprintf out "(@[<1>declare-fun@ %a (@[%a@])@ %a@])" ID.pp id
(Util.pp_list Ty.pp) args Ty.pp ret
| Stmt_assert t -> Fmt.fprintf out "(@[assert@ %a@])" Term.pp_debug t
| Stmt_assert_clause c ->
Fmt.fprintf out "(@[assert-clause@ %a@])" (Util.pp_list Term.pp_debug) c
| Stmt_exit -> Fmt.string out "(exit)"
| Stmt_data l ->
Fmt.fprintf out "(@[declare-datatypes@ %a@])" (Util.pp_list Data_ty.pp) l
| Stmt_get_model -> Fmt.string out "(get-model)"
| Stmt_get_value l ->
Fmt.fprintf out "(@[get-value@ (@[%a@])@])" (Util.pp_list Term.pp_debug) l
| Stmt_define _ -> assert false
(* TODO *)

24
src/base/Statement.mli Normal file
View file

@ -0,0 +1,24 @@
(** Statements.
A statement is an instruction for the SMT solver to do something,
like asserting that a formula is true, declaring a new constant,
or checking satisfiabilty of the current set of assertions. *)
open Types_
type t = statement =
| Stmt_set_logic of string
| Stmt_set_option of string list
| Stmt_set_info of string * string
| Stmt_data of data list
| Stmt_ty_decl of ID.t * int (* new atomic cstor *)
| Stmt_decl of ID.t * ty list * ty
| Stmt_define of definition list
| Stmt_assert of term
| Stmt_assert_clause of term list
| Stmt_check_sat of (bool * term) list
| Stmt_get_model
| Stmt_get_value of term list
| Stmt_exit
include Sidekick_sigs.PRINT with type t := t

3
src/base/Term.ml Normal file
View file

@ -0,0 +1,3 @@
include Sidekick_core.Term
let view_as_cc = Sidekick_core.Default_cc_view.view_as_cc

67
src/base/Ty.ml Normal file
View file

@ -0,0 +1,67 @@
(** Core types *)
include Sidekick_core.Term
open Types_
let pp = pp_debug
type Const.view += Ty of ty_view
type data = Types_.data
let ops_ty : Const.ops =
(module struct
let pp out = function
| Ty ty ->
(match ty with
| Ty_real -> Fmt.string out "Real"
| Ty_int -> Fmt.string out "Int"
| Ty_uninterpreted { id; _ } -> ID.pp out id)
| _ -> ()
let equal a b =
match a, b with
| Ty a, Ty b ->
(match a, b with
| Ty_int, Ty_int | Ty_real, Ty_real -> true
| Ty_uninterpreted u1, Ty_uninterpreted u2 -> ID.equal u1.id u2.id
| (Ty_real | Ty_int | Ty_uninterpreted _), _ -> false)
| _ -> false
let hash = function
| Ty a ->
(match a with
| Ty_real -> Hash.int 2
| Ty_int -> Hash.int 3
| Ty_uninterpreted u -> Hash.combine2 10 (ID.hash u.id))
| _ -> assert false
end)
open struct
let mk_ty0 tst view =
let ty = Term.type_ tst in
Term.const tst @@ Const.make (Ty view) ops_ty ~ty
end
(* TODO: handle polymorphic constants *)
let int tst : ty = mk_ty0 tst Ty_int
let real tst : ty = mk_ty0 tst Ty_real
let is_real t =
match Term.view t with
| E_const { Const.c_view = Ty Ty_real; _ } -> true
| _ -> false
let is_int t =
match Term.view t with
| E_const { Const.c_view = Ty Ty_int; _ } -> true
| _ -> false
let uninterpreted tst id : t =
mk_ty0 tst (Ty_uninterpreted { id; finite = false })
let uninterpreted_str tst s : t = uninterpreted tst (ID.make s)
let is_uninterpreted (self : t) =
match view self with
| E_const { Const.c_view = Ty (Ty_uninterpreted _); _ } -> true
| _ -> false

28
src/base/Ty.mli Normal file
View file

@ -0,0 +1,28 @@
open Types_
include module type of struct
include Term
end
type t = ty
type data = Types_.data
include Sidekick_sigs.EQ_ORD_HASH_PRINT with type t := t
val bool : store -> t
val real : store -> t
val int : store -> t
val uninterpreted : store -> ID.t -> t
val uninterpreted_str : store -> string -> t
val is_uninterpreted : t -> bool
val is_real : t -> bool
val is_int : t -> bool
(* TODO: separate functor?
val finite : t -> bool
val set_finite : t -> bool -> unit
val args : t -> ty list
val ret : t -> ty
val arity : t -> int
val unfold : t -> ty list * ty
*)

54
src/base/Uconst.ml Normal file
View file

@ -0,0 +1,54 @@
open Types_
type ty = Term.t
type t = Types_.uconst = { uc_id: ID.t; uc_ty: ty }
let[@inline] id self = self.uc_id
let[@inline] ty self = self.uc_ty
let equal a b = ID.equal a.uc_id b.uc_id
let compare a b = ID.compare a.uc_id b.uc_id
let hash a = ID.hash a.uc_id
let pp out c = ID.pp out c.uc_id
type Const.view += Uconst of t
let ops =
(module struct
let pp out = function
| Uconst c -> pp out c
| _ -> assert false
let equal a b =
match a, b with
| Uconst a, Uconst b -> equal a b
| _ -> false
let hash = function
| Uconst c -> Hash.combine2 522660 (hash c)
| _ -> assert false
end : Const.DYN_OPS)
let[@inline] make uc_id uc_ty : t = { uc_id; uc_ty }
let uconst tst (self : t) : Term.t =
Term.const tst @@ Const.make (Uconst self) ops ~ty:self.uc_ty
let uconst_of_id tst id ty = uconst tst @@ make id ty
let uconst_of_id' tst id args ret =
let ty = Term.arrow_l tst args ret in
uconst_of_id tst id ty
let uconst_of_str tst name args ret : term =
uconst_of_id' tst (ID.make name) args ret
module As_key = struct
type nonrec t = t
let compare = compare
let equal = equal
let hash = hash
end
module Map = CCMap.Make (As_key)
module Tbl = CCHashtbl.Make (As_key)

24
src/base/Uconst.mli Normal file
View file

@ -0,0 +1,24 @@
(** Uninterpreted constants *)
open Types_
type ty = Term.t
type t = Types_.uconst = { uc_id: ID.t; uc_ty: ty }
val id : t -> ID.t
val ty : t -> ty
type Const.view += private Uconst of t
include Sidekick_sigs.EQ_ORD_HASH_PRINT with type t := t
val make : ID.t -> ty -> t
(** Make a new uninterpreted function. *)
val uconst : Term.store -> t -> Term.t
val uconst_of_id : Term.store -> ID.t -> ty -> Term.t
val uconst_of_id' : Term.store -> ID.t -> ty list -> ty -> Term.t
val uconst_of_str : Term.store -> string -> ty list -> ty -> Term.t
module Map : CCMap.S with type key = t
module Tbl : CCHashtbl.S with type key = t

View file

@ -2,7 +2,7 @@
(name sidekick_base)
(public_name sidekick-base)
(synopsis "Base term definitions for the standalone SMT solver and library")
(libraries containers iter sidekick.core sidekick.util sidekick.lit
sidekick-base.proof-trace sidekick.quip sidekick.arith-lra
sidekick.th-bool-static sidekick.th-data sidekick.zarith zarith)
(flags :standard -w -32 -open Sidekick_util))
(libraries containers iter sidekick.core sidekick.util sidekick.smt-solver
sidekick.cc sidekick.quip sidekick.th-lra sidekick.th-bool-static
sidekick.th-bool-dyn sidekick.th-data sidekick.zarith zarith)
(flags :standard -w +32 -open Sidekick_util))

25
src/base/th_bool.ml Normal file
View file

@ -0,0 +1,25 @@
(** Reducing boolean formulas to clauses *)
let k_config : [ `Dyn | `Static ] Config.Key.t = Config.Key.create ()
let theory_static : Solver.theory =
Sidekick_th_bool_static.theory
(module struct
let view_as_bool = Form.view
let mk_bool = Form.mk_of_view
end : Sidekick_th_bool_static.ARG)
let theory_dyn : Solver.theory =
Sidekick_th_bool_dyn.theory
(module struct
let view_as_bool = Form.view
let mk_bool = Form.mk_of_view
end : Sidekick_th_bool_static.ARG)
let theory (conf : Config.t) : Solver.theory =
match Config.find k_config conf with
| Some `Dyn -> theory_dyn
| Some `Static -> theory_static
| None ->
(* default *)
theory_static

79
src/base/th_data.ml Normal file
View file

@ -0,0 +1,79 @@
(** Theory of datatypes *)
open Sidekick_core
let arg =
(module struct
module S = Solver
open! Sidekick_th_data
open Data_ty
module Cstor = Cstor
(* TODO: we probably want to make sure cstors are not polymorphic?!
maybe work on a type/cstor that's applied to pre-selected variables,
like [Map A B] with [A],[B] used for the whole type *)
let unfold_pi t =
let rec unfold acc t =
match Term.view t with
| Term.E_pi (_, ty, bod) -> unfold (ty :: acc) bod
| _ -> List.rev acc, t
in
unfold [] t
let as_datatype ty : _ data_ty_view =
let args, ret = unfold_pi ty in
if args <> [] then
Ty_arrow (args, ret)
else (
match Data_ty.as_data ty, Term.view ty with
| Some d, _ ->
let cstors = Lazy.force d.data_cstors in
let cstors = ID.Map.fold (fun _ c l -> c :: l) cstors [] in
Ty_data { cstors }
| None, E_app (a, b) -> Ty_other { sub = [ a; b ] }
| None, E_pi (_, a, b) -> Ty_other { sub = [ a; b ] }
| ( None,
( E_const _ | E_var _ | E_type _ | E_bound_var _ | E_lam _
| E_app_fold _ ) ) ->
Ty_other { sub = [] }
)
let view_as_data t : _ data_view =
let h, args = Term.unfold_app t in
match
Data_ty.as_cstor h, Data_ty.as_select h, Data_ty.as_is_a h, args
with
| Some c, _, _, _ ->
(* TODO: check arity? store it in [c] ? *)
T_cstor (c, args)
| None, Some sel, _, [ arg ] ->
T_select (sel.select_cstor, sel.select_i, arg)
| None, None, Some c, [ arg ] -> T_is_a (c, arg)
| _ -> T_other t
let mk_eq = Term.eq
let mk_cstor tst c args : Term.t = Term.app_l tst (Data_ty.cstor tst c) args
let mk_sel tst c i u =
Term.app_l tst (Data_ty.select tst @@ Data_ty.Cstor.select_idx c i) [ u ]
let mk_is_a tst c u : Term.t =
if c.cstor_arity = 0 then
Term.eq tst u (Data_ty.cstor tst c)
else
Term.app_l tst (Data_ty.is_a tst c) [ u ]
(* NOTE: maybe finiteness should be part of the core typeclass for
type consts? or we have a registry for infinite types? *)
let rec ty_is_finite ty =
match Term.view ty with
| E_const { Const.c_view = Uconst.Uconst _; _ } -> true
| E_const { Const.c_view = Data_ty.Data _d; _ } -> true (* TODO: ?? *)
| E_pi (_, a, b) -> ty_is_finite a && ty_is_finite b
| _ -> true
let ty_set_is_finite _ _ = () (* TODO: remove, use a weak table instead *)
end : Sidekick_th_data.ARG)
let theory = Sidekick_th_data.make arg

21
src/base/th_lra.ml Normal file
View file

@ -0,0 +1,21 @@
(** Theory of Linear Rational Arithmetic *)
open Sidekick_core
module T = Term
module Q = Sidekick_zarith.Rational
open LRA_term
let mk_eq = Form.eq
let mk_bool = T.bool
let theory : Solver.theory =
Sidekick_th_lra.theory
(module struct
module Z = Sidekick_zarith.Int
module Q = Sidekick_zarith.Rational
let ty_real = LRA_term.real
let has_ty_real = LRA_term.has_ty_real
let view_as_lra = LRA_term.view
let mk_lra = LRA_term.term_of_view
end : Sidekick_th_lra.ARG)

24
src/base/th_uf.ml Normal file
View file

@ -0,0 +1,24 @@
(** Theory of uninterpreted functions *)
open Sidekick_core
open Sidekick_smt_solver
open struct
module SI = Solver_internal
let on_is_subterm ~th_id (solver : SI.t) (_, _, t) : _ list =
let f, args = Term.unfold_app t in
(match Term.view f, args with
| Term.E_const { Const.c_view = Uconst.Uconst _; _ }, _ :: _ ->
SI.claim_term solver ~th_id t
| _ -> ());
[]
end
let theory : Theory.t =
Theory.make ~name:"uf"
~create_and_setup:(fun ~id:th_id solver ->
SI.on_cc_is_subterm solver (on_is_subterm ~th_id solver);
())
()

77
src/base/types_.ml Normal file
View file

@ -0,0 +1,77 @@
include Sidekick_core
(* FIXME
module Proof_ser = Sidekick_base_proof_trace.Proof_ser
module Storage = Sidekick_base_proof_trace.Storage
*)
type term = Term.t
type ty = Term.t
type value = Term.t
type uconst = { uc_id: ID.t; uc_ty: ty }
(** Uninterpreted constant. *)
type ty_view =
| Ty_int
| Ty_real
| Ty_uninterpreted of { id: ID.t; mutable finite: bool }
(* TODO: remove (lives in Data_ty now)
| Ty_data of { data: data }
*)
and data = {
data_id: ID.t;
data_cstors: cstor ID.Map.t lazy_t;
data_as_ty: ty lazy_t;
}
and cstor = {
cstor_id: ID.t;
cstor_is_a: ID.t;
mutable cstor_arity: int;
cstor_args: select list lazy_t;
cstor_ty_as_data: data;
cstor_ty: ty lazy_t;
}
and select = {
select_id: ID.t;
select_cstor: cstor;
select_ty: ty lazy_t;
select_i: int;
}
(* FIXME: just use terms; introduce a Const.view for V_element
(** Semantic values, used for models (and possibly model-constructing calculi) *)
type value_view =
| V_element of { id: ID.t; ty: ty }
(** a named constant, distinct from any other constant *)
| V_cstor of { c: cstor; args: value list }
| V_custom of {
view: value_custom_view;
pp: value_custom_view Fmt.printer;
eq: value_custom_view -> value_custom_view -> bool;
hash: value_custom_view -> int;
} (** Custom value *)
| V_real of Q.t
and value_custom_view = ..
*)
type definition = ID.t * ty * term
type statement =
| Stmt_set_logic of string
| Stmt_set_option of string list
| Stmt_set_info of string * string
| Stmt_data of data list
| Stmt_ty_decl of ID.t * int (* new atomic cstor *)
| Stmt_decl of ID.t * ty list * ty
| Stmt_define of definition list
| Stmt_assert of term
| Stmt_assert_clause of term list
| Stmt_check_sat of (bool * term) list
| Stmt_get_model
| Stmt_get_value of term list
| Stmt_exit

961
src/cc/CC.ml Normal file
View file

@ -0,0 +1,961 @@
open Types_
type view_as_cc = Term.t -> (Const.t, Term.t, Term.t list) CC_view.t
type e_node = E_node.t
(** A node of the congruence closure *)
type repr = E_node.t
(** Node that is currently a representative. *)
type explanation = Expl.t
type bitfield = Bits.field
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n : e_node) : repr =
let n2 = n.n_root in
assert (E_node.is_root n2);
n2
let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool =
E_node.equal (find_ n1) (find_ n2)
let[@inline] find _ n = find_ n
module Sig_tbl = CCHashtbl.Make (Signature)
module T_tbl = Term.Tbl
type propagation_reason = unit -> Lit.t list * Proof_term.step_id
module Handler_action = struct
type t =
| Act_merge of E_node.t * E_node.t * Expl.t
| Act_propagate of Lit.t * propagation_reason
type conflict = Conflict of Expl.t [@@unboxed]
type or_conflict = (t list, conflict) result
end
module Result_action = struct
type t = Act_propagate of { lit: Lit.t; reason: propagation_reason }
type conflict = Conflict of Lit.t list * Proof_term.step_id
type or_conflict = (t list, conflict) result
end
type combine_task =
| CT_merge of e_node * e_node * explanation
| CT_act of Handler_action.t
type t = {
view_as_cc: view_as_cc;
tst: Term.store;
stat: Stat.t;
proof: Proof_trace.t;
tbl: e_node T_tbl.t; (* internalization [term -> e_node] *)
signatures_tbl: e_node Sig_tbl.t;
(* map a signature to the corresponding e_node in some equivalence class.
A signature is a [term_cell] in which every immediate subterm
that participates in the congruence/evaluation relation
is normalized (i.e. is its own representative).
The critical property is that all members of an equivalence class
that have the same "shape" (including head symbol)
have the same signature *)
pending: e_node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
bitgen: Bits.bitfield_gen;
field_marked_explain: Bits.field;
(* used to mark traversed nodes when looking for a common ancestor *)
true_: e_node lazy_t;
false_: e_node lazy_t;
mutable in_loop: bool; (* currently being modified? *)
res_acts: Result_action.t Vec.t; (* to return *)
on_pre_merge:
( t * E_node.t * E_node.t * Expl.t,
Handler_action.or_conflict )
Event.Emitter.t;
on_pre_merge2:
( t * E_node.t * E_node.t * Expl.t,
Handler_action.or_conflict )
Event.Emitter.t;
on_post_merge:
(t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t;
on_new_term: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t;
on_conflict: (ev_on_conflict, unit) Event.Emitter.t;
on_propagate:
(t * Lit.t * propagation_reason, Handler_action.t list) Event.Emitter.t;
on_is_subterm: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t;
count_conflict: int Stat.counter;
count_props: int Stat.counter;
count_merge: int Stat.counter;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
to the current explanation. That allows not to prove some equality
several times.
See "fast congruence closure and extensions", Nieuwenhuis&al, page 14 *)
and ev_on_conflict = { cc: t; th: bool; c: Lit.t list }
let[@inline] size_ (r : repr) = r.n_size
let[@inline] n_true self = Lazy.force self.true_
let[@inline] n_false self = Lazy.force self.false_
let n_bool self b =
if b then
n_true self
else
n_false self
let[@inline] term_store self = self.tst
let[@inline] proof self = self.proof
let[@inline] stat self = self.stat
let allocate_bitfield self ~descr : bitfield =
Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr);
Bits.mk_field self.bitgen
let[@inline] on_backtrack self f : unit =
Backtrack_stack.push_if_nonzero_level self.undo f
let[@inline] set_bitfield_ f b t = t.n_bits <- Bits.set f b t.n_bits
let[@inline] get_bitfield_ field n = Bits.get field n.n_bits
let[@inline] get_bitfield _cc field n = get_bitfield_ field n
let set_bitfield self field b n =
let old = get_bitfield self field n in
if old <> b then (
on_backtrack self (fun () -> set_bitfield_ field old n);
set_bitfield_ field b n
)
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (self : t) (t : Term.t) : bool = T_tbl.mem self.tbl t
module Debug_ = struct
(* print full state *)
let pp out (self : t) : unit =
let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in
let pp_root out n =
if E_node.is_root n then
Fmt.string out " :is-root"
else
Fmt.fprintf out "@ :root %a" E_node.pp n.n_root
in
let pp_expl out n =
match n.n_expl with
| FL_none -> ()
| FL_some e ->
Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp
e.expl
in
let pp_n out n =
Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp_debug n.n_term pp_root n pp_next
n pp_expl n
and pp_sig_e out (s, n) =
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n pp_root
n
in
Fmt.fprintf out
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig-tbl@ %a@])@])"
(Util.pp_iter ~sep:" " pp_n)
(T_tbl.values self.tbl)
(Util.pp_iter ~sep:" " pp_sig_e)
(Sig_tbl.to_iter self.signatures_tbl)
end
(* compute up-to-date signature *)
let update_sig (s : signature) : Signature.t =
CC_view.map_view s ~f_f:(fun x -> x) ~f_t:find_ ~f_ts:(List.map find_)
(* find whether the given (parent) term corresponds to some signature
in [signatures_] *)
let[@inline] find_signature cc (s : signature) : repr option =
Sig_tbl.get cc.signatures_tbl s
(* add to signature table. Assume it's not present already *)
let add_signature self (s : signature) (n : e_node) : unit =
assert (not @@ Sig_tbl.mem self.signatures_tbl s);
Log.debugf 50 (fun k ->
k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n);
on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s);
Sig_tbl.add self.signatures_tbl s n
let push_pending self (n : E_node.t) : unit =
if Option.is_some n.n_sig0 then (
Log.debugf 50 (fun k -> k "(@[<hv1>cc.push-pending@ %a@])" E_node.pp n);
Vec.push self.pending n
)
let[@inline] push_action self (a : Handler_action.t) : unit =
Vec.push self.combine (CT_act a)
let[@inline] push_action_l self (l : _ list) : unit =
List.iter (push_action self) l
let merge_classes self t u e : unit =
if t != u && not (same_class t u) then (
Log.debugf 50 (fun k ->
k "(@[<hv1>cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t E_node.pp
u Expl.pp e);
Vec.push self.combine @@ CT_merge (t, u, e)
)
(* re-root the explanation tree of the equivalence class of [n]
so that it points to [n].
postcondition: [n.n_expl = None] *)
let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit =
match n.n_expl with
| FL_none -> () (* already root *)
| FL_some { next = u; expl = e_n_u } ->
(* reroot to [u], then invert link between [u] and [n] *)
reroot_expl self u;
u.n_expl <- FL_some { next = n; expl = e_n_u };
n.n_expl <- FL_none
exception E_confl of Result_action.conflict
let raise_conflict_ (cc : t) ~th (e : Lit.t list) (p : Proof_term.step_id) : _ =
Profile.instant "cc.conflict";
(* clear tasks queue *)
Vec.clear cc.pending;
Vec.clear cc.combine;
Event.emit cc.on_conflict { cc; th; c = e };
Stat.incr cc.count_conflict;
Vec.clear cc.res_acts;
raise (E_confl (Conflict (e, p)))
let[@inline] all_classes self : repr Iter.t =
T_tbl.values self.tbl |> Iter.filter E_node.is_root
(* find the closest common ancestor of [a] and [b] in the proof forest.
Precond:
- [a] and [b] are in the same class
- no e_node has the flag [field_marked_explain] on
Invariants:
- if [n] is marked, then all the predecessors of [n]
from [a] or [b] are marked too.
*)
let find_common_ancestor self (a : e_node) (b : e_node) : e_node =
(* catch up to the other e_node *)
let rec find1 a =
if get_bitfield_ self.field_marked_explain a then
a
else (
match a.n_expl with
| FL_none -> assert false
| FL_some r -> find1 r.next
)
in
let rec find2 a b =
if E_node.equal a b then
a
else if get_bitfield_ self.field_marked_explain a then
a
else if get_bitfield_ self.field_marked_explain b then
b
else (
set_bitfield_ self.field_marked_explain true a;
set_bitfield_ self.field_marked_explain true b;
match a.n_expl, b.n_expl with
| FL_some r1, FL_some r2 -> find2 r1.next r2.next
| FL_some r, FL_none -> find1 r.next
| FL_none, FL_some r -> find1 r.next
| FL_none, FL_none ->
(* no common ancestor *)
assert false
)
in
(* cleanup tags on nodes traversed in [find2] *)
let rec cleanup_ n =
if get_bitfield_ self.field_marked_explain n then (
set_bitfield_ self.field_marked_explain false n;
match n.n_expl with
| FL_none -> ()
| FL_some { next; _ } -> cleanup_ next
)
in
let n = find2 a b in
cleanup_ a;
cleanup_ b;
n
module Expl_state = struct
type t = {
mutable lits: Lit.t list;
mutable th_lemmas:
(Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list;
}
let create () : t = { lits = []; th_lemmas = [] }
let[@inline] copy self : t = { self with lits = self.lits }
let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits
let[@inline] add_th (self : t) lit hyps pr : unit =
self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas
let merge self other =
let { lits = o_lits; th_lemmas = o_lemmas } = other in
self.lits <- List.rev_append o_lits self.lits;
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas;
()
(* proof of [\/_i ¬lits[i]] *)
let proof_of_th_lemmas (self : t) (proof : Proof_trace.t) : Proof_term.step_id
=
let p_lits1 = List.rev_map Lit.neg self.lits in
let p_lits2 =
self.th_lemmas |> List.rev_map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u)
in
let p_cc =
Proof_trace.add_step proof @@ fun () ->
Proof_core.lemma_cc (List.rev_append p_lits1 p_lits2)
in
let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) =
(* pr_th: [sub_proofs |- t=u].
now resolve away [sub_proofs] to get literals that were
asserted in the congruence closure *)
let pr_th =
List.fold_left
(fun pr_th (lit_i, hyps_i) ->
(* [hyps_i |- lit_i] *)
let lemma_i =
Proof_trace.add_step proof @@ fun () ->
Proof_core.lemma_cc (lit_i :: List.rev_map Lit.neg hyps_i)
in
(* resolve [lit_i] away. *)
Proof_trace.add_step proof @@ fun () ->
Proof_core.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th)
pr_th sub_proofs
in
Proof_trace.add_step proof @@ fun () ->
Proof_core.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr
in
(* resolve with theory proofs responsible for some merges, if any. *)
List.fold_left resolve_with_th_proof p_cc self.th_lemmas
let to_resolved_expl (self : t) : Resolved_expl.t =
(* FIXME: package the th lemmas too *)
let { lits; th_lemmas = _ } = self in
let s2 = copy self in
let pr proof = proof_of_th_lemmas s2 proof in
{ Resolved_expl.lits; pr }
end
(* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose_expl self (st : Expl_state.t) (e : explanation) : unit
=
Log.debugf 5 (fun k -> k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with
| E_trivial -> ()
| E_congruence (n1, n2) ->
(match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Const.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (explain_equal_rec_ self st) a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
explain_equal_rec_ self st f1 f2;
explain_equal_rec_ self st a1 a2
| Some (If (a1, b1, c1)), Some (If (a2, b2, c2)) ->
explain_equal_rec_ self st a1 a2;
explain_equal_rec_ self st b1 b2;
explain_equal_rec_ self st c1 c2
| _ -> assert false)
| E_lit lit -> Expl_state.add_lit st lit
| E_theory (t, u, expl_sets, pr) ->
let sub_proofs =
List.map
(fun (t_i, u_i, expls_i) ->
let lit_i = Lit.make_eq self.tst t_i u_i in
(* use a separate call to [explain_expls] for each set *)
let sub = explain_expls self expls_i in
Expl_state.merge st sub;
lit_i, sub.lits)
expl_sets
in
let lit_t_u = Lit.make_eq self.tst t u in
Expl_state.add_th st lit_t_u sub_proofs pr
| E_merge (a, b) -> explain_equal_rec_ self st a b
| E_merge_t (a, b) ->
(* find nodes for [a] and [b] on the fly *)
(match T_tbl.find self.tbl a, T_tbl.find self.tbl b with
| a, b -> explain_equal_rec_ self st a b
| exception Not_found ->
Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp_debug a
Term.pp_debug b)
| E_and (a, b) ->
explain_decompose_expl self st a;
explain_decompose_expl self st b
and explain_expls self (es : explanation list) : Expl_state.t =
let st = Expl_state.create () in
List.iter (explain_decompose_expl self st) es;
st
and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) :
unit =
if a != b then (
Log.debugf 5 (fun k ->
k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b);
assert (E_node.equal (find_ a) (find_ b));
let ancestor = find_common_ancestor cc a b in
explain_along_path cc st a ancestor;
explain_along_path cc st b ancestor
)
(* explain why [a = target], where [a -> ... -> target] in the
proof forest *)
and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) :
unit =
let rec aux n =
if n != target then (
match n.n_expl with
| FL_none -> assert false
| FL_some { next = next_a; expl } ->
(* prove [a = next_n] *)
explain_decompose_expl self st expl;
(* now prove [next_a = target] *)
aux next_a
)
in
aux a
(* add a term *)
let[@inline] rec add_term_rec_ self t : e_node =
match T_tbl.find self.tbl t with
| n -> n
| exception Not_found -> add_new_term_ self t
(* add [t] when not present already *)
and add_new_term_ self (t : Term.t) : e_node =
assert (not @@ mem self t);
Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp_debug t);
let n = E_node.Internal_.make t in
(* register sub-terms, add [t] to their parent list, and return the
corresponding initial signature *)
let sig0 = compute_sig0 self n in
n.n_sig0 <- sig0;
(* remove term when we backtrack *)
on_backtrack self (fun () ->
Log.debugf 30 (fun k -> k "(@[cc.remove-term@ %a@])" Term.pp_debug t);
T_tbl.remove self.tbl t);
(* add term to the table *)
T_tbl.add self.tbl t n;
if Option.is_some sig0 then
(* [n] might be merged with other equiv classes *)
push_pending self n;
Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self);
n
(* compute the initial signature of the given e_node [n] *)
and compute_sig0 (self : t) (n : e_node) : Signature.t option =
(* add sub-term to [cc], and register [n] to its parents.
Note that we return the exact sub-term, to get proper
explanations, but we add to the sub-term's root's parent list. *)
let deref_sub (u : Term.t) : e_node =
let sub = add_term_rec_ self u in
(* add [n] to [sub.root]'s parent list *)
(let sub_r = find_ sub in
let old_parents = sub_r.n_parents in
if Bag.is_empty old_parents then
(* first time it has parents: tell watchers that this is a subterm *)
Event.emit_iter self.on_is_subterm (self, sub, u) ~f:(push_action_l self);
on_backtrack self (fun () -> sub_r.n_parents <- old_parents);
sub_r.n_parents <- Bag.cons n sub_r.n_parents);
sub
in
let[@inline] return x = Some x in
match self.view_as_cc n.n_term with
| Bool _ | Opaque _ -> None
| Eq (a, b) ->
let a = deref_sub a in
let b = deref_sub b in
return @@ CC_view.Eq (a, b)
| Not u -> return @@ CC_view.Not (deref_sub u)
| App_fun (f, args) ->
let args = List.map deref_sub args in
if args <> [] then
return @@ CC_view.App_fun (f, args)
else
None
| App_ho (f, a) ->
let f = deref_sub f in
let a = deref_sub a in
return @@ CC_view.App_ho (f, a)
| If (a, b, c) -> return @@ CC_view.If (deref_sub a, deref_sub b, deref_sub c)
let[@inline] add_term self t : e_node = add_term_rec_ self t
let mem_term = mem
let set_as_lit self (n : e_node) (lit : Lit.t) : unit =
match n.n_as_lit with
| Some _ -> ()
| None ->
Log.debugf 15 (fun k ->
k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit);
on_backtrack self (fun () -> n.n_as_lit <- None);
n.n_as_lit <- Some lit
(* is [n] true or false? *)
let n_is_bool_value (self : t) n : bool =
E_node.equal n (n_true self) || E_node.equal n (n_false self)
(* gather a pair [lits, pr], where [lits] is the set of
asserted literals needed in the explanation (which is useful for
the SAT solver), and [pr] is a proof, including sub-proofs for theory
merges. *)
let lits_and_proof_of_expl (self : t) (st : Expl_state.t) :
Lit.t list * Proof_term.step_id =
let { Expl_state.lits; th_lemmas = _ } = st in
let pr = Expl_state.proof_of_th_lemmas st self.proof in
lits, pr
(* main CC algo: add terms from [pending] to the signature table,
check for collisions *)
let rec update_tasks (self : t) : unit =
while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do
while not @@ Vec.is_empty self.pending do
task_pending_ self (Vec.pop_exn self.pending)
done;
while not @@ Vec.is_empty self.combine do
task_combine_ self (Vec.pop_exn self.combine)
done
done
and task_pending_ self (n : e_node) : unit =
(* check if some parent collided *)
match n.n_sig0 with
| None -> () (* no-op *)
| Some (Eq (a, b)) ->
(* if [a=b] is now true, merge [(a=b)] and [true] *)
if a != b && same_class a b then (
let expl = Expl.mk_merge a b in
Log.debugf 5 (fun k ->
k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a
E_node.pp b);
merge_classes self n (n_true self) expl
)
| Some (Not u) ->
(* [u = bool ==> not u = not bool] *)
let r_u = find_ u in
if E_node.equal r_u (n_true self) then (
let expl = Expl.mk_merge u (n_true self) in
merge_classes self n (n_false self) expl
) else if E_node.equal r_u (n_false self) then (
let expl = Expl.mk_merge u (n_false self) in
merge_classes self n (n_true self) expl
)
| Some s0 ->
(* update the signature by using [find] on each sub-e_node *)
let s = update_sig s0 in
(match find_signature self s with
| None ->
(* add to the signature table [sig(n) --> n] *)
add_signature self s n
| Some u when E_node.equal n u -> ()
| Some u ->
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
assert (n != u);
let expl = Expl.mk_congruence n u in
merge_classes self n u expl)
and task_combine_ self = function
| CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab
| CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e
| CT_act (Handler_action.Act_propagate (lit, reason)) ->
(* will return this propagation to the caller *)
Vec.push self.res_acts (Result_action.Act_propagate { lit; reason })
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
and task_merge_ self a b e_ab : unit =
let ra = find_ a in
let rb = find_ b in
if not @@ E_node.equal ra rb then (
assert (E_node.is_root ra);
assert (E_node.is_root rb);
Stat.incr self.count_merge;
(* check we're not merging [true] and [false] *)
if
(E_node.equal ra (n_true self) && E_node.equal rb (n_false self))
|| (E_node.equal rb (n_true self) && E_node.equal ra (n_false self))
then (
Log.debugf 5 (fun k ->
k
"(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \
%a@ :t2 %a@]@ :e_ab %a@])"
E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab);
let th = ref false in
(* TODO:
C1: Proof_trace.true_neq_false
C2: lemma [lits |- true=false] (and resolve on theory proofs)
C3: r1 C1 C2
*)
let expl_st = Expl_state.create () in
explain_decompose_expl self expl_st e_ab;
explain_equal_rec_ self expl_st a ra;
explain_equal_rec_ self expl_st b rb;
(* regular conflict *)
let lits, pr = lits_and_proof_of_expl self expl_st in
raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
keep values as representative *)
let r_from, r_into =
if n_is_bool_value self ra then
rb, ra
else if n_is_bool_value self rb then
ra, rb
else if size_ ra > size_ rb then
rb, ra
else
ra, rb
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if E_node.equal r1 (n_true self) then
propagate_bools self r2 t2 r1 t1 e_ab true
else if E_node.equal r1 (n_false self) then
propagate_bools self r2 t2 r1 t1 e_ab false
in
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k ->
k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp r_into);
(* call [on_pre_merge] functions, and merge theory data items *)
(* explanation is [a=ra & e_ab & b=rb] *)
(let expl = Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] in
let handle_act = function
| Ok l -> push_action_l self l
| Error (Handler_action.Conflict expl) ->
raise_conflict_from_expl self expl
in
Event.emit_iter self.on_pre_merge
(self, r_into, r_from, expl)
~f:handle_act;
Event.emit_iter self.on_pre_merge2
(self, r_into, r_from, expl)
~f:handle_act);
(* TODO: merge plugin data here, _after_ the pre-merge hooks are called,
so they have a chance of observing pre-merge plugin data *)
((* parents might have a different signature, check for collisions *)
E_node.iter_parents r_from (fun parent -> push_pending self parent);
(* for each e_node in [r_from]'s class, make it point to [r_into] *)
E_node.iter_class r_from (fun u ->
assert (u.n_root == r_from);
u.n_root <- r_into);
(* capture current state *)
let r_into_old_parents = r_into.n_parents in
let r_into_old_bits = r_into.n_bits in
(* swap [into.next] and [from.next], merging the classes *)
E_node.swap_next r_into r_from;
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
r_into.n_size <- r_into.n_size + r_from.n_size;
r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits;
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack self (fun () ->
Log.debugf 30 (fun k ->
k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from
E_node.pp r_into);
r_into.n_bits <- r_into_old_bits;
(* un-merge the classes *)
E_node.swap_next r_into r_from;
r_into.n_parents <- r_into_old_parents;
(* NOTE: this must come after the restoration of [next] pointers,
otherwise we'd iterate on too big a class *)
E_node.Internal_.iter_class_ r_from (fun u -> u.n_root <- r_from);
r_into.n_size <- r_into.n_size - r_from.n_size));
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *)
reroot_expl self a;
assert (a.n_expl = FL_none);
on_backtrack self (fun () ->
(* on backtracking, link may be inverted, but we delete the one
that bridges between [a] and [b] *)
match a.n_expl, b.n_expl with
| FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none
| _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none
| _ -> assert false);
a.n_expl <- FL_some { next = b; expl = e_ab };
(* call [on_post_merge] *)
Event.emit_iter self.on_post_merge (self, r_into, r_from)
~f:(push_action_l self)
)
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
and which is not the one initially merged.
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl_and_pr =
lazy
(let st = Expl_state.create () in
explain_decompose_expl self st e_12;
explain_equal_rec_ self st r2 t2;
st)
in
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *)
E_node.iter_class r1 (fun u1 ->
(* propagate if:
- [u1] is a proper literal
- [t2 != r2], because that can only happen
after an explicit merge (no way to obtain that by propagation)
*)
match E_node.as_lit u1 with
| Some lit when not (E_node.equal r2 t2) ->
let lit =
if sign then
lit
else
Lit.neg lit
in
(* apply sign *)
Log.debugf 5 (fun k -> k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let (lazy st) = half_expl_and_pr in
let st = Expl_state.copy st in
(* do not modify shared st *)
explain_equal_rec_ self st u1 t1;
(* propagate only if this doesn't depend on some semantic values *)
let reason () =
(* true literals explaining why t1=t2 *)
let guard = st.lits in
(* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *)
Expl_state.add_lit st (Lit.neg lit);
let _, pr = lits_and_proof_of_expl self st in
guard, pr
in
Vec.push self.res_acts (Result_action.Act_propagate { lit; reason });
Event.emit_iter self.on_propagate (self, lit, reason)
~f:(push_action_l self);
Stat.incr self.count_props
| _ -> ())
(* raise a conflict from an explanation, typically from an event handler.
Raises E_confl with a result conflict. *)
and raise_conflict_from_expl self (expl : Expl.t) : 'a =
Log.debugf 5 (fun k ->
k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
let st = Expl_state.create () in
explain_decompose_expl self st expl;
let lits, pr = lits_and_proof_of_expl self st in
let c = List.rev_map Lit.neg lits in
let th = st.th_lemmas <> [] in
raise_conflict_ self ~th c pr
let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t)
let push_level (self : t) : unit =
assert (not self.in_loop);
Backtrack_stack.push_level self.undo
let pop_levels (self : t) n : unit =
assert (not self.in_loop);
Vec.clear self.pending;
Vec.clear self.combine;
Log.debugf 15 (fun k ->
k "(@[cc.pop-levels %d@ :n-lvls %d@])" n
(Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ());
()
let assert_eq self t u expl : unit =
assert (not self.in_loop);
let t = add_term self t in
let u = add_term self u in
(* merge [a] and [b] *)
merge_classes self t u expl
(* assert that this boolean literal holds.
if a lit is [= a b], merge [a] and [b];
otherwise merge the atom with true/false *)
let assert_lit self lit : unit =
assert (not self.in_loop);
let t = Lit.term lit in
Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit);
let sign = Lit.sign lit in
match self.view_as_cc t with
| Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit)
| _ ->
(* equate t and true/false *)
let rhs = n_bool self sign in
let n = add_term self t in
(* TODO: ensure that this is O(1).
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
(* TODO: use oriented merge (force direction [n -> rhs]) *)
merge_classes self n rhs (Expl.mk_lit lit)
let[@inline] assert_lits self lits : unit =
assert (not self.in_loop);
Iter.iter (assert_lit self) lits
let merge self n1 n2 expl =
assert (not self.in_loop);
Log.debugf 5 (fun k ->
k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 E_node.pp
n2 Expl.pp expl);
assert (Term.equal (Term.ty n1.n_term) (Term.ty n2.n_term));
merge_classes self n1 n2 expl
let merge_t self t1 t2 expl =
merge self (add_term self t1) (add_term self t2) expl
let explain_eq self n1 n2 : Resolved_expl.t =
let st = Expl_state.create () in
explain_equal_rec_ self st n1 n2;
(* FIXME: also need to return the proof? *)
Expl_state.to_resolved_expl st
let explain_expl (self : t) expl : Resolved_expl.t =
let expl_st = Expl_state.create () in
explain_decompose_expl self expl_st expl;
Expl_state.to_resolved_expl expl_st
let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge
let[@inline] on_pre_merge2 self = Event.of_emitter self.on_pre_merge2
let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge
let[@inline] on_new_term self = Event.of_emitter self.on_new_term
let[@inline] on_conflict self = Event.of_emitter self.on_conflict
let[@inline] on_propagate self = Event.of_emitter self.on_propagate
let[@inline] on_is_subterm self = Event.of_emitter self.on_is_subterm
let create_ ?(stat = Stat.global) ?(size = `Big) (tst : Term.store)
(proof : Proof_trace.t) ~view_as_cc : t =
let size =
match size with
| `Small -> 128
| `Big -> 2048
in
let bitgen = Bits.mk_gen () in
let field_marked_explain = Bits.mk_field bitgen in
let rec cc =
{
view_as_cc;
tst;
proof;
stat;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
bitgen;
on_pre_merge = Event.Emitter.create ();
on_pre_merge2 = Event.Emitter.create ();
on_post_merge = Event.Emitter.create ();
on_new_term = Event.Emitter.create ();
on_conflict = Event.Emitter.create ();
on_propagate = Event.Emitter.create ();
on_is_subterm = Event.Emitter.create ();
pending = Vec.create ();
combine = Vec.create ();
undo = Backtrack_stack.create ();
true_;
false_;
in_loop = false;
res_acts = Vec.create ();
field_marked_explain;
count_conflict = Stat.mk_int stat "cc.conflicts";
count_props = Stat.mk_int stat "cc.propagations";
count_merge = Stat.mk_int stat "cc.merges";
}
and true_ = lazy (add_term cc (Term.true_ tst))
and false_ = lazy (add_term cc (Term.false_ tst)) in
ignore (Lazy.force true_ : e_node);
ignore (Lazy.force false_ : e_node);
cc
let[@inline] find_t self t : repr =
let n = T_tbl.find self.tbl t in
find_ n
let pop_acts_ self =
let l = Vec.to_list self.res_acts in
Vec.clear self.res_acts;
l
let check self : Result_action.or_conflict =
Log.debug 5 "(cc.check)";
self.in_loop <- true;
let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in
try
update_tasks self;
let l = pop_acts_ self in
Ok l
with E_confl c -> Error c
let check_inv_enabled_ = true (* XXX NUDGE *)
(* check some internal invariants *)
let check_inv_ (self : t) : unit =
if check_inv_enabled_ then (
Log.debug 2 "(cc.check-invariants)";
all_classes self
|> Iter.flat_map E_node.iter_class
|> Iter.iter (fun n ->
match n.n_sig0 with
| None -> ()
| Some s ->
let s' = update_sig s in
let ok =
match find_signature self s' with
| None -> false
| Some r -> E_node.equal r n.n_root
in
if not ok then
Log.debugf 0 (fun k ->
k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])"
E_node.pp n Signature.pp s Signature.pp s'))
)
(* model: return all the classes *)
let get_model (self : t) : repr Iter.t Iter.t =
check_inv_ self;
all_classes self |> Iter.map E_node.iter_class
(** Arguments to a congruence closure's implementation *)
module type ARG = sig
val view_as_cc : view_as_cc
(** View the Term.t through the lens of the congruence closure *)
end
module type BUILD = sig
val create :
?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t
(** Create a new congruence closure.
@param term_store used to be able to create new terms. All terms
interacting with this congruence closure must belong in this term state
as well.
*)
end
module Make (A : ARG) : BUILD = struct
let create ?stat ?size tst proof : t =
create_ ?stat ?size tst proof ~view_as_cc:A.view_as_cc
end
module Default = Make (Sidekick_core.Default_cc_view)
let create (module A : ARG) ?stat ?size tst proof : t =
create_ ?stat ?size tst proof ~view_as_cc:A.view_as_cc
let create_default = Default.create

305
src/cc/CC.mli Normal file
View file

@ -0,0 +1,305 @@
(** Main congruence closure type. *)
open Sidekick_core
type e_node = E_node.t
(** A node of the congruence closure *)
type repr = E_node.t
(** Node that is currently a representative. *)
type explanation = Expl.t
type bitfield = Bits.field
(** A field in the bitfield of this node. This should only be
allocated when a theory is initialized.
Bitfields are accessed using preallocated keys.
See {!allocate_bitfield}.
All fields are initially 0, are backtracked automatically,
and are merged automatically when classes are merged. *)
(** Main congruence closure signature.
The congruence closure handles the theory QF_UF (uninterpreted
function symbols).
It is also responsible for {i theory combination}, and provides
a general framework for equality reasoning that other
theories piggyback on.
For example, the theory of datatypes relies on the congruence closure
to do most of the work, and "only" adds injectivity/disjointness/acyclicity
lemmas when needed.
Similarly, a theory of arrays would hook into the congruence closure and
assert (dis)equalities as needed.
*)
type t
(** The congruence closure object.
It contains a fair amount of state and is mutable
and backtrackable. *)
(** {3 Accessors} *)
val term_store : t -> Term.store
val proof : t -> Proof_trace.t
val stat : t -> Stat.t
val find : t -> e_node -> repr
(** Current representative *)
val add_term : t -> Term.t -> e_node
(** Add the Term.t to the congruence closure, if not present already.
Will be backtracked. *)
val mem_term : t -> Term.t -> bool
(** Returns [true] if the Term.t is explicitly present in the congruence closure *)
val allocate_bitfield : t -> descr:string -> bitfield
(** Allocate a new e_node field (see {!E_node.bitfield}).
This field descriptor is henceforth reserved for all nodes
in this congruence closure, and can be set using {!set_bitfield}
for each class_ individually.
This can be used to efficiently store some metadata on nodes
(e.g. "is there a numeric value in the class"
or "is there a constructor Term.t in the class").
There may be restrictions on how many distinct fields are allocated
for a given congruence closure (e.g. at most {!Sys.int_size} fields).
*)
val get_bitfield : t -> bitfield -> E_node.t -> bool
(** Access the bit field of the given e_node *)
val set_bitfield : t -> bitfield -> bool -> E_node.t -> unit
(** Set the bitfield for the e_node. This will be backtracked.
See {!E_node.bitfield}. *)
type propagation_reason = unit -> Lit.t list * Proof_term.step_id
(** Handler Actions
Actions that can be scheduled by event handlers. *)
module Handler_action : sig
type t =
| Act_merge of E_node.t * E_node.t * Expl.t
| Act_propagate of Lit.t * propagation_reason
(* TODO:
- an action to modify data associated with a class
*)
type conflict = Conflict of Expl.t [@@unboxed]
type or_conflict = (t list, conflict) result
(** Actions or conflict scheduled by an event handler.
- [Ok acts] is a list of merges and propagations
- [Error confl] is a conflict to resolve.
*)
end
(** Result Actions.
Actions returned by the congruence closure after calling {!check}. *)
module Result_action : sig
type t =
| Act_propagate of { lit: Lit.t; reason: propagation_reason }
(** [propagate (Lit.t, reason)] declares that [reason() => Lit.t]
is a tautology.
- [reason()] should return a list of literals that are currently true,
as well as a proof.
- [Lit.t] should be a literal of interest (see {!S.set_as_lit}).
This function might never be called, a congruence closure has the right
to not propagate and only trigger conflicts. *)
type conflict =
| Conflict of Lit.t list * Proof_term.step_id
(** [raise_conflict (c,pr)] declares that [c] is a tautology of
the theory of congruence.
@param pr the proof of [c] being a tautology *)
type or_conflict = (t list, conflict) result
end
(** {3 Events}
Events triggered by the congruence closure, to which
other plugins can subscribe. *)
(** Events emitted by the congruence closure when something changes. *)
val on_pre_merge :
t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t
(** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1]
and [n2] are merged with explanation [expl]. *)
val on_pre_merge2 :
t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t
(** Second phase of "on pre merge". This runs after {!on_pre_merge}
and is used by Plugins. {b NOTE}: Plugin state might be observed as already
changed in these handlers. *)
val on_post_merge :
t -> (t * E_node.t * E_node.t, Handler_action.t list) Event.t
(** [ev_on_post_merge acts n1 n2] is emitted right after [n1]
and [n2] were merged. [find cc n1] and [find cc n2] will return
the same E_node.t. *)
val on_new_term : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t
(** [ev_on_new_term n t] is emitted whenever a new Term.t [t]
is added to the congruence closure. Its E_node.t is [n]. *)
type ev_on_conflict = { cc: t; th: bool; c: Lit.t list }
(** Event emitted when a conflict occurs in the CC.
[th] is true if the explanation for this conflict involves
at least one "theory" explanation; i.e. some of the equations
participating in the conflict are purely syntactic theories
like injectivity of constructors. *)
val on_conflict : t -> (ev_on_conflict, unit) Event.t
(** [ev_on_conflict {th; c}] is emitted when the congruence
closure triggers a conflict by asserting the tautology [c]. *)
val on_propagate :
t ->
( t * Lit.t * (unit -> Lit.t list * Proof_term.step_id),
Handler_action.t list )
Event.t
(** [ev_on_propagate Lit.t reason] is emitted whenever [reason() => Lit.t]
is a propagated lemma. See {!CC_ACTIONS.propagate}. *)
val on_is_subterm : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t
(** [ev_on_is_subterm n t] is emitted when [n] is a subterm of
another E_node.t for the first time. [t] is the Term.t corresponding to
the E_node.t [n]. This can be useful for theory combination. *)
(** {3 Misc} *)
val n_true : t -> E_node.t
(** Node for [true] *)
val n_false : t -> E_node.t
(** Node for [false] *)
val n_bool : t -> bool -> E_node.t
(** Node for either true or false *)
val set_as_lit : t -> E_node.t -> Lit.t -> unit
(** map the given e_node to a literal. *)
val find_t : t -> Term.t -> repr
(** Current representative of the Term.t.
@raise E_node.t_found if the Term.t is not already {!add}-ed. *)
val add_iter : t -> Term.t Iter.t -> unit
(** Add a sequence of terms to the congruence closure *)
val all_classes : t -> repr Iter.t
(** All current classes. This is costly, only use if there is no other solution *)
val explain_eq : t -> E_node.t -> E_node.t -> Resolved_expl.t
(** Explain why the two nodes are equal.
Fails if they are not, in an unspecified way. *)
val explain_expl : t -> Expl.t -> Resolved_expl.t
(** Transform explanation into an actionable conflict clause *)
(* FIXME: remove
val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a
(** Raise a conflict with the given explanation.
It must be a theory tautology that [expl ==> absurd].
To be used in theories.
This fails in an unspecified way if the explanation, once resolved,
satisfies {!Resolved_expl.is_semantic}. *)
*)
val merge : t -> E_node.t -> E_node.t -> Expl.t -> unit
(** Merge these two nodes given this explanation.
It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *)
val merge_t : t -> Term.t -> Term.t -> Expl.t -> unit
(** Shortcut for adding + merging *)
(** {3 Main API *)
val assert_eq : t -> Term.t -> Term.t -> Expl.t -> unit
(** Assert that two terms are equal, using the given explanation. *)
val assert_lit : t -> Lit.t -> unit
(** Given a literal, assume it in the congruence closure and propagate
its consequences. Will be backtracked.
Useful for the theory combination or the SAT solver's functor *)
val assert_lits : t -> Lit.t Iter.t -> unit
(** Addition of many literals *)
val check : t -> Result_action.or_conflict
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the {!actions} to propagate literals, declare conflicts, etc. *)
val push_level : t -> unit
(** Push backtracking level *)
val pop_levels : t -> int -> unit
(** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *)
val get_model : t -> E_node.t Iter.t Iter.t
(** get all the equivalence classes so they can be merged in the model *)
type view_as_cc = Term.t -> (Const.t, Term.t, Term.t list) CC_view.t
(** Arguments to a congruence closure's implementation *)
module type ARG = sig
val view_as_cc : view_as_cc
(** View the Term.t through the lens of the congruence closure *)
end
module type BUILD = sig
val create :
?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t
(** Create a new congruence closure.
@param term_store used to be able to create new terms. All terms
interacting with this congruence closure must belong in this term state
as well.
*)
end
module Make (_ : ARG) : BUILD
val create :
(module ARG) ->
?stat:Stat.t ->
?size:[ `Small | `Big ] ->
Term.store ->
Proof_trace.t ->
t
(** Create a new congruence closure.
@param term_store used to be able to create new terms. All terms
interacting with this congruence closure must belong in this term state
as well.
*)
val create_default :
?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t
(** Same as {!create} but with the default CC view *)
(**/**)
module Debug_ : sig
val pp : t Fmt.printer
(** Print the whole CC *)
end
(**/**)

File diff suppressed because it is too large Load diff

View file

@ -1,13 +1,17 @@
(** {2 Congruence Closure} *)
(** Congruence Closure Implementation *)
open Sidekick_core
module type DYN_MONOID_PLUGIN = Sigs_plugin.DYN_MONOID_PLUGIN
module type MONOID_PLUGIN_ARG = Sigs_plugin.MONOID_PLUGIN_ARG
module type MONOID_PLUGIN_BUILDER = Sigs_plugin.MONOID_PLUGIN_BUILDER
module type S = Sidekick_core.CC_S
module View = Sidekick_core.CC_view
module E_node = E_node
module Expl = Expl
module Signature = Signature
module Resolved_expl = Resolved_expl
module Plugin = Plugin
module CC = CC
module Make (A : CC_ARG) :
S
with module T = A.T
and module Lit = A.Lit
and type proof = A.proof
and type proof_step = A.proof_step
and module Actions = A.Actions
include module type of struct
include CC
end

26
src/cc/bits.ml Normal file
View file

@ -0,0 +1,26 @@
type bitfield_gen = int ref
let max_width = Sys.word_size - 2
let mk_gen () = ref 0
type t = int
type field = int
let empty : t = 0
let mk_field (gen : bitfield_gen) : field =
let n = !gen in
if n > max_width then Error.errorf "maximum number of CC bitfields reached";
incr gen;
1 lsl n
let[@inline] get field x = x land field <> 0
let[@inline] set field b x =
if b then
x lor field
else
x land lnot field
let merge = ( lor )
let equal : t -> t -> bool = CCEqual.poly

13
src/cc/bits.mli Normal file
View file

@ -0,0 +1,13 @@
(** Basic bitfield *)
type t = private int
type field
type bitfield_gen
val empty : t
val equal : t -> t -> bool
val mk_field : bitfield_gen -> field
val mk_gen : unit -> bitfield_gen
val get : field -> t -> bool
val set : field -> bool -> t -> t
val merge : t -> t -> t

View file

@ -1,5 +1,7 @@
(library
(name Sidekick_cc)
(public_name sidekick.cc)
(libraries containers iter sidekick.core sidekick.util)
(flags :standard -warn-error -a+8 -w -32 -open Sidekick_util))
(synopsis "main congruence closure implementation")
(private_modules signature)
(libraries containers iter sidekick.sigs sidekick.core sidekick.util)
(flags :standard -open Sidekick_util))

53
src/cc/e_node.ml Normal file
View file

@ -0,0 +1,53 @@
open Types_
type t = e_node
let[@inline] equal (n1 : t) n2 = n1 == n2
let[@inline] hash n = Term.hash n.n_term
let[@inline] term n = n.n_term
let[@inline] pp out n = Term.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let make (t : Term.t) : t =
let rec n =
{
n_term = t;
n_sig0 = None;
n_bits = Bits.empty;
n_parents = Bag.empty;
n_as_lit = None;
(* TODO: provide a method to do it *)
n_root = n;
n_expl = FL_none;
n_next = n;
n_size = 1;
}
in
n
let[@inline] is_root (n : e_node) : bool = n.n_root == n
(* traverse the equivalence class of [n] *)
let iter_class_ (n_start : e_node) : e_node Iter.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n_start then aux u.n_next
in
aux n_start
let[@inline] iter_class n = iter_class_ n
let[@inline] iter_parents (n : e_node) : e_node Iter.t =
assert (is_root n);
Bag.to_iter n.n_parents
let[@inline] swap_next n1 n2 : unit =
let tmp = n1.n_next in
n1.n_next <- n2.n_next;
n2.n_next <- tmp
module Internal_ = struct
let iter_class_ = iter_class_
let make = make
end

65
src/cc/e_node.mli Normal file
View file

@ -0,0 +1,65 @@
(** E-node.
An e-node is a node in the congruence closure that is contained
in some equivalence classe).
An equivalence class is a set of terms that are currently equal
in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is
distinguished and is called the "representative".
All information pertaining to the whole equivalence class is stored
in its representative's {!E_node.t}.
When two classes become equal (are "merged"), one of the two
representatives is picked as the representative of the new class.
The new class contains the union of the two old classes' nodes.
We also allow theories to store additional information in the
representative. This information can be used when two classes are
merged, to detect conflicts and solve equations à la Shostak.
*)
open Types_
type t = Types_.e_node
(** An E-node.
A value of type [t] points to a particular Term.t, but see
{!find} to get the representative of the class. *)
include Sidekick_sigs.PRINT with type t := t
val term : t -> Term.t
(** Term contained in this equivalence class.
If [is_root n], then [Term.t n] is the class' representative Term.t. *)
val equal : t -> t -> bool
(** Are two classes {b physically} equal? To check for
logical equality, use [CC.E_node.equal (CC.find cc n1) (CC.find cc n2)]
which checks for equality of representatives. *)
val hash : t -> int
(** An opaque hash of this E_node.t. *)
val is_root : t -> bool
(** Is the E_node.t a root (ie the representative of its class)?
See {!find} to get the root. *)
val iter_class : t -> t Iter.t
(** Traverse the congruence class.
Precondition: [is_root n] (see {!find} below) *)
val iter_parents : t -> t Iter.t
(** Traverse the parents of the class.
Precondition: [is_root n] (see {!find} below) *)
val as_lit : t -> Lit.t option
val swap_next : t -> t -> unit
(** Swap the next pointer of each node. If their classes were disjoint,
they are now unioned. *)
module Internal_ : sig
val iter_class_ : t -> t Iter.t
val make : Term.t -> t
end

50
src/cc/expl.ml Normal file
View file

@ -0,0 +1,50 @@
open Types_
type t = explanation
let rec pp out (e : explanation) =
match e with
| E_trivial -> Fmt.string out "reduction"
| E_lit lit -> Lit.pp out lit
| E_congruence (n1, n2) ->
Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2
| E_merge (a, b) ->
Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b
| E_merge_t (a, b) ->
Fmt.fprintf out "(@[<hv>merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp_debug a
Term.pp_debug b
| E_theory (t, u, es, _) ->
Fmt.fprintf out "(@[th@ :t `%a`@ :u `%a`@ :expl_sets %a@])" Term.pp_debug t
Term.pp_debug u
(Util.pp_list
@@ Fmt.Dump.triple Term.pp_debug Term.pp_debug (Fmt.Dump.list pp))
es
| E_and (a, b) -> Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
let mk_trivial : t = E_trivial
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2)
let[@inline] mk_merge a b : t =
if E_node.equal a b then
mk_trivial
else
E_merge (a, b)
let[@inline] mk_merge_t a b : t =
if Term.equal a b then
mk_trivial
else
E_merge_t (a, b)
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr)
let rec mk_list l =
match l with
| [] -> mk_trivial
| [ x ] -> x
| E_trivial :: tl -> mk_list tl
| x :: y ->
(match mk_list y with
| E_trivial -> x
| y' -> E_and (x, y'))

47
src/cc/expl.mli Normal file
View file

@ -0,0 +1,47 @@
(** Explanations
Explanations are specialized proofs, created by the congruence closure
when asked to justify why two terms are equal. *)
open Types_
type t = Types_.explanation
include Sidekick_sigs.PRINT with type t := t
val mk_merge : E_node.t -> E_node.t -> t
(** Explanation: the nodes were explicitly merged *)
val mk_merge_t : Term.t -> Term.t -> t
(** Explanation: the terms were explicitly merged *)
val mk_lit : Lit.t -> t
(** Explanation: we merged [t] and [u] because of literal [t=u],
or we merged [t] and [true] because of literal [t],
or [t] and [false] because of literal [¬t] *)
val mk_list : t list -> t
(** Conjunction of explanations *)
val mk_congruence : E_node.t -> E_node.t -> t
val mk_theory :
Term.t -> Term.t -> (Term.t * Term.t * t list) list -> Proof_term.step_id -> t
(** [mk_theory t u expl_sets pr] builds a theory explanation for
why [|- t=u]. It depends on sub-explanations [expl_sets] which
are tuples [ (t_i, u_i, expls_i) ] where [expls_i] are
explanations that justify [t_i = u_i] in the current congruence closure.
The proof [pr] is the theory lemma, of the form
[ (t_i = u_i)_i |- t=u ].
It is resolved against each [expls_i |- t_i=u_i] obtained from
[expl_sets], on pivot [t_i=u_i], to obtain a proof of [Gamma |- t=u]
where [Gamma] is a subset of the literals asserted into the congruence
closure.
For example for the lemma [a=b] deduced by injectivity
from [Some a=Some b] in the theory of datatypes,
the arguments would be
[a, b, [Some a, Some b, mk_merge_t (Some a)(Some b)], pr]
where [pr] is the injectivity lemma [Some a=Some b |- a=b].
*)

167
src/cc/plugin.ml Normal file
View file

@ -0,0 +1,167 @@
open Types_
open Sigs_plugin
module type EXTENDED_PLUGIN_BUILDER = sig
include MONOID_PLUGIN_BUILDER
val mem : t -> E_node.t -> bool
(** Does the CC.E_node.t have a monoid value? *)
val get : t -> E_node.t -> M.t option
(** Get monoid value for this CC.E_node.t, if any *)
val iter_all : t -> (CC.repr * M.t) Iter.t
include Sidekick_sigs.BACKTRACKABLE0 with type t := t
include Sidekick_sigs.PRINT with type t := t
end
module Make (M : MONOID_PLUGIN_ARG) :
EXTENDED_PLUGIN_BUILDER with module M = M = struct
module M = M
module Cls_tbl = Backtrackable_tbl.Make (E_node)
module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M
type t = (module DYN_PL_FOR_M)
module Make (A : sig
val size : int option
val cc : CC.t
end) : DYN_PL_FOR_M = struct
module M = M
module CC = CC
open A
(* plugin's state *)
let plugin_st = M.create cc
(* repr -> value for the class *)
let values : M.t Cls_tbl.t = Cls_tbl.create ?size ()
(* bit in CC to filter out quickly classes without value *)
let field_has_value : CC.bitfield =
CC.allocate_bitfield ~descr:("monoid." ^ M.name ^ ".has-value") cc
let push_level () = Cls_tbl.push_level values
let pop_levels n = Cls_tbl.pop_levels values n
let n_levels () = Cls_tbl.n_levels values
let mem n =
let res = CC.get_bitfield cc field_has_value n in
assert (
if res then
Cls_tbl.mem values n
else
true);
res
let get n =
if CC.get_bitfield cc field_has_value n then
Cls_tbl.get values n
else
None
let on_new_term cc n (t : Term.t) : CC.Handler_action.t list =
(*Log.debugf 50 (fun k->k "(@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);*)
let acts = ref [] in
let maybe_m, l = M.of_term cc plugin_st n t in
(match maybe_m with
| Some v ->
Log.debugf 20 (fun k ->
k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" M.name E_node.pp
n M.pp v);
CC.set_bitfield cc field_has_value true n;
Cls_tbl.add values n v
| None -> ());
List.iter
(fun (n_u, m_u) ->
Log.debugf 20 (fun k ->
k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
M.name E_node.pp n E_node.pp n_u M.pp m_u);
let n_u = CC.find cc n_u in
if CC.get_bitfield cc field_has_value n_u then (
let m_u' =
try Cls_tbl.find values n_u
with Not_found ->
Error.errorf "node %a has bitfield but no value" E_node.pp n_u
in
match M.merge cc plugin_st n_u m_u n_u m_u' (Expl.mk_list []) with
| Error (CC.Handler_action.Conflict expl) ->
Error.errorf
"when merging@ @[for node %a@],@ values %a and %a:@ conflict %a"
E_node.pp n_u M.pp m_u M.pp m_u' Expl.pp expl
| Ok (m_u_merged, merge_acts) ->
acts := List.rev_append merge_acts !acts;
Log.debugf 20 (fun k ->
k
"(@[monoid[%s].on-new-term.sub.merged@ :n %a@ :sub-t %a@ \
:value %a@])"
M.name E_node.pp n E_node.pp n_u M.pp m_u_merged);
Cls_tbl.add values n_u m_u_merged
) else (
(* just add to [n_u] *)
CC.set_bitfield cc field_has_value true n_u;
Cls_tbl.add values n_u m_u
))
l;
!acts
let iter_all : _ Iter.t = Cls_tbl.to_iter values
let on_pre_merge cc n1 n2 e_n1_n2 : CC.Handler_action.or_conflict =
let exception E of CC.Handler_action.conflict in
let acts = ref [] in
try
(match get n1, get n2 with
| Some v1, Some v2 ->
Log.debugf 5 (fun k ->
k
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 \
%a@ :val2 %a@])@])"
M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2);
(match M.merge cc plugin_st n1 v1 n2 v2 e_n1_n2 with
| Ok (v', merge_acts) ->
acts := merge_acts;
Cls_tbl.remove values n2;
(* only keep repr *)
Cls_tbl.add values n1 v'
| Error c -> raise (E c))
| None, Some cr ->
CC.set_bitfield cc field_has_value true n1;
Cls_tbl.add values n1 cr;
Cls_tbl.remove values n2 (* only keep reprs *)
| Some _, None -> () (* already there on the left *)
| None, None -> ());
Ok !acts
with E c -> Error c
let pp out () : unit =
let pp_e out (t, v) =
Fmt.fprintf out "(@[%a@ :has %a@])" E_node.pp t M.pp v
in
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) iter_all
let () =
(* hook into the CC's events *)
Event.on (CC.on_new_term cc) ~f:(fun (_, r, t) -> on_new_term cc r t);
Event.on (CC.on_pre_merge2 cc) ~f:(fun (_, ra, rb, expl) ->
on_pre_merge cc ra rb expl);
()
end
let create_and_setup ?size (cc : CC.t) : t =
(module Make (struct
let size = size
let cc = cc
end))
let push_level ((module P) : t) = P.push_level ()
let pop_levels ((module P) : t) n = P.pop_levels n
let n_levels ((module P) : t) = P.n_levels ()
let mem ((module P) : t) t = P.mem t
let get ((module P) : t) t = P.get t
let iter_all ((module P) : t) = P.iter_all
let pp out ((module P) : t) = P.pp out ()
end

21
src/cc/plugin.mli Normal file
View file

@ -0,0 +1,21 @@
(** Congruence Closure Plugin *)
open Sigs_plugin
module type EXTENDED_PLUGIN_BUILDER = sig
include MONOID_PLUGIN_BUILDER
val mem : t -> E_node.t -> bool
(** Does the CC.E_node.t have a monoid value? *)
val get : t -> E_node.t -> M.t option
(** Get monoid value for this CC.E_node.t, if any *)
val iter_all : t -> (CC.repr * M.t) Iter.t
include Sidekick_sigs.BACKTRACKABLE0 with type t := t
include Sidekick_sigs.PRINT with type t := t
end
(** Create a plugin builder from the given per-class monoid *)
module Make (M : MONOID_PLUGIN_ARG) : EXTENDED_PLUGIN_BUILDER with module M = M

5
src/cc/plugin/dune Normal file
View file

@ -0,0 +1,5 @@
(library
(name Sidekick_cc_plugin)
(public_name sidekick.cc.plugin)
(libraries containers iter sidekick.sigs sidekick.cc sidekick.util)
(flags :standard -w +32 -open Sidekick_util))

6
src/cc/resolved_expl.ml Normal file
View file

@ -0,0 +1,6 @@
open Types_
type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id }
let pp out (self : t) =
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits

17
src/cc/resolved_expl.mli Normal file
View file

@ -0,0 +1,17 @@
(** Resolved explanations.
The congruence closure keeps explanations for why terms are in the same
class. However these are represented in a compact, cheap form.
To use these explanations we need to {b resolve} them into a
resolved explanation, typically a list of
literals that are true in the current trail and are responsible for
merges.
However, we can also have merged classes because they have the same value
in the current model. *)
open Types_
type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id }
include Sidekick_sigs.PRINT with type t := t

48
src/cc/signature.ml Normal file
View file

@ -0,0 +1,48 @@
(** A signature is a shallow term shape where immediate subterms
are representative *)
open Sidekick_core.CC_view
open Types_
type t = signature
let equal (s1 : t) s2 : bool =
let open CC_view in
s1 == s2
||
match s1, s2 with
| Bool b1, Bool b2 -> b1 = b2
| App_fun (f1, []), App_fun (f2, []) -> Const.equal f1 f2
| App_fun (f1, l1), App_fun (f2, l2) ->
Const.equal f1 f2 && CCList.equal E_node.equal l1 l2
| App_ho (f1, a1), App_ho (f2, a2) -> E_node.equal f1 f2 && E_node.equal a1 a2
| Not a, Not b -> E_node.equal a b
| If (a1, b1, c1), If (a2, b2, c2) ->
E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2
| Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2
| Opaque u1, Opaque u2 -> E_node.equal u1 u2
| (Bool _ | App_fun _ | App_ho _ | If _ | Eq _ | Opaque _ | Not _), _ -> false
let hash (s : t) : int =
let module H = CCHash in
match s with
| Bool b -> H.combine2 10 (H.bool b)
| App_fun (f, l) -> H.combine3 20 (Const.hash f) (H.list E_node.hash l)
| App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a)
| Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b)
| Opaque u -> H.combine2 50 (E_node.hash u)
| If (a, b, c) ->
H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c)
| Not u -> H.combine2 70 (E_node.hash u)
let[@inline never] pp out = function
| Bool b -> Fmt.bool out b
| App_fun (f, []) -> Const.pp out f
| App_fun (f, l) ->
Fmt.fprintf out "(@[%a@ %a@])" Const.pp f (Util.pp_list E_node.pp) l
| App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a
| Opaque t -> E_node.pp out t
| Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u
| Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b
| If (a, b, c) ->
Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b E_node.pp c

97
src/cc/sigs_plugin.ml Normal file
View file

@ -0,0 +1,97 @@
open Types_
(* TODO: full EGG, also have a function to update the value when
the subterms (produced in [of_term]) are updated *)
(** Data attached to the congruence closure classes.
This helps theories keeping track of some state for each class.
The state of a class is the monoidal combination of the state for each
Term.t in the class (for example, the set of terms in the
class whose head symbol is a datatype constructor). *)
module type MONOID_PLUGIN_ARG = sig
type t
(** Some type with a monoid structure *)
include Sidekick_sigs.PRINT with type t := t
type state
val create : CC.t -> state
(** Initialize state from the congruence closure *)
val name : string
(** name of the monoid structure (short) *)
(* FIXME: for subs, return list of e_nodes, and assume of_term already
returned data for them. *)
val of_term :
CC.t -> state -> E_node.t -> Term.t -> t option * (E_node.t * t) list
(** [of_term n t], where [t] is the Term.t annotating node [n],
must return [maybe_m, l], where:
- [maybe_m = Some m] if [t] has monoid value [m];
otherwise [maybe_m=None]
- [l] is a list of [(u, m_u)] where each [u]'s Term.t
is a direct subterm of [t]
and [m_u] is the monoid value attached to [u].
*)
val merge :
CC.t ->
state ->
E_node.t ->
t ->
E_node.t ->
t ->
Expl.t ->
(t * CC.Handler_action.t list, CC.Handler_action.conflict) result
(** Monoidal combination of two values.
[merge cc n1 mon1 n2 mon2 expl] returns the result of merging
monoid values [mon1] (for class [n1]) and [mon2] (for class [n2])
when [n1] and [n2] are merged with explanation [expl].
@return [Ok mon] if the merge is acceptable, annotating the class of [n1 n2];
or [Error expl'] if the merge is unsatisfiable. [expl'] can then be
used to trigger a conflict and undo the merge.
*)
end
(** Stateful plugin holding a per-equivalence-class monoid.
Helps keep track of monoid state per equivalence class.
A theory might use one or more instance(s) of this to
aggregate some theory-specific state over all terms, with
the information of what terms are already known to be equal
potentially saving work for the theory. *)
module type DYN_MONOID_PLUGIN = sig
module M : MONOID_PLUGIN_ARG
include Sidekick_sigs.DYN_BACKTRACKABLE
val pp : unit Fmt.printer
val mem : E_node.t -> bool
(** Does the CC E_node.t have a monoid value? *)
val get : E_node.t -> M.t option
(** Get monoid value for this CC E_node.t, if any *)
val iter_all : (CC.repr * M.t) Iter.t
end
(** Builder for a plugin.
The builder takes a congruence closure, and instantiate the
plugin on it. *)
module type MONOID_PLUGIN_BUILDER = sig
module M : MONOID_PLUGIN_ARG
module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M
type t = (module DYN_PL_FOR_M)
val create_and_setup : ?size:int -> CC.t -> t
(** Create a new monoid state *)
end

39
src/cc/types_.ml Normal file
View file

@ -0,0 +1,39 @@
include Sidekick_core
type e_node = {
n_term: Term.t;
mutable n_sig0: signature option; (* initial signature *)
mutable n_bits: Bits.t; (* bitfield for various properties *)
mutable n_parents: e_node Bag.t; (* parent terms of this node *)
mutable n_root: e_node;
(* representative of congruence class (itself if a representative) *)
mutable n_next: e_node; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_as_lit: Lit.t option;
(* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link;
(* the rooted forest for explanations *)
}
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
and signature = (Const.t, e_node, e_node list) CC_view.t
and explanation_forest_link =
| FL_none
| FL_some of { next: e_node; expl: explanation }
(* atomic explanation in the congruence closure *)
and explanation =
| E_trivial (* by pure reduction, tautologically equal *)
| E_lit of Lit.t (* because of this literal *)
| E_merge of e_node * e_node
| E_merge_t of Term.t * Term.t
| E_congruence of e_node * e_node (* caused by normal congruence *)
| E_and of explanation * explanation
| E_theory of
Term.t
* Term.t
* (Term.t * Term.t * explanation list) list
* Proof_term.step_id

9
src/core-logic/bvar.ml Normal file
View file

@ -0,0 +1,9 @@
open Types_
type t = bvar = { bv_idx: int; bv_ty: term }
let equal (v1 : t) v2 = v1.bv_idx = v2.bv_idx && Term_.equal v1.bv_ty v2.bv_ty
let hash v = H.combine2 (H.int v.bv_idx) (Term_.hash v.bv_ty)
let pp out v = Fmt.fprintf out "bv[%d]" v.bv_idx
let[@inline] ty self = self.bv_ty
let make i ty : t = { bv_idx = i; bv_ty = ty }

10
src/core-logic/bvar.mli Normal file
View file

@ -0,0 +1,10 @@
(** Bound variable *)
open Types_
type t = bvar = { bv_idx: int; bv_ty: term }
include EQ_HASH_PRINT with type t := t
val make : int -> term -> t
val ty : t -> term

29
src/core-logic/const.ml Normal file
View file

@ -0,0 +1,29 @@
open Types_
type view = const_view = ..
module type DYN_OPS = sig
val pp : view Fmt.printer
val equal : view -> view -> bool
val hash : view -> int
end
type ops = (module DYN_OPS)
type t = const = { c_view: view; c_ops: ops; c_ty: term }
let[@inline] view self = self.c_view
let[@inline] ty self = self.c_ty
let equal (a : t) b =
let (module O) = a.c_ops in
O.equal a.c_view b.c_view && Term_.equal a.c_ty b.c_ty
let hash (a : t) : int =
let (module O) = a.c_ops in
H.combine2 (O.hash a.c_view) (Term_.hash a.c_ty)
let pp out (a : t) =
let (module O) = a.c_ops in
O.pp out a.c_view
let make c_view c_ops ~ty:c_ty : t = { c_view; c_ops; c_ty }

22
src/core-logic/const.mli Normal file
View file

@ -0,0 +1,22 @@
(** Constants.
Constants are logical symbols, defined by the user thanks to an open type *)
open Types_
type view = const_view = ..
module type DYN_OPS = sig
val pp : view Fmt.printer
val equal : view -> view -> bool
val hash : view -> int
end
type ops = (module DYN_OPS)
type t = const = { c_view: view; c_ops: ops; c_ty: term }
val view : t -> view
val make : view -> ops -> ty:term -> t
val ty : t -> term
include EQ_HASH_PRINT with type t := t

7
src/core-logic/dune Normal file
View file

@ -0,0 +1,7 @@
(library
(name sidekick_core_logic)
(public_name sidekick.core-logic)
(synopsis "Core AST for logic terms in the calculus of constructions")
(private_modules types_)
(flags :standard -w +32 -open Sidekick_sigs -open Sidekick_util)
(libraries containers iter sidekick.sigs sidekick.util))

View file

@ -0,0 +1,10 @@
module Term = Term
module Var = Var
module Bvar = Bvar
module Const = Const
module Subst = Subst
module T_builtins = T_builtins
module Store = Term.Store
(* TODO: move to separate library? *)
module Str_const = Str_const

View file

@ -0,0 +1,21 @@
open Types_
type const_view += Str of string
let ops : Const.ops =
(module struct
let pp out = function
| Str s -> Fmt.string out s
| _ -> assert false
let equal a b =
match a, b with
| Str s1, Str s2 -> s1 = s2
| _ -> false
let hash = function
| Str s -> CCHash.string s
| _ -> assert false
end)
let make name ~ty : Const.t = Const.make (Str name) ops ~ty

View file

@ -0,0 +1,10 @@
(** Basic string constants.
These constants are a string name, coupled with a type.
*)
open Types_
type const_view += private Str of string
val make : string -> ty:term -> const

25
src/core-logic/subst.ml Normal file
View file

@ -0,0 +1,25 @@
open Types_
module M = Var_.Map
type t = subst
let empty = { m = M.empty }
let is_empty self = M.is_empty self.m
let add v t self = { m = M.add v t self.m }
let pp out (self : t) =
if is_empty self then
Fmt.string out "(subst)"
else (
let pp_pair out (v, t) =
Fmt.fprintf out "(@[%a := %a@])" Var.pp v !Term_.pp_debug_ t
in
Fmt.fprintf out "(@[subst@ %a@])" (Util.pp_iter pp_pair) (M.to_iter self.m)
)
let of_list l = { m = M.of_list l }
let of_iter it = { m = M.of_iter it }
let to_iter self = M.to_iter self.m
let apply (store : Term.store) ~recursive (self : t) (e : term) : term =
Term.Internal_.subst_ store ~recursive e self

15
src/core-logic/subst.mli Normal file
View file

@ -0,0 +1,15 @@
(** Substitutions *)
open Types_
type t = subst
include PRINT with type t := t
val empty : t
val is_empty : t -> bool
val of_list : (var * term) list -> t
val of_iter : (var * term) Iter.t -> t
val to_iter : t -> (var * term) Iter.t
val add : var -> term -> t -> t
val apply : Term.store -> recursive:bool -> t -> term -> term

View file

@ -0,0 +1,116 @@
open Types_
open Term
type const_view += C_bool | C_eq | C_ite | C_not | C_true | C_false
let ops : const_ops =
(module struct
let equal a b =
match a, b with
| C_bool, C_bool
| C_eq, C_eq
| C_ite, C_ite
| C_not, C_not
| C_true, C_true
| C_false, C_false ->
true
| _ -> false
let hash = function
| C_bool -> CCHash.int 167
| C_eq -> CCHash.int 168
| C_ite -> CCHash.int 169
| C_not -> CCHash.int 170
| C_true -> CCHash.int 171
| C_false -> CCHash.int 172
| _ -> assert false
let pp out = function
| C_bool -> Fmt.string out "Bool"
| C_eq -> Fmt.string out "="
| C_ite -> Fmt.string out "ite"
| C_not -> Fmt.string out "not"
| C_true -> Fmt.string out "true"
| C_false -> Fmt.string out "false"
| _ -> assert false
end)
let bool store = const store @@ Const.make C_bool ops ~ty:(type_ store)
let true_ store = const store @@ Const.make C_true ops ~ty:(bool store)
let false_ store = const store @@ Const.make C_false ops ~ty:(bool store)
let bool_val store b =
if b then
true_ store
else
false_ store
let c_eq store =
let type_ = type_ store in
let v = bvar_i store 0 ~ty:type_ in
let ty =
DB.pi_db ~var_name:"A" store ~var_ty:type_
@@ arrow_l store [ v; v ] (bool store)
in
const store @@ Const.make C_eq ops ~ty
let c_ite store =
let type_ = type_ store in
let v = bvar_i store 0 ~ty:type_ in
let ty =
DB.pi_db ~var_name:"A" store ~var_ty:type_
@@ arrow_l store [ bool store; v; v ] v
in
const store @@ Const.make C_ite ops ~ty
let c_not store =
let b = bool store in
let ty = arrow store b b in
const store @@ Const.make C_not ops ~ty
let eq store a b =
if equal a b then
true_ store
else (
let a, b =
if compare a b <= 0 then
a, b
else
b, a
in
app_l store (c_eq store) [ ty a; a; b ]
)
let ite store a b c = app_l store (c_ite store) [ ty b; a; b; c ]
let not store a =
(* turn [not (not u)] into [u] *)
match view a with
| E_app ({ view = E_const { c_view = C_not; _ }; _ }, u) -> u
| E_const { c_view = C_true; _ } -> false_ store
| E_const { c_view = C_false; _ } -> true_ store
| _ -> app store (c_not store) a
let is_bool t =
match view t with
| E_const { c_view = C_bool; _ } -> true
| _ -> false
let is_eq t =
match view t with
| E_const { c_view = C_eq; _ } -> true
| _ -> false
let rec abs tst t =
match view t with
| E_app ({ view = E_const { c_view = C_not; _ }; _ }, u) ->
let sign, v = abs tst u in
Stdlib.not sign, v
| E_const { c_view = C_false; _ } -> false, true_ tst
| _ -> true, t
let as_bool_val t =
match Term.view t with
| Term.E_const { c_view = C_true; _ } -> Some true
| Term.E_const { c_view = C_false; _ } -> Some false
| _ -> None

View file

@ -0,0 +1,35 @@
(** Core builtins *)
open Types_
open Term
type const_view += C_bool | C_eq | C_ite | C_not | C_true | C_false
val bool : store -> t
val c_not : store -> t
val c_eq : store -> t
val c_ite : store -> t
val true_ : store -> t
val false_ : store -> t
val bool_val : store -> bool -> t
val eq : store -> t -> t -> t
(** [eq a b] is [a = b] *)
val not : store -> t -> t
val ite : store -> t -> t -> t -> t
(** [ite a b c] is [if a then b else c] *)
val is_eq : t -> bool
val is_bool : t -> bool
val abs : store -> t -> bool * t
(** [abs t] returns an "absolute value" for the term, along with the
sign of [t].
The idea is that we want to turn [not a] into [(false, a)],
or [(a != b)] into [(false, a=b)]. For terms without a negation this
should return [(true, t)]. *)
val as_bool_val : t -> bool option

690
src/core-logic/term.ml Normal file
View file

@ -0,0 +1,690 @@
open Types_
type nonrec var = var
type nonrec bvar = bvar
type nonrec term = term
type view = term_view =
| E_type of int
| E_var of var
| E_bound_var of bvar
| E_const of const
| E_app of term * term
| E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * term * term
| E_pi of string * term * term
type t = term
(* 5 bits in [t.id] are used to store which store it belongs to, so we have
a chance of detecting when the user passes a term to the wrong store *)
let store_id_bits = 5
(* mask to access the store id *)
let store_id_mask = (1 lsl store_id_bits) - 1
include Term_
let[@inline] view (e : term) : view = e.view
let[@inline] db_depth e = e.flags lsr (1 + store_id_bits)
let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1
let[@inline] store_uid e : int = e.flags land store_id_mask
let[@inline] is_closed e : bool = db_depth e == 0
(* slow path *)
let[@inline never] ty_force_delayed_ e f =
let ty = f () in
e.ty <- T_ty ty;
ty
let[@inline] ty e : term =
match e.ty with
| T_ty t -> t
| T_ty_delayed f -> ty_force_delayed_ e f
(* open an application *)
let unfold_app (e : term) : term * term list =
let[@unroll 1] rec aux acc e =
match e.view with
| E_app (f, a) -> aux (a :: acc) f
| _ -> e, acc
in
aux [] e
let[@inline] is_const e =
match e.view with
| E_const _ -> true
| _ -> false
let[@inline] is_app e =
match e.view with
| E_app _ -> true
| _ -> false
(* debug printer *)
let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let rec loop k ~depth names out e =
let pp' = loop k ~depth:(depth + 1) names in
(match e.view with
| E_type 0 -> Fmt.string out "Type"
| E_type i -> Fmt.fprintf out "Type(%d)" i
| E_var v -> Fmt.string out v.v_name
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
| E_bound_var v ->
let idx = v.bv_idx in
(match CCList.nth_opt names idx with
| Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx
| _ -> Fmt.fprintf out "_[%d]" idx)
| E_const c -> Const.pp out c
| (E_app _ | E_lam _) when depth > max_depth ->
Fmt.fprintf out "@<1>…/%d" e.id
| E_app _ ->
let f, args = unfold_app e in
Fmt.fprintf out "(%a@ %a)" pp' f (Util.pp_list pp') args
| E_lam ("", _ty, bod) ->
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_app_fold { f; args; acc0 } ->
Fmt.fprintf out "(@[%a" pp' f;
List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args;
Fmt.fprintf out "@ %a" pp' acc0;
Fmt.fprintf out "@])"
| E_lam (n, _ty, bod) ->
Fmt.fprintf out "(@[\\%s:@[%a@].@ %a@])" n pp' _ty
(loop (k + 1) ~depth:(depth + 1) (n :: names))
bod
| E_pi (_, ty, bod) when is_closed bod ->
(* actually just an arrow *)
Fmt.fprintf out "(@[%a@ -> %a@])"
(loop k ~depth:(depth + 1) names)
ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_pi ("", _ty, bod) ->
Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_pi (n, _ty, bod) ->
Fmt.fprintf out "(@[Pi %s:@[%a@].@ %a@])" n pp' _ty
(loop (k + 1) ~depth:(depth + 1) (n :: names))
bod);
if pp_ids then Fmt.fprintf out "/%d" e.id
in
Fmt.fprintf out "@[%a@]" (loop 0 ~depth:0 []) e
let pp_debug = expr_pp_with_ ~pp_ids:false ~max_depth:max_int
let pp_debug_with_ids = expr_pp_with_ ~pp_ids:true ~max_depth:max_int
let () = pp_debug_ := pp_debug
module AsKey = struct
type nonrec t = term
let equal = equal
let compare = compare
let hash = hash
end
module Map = CCMap.Make (AsKey)
module Set = CCSet.Make (AsKey)
module Tbl = CCHashtbl.Make (AsKey)
module Hcons = Hashcons.Make (struct
type nonrec t = term
let equal a b =
match a.view, b.view with
| E_type i, E_type j -> i = j
| E_const c1, E_const c2 -> Const.equal c1 c2
| E_var v1, E_var v2 -> Var.equal v1 v2
| E_bound_var v1, E_bound_var v2 -> Bvar.equal v1 v2
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
| E_app_fold a1, E_app_fold a2 ->
equal a1.f a2.f && equal a1.acc0 a2.acc0
&& CCList.equal equal a1.args a2.args
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _
| E_app_fold _ | E_lam _ | E_pi _ ),
_ ) ->
false
let hash e : int =
match e.view with
| E_type i -> H.combine2 10 (H.int i)
| E_const c -> H.combine2 20 (Const.hash c)
| E_var v -> H.combine2 30 (Var.hash v)
| E_bound_var v -> H.combine2 40 (Bvar.hash v)
| E_app (f, a) -> H.combine3 50 (hash f) (hash a)
| E_app_fold a ->
H.combine4 55 (hash a.f) (hash a.acc0) (Hash.list hash a.args)
| E_lam (_, ty, bod) -> H.combine3 60 (hash ty) (hash bod)
| E_pi (_, ty, bod) -> H.combine3 70 (hash ty) (hash bod)
let set_id t id =
assert (t.id == -1);
t.id <- id
end)
module Store = struct
type t = { (* unique ID for this store *)
s_uid: int; s_exprs: Hcons.t }
(* TODO: use atomic? CCAtomic? *)
let n = ref 0
let size self = Hcons.size self.s_exprs
let create ?(size = 256) () : t =
(* store id, modulo 2^5 *)
let s_uid = !n land store_id_mask in
incr n;
{ s_uid; s_exprs = Hcons.create ~size () }
(* check that [e] belongs in this store *)
let[@inline] check_e_uid (self : t) (e : term) =
assert (self.s_uid == store_uid e)
end
type store = Store.t
let iter_shallow ~f (e : term) : unit =
match e.view with
| E_type _ -> ()
| _ ->
f false (ty e);
(match e.view with
| E_const _ -> ()
| E_type _ -> assert false
| E_var v -> f false v.v_ty
| E_bound_var v -> f false v.bv_ty
| E_app (hd, a) ->
f false hd;
f false a
| E_app_fold { f = fold_f; args; acc0 } ->
f false fold_f;
f false acc0;
List.iter (fun u -> f false u) args
| E_lam (_, tyv, bod) | E_pi (_, tyv, bod) ->
f false tyv;
f true bod)
let map_shallow_ ~make ~f (e : term) : term =
match view e with
| E_type _ | E_const _ -> e
| E_var v ->
let v_ty = f false v.v_ty in
if v_ty == v.v_ty then
e
else
make (E_var { v with v_ty })
| E_bound_var v ->
let ty' = f false v.bv_ty in
if v.bv_ty == ty' then
e
else
make (E_bound_var { v with bv_ty = ty' })
| E_app (hd, a) ->
let hd' = f false hd in
let a' = f false a in
if a == a' && hd == hd' then
e
else
make (E_app (f false hd, f false a))
| E_app_fold { f = fold_f; args = l; acc0 } ->
let fold_f' = f false fold_f in
let l' = List.map (fun u -> f false u) l in
let acc0' = f false acc0 in
if equal fold_f fold_f' && equal acc0 acc0' && CCList.equal equal l l' then
e
else
make (E_app_fold { f = fold_f'; args = l'; acc0 = acc0' })
| E_lam (n, tyv, bod) ->
let tyv' = f false tyv in
let bod' = f true bod in
if tyv == tyv' && bod == bod' then
e
else
make (E_lam (n, tyv', bod'))
| E_pi (n, tyv, bod) ->
let tyv' = f false tyv in
let bod' = f true bod in
if tyv == tyv' && bod == bod' then
e
else
make (E_pi (n, tyv', bod'))
exception IsSub
let[@inline] is_type e =
match e.view with
| E_type _ -> true
| _ -> false
let[@inline] is_a_type (t : t) = is_type (ty t)
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
let rec loop e =
if not (Tbl.mem seen e) then (
Tbl.add seen e ();
if iter_ty && not (is_type e) then loop (ty e);
f e;
iter_shallow e ~f:(fun _ u -> loop u)
)
in
loop e
exception E_exit
let exists_shallow ~f e : bool =
try
iter_shallow e ~f:(fun b x -> if f b x then raise_notrace E_exit);
false
with E_exit -> true
let for_all_shallow ~f e : bool =
try
iter_shallow e ~f:(fun b x -> if not (f b x) then raise_notrace E_exit);
true
with E_exit -> false
let contains e ~sub : bool =
try
iter_dag ~iter_ty:true e ~f:(fun e' ->
if equal e' sub then raise_notrace IsSub);
false
with IsSub -> true
let free_vars_iter e : var Iter.t =
fun yield ->
iter_dag ~iter_ty:true e ~f:(fun e' ->
match view e' with
| E_var v -> yield v
| _ -> ())
let free_vars ?(init = Var.Set.empty) e : Var.Set.t =
let set = ref init in
free_vars_iter e (fun v -> set := Var.Set.add v !set);
!set
module Make_ = struct
let compute_db_depth_ e : int =
if is_type e then
0
else (
let d1 = db_depth @@ ty e in
let d2 =
match view e with
| E_type _ | E_const _ | E_var _ -> 0
| E_bound_var v -> v.bv_idx + 1
| E_app (a, b) -> max (db_depth a) (db_depth b)
| E_app_fold { f; acc0; args } ->
let m = max (db_depth f) (db_depth acc0) in
List.fold_left (fun x u -> max x (db_depth u)) m args
| E_lam (_, ty, bod) | E_pi (_, ty, bod) ->
max (db_depth ty) (max 0 (db_depth bod - 1))
in
max d1 d2
)
let compute_has_fvars_ e : bool =
if is_type e then
false
else
has_fvars (ty e)
||
match view e with
| E_var _ -> true
| E_type _ | E_bound_var _ | E_const _ -> false
| E_app (a, b) -> has_fvars a || has_fvars b
| E_app_fold { f; acc0; args } ->
has_fvars f || has_fvars acc0 || List.exists has_fvars args
| E_lam (_, ty, bod) | E_pi (_, ty, bod) -> has_fvars ty || has_fvars bod
let universe_ (e : term) : int =
match e.view with
| E_type i -> i
| _ -> assert false
let[@inline] universe_of_ty_ (e : term) : int =
match e.view with
| E_type i -> i + 1
| _ -> universe_ (ty e)
module T_int_tbl = CCHashtbl.Make (struct
type t = term * int
let equal (t1, k1) (t2, k2) = equal t1 t2 && k1 == k2
let hash (t, k) = H.combine3 27 (hash t) (H.int k)
end)
(* shift open bound variables of [e] by [n] *)
let db_shift_ ~make (e : term) (n : int) =
let rec loop e k : term =
if is_closed e then
e
else if is_type e then
e
else (
match view e with
| E_bound_var bv ->
if bv.bv_idx >= k then
make (E_bound_var (Bvar.make (bv.bv_idx + n) bv.bv_ty))
else
e
| _ ->
map_shallow_ e ~make ~f:(fun inbind u ->
loop u
(if inbind then
k + 1
else
k))
)
in
assert (n >= 0);
if n = 0 || is_closed e then
e
else
loop e 0
(* replace DB0 in [e] with [u] *)
let db_0_replace_ ~make e ~by:u : term =
let cache_ = T_int_tbl.create 8 in
(* recurse in subterm [e], under [k] intermediate binders (so any
bound variable under k is bound by them) *)
let rec aux e k : term =
if is_type e then
e
else if db_depth e < k then
e
else (
match view e with
| E_const _ -> e
| E_bound_var bv when bv.bv_idx = k ->
(* replace [bv] with [u], and shift [u] to account for the
[k] intermediate binders we traversed to get to [bv] *)
db_shift_ ~make u k
| _ ->
(* use the cache *)
(try T_int_tbl.find cache_ (e, k)
with Not_found ->
let r =
map_shallow_ e ~make ~f:(fun inb u ->
aux u
(if inb then
k + 1
else
k))
in
T_int_tbl.add cache_ (e, k) r;
r)
)
in
if is_closed e then
e
else
aux e 0
let compute_ty_ store ~make (view : view) : term =
match view with
| E_var v -> Var.ty v
| E_bound_var v -> Bvar.ty v
| E_type i -> make (E_type (i + 1))
| E_const c ->
let ty = Const.ty c in
Store.check_e_uid store ty;
if not (is_closed ty) then
Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp
c pp_debug ty;
ty
| E_lam (name, ty_v, bod) ->
Store.check_e_uid store ty_v;
Store.check_e_uid store bod;
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
let ty_bod = ty bod in
make (E_pi (name, ty_v, ty_bod))
| E_app (f, a) ->
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
is [ty_bod_f[x := a]] *)
Store.check_e_uid store f;
Store.check_e_uid store a;
let ty_f = ty f in
let ty_a = ty a in
(match ty_f.view with
| E_pi (_, ty_arg_f, ty_bod_f) ->
(* check that the expected type matches *)
if not (equal ty_arg_f ty_a) then
Error.errorf
"@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \
`%a`@ @[actual: `%a`@]@]"
pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids
ty_a;
db_0_replace_ ~make ty_bod_f ~by:a
| _ ->
Error.errorf
"@[<2>cannot apply %a@ (to %a),@ must have Pi type, but actual type \
is %a@]"
pp_debug f pp_debug a pp_debug ty_f)
| E_app_fold { args = []; _ } -> assert false
| E_app_fold { f; args = a0 :: other_args as args; acc0 } ->
Store.check_e_uid store f;
Store.check_e_uid store acc0;
List.iter (Store.check_e_uid store) args;
let ty_result = ty acc0 in
let ty_a0 = ty a0 in
(* check that all arguments have the same type *)
List.iter
(fun a' ->
let ty' = ty a' in
if not (equal ty_a0 ty') then
Error.errorf
"app_fold: arguments %a@ and %a@ have incompatible types" pp_debug
a0 pp_debug a')
other_args;
(* check that [f a0 acc0] has type [ty_result] *)
let app1 = make (E_app (make (E_app (f, a0)), acc0)) in
if not (equal (ty app1) ty_result) then
Error.errorf
"app_fold: single application `%a`@ has type `%a`,@ but should have \
type %a"
pp_debug app1 pp_debug (ty app1) pp_debug ty_result;
ty_result
| E_pi (_, ty, bod) ->
(* TODO: check the actual triplets for COC *)
(*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
Store.check_e_uid store ty;
Store.check_e_uid store bod;
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in
make (E_type u)
let ty_assert_false_ () = assert false
(* hashconsing + computing metadata + computing type (for new terms) *)
let rec make_ (store : store) view : term =
let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in
let e2 = Hcons.hashcons store.s_exprs e in
if e == e2 then (
(* new term, compute metadata *)
assert (store.s_uid land store_id_mask == store.s_uid);
(* first, compute type *)
(match e.view with
| E_type i ->
(* cannot force type now, as it's an infinite tower of types.
Instead we will produce the type on demand. *)
let get_ty () = make_ store (E_type (i + 1)) in
e.ty <- T_ty_delayed get_ty
| _ ->
let ty = compute_ty_ store ~make:(make_ store) view in
e.ty <- T_ty ty);
let has_fvars = compute_has_fvars_ e in
e2.flags <-
(compute_db_depth_ e lsl (1 + store_id_bits))
lor (if has_fvars then
1 lsl store_id_bits
else
0)
lor store.s_uid;
Store.check_e_uid store e2
);
e2
let type_of_univ store i : term = make_ store (E_type i)
let type_ store : term = type_of_univ store 0
let var store v : term = make_ store (E_var v)
let var_str store name ~ty : term = var store (Var.make name ty)
let bvar store v : term = make_ store (E_bound_var v)
let bvar_i store i ~ty : term = make_ store (E_bound_var (Bvar.make i ty))
let const store c : term = make_ store (E_const c)
let app store f a = make_ store (E_app (f, a))
let app_l store f l = List.fold_left (app store) f l
let app_fold store ~f ~acc0 args : t =
match args with
| [] -> acc0
| _ -> make_ store (E_app_fold { f; acc0; args })
type cache = t T_int_tbl.t
let create_cache : int -> cache = T_int_tbl.create
(* general substitution, compatible with DB indices. We use this
also to abstract on a free variable, because it subsumes it and
it's better to minimize the number of DB indices manipulations *)
let replace_ ?(cache = create_cache 8) ~make ~recursive e0 ~f : t =
let rec loop k e =
if is_type e then
e
else if not (has_fvars e) then
(* no free variables, cannot change *)
e
else (
try T_int_tbl.find cache (e, k)
with Not_found ->
let r = loop_uncached_ k e in
T_int_tbl.add cache (e, k) r;
r
)
and loop_uncached_ k (e : t) : t =
match f ~recurse:(loop k) e with
| None ->
map_shallow_ e ~make ~f:(fun inb u ->
loop
(if inb then
k + 1
else
k)
u)
| Some u ->
let u = db_shift_ ~make u k in
if recursive then
loop 0 u
else
u
in
loop 0 e0
let subst_ ~make ~recursive e0 (subst : subst) : t =
if Var_.Map.is_empty subst.m then
e0
else
replace_ ~make ~recursive e0 ~f:(fun ~recurse e ->
match view e with
| E_var v ->
(* first, subst in type *)
let v = { v with v_ty = recurse v.v_ty } in
Var_.Map.find_opt v subst.m
| _ -> None)
module DB = struct
let subst_db0 store e ~by : t = db_0_replace_ ~make:(make_ store) e ~by
let shift store t ~by : t =
assert (by >= 0);
db_shift_ ~make:(make_ store) t by
let lam_db ?(var_name = "") store ~var_ty bod : term =
make_ store (E_lam (var_name, var_ty, bod))
let pi_db ?(var_name = "") store ~var_ty bod : term =
make_ store (E_pi (var_name, var_ty, bod))
let abs_on (store : store) (v : var) (e : term) : term =
Store.check_e_uid store v.v_ty;
Store.check_e_uid store e;
if not (is_closed v.v_ty) then
Error.errorf "cannot abstract on variable@ with non closed type %a"
pp_debug v.v_ty;
let db0 = bvar store (Bvar.make 0 v.v_ty) in
let body = db_shift_ ~make:(make_ store) e 1 in
subst_ ~make:(make_ store) ~recursive:false body
{ m = Var_.Map.singleton v db0 }
end
let lam store v bod : term =
let bod' = DB.abs_on store v bod in
DB.lam_db ~var_name:(Var.name v) store ~var_ty:(Var.ty v) bod'
let pi store v bod : term =
let bod' = DB.abs_on store v bod in
DB.pi_db ~var_name:(Var.name v) store ~var_ty:(Var.ty v) bod'
let arrow store a b : term =
let b' = DB.shift store b ~by:1 in
DB.pi_db store ~var_ty:a b'
let arrow_l store args ret = List.fold_right (arrow store) args ret
(* find a name that doesn't capture a variable of [e] *)
let pick_name_ (name0 : string) (e : term) : string =
let rec loop i =
let name =
if i = 0 then
name0
else
Printf.sprintf "%s%d" name0 i
in
if free_vars_iter e |> Iter.exists (fun v -> v.v_name = name) then
loop (i + 1)
else
name
in
loop 0
let open_lambda store e : _ option =
match view e with
| E_lam (name, ty, bod) ->
let name = pick_name_ name bod in
let v = Var.make name ty in
let bod' = DB.subst_db0 store bod ~by:(var store v) in
Some (v, bod')
| _ -> None
let open_lambda_exn store e =
match open_lambda store e with
| Some tup -> tup
| None -> Error.errorf "open-lambda: term is not a lambda:@ %a" pp_debug e
end
include Make_
let map_shallow store ~f e : t = map_shallow_ ~make:(make_ store) ~f e
(* re-export some internal things *)
module Internal_ = struct
type nonrec cache = cache
let create_cache = create_cache
let replace_ ?cache store ~recursive t ~f =
replace_ ?cache ~make:(make_ store) ~recursive t ~f
let subst_ store ~recursive t subst =
subst_ ~make:(make_ store) ~recursive t subst
end

185
src/core-logic/term.mli Normal file
View file

@ -0,0 +1,185 @@
(** Core logic terms.
The core terms are expressions in the calculus of constructions,
with no universe polymorphism nor cumulativity. It should be fast, with hashconsing;
and simple enough (no inductives, no universe trickery).
It is intended to be the foundation for user-level terms and types and formulas.
*)
open Types_
type nonrec var = var
type nonrec bvar = bvar
type nonrec term = term
type t = term
(** A term, in the calculus of constructions *)
type store
(** The store for terms.
The store is responsible for allocating unique IDs to terms, and
enforcing their hashconsing (so that syntactic equality is just a pointer
comparison). *)
(** View.
A view is the shape of the root node of a term. *)
type view = term_view =
| E_type of int
| E_var of var
| E_bound_var of bvar
| E_const of const
| E_app of t * t
| E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * t * t
| E_pi of string * t * t
include EQ_ORD_HASH with type t := t
val pp_debug : t Fmt.printer
val pp_debug_with_ids : t Fmt.printer
(** {2 Containers} *)
include WITH_SET_MAP_TBL with type t := t
(** {2 Utils} *)
val view : t -> view
val unfold_app : t -> t * t list
val is_app : t -> bool
val is_const : t -> bool
val iter_dag : ?seen:unit Tbl.t -> iter_ty:bool -> f:(t -> unit) -> t -> unit
(** [iter_dag t ~f] calls [f] once on each subterm of [t], [t] included.
It must {b not} traverse [t] as a tree, but rather as a
perfectly shared DAG.
For example, in:
{[
let x = 2 in
let y = f x x in
let z = g y x in
z = z
]}
the DAG has the following nodes:
{[ n1: 2
n2: f n1 n1
n3: g n2 n1
n4: = n3 n3
]}
*)
val iter_shallow : f:(bool -> t -> unit) -> t -> unit
(** [iter_shallow f e] iterates on immediate subterms of [e],
calling [f trdb e'] for each subterm [e'], with [trdb = true] iff
[e'] is directly under a binder. *)
val map_shallow : store -> f:(bool -> t -> t) -> t -> t
val exists_shallow : f:(bool -> t -> bool) -> t -> bool
val for_all_shallow : f:(bool -> t -> bool) -> t -> bool
val contains : t -> sub:t -> bool
val free_vars_iter : t -> var Iter.t
val free_vars : ?init:Var.Set.t -> t -> Var.Set.t
val is_type : t -> bool
(** [is_type t] is true iff [view t] is [Type _] *)
val is_a_type : t -> bool
(** [is_a_type t] is true if [is_ty (ty t)] *)
val is_closed : t -> bool
(** Is the term closed (all bound variables are paired with a binder)?
time: O(1) *)
val has_fvars : t -> bool
(** Does the term contain free variables?
time: O(1) *)
val ty : t -> t
(** Return the type of this term. *)
(** {2 Creation} *)
module Store : sig
type t = store
val create : ?size:int -> unit -> t
val size : t -> int
end
val type_ : store -> t
val type_of_univ : store -> int -> t
val var : store -> var -> t
val var_str : store -> string -> ty:t -> t
val bvar : store -> bvar -> t
val bvar_i : store -> int -> ty:t -> t
val const : store -> const -> t
val app : store -> t -> t -> t
val app_l : store -> t -> t list -> t
val app_fold : store -> f:t -> acc0:t -> t list -> t
val lam : store -> var -> t -> t
val pi : store -> var -> t -> t
val arrow : store -> t -> t -> t
val arrow_l : store -> t list -> t -> t
val open_lambda : store -> t -> (var * t) option
val open_lambda_exn : store -> t -> var * t
(** De bruijn indices *)
module DB : sig
val lam_db : ?var_name:string -> store -> var_ty:t -> t -> t
(** [lam_db store ~var_ty bod] is [\ _:var_ty. bod]. Not DB shifting is done. *)
val pi_db : ?var_name:string -> store -> var_ty:t -> t -> t
(** [pi_db store ~var_ty bod] is [pi _:var_ty. bod]. Not DB shifting is done. *)
val subst_db0 : store -> t -> by:t -> t
(** [subst_db0 store t ~by] replaces bound variable 0 in [t] with
the term [by]. This is useful, for example, to implement beta-reduction.
For example, with [t] being [_[0] = (\x. _[2] _[1] x[0])],
[subst_db0 store t ~by:"hello"] is ["hello" = (\x. _[2] "hello" x[0])].
*)
val shift : store -> t -> by:int -> t
(** [shift store t ~by] shifts all bound variables in [t] that are not
closed on, by amount [by] (which must be >= 0).
For example, with term [t] being [\x. _[1] _[2] x[0]],
[shift store t ~by:5] is [\x. _[6] _[7] x[0]]. *)
val abs_on : store -> var -> t -> t
(** [abs_on store v t] is the term [t[v := _[0]]]. It replaces [v] with
the bound variable with the same type as [v], and the DB index 0,
and takes care of shifting if [v] occurs under binders.
For example, [abs_on store x (\y. x+y)] is [\y. _[1] y].
*)
end
(**/**)
module Internal_ : sig
type cache
val create_cache : int -> cache
val subst_ : store -> recursive:bool -> t -> subst -> t
val replace_ :
?cache:cache ->
store ->
recursive:bool ->
t ->
f:(recurse:(t -> t) -> t -> t option) ->
t
end
(**/**)

74
src/core-logic/types_.ml Normal file
View file

@ -0,0 +1,74 @@
module H = CCHash
type const_view = ..
module type DYN_CONST_OPS = sig
val pp : const_view Fmt.printer
val equal : const_view -> const_view -> bool
val hash : const_view -> int
end
type const_ops = (module DYN_CONST_OPS)
type term_view =
| E_type of int
| E_var of var
| E_bound_var of bvar
| E_const of const
| E_app of term * term
| E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * term * term
| E_pi of string * term * term
and var = { v_name: string; v_ty: term }
and bvar = { bv_idx: int; bv_ty: term }
and const = { c_view: const_view; c_ops: const_ops; c_ty: term }
and term = {
view: term_view;
(* computed on demand *)
mutable ty: term_ty_;
mutable id: int;
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
mutable flags: int;
}
and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term)
module Term_ = struct
let[@inline] equal (e1 : term) e2 : bool = e1 == e2
let[@inline] hash (e : term) = H.int e.id
let[@inline] compare (e1 : term) e2 : int = CCInt.compare e1.id e2.id
let pp_debug_ : term Fmt.printer ref = ref (fun _ _ -> assert false)
end
module Var_ = struct
let[@inline] equal v1 v2 =
v1.v_name = v2.v_name && Term_.equal v1.v_ty v2.v_ty
let[@inline] hash v1 = H.combine3 5 (H.string v1.v_name) (Term_.hash v1.v_ty)
let compare a b : int =
if Term_.equal a.v_ty b.v_ty then
String.compare a.v_name b.v_name
else
compare a.v_ty b.v_ty
module AsKey = struct
type nonrec t = var
let equal = equal
let compare = compare
let hash = hash
end
module Map = CCMap.Make (AsKey)
module Set = CCSet.Make (AsKey)
module Tbl = CCHashtbl.Make (AsKey)
end
type subst = { m: term Var_.Map.t } [@@unboxed]

14
src/core-logic/var.ml Normal file
View file

@ -0,0 +1,14 @@
open Types_
type t = var = { v_name: string; v_ty: term }
include Var_
let[@inline] name v = v.v_name
let[@inline] ty self = self.v_ty
let[@inline] pp out v1 = Fmt.string out v1.v_name
let make v_name v_ty : t = { v_name; v_ty }
let makef fmt ty = Fmt.kasprintf (fun s -> make s ty) fmt
let pp_with_ty out v =
Fmt.fprintf out "(@[%s :@ %a@])" v.v_name !Term_.pp_debug_ v.v_ty

Some files were not shown because too many files have changed in this diff Show more