diff --git a/sat/res.ml b/sat/res.ml index ed029a71..ef2514c7 100644 --- a/sat/res.ml +++ b/sat/res.ml @@ -48,12 +48,12 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct let fresh_pcl_name () = incr _c; "P" ^ (string_of_int !_c) let clause_unit a = - try - H.find unit_learnt [a] - with Not_found -> - let new_c = St.(make_clause (fresh_pcl_name ()) [a] 1 true a.var.vpremise) in - H.add unit_learnt [a] new_c; - new_c + try + H.find unit_learnt [a] + with Not_found -> + let new_c = St.(make_clause (fresh_pcl_name ()) [a] 1 true a.var.vpremise) in + H.add unit_learnt [a] new_c; + new_c (* Printing functions *) let print_atom fmt a = @@ -91,57 +91,28 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct raise (Resolution_error "Input cause is a tautology"); res - (* Adding new proven clauses *) - let is_proved c = H.mem proof c - let is_proven c = is_proved (to_list c) - + (* Adding hyptoheses *) let is_unit_hyp = function - | [a] -> St.(a.var.level = 0 && a.var.reason = None && a.var.vpremise <> []) - | _ -> false + | [a] -> St.(a.var.level = 0 && a.var.reason = None && a.var.vpremise <> []) + | _ -> false - let unit_learnts a = - match St.(a.var.level, a.var.reason, a.var.vpremise) with - | 0, None, [] -> [clause_unit a] - | _ -> [] - - let need_clause (c, cl) = - if is_proved cl then - [], [] - else if not St.(c.learnt) || is_unit_hyp cl then begin + let is_proved (c, cl) = + if H.mem proof cl then + true + else if is_unit_hyp cl || not St.(c.learnt) then begin H.add proof cl Assumption; - [], [] + true end else - let l = - if List.length cl > 1 then - List.flatten (List.map unit_learnts cl) - else - [] - in - (* - Log.debug 0 "Need for : %s" St.(c.name); - List.iter (fun c -> - Log.debug 0 " premise: %s" St.(c.name)) St.(c.cpremise); - List.iter (fun c -> - Log.debug 0 " unit: %s" St.(c.name)) l; - *) - St.(c.cpremise), l + false - let rec diff_learnt acc l l' = - match l, l' with - | [], _ -> l' @ acc - | a :: r, b :: r' -> - if equal_atoms a b then - diff_learnt acc r r' - else - diff_learnt (b :: acc) l r' - | _ -> raise (Resolution_error "Impossible to derive correct clause") + let is_proven c = is_proved (c, to_list c) let add_res (c, cl_c) (d, cl_d) = - Log.debug 7 "Resolving clauses :"; - Log.debug 7 " %a" St.pp_clause c; - Log.debug 7 " %a" St.pp_clause d; - assert (is_proved cl_c); - assert (is_proved cl_d); + Log.debug 7 " Resolving clauses :"; + Log.debug 7 " %a" St.pp_clause c; + Log.debug 7 " %a" St.pp_clause d; + assert (is_proved (c, cl_c)); + assert (is_proved (c, cl_d)); let l = List.merge compare_atoms cl_c cl_d in let resolved, new_clause = resolve l in match resolved with @@ -153,14 +124,27 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct new_c, new_clause | _ -> raise (Resolution_error "Resolved to a tautology") - let add_clause cl l = (* We assume that all clauses in c.cpremise are already proved ! *) + let rec diff_learnt acc l l' = + match l, l' with + | [], _ -> l' @ acc + | a :: r, b :: r' -> + if equal_atoms a b then + diff_learnt acc r r' + else + diff_learnt (b :: acc) l r' + | _ -> raise (Resolution_error "Impossible to derive correct clause") + + let add_clause c cl l = (* We assume that all clauses in l are already proved ! *) match l with | a :: ((_ :: _) as r) -> + Log.debug 5 "Resolving (with history) %a" St.pp_clause c; let temp_c, temp_cl = List.fold_left add_res a r in + Log.debug 10 " Switching to unit resolutions"; let unit_to_use = diff_learnt [] cl temp_cl in let unit_r = List.map St.(fun a -> clause_unit a.neg, [a.neg]) unit_to_use in let new_c, new_cl = List.fold_left add_res (temp_c, temp_cl) unit_r in if not (equal_cl cl new_cl) then begin + (* We didn't get the expected clause, raise an error *) Log.debug 0 "Expected the following clauses to be equal :"; Log.debug 0 "expected : %s" (Log.on_fmt print_cl cl); Log.debug 0 "found : %a" St.pp_clause new_c; @@ -168,24 +152,28 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct end | _ -> assert false + let need_clause (c, cl) = + if is_proved (c, cl) then + [] + else + St.(c.cpremise) + let rec do_clause = function | [] -> () | c :: r -> let cl = to_list c in - let history, unit_to_learn = need_clause (c, cl) in - if history = [] then (* c is either an asusmption, or already proved *) - do_clause r - else + match need_clause (c, cl) with + | [] -> do_clause r + | history -> let history_cl = List.rev_map (fun c -> c, to_list c) history in - let to_prove = List.filter (fun (_, cl) -> not (is_proved cl)) history_cl in - let to_prove = (List.rev_map fst to_prove) @ unit_to_learn in - if to_prove = [] then begin - (* See wether we can prove c right now *) - add_clause cl history_cl; - do_clause r - end else - (* Or if we have to prove some other clauses first *) - do_clause (to_prove @ (c :: r)) + let to_prove = List.filter (fun (c, cl) -> not (is_proved (c, cl))) history_cl in + let to_prove = (List.rev_map fst to_prove) in + begin match to_prove with + | [] -> + add_clause c cl history_cl; + do_clause r + | _ -> do_clause (to_prove @ (c :: r)) + end let prove c = Log.debug 3 "Proving : %a" St.pp_clause c; @@ -195,17 +183,21 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct let rec prove_unsat_cl (c, cl) = match cl with | [] -> true | a :: r -> - try - Log.debug 2 "Eliminating %a in %a" St.pp_atom a St.pp_clause c; - let d = match St.(a.var.level, a.var.reason) with - | 0, Some d -> d - | 0, None -> clause_unit St.(a.neg) - | _ -> raise Exit - in - prove d; - let cl_d = to_list d in - prove_unsat_cl (add_res (c, cl) (d, cl_d)) - with Exit -> false + Log.debug 2 "Eliminating %a in %a" St.pp_atom a St.pp_clause c; + let d = match St.(a.var.level, a.var.reason) with + | 0, Some d -> d + | 0, None -> clause_unit St.(a.neg) + | _ -> raise Exit + in + prove d; + let cl_d = to_list d in + prove_unsat_cl (add_res (c, cl) (d, cl_d)) + + let prove_unsat_cl c = + try + prove_unsat_cl c + with Exit -> + false exception Cannot let assert_can_prove_unsat c = @@ -276,11 +268,11 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct Format.fprintf fmt "%s -> %s;@\n" id_c id_d let print_res_atom id fmt a = - Format.fprintf fmt "%s [label=\"%a\"]" id print_atom a + Format.fprintf fmt "%s [label=\"%a\"]" id print_atom a let print_res_node concl p1 p2 fmt atom = - let id = new_id () in - Format.fprintf fmt "%a;@\n%a%a%a" + let id = new_id () in + Format.fprintf fmt "%a;@\n%a%a%a" (print_res_atom id) atom (print_dot_edge (c_id concl)) id (print_dot_edge id) (c_id p1) @@ -303,8 +295,8 @@ module Make(St : Solver_types.S)(Proof : sig type proof end) = struct | Resolution (proof1, proof2, a) -> let aux fmt () = Format.fprintf fmt "