From 51293bc66a5c0a8bc7bc33789ebf8b31f7676381 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 10 Aug 2021 00:13:37 -0400 Subject: [PATCH] feat(drup-check): functorize over atoms --- src/checker/drup_check.ml | 211 +++++++++++ src/checker/main.ml | 23 +- src/drup/sidekick_drup.ml | 716 +++++++++++++++++--------------------- 3 files changed, 540 insertions(+), 410 deletions(-) create mode 100644 src/checker/drup_check.ml diff --git a/src/checker/drup_check.ml b/src/checker/drup_check.ml new file mode 100644 index 00000000..efe2b677 --- /dev/null +++ b/src/checker/drup_check.ml @@ -0,0 +1,211 @@ + +module SDrup = Sidekick_drup + +module Atom : sig + include SDrup.ATOM with type t = private int + + val of_int : int -> t +end = struct + type t = int + type atom = t + let hash = CCHash.int + let equal : t -> t -> bool = (=) + let compare : t -> t -> int = compare + let[@inline] neg x = x lxor 1 + let[@inline] of_int x = + let v = abs x lsl 1 in + if x < 0 then neg v else v + let[@inline] sign x = (x land 1) = 0 + let[@inline] to_int x = (if sign x then 1 else -1) * (x lsr 1) + let pp out x = + Fmt.fprintf out "%s%d" (if sign x then "+" else "-") (x lsr 1) + let[@inline] of_int_unsafe i = i + let dummy = 0 + module Assign = struct + type t = Bitvec.t + let create = Bitvec.create + let ensure_size = Bitvec.ensure_size + let is_true = Bitvec.get + let[@inline] is_false self (a:atom) : bool = + is_true self (neg a) + let[@inline] is_unassigned self a = + not (is_true self a) && not (is_false self a) + let set = Bitvec.set + end + module Map = struct + type 'a t = 'a Vec.t + let create () = Vec.create () + let[@inline] ensure_has (self:_ t) a mk : unit = + (* size: 2+atom, because: 1+atom makes atom valid, and if it's positive, + 2+atom is (¬atom)+1 *) + Vec.ensure_size_with self mk (2+(a:atom:>int)) + let get = Vec.get + let set = Vec.set + end + module Stack = struct + include VecI32 + let create()=create() + end +end + +include SDrup.Make(Atom) + +(** A DRUP trace, as a series of operations *) +module Trace : sig + type t + + val create : Clause.store -> t + val cstore : t -> Clause.store + + val add_clause : t -> clause -> unit + val add_input_clause : t -> clause -> unit + val del_clause : t -> clause -> unit + + (** Operator on the set of clauses *) + type op = + | Input of clause + | Redundant of clause + | Delete of clause + + val iteri : t -> f:(int -> op -> unit) -> unit + val ops : t -> op Iter.t + val size : t -> int + val get : t -> int -> op + + val pp_op : op Fmt.printer + + val dump : out_channel -> t -> unit +end = struct + type op = + | Input of clause + | Redundant of clause + | Delete of clause + + type t = { + cstore: Clause.store; + ops: op Vec.t; + } + + let create cstore : t = + { cstore; ops=Vec.create() } + + let cstore self = self.cstore + let add_clause self c = Vec.push self.ops (Redundant c) + let add_input_clause self c = Vec.push self.ops (Input c) + let del_clause self c = Vec.push self.ops (Delete c) + let get self i = Vec.get self.ops i + let size self = Vec.size self.ops + let ops self = Vec.to_seq self.ops + let iteri self ~f = Vec.iteri f self.ops + + let pp_op out = function + | Input c -> Fmt.fprintf out "(@[Input %a@])" Clause.pp c + | Redundant c -> Fmt.fprintf out "(@[Redundant %a@])" Clause.pp c + | Delete c -> Fmt.fprintf out "(@[Delete %a@])" Clause.pp c + + let dump oc self : unit = + let fpf = Printf.fprintf in + let pp_c out c = Clause.iter c ~f:(fun a -> fpf oc "%d " (a:atom:>int)); in + Vec.iter + (function + | Input c -> fpf oc "i %a0\n" pp_c c; + | Redundant c -> fpf oc "%a0\n" pp_c c; + | Delete c -> fpf oc "d %a0\n" pp_c c; + ) + self.ops +end + +(** Forward checking. + + Each event is checked by reverse-unit propagation on previous events. *) +module Fwd_check : sig + type error = + [ `Bad_steps of VecI32.t + | `No_empty_clause + ] + + val pp_error : Trace.t -> error Fmt.printer + + (** [check tr] checks the trace and returns [Ok ()] in case of + success. In case of error it returns [Error idxs] where [idxs] are the + indexes in the trace of the steps that failed. *) + val check : Trace.t -> (unit, error) result +end = struct + module ISet = CCSet.Make(CCInt) + + type t = { + checker: Checker.t; + errors: VecI32.t; + } + let create cstore : t = { + checker=Checker.create cstore; + errors=VecI32.create(); + } + + (* check event, return [true] if it's valid *) + let check_op (self:t) i (op:Trace.op) : bool = + Profile.with_ "check-op" @@ fun() -> + Log.debugf 20 (fun k->k"(@[check-op :idx %d@ :op %a@])" i Trace.pp_op op); + + begin match op with + | Trace.Input c -> + Checker.add_clause self.checker c; + true + + | Trace.Redundant c -> + + let ok = Checker.is_valid_drup self.checker c in + Checker.add_clause self.checker c; (* now add clause *) + ok + + | Trace.Delete c -> + Checker.del_clause self.checker c; + true + + end + + type error = + [ `Bad_steps of VecI32.t + | `No_empty_clause + ] + + let pp_error trace out = function + | `No_empty_clause -> Fmt.string out "no empty clause found" + | `Bad_steps bad -> + let n0 = VecI32.get bad 0 in + Fmt.fprintf out + "@[checking failed on %d ops.@ @[<2>First failure is op[%d]:@ %a@]@]" + (VecI32.size bad) n0 + Trace.pp_op (Trace.get trace n0) + + let check trace : _ result = + let self = create (Trace.cstore trace) in + + (* check each event in turn *) + let has_false = ref false in + Trace.iteri trace + ~f:(fun i op -> + let ok = check_op self i op in + if ok then ( + Log.debugf 50 + (fun k->k"(@[check.step.ok@ :idx %d@ :op %a@])" i Trace.pp_op op); + + (* check if op adds the empty clause *) + begin match op with + | (Trace.Redundant c | Trace.Input c) when Clause.size c = 0 -> + has_false := true + | _ -> () + end; + ) else ( + Log.debugf 10 + (fun k->k"(@[check.step.fail@ :idx %d@ :op %a@])" i Trace.pp_op op); + VecI32.push self.errors i + )); + + Log.debugf 10 (fun k->k"found %d errors" (VecI32.size self.errors)); + if not !has_false then Error `No_empty_clause + else if VecI32.size self.errors > 0 then Error (`Bad_steps self.errors) + else Ok () +end + + diff --git a/src/checker/main.ml b/src/checker/main.ml index 9fc84cab..8e8b0f81 100644 --- a/src/checker/main.ml +++ b/src/checker/main.ml @@ -1,16 +1,15 @@ module BL = Sidekick_bin_lib -module SDrup = Sidekick_drup -let clause_of_int_l store atoms : SDrup.clause = +let clause_of_int_l store atoms : Drup_check.clause = atoms - |> CCList.map SDrup.Atom.of_int - |> SDrup.Clause.of_list store + |> CCList.map Drup_check.Atom.of_int + |> Drup_check.Clause.of_list store let check ?pb proof : bool = Profile.with_ "check" @@ fun() -> - let cstore = SDrup.Clause.create() in - let trace = SDrup.Trace.create cstore in + let cstore = Drup_check.Clause.create() in + let trace = Drup_check.Trace.create cstore in (* add problem to trace, if provided *) begin match pb with @@ -23,7 +22,7 @@ let check ?pb proof : bool = BL.Dimacs_parser.iter parser_ (fun atoms -> let c = clause_of_int_l cstore atoms in - SDrup.Trace.add_input_clause trace c)) + Drup_check.Trace.add_input_clause trace c)) | Some f -> (* TODO: handle .cnf.gz *) Error.errorf "unknown problem file extension '%s'" (Filename.extension f) @@ -40,21 +39,21 @@ let check ?pb proof : bool = (function | BL.Drup_parser.Add c -> let c = clause_of_int_l cstore c in - SDrup.Trace.add_clause trace c + Drup_check.Trace.add_clause trace c | BL.Drup_parser.Delete c -> let c = clause_of_int_l cstore c in - SDrup.Trace.del_clause trace c)) + Drup_check.Trace.del_clause trace c)) | f -> (* TODO: handle .drup.gz *) Error.errorf "unknown proof file extension '%s'" (Filename.extension f) end; (* check proof *) - Log.debugf 1 (fun k->k"checking proof (%d steps)" (SDrup.Trace.size trace)); - begin match SDrup.Fwd_check.check trace with + Log.debugf 1 (fun k->k"checking proof (%d steps)" (Drup_check.Trace.size trace)); + begin match Drup_check.Fwd_check.check trace with | Ok () -> true | Error err -> - Format.eprintf "%a@." (SDrup.Fwd_check.pp_error trace) err; + Format.eprintf "%a@." (Drup_check.Fwd_check.pp_error trace) err; false end diff --git a/src/drup/sidekick_drup.ml b/src/drup/sidekick_drup.ml index b4ba8423..932f4ea2 100644 --- a/src/drup/sidekick_drup.ml +++ b/src/drup/sidekick_drup.ml @@ -1,11 +1,16 @@ +(** DRUP trace checker. + + This module provides a checker for DRUP traces, including step-by-step + checking for traces that interleave DRUP steps with other kinds of steps. +*) + module Fmt = CCFormat module VecI32 = VecI32 -module Atom : sig - type t = private int - val of_int : int -> t - val to_int : t -> int +(** Signature for boolean atoms *) +module type ATOM = sig + type t val equal : t -> t -> bool val compare : t -> t -> int val hash : t -> int @@ -14,349 +19,321 @@ module Atom : sig val pp : t Fmt.printer val dummy : t - val of_int_unsafe : int -> t - module Map : CCMap.S with type key = t -end = struct - type t = int - let hash = CCHash.int - let equal : t -> t -> bool = (=) - let compare : t -> t -> int = compare - let neg x = x lxor 1 - let of_int x = - let v = abs x lsl 1 in - if x < 0 then neg v else v - let sign x = (x land 1) = 0 - let to_int x = (if sign x then 1 else -1) * (x lsr 1) - let pp out x = - Fmt.fprintf out "%s%d" (if sign x then "+" else "-") (x lsr 1) - let of_int_unsafe i = i - let dummy = 0 - module Map = Util.Int_map -end -type atom = Atom.t -(** Boolean clauses *) -module Clause : sig - type store - val create : unit -> store - type t - val size : t -> int - val get : t -> int -> atom - val iter : f:(atom -> unit) -> t -> unit - val watches: t -> atom * atom - val set_watches : t -> atom * atom -> unit - val pp : t Fmt.printer - val of_list : store -> atom list -> t - module Set : CCSet.S with type elt = t - module Tbl : CCHashtbl.S with type key = t -end = struct - module I_arr_tbl = CCHashtbl.Make(struct - type t = atom array - let equal = CCEqual.(array Atom.equal) - let hash = CCHash.(array Atom.hash) - end) - type t = { - id: int; - atoms: atom array; - mutable watches: atom * atom; - } - type store = { - mutable n: int; - } - let create(): store = - { n=0; } - let[@inline] size self = Array.length self.atoms - let[@inline] get self i = Array.get self.atoms i - let[@inline] watches self = self.watches - let[@inline] set_watches self w = self.watches <- w - let[@inline] iter ~f self = - for i=0 to Array.length self.atoms-1 do - f (Array.unsafe_get self.atoms i) - done - let pp out (self:t) = - let pp_watches out = function - | (p,q) when p=Atom.dummy || q=Atom.dummy -> () - | (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in - Fmt.fprintf out "(@[cl[%d]@ %a%a])" - self.id (Fmt.Dump.array Atom.pp) self.atoms pp_watches self.watches - let of_list self atoms : t = - (* normalize + find in table *) - let atoms = List.sort_uniq Atom.compare atoms |> Array.of_list in - let id = self.n in - self.n <- 1 + self.n; - let c = {atoms; id; watches=Atom.dummy, Atom.dummy} in - c - module As_key = struct - type nonrec t=t - let[@inline] hash a = CCHash.int a.id - let[@inline] equal a b = a.id = b.id - let[@inline] compare a b = compare a.id b.id + type atom = t + + module Assign : sig + type t + val create : unit -> t + + val ensure_size : t -> atom -> unit + val set : t -> atom -> bool -> unit + val is_true : t -> atom -> bool + val is_false : t -> atom -> bool + val is_unassigned : t -> atom -> bool end - module Set = CCSet.Make(As_key) - module Tbl = CCHashtbl.Make(As_key) -end -type clause = Clause.t -(** A DRUP trace, as a series of operations *) -module Trace : sig - type t + module Map : sig + type 'a t + val create : unit -> 'a t + val ensure_has : 'a t -> atom -> (unit -> 'a) -> unit + val get : 'a t -> atom -> 'a + end - val create : Clause.store -> t - val cstore : t -> Clause.store - - val add_clause : t -> clause -> unit - val add_input_clause : t -> clause -> unit - val del_clause : t -> clause -> unit - - (** Operator on the set of clauses *) - type op = - | Input of clause - | Redundant of clause - | Delete of clause - - val iteri : t -> f:(int -> op -> unit) -> unit - val ops : t -> op Iter.t - val size : t -> int - val get : t -> int -> op - - val pp_op : op Fmt.printer - - val dump : out_channel -> t -> unit -end = struct - type op = - | Input of clause - | Redundant of clause - | Delete of clause - - type t = { - cstore: Clause.store; - ops: op Vec.t; - } - - let create cstore : t = - { cstore; ops=Vec.create() } - - let cstore self = self.cstore - let add_clause self c = Vec.push self.ops (Redundant c) - let add_input_clause self c = Vec.push self.ops (Input c) - let del_clause self c = Vec.push self.ops (Delete c) - let get self i = Vec.get self.ops i - let size self = Vec.size self.ops - let ops self = Vec.to_seq self.ops - let iteri self ~f = Vec.iteri f self.ops - - let pp_op out = function - | Input c -> Fmt.fprintf out "(@[Input %a@])" Clause.pp c - | Redundant c -> Fmt.fprintf out "(@[Redundant %a@])" Clause.pp c - | Delete c -> Fmt.fprintf out "(@[Delete %a@])" Clause.pp c - - let dump oc self : unit = - let fpf = Printf.fprintf in - let pp_c out c = Clause.iter c ~f:(fun a -> fpf oc "%d " (Atom.to_int a)); in - Vec.iter - (function - | Input c -> fpf oc "i %a0\n" pp_c c; - | Redundant c -> fpf oc "%a0\n" pp_c c; - | Delete c -> fpf oc "d %a0\n" pp_c c; - ) - self.ops + module Stack : sig + type t + val create : unit -> t + val get : t -> int -> atom + val set : t -> int -> atom -> unit + val push : t -> atom -> unit + val size : t -> int + val shrink : t -> int -> unit + val to_iter : t -> atom Iter.t + end end -(** Forward checking. +(* TODO: resolution proof construction, optionally *) - Each event is checked by reverse-unit propagation on previous events. *) -module Fwd_check : sig - type error = - [ `Bad_steps of VecI32.t - | `No_empty_clause - ] +(* TODO: backward checking + pruning of traces *) - val pp_error : Trace.t -> error Fmt.printer +(** An instance of the checker *) +module type S = sig + type atom - (** [check tr] checks the trace and returns [Ok ()] in case of - success. In case of error it returns [Error idxs] where [idxs] are the - indexes in the trace of the steps that failed. *) - val check : Trace.t -> (unit, error) result -end = struct - module ISet = CCSet.Make(CCInt) + module Clause : sig - type t = { - cstore: Clause.store; - assign: Bitvec.t; (* atom -> is_true(atom) *) - trail: VecI32.t; (* current assignment *) - mutable trail_ptr : int; (* offset in trail for propagation *) - active_clauses: unit Clause.Tbl.t; - watches: Clause.t Vec.t Vec.t; (* atom -> clauses it watches *) - errors: VecI32.t; - } + type store + val create : unit -> store - let create cstore : t = - { trail=VecI32.create(); - trail_ptr = 0; - cstore; - active_clauses=Clause.Tbl.create 32; - assign=Bitvec.create(); - watches=Vec.create(); - errors=VecI32.create(); + type t + + val size : t -> int + + val get : t -> int -> atom + + val iter : f:(atom -> unit) -> t -> unit + + val pp : t Fmt.printer + + val of_list : store -> atom list -> t + end + type clause = Clause.t + + module Checker : sig + type t + + val create : Clause.store -> t + + val add_clause : t -> Clause.t -> unit + + val is_valid_drup : t -> Clause.t -> bool + + val del_clause : t -> Clause.t -> unit + end +end + +module[@inline] Make(A : ATOM) + : S with type atom = A.t += struct + module Atom = A + type atom = Atom.t + + (** Boolean clauses *) + module Clause : sig + type store + val create : unit -> store + type t + val size : t -> int + val id : t -> int + val get : t -> int -> atom + val iter : f:(atom -> unit) -> t -> unit + val watches: t -> atom * atom + val set_watches : t -> atom * atom -> unit + val pp : t Fmt.printer + val of_list : store -> atom list -> t + module Set : CCSet.S with type elt = t + module Tbl : CCHashtbl.S with type key = t + end = struct + type t = { + id: int; + atoms: atom array; + mutable watches: atom * atom; + } + type store = { + mutable n: int; + } + let create(): store = + { n=0; } + let[@inline] id self = self.id + let[@inline] size self = Array.length self.atoms + let[@inline] get self i = Array.get self.atoms i + let[@inline] watches self = self.watches + let[@inline] set_watches self w = self.watches <- w + let[@inline] iter ~f self = + for i=0 to Array.length self.atoms-1 do + f (Array.unsafe_get self.atoms i) + done + let pp out (self:t) = + let pp_watches out = function + | (p,q) when p=Atom.dummy || q=Atom.dummy -> () + | (p,q) -> Fmt.fprintf out "@ :watches (%a,%a)" Atom.pp p Atom.pp q in + Fmt.fprintf out "(@[cl[%d]@ %a%a])" + self.id (Fmt.Dump.array Atom.pp) self.atoms pp_watches self.watches + let of_list self atoms : t = + (* normalize + find in table *) + let atoms = List.sort_uniq Atom.compare atoms |> Array.of_list in + let id = self.n in + self.n <- 1 + self.n; + let c = {atoms; id; watches=Atom.dummy, Atom.dummy} in + c + module As_key = struct + type nonrec t=t + let[@inline] hash a = CCHash.int a.id + let[@inline] equal a b = a.id = b.id + let[@inline] compare a b = compare a.id b.id + end + module Set = CCSet.Make(As_key) + module Tbl = CCHashtbl.Make(As_key) + end + type clause = Clause.t + + (** Forward proof checker. + + Each event is checked by reverse-unit propagation on previous events. *) + module Checker : sig + type t + val create : Clause.store -> t + val add_clause : t -> Clause.t -> unit + val is_valid_drup : t -> Clause.t -> bool + val del_clause : t -> Clause.t -> unit + end = struct + type t = { + cstore: Clause.store; + assign: Atom.Assign.t; (* atom -> is_true(atom) *) + trail: Atom.Stack.t; (* current assignment *) + mutable trail_ptr : int; (* offset in trail for propagation *) + active_clauses: unit Clause.Tbl.t; + watches: Clause.t Vec.t Atom.Map.t; (* atom -> clauses it watches *) } - (* ensure data structures are big enough to handle [a] *) - let ensure_atom_ self (a:atom) = - Bitvec.ensure_size self.assign (a:atom:>int); - (* size: 2+atom, because: 1+atom makes atom valid, and if it's positive, - 2+atom is (¬atom)+1 *) - Vec.ensure_size_with self.watches Vec.create (2+(a:atom:>int)); - () + let create cstore : t = + { trail=Atom.Stack.create(); + trail_ptr = 0; + cstore; + active_clauses=Clause.Tbl.create 32; + assign=Atom.Assign.create(); + watches=Atom.Map.create(); + } - let[@inline] is_true self (a:atom) : bool = - Bitvec.get self.assign (a:atom:>int) + (* ensure data structures are big enough to handle [a] *) + let ensure_atom_ self (a:atom) = + Atom.Assign.ensure_size self.assign a; + (* size: 2+atom, because: 1+atom makes atom valid, and if it's positive, + 2+atom is (¬atom)+1 *) + Atom.Map.ensure_has self.watches a (fun _ -> Vec.create ()); + () - let[@inline] is_false self (a:atom) : bool = - is_true self (Atom.neg a) + let[@inline] is_true self (a:atom) : bool = + Atom.Assign.is_true self.assign a + let[@inline] is_false self (a:atom) : bool = + Atom.Assign.is_false self.assign a + let[@inline] is_unassigned self a = + Atom.Assign.is_unassigned self.assign a - let is_unassigned self a = - not (is_true self a) && not (is_false self a) + let add_watch_ self (a:atom) (c:clause) = + Vec.push (Atom.Map.get self.watches a) c - let add_watch_ self (a:atom) (c:clause) = - Vec.push (Vec.get self.watches (a:atom:>int)) c + let remove_watch_ self (a:atom) idx = + let v = Atom.Map.get self.watches a in + Vec.fast_remove v idx - let remove_watch_ self (a:atom) idx = - let v = Vec.get self.watches (a:atom:>int) in - Vec.fast_remove v idx + exception Conflict - exception Conflict + let raise_conflict_ self a = + Log.debugf 5 (fun k->k"conflict on atom %a" Atom.pp a); + raise Conflict - let raise_conflict_ self a = - Log.debugf 5 (fun k->k"conflict on atom %a" Atom.pp a); - raise Conflict - - (* set atom to true *) - let set_atom_true (self:t) (a:atom) : unit = - if is_true self a then () - else if is_false self a then raise_conflict_ self a - else ( - Bitvec.set self.assign (a:atom:>int) true; - VecI32.push self.trail (a:atom:>int) - ) - - (* print the trail *) - let pp_trail_ out self = - let pp_a out i = Atom.pp out (Atom.of_int_unsafe i) in - Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_a) (VecI32.to_iter self.trail) - - exception Found_watch of atom - exception Is_sat - exception Is_undecided - - (* check if [c] is false in current trail *) - let c_is_false_ self c = - try Clause.iter c ~f:(fun a -> if not (is_false self a) then raise Exit); true - with Exit -> false - - type propagation_res = - | Keep - | Remove - - (* do boolean propagation in [c], which is watched by the true literal [a] *) - let propagate_in_clause_ (self:t) (a:atom) (c:clause) : propagation_res = - assert (is_true self a); - let a1, a2 = Clause.watches c in - let na = Atom.neg a in - (* [q] is the other literal in [c] such that [¬q] watches [c]. *) - let q = if Atom.equal a1 na then a2 else (assert(a2==na); a1) in - try - if is_true self q then Keep (* clause is satisfied *) + (* set atom to true *) + let[@inline] set_atom_true (self:t) (a:atom) : unit = + if is_true self a then () + else if is_false self a then raise_conflict_ self a else ( - let n_unassigned = ref 0 in - let unassigned_a = ref a in (* an unassigned atom, if [!n_unassigned > 0] *) - if not (is_false self q) then unassigned_a := q; - begin - try - Clause.iter c - ~f:(fun ai -> - if is_true self ai then raise Is_sat (* no watch update *) - else if is_unassigned self ai then ( - incr n_unassigned; - if q <> ai then unassigned_a := ai; - if !n_unassigned >= 2 then raise Is_undecided; (* early exit *) - ); - ) - with Is_undecided -> () - end; - - if !n_unassigned = 0 then ( - (* if we reach this point it means no literal is true, and none is - unassigned. So they're all false and we have a conflict. *) - assert (is_false self q); - raise_conflict_ self a; - ) else if !n_unassigned = 1 then ( - (* no lit is true, only one is unassigned: propagate it. - no need to update the watches as the clause is satisfied. *) - assert (is_unassigned self !unassigned_a); - let p = !unassigned_a in - Log.debugf 30 (fun k->k"(@[propagate@ :atom %a@ :reason %a@])" Atom.pp p Clause.pp c); - set_atom_true self p; - Keep - ) else ( - (* at least 2 unassigned, just update the watch literal to [¬p] *) - let p = !unassigned_a in - assert (p <> q); - Clause.set_watches c (q, p); - add_watch_ self (Atom.neg p) c; - Remove - ); + Atom.Assign.set self.assign a true; + Atom.Stack.push self.trail a ) - with - | Is_sat -> Keep - let propagate_atom_ self (a:atom) : unit = - let v = Vec.get self.watches (a:atom:>int) in - let i = ref 0 in - while !i < Vec.size v do - match propagate_in_clause_ self a (Vec.get v !i) with - | Keep -> incr i; - | Remove -> - remove_watch_ self a !i - done + (* print the trail *) + let pp_trail_ out self = + Fmt.fprintf out "(@[%a@])" (Fmt.iter Atom.pp) (Atom.Stack.to_iter self.trail) - (* perform boolean propagation in a fixpoint - @raise Conflict if a clause is false *) - let bcp_fixpoint_ (self:t) : unit = - Profile.with_ "bcp-fixpoint" @@ fun() -> - while self.trail_ptr < VecI32.size self.trail do - let a = Atom.of_int_unsafe (VecI32.get self.trail self.trail_ptr) in - Log.debugf 50 (fun k->k"(@[bcp@ :atom %a@])" Atom.pp a); - self.trail_ptr <- 1 + self.trail_ptr; - propagate_atom_ self a; - done + exception Found_watch of atom + exception Is_sat + exception Is_undecided - (* calls [f] and then restore trail to what it was *) - let with_restore_trail_ self f = - let trail_size0 = VecI32.size self.trail in - let ptr0 = self.trail_ptr in + (* check if [c] is false in current trail *) + let c_is_false_ self c = + try Clause.iter c ~f:(fun a -> if not (is_false self a) then raise Exit); true + with Exit -> false - let restore () = - (* unassign new literals *) - for i=trail_size0 to VecI32.size self.trail - 1 do - let a = Atom.of_int_unsafe (VecI32.get self.trail i) in - assert (is_true self a); - Bitvec.set self.assign (a:atom:>int) false; - done; + type propagation_res = + | Keep + | Remove - (* remove literals from trail *) - VecI32.shrink self.trail trail_size0; - self.trail_ptr <- ptr0 - in + (* do boolean propagation in [c], which is watched by the true literal [a] *) + let propagate_in_clause_ (self:t) (a:atom) (c:clause) : propagation_res = + assert (is_true self a); + let a1, a2 = Clause.watches c in + let na = Atom.neg a in + (* [q] is the other literal in [c] such that [¬q] watches [c]. *) + let q = if Atom.equal a1 na then a2 else (assert(a2==na); a1) in + try + if is_true self q then Keep (* clause is satisfied *) + else ( + let n_unassigned = ref 0 in + let unassigned_a = ref a in (* an unassigned atom, if [!n_unassigned > 0] *) + if not (is_false self q) then unassigned_a := q; + begin + try + Clause.iter c + ~f:(fun ai -> + if is_true self ai then raise Is_sat (* no watch update *) + else if is_unassigned self ai then ( + incr n_unassigned; + if q <> ai then unassigned_a := ai; + if !n_unassigned >= 2 then raise Is_undecided; (* early exit *) + ); + ) + with Is_undecided -> () + end; - CCFun.finally ~h:restore ~f + if !n_unassigned = 0 then ( + (* if we reach this point it means no literal is true, and none is + unassigned. So they're all false and we have a conflict. *) + assert (is_false self q); + raise_conflict_ self a; + ) else if !n_unassigned = 1 then ( + (* no lit is true, only one is unassigned: propagate it. + no need to update the watches as the clause is satisfied. *) + assert (is_unassigned self !unassigned_a); + let p = !unassigned_a in + Log.debugf 30 (fun k->k"(@[propagate@ :atom %a@ :reason %a@])" Atom.pp p Clause.pp c); + set_atom_true self p; + Keep + ) else ( + (* at least 2 unassigned, just update the watch literal to [¬p] *) + let p = !unassigned_a in + assert (p <> q); + Clause.set_watches c (q, p); + add_watch_ self (Atom.neg p) c; + Remove + ); + ) + with + | Is_sat -> Keep - (* check event, return [true] if it's valid *) - let check_op (self:t) i (op:Trace.op) : bool = - Profile.with_ "check-op" @@ fun() -> - Log.debugf 20 (fun k->k"(@[check-op :idx %d@ :op %a@])" i Trace.pp_op op); + let propagate_atom_ self (a:atom) : unit = + let v = Atom.Map.get self.watches a in + let i = ref 0 in + while !i < Vec.size v do + match propagate_in_clause_ self a (Vec.get v !i) with + | Keep -> incr i; + | Remove -> + remove_watch_ self a !i + done + + (* perform boolean propagation in a fixpoint + @raise Conflict if a clause is false *) + let bcp_fixpoint_ (self:t) : unit = + Profile.with_ "bcp-fixpoint" @@ fun() -> + while self.trail_ptr < Atom.Stack.size self.trail do + let a = Atom.Stack.get self.trail self.trail_ptr in + Log.debugf 50 (fun k->k"(@[bcp@ :atom %a@])" Atom.pp a); + self.trail_ptr <- 1 + self.trail_ptr; + propagate_atom_ self a; + done + + (* calls [f] and then restore trail to what it was *) + let with_restore_trail_ self f = + let trail_size0 = Atom.Stack.size self.trail in + let ptr0 = self.trail_ptr in + + let restore () = + (* unassign new literals *) + for i=trail_size0 to Atom.Stack.size self.trail - 1 do + let a = Atom.Stack.get self.trail i in + assert (is_true self a); + Atom.Assign.set self.assign a false; + done; + + (* remove literals from trail *) + Atom.Stack.shrink self.trail trail_size0; + self.trail_ptr <- ptr0 + in + + CCFun.finally ~h:restore ~f (* add clause to the state *) - let add_c_ (c:Clause.t) = + let add_clause (self:t) (c:Clause.t) = Log.debugf 50 (fun k->k"(@[add-clause@ %a@])" Clause.pp c); Clause.iter c ~f:(ensure_atom_ self); Clause.Tbl.add self.active_clauses c (); @@ -387,94 +364,37 @@ end = struct ) end; () - in - - match op with - | Trace.Input c -> - add_c_ c; - true - - | Trace.Redundant c -> + let is_valid_drup (self:t) (c:Clause.t) : bool = (* negate [c], pushing each atom on trail, and see if we get [Conflict] by pure propagation *) - let ok = - try - with_restore_trail_ self @@ fun () -> - Clause.iter c - ~f:(fun a -> - if is_true self a then raise_notrace Conflict; (* tautology *) - let a' = Atom.neg a in - if is_true self a' then () else ( - set_atom_true self a' - )); - bcp_fixpoint_ self; + try + with_restore_trail_ self @@ fun () -> + Clause.iter c + ~f:(fun a -> + if is_true self a then raise_notrace Conflict; (* tautology *) + let a' = Atom.neg a in + if is_true self a' then () else ( + set_atom_true self a' + )); + bcp_fixpoint_ self; - (* - (* slow sanity check *) - Clause.Tbl.iter - (fun c () -> - if c_is_false_ self c then - Log.debugf 0 (fun k->k"clause is false: %a" Clause.pp c)) - self.active_clauses; - *) + (* + (* slow sanity check *) + Clause.Tbl.iter + (fun c () -> + if c_is_false_ self c then + Log.debugf 0 (fun k->k"clause is false: %a" Clause.pp c)) + self.active_clauses; + *) - false - with Conflict -> - true - in + false + with Conflict -> + true - (* now add clause *) - add_c_ c; - ok + let del_clause (_self:t) (_c:Clause.t) : unit = + () (* TODO *) + end - | Trace.Delete _c -> - true (* TODO: actually remove the clause *) - - type error = - [ `Bad_steps of VecI32.t - | `No_empty_clause - ] - - let pp_error trace out = function - | `No_empty_clause -> Fmt.string out "no empty clause found" - | `Bad_steps bad -> - let n0 = VecI32.get bad 0 in - Fmt.fprintf out - "@[checking failed on %d ops.@ @[<2>First failure is op[%d]:@ %a@]@]" - (VecI32.size bad) n0 - Trace.pp_op (Trace.get trace n0) - - let check trace : _ result = - let self = create (Trace.cstore trace) in - - (* check each event in turn *) - let has_false = ref false in - Trace.iteri trace - ~f:(fun i op -> - let ok = check_op self i op in - if ok then ( - Log.debugf 50 - (fun k->k"(@[check.step.ok@ :idx %d@ :op %a@])" i Trace.pp_op op); - - (* check if op adds the empty clause *) - begin match op with - | (Trace.Redundant c | Trace.Input c) when Clause.size c = 0 -> - has_false := true - | _ -> () - end; - ) else ( - Log.debugf 10 - (fun k->k"(@[check.step.fail@ :idx %d@ :op %a@])" i Trace.pp_op op); - Log.debugf 50 (fun k->k"(@[trail: %a@])" pp_trail_ self); - VecI32.push self.errors i - )); - - Log.debugf 10 (fun k->k"found %d errors" (VecI32.size self.errors)); - if not !has_false then Error `No_empty_clause - else if VecI32.size self.errors > 0 then Error (`Bad_steps self.errors) - else Ok () end - -(* TODO: backward checking + pruning of traces *)