perf(checker): optimize watch literals

This commit is contained in:
Simon Cruanes 2021-08-08 02:36:30 -04:00
parent ab6e574298
commit b4231d23c1
5 changed files with 65 additions and 41 deletions

View file

@ -3,6 +3,6 @@
(name main) (name main)
(public_name sidekick-checker) (public_name sidekick-checker)
(package sidekick-bin) (package sidekick-bin)
(libraries containers sidekick-bin.lib (libraries containers sidekick-bin.lib mtime mtime.clock.os
sidekick.util sidekick.tef sidekick.drup) sidekick.util sidekick.tef sidekick.drup)
(flags :standard -warn-error -a+8 -open Sidekick_util)) (flags :standard -warn-error -a+8 -open Sidekick_util))

View file

@ -68,14 +68,18 @@ let () =
Arg.parse opts (fun f -> files := f :: !files) "checker [opt]* [file]+"; Arg.parse opts (fun f -> files := f :: !files) "checker [opt]* [file]+";
begin match List.rev !files with let ok =
match List.rev !files with
| [pb; proof] -> | [pb; proof] ->
Log.debugf 1 (fun k->k"checker: problem `%s`, proof `%s`" pb proof); Log.debugf 1 (fun k->k"checker: problem `%s`, proof `%s`" pb proof);
let ok = check ~pb proof in check ~pb proof
if not ok then exit 1
| [proof] -> | [proof] ->
Log.debugf 1 (fun k->k"checker: proof `%s`" proof); Log.debugf 1 (fun k->k"checker: proof `%s`" proof);
let ok = check ?pb:None proof in check ?pb:None proof
| _ -> Error.errorf "expected <problem>? <proof>"
in
let t2 = Mtime_clock.elapsed () |> Mtime.Span.to_s in
Format.printf "c %s@." (if ok then "OK" else "FAIL");
Format.printf "c elapsed time: %.3fs@." t2;
if not ok then exit 1 if not ok then exit 1
| _ -> failwith "expected <problem>? <proof>"
end

View file

