notebook.py 32,6 ko
Newer Older
from inspect import getsource

from utils import argmax, argmin
from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity
from logic import parse_definite_clause, standardize_variables, unify, subst
from IPython.display import HTML, display
Anthony Marakis's avatar
Anthony Marakis a validé
from collections import Counter, defaultdict
import numpy as np
Anthony Marakis's avatar
Anthony Marakis a validé
import time
#______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# Magic Words
def pseudocode(algorithm):
    """Print the pseudocode for the given algorithm."""
    from urllib.request import urlopen
    from IPython.display import Markdown

Anthony Marakis's avatar
Anthony Marakis a validé
    algorithm = algorithm.replace(' ', '-')
    url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
    f = urlopen(url)
    md = f.read().decode('utf-8')
    md = md.split('\n', 1)[-1].strip()
    md = '#' + md
    return Markdown(md)


def psource(*functions):
    """Print the source code for the given function(s)."""
    source_code = '\n\n'.join(getsource(fn) for fn in functions)
    try:
        from pygments.formatters import HtmlFormatter
        from pygments.lexers import PythonLexer
        from pygments import highlight

        display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))

    except ImportError:
        print(source_code)
# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# Iris Visualization
    """Plots the iris dataset in a 3D plot.
    The three axes are given by i, j and k,
    which correspond to three of the four iris features."""
    from mpl_toolkits.mplot3d import Axes3D

    plt.rcParams.update(plt.rcParamsDefault)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    iris = DataSet(name="iris")
    buckets = iris.split_values_by_classes()

    features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
    f1, f2, f3 = features[i], features[j], features[k]

    a_setosa = [v[i] for v in buckets["setosa"]]
    b_setosa = [v[j] for v in buckets["setosa"]]
    c_setosa = [v[k] for v in buckets["setosa"]]

    a_virginica = [v[i] for v in buckets["virginica"]]
    b_virginica = [v[j] for v in buckets["virginica"]]
    c_virginica = [v[k] for v in buckets["virginica"]]

    a_versicolor = [v[i] for v in buckets["versicolor"]]
    b_versicolor = [v[j] for v in buckets["versicolor"]]
    c_versicolor = [v[k] for v in buckets["versicolor"]]


    for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
                             ('g', '^', a_virginica, b_virginica, c_virginica),
                             ('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
        ax.scatter(sl, sw, pl, c=c, marker=m)

    ax.set_xlabel(f1)
    ax.set_ylabel(f2)
    ax.set_zlabel(f3)

    plt.show()

# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# MNIST
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
    import os, struct
    import array
    import numpy as np
    from collections import Counter

    if fashion:
        path = "aima-data/MNIST/Fashion"

    plt.rcParams.update(plt.rcParamsDefault)
    plt.rcParams['figure.figsize'] = (10.0, 8.0)
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'

    train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
    train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
    test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
    test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")

    magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
    tr_img = array.array("B", train_img_file.read())
    train_img_file.close()
    magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
    tr_lbl = array.array("b", train_lbl_file.read())
    train_lbl_file.close()

    magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
    te_img = array.array("B", test_img_file.read())
    test_img_file.close()
    magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
    te_lbl = array.array("b", test_lbl_file.read())
    test_lbl_file.close()

     #print(len(tr_img), len(tr_lbl), tr_size)
     #print(len(te_img), len(te_lbl), te_size)

    train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)
    train_lbl = np.zeros((tr_size,), dtype=np.int8)
    for i in range(tr_size):
        train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols))
        train_lbl[i] = tr_lbl[i]

    test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16)
    test_lbl = np.zeros((te_size,), dtype=np.int8)
    for i in range(te_size):
        test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols))
        test_lbl[i] = te_lbl[i]

    return(train_img, train_lbl, test_img, test_lbl)


digit_classes = [str(i) for i in range(10)]
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
                   "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]


