utils.py 21,1 ko
Newer Older
"""Provides some utilities widely used by other modules"""
norvig's avatar
norvig a validé

import bisect
Peter Norvig's avatar
Peter Norvig a validé
import collections
Peter Norvig's avatar
Peter Norvig a validé
import collections.abc
Peter Norvig's avatar
Peter Norvig a validé
import operator
import os.path
import random
Tarun Kumar Vangani's avatar
Tarun Kumar Vangani a validé
import math
C.G.Vedant's avatar
C.G.Vedant a validé
import functools
# ______________________________________________________________________________
Peter Norvig's avatar
Peter Norvig a validé
# Functions on Sequences and Iterables
MircoT's avatar
MircoT a validé

Peter Norvig's avatar
Peter Norvig a validé
def sequence(iterable):
    """Coerce iterable to sequence, if it is not already one."""
Peter Norvig's avatar
Peter Norvig a validé
    return (iterable if isinstance(iterable, collections.abc.Sequence)
            else tuple(iterable))

    """Return a copy of seq (or string) with all occurences of item removed."""
withal's avatar
withal a validé
        return seq.replace(item, '')
withal's avatar
withal a validé
        return [x for x in seq if x != item]

def unique(seq):  # TODO: replace with set
    """Remove duplicate elements from seq. Assumes hashable elements."""
