notebook.py 41,2 ko
Newer Older
from inspect import getsource

from utils import argmax, argmin
from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity
from logic import parse_definite_clause, standardize_variables, unify, subst
from IPython.display import HTML, display
Anthony Marakis's avatar
Anthony Marakis a validé
from collections import Counter, defaultdict
import numpy as np
from PIL import Image
Anthony Marakis's avatar
Anthony Marakis a validé
import time
#______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# Magic Words
def pseudocode(algorithm):
    """Print the pseudocode for the given algorithm."""
    from urllib.request import urlopen
    from IPython.display import Markdown

Anthony Marakis's avatar
Anthony Marakis a validé
    algorithm = algorithm.replace(' ', '-')
    url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
    f = urlopen(url)
    md = f.read().decode('utf-8')
    md = md.split('\n', 1)[-1].strip()
    md = '#' + md
    return Markdown(md)


def psource(*functions):
    """Print the source code for the given function(s)."""
    source_code = '\n\n'.join(getsource(fn) for fn in functions)
    try:
        from pygments.formatters import HtmlFormatter
        from pygments.lexers import PythonLexer
        from pygments import highlight

        display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))

    except ImportError:
        print(source_code)
# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# Iris Visualization
    """Plots the iris dataset in a 3D plot.
    The three axes are given by i, j and k,
    which correspond to three of the four iris features."""
    from mpl_toolkits.mplot3d import Axes3D

    plt.rcParams.update(plt.rcParamsDefault)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    iris = DataSet(name="iris")
    buckets = iris.split_values_by_classes()

    features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
    f1, f2, f3 = features[i], features[j], features[k]

    a_setosa = [v[i] for v in buckets["setosa"]]
    b_setosa = [v[j] for v in buckets["setosa"]]
    c_setosa = [v[k] for v in buckets["setosa"]]

    a_virginica = [v[i] for v in buckets["virginica"]]
    b_virginica = [v[j] for v in buckets["virginica"]]
    c_virginica = [v[k] for v in buckets["virginica"]]

    a_versicolor = [v[i] for v in buckets["versicolor"]]
    b_versicolor = [v[j] for v in buckets["versicolor"]]
    c_versicolor = [v[k] for v in buckets["versicolor"]]


    for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
                             ('g', '^', a_virginica, b_virginica, c_virginica),
                             ('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
        ax.scatter(sl, sw, pl, c=c, marker=m)

    ax.set_xlabel(f1)
    ax.set_ylabel(f2)
    ax.set_zlabel(f3)

    plt.show()

# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# MNIST
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
    import os, struct
    import array
    import numpy as np
    from collections import Counter

    if fashion:
        path = "aima-data/MNIST/Fashion"

    plt.rcParams.update(plt.rcParamsDefault)
    plt.rcParams['figure.figsize'] = (10.0, 8.0)
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'

    train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
    train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
    test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
    test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")

    magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
    tr_img = array.array("B", train_img_file.read())
    train_img_file.close()
    magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
    tr_lbl = array.array("b", train_lbl_file.read())
    train_lbl_file.close()

    magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
    te_img = array.array("B", test_img_file.read())
    test_img_file.close()
    magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
    te_lbl = array.array("b", test_lbl_file.read())
    test_lbl_file.close()

     #print(len(tr_img), len(tr_lbl), tr_size)
     #print(len(te_img), len(te_lbl), te_size)

    train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)
    train_lbl = np.zeros((tr_size,), dtype=np.int8)
    for i in range(tr_size):
        train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols))
        train_lbl[i] = tr_lbl[i]

    test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16)
    test_lbl = np.zeros((te_size,), dtype=np.int8)
    for i in range(te_size):
        test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols))
        test_lbl[i] = te_lbl[i]

    return(train_img, train_lbl, test_img, test_lbl)


digit_classes = [str(i) for i in range(10)]
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
                   "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]