def show_MNIST(labels, images, samples=8, fashion=False):
    if not fashion:
        classes = digit_classes
    else:
        classes = fashion_classes

    num_classes = len(classes)

    for y, cls in enumerate(classes):
        idxs = np.nonzero([i == y for i in labels])
        idxs = np.random.choice(idxs[0], samples, replace=False)
        for i , idx in enumerate(idxs):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples, num_classes, plt_idx)
            plt.imshow(images[idx].reshape((28, 28)))
            plt.axis("off")
            if i == 0:
                plt.title(cls)

    plt.show()


def show_ave_MNIST(labels, images, fashion=False):
    if not fashion:
        item_type = "Digit"
        classes = digit_classes
    else:
        item_type = "Apparel"
        classes = fashion_classes

    num_classes = len(classes)

    for y, cls in enumerate(classes):
        idxs = np.nonzero([i == y for i in labels])
        print(item_type, y, ":", len(idxs[0]), "images.")

        ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
        #print(ave_img.shape)

        plt.subplot(1, num_classes, y+1)
        plt.imshow(ave_img.reshape((28, 28)))
        plt.axis("off")
        plt.title(cls)

    plt.show()

# ______________________________________________________________________________
Anthony Marakis's avatar
Anthony Marakis a validé
# MDP


def make_plot_grid_step_function(columns, rows, U_over_time):
    """ipywidgets interactive function supports single parameter as input.
    This function creates and return such a function by taking as input
    other parameters."""

    def plot_grid_step(iteration):
        data = U_over_time[iteration]
        data = defaultdict(lambda: 0, data)
        grid = []
        for row in range(rows):
            current_row = []
            for column in range(columns):
                current_row.append(data[(column, row)])
            grid.append(current_row)
        grid.reverse() # output like book
        fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation='nearest')

        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        for col in range(len(grid)):
            for row in range(len(grid[0])):
                magic = grid[col][row]
                fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center')

        plt.show()

    return plot_grid_step

def make_visualize(slider):
    """Takes an input a sliderand returns callback function
    for timer and animation."""

    def visualize_callback(Visualize, time_step):
        if Visualize is True:
            for i in range(slider.min, slider.max + 1):
                slider.value = i
                time.sleep(float(time_step))

    return visualize_callback

# ______________________________________________________________________________
_canvas = """
<script type="text/javascript" src="./js/canvas.js"></script>
<div>
<canvas id="{0}" width="{1}" height="{2}" style="background:rgba(158, 167, 184, 0.2);" onclick='click_callback(this, event, "{3}")'></canvas>
</div>

<script> var {0}_canvas_object = new Canvas("{0}");</script>
"""  # noqa


