"""
problog.constraint - Propositional constraints
----------------------------------------------
Data structures for specifying propositional constraints.
..
Part of the ProbLog distribution.
Copyright 2015 KU Leuven, DTAI Research Group
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .errors import InvalidValue
from .logic import Term, Constant
[docs]class Constraint(object):
"""A propositional constraint."""
[docs] def get_nodes(self):
"""Get all nodes involved in this constraint."""
raise NotImplementedError("abstract method")
[docs] def update_weights(self, weights, semiring):
"""Update the weights in the given dictionary according to the constraints.
:param weights: dictionary of weights (see result of :func:`LogicFormula.extract_weights`)
:param semiring: semiring to use for weight transformation
"""
# Typically, constraints don't update weights
pass
[docs] def is_true(self):
"""Checks whether the constraint is trivially true."""
return False
[docs] def is_false(self):
"""Checks whether the constraint is trivially false."""
return False
[docs] def is_nontrivial(self):
"""Checks whether the constraint is non-trivial."""
return not self.is_true() and not self.is_false()
[docs] def as_clauses(self):
"""Represent the constraint as a list of clauses (CNF form).
:return: list of clauses where each clause is represent as a list of node keys
:rtype: list[list[int]]
"""
raise NotImplementedError("abstract method")
[docs] def copy(self, rename=None):
"""Copy this constraint while applying the given node renaming.
:param rename: node rename map (or None if no rename is required)
:return: copy of the current constraint
"""
raise NotImplementedError("abstract method")
[docs]class ConstraintAD(Constraint):
"""Annotated disjunction constraint (mutually exclusive with weight update)."""
def __init__(self, group):
self.nodes = set()
self.group = group
self.extra_node = None
self.location = None
def __str__(self):
return "annotated_disjunction(%s, %s)" % (list(self.nodes), self.extra_node)
[docs] def get_nodes(self):
if self.extra_node:
return list(self.nodes) + [self.extra_node]
else:
return self.nodes
[docs] def is_true(self):
return len(self.nodes) <= 1
[docs] def is_false(self):
return False
[docs] def add(self, node, formula, cr_extra=True):
"""Add a node to the constraint from the given formula.
:param node: node to add
:param formula: formula from which the node is taken
:param cr_extra: Create an extra_node when required (when it is None and this is the second atom of the group).
:return: value of the node after constraint propagation
"""
if node in self.nodes:
return node
is_extra = formula.get_node(node).is_extra
try:
if (
not self.location
and formula.get_node(node).name
and formula.get_node(node).name.args
):
if formula.database:
self.location = formula.database.lineno(
formula.get_node(node).name.args[-1].location
)
except AttributeError:
pass
if formula.has_evidence_values() and not is_extra:
# Propagate constraint: if one of the other nodes is True: this one is false
for n in self.nodes:
if formula.get_evidence_value(n) == formula.TRUE:
return formula.FALSE
if formula.get_evidence_value(node) == formula.FALSE:
return node
elif formula.get_evidence_value(node) == formula.TRUE:
for n in self.nodes:
formula.set_evidence_value(n, formula.FALSE)
if formula.semiring:
sr = formula.semiring
w = formula.get_weight(node, sr)
for n in self.nodes:
w = sr.plus(w, formula.get_weight(n, sr))
if sr.is_one(w):
unknown = None
if formula.get_evidence_value(node) != formula.FALSE:
unknown = node
for n in self.nodes:
if formula.get_evidence_value(n) != formula.FALSE:
if unknown is not None:
unknown = None
break
else:
unknown = n
if unknown is not None:
formula.set_evidence_value(unknown, formula.TRUE)
if is_extra:
self.extra_node = node
else:
self.nodes.add(node)
if cr_extra and len(self.nodes) > 1 and self.extra_node is None:
# If there are two or more choices -> add extra choice node
self._update_logic(formula)
return node
[docs] def as_clauses(self):
if self.is_nontrivial():
nodes = list(self.nodes) + [self.extra_node]
lines = []
for i, n in enumerate(nodes):
for m in nodes[i + 1 :]:
lines.append((-n, -m)) # mutually exclusive
lines.append(nodes) # pick one
return lines
else:
return []
def _update_logic(self, formula):
"""Add extra information to the logic structure of the formula.
:param formula: formula to update
"""
if self.is_nontrivial():
name = Term(
"choice",
Constant(self.group[0]),
Term("e"),
Term("null"),
*self.group[1]
)
self.extra_node = formula.add_atom(
("%s_extra" % (self.group,)),
True,
name=name,
group=self.group,
is_extra=True,
)
# formula.addConstraintOnNode(self, self.extra_node)
[docs] def update_weights(self, weights, semiring):
if self.is_nontrivial():
ws = []
for n in self.nodes:
pos, neg = weights.get(n, (semiring.one(), semiring.one()))
weights[n] = (pos, semiring.ad_negate(pos, neg))
ws.append(pos)
name = Term(
"choice",
Constant(self.group[0]),
Term("e"),
Term("null"),
*self.group[1]
)
try:
complement = semiring.ad_complement(ws, key=name)
if not semiring.in_domain(complement):
raise InvalidValue(
"Sum of annotated disjunction weigths exceeds acceptable value",
location=self.location,
)
except InvalidValue:
raise InvalidValue(
"Sum of annotated disjunction weigths exceeds acceptable value",
location=self.location,
)
# TODO add location
weights[self.extra_node] = (
complement,
semiring.ad_negate(complement, semiring.one()),
)
[docs] def copy(self, rename=None):
if rename is None:
rename = {}
result = ConstraintAD(self.group)
result.nodes = set(rename.get(x, x) for x in self.nodes)
result.extra_node = rename.get(self.extra_node, self.extra_node)
return result
[docs] def check(self, values):
"""Check the constraint
:param values: dictionary of values for nodes
:return: True if constraint succeeds, False otherwise
"""
if self.is_true():
return True
elif self.is_false():
return False
else:
actual = [
values.get(i) for i in self.get_nodes() if values.get(i) is not None
]
return sum(actual) == 1
[docs] def propagate(self, values, weights, node=None):
"""Returns
- True: constraint satisfied
- False: constraint violated
- None: unknown
"""
if node is not None and node not in self.get_nodes():
return
if self.is_true():
return True
elif self.is_false():
return False
else:
# print ([(i, values[i]) for i in self.get_nodes() if values.get(i) is not None], self.get_nodes())
# If there is a true value: set all the others to false
true_values = [i for i in self.get_nodes() if values.get(i) == 1.0]
if len(true_values) == 1:
v = true_values[0]
for i in self.get_nodes():
if i != v:
values[i] = 0.0
# print ('a', values)
return True
elif len(true_values) > 1:
# print('b', values)
return False
else:
false_values = set(
[i for i in self.get_nodes() if values.get(i) == 0.0]
)
remain = 1.0 - sum(weights[v] for v in false_values)
# if len(false_values) == len(self.get_nodes()) - 1:
for i in self.get_nodes():
if not i in false_values:
values[i] = weights[i] / remain
# print ('c', values)
return True
# print ('d', values)
[docs]class ClauseConstraint(Constraint):
"""A constraint specifying that a given clause should be true."""
def __init__(self, nodes):
self.nodes = nodes
[docs] def as_clauses(self):
return [self.nodes]
[docs] def copy(self, rename=None):
if rename is None:
rename = {}
return ClauseConstraint(map(lambda x: rename.get(x, x), self.nodes))
def __str__(self):
return "%s is true" % self.nodes
[docs]class TrueConstraint(Constraint):
"""A constraint specifying that a given node should be true."""
def __init__(self, node):
self.node = node
[docs] def get_nodes(self):
return [self.node]
[docs] def as_clauses(self):
return [[self.node]]
[docs] def copy(self, rename=None):
if rename is None:
rename = {}
return TrueConstraint(rename.get(self.node, self.node))
def __str__(self):
return "%s is true" % self.node