@ -13,6 +13,7 @@ module Atom : sig
val sign : t -> bool val sign : t -> bool
val pp : t Fmt.printer val pp : t Fmt.printer
val dummy : t
val of_int_unsafe : int -> t val of_int_unsafe : int -> t
module Map : CCMap.S with type key = t module Map : CCMap.S with type key = t
end = struct end = struct
@ -29,6 +30,7 @@ end = struct
let pp out x = let pp out x =
Fmt.fprintf out "%s%d" (if sign x then "+" else "-") (x lsr 1) Fmt.fprintf out "%s%d" (if sign x then "+" else "-") (x lsr 1)
let of_int_unsafe i = i let of_int_unsafe i = i
let dummy = 0
module Map = Util.Int_map module Map = Util.Int_map
end end
type atom = Atom.t type atom = Atom.t
@ -41,7 +43,7 @@ module Clause : sig
val size : t -> int val size : t -> int
val get : t -> int -> atom val get : t -> int -> atom
val iter : f:(atom -> unit) -> t -> unit val iter : f:(atom -> unit) -> t -> unit
val watches: t -> (atom * atom) option val watches: t -> atom * atom
val set_watches : t -> atom * atom -> unit val set_watches : t -> atom * atom -> unit
val pp : t Fmt.printer val pp : t Fmt.printer
val of_list : store -> atom list -> t val of_list : store -> atom list -> t
@ -56,7 +58,7 @@ end = struct
type t = { type t = {
id: int; id: int;
atoms: atom array; atoms: atom array;
mutable watches: (atom * atom) option; mutable watches: atom * atom;
} }
type store = { type store = {
mutable n: int; mutable n: int;
@ -66,12 +68,12 @@ end = struct
let size self = Array.length self.atoms let size self = Array.length self.atoms
let get self i = Array.get self.atoms i let get self i = Array.get self.atoms i
let watches self = self.watches let watches self = self.watches
let set_watches self w = self.watches <- Some w let set_watches self w = self.watches <- w
let iter ~f self = Array.iter f self.atoms let iter ~f self = Array.iter f self.atoms
let pp out (self:t) = let pp out (self:t) =
let pp_watches out = function let pp_watches out = function
| None -> () | (p,q) when p=Atom.dummy || q=Atom.dummy -> ()
| Some (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in | (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in
Fmt.fprintf out "(@[cl[%d]@ %a%a])" Fmt.fprintf out "(@[cl[%d]@ %a%a])"
self.id (Fmt.Dump.array Atom.pp) self.atoms pp_watches self.watches self.id (Fmt.Dump.array Atom.pp) self.atoms pp_watches self.watches
let of_list self atoms : t = let of_list self atoms : t =
@ -79,7 +81,7 @@ end = struct
let atoms = List.sort_uniq Atom.compare atoms |> Array.of_list in let atoms = List.sort_uniq Atom.compare atoms |> Array.of_list in
let id = self.n in let id = self.n in
self.n <- 1 + self.n; self.n <- 1 + self.n;
let c = {atoms; id; watches=None} in let c = {atoms; id; watches=Atom.dummy, Atom.dummy} in
c c
module As_key = struct module As_key = struct
type nonrec t=t type nonrec t=t
@ -161,7 +163,6 @@ end
Each event is checked by reverse-unit propagation on previous events. *) Each event is checked by reverse-unit propagation on previous events. *)
module Fwd_check : sig module Fwd_check : sig
type error = type error =
[ `Bad_steps of VecI32.t [ `Bad_steps of VecI32.t
| `No_empty_clause | `No_empty_clause
@ -173,7 +174,6 @@ module Fwd_check : sig
success. In case of error it returns [Error idxs] where [idxs] are the success. In case of error it returns [Error idxs] where [idxs] are the
indexes in the trace of the steps that failed. *) indexes in the trace of the steps that failed. *)
val check : Trace.t -> (unit, error) result val check : Trace.t -> (unit, error) result
end = struct end = struct
module ISet = CCSet.Make(CCInt) module ISet = CCSet.Make(CCInt)
@ -183,7 +183,7 @@ end = struct
trail: VecI32.t; (* current assignment *) trail: VecI32.t; (* current assignment *)
mutable trail_ptr : int; (* offset in trail for propagation *) mutable trail_ptr : int; (* offset in trail for propagation *)
active_clauses: unit Clause.Tbl.t; active_clauses: unit Clause.Tbl.t;
watches: Clause.Set.t Vec.t; (* atom -> clauses it watches *) watches: Clause.t Vec.t Vec.t; (* atom -> clauses it watches *)
errors: VecI32.t; errors: VecI32.t;
} }
@ -202,7 +202,7 @@ end = struct
Bitvec.ensure_size self.assign (a:atom:>int); Bitvec.ensure_size self.assign (a:atom:>int);
(* size: 2+atom, because: 1+atom makes atom valid, and if it's positive, (* size: 2+atom, because: 1+atom makes atom valid, and if it's positive,
2+atom is (¬atom)+1 *) 2+atom is (¬atom)+1 *)
Vec.ensure_size self.watches Clause.Set.empty (2+(a:atom:>int)); Vec.ensure_size_with self.watches Vec.create (2+(a:atom:>int));
() ()
let[@inline] is_true self (a:atom) : bool = let[@inline] is_true self (a:atom) : bool =
@ -215,12 +215,11 @@ end = struct
not (is_true self a) && not (is_false self a) not (is_true self a) && not (is_false self a)
let add_watch_ self (a:atom) (c:clause) = let add_watch_ self (a:atom) (c:clause) =
let set = Vec.get self.watches (a:atom:>int) in Vec.push (Vec.get self.watches (a:atom:>int)) c
Vec.set self.watches (a:atom:>int) (Clause.Set.add c set)
let remove_watch_ self (a:atom) (c:clause) = let remove_watch_ self (a:atom) idx =
let set = Vec.get self.watches (a:atom:>int) in let v = Vec.get self.watches (a:atom:>int) in
Vec.set self.watches (a:atom:>int) (Clause.Set.remove c set) Vec.fast_remove v idx
exception Conflict exception Conflict
@ -251,19 +250,19 @@ end = struct
try Clause.iter c ~f:(fun a -> if not (is_false self a) then raise Exit); true try Clause.iter c ~f:(fun a -> if not (is_false self a) then raise Exit); true
with Exit -> false with Exit -> false
type propagation_res =
| Keep
| Remove
(* do boolean propagation in [c], which is watched by the true literal [a] *) (* do boolean propagation in [c], which is watched by the true literal [a] *)
let propagate_in_clause_ (self:t) (a:atom) (c:clause) : unit = let propagate_in_clause_ (self:t) (a:atom) (c:clause) : propagation_res =
assert (is_true self a); assert (is_true self a);
let a1, a2 = let a1, a2 = Clause.watches c in
match Clause.watches c with
| None -> assert false
| Some tup -> tup
in
let na = Atom.neg a in let na = Atom.neg a in
(* [q] is the other literal in [c] such that [¬q] watches [c]. *) (* [q] is the other literal in [c] such that [¬q] watches [c]. *)
let q = if Atom.equal a1 na then a2 else (assert(a2==na); a1) in let q = if Atom.equal a1 na then a2 else (assert(a2==na); a1) in
try try
if is_true self q then () (* clause is satisfied *) if is_true self q then Keep (* clause is satisfied *)
else ( else (
let n_unassigned = ref 0 in let n_unassigned = ref 0 in
let unassigned_a = ref a in (* an unassigned atom, if [!n_unassigned > 0] *) let unassigned_a = ref a in (* an unassigned atom, if [!n_unassigned > 0] *)
@ -294,21 +293,28 @@ end = struct
let p = !unassigned_a in let p = !unassigned_a in
Log.debugf 30 (fun k->k"(@[propagate@ :atom %a@ :reason %a@])" Atom.pp p Clause.pp c); Log.debugf 30 (fun k->k"(@[propagate@ :atom %a@ :reason %a@])" Atom.pp p Clause.pp c);
set_atom_true self p; set_atom_true self p;
Keep
) else ( ) else (
(* at least 2 unassigned, just update the watch literal to [¬p] *) (* at least 2 unassigned, just update the watch literal to [¬p] *)
let p = !unassigned_a in let p = !unassigned_a in
assert (p <> q); assert (p <> q);
Clause.set_watches c (q, p); Clause.set_watches c (q, p);
remove_watch_ self a c;
add_watch_ self (Atom.neg p) c; add_watch_ self (Atom.neg p) c;
Remove
); );
) )
with with
| Is_sat -> () | Is_sat -> Keep
let propagate_atom_ self (a:atom) : unit = let propagate_atom_ self (a:atom) : unit =
let set = Vec.get self.watches (a:atom:>int) in let v = Vec.get self.watches (a:atom:>int) in
Clause.Set.iter (propagate_in_clause_ self a) set let i = ref 0 in
while !i < Vec.size v do
match propagate_in_clause_ self a (Vec.get v !i) with
| Keep -> incr i;
| Remove ->
remove_watch_ self a !i
done
(* perform boolean propagation in a fixpoint (* perform boolean propagation in a fixpoint
@raise Conflict if a clause is false *) @raise Conflict if a clause is false *)
@ -360,15 +366,21 @@ end = struct
let c0 = Clause.get c 0 in let c0 = Clause.get c 0 in
let c1 = Clause.get c 1 in let c1 = Clause.get c 1 in
assert (c0 <> c1); assert (c0 <> c1);
add_watch_ self (Atom.neg c0) c;
add_watch_ self (Atom.neg c1) c;
Clause.set_watches c (c0,c1); Clause.set_watches c (c0,c1);
(* make sure watches are valid *) (* make sure watches are valid *)
if is_false self c0 then ( if is_false self c0 then (
propagate_in_clause_ self (Atom.neg c0) c; match propagate_in_clause_ self (Atom.neg c0) c with
| Keep -> add_watch_ self (Atom.neg c0) c;
| Remove -> ()
) else (
add_watch_ self (Atom.neg c0) c
); );
if is_false self c1 then ( if is_false self c1 then (
propagate_in_clause_ self (Atom.neg c1) c; match propagate_in_clause_ self (Atom.neg c1) c with
| Keep -> add_watch_ self (Atom.neg c1) c;
| Remove -> ()
) else (
add_watch_ self (Atom.neg c1) c
) )
end; end;
() ()

