utils.py 25 ko
Newer Older
"""Provides some utilities widely used by other modules."""
norvig's avatar
norvig a validé

import bisect
Peter Norvig's avatar
Peter Norvig a validé
import collections
Peter Norvig's avatar
Peter Norvig a validé
import collections.abc
Peter Norvig's avatar
Peter Norvig a validé
import operator
import os.path
import random
Tarun Kumar Vangani's avatar
Tarun Kumar Vangani a validé
import math
C.G.Vedant's avatar
C.G.Vedant a validé
import functools
Donato Meoli's avatar
Donato Meoli a validé
from statistics import mean

import numpy as np
from itertools import chain, combinations

# ______________________________________________________________________________
Peter Norvig's avatar
Peter Norvig a validé
# Functions on Sequences and Iterables
MircoT's avatar
MircoT a validé

Peter Norvig's avatar
Peter Norvig a validé
def sequence(iterable):
    """Converts iterable to sequence, if it is not already one."""
    return iterable if isinstance(iterable, collections.abc.Sequence) else tuple([iterable])
Robert Hönig's avatar
Robert Hönig a validé
    """Return a copy of seq (or string) with all occurrences of item removed."""
withal's avatar
withal a validé
        return seq.replace(item, '')
    elif isinstance(seq, set):
        rest = seq.copy()
        rest.remove(item)
        return rest
withal's avatar
withal a validé
        return [x for x in seq if x != item]
def unique(seq):
    """Remove duplicate elements from seq. Assumes hashable elements."""
def count(seq):
    """Count the number of items in sequence that are interpreted as true."""
    return sum(map(bool, seq))
Peter Norvig's avatar
Peter Norvig a validé
def multimap(items):
    """Given (key, val) pairs, return {key: [val, ....], ...}."""
    result = collections.defaultdict(list)
Peter Norvig's avatar
Peter Norvig a validé
    for (key, val) in items:
        result[key].append(val)
Peter Norvig's avatar
Peter Norvig a validé

Peter Norvig's avatar
Peter Norvig a validé
def multimap_items(mmap):
    """Yield all (key, val) pairs stored in the multimap."""
    for (key, vals) in mmap.items():
        for val in vals:
            yield key, val
    """Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
    result = 1
    for x in numbers:
        result *= x
utk1610's avatar
utk1610 a validé
    return result
def first(iterable, default=None):
Peter Norvig's avatar
Peter Norvig a validé
    """Return the first element of an iterable; or default."""
    return next(iter(iterable), default)
Chipe1's avatar
Chipe1 a validé
def is_in(elt, seq):
    """Similar to (elt in seq), but compares with 'is', not '=='."""
    return any(x is elt for x in seq)

def mode(data):
Peter Norvig's avatar
Peter Norvig a validé
    """Return the most common data item. If there are ties, return any one of them."""
Peter Norvig's avatar
Peter Norvig a validé
    [(item, count)] = collections.Counter(data).most_common(1)
Peter Norvig's avatar
Peter Norvig a validé
    return item

def powerset(iterable):
    """powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:]
def extend(s, var, val):
    """Copy dict s and extend it by setting var to val; return copy."""
    s2 = s.copy()
    s2[var] = val
    return s2


Peter Norvig's avatar
Peter Norvig a validé
# ______________________________________________________________________________
# argmin and argmax

Peter Norvig's avatar
Peter Norvig a validé
identity = lambda x: x
Peter Norvig's avatar
Peter Norvig a validé
argmin = min
argmax = max
Peter Norvig's avatar
Peter Norvig a validé
def argmin_random_tie(seq, key=identity):
    """Return a minimum element of seq; break ties at random."""
    return argmin(shuffled(seq), key=key)
Peter Norvig's avatar
Peter Norvig a validé
def argmax_random_tie(seq, key=identity):
    """Return an element with highest fn(seq[i]) score; break ties at random."""
Peter Norvig's avatar
Peter Norvig a validé
    return argmax(shuffled(seq), key=key)
Peter Norvig's avatar
Peter Norvig a validé
def shuffled(iterable):
    """Randomly shuffle a copy of iterable."""