class Canvas:
    """Inherit from this class to manage the HTML canvas element in jupyter notebooks.
    To create an object of this class any_name_xyz = Canvas("any_name_xyz")
    The first argument given must be the name of the object being created.
    IPython must be able to refernce the variable name that is being passed."""
    def __init__(self, varname, width=800, height=600, cid=None):
        self.name = varname
        self.cid = cid or varname
        self.width = width
        self.height = height
        self.html = _canvas.format(self.cid, self.width, self.height, self.name)
        self.exec_list = []
        display_html(self.html)

    def mouse_click(self, x, y):
        """Override this method to handle mouse click at position (x, y)"""
        raise NotImplementedError

    def mouse_move(self, x, y):
        raise NotImplementedError

Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
    def execute(self, exec_str):
        """Stores the command to be exectued to a list which is used later during update()"""
        if not isinstance(exec_str, str):
            print("Invalid execution argument:", exec_str)
            self.alert("Recieved invalid execution command format")
        prefix = "{0}_canvas_object.".format(self.cid)
        self.exec_list.append(prefix + exec_str + ';')

    def fill(self, r, g, b):
        """Changes the fill color to a color in rgb format"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("fill({0}, {1}, {2})".format(r, g, b))

    def stroke(self, r, g, b):
        """Changes the colors of line/strokes to rgb"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("stroke({0}, {1}, {2})".format(r, g, b))

    def strokeWidth(self, w):
        """Changes the width of lines/strokes to 'w' pixels"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("strokeWidth({0})".format(w))

    def rect(self, x, y, w, h):
        """Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h))

    def rect_n(self, xn, yn, wn, hn):
        """Similar to rect(), but the dimensions are normalized to fall between 0 and 1"""
        x = round(xn * self.width)
        y = round(yn * self.height)
        w = round(wn * self.width)
        h = round(hn * self.height)
        self.rect(x, y, w, h)

    def line(self, x1, y1, x2, y2):
        """Draw a line from (x1, y1) to (x2, y2)"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2))

    def line_n(self, x1n, y1n, x2n, y2n):
        """Similar to line(), but the dimensions are normalized to fall between 0 and 1"""
        x1 = round(x1n * self.width)
        y1 = round(y1n * self.height)
        x2 = round(x2n * self.width)
        y2 = round(y2n * self.height)
        self.line(x1, y1, x2, y2)

    def arc(self, x, y, r, start, stop):
        """Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop))
    def arc_n(self, xn, yn, rn, start, stop):
        """Similar to arc(), but the dimensions are normalized to fall between 0 and 1
        The normalizing factor for radius is selected between width and height by
        seeing which is smaller."""
        x = round(xn * self.width)
        y = round(yn * self.height)
        r = round(rn * min(self.width, self.height))
        self.arc(x, y, r, start, stop)

    def clear(self):
        """Clear the HTML canvas"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute("clear()")

    def font(self, font):
        """Changes the font of text"""
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
        self.execute('font("{0}")'.format(font))
    def text(self, txt, x, y, fill=True):
        """Display a text at (x, y)"""
        if fill:
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
            self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y))
        else:
Surya Teja Cheedella's avatar
Surya Teja Cheedella a validé
            self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y))
    def text_n(self, txt, xn, yn, fill=True):
        """Similar to text(), but with normalized coordinates"""
        x = round(xn * self.width)
        y = round(yn * self.height)
        self.text(txt, x, y, fill)

    def alert(self, message):
        """Immediately display an alert"""
        display_html('<script>alert("{0}")</script>'.format(message))

    def update(self):
        """Execute the JS code to execute the commands queued by execute()"""
        exec_code = "<script>\n" + '\n'.join(self.exec_list) + "\n</script>"
        self.exec_list = []
        display_html(exec_code)

def display_html(html_string):
    display(HTML(html_string))


################################################################################

class Canvas_TicTacToe(Canvas):
    """Play a 3x3 TicTacToe game on HTML canvas"""
    def __init__(self, varname, player_1='human', player_2='random',
                 width=300, height=350, cid=None):
        valid_players = ('human', 'random', 'alphabeta')
        if player_1 not in valid_players or player_2 not in valid_players:
            raise TypeError("Players must be one of {}".format(valid_players))
        Canvas.__init__(self, varname, width, height, cid)
        self.ttt = TicTacToe()
        self.state = self.ttt.initial
        self.turn = 0
        self.strokeWidth(5)
        self.players = (player_1, player_2)
        self.font("20px Arial")
        self.draw_board()

    def mouse_click(self, x, y):
        player = self.players[self.turn]
        if self.ttt.terminal_test(self.state):
            if 0.55 <= x/self.width <= 0.95 and 6/7 <= y/self.height <= 6/7+1/8:
                self.state = self.ttt.initial
                self.turn = 0
                self.draw_board()
            return

        if player == 'human':
            x, y = int(3*x/self.width) + 1, int(3*y/(self.height*6/7)) + 1
            if (x, y) not in self.ttt.actions(self.state):
                # Invalid move
                return
            move = (x, y)
        elif player == 'alphabeta':
            move = alphabeta_player(self.ttt, self.state)
        else:
            move = random_player(self.ttt, self.state)
        self.state = self.ttt.result(self.state, move)
        self.turn ^= 1
        self.draw_board()

    def draw_board(self):
        self.clear()
        self.stroke(0, 0, 0)
        offset = 1/20
        self.line_n(0 + offset, (1/3)*6/7, 1 - offset, (1/3)*6/7)
        self.line_n(0 + offset, (2/3)*6/7, 1 - offset, (2/3)*6/7)
        self.line_n(1/3, (0 + offset)*6/7, 1/3, (1 - offset)*6/7)
        self.line_n(2/3, (0 + offset)*6/7, 2/3, (1 - offset)*6/7)

        board = self.state.board
        for mark in board:
            if board[mark] == 'X':
                self.draw_x(mark)
            elif board[mark] == 'O':
                self.draw_o(mark)
        if self.ttt.terminal_test(self.state):
            # End game message
            utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial))
            if utility == 0:
                self.text_n('Game Draw!', offset, 6/7 + offset)
            else:
                self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6/7 + offset)
                # Find the 3 and draw a line
                self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
                for i in range(3):
                    if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
                       len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
                        self.line_n(i/3 + 1/6, offset*6/7, i/3 + 1/6, (1 - offset)*6/7)
                    if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
                       len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
                        self.line_n(offset, (i/3 + 1/6)*6/7, 1 - offset, (i/3 + 1/6)*6/7)
                if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
                   len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
                        self.line_n(offset, offset*6/7, 1 - offset, (1 - offset)*6/7)
                if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
                   len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
                        self.line_n(offset, (1 - offset)*6/7, 1 - offset, offset*6/7)
            # restart button
            self.fill(0, 0, 255)
            self.rect_n(0.5 + offset, 6/7, 0.4, 1/8)
            self.fill(0, 0, 0)
            self.text_n('Restart', 0.5 + 2*offset, 13/14)
        else:  # Print which player's turn it is
            self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
                        offset, 6/7 + offset)

        self.update()

    def draw_x(self, position):
        self.stroke(0, 255, 0)
        x, y = [i-1 for i in position]
        offset = 1/15
        self.line_n(x/3 + offset, (y/3 + offset)*6/7, x/3 + 1/3 - offset, (y/3 + 1/3 - offset)*6/7)
        self.line_n(x/3 + 1/3 - offset, (y/3 + offset)*6/7, x/3 + offset, (y/3 + 1/3 - offset)*6/7)

    def draw_o(self, position):
        self.stroke(255, 0, 0)
        x, y = [i-1 for i in position]
        self.arc_n(x/3 + 1/6, (y/3 + 1/6)*6/7, 1/9, 0, 360)


class Canvas_minimax(Canvas):
    """Minimax for Fig52Extended on HTML canvas"""
    def __init__(self, varname, util_list, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
        self.game = Fig52Extended()
        self.game.utils = self.utils
        self.nodes = list(range(40))
        self.l = 1/40
        self.node_pos = {}
        for i in range(4):
            base = len(self.node_pos)
            row_size = 3**i
            for node in [base + j for j in range(row_size)]:
                self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
                                       self.l/2 + (self.l + (1 - 5*self.l)/3)*i)
        self.font("12px Arial")
        self.node_stack = []
        self.explored = {node for node in self.utils}
        self.thick_lines = set()
        self.change_list = []
        self.draw_graph()
        self.stack_manager = self.stack_manager_gen()

    def minimax(self, node):
        game = self.game
        player = game.to_move(node)
        def max_value(node):
            if game.terminal_test(node):
                return game.utility(node, player)
            self.change_list.append(('a', node))
            self.change_list.append(('h',))
            max_a = argmax(game.actions(node), key=lambda x: min_value(game.result(node, x)))
            max_node = game.result(node, max_a)
            self.utils[node] = self.utils[max_node]
            x1, y1 = self.node_pos[node]
            x2, y2 = self.node_pos[max_node]
            self.change_list.append(('l', (node, max_node - 3*node - 1)))
            self.change_list.append(('e', node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return self.utils[node]

        def min_value(node):
            if game.terminal_test(node):
                return game.utility(node, player)
            self.change_list.append(('a', node))
            self.change_list.append(('h',))
            min_a = argmin(game.actions(node), key=lambda x: max_value(game.result(node, x)))
            min_node = game.result(node, min_a)
            self.utils[node] = self.utils[min_node]
            x1, y1 = self.node_pos[node]
            x2, y2 = self.node_pos[min_node]
            self.change_list.append(('l', (node, min_node - 3*node - 1)))
            self.change_list.append(('e', node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return self.utils[node]

        return max_value(node)

    def stack_manager_gen(self):
        self.minimax(0)
        for change in self.change_list:
            if change[0] == 'a':
                self.node_stack.append(change[1])
            elif change[0] == 'e':
                self.explored.add(change[1])
            elif change[0] == 'h':
                yield
            elif change[0] == 'l':
                self.thick_lines.add(change[1])
            elif change[0] == 'p':
                self.node_stack.pop()

    def mouse_click(self, x, y):
        try:
            self.stack_manager.send(None)
        except StopIteration:
            pass
        self.draw_graph()

    def draw_graph(self):
        self.clear()
        # draw nodes
        self.stroke(0, 0, 0)
        self.strokeWidth(1)
        # highlight for nodes in stack
        for node in self.node_stack:
            x, y = self.node_pos[node]
            self.fill(200, 200, 0)
            self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
        for node in self.nodes:
            x, y = self.node_pos[node]
            if node in self.explored:
                self.fill(255, 255, 255)
            else:
                self.fill(200, 200, 200)
            self.rect_n(x, y, self.l, self.l)
            self.line_n(x, y, x + self.l, y)
            self.line_n(x, y, x, y + self.l)
            self.line_n(x + self.l, y + self.l, x + self.l, y)
            self.line_n(x + self.l, y + self.l, x, y + self.l)
            self.fill(0, 0, 0)
            if node in self.explored:
                self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
        # draw edges
        for i in range(13):
            x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
            for j in range(3):
                x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
                if i in [1, 2, 3]:
                    self.stroke(200, 0, 0)
                else:
                    self.stroke(0, 200, 0)
                if (i, j) in self.thick_lines:
                    self.strokeWidth(3)
                else:
                    self.strokeWidth(1)
                self.line_n(x1, y1, x2, y2)
        self.update()


class Canvas_alphabeta(Canvas):
    """Alpha-beta pruning for Fig52Extended on HTML canvas"""
    def __init__(self, varname, util_list, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
        self.game = Fig52Extended()
        self.game.utils = self.utils
        self.nodes = list(range(40))
        self.l = 1/40
        self.node_pos = {}
        for i in range(4):
            base = len(self.node_pos)
            row_size = 3**i
            for node in [base + j for j in range(row_size)]:
                self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
                                       3*self.l/2 + (self.l + (1 - 6*self.l)/3)*i)
        self.font("12px Arial")
        self.node_stack = []
        self.explored = {node for node in self.utils}
        self.pruned = set()
        self.ab = {}
        self.thick_lines = set()
        self.change_list = []
        self.draw_graph()
        self.stack_manager = self.stack_manager_gen()

    def alphabeta_search(self, node):
        game = self.game
        player = game.to_move(node)

        # Functions used by alphabeta
        def max_value(node, alpha, beta):
            if game.terminal_test(node):
                self.change_list.append(('a', node))
                self.change_list.append(('h',))
                self.change_list.append(('p',))
                return game.utility(node, player)
            v = -infinity
            self.change_list.append(('a', node))
            self.change_list.append(('ab',node, v, beta))
            self.change_list.append(('h',))
            for a in game.actions(node):
                min_val = min_value(game.result(node, a), alpha, beta)
                if v < min_val:
                    v = min_val
                    max_node = game.result(node, a)
                    self.change_list.append(('ab',node, v, beta))
                if v >= beta:
                    self.change_list.append(('h',))
                    self.pruned.add(node)
                    break
                alpha = max(alpha, v)
            self.utils[node] = v
            if node not in self.pruned:
                self.change_list.append(('l', (node, max_node - 3*node - 1)))
            self.change_list.append(('e',node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return v

        def min_value(node, alpha, beta):
            if game.terminal_test(node):
                self.change_list.append(('a', node))
                self.change_list.append(('h',))
                self.change_list.append(('p',))
                return game.utility(node, player)
            v = infinity
            self.change_list.append(('a', node))
            self.change_list.append(('ab',node, alpha, v))
            self.change_list.append(('h',))
            for a in game.actions(node):
                max_val = max_value(game.result(node, a), alpha, beta)
                if v > max_val:
                    v = max_val
                    min_node = game.result(node, a)
                    self.change_list.append(('ab',node, alpha, v))
                if v <= alpha:
                    self.change_list.append(('h',))
                    self.pruned.add(node)
                    break
                beta = min(beta, v)
            self.utils[node] = v
            if node not in self.pruned:
                self.change_list.append(('l', (node, min_node - 3*node - 1)))
            self.change_list.append(('e',node))
            self.change_list.append(('p',))
            self.change_list.append(('h',))
            return v

        return max_value(node, -infinity, infinity)

    def stack_manager_gen(self):
        self.alphabeta_search(0)
        for change in self.change_list:
            if change[0] == 'a':
                self.node_stack.append(change[1])
            elif change[0] == 'ab':
                self.ab[change[1]] = change[2:]
            elif change[0] == 'e':
                self.explored.add(change[1])
            elif change[0] == 'h':
                yield
            elif change[0] == 'l':
                self.thick_lines.add(change[1])
            elif change[0] == 'p':
                self.node_stack.pop()

    def mouse_click(self, x, y):
        try:
            self.stack_manager.send(None)
        except StopIteration:
            pass
        self.draw_graph()

    def draw_graph(self):
        self.clear()
        # draw nodes
        self.stroke(0, 0, 0)
        self.strokeWidth(1)
        # highlight for nodes in stack
        for node in self.node_stack:
            x, y = self.node_pos[node]
            # alpha > beta
            if node not in self.explored and self.ab[node][0] > self.ab[node][1]:
                self.fill(200, 100, 100)
            else:
                self.fill(200, 200, 0)
            self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
        for node in self.nodes:
            x, y = self.node_pos[node]
            if node in self.explored:
                if node in self.pruned:
                    self.fill(50, 50, 50)
                else:
                    self.fill(255, 255, 255)
            else:
                self.fill(200, 200, 200)
            self.rect_n(x, y, self.l, self.l)
            self.line_n(x, y, x + self.l, y)
            self.line_n(x, y, x, y + self.l)
            self.line_n(x + self.l, y + self.l, x + self.l, y)
            self.line_n(x + self.l, y + self.l, x, y + self.l)
            self.fill(0, 0, 0)
            if node in self.explored and node not in self.pruned:
                self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
        # draw edges
        for i in range(13):
            x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
            for j in range(3):
                x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
                if i in [1, 2, 3]:
                    self.stroke(200, 0, 0)
                else:
                    self.stroke(0, 200, 0)
                if (i, j) in self.thick_lines:
                    self.strokeWidth(3)
                else:
                    self.strokeWidth(1)
                self.line_n(x1, y1, x2, y2)
        # display alpha and beta
        for node in self.node_stack:
            if node not in self.explored:
                x, y = self.node_pos[node]
                alpha, beta = self.ab[node]
                self.text_n(alpha, x - self.l/2, y - self.l/10)
                self.text_n(beta, x + self.l, y - self.l/10)
        self.update()


class Canvas_fol_bc_ask(Canvas):
    """fol_bc_ask() on HTML canvas"""
    def __init__(self, varname, kb, query, width=800, height=600, cid=None):
        Canvas.__init__(self, varname, width, height, cid)
        self.kb = kb
        self.query = query
        self.l = 1/20
        self.b = 3*self.l
        bc_out = list(self.fol_bc_ask())
        if len(bc_out) is 0:
            self.valid = False
        else:
            self.valid = True
            graph = bc_out[0][0][0]
            s = bc_out[0][1]
            while True:
                new_graph = subst(s, graph)
                if graph == new_graph:
                    break
                graph = new_graph
            self.make_table(graph)
        self.context = None
        self.draw_table()

    def fol_bc_ask(self):
        KB = self.kb
        query = self.query
        def fol_bc_or(KB, goal, theta):
            for rule in KB.fetch_rules_for_goal(goal):
                lhs, rhs = parse_definite_clause(standardize_variables(rule))
                for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)):
                    yield ([(goal, theta1[0])], theta1[1])

        def fol_bc_and(KB, goals, theta):
            if theta is None:
                pass
            elif not goals:
                yield ([], theta)
            else:
                first, rest = goals[0], goals[1:]
                for theta1 in fol_bc_or(KB, subst(theta, first), theta):
                    for theta2 in fol_bc_and(KB, rest, theta1[1]):
                        yield (theta1[0] + theta2[0], theta2[1])

        return fol_bc_or(KB, query, {})

    def make_table(self, graph):
        table = []
        pos = {}
        links = set()
        edges = set()

        def dfs(node, depth):
            if len(table) <= depth:
                table.append([])
            pos = len(table[depth])
            table[depth].append(node[0])
            for child in node[1]:
                child_id = dfs(child, depth + 1)
                links.add(((depth, pos), child_id))
            return (depth, pos)

        dfs(graph, 0)
        y_off = 0.85/len(table)
        for i, row in enumerate(table):
            x_off = 0.95/len(row)
            for j, node in enumerate(row):
                pos[(i, j)] = (0.025 + j*x_off + (x_off - self.b)/2, 0.025 + i*y_off + (y_off - self.l)/2)
        for p, c in links:
            x1, y1 = pos[p]
            x2, y2 = pos[c]
            edges.add((x1 + self.b/2, y1 + self.l, x2 + self.b/2, y2))

        self.table = table
        self.pos = pos
        self.edges = edges

    def mouse_click(self, x, y):
        x, y = x/self.width, y/self.height
        for node in self.pos:
            xs, ys = self.pos[node]
            xe, ye = xs + self.b, ys + self.l
            if xs <= x <= xe and ys <= y <= ye:
                self.context = node
                break
        self.draw_table()

    def draw_table(self):
        self.clear()
        self.strokeWidth(3)
        self.stroke(0, 0, 0)
        self.font("12px Arial")
        if self.valid:
            # draw nodes
            for i, j in self.pos:
                x, y = self.pos[(i, j)]
                self.fill(200, 200, 200)
                self.rect_n(x, y, self.b, self.l)
                self.line_n(x, y, x + self.b, y)
                self.line_n(x, y, x, y + self.l)
                self.line_n(x + self.b, y, x + self.b, y + self.l)
                self.line_n(x, y + self.l, x + self.b, y + self.l)
                self.fill(0, 0, 0)
                self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01)
            #draw edges
            for x1, y1, x2, y2 in self.edges:
                self.line_n(x1, y1, x2, y2)
        else:
            self.fill(255, 0, 0)
            self.rect_n(0, 0, 1, 1)
        # text area
        self.fill(255, 255, 255)
        self.rect_n(0, 0.9, 1, 0.1)
        self.strokeWidth(5)
        self.stroke(0, 0, 0)
        self.line_n(0, 0.9, 1, 0.9)
        self.font("22px Arial")
        self.fill(0, 0, 0)
        self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975)
        self.update()