View file

@ -48,15 +48,20 @@ let resize_ t x size =
let ensure_cap_ self x n = let ensure_cap_ self x n =
if n > Array.length self.data then ( if n > Array.length self.data then (
let new_size = max n (2 * Array.length self.data) in let new_size = max n (2 * Array.length self.data) in
resize_ self x new_size resize_ self (x()) new_size
) )
let ensure_size self x n = let ensure_size_with self f n =
ensure_cap_ self x n; ensure_cap_ self f n;
if n > self.sz then ( if n > self.sz then (
for i=self.sz to n-1 do
self.data.(i) <- f();
done;
self.sz <- n self.sz <- n
) )
let ensure_size self x n = ensure_size_with self (fun() -> x) n
(* grow the array *) (* grow the array *)
let[@inline never] grow_to_double_size t x : unit = let[@inline never] grow_to_double_size t x : unit =
let new_size = let new_size =

View file

@ -31,6 +31,9 @@ val clear : 'a t -> unit
val ensure_size : 'a t -> 'a -> int -> unit val ensure_size : 'a t -> 'a -> int -> unit
(** ensure size is at least [n] *) (** ensure size is at least [n] *)
val ensure_size_with : 'a t -> (unit -> 'a) -> int -> unit
(** ensure size is at least [n] *)
val shrink : 'a t -> int -> unit val shrink : 'a t -> int -> unit
(** [shrink vec sz] resets size of [vec] to [sz]. (** [shrink vec sz] resets size of [vec] to [sz].
Assumes [sz >=0 && sz <= size vec] *) Assumes [sz >=0 && sz <= size vec] *)