Peter Norvig's avatar
Peter Norvig a validé
    items = list(iterable)
    random.shuffle(items)
# ______________________________________________________________________________
# Statistical and mathematical functions

def histogram(values, mode=0, bin_function=None):
    """Return a list of (value, count) pairs, summarizing the input values.
    Sorted by increasing value, or if mode=1, by decreasing count.
    If bin_function is given, map it over values first."""
        values = map(bin_function, values)
    bins = {}
    for val in values:
        bins[val] = bins.get(val, 0) + 1
        return sorted(list(bins.items()), key=lambda x: (x[1], x[0]), reverse=True)
    """Return the sum of the element-wise product of vectors X and Y."""
    return sum(x * y for x, y in zip(X, Y))
def element_wise_product(X, Y):
    """Return vector as an element-wise product of vectors X and Y"""
    assert len(X) == len(Y)
norvig's avatar
norvig a validé
    return [x * y for x, y in zip(X, Y)]
def matrix_multiplication(X_M, *Y_M):
Robert Hönig's avatar
Robert Hönig a validé
    """Return a matrix as a matrix-multiplication of X_M and arbitrary number of matrices *Y_M"""
    def _mat_mult(X_M, Y_M):
        """Return a matrix as a matrix-multiplication of two matrices X_M and Y_M
        >>> matrix_multiplication([[1, 2, 3], [2, 3, 4]], [[3, 4], [1, 2], [1, 0]])
        [[8, 8],[13, 14]]
        """
        assert len(X_M[0]) == len(Y_M)
        result = [[0 for i in range(len(Y_M[0]))] for _ in range(len(X_M))]
        for i in range(len(X_M)):
            for j in range(len(Y_M[0])):
                for k in range(len(Y_M)):
                    result[i][j] += X_M[i][k] * Y_M[k][j]

    result = X_M
    for Y in Y_M:
        result = _mat_mult(result, Y)

def vector_to_diagonal(v):
    """Converts a vector to a diagonal matrix with vector elements
    as the diagonal elements of the matrix"""
    diag_matrix = [[0 for i in range(len(v))] for _ in range(len(v))]
    for i in range(len(v)):
        diag_matrix[i][i] = v[i]

    return diag_matrix
    """Component-wise addition of two vectors."""
    return tuple(map(operator.add, a, b))

def scalar_vector_product(X, Y):
    """Return vector as a product of a scalar and a vector"""
    return [X * y for y in Y]

def scalar_matrix_product(X, Y):
    """Return matrix as a product of a scalar and a matrix"""
    return [scalar_vector_product(X, y) for y in Y]
def inverse_matrix(X):
    """Inverse a given square matrix of size 2x2"""
    assert len(X) == 2
    assert len(X[0]) == 2
    det = X[0][0] * X[1][1] - X[0][1] * X[1][0]
    assert det != 0
    inv_mat = scalar_matrix_product(1.0 / det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]])
    """Return true with probability p."""
    return p > random.uniform(0.0, 1.0)

def weighted_sample_with_replacement(n, seq, weights):
    """Pick n samples from seq at random, with replacement, with the
    probability of each element in proportion to its corresponding
    weight."""
withal's avatar
withal a validé
    sample = weighted_sampler(seq, weights)
    return [sample() for _ in range(n)]
withal's avatar
withal a validé

withal's avatar
withal a validé
def weighted_sampler(seq, weights):
    """Return a random-sample function that picks from seq weighted by weights."""
    totals = []
    for w in weights:
        totals.append(w + totals[-1] if totals else w)
withal's avatar
withal a validé
    return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
Peter Norvig's avatar
Peter Norvig a validé
def weighted_choice(choices):
    """A weighted version of random.choice"""
    # NOTE: should be replaced by random.choices if we port to Python 3.6
