wip: reimplement a fourier motzkin module, from scratch

This commit is contained in:
Simon Cruanes 2020-10-10 00:00:20 -04:00
parent c67e44e654
commit 9783c3ae1b
4 changed files with 269 additions and 85 deletions

View file

@ -4,7 +4,7 @@ module Fmt = CCFormat
module CC_view = Sidekick_core.CC_view
type lra_pred = Sidekick_lra.pred = Lt | Leq | Eq | Neq | Geq | Gt
type lra_pred = Sidekick_lra.FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq
type lra_op = Sidekick_lra.op = Plus | Minus
type 'a lra_view = 'a Sidekick_lra.lra_view =
@ -894,6 +894,8 @@ end = struct
| Not u -> u, false
| App_fun ({fun_view=Fun_def def; _}, args) ->
def.abs ~self:t args (* TODO: pass state *)
| LRA (LRA_pred (Neq, a, b)) ->
lra tst (LRA_pred (Eq,a,b)), false (* != is just not eq *)
| _ -> t, true
let[@inline] is_true t = match view t with Bool true -> true | _ -> false

View file

@ -6,7 +6,9 @@
open Sidekick_core
type pred = Lt | Leq | Eq | Neq | Geq | Gt
module FM = Fourier_motzkin
type pred = FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq
type op = Plus | Minus
type 'a lra_view =
@ -25,24 +27,6 @@ let map_view f (l:_ lra_view) : _ lra_view =
| LRA_other x -> LRA_other (f x)
end
(* TODO: upstream *)
let neg_pred = function
| Leq -> Gt
| Lt -> Geq
| Eq -> Neq
| Neq -> Eq
| Geq -> Lt
| Gt -> Leq
let pred_to_funarith = function
| Leq -> `Leq
| Lt -> `Lt
| Geq -> `Geq
| Gt -> `Gt
| Eq -> `Eq
| Neq -> `Neq
module type ARG = sig
module S : Sidekick_core.SOLVER
@ -81,47 +65,30 @@ module Make(A : ARG) : S with module A = A = struct
module Lit = A.S.Solver_internal.Lit
module SI = A.S.Solver_internal
type simp_var =
| V_fresh of int
| V_t of T.t
(* the fourier motzkin module *)
module FM_A = FM.Make(struct
module T = T
type tag = Lit.t
end)
(** Simplex variables *)
module Simp_vars = struct
type t = simp_var
let compare a b =
match a, b with
| V_fresh i, V_fresh j -> CCInt.compare i j
| V_fresh _, V_t _ -> -1
| V_t _, V_fresh _ -> 1
| V_t t1, V_t t2 -> T.compare t1 t2
let pp out = function
| V_fresh i -> Fmt.fprintf out "$fresh_%d" i
| V_t t -> T.pp out t
module Fresh = struct
type t = int ref
let create() : t = ref 0
let fresh n = V_fresh (CCRef.get_then_incr n)
end
end
module Simplex = Funarith_zarith.Simplex.Make_full(Simp_vars)
module LE = Simplex.L.Expr
module LComb = Simplex.L.Comb
module Constr = Simplex.L.Constr
(* linear expressions *)
module LE = FM_A.LE
type state = {
tst: T.state;
simps: T.t T.Tbl.t; (* cache *)
simplex: Simplex.t;
gensym: A.Gensym.t;
neq_encoded: unit T.Tbl.t;
(* if [a != b] asserted and not in this table, add clause [a = b \/ a<b \/ a>b] *)
mutable t_defs: (T.t * LE.t) list; (* term definitions *)
pred_defs: (pred * LE.t * LE.t) T.Tbl.t; (* predicate definitions *)
}
let create tst : state =
{ tst; simps=T.Tbl.create 128;
{ tst;
simps=T.Tbl.create 128;
gensym=A.Gensym.create tst;
simplex=Simplex.create();
neq_encoded=T.Tbl.create 16;
t_defs=[];
pred_defs=T.Tbl.create 16;
}
@ -179,7 +146,7 @@ module Make(A : ARG) : S with module A = A = struct
let rec as_linexp (t:T.t) : LE.t =
let open LE.Infix in
match A.view_as_lra t with
| LRA_other _ -> LE.of_list Q.zero [Q.one, V_t t]
| LRA_other _ -> LE.var t
| LRA_pred _ ->
Error.errorf "type error: in linexp, LRA predicate %a" T.pp t
| LRA_op (op, t1, t2) ->
@ -192,7 +159,7 @@ module Make(A : ARG) : S with module A = A = struct
| LRA_mult (n, x) ->
let t = as_linexp x in
LE.( n * t )
| LRA_const q -> LE.of_const q
| LRA_const q -> LE.const q
(* TODO: keep the linexps until they're asserted;
TODO: but use simplification in preprocess
@ -220,18 +187,43 @@ module Make(A : ARG) : S with module A = A = struct
Some proxy
| LRA_const _ | LRA_other _ -> None
let final_check_ (self:state) _si (_acts:SI.actions) (trail:_ Iter.t) : unit =
(* partial check: just ensure [a != b] triggers the clause
[a=b \/ a<b \/ a>b] *)
let partial_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit =
let tst = self.tst in
begin
trail
|> Iter.filter (fun lit -> not (Lit.sign lit))
|> Iter.filter_map
(fun lit ->
let t = Lit.term lit in
match A.view_as_lra t with
| LRA_pred (Eq, a, b) when not (T.Tbl.mem self.neq_encoded t) ->
Some (lit, a,b)
| _ -> None)
|> Iter.iter
(fun (lit,a,b) ->
let c = [
Lit.abs lit;
SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, a, b)));
SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, b, a)));
] in
SI.add_clause_permanent si acts c;
T.Tbl.add self.neq_encoded (Lit.term lit) ();
)
end
let final_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit =
Log.debug 5 "(th-lra.final-check)";
let simplex = Simplex.create() in
let fm = FM_A.create() in
(* first, add definitions *)
begin
List.iter
(fun (t,le) ->
let open LE.Infix in
let c =
Constr.of_expr (le - LE.of_comb (LComb.monomial1 (V_t t))) `Eq
in
Simplex.add_constr simplex c)
let le = le - LE.var t in
let c = FM_A.Constr.mk ?tag:None Eq (LE.var t) le in
FM_A.assert_c fm c)
self.t_defs
end;
(* add trail *)
@ -245,35 +237,24 @@ module Make(A : ARG) : S with module A = A = struct
| exception Not_found -> ()
| (pred, a, b) ->
let open LE.Infix in
let e = a - b in
let pred = if sign then pred else neg_pred pred in
let pred = match pred_to_funarith pred with
| `Neq -> Sidekick_util.Error.errorf "cannot handle negative LEQ equality"
| (`Eq | `Geq | `Gt | `Leq | `Lt) as p -> p
in
let c = Constr.of_expr e pred in
Simplex.add_constr simplex c;
let pred = if sign then pred else FM.Pred.neg pred in
let c = FM_A.Constr.mk ~tag:lit pred a b in
FM_A.assert_c fm c;
end)
end;
Log.debug 5 "lra: call simplex";
begin match Simplex.solve simplex with
| Simplex.Solution _ ->
Log.debug 5 "lra: simplex returns SAT";
() (* TODO: model combination *)
| Simplex.Unsatisfiable cert ->
Log.debugf 5 (fun k->k"lra: simplex returns UNSAT@ with cert %a" Simplex.pp_cert cert);
(* find what terms are involved *)
let asserts =
cert.Simplex.cert_expr
|> Iter.of_list
|> Iter.filter_map
(function
| V_t -> Some t
| V_fresh _ -> None)
|> Iter.to_rev_list
in
Simplex.cert
() (* TODO: produce conflict *)
Log.debug 5 "lra: call arith solver";
begin match FM_A.solve fm with
| FM_A.Sat ->
Log.debug 5 "lra: solver returns SAT";
() (* TODO: get a model + model combination *)
| FM_A.Unsat lits ->
(* we tagged assertions with their lit, so the certificate being an
unsat core translates directly into a conflict clause *)
Log.debugf 5 (fun k->k"lra: solver returns UNSAT@ with cert %a"
(Fmt.Dump.list Lit.pp) lits);
let confl = List.rev_map Lit.neg lits in
(* TODO: produce and store a proper LRA resolution proof *)
SI.raise_conflict si acts confl SI.P.default
end;
()
@ -282,6 +263,7 @@ module Make(A : ARG) : S with module A = A = struct
let st = create (SI.tst si) in
(* TODO SI.add_simplifier si (simplify st); *)
SI.add_preprocess si (preproc_lra st);
SI.on_partial_check si (partial_check_ st);
SI.on_final_check si (final_check_ st);
(* SI.add_preprocess si (cnf st); *)
(* TODO: theory combination *)

View file

@ -4,4 +4,4 @@
(public_name sidekick.th-lra)
(optional) ; only if deps present
(flags :standard -warn-error -a+8 -open Sidekick_util)
(libraries containers sidekick.core zarith funarith.zarith funarith))
(libraries containers sidekick.core zarith))

View file

@ -0,0 +1,200 @@
module type ARG = sig
(** terms *)
module T : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val compare : t -> t -> int
val pp : t Fmt.printer
end
type tag
end
module Pred : sig
type t = Lt | Leq | Geq | Gt | Neq | Eq
val neg : t -> t
val pp : t Fmt.printer
val to_string : t -> string
end = struct
type t = Lt | Leq | Geq | Gt | Neq | Eq
let to_string = function
| Lt -> "<"
| Leq -> "<="
| Eq -> "="
| Neq -> "!="
| Gt -> ">"
| Geq -> ">="
let neg = function
| Leq -> Gt
| Lt -> Geq
| Eq -> Neq
| Neq -> Eq
| Geq -> Lt
| Gt -> Leq
let pp out p = Fmt.string out (to_string p)
end
module type S = sig
module A : ARG
type t
type expr = A.T.t
module LE : sig
type t
val const : Q.t -> t
val var : expr -> t
module Infix : sig
val (+) : t -> t -> t
val (-) : t -> t -> t
val ( * ) : Q.t -> t -> t
end
include module type of Infix
val pp : t Fmt.printer
end
(** {3 Arithmetic constraint} *)
module Constr : sig
type t = {
pred: Pred.t;
le: LE.t;
tag: A.tag option;
}
val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t
val pp : t Fmt.printer
end
val create : unit -> t
val assert_c : t -> Constr.t -> unit
type res =
| Sat
| Unsat of A.tag list
val solve : t -> res
end
module Make(A : ARG)
: S with module A = A
= struct
module A = A
module T = A.T
module T_set = CCSet.Make(A.T)
module T_map = CCMap.Make(A.T)
type expr = A.T.t
module LE = struct
module M = T_map
type t = {
le: Q.t M.t;
const: Q.t;
}
let const x : t = {const=x; le=M.empty}
let var x : t = {const=Q.zero; le=M.singleton x Q.one}
let (+) a b : t =
{const = Q.(a.const + b.const);
le=M.merge_safe a.le b.le
~f:(fun _ -> function
| `Left x | `Right x -> Some x
| `Both (x,y) ->
let z = Q.(x + y) in
if Q.sign z = 0 then None else Some z)
}
let (-) a b : t =
{const = Q.(a.const - b.const);
le=M.merge_safe a.le b.le
~f:(fun _ -> function
| `Left x -> Some x
| `Right x -> Some (Q.neg x)
| `Both (x,y) ->
let z = Q.(x - y) in
if Q.sign z = 0 then None else Some z)
}
let ( * ) x a : t =
if Q.sign x = 0 then const Q.zero
else (
{const=Q.( a.const * x );
le=M.map (fun y -> Q.(x * y)) a.le
}
)
module Infix = struct
let (+) = (+)
let (-) = (-)
let ( * ) = ( * )
end
let vars self = T_map.keys self.le
let pp out (self:t) : unit =
let pp_pair out (e,q) =
if Q.equal Q.one q then T.pp out e
else Fmt.fprintf out "%a * %a" Q.pp_print q T.pp e
in
Fmt.fprintf out "(@[%a@ + %a@])"
Q.pp_print self.const (Util.pp_iter ~sep:" + " pp_pair) (M.to_iter self.le)
end
module Constr = struct
type t = {
pred: Pred.t;
le: LE.t;
tag: A.tag option;
}
let pp out (c:t) : unit =
Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])"
LE.pp c.le (Pred.to_string c.pred)
let mk ?tag pred l1 l2 : t =
{pred; tag; le=LE.(l1 - l2); }
end
type t = {
mutable cs: Constr.t list;
mutable all_vars: T_set.t;
}
let create () : t = {
cs=[];
all_vars=T_set.empty;
}
let assert_c (self:t) c : unit =
self.cs <- c :: self.cs;
self.all_vars <- c.Constr.le |> LE.vars |> T_set.add_iter self.all_vars;
()
(* TODO: be able to provide a model for SAT *)
type res =
| Sat
| Unsat of A.tag list
let solve (self:t) : res =
Log.debugf 5
(fun k->k"(@[FM.solve@ %a@])" (Util.pp_list Constr.pp) self.cs);
assert false
end