def count(seq):
    """Count the number of items in sequence that are interpreted as true."""
    return sum(bool(x) for x in seq)

    """Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
    result = 1
    for x in numbers:
        result *= x
utk1610's avatar
utk1610 a validé
    return result
def first(iterable, default=None):
    """Return the first element of an iterable or the next element of a generator; or default."""
    try:
        return iterable[0]
    except IndexError:
        return default
    except TypeError:
        return next(iterable, default)
Chipe1's avatar
Chipe1 a validé
def is_in(elt, seq):
    """Similar to (elt in seq), but compares with 'is', not '=='."""
    return any(x is elt for x in seq)

def mode(data):
Peter Norvig's avatar
Peter Norvig a validé
    """Return the most common data item. If there are ties, return any one of them."""
Peter Norvig's avatar
Peter Norvig a validé
    [(item, count)] = collections.Counter(data).most_common(1)
Peter Norvig's avatar
Peter Norvig a validé
    return item

Peter Norvig's avatar
Peter Norvig a validé
# ______________________________________________________________________________
# argmin and argmax

Peter Norvig's avatar
Peter Norvig a validé
identity = lambda x: x
MircoT's avatar
MircoT a validé

Peter Norvig's avatar
Peter Norvig a validé
argmin = min
argmax = max
norvig's avatar
norvig a validé

Peter Norvig's avatar
Peter Norvig a validé
def argmin_random_tie(seq, key=identity):
    """Return a minimum element of seq; break ties at random."""
    return argmin(shuffled(seq), key=key)
MircoT's avatar
MircoT a validé

Peter Norvig's avatar
Peter Norvig a validé
def argmax_random_tie(seq, key=identity):
    """Return an element with highest fn(seq[i]) score; break ties at random."""
Peter Norvig's avatar
Peter Norvig a validé
    return argmax(shuffled(seq), key=key)
Peter Norvig's avatar
Peter Norvig a validé
def shuffled(iterable):
    """Randomly shuffle a copy of iterable."""
Peter Norvig's avatar
Peter Norvig a validé
    items = list(iterable)
    random.shuffle(items)
MircoT's avatar
MircoT a validé

# ______________________________________________________________________________
# Statistical and mathematical functions

MircoT's avatar
MircoT a validé

def histogram(values, mode=0, bin_function=None):
    """Return a list of (value, count) pairs, summarizing the input values.
    Sorted by increasing value, or if mode=1, by decreasing count.
    If bin_function is given, map it over values first."""
        values = map(bin_function, values)
    bins = {}
    for val in values:
        bins[val] = bins.get(val, 0) + 1
        return sorted(list(bins.items()), key=lambda x: (x[1], x[0]),
                      reverse=True)
    """Return the sum of the element-wise product of vectors X and Y."""
    return sum(x * y for x, y in zip(X, Y))
def element_wise_product(X, Y):
    """Return vector as an element-wise product of vectors X and Y"""
    assert len(X) == len(Y)
norvig's avatar
norvig a validé
    return [x * y for x, y in zip(X, Y)]
def matrix_multiplication(X_M, *Y_M):
    """Return a matrix as a matrix-multiplication of X_M and arbitary number of matrices *Y_M"""
    def _mat_mult(X_M, Y_M):
        """Return a matrix as a matrix-multiplication of two matrices X_M and Y_M
        >>> matrix_multiplication([[1, 2, 3],
                                   [2, 3, 4]],
                                   [[3, 4],
                                    [1, 2],
                                    [1, 0]])
        [[8, 8],[13, 14]]
        """
        assert len(X_M[0]) == len(Y_M)
        result = [[0 for i in range(len(Y_M[0]))] for j in range(len(X_M))]
        for i in range(len(X_M)):
            for j in range(len(Y_M[0])):
                for k in range(len(Y_M)):
                    result[i][j] += X_M[i][k] * Y_M[k][j]

    result = X_M
    for Y in Y_M:
        result = _mat_mult(result, Y)

def vector_to_diagonal(v):
    """Converts a vector to a diagonal matrix with vector elements
    as the diagonal elements of the matrix"""
    diag_matrix = [[0 for i in range(len(v))] for j in range(len(v))]
    for i in range(len(v)):
        diag_matrix[i][i] = v[i]

    return diag_matrix
    """Component-wise addition of two vectors."""
    return tuple(map(operator.add, a, b))

def scalar_vector_product(X, Y):
    """Return vector as a product of a scalar and a vector"""
    return [X * y for y in Y]

def scalar_matrix_product(X, Y):
    """Return matrix as a product of a scalar and a matrix"""
    return [scalar_vector_product(X, y) for y in Y]
def inverse_matrix(X):
    """Inverse a given square matrix of size 2x2"""
    assert len(X) == 2
    assert len(X[0]) == 2
    det = X[0][0] * X[1][1] - X[0][1] * X[1][0]
    assert det != 0
    inv_mat = scalar_matrix_product(1.0/det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]])

    """Return true with probability p."""
    return p > random.uniform(0.0, 1.0)

def weighted_sample_with_replacement(n, seq, weights):
    """Pick n samples from seq at random, with replacement, with the
    probability of each element in proportion to its corresponding
    weight."""
withal's avatar
withal a validé
    sample = weighted_sampler(seq, weights)

    return [sample() for _ in range(n)]
withal's avatar
withal a validé

withal's avatar
withal a validé
def weighted_sampler(seq, weights):
    """Return a random-sample function that picks from seq weighted by weights."""
    totals = []
    for w in weights:
        totals.append(w + totals[-1] if totals else w)
withal's avatar
withal a validé
    return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
def rounder(numbers, d=4):
    """Round a single number, or sequence of numbers, to d decimal places."""
    if isinstance(numbers, (int, float)):
        return round(numbers, d)
        constructor = type(numbers)     # Can be list, set, tuple, etc.
        return constructor(rounder(n, d) for n in numbers)
    """The argument is a string; convert to a number if
       possible, or strip it."""
withal's avatar
withal a validé
        return int(x)
withal's avatar
withal a validé
            return float(x)
withal's avatar
withal a validé
            return str(x).strip()
def normalize(dist):
    """Multiply each number by a constant such that the sum is 1.0"""
    if isinstance(dist, dict):
        total = sum(dist.values())
        for key in dist:
            dist[key] = dist[key] / total
            assert 0 <= dist[key] <= 1, "Probabilities must be between 0 and 1."
        return dist
    total = sum(dist)
    return [(n / total) for n in dist]
withal's avatar
withal a validé
def clip(x, lowest, highest):
    """Return x clipped to the range [lowest..highest]."""
withal's avatar
withal a validé
    return max(lowest, min(x, highest))


def sigmoid_derivative(value):
    return value * (1 - value)


def sigmoid(x):
    """Return activation value of x with sigmoid function"""
    return 1/(1 + math.exp(-x))


def step(x):
    """Return activation value of x with sign function"""
    return 1 if x >= 0 else 0
def gaussian(mean, st_dev, x):
    """Given the mean and standard deviation of a distribution, it returns the probability of x."""
    return 1/(math.sqrt(2*math.pi)*st_dev)*math.e**(-0.5*(float(x-mean)/st_dev)**2)

Peter Norvig's avatar
Peter Norvig a validé
try:  # math.isclose was added in Python 3.5; but we might be in 3.4
    from math import isclose
except ImportError:
    def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
        """Return true if numbers a and b are close to each other."""
        return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)

# ______________________________________________________________________________
# Grid Functions


orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
turns = LEFT, RIGHT = (+1, -1)


def turn_heading(heading, inc, headings=orientations):
    return headings[(headings.index(heading) + inc) % len(headings)]


def turn_right(heading):
    return turn_heading(heading, RIGHT)


def turn_left(heading):
    return turn_heading(heading, LEFT)


def distance(a, b):
    """The distance between two (x, y) points."""
    xA, yA = a
    xB, yB = b
    return math.hypot((xA - xB), (yA - yB))


def distance_squared(a, b):
    """The square of the distance between two (x, y) points."""
    xA, yA = a
    xB, yB = b
    return (xA - xB)**2 + (yA - yB)**2


def vector_clip(vector, lowest, highest):
    """Return vector, except if any element is less than the corresponding
    value of lowest or more than the corresponding value of highest, clip to
    those values."""
    return type(vector)(map(clip, vector, lowest, highest))


# ______________________________________________________________________________
def memoize(fn, slot=None, maxsize=32):
    """Memoize fn: make it remember the computed value for any argument list.
    If slot is specified, store result in that slot of first argument.
    If slot is false, use lru_cache for caching the values."""
    if slot:
        def memoized_fn(obj, *args):
            if hasattr(obj, slot):
                return getattr(obj, slot)
            else:
                val = fn(obj, *args)
                setattr(obj, slot, val)
                return val
    else:
        @functools.lru_cache(maxsize=maxsize)

def name(obj):
    """Try to find some reasonable name for the object."""
    return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
            getattr(getattr(obj, '__class__', 0), '__name__', 0) or
            str(obj))
    """Is x a number?"""
    return hasattr(x, '__int__')
    """Is x a sequence?"""
Peter Norvig's avatar
Peter Norvig a validé
    return isinstance(x, collections.abc.Sequence)
def print_table(table, header=None, sep='   ', numfmt='{}'):
    """Print a list of lists as a table, so that columns line up nicely.
    header, if specified, will be printed as the first row.
    numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
    (If you want different formats in different columns,
    don't use print_table.) sep is the separator between columns."""
    justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]

        table.insert(0, header)

    table = [[numfmt.format(x) if isnumber(x) else x for x in row]
MircoT's avatar
MircoT a validé
             for row in table]