Peter Norvig's avatar
Peter Norvig a validé

    total = sum(w for _, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c, w
        upto += w
Anthony Marakis's avatar
Anthony Marakis a validé
def rounder(numbers, d=4):
    """Round a single number, or sequence of numbers, to d decimal places."""
    if isinstance(numbers, (int, float)):
        return round(numbers, d)
        constructor = type(numbers)  # Can be list, set, tuple, etc.
        return constructor(rounder(n, d) for n in numbers)
    """The argument is a string; convert to a number if possible, or strip it."""
withal's avatar
withal a validé
        return int(x)
withal's avatar
withal a validé
            return float(x)
withal's avatar
withal a validé
            return str(x).strip()
Donato Meoli's avatar
Donato Meoli a validé
def euclidean_distance(X, Y):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(X, Y)))


def cross_entropy_loss(X, Y):
    n = len(X)
    return (-1.0 / n) * sum(x * math.log(y) + (1 - x) * math.log(1 - y) for x, y in zip(X, Y))


def rms_error(X, Y):
    return math.sqrt(ms_error(X, Y))


def ms_error(X, Y):
    return mean((x - y) ** 2 for x, y in zip(X, Y))


def mean_error(X, Y):
    return mean(abs(x - y) for x, y in zip(X, Y))


def manhattan_distance(X, Y):
    return sum(abs(x - y) for x, y in zip(X, Y))


def mean_boolean_error(X, Y):
    return mean(x != y for x, y in zip(X, Y))
Donato Meoli's avatar
Donato Meoli a validé


def hamming_distance(X, Y):
    return sum(x != y for x, y in zip(X, Y))


def normalize(dist):
    """Multiply each number by a constant such that the sum is 1.0"""
    if isinstance(dist, dict):
        total = sum(dist.values())
        for key in dist:
            dist[key] = dist[key] / total
            assert 0 <= dist[key] <= 1  # Probabilities must be between 0 and 1
        return dist
    total = sum(dist)
    return [(n / total) for n in dist]
C.G.Vedant's avatar
C.G.Vedant a validé
def norm(X, n=2):
    """Return the n-norm of vector X"""
    return sum([x ** n for x in X]) ** (1 / n)
def random_weights(min_value, max_value, num_weights):
    return [random.uniform(min_value, max_value) for _ in range(num_weights)]


withal's avatar
withal a validé
def clip(x, lowest, highest):
    """Return x clipped to the range [lowest..highest]."""
withal's avatar
withal a validé
    return max(lowest, min(x, highest))


def sigmoid_derivative(value):
    return value * (1 - value)


def sigmoid(x):
    """Return activation value of x with sigmoid function"""
    return 1 / (1 + math.exp(-x))

def relu_derivative(value):

def elu(x, alpha=0.01):
    return x if x > 0 else alpha * (math.exp(x) - 1)
    return 1 if value > 0 else alpha * math.exp(value)

def tanh(x):

def tanh_derivative(value):
    return 1 - (value ** 2)
    return x if x > 0 else alpha * x


def leaky_relu_derivative(value, alpha=0.01):
    return 1 if value > 0 else alpha
Nouman Ahmed's avatar
Nouman Ahmed a validé
def relu(x):
Nouman Ahmed's avatar
Nouman Ahmed a validé
def relu_derivative(value):
    return 1 if value > 0 else 0
def step(x):
    """Return activation value of x with sign function"""
    return 1 if x >= 0 else 0
def gaussian(mean, st_dev, x):
    """Given the mean and standard deviation of a distribution, it returns the probability of x."""
    return 1 / (math.sqrt(2 * math.pi) * st_dev) * math.e ** (-0.5 * (float(x - mean) / st_dev) ** 2)
Peter Norvig's avatar
Peter Norvig a validé
try:  # math.isclose was added in Python 3.5; but we might be in 3.4
    from math import isclose
except ImportError:
    def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
        """Return true if numbers a and b are close to each other."""
        return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)