def show_MNIST(labels, images, samples=8, fashion=False):
    if not fashion:
        classes = digit_classes
    else:
        classes = fashion_classes

    num_classes = len(classes)

    for y, cls in enumerate(classes):
        idxs = np.nonzero([i == y for i in labels])
        idxs = np.random.choice(idxs[0], samples, replace=False)
        for i , idx in enumerate(idxs):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples, num_classes, plt_idx)
            plt.imshow(images[idx].reshape((28, 28)))
            plt.axis("off")
            if i == 0:
                plt.title(cls)

    plt.show()


def show_ave_MNIST(labels, images, fashion=False):
    if not fashion:
        item_type = "Digit"
        classes = digit_classes
    else:
        item_type = "Apparel"
        classes = fashion_classes

    num_classes = len(classes)

    for y, cls in enumerate(classes):
        idxs = np.nonzero([i == y for i in labels])
        print(item_type, y, ":", len(idxs[0]), "images.")

        ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
        #print(ave_img.shape)

        plt.subplot(1, num_classes, y+1)
        plt.imshow(ave_img.reshape((28, 28)))
        plt.axis("off")
        plt.title(cls)

    plt.show()

# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# MDP


def make_plot_grid_step_function(columns, rows, U_over_time):
    """ipywidgets interactive function supports single parameter as input.
    This function creates and return such a function by taking as input
    other parameters."""

    def plot_grid_step(iteration):
        data = U_over_time[iteration]
        data = defaultdict(lambda: 0, data)
        grid = []
        for row in range(rows):
            current_row = []
            for column in range(columns):
                current_row.append(data[(column, row)])
            grid.append(current_row)
        grid.reverse() # output like book
        fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation='nearest')

        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        for col in range(len(grid)):
            for row in range(len(grid[0])):
                magic = grid[col][row]
                fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center')

        plt.show()

    return plot_grid_step

def make_visualize(slider):
    """Takes an input a sliderand returns callback function
    for timer and animation."""

    def visualize_callback(Visualize, time_step):
        if Visualize is True:
            for i in range(slider.min, slider.max + 1):
                slider.value = i
                time.sleep(float(time_step))

    return visualize_callback

# ______________________________________________________________________________
_canvas = """
<script type="text/javascript" src="./js/canvas.js"></script>
<div>
<canvas id="{0}" width="{1}" height="{2}" style="background:rgba(158, 167, 184, 0.2);" onclick='click_callback(this, event, "{3}")'></canvas>
</div>

<script> var {0}_canvas_object = new Canvas("{0}");</script>
"""  # noqa


class Canvas:
    """Inherit from this class to manage the HTML canvas element in jupyter notebooks.
    To create an object of this class any_name_xyz = Canvas("any_name_xyz")
    The first argument given must be the name of the object being created.
Robert Hönig's avatar
Robert Hönig a validé
    IPython must be able to reference the variable name that is being passed."""
    def __init__(self, varname, width=800, height=600, cid=None):
        self.name = varname
        self.cid = cid or varname
        self.width = width
        self.height = height
        self.html = _canvas.format(self.cid, self.width, self.height, self.name)
        self.exec_list = []
        display_html(self.html)

    def mouse_click(self, x, y):
        """Override this method to handle mouse click at position (x, y)"""
        raise NotImplementedError

    def mouse_move(self, x, y):
        raise NotImplementedError

Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
    def execute(self, exec_str):
Robert Hönig's avatar
Robert Hönig a validé
        """Stores the command to be executed to a list which is used later during update()"""
        if not isinstance(exec_str, str):
            print("Invalid execution argument:", exec_str)
