perf(checker): optimize watch literals

This commit is contained in:
Simon Cruanes 2021-08-08 02:36:30 -04:00
parent ab6e574298
commit b4231d23c1
5 changed files with 65 additions and 41 deletions

View file

@ -3,6 +3,6 @@
(name main)
(public_name sidekick-checker)
(package sidekick-bin)
(libraries containers sidekick-bin.lib
(libraries containers sidekick-bin.lib mtime mtime.clock.os
sidekick.util sidekick.tef sidekick.drup)
(flags :standard -warn-error -a+8 -open Sidekick_util))

View file

@ -68,14 +68,18 @@ let () =
Arg.parse opts (fun f -> files := f :: !files) "checker [opt]* [file]+";
begin match List.rev !files with
let ok =
match List.rev !files with
| [pb; proof] ->
Log.debugf 1 (fun k->k"checker: problem `%s`, proof `%s`" pb proof);
let ok = check ~pb proof in
if not ok then exit 1
check ~pb proof
| [proof] ->
Log.debugf 1 (fun k->k"checker: proof `%s`" proof);
let ok = check ?pb:None proof in
if not ok then exit 1
| _ -> failwith "expected <problem>? <proof>"
end
check ?pb:None proof
| _ -> Error.errorf "expected <problem>? <proof>"
in
let t2 = Mtime_clock.elapsed () |> Mtime.Span.to_s in
Format.printf "c %s@." (if ok then "OK" else "FAIL");
Format.printf "c elapsed time: %.3fs@." t2;
if not ok then exit 1

View file