def truncated_svd(X, num_val=2, max_iter=1000):
    """Compute the first component of SVD."""

    def normalize_vec(X, n=2):
        """Normalize two parts (:m and m:) of the vector."""
        X_m = X[:m]
        X_n = X[m:]
        norm_X_m = norm(X_m, n)
        Y_m = [x / norm_X_m for x in X_m]
        norm_X_n = norm(X_n, n)
        Y_n = [x / norm_X_n for x in X_n]
        return Y_m + Y_n

    def remove_component(X):
        """Remove components of already obtained eigen vectors from X."""
        X_m = X[:m]
        X_n = X[m:]
        for eivec in eivec_m:
            coeff = dot_product(X_m, eivec)
            X_m = [x1 - coeff * x2 for x1, x2 in zip(X_m, eivec)]
        for eivec in eivec_n:
            coeff = dot_product(X_n, eivec)
            X_n = [x1 - coeff * x2 for x1, x2 in zip(X_n, eivec)]
        return X_m + X_n

    m, n = len(X), len(X[0])
    A = [[0] * (n + m) for _ in range(n + m)]
    for i in range(m):
        for j in range(n):
            A[i][m + j] = A[m + j][i] = X[i][j]

    eivec_m = []
    eivec_n = []
    eivals = []

    for _ in range(num_val):
        X = [random.random() for _ in range(m + n)]
        X = remove_component(X)
        X = normalize_vec(X)

        for i in range(max_iter):
            old_X = X
            X = matrix_multiplication(A, [[x] for x in X])
            X = [x[0] for x in X]
            X = remove_component(X)
            X = normalize_vec(X)
            # check for convergence
            if norm([x1 - x2 for x1, x2 in zip(old_X, X)]) <= 1e-10:
                break

        projected_X = matrix_multiplication(A, [[x] for x in X])
        projected_X = [x[0] for x in projected_X]
        new_eigenvalue = norm(projected_X, 1) / norm(X, 1)
        ev_m = X[:m]
        ev_n = X[m:]
        if new_eigenvalue < 0:
            new_eigenvalue = -new_eigenvalue
            ev_m = [-ev_m_i for ev_m_i in ev_m]
        eivals.append(new_eigenvalue)
        eivec_m.append(ev_m)
        eivec_n.append(ev_n)
    return eivec_m, eivec_n, eivals


# ______________________________________________________________________________
# Grid Functions


orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
turns = LEFT, RIGHT = (+1, -1)


def turn_heading(heading, inc, headings=orientations):
    return headings[(headings.index(heading) + inc) % len(headings)]


def turn_right(heading):
    return turn_heading(heading, RIGHT)


def turn_left(heading):
    return turn_heading(heading, LEFT)


def distance(a, b):
    """The distance between two (x, y) points."""
    xA, yA = a
    xB, yB = b
    return math.hypot((xA - xB), (yA - yB))


def distance_squared(a, b):
    """The square of the distance between two (x, y) points."""
    xA, yA = a
    xB, yB = b
    return (xA - xB) ** 2 + (yA - yB) ** 2


def vector_clip(vector, lowest, highest):
    """Return vector, except if any element is less than the corresponding
    value of lowest or more than the corresponding value of highest, clip to
    those values."""
    return type(vector)(map(clip, vector, lowest, highest))


# ______________________________________________________________________________
Peter Norvig's avatar
Peter Norvig a validé
    """Dependency injection of temporary values for global functions/classes/etc.
    E.g., `with injection(DataBase=MockDataBase): ...`"""

    def __init__(self, **kwds):
Peter Norvig's avatar
Peter Norvig a validé
        self.new = kwds
Peter Norvig's avatar
Peter Norvig a validé
        self.old = {v: globals()[v] for v in self.new}
        globals().update(self.new)

    def __exit__(self, type, value, traceback):
Peter Norvig's avatar
Peter Norvig a validé
        globals().update(self.old)

def memoize(fn, slot=None, maxsize=32):
    """Memoize fn: make it remember the computed value for any argument list.
    If slot is specified, store result in that slot of first argument.
    If slot is false, use lru_cache for caching the values."""
    if slot:
        def memoized_fn(obj, *args):
            if hasattr(obj, slot):
                return getattr(obj, slot)
            else:
                val = fn(obj, *args)
                setattr(obj, slot, val)
                return val
    else:
        @functools.lru_cache(maxsize=maxsize)