Robert Hönig's avatar
Robert Hönig a validé
            self.alert("Received invalid execution command format")
        prefix = "{0}_canvas_object.".format(self.cid)
        self.exec_list.append(prefix + exec_str + ';')

    def fill(self, r, g, b):
        """Changes the fill color to a color in rgb format"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("fill({0}, {1}, {2})".format(r, g, b))

    def stroke(self, r, g, b):
        """Changes the colors of line/strokes to rgb"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("stroke({0}, {1}, {2})".format(r, g, b))

    def strokeWidth(self, w):
        """Changes the width of lines/strokes to 'w' pixels"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("strokeWidth({0})".format(w))

    def rect(self, x, y, w, h):
        """Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h))

    def rect_n(self, xn, yn, wn, hn):
        """Similar to rect(), but the dimensions are normalized to fall between 0 and 1"""
        x = round(xn * self.width)
        y = round(yn * self.height)
        w = round(wn * self.width)
        h = round(hn * self.height)
        self.rect(x, y, w, h)

    def line(self, x1, y1, x2, y2):
        """Draw a line from (x1, y1) to (x2, y2)"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2))

    def line_n(self, x1n, y1n, x2n, y2n):
        """Similar to line(), but the dimensions are normalized to fall between 0 and 1"""
        x1 = round(x1n * self.width)
        y1 = round(y1n * self.height)
        x2 = round(x2n * self.width)
        y2 = round(y2n * self.height)
        self.line(x1, y1, x2, y2)

    def arc(self, x, y, r, start, stop):
        """Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop))
    def arc_n(self, xn, yn, rn, start, stop):
        """Similar to arc(), but the dimensions are normalized to fall between 0 and 1
        The normalizing factor for radius is selected between width and height by
        seeing which is smaller."""
        x = round(xn * self.width)
        y = round(yn * self.height)
        r = round(rn * min(self.width, self.height))
        self.arc(x, y, r, start, stop)

    def clear(self):
        """Clear the HTML canvas"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("clear()")

    def font(self, font):
        """Changes the font of text"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute('font("{0}")'.format(font))
    def text(self, txt, x, y, fill=True):
        """Display a text at (x, y)"""
        if fill:
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
            self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y))
        else:
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
            self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y))
    def text_n(self, txt, xn, yn, fill=True):
        """Similar to text(), but with normalized coordinates"""
        x = round(xn * self.width)
        y = round(yn * self.height)
        self.text(txt, x, y, fill)

    def alert(self, message):
        """Immediately display an alert"""
        display_html('<script>alert("{0}")</script>'.format(message))

    def update(self):
        """Execute the JS code to execute the commands queued by execute()"""
        exec_code = "<script>\n" + '\n'.join(self.exec_list) + "\n</script>"
        self.exec_list = []
        display_html(exec_code)

def display_html(html_string):
    display(HTML(html_string))


################################################################################