@ -13,6 +13,7 @@ module Atom : sig
val sign : t -> bool
val pp : t Fmt.printer
val dummy : t
val of_int_unsafe : int -> t
module Map : CCMap.S with type key = t
end = struct
@ -29,6 +30,7 @@ end = struct
let pp out x =
Fmt.fprintf out "%s%d" (if sign x then "+" else "-") (x lsr 1)
let of_int_unsafe i = i
let dummy = 0
module Map = Util.Int_map
end
type atom = Atom.t
@ -41,7 +43,7 @@ module Clause : sig
val size : t -> int
val get : t -> int -> atom
val iter : f:(atom -> unit) -> t -> unit
val watches: t -> (atom * atom) option
val watches: t -> atom * atom
val set_watches : t -> atom * atom -> unit
val pp : t Fmt.printer
val of_list : store -> atom list -> t
@ -56,7 +58,7 @@ end = struct
type t = {
id: int;
atoms: atom array;
mutable watches: (atom * atom) option;
mutable watches: atom * atom;
}
type store = {
mutable n: int;
@ -66,12 +68,12 @@ end = struct
let size self = Array.length self.atoms
let get self i = Array.get self.atoms i
let watches self = self.watches
let set_watches self w = self.watches <- Some w
let set_watches self w = self.watches <- w
let iter ~f self = Array.iter f self.atoms
let pp out (self:t) =
let pp_watches out = function
| None -> ()
| Some (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in
| (p,q) when p=Atom.dummy || q=Atom.dummy -> ()
| (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in
Fmt.fprintf out "(@[cl[%d]@ %a%a])"
self.id (Fmt.Dump.array Atom.pp) self.atoms pp_watches self.watches
let of_list self atoms : t =
@ -79,7 +81,7 @@ end = struct
let atoms = List.sort_uniq Atom.compare atoms |> Array.of_list in
let id = self.n in
self.n <- 1 + self.n;
let c = {atoms; id; watches=None} in
let c = {atoms; id; watches=Atom.dummy, Atom.dummy} in
c
module As_key = struct
type nonrec t=t
@ -161,7 +163,6 @@ end
Each event is checked by reverse-unit propagation on previous events. *)
module Fwd_check : sig
type error =
[ `Bad_steps of VecI32.t
| `No_empty_clause
@ -173,7 +174,6 @@ module Fwd_check : sig
success. In case of error it returns [Error idxs] where [idxs] are the
indexes in the trace of the steps that failed. *)
val check : Trace.t -> (unit, error) result
end = struct
module ISet = CCSet.Make(CCInt)
@ -183,7 +183,7 @@ end = struct
trail: VecI32.t; (* current assignment *)
mutable trail_ptr : int; (* offset in trail for propagation *)
active_clauses: unit Clause.Tbl.t;
watches: Clause.Set.t Vec.t; (* atom -> clauses it watches *)
watches: Clause.t Vec.t Vec.t; (* atom -> clauses it watches *)
errors: VecI32.t;
}
@ -202,7 +202,7 @@ end = struct
Bitvec.ensure_size self.assign (a:atom:>int);
(* size: 2+atom, because: 1+atom makes atom valid, and if it's positive,
2+atom is (¬atom)+1 *)
Vec.ensure_size self.watches Clause.Set.empty (2+(a:atom:>int));
Vec.ensure_size_with self.watches Vec.create (2+(a:atom:>int));
()
let[@inline] is_true self (a:atom) : bool =
@ -215,12 +215,11 @@ end = struct
not (is_true self a) && not (is_false self a)
let add_watch_ self (a:atom) (c:clause) =
let set = Vec.get self.watches (a:atom:>int) in
Vec.set self.watches (a:atom:>int) (Clause.Set.add c set)
Vec.push (Vec.get self.watches (a:atom:>int)) c
let remove_watch_ self (a:atom) (c:clause) =
let set = Vec.get self.watches (a:atom:>int) in
Vec.set self.watches (a:atom:>int) (Clause.Set.remove c set)
let remove_watch_ self (a:atom) idx =
let v = Vec.get self.watches (a:atom:>int) in
Vec.fast_remove v idx
exception Conflict
@ -251,19 +250,19 @@ end = struct
try Clause.iter c ~f:(fun a -> if not (is_false self a) then raise Exit); true
with Exit -> false
type propagation_res =
| Keep
| Remove
(* do boolean propagation in [c], which is watched by the true literal [a] *)
let propagate_in_clause_ (self:t) (a:atom) (c:clause) : unit =
let propagate_in_clause_ (self:t) (a:atom) (c:clause) : propagation_res =
assert (is_true self a);
let a1, a2 =
match Clause.watches c with
| None -> assert false
| Some tup -> tup
in
let a1, a2 = Clause.watches c in
let na = Atom.neg a in
(* [q] is the other literal in [c] such that [¬q] watches [c]. *)
let q = if Atom.equal a1 na then a2 else (assert(a2==na); a1) in
try
if is_true self q then () (* clause is satisfied *)
if is_true self q then Keep (* clause is satisfied *)
else (
let n_unassigned = ref 0 in
let unassigned_a = ref a in (* an unassigned atom, if [!n_unassigned > 0] *)
@ -294,21 +293,28 @@ end = struct
let p = !unassigned_a in
Log.debugf 30 (fun k->k"(@[propagate@ :atom %a@ :reason %a@])" Atom.pp p Clause.pp c);
set_atom_true self p;
Keep
) else (
(* at least 2 unassigned, just update the watch literal to [¬p] *)
let p = !unassigned_a in
assert (p <> q);
Clause.set_watches c (q, p);
remove_watch_ self a c;
add_watch_ self (Atom.neg p) c;
Remove
);
)
with
| Is_sat -> ()
| Is_sat -> Keep
let propagate_atom_ self (a:atom) : unit =
let set = Vec.get self.watches (a:atom:>int) in
Clause.Set.iter (propagate_in_clause_ self a) set
let v = Vec.get self.watches (a:atom:>int) in
let i = ref 0 in
while !i < Vec.size v do
match propagate_in_clause_ self a (Vec.get v !i) with
| Keep -> incr i;
| Remove ->
remove_watch_ self a !i
done
(* perform boolean propagation in a fixpoint
@raise Conflict if a clause is false *)
@ -360,15 +366,21 @@ end = struct
let c0 = Clause.get c 0 in
let c1 = Clause.get c 1 in
assert (c0 <> c1);
add_watch_ self (Atom.neg c0) c;
add_watch_ self (Atom.neg c1) c;
Clause.set_watches c (c0,c1);
(* make sure watches are valid *)
if is_false self c0 then (
propagate_in_clause_ self (Atom.neg c0) c;
match propagate_in_clause_ self (Atom.neg c0) c with
| Keep -> add_watch_ self (Atom.neg c0) c;
| Remove -> ()
) else (
add_watch_ self (Atom.neg c0) c
);
if is_false self c1 then (
propagate_in_clause_ self (Atom.neg c1) c;
match propagate_in_clause_ self (Atom.neg c1) c with
| Keep -> add_watch_ self (Atom.neg c1) c;
| Remove -> ()
) else (
add_watch_ self (Atom.neg c1) c
)
end;
()

View file

@ -48,15 +48,20 @@ let resize_ t x size =
let ensure_cap_ self x n =
if n > Array.length self.data then (
let new_size = max n (2 * Array.length self.data) in
resize_ self x new_size
resize_ self (x()) new_size
)
let ensure_size self x n =
ensure_cap_ self x n;
let ensure_size_with self f n =
ensure_cap_ self f n;
if n > self.sz then (
for i=self.sz to n-1 do
self.data.(i) <- f();
done;
self.sz <- n
)
let ensure_size self x n = ensure_size_with self (fun() -> x) n
(* grow the array *)
let[@inline never] grow_to_double_size t x : unit =
let new_size =

View file

@ -31,6 +31,9 @@ val clear : 'a t -> unit
val ensure_size : 'a t -> 'a -> int -> unit
(** ensure size is at least [n] *)
val ensure_size_with : 'a t -> (unit -> 'a) -> int -> unit
(** ensure size is at least [n] *)
val shrink : 'a t -> int -> unit
(** [shrink vec sz] resets size of [vec] to [sz].
Assumes [sz >=0 && sz <= size vec] *)