MircoT's avatar
MircoT a validé
    sizes = list(
            map(lambda seq: max(map(len, seq)),
                list(zip(*[map(str, row) for row in table]))))
        print(sep.join(getattr(
            str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
def open_data(name, mode='r'):
    aima_root = os.path.dirname(__file__)
    aima_file = os.path.join(aima_root, *['aima-data', name])

    return open(aima_file)
Peter Norvig's avatar
Peter Norvig a validé
# ______________________________________________________________________________
# Expressions

# See https://docs.python.org/3/reference/expressions.html#operator-precedence
# See https://docs.python.org/3/reference/datamodel.html#special-method-names

Peter Norvig's avatar
Peter Norvig a validé
    """A mathematical expression with an operator and 0 or more arguments.
    op is a str like '+' or 'sin'; args are Expressions.
    Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
    Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
Peter Norvig's avatar
Peter Norvig a validé
        self.op = str(op)
        self.args = args
Peter Norvig's avatar
Peter Norvig a validé
    # Operator overloads
    def __neg__(self):
        return Expr('-', self)

    def __pos__(self):
        return Expr('+', self)

    def __invert__(self):
        return Expr('~', self)

    def __add__(self, rhs):
        return Expr('+', self, rhs)

    def __sub__(self, rhs):
        return Expr('-', self, rhs)

    def __mul__(self, rhs):
        return Expr('*', self, rhs)

    def __pow__(self, rhs):
        return Expr('**', self, rhs)

    def __mod__(self, rhs):
        return Expr('%', self, rhs)

    def __and__(self, rhs):
        return Expr('&', self, rhs)

    def __xor__(self, rhs):
        return Expr('^', self, rhs)

    def __rshift__(self, rhs):
        return Expr('>>', self, rhs)

    def __lshift__(self, rhs):
        return Expr('<<', self, rhs)

    def __truediv__(self, rhs):
        return Expr('/', self, rhs)

    def __floordiv__(self, rhs):
        return Expr('//', self, rhs)

    def __matmul__(self, rhs):
        return Expr('@', self, rhs)
Peter Norvig's avatar
Peter Norvig a validé

norvig's avatar
norvig a validé
    def __or__(self, rhs):
        """Allow both P | Q, and P |'==>'| Q."""
        if isinstance(rhs, Expression):
            return Expr('|', self, rhs)
Peter Norvig's avatar
Peter Norvig a validé
        else:
Peter Norvig's avatar
Peter Norvig a validé
            return PartialExpr(rhs, self)
Peter Norvig's avatar
Peter Norvig a validé
    # Reverse operator overloads
    def __radd__(self, lhs):
        return Expr('+', lhs, self)

    def __rsub__(self, lhs):
        return Expr('-', lhs, self)

    def __rmul__(self, lhs):
        return Expr('*', lhs, self)

    def __rdiv__(self, lhs):
        return Expr('/', lhs, self)

    def __rpow__(self, lhs):
        return Expr('**', lhs, self)

    def __rmod__(self, lhs):
        return Expr('%', lhs, self)

    def __rand__(self, lhs):
        return Expr('&', lhs, self)

    def __rxor__(self, lhs):
        return Expr('^', lhs, self)

    def __ror__(self, lhs):
        return Expr('|', lhs, self)

    def __rrshift__(self, lhs):
        return Expr('>>', lhs, self)

    def __rlshift__(self, lhs):
        return Expr('<<', lhs, self)

    def __rtruediv__(self, lhs):
        return Expr('/', lhs, self)

    def __rfloordiv__(self, lhs):
        return Expr('//', lhs, self)

    def __rmatmul__(self, lhs):
        return Expr('@', lhs, self)
Peter Norvig's avatar
Peter Norvig a validé
        "Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."
        if self.args:
            raise ValueError('can only do a call for a Symbol, not an Expr')
        else:
            return Expr(self.op, *args)
Peter Norvig's avatar
Peter Norvig a validé

    # Equality and repr
Peter Norvig's avatar
Peter Norvig a validé
        "'x == y' evaluates to True or False; does not build an Expr."
        return (isinstance(other, Expr)
                and self.op == other.op
Peter Norvig's avatar
Peter Norvig a validé
                and self.args == other.args)
Peter Norvig's avatar
Peter Norvig a validé
    def __hash__(self): return hash(self.op) ^ hash(self.args)
Peter Norvig's avatar
Peter Norvig a validé
    def __repr__(self):
Peter Norvig's avatar
Peter Norvig a validé
        args = [str(arg) for arg in self.args]
        if op.isidentifier():       # f(x) or f(x, y)
            return '{}({})'.format(op, ', '.join(args)) if args else op
        elif len(args) == 1:        # -x or -(x + 1)
            return op + args[0]
        else:                       # (x - y)
            opp = (' ' + op + ' ')
            return '(' + opp.join(args) + ')'

# An 'Expression' is either an Expr or a Number.
# Symbol is not an explicit type; it is any Expr with 0 args.

Number = (int, float, complex)
Peter Norvig's avatar
Peter Norvig a validé
Expression = (Expr, Number)

Peter Norvig's avatar
Peter Norvig a validé
def Symbol(name):
    """A Symbol is just an Expr with no args."""
Peter Norvig's avatar
Peter Norvig a validé
    return Expr(name)

Peter Norvig's avatar
Peter Norvig a validé
def symbols(names):
    """Return a tuple of Symbols; names is a comma/whitespace delimited str."""
Peter Norvig's avatar
Peter Norvig a validé
    return tuple(Symbol(name) for name in names.replace(',', ' ').split())

Peter Norvig's avatar
Peter Norvig a validé
def subexpressions(x):
    """Yield the subexpressions of an Expression (including x itself)."""
Peter Norvig's avatar
Peter Norvig a validé
    yield x
    if isinstance(x, Expr):
        for arg in x.args:
            yield from subexpressions(arg)

Peter Norvig's avatar
Peter Norvig a validé
def arity(expression):
    """The number of sub-expressions in this expression."""
Peter Norvig's avatar
Peter Norvig a validé
    if isinstance(expression, Expr):
        return len(expression.args)
    else:  # expression is a number
Peter Norvig's avatar
Peter Norvig a validé
        return 0

# For operators that are not defined in Python, we allow new InfixOps:

Peter Norvig's avatar
Peter Norvig a validé
class PartialExpr:
    """Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q."""
    def __init__(self, op, lhs):
        self.op, self.lhs = op, lhs

    def __or__(self, rhs):
        return Expr(self.op, self.lhs, rhs)

    def __repr__(self):
        return "PartialExpr('{}', {})".format(self.op, self.lhs)
Peter Norvig's avatar
Peter Norvig a validé
def expr(x):
    """Shortcut to create an Expression. x is a str in which:
    - identifiers are automatically defined as Symbols.
Peter Norvig's avatar
Peter Norvig a validé
    - ==> is treated as an infix |'==>'|, as are <== and <=>.
Peter Norvig's avatar
Peter Norvig a validé
    If x is already an Expression, it is returned unchanged. Example:
    >>> expr('P & Q ==> Q')
    ((P & Q) ==> Q)
    """
    if isinstance(x, str):
Peter Norvig's avatar
Peter Norvig a validé
        return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol))
Peter Norvig's avatar
Peter Norvig a validé
    else:
        return x

Peter Norvig's avatar
Peter Norvig a validé
infix_ops = '==> <== <=>'.split()

Peter Norvig's avatar
Peter Norvig a validé
def expr_handle_infix_ops(x):
Peter Norvig's avatar
Peter Norvig a validé
    """Given a str, return a new str with ==> replaced by |'==>'|, etc.
Peter Norvig's avatar
Peter Norvig a validé
    >>> expr_handle_infix_ops('P ==> Q')
Peter Norvig's avatar
Peter Norvig a validé
    "P |'==>'| Q"
Peter Norvig's avatar
Peter Norvig a validé
    """
    for op in infix_ops:
Peter Norvig's avatar
Peter Norvig a validé
        x = x.replace(op, '|' + repr(op) + '|')
Peter Norvig's avatar
Peter Norvig a validé
    return x

Peter Norvig's avatar
Peter Norvig a validé
class defaultkeydict(collections.defaultdict):
    """Like defaultdict, but the default_factory is a function of the key.
    >>> d = defaultkeydict(len); d['four']
    4
    """
    def __missing__(self, key):
        self[key] = result = self.default_factory(key)
        return result
class hashabledict(dict):
    """Allows hashing by representing a dictionary as tuple of key:value pairs
       May cause problems as the hash value may change during runtime
    """
    def __tuplify__(self):
        return tuple(sorted(self.items()))

    def __hash__(self):
        return hash(self.__tuplify__())

    def __lt__(self, odict):
        assert isinstance(odict, hashabledict)
        return self.__tuplify__() < odict.__tuplify__()

    def __gt__(self, odict):
        assert isinstance(odict, hashabledict)
        return self.__tuplify__() > odict.__tuplify__()

    def __le__(self, odict):
        assert isinstance(odict, hashabledict)
        return self.__tuplify__() <= odict.__tuplify__()

    def __ge__(self, odict):
        assert isinstance(odict, hashabledict)
        return self.__tuplify__() >= odict.__tuplify__()


# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue

# TODO: queue.PriorityQueue
Peter Norvig's avatar
Peter Norvig a validé
# TODO: Priority queues may not belong here -- see treatment in search.py
    """Queue is an abstract class/interface. There are three types:
        Stack(): A Last In First Out Queue.
        FIFOQueue(): A First In First Out Queue.
withal's avatar
withal a validé
        PriorityQueue(order, f): Queue in sorted order (default min-first).
    Each type supports the following methods and functions:
        q.append(item)  -- add an item to the queue
        q.extend(items) -- equivalent to: for item in items: q.append(item)
        q.pop()         -- return the top item from the queue
        len(q)          -- number of items in q (also q.__len())
        item in q       -- does q contain item?
    Note that isinstance(Stack(), Queue) is false, because we implement stacks
    as lists.  If Python ever gets interfaces, Queue will be an interface."""
withal's avatar
withal a validé
    def __init__(self):
norvig's avatar
norvig a validé
        raise NotImplementedError
MircoT's avatar
MircoT a validé
        for item in items:
            self.append(item)


def Stack():
    """Return an empty list, suitable as a Last-In-First-Out Queue."""
    return []

    """A First-In-First-Out Queue."""
    def __init__(self, maxlen=None, items=[]):
        self.queue = collections.deque(items, maxlen)
        if not self.queue.maxlen or len(self.queue) < self.queue.maxlen:
            self.queue.append(item)
        else:
            raise Exception('FIFOQueue is full')
        if not self.queue.maxlen or len(self.queue) + len(items) <= self.queue.maxlen:
            self.queue.extend(items)
        else:
            raise Exception('FIFOQueue max length exceeded')
withal's avatar
withal a validé
    def pop(self):
        if len(self.queue) > 0:
            return self.queue.popleft()
C.G.Vedant's avatar
C.G.Vedant a validé
        else:
            raise Exception('FIFOQueue is empty')

    def __len__(self):
        return len(self.queue)
    def __contains__(self, item):
        return item in self.queue
    """A queue in which the minimum (or maximum) element (as determined by f and
    order) is returned first. If order is min, the item with minimum f(x) is
    returned first; if order is max, then it is the item with maximum f(x).
    Also supports dict-like lookup."""
    def __init__(self, order=min, f=lambda x: x):
    def append(self, item):
        bisect.insort(self.A, (self.f(item), item))
    def __len__(self):
        return len(self.A)
    def pop(self):
        if self.order == min:
            return self.A.pop(0)[1]
        else:
            return self.A.pop()[1]
    def __contains__(self, item):
        return any(item == pair[1] for pair in self.A)
    def __getitem__(self, key):
        for _, item in self.A:
            if item == key:
                return item
    def __delitem__(self, key):
        for i, (value, item) in enumerate(self.A):
            if item == key:
                self.A.pop(i)
# ______________________________________________________________________________
# Useful Shorthands


class Bool(int):
    """Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'"""
    __str__ = __repr__ = lambda self: 'T' if self else 'F'

T = Bool(True)
F = Bool(False)