class Canvas_TicTacToe(Canvas):
    """Play a 3x3 TicTacToe game on HTML canvas"""
    def __init__(self, varname, player_1='human', player_2='random',
                 width=300, height=350, cid=None):
        valid_players = ('human', 'random', 'alphabeta')
        if player_1 not in valid_players or player_2 not in valid_players:
            raise TypeError("Players must be one of {}".format(valid_players))
        Canvas.__init__(self, varname, width, height, cid)
        self.ttt = TicTacToe()
        self.state = self.ttt.initial
        self.turn = 0
        self.strokeWidth(5)
        self.players = (player_1, player_2)
        self.font("20px Arial")
        self.draw_board()

    def mouse_click(self, x, y):
        player = self.players[self.turn]
        if self.ttt.terminal_test(self.state):
            if 0.55 <= x/self.width <= 0.95 and 6/7 <= y/self.height <= 6/7+1/8:
                self.state = self.ttt.initial
                self.turn = 0
                self.draw_board()
            return

        if player == 'human':
            x, y = int(3*x/self.width) + 1, int(3*y/(self.height*6/7)) + 1
            if (x, y) not in self.ttt.actions(self.state):
                # Invalid move
                return
            move = (x, y)
        elif player == 'alphabeta':
            move = alphabeta_player(self.ttt, self.state)
        else:
            move = random_player(self.ttt, self.state)
        self.state = self.ttt.result(self.state, move)
        self.turn ^= 1
        self.draw_board()

    def draw_board(self):
        self.clear()
        self.stroke(0, 0, 0)
        offset = 1/20
        self.line_n(0 + offset, (1/3)*6/7, 1 - offset, (1/3)*6/7)
        self.line_n(0 + offset, (2/3)*6/7, 1 - offset, (2/3)*6/7)
        self.line_n(1/3, (0 + offset)*6/7, 1/3, (1 - offset)*6/7)
        self.line_n(2/3, (0 + offset)*6/7, 2/3, (1 - offset)*6/7)

        board = self.state.board
        for mark in board:
            if board[mark] == 'X':
                self.draw_x(mark)
            elif board[mark] == 'O':
                self.draw_o(mark)
        if self.ttt.terminal_test(self.state):
            # End game message
            utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial))
            if utility == 0:
                self.text_n('Game Draw!', offset, 6/7 + offset)
            else:
                self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6/7 + offset)
                # Find the 3 and draw a line
                self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
                for i in range(3):
                    if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
                       len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
                        self.line_n(i/3 + 1/6, offset*6/7, i/3 + 1/6, (1 - offset)*6/7)
                    if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
                       len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
                        self.line_n(offset, (i/3 + 1/6)*6/7, 1 - offset, (i/3 + 1/6)*6/7)
                if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
                   len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
                        self.line_n(offset, offset*6/7, 1 - offset, (1 - offset)*6/7)
                if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
                   len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
                        self.line_n(offset, (1 - offset)*6/7, 1 - offset, offset*6/7)
            # restart button
            self.fill(0, 0, 255)
            self.rect_n(0.5 + offset, 6/7, 0.4, 1/8)
            self.fill(0, 0, 0)
            self.text_n('Restart', 0.5 + 2*offset, 13/14)
        else:  # Print which player's turn it is
            self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
                        offset, 6/7 + offset)

        self.update()

    def draw_x(self, position):
        self.stroke(0, 255, 0)
        x, y = [i-1 for i in position]
        offset = 1/15
        self.line_n(x/3 + offset, (y/3 + offset)*6/7, x/3 + 1/3 - offset, (y/3 + 1/3 - offset)*6/7)
        self.line_n(x/3 + 1/3 - offset, (y/3 + offset)*6/7, x/3 + offset, (y/3 + 1/3 - offset)*6/7)

    def draw_o(self, position):
        self.stroke(255, 0, 0)
        x, y = [i-1 for i in position]
        self.arc_n(x/3 + 1/6, (y/3 + 1/6)*6/7, 1/9, 0, 360)


