Source code for problog.evaluator

"""
problog.evaluator - Commone interface for evaluation
----------------------------------------------------

Provides common interface for evaluation of weighted logic formulas.

..
    Part of the ProbLog distribution.

    Copyright 2015 KU Leuven, DTAI Research Group

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
"""
import math

from .core import ProbLogObject, transform_allow_subclass
from .errors import InconsistentEvidenceError, InvalidValue, ProbLogError


[docs]class OperationNotSupported(ProbLogError): def __init__(self): ProbLogError.__init__(self, "This operation is not supported by this semiring")
[docs]class Semiring(object): """Interface for weight manipulation. A semiring is a set R equipped with two binary operations '+' and 'x'. The semiring can use different representations for internal values and external values. For example, the LogProbability semiring uses probabilities [0, 1] as external values and uses \ the logarithm of these probabilities as internal values. Most methods take and return internal values. The execeptions are: - value, pos_value, neg_value: transform an external value to an internal value - result: transform an internal to an external value - result_zero, result_one: return an external value """
[docs] def one(self): """Returns the identity element of the multiplication.""" raise NotImplementedError()
[docs] def is_one(self, value): """Tests whether the given value is the identity element of the multiplication.""" return value == self.one
[docs] def zero(self): """Returns the identity element of the addition.""" raise NotImplementedError()
[docs] def is_zero(self, value): """Tests whether the given value is the identity element of the addition.""" return value == self.zero()
[docs] def plus(self, a, b): """Computes the addition of the given values.""" raise NotImplementedError()
[docs] def times(self, a, b): """Computes the multiplication of the given values.""" raise NotImplementedError()
[docs] def negate(self, a): """Returns the negation. This operation is optional. For example, for probabilities return 1-a. :raise OperationNotSupported: if the semiring does not support this operation """ raise OperationNotSupported()
[docs] def value(self, a): """Transform the given external value into an internal value.""" return float(a)
[docs] def result(self, a, formula=None): """Transform the given internal value into an external value.""" return a
[docs] def normalize(self, a, z): """Normalizes the given value with the given normalization constant. For example, for probabilities, returns a/z. :raise OperationNotSupported: if z is not one and the semiring does not support \ this operation """ if self.is_one(z): return a else: raise OperationNotSupported()
[docs] def pos_value(self, a, key=None): """Extract the positive internal value for the given external value.""" return self.value(a)
[docs] def neg_value(self, a, key=None): """Extract the negative internal value for the given external value.""" return self.negate(self.value(a))
[docs] def result_zero(self): """Give the external representation of the identity element of the addition.""" return self.result(self.zero())
[docs] def result_one(self): """Give the external representation of the identity element of the multiplication.""" return self.result(self.one())
[docs] def is_dsp(self): """Indicates whether this semiring requires solving a disjoint sum problem.""" return False
[docs] def is_nsp(self): """Indicates whether this semiring requires solving a neutral sum problem.""" return False
[docs] def in_domain(self, a): """Checks whether the given (internal) value is valid.""" return True
def ad_complement(self, ws, key=None): s = self.zero() for w in ws: s = self.plus(s, w) return self.negate(s)
[docs] def true(self, key=None): """Handle weight for deterministically true.""" return self.one(), self.zero()
[docs] def false(self, key=None): """Handle weight for deterministically false.""" return self.zero(), self.one()
[docs] def to_evidence(self, pos_weight, neg_weight, sign): """ Converts the pos. and neg. weight (internal repr.) of a literal into the case where the literal is evidence. Note that the literal can be a negative atom regardless of the given sign. :param pos_weight: The current positive weight of the literal. :param neg_weight: The current negative weight of the literal. :param sign: Denotes whether the literal or its negation is evidence. sign > 0 denotes the literal is evidence, otherwise its negation is evidence. Note: The literal itself can also still be a negative atom. :returns: A tuple of the positive and negative weight as if the literal was evidence. For example, for probability, returns (self.one(), self.zero()) if sign else (self.zero(), self.one()) """ return (self.one(), self.zero()) if sign > 0 else (self.zero(), self.one())
[docs] def ad_negate(self, pos_weight, neg_weight): """ Negation in the context of an annotated disjunction. e.g. in a probabilistic context for 0.2::a ; 0.8::b, the negative label for both a and b is 1.0 such that model {a,-b} = 0.2 * 1.0 and {-a,b} = 1.0 * 0.8. For a, pos_weight would be 0.2 and neg_weight could be 0.8. The returned value is 1.0. :param pos_weight: The current positive weight of the literal (e.g. 0.2 or 0.8). Internal representation. :param neg_weight: The current negative weight of the literal (e.g. 0.8 or 0.2). Internal representation. :return: neg_weight corrected based on the given pos_weight, given the ad context (e.g. 1.0). Internal representation. """ return self.one()
[docs] @classmethod def create(cls, *, engine, database, **kwargs): """Create an instance of this semiring class. Used for sub-queries. :param engine: Engine in use. :param database: Database in use. :param kwargs: Keyword arguments passed from subquery """ raise NotImplementedError()
[docs]class SemiringProbability(Semiring): """Implementation of the semiring interface for probabilities."""
[docs] def one(self): return 1.0
[docs] def zero(self): return 0.0
[docs] def is_one(self, value): return 1.0 - 1e-12 < value < 1.0 + 1e-12
[docs] def is_zero(self, value): return -1e-12 < value < 1e-12
[docs] def plus(self, a, b): return a + b
[docs] def times(self, a, b): return a * b
[docs] def negate(self, a): return 1.0 - a
[docs] def normalize(self, a, z): return a / z
[docs] def value(self, a): v = float(a) if 0.0 - 1e-9 <= v <= 1.0 + 1e-9: return v else: raise InvalidValue( "Not a valid value for this semiring: '%s'" % a, location=a.location )
[docs] def is_dsp(self): """Indicates whether this semiring requires solving a disjoint sum problem.""" return True
[docs] def in_domain(self, a): return 0.0 - 1e-9 <= a <= 1.0 + 1e-9
[docs] @classmethod def create(cls, *, engine, database, **kwargs): return cls()
[docs]class SemiringLogProbability(SemiringProbability): """Implementation of the semiring interface for probabilities with logspace calculations.""" inf, ninf = float("inf"), float("-inf")
[docs] def one(self): return 0.0
[docs] def zero(self): return self.ninf
[docs] def is_zero(self, value): return value <= -1e100
[docs] def is_one(self, value): return -1e-12 < value < 1e-12
[docs] def plus(self, a, b): if a < b: if a == self.ninf: return b return b + math.log1p(math.exp(a - b)) else: if b == self.ninf: return a return a + math.log1p(math.exp(b - a))
[docs] def times(self, a, b): return a + b
[docs] def negate(self, a): if not self.in_domain(a): raise InvalidValue("Not a valid value for this semiring: '%s'" % a) if a > -1e-10: return self.zero() return math.log1p(-math.exp(a))
[docs] def value(self, a): v = float(a) if -1e-9 <= v < 1e-9: return self.zero() else: if 0.0 - 1e-9 <= v <= 1.0 + 1e-9: return math.log(v) else: raise InvalidValue( "Not a valid value for this semiring: '%s'" % a, location=a.location )
[docs] def result(self, a, formula=None): return math.exp(a)
[docs] def normalize(self, a, z): # Assumes Z is in log return a - z
[docs] def is_dsp(self): """Indicates whether this semiring requires solving a disjoint sum problem.""" return True
[docs] def in_domain(self, a): return a <= 1e-12
[docs]class SemiringSymbolic(Semiring): """Implementation of the semiring interface for probabilities using symbolic calculations."""
[docs] def one(self): return "1"
[docs] def zero(self): return "0"
[docs] def plus(self, a, b): if a == "0": return b elif b == "0": return a else: return "(%s + %s)" % (a, b)
[docs] def times(self, a, b): if a == "0" or b == "0": return "0" elif a == "1": return b elif b == "1": return a else: return "%s*%s" % (a, b)
[docs] def negate(self, a): if a == "0": return "1" elif a == "1": return "0" else: return "(1-%s)" % a
[docs] def value(self, a): return str(a)
[docs] def normalize(self, a, z): if z == "1": return a else: return "%s / %s" % (a, z)
[docs] def is_dsp(self): """Indicates whether this semiring requires solving a disjoint sum problem.""" return True
[docs] @classmethod def create(cls, *, engine, database, **kwargs): return cls()
[docs]class Evaluatable(ProbLogObject): def evidence_all(self): raise NotImplementedError() def _create_evaluator(self, semiring, weights, **kwargs): """Create a new evaluator. :param semiring: semiring to use :param weights: weights to use (replace weights defined in formula) :return: evaluator :rtype: Evaluator """ raise NotImplementedError("Evaluatable._create_evaluator is an abstract method")
[docs] def get_evaluator( self, semiring=None, evidence=None, weights=None, keep_evidence=False, **kwargs ): """Get an evaluator for computing queries on this formula. It creates an new evaluator and initializes it with the given or predefined evidence. :param semiring: semiring to use :param evidence: evidence values (override values defined in formula) :type evidence: dict(Term, bool) :param weights: weights to use :return: evaluator for this formula """ if semiring is None: semiring = SemiringLogProbability() evaluator = self._create_evaluator(semiring, weights, **kwargs) for ev_name, ev_index, ev_value in self.evidence_all(): if ev_index == 0 and ev_value > 0: pass # true evidence is deterministically true elif ev_index is None and ev_value < 0: pass # false evidence is deterministically false elif ev_index == 0 and ev_value < 0: raise InconsistentEvidenceError( source="evidence(" + str(ev_name) + ",false)" ) # true evidence is false elif ev_index is None and ev_value > 0: raise InconsistentEvidenceError( source="evidence(" + str(ev_name) + ",true)" ) # false evidence is true elif evidence is None and ev_value != 0: evaluator.add_evidence(ev_value * ev_index) elif evidence is not None: try: value = evidence[ev_name] if value is None: pass elif value: evaluator.add_evidence(ev_index) else: evaluator.add_evidence(-ev_index) except KeyError: if keep_evidence: evaluator.add_evidence(ev_value * ev_index) evaluator.propagate() return evaluator
[docs] def evaluate( self, index=None, semiring=None, evidence=None, weights=None, **kwargs ): """Evaluate a set of nodes. :param index: node to evaluate (default: all queries) :param semiring: use the given semiring :param evidence: use the given evidence values (overrides formula) :param weights: use the given weights (overrides formula) :return: The result of the evaluation expressed as an external value of the semiring. \ If index is ``None`` (all queries) then the result is a dictionary of name to value. """ evaluator = self.get_evaluator(semiring, evidence, weights, **kwargs) if index is None: result = {} # Probability of query given evidence # interrupted = False for name, node, label in evaluator.formula.labeled(): w = evaluator.evaluate(node) result[name] = w return result else: return evaluator.evaluate(index)
[docs]@transform_allow_subclass class EvaluatableDSP(Evaluatable): """Interface for evaluatable formulae.""" def __init__(self): Evaluatable.__init__(self)
[docs]class Evaluator(object): """Generic evaluator.""" # noinspection PyUnusedLocal def __init__(self, formula, semiring, weights, **kwargs): self.formula = formula self.weights = {} self.given_weights = weights self.__semiring = semiring self.__evidence = [] @property def semiring(self): """Semiring used by this evaluator.""" return self.__semiring
[docs] def propagate(self): """Propagate changes in weight or evidence values.""" raise NotImplementedError("Evaluator.propagate() is an abstract method.")
[docs] def evaluate(self, index): """Compute the value of the given node.""" raise NotImplementedError("abstract method")
def evaluate_evidence(self): raise NotImplementedError("abstract method")
[docs] def evaluate_fact(self, node): """Evaluate fact. :param node: fact to evaluate :return: weight of the fact (as semiring result value) """ raise NotImplementedError("abstract method")
[docs] def add_evidence(self, node): """Add evidence""" self.__evidence.append(node)
[docs] def has_evidence(self): """Checks whether there is active evidence.""" return self.__evidence != []
[docs] def set_evidence(self, index, value): """Set value for evidence node. :param index: index of evidence node :param value: value of evidence. True if the evidence is positive, False otherwise. """ raise NotImplementedError("abstract method")
[docs] def set_weight(self, index, pos, neg): """Set weight of a node. :param index: index of node :param pos: positive weight (as semiring internal value) :param neg: negative weight (as semiring internal value) """ raise NotImplementedError("abstract method")
[docs] def clear_evidence(self): """Clear all evidence.""" self.__evidence = []
[docs] def evidence(self): """Iterate over evidence.""" return iter(self.__evidence)
[docs]class FormulaEvaluator(object): """Standard evaluator for boolean formula.""" def __init__(self, formula, semiring, weights=None): self._computed_weights = {} self._semiring = semiring self._formula = formula self._fact_weights = {} if weights is not None: self.set_weights(weights) @property def semiring(self): return self._semiring @property def formula(self): return self._formula
[docs] def set_weights(self, weights): """Set known weights. :param weights: dictionary of weights :return: """ self._computed_weights.clear() self._fact_weights = weights
[docs] def get_weight(self, index): """Get the weight of the node with the given index. :param index: integer or formula.TRUE or formula.FALSE :return: weight of the node """ if index == self.formula.TRUE: return self.semiring.one() elif index == self.formula.FALSE: return self.semiring.zero() elif index < 0: weight = self._fact_weights.get(abs(index)) if weight is None: # This will only work if the semiring support negation! nw = self.get_weight(-index) return self.semiring.negate(nw) else: return weight[1] else: weight = self._fact_weights.get(index) if weight is None: weight = self._computed_weights.get(index) if weight is None: weight = self.compute_weight(index) self._computed_weights[index] = weight return weight else: return weight[0]
def propagate(self): self._fact_weights = self.formula.extract_weights( self.semiring, self._fact_weights ) def evaluate(self, index): return self.semiring.result(self.get_weight(index), self.formula)
[docs] def compute_weight(self, index): """Compute the weight of the node with the given index. :param index: integer or formula.TRUE or formula.FALSE :return: weight of the node """ if index == self.formula.TRUE: return self.semiring.one() elif index == self.formula.FALSE: return self.semiring.zero() else: node = self.formula.get_node(abs(index)) ntype = type(node).__name__ if ntype == "atom": return self.semiring.one() else: childprobs = [self.get_weight(c) for c in node.children] if ntype == "conj": p = self.semiring.one() for c in childprobs: p = self.semiring.times(p, c) return p elif ntype == "disj": p = self.semiring.zero() for c in childprobs: p = self.semiring.plus(p, c) return p else: raise TypeError("Unexpected node type: '%s'." % ntype)
[docs]class FormulaEvaluatorNSP(FormulaEvaluator): """Evaluator for boolean formula that addresses the Neutral Sum Problem.""" def __init__(self, formula, semiring, weights=None): FormulaEvaluator.__init__(self, formula, semiring, weights)
[docs] def get_weight(self, index): """Get the weight of the node with the given index. :param index: integer or formula.TRUE or formula.FALSE :return: weight of the node and the set of abs(literals) involved """ if index == self.formula.TRUE: return self.semiring.one(), set() elif index == self.formula.FALSE: return self.semiring.zero(), set() elif index < 0: weight = self._fact_weights.get(-index) if weight is None: # This will only work if the semiring support negation! nw, nu = self.get_weight(-index) return self.semiring.negate(nw), nu else: return weight[1], {abs(index)} else: weight = self._fact_weights.get(index) if weight is None: weight = self._computed_weights.get(index) if weight is None: weight = self.compute_weight(index) self._computed_weights[index] = weight return weight else: return weight[0], {abs(index)}
def evaluate(self, index): cp, cu = self.get_weight(index) all_used = set(self._fact_weights.keys()) not_used = all_used - cu for nu in not_used: nu_p, a = self.get_weight(nu) nu_n, b = self.get_weight(-nu) cp = self.semiring.times(cp, self.semiring.plus(nu_p, nu_n)) return self.semiring.result(cp, self.formula)
[docs] def compute_weight(self, index): """Compute the weight of the node with the given index. :param index: integer or formula.TRUE or formula.FALSE :return: weight of the node """ node = self.formula.get_node(index) ntype = type(node).__name__ if ntype == "atom": return self.semiring.one(), {index} else: childprobs = [self.get_weight(c) for c in node.children] if ntype == "conj": p = self.semiring.one() all_used = set() for cp, cu in childprobs: all_used |= cu p = self.semiring.times(p, cp) return p, all_used elif ntype == "disj": p = self.semiring.zero() all_used = set() for cp, cu in childprobs: all_used |= cu for cp, cu in childprobs: not_used = all_used - cu for nu in not_used: nu_p, u = self.get_weight(nu) nu_n, u = self.get_weight(-nu) cp = self.semiring.times(cp, self.semiring.plus(nu_p, nu_n)) p = self.semiring.plus(p, cp) return p, all_used else: raise TypeError("Unexpected node type: '%s'." % ntype)