Newer
Older
"""Provides some utilities widely used by other modules."""
import operator
import os.path
import random
from itertools import chain, combinations
try: # math.inf was added in Python 3.5
from math import inf
except ImportError: # Python 3.4
inf = float('inf')
# ______________________________________________________________________________
"""Converts iterable to sequence, if it is not already one."""
return iterable if isinstance(iterable, collections.abc.Sequence) else tuple([iterable])
Donato Meoli
a validé
def remove_all(item, seq):
"""Return a copy of seq (or string) with all occurrences of item removed."""
if isinstance(seq, str):
Donato Meoli
a validé
elif isinstance(seq, set):
rest = seq.copy()
rest.remove(item)
return rest
"""Remove duplicate elements from seq. Assumes hashable elements."""
return list(set(seq))
"""Count the number of items in sequence that are interpreted as true."""
Donato Meoli
a validé
def multimap(items):
"""Given (key, val) pairs, return {key: [val, ....], ...}."""
result = collections.defaultdict(list)
return dict(result)
Donato Meoli
a validé
def multimap_items(mmap):
"""Yield all (key, val) pairs stored in the multimap."""
for (key, vals) in mmap.items():
for val in vals:
yield key, val
Donato Meoli
a validé
def product(numbers):
"""Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
result = 1
for x in numbers:
result *= x
"""Return the first element of an iterable; or default."""
return next(iter(iterable), default)
Donato Meoli
a validé
def is_in(elt, seq):
"""Similar to (elt in seq), but compares with 'is', not '=='."""
"""Return the most common data item. If there are ties, return any one of them."""
def power_set(iterable):
"""power_set([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:]
Donato Meoli
a validé
def extend(s, var, val):
"""Copy dict s and extend it by setting var to val; return copy."""
try: # Python 3.5 and later
return eval('{**s, var: val}')
except SyntaxError: # Python 3.4
s2 = s.copy()
s2[var] = val
return s2
Donato Meoli
a validé
# ______________________________________________________________________________
# argmin and argmax
def argmin_random_tie(seq, key=identity):
"""Return a minimum element of seq; break ties at random."""
return min(shuffled(seq), key=key)
"""Return an element with highest fn(seq[i]) score; break ties at random."""
return max(shuffled(seq), key=key)
"""Randomly shuffle a copy of iterable."""
Surya Teja Cheedella
a validé
return items
# ______________________________________________________________________________
# Statistical and mathematical functions
def histogram(values, mode=0, bin_function=None):
"""Return a list of (value, count) pairs, summarizing the input values.
Sorted by increasing value, or if mode=1, by decreasing count.
If bin_function is given, map it over values first."""
values = map(bin_function, values)
bins = {}
for val in values:
bins[val] = bins.get(val, 0) + 1
if mode:
return sorted(list(bins.items()), key=lambda x: (x[1], x[0]), reverse=True)
else:
return sorted(bins.items())
def dot_product(x, y):
"""Return the sum of the element-wise product of vectors x and y."""
return sum(_x * _y for _x, _y in zip(x, y))
def element_wise_product(x, y):
"""Return vector as an element-wise product of vectors x and y."""
assert len(x) == len(y)
return np.multiply(x, y)
def matrix_multiplication(x, *y):
"""Return a matrix as a matrix-multiplication of x and arbitrary number of matrices *y."""
result = x
for _y in y:
result = np.matmul(result, _y)
return result
def vector_add(a, b):
"""Component-wise addition of two vectors."""
return tuple(map(operator.add, a, b))
def scalar_vector_product(x, y):
"""Return vector as a product of a scalar and a vector"""
return np.multiply(x, y)
def probability(p):
return p > random.uniform(0.0, 1.0)
def weighted_sample_with_replacement(n, seq, weights):
"""Pick n samples from seq at random, with replacement, with the
probability of each element in proportion to its corresponding
weight."""
"""Return a random-sample function that picks from seq weighted by weights."""
totals = []
for w in weights:
totals.append(w + totals[-1] if totals else w)
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
def weighted_choice(choices):
"""A weighted version of random.choice"""
# NOTE: should be replaced by random.choices if we port to Python 3.6
total = sum(w for _, w in choices)
r = random.uniform(0, total)
upto = 0
for c, w in choices:
if upto + w >= r:
return c, w
upto += w
Donato Meoli
a validé
"""Round a single number, or sequence of numbers, to d decimal places."""
if isinstance(numbers, (int, float)):
return round(numbers, d)
else:
constructor = type(numbers) # Can be list, set, tuple, etc.
return constructor(rounder(n, d) for n in numbers)
Donato Meoli
a validé
def num_or_str(x): # TODO: rename as `atom`
"""The argument is a string; convert to a number if possible, or strip it."""
except ValueError:
try:
except ValueError:
def euclidean_distance(x, y):
return math.sqrt(sum((_x - _y) ** 2 for _x, _y in zip(x, y)))
def cross_entropy_loss(x, y):
return (-1.0 / len(x)) * sum(x * math.log(y) + (1 - x) * math.log(1 - y) for x, y in zip(x, y))
def rms_error(x, y):
return math.sqrt(ms_error(x, y))
def ms_error(x, y):
return mean((x - y) ** 2 for x, y in zip(x, y))
def mean_error(x, y):
return mean(abs(x - y) for x, y in zip(x, y))
def manhattan_distance(x, y):
return sum(abs(_x - _y) for _x, _y in zip(x, y))
def mean_boolean_error(x, y):
return mean(_x != _y for _x, _y in zip(x, y))
def hamming_distance(x, y):
return sum(_x != _y for _x, _y in zip(x, y))
"""Multiply each number by a constant such that the sum is 1.0"""
if isinstance(dist, dict):
total = sum(dist.values())
for key in dist:
dist[key] = dist[key] / total
assert 0 <= dist[key] <= 1 # probabilities must be between 0 and 1
return dist
total = sum(dist)
return [(n / total) for n in dist]
def norm(x, ord=2):
"""Return the n-norm of vector x."""
return np.linalg.norm(x, ord)
Donato Meoli
a validé
def random_weights(min_value, max_value, num_weights):
return [random.uniform(min_value, max_value) for _ in range(num_weights)]
"""Return x clipped to the range [lowest..highest]."""
def sigmoid_derivative(value):
return value * (1 - value)
"""Return activation value of x with sigmoid function."""
return 1 / (1 + math.exp(-x))
return x if x > 0 else alpha * (math.exp(x) - 1)
Donato Meoli
a validé
def elu_derivative(value, alpha=0.01):
return 1 if value > 0 else alpha * math.exp(value)
Donato Meoli
a validé
Donato Meoli
a validé
return np.tanh(x)
return 1 - (value ** 2)
Donato Meoli
a validé
def leaky_relu(x, alpha=0.01):
return x if x > 0 else alpha * x
def leaky_relu_derivative(value, alpha=0.01):
return 1 if value > 0 else alpha
Donato Meoli
a validé
Donato Meoli
a validé
return max(0, x)
return 1 if value > 0 else 0
Donato Meoli
a validé
def step(x):
"""Return activation value of x with sign function"""
return 1 if x >= 0 else 0
def gaussian(mean, st_dev, x):
"""Given the mean and standard deviation of a distribution, it returns the probability of x."""
return 1 / (math.sqrt(2 * math.pi) * st_dev) * math.e ** (-0.5 * (float(x - mean) / st_dev) ** 2)
def linear_kernel(x, y=None):
if y is None:
y = x
return np.dot(x, y.T)
def polynomial_kernel(x, y=None, degree=2.0):
if y is None:
y = x
return (1.0 + np.dot(x, y.T)) ** degree
def rbf_kernel(x, y=None, gamma=None):
"""Radial-basis function kernel (aka squared-exponential kernel)."""
if y is None:
y = x
if gamma is None:
gamma = 1.0 / x.shape[1] # 1.0 / n_features
return np.exp(-gamma * (-2.0 * np.dot(x, y.T) +
np.sum(x * x, axis=1).reshape((-1, 1)) + np.sum(y * y, axis=1).reshape((1, -1))))
try: # math.isclose was added in Python 3.5
except ImportError: # Python 3.4
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
"""Return true if numbers a and b are close to each other."""
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
# ______________________________________________________________________________
# Grid Functions
orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
turns = LEFT, RIGHT = (+1, -1)
def turn_heading(heading, inc, headings=orientations):
return headings[(headings.index(heading) + inc) % len(headings)]
def turn_right(heading):
return turn_heading(heading, RIGHT)
def turn_left(heading):
return turn_heading(heading, LEFT)
def distance(a, b):
"""The distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return math.hypot((xA - xB), (yA - yB))
def distance_squared(a, b):
"""The square of the distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return (xA - xB) ** 2 + (yA - yB) ** 2
def vector_clip(vector, lowest, highest):
"""Return vector, except if any element is less than the corresponding
value of lowest or more than the corresponding value of highest, clip to
those values."""
return type(vector)(map(clip, vector, lowest, highest))
# ______________________________________________________________________________
# Misc Functions
class injection:
"""Dependency injection of temporary values for global functions/classes/etc.
E.g., `with injection(DataBase=MockDataBase): ...`"""
def __init__(self, **kwds):
def __enter__(self):
self.old = {v: globals()[v] for v in self.new}
globals().update(self.new)
def __exit__(self, type, value, traceback):
def memoize(fn, slot=None, maxsize=32):
"""Memoize fn: make it remember the computed value for any argument list.
If slot is specified, store result in that slot of first argument.
If slot is false, use lru_cache for caching the values."""
if slot:
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
else:
val = fn(obj, *args)
setattr(obj, slot, val)
return val
else:
@functools.lru_cache(maxsize=maxsize)
def memoized_fn(*args):
return memoized_fn
"""Try to find some reasonable name for the object."""
return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
getattr(getattr(obj, '__class__', 0), '__name__', 0) or
str(obj))
def isnumber(x):
def issequence(x):
def print_table(table, header=None, sep=' ', numfmt='{}'):
"""Print a list of lists as a table, so that columns line up nicely.
header, if specified, will be printed as the first row.
numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
(If you want different formats in different columns,
don't use print_table.) sep is the separator between columns."""
justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]
if header:
table.insert(0, header)
table = [[numfmt.format(x) if isnumber(x) else x for x in row]
sizes = list(map(lambda seq: max(map(len, seq)), list(zip(*[map(str, row) for row in table]))))
for row in table:
print(sep.join(getattr(str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
def open_data(name, mode='r'):
aima_root = os.path.dirname(__file__)
aima_file = os.path.join(aima_root, *['aima-data', name])
return open(aima_file, mode=mode)
def failure_test(algorithm, tests):
"""Grades the given algorithm based on how many tests it passes.
Most algorithms have arbitrary output on correct execution, which is difficult
to check for correctness. On the other hand, a lot of algorithms output something
particular on fail (for example, False, or None).
tests is a list with each element in the form: (values, failure_output)."""
from statistics import mean
return mean(int(algorithm(x) != y) for x, y in tests)
# ______________________________________________________________________________
# Expressions
# See https://docs.python.org/3/reference/expressions.html#operator-precedence
# See https://docs.python.org/3/reference/datamodel.html#special-method-names
"""A mathematical expression with an operator and 0 or more arguments.
op is a str like '+' or 'sin'; args are Expressions.
Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
Surya Teja Cheedella
a validé
def __init__(self, op, *args):
Surya Teja Cheedella
a validé
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def __neg__(self):
return Expr('-', self)
def __pos__(self):
return Expr('+', self)
def __invert__(self):
return Expr('~', self)
def __add__(self, rhs):
return Expr('+', self, rhs)
def __sub__(self, rhs):
return Expr('-', self, rhs)
def __mul__(self, rhs):
return Expr('*', self, rhs)
def __pow__(self, rhs):
return Expr('**', self, rhs)
def __mod__(self, rhs):
return Expr('%', self, rhs)
def __and__(self, rhs):
return Expr('&', self, rhs)
def __xor__(self, rhs):
return Expr('^', self, rhs)
def __rshift__(self, rhs):
return Expr('>>', self, rhs)
def __lshift__(self, rhs):
return Expr('<<', self, rhs)
def __truediv__(self, rhs):
return Expr('/', self, rhs)
def __floordiv__(self, rhs):
return Expr('//', self, rhs)
def __matmul__(self, rhs):
return Expr('@', self, rhs)
if isinstance(rhs, Expression):
return Expr('|', self, rhs)
Surya Teja Cheedella
a validé
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def __radd__(self, lhs):
return Expr('+', lhs, self)
def __rsub__(self, lhs):
return Expr('-', lhs, self)
def __rmul__(self, lhs):
return Expr('*', lhs, self)
def __rdiv__(self, lhs):
return Expr('/', lhs, self)
def __rpow__(self, lhs):
return Expr('**', lhs, self)
def __rmod__(self, lhs):
return Expr('%', lhs, self)
def __rand__(self, lhs):
return Expr('&', lhs, self)
def __rxor__(self, lhs):
return Expr('^', lhs, self)
def __ror__(self, lhs):
return Expr('|', lhs, self)
def __rrshift__(self, lhs):
return Expr('>>', lhs, self)
def __rlshift__(self, lhs):
return Expr('<<', lhs, self)
def __rtruediv__(self, lhs):
return Expr('/', lhs, self)
def __rfloordiv__(self, lhs):
return Expr('//', lhs, self)
def __rmatmul__(self, lhs):
return Expr('@', lhs, self)
Surya Teja Cheedella
a validé
def __call__(self, *args):
Donato Meoli
a validé
"""Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."""
raise ValueError('Can only do a call for a Symbol, not an Expr')
else:
return Expr(self.op, *args)
Surya Teja Cheedella
a validé
def __eq__(self, other):
Donato Meoli
a validé
"""x == y' evaluates to True or False; does not build an Expr."""
return isinstance(other, Expr) and self.op == other.op and self.args == other.args
Surya Teja Cheedella
a validé
Donato Meoli
a validé
def __lt__(self, other):
return isinstance(other, Expr) and str(self) < str(other)
Donato Meoli
a validé
def __hash__(self):
return hash(self.op) ^ hash(self.args)
Surya Teja Cheedella
a validé
if op.isidentifier(): # f(x) or f(x, y)
return '{}({})'.format(op, ', '.join(args)) if args else op
elif len(args) == 1: # -x or -(x + 1)
opp = (' ' + op + ' ')
return '(' + opp.join(args) + ')'
# An 'Expression' is either an Expr or a Number.
# Symbol is not an explicit type; it is any Expr with 0 args.
Number = (int, float, complex)
"""A Symbol is just an Expr with no args."""
"""Return a tuple of Symbols; names is a comma/whitespace delimited str."""
return tuple(Symbol(name) for name in names.replace(',', ' ').split())
"""Yield the subexpressions of an Expression (including x itself)."""
yield x
if isinstance(x, Expr):
for arg in x.args:
yield from subexpressions(arg)
"""The number of sub-expressions in this expression."""
if isinstance(expression, Expr):
return len(expression.args)
else: # expression is a number
# For operators that are not defined in Python, we allow new InfixOps:
class PartialExpr:
"""Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q."""
def __init__(self, op, lhs):
self.op, self.lhs = op, lhs
def __or__(self, rhs):
return Expr(self.op, self.lhs, rhs)
def __repr__(self):
return "PartialExpr('{}', {})".format(self.op, self.lhs)
def expr(x):
"""Shortcut to create an Expression. x is a str in which:
- identifiers are automatically defined as Symbols.
- ==> is treated as an infix |'==>'|, as are <== and <=>.
If x is already an Expression, it is returned unchanged. Example:
>>> expr('P & Q ==> Q')
((P & Q) ==> Q)
"""
return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol)) if isinstance(x, str) else x
"""Given a str, return a new str with ==> replaced by |'==>'|, etc.
class defaultkeydict(collections.defaultdict):
"""Like defaultdict, but the default_factory is a function of the key.
>>> d = defaultkeydict(len); d['four']
4
"""
def __missing__(self, key):
self[key] = result = self.default_factory(key)
return result
Surya Teja Cheedella
a validé
"""Allows hashing by representing a dictionary as tuple of key:value pairs.
May cause problems as the hash value may change during runtime."""
# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue
# Stack and FIFOQueue are implemented as list and collection.deque
# PriorityQueue is implemented here
class PriorityQueue:
"""A Queue in which the minimum (or maximum) element (as determined by f and
order) is returned first.
If order is 'min', the item with minimum f(x) is
returned first; if order is 'max', then it is the item with maximum f(x).
Also supports dict-like lookup."""
def __init__(self, order='min', f=lambda x: x):
self.heap = []
if order == 'min':
self.f = f
elif order == 'max': # now item with max f(x)
self.f = lambda x: -f(x) # will be popped first
else:
raise ValueError("Order must be either 'min' or 'max'.")
def append(self, item):
"""Insert item at its correct position."""
heapq.heappush(self.heap, (self.f(item), item))
def extend(self, items):
"""Insert each item in items at its correct position."""
for item in items:
"""Pop and return the item (with min or max f(x) value)
depending on the order."""
if self.heap:
return heapq.heappop(self.heap)[1]
raise Exception('Trying to pop from empty PriorityQueue.')
"""Return current capacity of PriorityQueue."""
return len(self.heap)
def __contains__(self, key):
"""Return True if the key is in PriorityQueue."""
return any([item == key for _, item in self.heap])
def __getitem__(self, key):
"""Returns the first value associated with key in PriorityQueue.
Raises KeyError if key is not present."""
for value, item in self.heap:
return value
raise KeyError(str(key) + " is not in the priority queue")
def __delitem__(self, key):
"""Delete the first occurrence of key."""
try:
del self.heap[[item == key for _, item in self.heap].index(True)]
except ValueError:
raise KeyError(str(key) + " is not in the priority queue")
heapq.heapify(self.heap)
# ______________________________________________________________________________
# Useful Shorthands
class Bool(int):
"""Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'."""
__str__ = __repr__ = lambda self: 'T' if self else 'F'