class Canvas_minimax(Canvas):
    """Minimax for Fig52Extended on HTML canvas"""
    def __init__(self, varname, util_list, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
        self.game = Fig52Extended()
        self.game.utils = self.utils
        self.nodes = list(range(40))
        self.l = 1/40
        self.node_pos = {}
        for i in range(4):
            base = len(self.node_pos)
            row_size = 3**i
            for node in [base + j for j in range(row_size)]:
                self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
                                       self.l/2 + (self.l + (1 - 5*self.l)/3)*i)
        self.font("12px Arial")
        self.node_stack = []
        self.explored = {node for node in self.utils}
        self.thick_lines = set()
        self.change_list = []
        self.draw_graph()
        self.stack_manager = self.stack_manager_gen()

    def minimax(self, node):
        game = self.game
        player = game.to_move(node)
        def max_value(node):
            if game.terminal_test(node):
                return game.utility(node, player)
            self.change_list.append(('a', node))
            self.change_list.append(('h',))
            max_a = argmax(game.actions(node), key=lambda x: min_value(game.result(node, x)))
            max_node = game.result(node, max_a)
            self.utils[node] = self.utils[max_node]
            x1, y1 = self.node_pos[node]
            x2, y2 = self.node_pos[max_node]
            self.change_list.append(('l', (node, max_node - 3*node - 1)))
            self.change_list.append(('e', node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return self.utils[node]

        def min_value(node):
            if game.terminal_test(node):
                return game.utility(node, player)
            self.change_list.append(('a', node))
            self.change_list.append(('h',))
            min_a = argmin(game.actions(node), key=lambda x: max_value(game.result(node, x)))
            min_node = game.result(node, min_a)
            self.utils[node] = self.utils[min_node]
            x1, y1 = self.node_pos[node]
            x2, y2 = self.node_pos[min_node]
            self.change_list.append(('l', (node, min_node - 3*node - 1)))
            self.change_list.append(('e', node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return self.utils[node]

        return max_value(node)

    def stack_manager_gen(self):
        self.minimax(0)
        for change in self.change_list:
            if change[0] == 'a':
                self.node_stack.append(change[1])
            elif change[0] == 'e':
                self.explored.add(change[1])
            elif change[0] == 'h':
                yield
            elif change[0] == 'l':
                self.thick_lines.add(change[1])
            elif change[0] == 'p':
                self.node_stack.pop()

    def mouse_click(self, x, y):
        try:
            self.stack_manager.send(None)
        except StopIteration:
            pass
        self.draw_graph()

    def draw_graph(self):
        self.clear()
        # draw nodes
        self.stroke(0, 0, 0)
        self.strokeWidth(1)
        # highlight for nodes in stack
        for node in self.node_stack:
            x, y = self.node_pos[node]
            self.fill(200, 200, 0)
            self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
        for node in self.nodes:
            x, y = self.node_pos[node]
            if node in self.explored:
                self.fill(255, 255, 255)
            else:
                self.fill(200, 200, 200)
            self.rect_n(x, y, self.l, self.l)
            self.line_n(x, y, x + self.l, y)
            self.line_n(x, y, x, y + self.l)
            self.line_n(x + self.l, y + self.l, x + self.l, y)
            self.line_n(x + self.l, y + self.l, x, y + self.l)
            self.fill(0, 0, 0)
            if node in self.explored:
                self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
        # draw edges
        for i in range(13):
            x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
            for j in range(3):
                x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
                if i in [1, 2, 3]:
                    self.stroke(200, 0, 0)
                else:
                    self.stroke(0, 200, 0)
                if (i, j) in self.thick_lines:
                    self.strokeWidth(3)
                else:
                    self.strokeWidth(1)
                self.line_n(x1, y1, x2, y2)
        self.update()


class Canvas_alphabeta(Canvas):
    """Alpha-beta pruning for Fig52Extended on HTML canvas"""
    def __init__(self, varname, util_list, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
        self.game = Fig52Extended()
        self.game.utils = self.utils
        self.nodes = list(range(40))
        self.l = 1/40
        self.node_pos = {}
        for i in range(4):
            base = len(self.node_pos)
            row_size = 3**i
            for node in [base + j for j in range(row_size)]:
                self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
                                       3*self.l/2 + (self.l + (1 - 6*self.l)/3)*i)
        self.font("12px Arial")
        self.node_stack = []
        self.explored = {node for node in self.utils}
        self.pruned = set()
        self.ab = {}
        self.thick_lines = set()
        self.change_list = []
        self.draw_graph()
        self.stack_manager = self.stack_manager_gen()

    def alphabeta_search(self, node):
        game = self.game
        player = game.to_move(node)

        # Functions used by alphabeta
        def max_value(node, alpha, beta):
            if game.terminal_test(node):
                self.change_list.append(('a', node))
                self.change_list.append(('h',))
                self.change_list.append(('p',))
                return game.utility(node, player)
            v = -infinity
            self.change_list.append(('a', node))
            self.change_list.append(('ab',node, v, beta))
            self.change_list.append(('h',))
            for a in game.actions(node):
                min_val = min_value(game.result(node, a), alpha, beta)
                if v < min_val:
                    v = min_val
                    max_node = game.result(node, a)
                    self.change_list.append(('ab',node, v, beta))
                if v >= beta:
                    self.change_list.append(('h',))
                    self.pruned.add(node)
                    break
                alpha = max(alpha, v)
            self.utils[node] = v
            if node not in self.pruned:
                self.change_list.append(('l', (node, max_node - 3*node - 1)))
            self.change_list.append(('e',node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return v

        def min_value(node, alpha, beta):
            if game.terminal_test(node):
                self.change_list.append(('a', node))
                self.change_list.append(('h',))
                self.change_list.append(('p',))
                return game.utility(node, player)
            v = infinity
            self.change_list.append(('a', node))
            self.change_list.append(('ab',node, alpha, v))
            self.change_list.append(('h',))
            for a in game.actions(node):
                max_val = max_value(game.result(node, a), alpha, beta)
                if v > max_val:
                    v = max_val
                    min_node = game.result(node, a)
                    self.change_list.append(('ab',node, alpha, v))
                if v <= alpha:
                    self.change_list.append(('h',))
                    self.pruned.add(node)
                    break
                beta = min(beta, v)
            self.utils[node] = v
            if node not in self.pruned:
                self.change_list.append(('l', (node, min_node - 3*node - 1)))
            self.change_list.append(('e',node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return v

        return max_value(node, -infinity, infinity)

    def stack_manager_gen(self):
        self.alphabeta_search(0)
        for change in self.change_list:
            if change[0] == 'a':
                self.node_stack.append(change[1])
            elif change[0] == 'ab':
                self.ab[change[1]] = change[2:]
            elif change[0] == 'e':
                self.explored.add(change[1])
            elif change[0] == 'h':
                yield
            elif change[0] == 'l':
                self.thick_lines.add(change[1])
            elif change[0] == 'p':
                self.node_stack.pop()

    def mouse_click(self, x, y):
        try:
            self.stack_manager.send(None)
        except StopIteration:
            pass
        self.draw_graph()

    def draw_graph(self):
        self.clear()
        # draw nodes
        self.stroke(0, 0, 0)
        self.strokeWidth(1)
        # highlight for nodes in stack
        for node in self.node_stack:
            x, y = self.node_pos[node]
            # alpha > beta
            if node not in self.explored and self.ab[node][0] > self.ab[node][1]:
                self.fill(200, 100, 100)
            else:
                self.fill(200, 200, 0)
            self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
        for node in self.nodes:
            x, y = self.node_pos[node]
            if node in self.explored:
                if node in self.pruned:
                    self.fill(50, 50, 50)
                else:
                    self.fill(255, 255, 255)
            else:
                self.fill(200, 200, 200)
            self.rect_n(x, y, self.l, self.l)
            self.line_n(x, y, x + self.l, y)
            self.line_n(x, y, x, y + self.l)
            self.line_n(x + self.l, y + self.l, x + self.l, y)
            self.line_n(x + self.l, y + self.l, x, y + self.l)
            self.fill(0, 0, 0)
            if node in self.explored and node not in self.pruned:
                self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
        # draw edges
        for i in range(13):
            x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
            for j in range(3):
                x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
                if i in [1, 2, 3]:
                    self.stroke(200, 0, 0)
                else:
                    self.stroke(0, 200, 0)
                if (i, j) in self.thick_lines:
                    self.strokeWidth(3)
                else:
                    self.strokeWidth(1)
                self.line_n(x1, y1, x2, y2)
        # display alpha and beta
        for node in self.node_stack:
            if node not in self.explored:
                x, y = self.node_pos[node]
                alpha, beta = self.ab[node]
                self.text_n(alpha, x - self.l/2, y - self.l/10)
                self.text_n(beta, x + self.l, y - self.l/10)
        self.update()


class Canvas_fol_bc_ask(Canvas):
    """fol_bc_ask() on HTML canvas"""
    def __init__(self, varname, kb, query, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.kb = kb
        self.query = query
        self.l = 1/20
        self.b = 3*self.l
        bc_out = list(self.fol_bc_ask())
        if len(bc_out) is 0:
            self.valid = False
        else:
            self.valid = True
            graph = bc_out[0][0][0]
            s = bc_out[0][1]
            while True:
                new_graph = subst(s, graph)
                if graph == new_graph:
                    break
                graph = new_graph
            self.make_table(graph)
        self.context = None
        self.draw_table()

    def fol_bc_ask(self):
        KB = self.kb
        query = self.query
        def fol_bc_or(KB, goal, theta):
            for rule in KB.fetch_rules_for_goal(goal):
                lhs, rhs = parse_definite_clause(standardize_variables(rule))
                for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)):
                    yield ([(goal, theta1[0])], theta1[1])

        def fol_bc_and(KB, goals, theta):
            if theta is None:
                pass
            elif not goals:
                yield ([], theta)
            else:
                first, rest = goals[0], goals[1:]
                for theta1 in fol_bc_or(KB, subst(theta, first), theta):
                    for theta2 in fol_bc_and(KB, rest, theta1[1]):
                        yield (theta1[0] + theta2[0], theta2[1])

        return fol_bc_or(KB, query, {})

    def make_table(self, graph):
        table = []
        pos = {}
        links = set()
        edges = set()

        def dfs(node, depth):
            if len(table) <= depth:
                table.append([])
            pos = len(table[depth])
            table[depth].append(node[0])
            for child in node[1]:
                child_id = dfs(child, depth + 1)
                links.add(((depth, pos), child_id))
            return (depth, pos)

        dfs(graph, 0)
        y_off = 0.85/len(table)
        for i, row in enumerate(table):
            x_off = 0.95/len(row)
            for j, node in enumerate(row):
                pos[(i, j)] = (0.025 + j*x_off + (x_off - self.b)/2, 0.025 + i*y_off + (y_off - self.l)/2)
        for p, c in links:
            x1, y1 = pos[p]
            x2, y2 = pos[c]
            edges.add((x1 + self.b/2, y1 + self.l, x2 + self.b/2, y2))

        self.table = table
        self.pos = pos
        self.edges = edges

    def mouse_click(self, x, y):
        x, y = x/self.width, y/self.height
        for node in self.pos:
            xs, ys = self.pos[node]
            xe, ye = xs + self.b, ys + self.l
            if xs <= x <= xe and ys <= y <= ye:
                self.context = node
                break
        self.draw_table()

    def draw_table(self):
        self.clear()
        self.strokeWidth(3)
        self.stroke(0, 0, 0)
        self.font("12px Arial")
        if self.valid:
            # draw nodes
            for i, j in self.pos:
                x, y = self.pos[(i, j)]
                self.fill(200, 200, 200)
                self.rect_n(x, y, self.b, self.l)
                self.line_n(x, y, x + self.b, y)
                self.line_n(x, y, x, y + self.l)
                self.line_n(x + self.b, y, x + self.b, y + self.l)
                self.line_n(x, y + self.l, x + self.b, y + self.l)
                self.fill(0, 0, 0)
                self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01)
            #draw edges
            for x1, y1, x2, y2 in self.edges:
                self.line_n(x1, y1, x2, y2)
        else:
            self.fill(255, 0, 0)
            self.rect_n(0, 0, 1, 1)
        # text area
        self.fill(255, 255, 255)
        self.rect_n(0, 0.9, 1, 0.1)
        self.strokeWidth(5)
        self.stroke(0, 0, 0)
        self.line_n(0, 0.9, 1, 0.9)
        self.font("22px Arial")
        self.fill(0, 0, 0)
        self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975)
        self.update()
############################################################################################################

#####################           Functions to assist plotting in search.ipynb            ####################

############################################################################################################
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib import lines

from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import time
from search import GraphProblem, romania_map

def show_map(graph_data, node_colors = None):
    G = nx.Graph(graph_data['graph_dict'])
    node_colors = node_colors or graph_data['node_colors']
    node_positions = graph_data['node_positions']
    node_label_pos = graph_data['node_label_positions']
    edge_weights= graph_data['edge_weights']
    
    # set the size of the plot
    plt.figure(figsize=(18,13))
    # draw the graph (both nodes and edges) with locations from romania_locations
    nx.draw(G, pos={k: node_positions[k] for k in G.nodes()},
            node_color=[node_colors[node] for node in G.nodes()], linewidths=0.3, edgecolors='k')

    # draw labels for nodes
    node_label_handles = nx.draw_networkx_labels(G, pos=node_label_pos, font_size=14)
    
    # add a white bounding box behind the node labels
    [label.set_bbox(dict(facecolor='white', edgecolor='none')) for label in node_label_handles.values()]

    # add edge lables to the graph
    nx.draw_networkx_edge_labels(G, pos=node_positions, edge_labels=edge_weights, font_size=14)
    
    # add a legend
    white_circle = lines.Line2D([], [], color="white", marker='o', markersize=15, markerfacecolor="white")
    orange_circle = lines.Line2D([], [], color="orange", marker='o', markersize=15, markerfacecolor="orange")
    red_circle = lines.Line2D([], [], color="red", marker='o', markersize=15, markerfacecolor="red")
    gray_circle = lines.Line2D([], [], color="gray", marker='o', markersize=15, markerfacecolor="gray")
    green_circle = lines.Line2D([], [], color="green", marker='o', markersize=15, markerfacecolor="green")
    plt.legend((white_circle, orange_circle, red_circle, gray_circle, green_circle),
               ('Un-explored', 'Frontier', 'Currently Exploring', 'Explored', 'Final Solution'),
               numpoints=1, prop={'size':16}, loc=(.8,.75))
    
    # show the plot. No need to use in notebooks. nx.draw will show the graph itself.
    plt.show()
    
## helper functions for visualisations
   
def final_path_colors(initial_node_colors, problem, solution):
    "Return a node_colors dict of the final path provided the problem and solution."
    
    # get initial node colors
    final_colors = dict(initial_node_colors)
    # color all the nodes in solution and starting node to green
    final_colors[problem.initial] = "green"
    for node in solution:
        final_colors[node] = "green"  
    return final_colors

def display_visual(graph_data, user_input, algorithm=None, problem=None):
    initial_node_colors = graph_data['node_colors']
    if user_input == False:
        def slider_callback(iteration):
            # don't show graph for the first time running the cell calling this function
            try:
                show_map(graph_data, node_colors=all_node_colors[iteration])
            except:
                pass
        def visualize_callback(Visualize):
            if Visualize is True:
                button.value = False
                
                global all_node_colors
                
                iterations, all_node_colors, node = algorithm(problem)
                solution = node.solution()
                all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
                
                slider.max = len(all_node_colors) - 1
                
                for i in range(slider.max + 1):
                    slider.value = i
                     #time.sleep(.5)
        
        slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
        slider_visual = widgets.interactive(slider_callback, iteration=slider)
        display(slider_visual)

        button = widgets.ToggleButton(value=False)
        button_visual = widgets.interactive(visualize_callback, Visualize=button)
        display(button_visual)
    
    if user_input == True:
        node_colors = dict(initial_node_colors)
        if isinstance(algorithm, dict):
            assert set(algorithm.keys()).issubset({"Breadth First Tree Search",
                                                       "Depth First Tree Search", 
                                                       "Breadth First Search", 
                                                       "Depth First Graph Search", 
                                                       "Uniform Cost Search", 
                                                       "A-star Search"})
            algo_dropdown = widgets.Dropdown(description="Search algorithm: ",
                                             options=sorted(list(algorithm.keys())),
                                             value="Breadth First Tree Search")