Source code for problog.constraint

"""
problog.constraint - Propositional constraints
----------------------------------------------

Data structures for specifying propositional constraints.

..
    Part of the ProbLog distribution.

    Copyright 2015 KU Leuven, DTAI Research Group

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
"""

from .errors import InvalidValue
from .logic import Term, Constant


[docs]class Constraint(object): """A propositional constraint."""
[docs] def get_nodes(self): """Get all nodes involved in this constraint.""" raise NotImplementedError("abstract method")
[docs] def update_weights(self, weights, semiring): """Update the weights in the given dictionary according to the constraints. :param weights: dictionary of weights (see result of :func:`LogicFormula.extract_weights`) :param semiring: semiring to use for weight transformation """ # Typically, constraints don't update weights pass
[docs] def is_true(self): """Checks whether the constraint is trivially true.""" return False
[docs] def is_false(self): """Checks whether the constraint is trivially false.""" return False
[docs] def is_nontrivial(self): """Checks whether the constraint is non-trivial.""" return not self.is_true() and not self.is_false()
[docs] def as_clauses(self): """Represent the constraint as a list of clauses (CNF form). :return: list of clauses where each clause is represent as a list of node keys :rtype: list[list[int]] """ raise NotImplementedError("abstract method")
[docs] def copy(self, rename=None): """Copy this constraint while applying the given node renaming. :param rename: node rename map (or None if no rename is required) :return: copy of the current constraint """ raise NotImplementedError("abstract method")
[docs]class ConstraintAD(Constraint): """Annotated disjunction constraint (mutually exclusive with weight update).""" def __init__(self, group): self.nodes = set() self.group = group self.extra_node = None self.location = None def __str__(self): return "annotated_disjunction(%s, %s)" % (list(self.nodes), self.extra_node)
[docs] def get_nodes(self): if self.extra_node: return list(self.nodes) + [self.extra_node] else: return self.nodes
[docs] def is_true(self): return len(self.nodes) <= 1
[docs] def is_false(self): return False
[docs] def add(self, node, formula, cr_extra=True): """Add a node to the constraint from the given formula. :param node: node to add :param formula: formula from which the node is taken :param cr_extra: Create an extra_node when required (when it is None and this is the second atom of the group). :return: value of the node after constraint propagation """ if node in self.nodes: return node is_extra = formula.get_node(node).is_extra try: if ( not self.location and formula.get_node(node).name and formula.get_node(node).name.args ): if formula.database: self.location = formula.database.lineno( formula.get_node(node).name.args[-1].location ) except AttributeError: pass if formula.has_evidence_values() and not is_extra: # Propagate constraint: if one of the other nodes is True: this one is false for n in self.nodes: if formula.get_evidence_value(n) == formula.TRUE: return formula.FALSE if formula.get_evidence_value(node) == formula.FALSE: return node elif formula.get_evidence_value(node) == formula.TRUE: for n in self.nodes: formula.set_evidence_value(n, formula.FALSE) if formula.semiring: sr = formula.semiring w = formula.get_weight(node, sr) for n in self.nodes: w = sr.plus(w, formula.get_weight(n, sr)) if sr.is_one(w): unknown = None if formula.get_evidence_value(node) != formula.FALSE: unknown = node for n in self.nodes: if formula.get_evidence_value(n) != formula.FALSE: if unknown is not None: unknown = None break else: unknown = n if unknown is not None: formula.set_evidence_value(unknown, formula.TRUE) if is_extra: self.extra_node = node else: self.nodes.add(node) if cr_extra and len(self.nodes) > 1 and self.extra_node is None: # If there are two or more choices -> add extra choice node self._update_logic(formula) return node
[docs] def as_clauses(self): if self.is_nontrivial(): nodes = list(self.nodes) + [self.extra_node] lines = [] for i, n in enumerate(nodes): for m in nodes[i + 1 :]: lines.append((-n, -m)) # mutually exclusive lines.append(nodes) # pick one return lines else: return []
def _update_logic(self, formula): """Add extra information to the logic structure of the formula. :param formula: formula to update """ if self.is_nontrivial(): name = Term( "choice", Constant(self.group[0]), Term("e"), Term("null"), *self.group[1] ) self.extra_node = formula.add_atom( ("%s_extra" % (self.group,)), True, name=name, group=self.group, is_extra=True, ) # formula.addConstraintOnNode(self, self.extra_node)
[docs] def update_weights(self, weights, semiring): if self.is_nontrivial(): ws = [] for n in self.nodes: pos, neg = weights.get(n, (semiring.one(), semiring.one())) weights[n] = (pos, semiring.ad_negate(pos, neg)) ws.append(pos) name = Term( "choice", Constant(self.group[0]), Term("e"), Term("null"), *self.group[1] ) try: complement = semiring.ad_complement(ws, key=name) if not semiring.in_domain(complement): raise InvalidValue( "Sum of annotated disjunction weigths exceeds acceptable value", location=self.location, ) except InvalidValue: raise InvalidValue( "Sum of annotated disjunction weigths exceeds acceptable value", location=self.location, ) # TODO add location weights[self.extra_node] = ( complement, semiring.ad_negate(complement, semiring.one()), )
[docs] def copy(self, rename=None): if rename is None: rename = {} result = ConstraintAD(self.group) result.nodes = set(rename.get(x, x) for x in self.nodes) result.extra_node = rename.get(self.extra_node, self.extra_node) return result
[docs] def check(self, values): """Check the constraint :param values: dictionary of values for nodes :return: True if constraint succeeds, False otherwise """ if self.is_true(): return True elif self.is_false(): return False else: actual = [ values.get(i) for i in self.get_nodes() if values.get(i) is not None ] return sum(actual) == 1
[docs] def propagate(self, values, weights, node=None): """Returns - True: constraint satisfied - False: constraint violated - None: unknown """ if node is not None and node not in self.get_nodes(): return if self.is_true(): return True elif self.is_false(): return False else: # print ([(i, values[i]) for i in self.get_nodes() if values.get(i) is not None], self.get_nodes()) # If there is a true value: set all the others to false true_values = [i for i in self.get_nodes() if values.get(i) == 1.0] if len(true_values) == 1: v = true_values[0] for i in self.get_nodes(): if i != v: values[i] = 0.0 # print ('a', values) return True elif len(true_values) > 1: # print('b', values) return False else: false_values = set( [i for i in self.get_nodes() if values.get(i) == 0.0] ) remain = 1.0 - sum(weights[v] for v in false_values) # if len(false_values) == len(self.get_nodes()) - 1: for i in self.get_nodes(): if not i in false_values: values[i] = weights[i] / remain # print ('c', values) return True
# print ('d', values)
[docs]class ClauseConstraint(Constraint): """A constraint specifying that a given clause should be true.""" def __init__(self, nodes): self.nodes = nodes
[docs] def as_clauses(self): return [self.nodes]
[docs] def copy(self, rename=None): if rename is None: rename = {} return ClauseConstraint(map(lambda x: rename.get(x, x), self.nodes))
def __str__(self): return "%s is true" % self.nodes
[docs]class TrueConstraint(Constraint): """A constraint specifying that a given node should be true.""" def __init__(self, node): self.node = node
[docs] def get_nodes(self): return [self.node]
[docs] def as_clauses(self): return [[self.node]]
[docs] def copy(self, rename=None): if rename is None: rename = {} return TrueConstraint(rename.get(self.node, self.node))
def __str__(self): return "%s is true" % self.node