def name(obj):
    """Try to find some reasonable name for the object."""
    return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
            getattr(getattr(obj, '__class__', 0), '__name__', 0) or
            str(obj))
    """Is x a number?"""
    return hasattr(x, '__int__')
    """Is x a sequence?"""
Peter Norvig's avatar
Peter Norvig a validé
    return isinstance(x, collections.abc.Sequence)
def print_table(table, header=None, sep='   ', numfmt='{}'):
    """Print a list of lists as a table, so that columns line up nicely.
    header, if specified, will be printed as the first row.
    numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
    (If you want different formats in different columns,
    don't use print_table.) sep is the separator between columns."""
    justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]

        table.insert(0, header)

    table = [[numfmt.format(x) if isnumber(x) else x for x in row]
MircoT's avatar
MircoT a validé
             for row in table]
Donato Meoli's avatar
Donato Meoli a validé
    sizes = list(map(lambda seq: max(map(len, seq)), list(zip(*[map(str, row) for row in table]))))
Donato Meoli's avatar
Donato Meoli a validé
        print(sep.join(getattr(str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
def open_data(name, mode='r'):
    aima_root = os.path.dirname(__file__)
    aima_file = os.path.join(aima_root, *['aima-data', name])
    return open(aima_file, mode=mode)
def failure_test(algorithm, tests):
    """Grades the given algorithm based on how many tests it passes.
Robert Hönig's avatar
Robert Hönig a validé
    Most algorithms have arbitrary output on correct execution, which is difficult
    to check for correctness. On the other hand, a lot of algorithms output something
    particular on fail (for example, False, or None).
    tests is a list with each element in the form: (values, failure_output)."""
    from statistics import mean
    return mean(int(algorithm(x) != y) for x, y in tests)


Peter Norvig's avatar
Peter Norvig a validé
# ______________________________________________________________________________
# Expressions

# See https://docs.python.org/3/reference/expressions.html#operator-precedence
# See https://docs.python.org/3/reference/datamodel.html#special-method-names

Donato Meoli's avatar
Donato Meoli a validé
class Expr:
Peter Norvig's avatar
Peter Norvig a validé
    """A mathematical expression with an operator and 0 or more arguments.
    op is a str like '+' or 'sin'; args are Expressions.
    Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
    Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
Peter Norvig's avatar
Peter Norvig a validé
        self.op = str(op)
        self.args = args
Peter Norvig's avatar
Peter Norvig a validé
    # Operator overloads
    def __neg__(self):
        return Expr('-', self)

    def __pos__(self):
        return Expr('+', self)

    def __invert__(self):
        return Expr('~', self)

    def __add__(self, rhs):
        return Expr('+', self, rhs)

    def __sub__(self, rhs):
        return Expr('-', self, rhs)

    def __mul__(self, rhs):
        return Expr('*', self, rhs)

    def __pow__(self, rhs):
        return Expr('**', self, rhs)

    def __mod__(self, rhs):
        return Expr('%', self, rhs)

    def __and__(self, rhs):
        return Expr('&', self, rhs)

    def __xor__(self, rhs):
        return Expr('^', self, rhs)

    def __rshift__(self, rhs):
        return Expr('>>', self, rhs)

    def __lshift__(self, rhs):
        return Expr('<<', self, rhs)

    def __truediv__(self, rhs):
        return Expr('/', self, rhs)

    def __floordiv__(self, rhs):
        return Expr('//', self, rhs)

    def __matmul__(self, rhs):
        return Expr('@', self, rhs)
Peter Norvig's avatar
Peter Norvig a validé

norvig's avatar
norvig a validé
    def __or__(self, rhs):
        """Allow both P | Q, and P |'==>'| Q."""
        if isinstance(rhs, Expression):
            return Expr('|', self, rhs)
Peter Norvig's avatar
Peter Norvig a validé
        else:
Peter Norvig's avatar
Peter Norvig a validé
            return PartialExpr(rhs, self)
Peter Norvig's avatar
Peter Norvig a validé
    # Reverse operator overloads
    def __radd__(self, lhs):
        return Expr('+', lhs, self)

    def __rsub__(self, lhs):
        return Expr('-', lhs, self)

    def __rmul__(self, lhs):
        return Expr('*', lhs, self)

    def __rdiv__(self, lhs):
        return Expr('/', lhs, self)

    def __rpow__(self, lhs):
        return Expr('**', lhs, self)

    def __rmod__(self, lhs):
        return Expr('%', lhs, self)

    def __rand__(self, lhs):
        return Expr('&', lhs, self)

    def __rxor__(self, lhs):
        return Expr('^', lhs, self)

    def __ror__(self, lhs):
        return Expr('|', lhs, self)

    def __rrshift__(self, lhs):
        return Expr('>>', lhs, self)

    def __rlshift__(self, lhs):
        return Expr('<<', lhs, self)

    def __rtruediv__(self, lhs):
        return Expr('/', lhs, self)

    def __rfloordiv__(self, lhs):
        return Expr('//', lhs, self)

    def __rmatmul__(self, lhs):
        return Expr('@', lhs, self)
        """Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."""
        if self.args:
            raise ValueError('can only do a call for a Symbol, not an Expr')
        else:
            return Expr(self.op, *args)
Peter Norvig's avatar
Peter Norvig a validé

    # Equality and repr
        """x == y' evaluates to True or False; does not build an Expr."""
        return isinstance(other, Expr) and self.op == other.op and self.args == other.args
        return isinstance(other, Expr) and str(self) < str(other)
    def __hash__(self):
        return hash(self.op) ^ hash(self.args)
Peter Norvig's avatar
Peter Norvig a validé
    def __repr__(self):
Peter Norvig's avatar
Peter Norvig a validé
        args = [str(arg) for arg in self.args]
        if op.isidentifier():  # f(x) or f(x, y)
Peter Norvig's avatar
Peter Norvig a validé
            return '{}({})'.format(op, ', '.join(args)) if args else op
        elif len(args) == 1:  # -x or -(x + 1)
Peter Norvig's avatar
Peter Norvig a validé
            return op + args[0]
Peter Norvig's avatar
Peter Norvig a validé
            opp = (' ' + op + ' ')
            return '(' + opp.join(args) + ')'

Peter Norvig's avatar
Peter Norvig a validé
# An 'Expression' is either an Expr or a Number.
# Symbol is not an explicit type; it is any Expr with 0 args.

Number = (int, float, complex)
Peter Norvig's avatar
Peter Norvig a validé
Expression = (Expr, Number)

Peter Norvig's avatar
Peter Norvig a validé
def Symbol(name):
    """A Symbol is just an Expr with no args."""
Peter Norvig's avatar
Peter Norvig a validé
    return Expr(name)

Peter Norvig's avatar
Peter Norvig a validé
def symbols(names):
    """Return a tuple of Symbols; names is a comma/whitespace delimited str."""
Peter Norvig's avatar
Peter Norvig a validé
    return tuple(Symbol(name) for name in names.replace(',', ' ').split())

Peter Norvig's avatar
Peter Norvig a validé
def subexpressions(x):
    """Yield the subexpressions of an Expression (including x itself)."""
Peter Norvig's avatar
Peter Norvig a validé
    yield x
    if isinstance(x, Expr):
        for arg in x.args:
            yield from subexpressions(arg)

Peter Norvig's avatar
Peter Norvig a validé
def arity(expression):
    """The number of sub-expressions in this expression."""
Peter Norvig's avatar
Peter Norvig a validé
    if isinstance(expression, Expr):
        return len(expression.args)
    else:  # expression is a number
Peter Norvig's avatar
Peter Norvig a validé
        return 0

Peter Norvig's avatar
Peter Norvig a validé
# For operators that are not defined in Python, we allow new InfixOps:

Peter Norvig's avatar
Peter Norvig a validé
class PartialExpr:
    """Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q."""
    def __init__(self, op, lhs):
        self.op, self.lhs = op, lhs

    def __or__(self, rhs):
        return Expr(self.op, self.lhs, rhs)

    def __repr__(self):
        return "PartialExpr('{}', {})".format(self.op, self.lhs)
Peter Norvig's avatar
Peter Norvig a validé
def expr(x):
    """Shortcut to create an Expression. x is a str in which:
    - identifiers are automatically defined as Symbols.
Peter Norvig's avatar
Peter Norvig a validé
    - ==> is treated as an infix |'==>'|, as are <== and <=>.
Peter Norvig's avatar
Peter Norvig a validé
    If x is already an Expression, it is returned unchanged. Example:
    >>> expr('P & Q ==> Q')
    ((P & Q) ==> Q)
    """
    return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol)) if isinstance(x, str) else x
Peter Norvig's avatar
Peter Norvig a validé
infix_ops = '==> <== <=>'.split()

Peter Norvig's avatar
Peter Norvig a validé
def expr_handle_infix_ops(x):
Peter Norvig's avatar
Peter Norvig a validé
    """Given a str, return a new str with ==> replaced by |'==>'|, etc.
Peter Norvig's avatar
Peter Norvig a validé
    >>> expr_handle_infix_ops('P ==> Q')
Peter Norvig's avatar
Peter Norvig a validé
    "P |'==>'| Q"
Peter Norvig's avatar
Peter Norvig a validé
    """
    for op in infix_ops:
Peter Norvig's avatar
Peter Norvig a validé
        x = x.replace(op, '|' + repr(op) + '|')
Peter Norvig's avatar
Peter Norvig a validé
    return x

Peter Norvig's avatar
Peter Norvig a validé
class defaultkeydict(collections.defaultdict):
    """Like defaultdict, but the default_factory is a function of the key.
    >>> d = defaultkeydict(len); d['four']
    4
    """
Peter Norvig's avatar
Peter Norvig a validé
    def __missing__(self, key):
        self[key] = result = self.default_factory(key)
        return result
class hashabledict(dict):
    """Allows hashing by representing a dictionary as tuple of key:value pairs
       May cause problems as the hash value may change during runtime
    """

    def __hash__(self):
# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue
# Stack and FIFOQueue are implemented as list and collection.deque
# PriorityQueue is implemented here
class PriorityQueue:
    """A Queue in which the minimum (or maximum) element (as determined by f and
    order) is returned first.
    If order is 'min', the item with minimum f(x) is
    returned first; if order is 'max', then it is the item with maximum f(x).
    Also supports dict-like lookup."""
    def __init__(self, order='min', f=lambda x: x):
        self.heap = []
        if order == 'min':
            self.f = f
        elif order == 'max':  # now item with max f(x)
            self.f = lambda x: -f(x)  # will be popped first
        else:
Sagar's avatar
Sagar a validé
            raise ValueError("order must be either 'min' or 'max'.")
        """Insert item at its correct position."""
        heapq.heappush(self.heap, (self.f(item), item))
        """Insert each item in items at its correct position."""
        for item in items:
DKE's avatar
DKE a validé
            self.append(item)
withal's avatar
withal a validé
    def pop(self):
Sagar's avatar
Sagar a validé
        """Pop and return the item (with min or max f(x) value)
        depending on the order."""
        if self.heap:
            return heapq.heappop(self.heap)[1]
C.G.Vedant's avatar
C.G.Vedant a validé
        else:
            raise Exception('Trying to pop from empty PriorityQueue.')

    def __len__(self):
        """Return current capacity of PriorityQueue."""
        return len(self.heap)
    def __contains__(self, key):
        """Return True if the key is in PriorityQueue."""
        return any([item == key for _, item in self.heap])
    def __getitem__(self, key):
        """Returns the first value associated with key in PriorityQueue.
        Raises KeyError if key is not present."""
        for value, item in self.heap:
                return value
        raise KeyError(str(key) + " is not in the priority queue")
    def __delitem__(self, key):
        """Delete the first occurrence of key."""
        try:
            del self.heap[[item == key for _, item in self.heap].index(True)]
        except ValueError:
            raise KeyError(str(key) + " is not in the priority queue")
        heapq.heapify(self.heap)
# ______________________________________________________________________________
# Useful Shorthands


class Bool(int):
    """Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'"""
    __str__ = __repr__ = lambda self: 'T' if self else 'F'

T = Bool(True)
F = Bool(False)