from inspect import getsource from utils import argmax, argmin from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity from logic import parse_definite_clause, standardize_variables, unify, subst from learning import DataSet from IPython.display import HTML, display from collections import Counter import matplotlib.pyplot as plt import numpy as np import os, struct import array #______________________________________________________________________________ def pseudocode(algorithm): """Print the pseudocode for the given algorithm.""" from urllib.request import urlopen from IPython.display import Markdown url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm) f = urlopen(url) md = f.read().decode('utf-8') md = md.split('\n', 1)[-1].strip() md = '#' + md return Markdown(md) def psource(*functions): """Print the source code for the given function(s).""" source_code = '\n\n'.join(getsource(fn) for fn in functions) try: from pygments.formatters import HtmlFormatter from pygments.lexers import PythonLexer from pygments import highlight display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True)))) except ImportError: print(source_code) # ______________________________________________________________________________ def psource(*functions): "Print the source code for the given function(s)." source_code = '\n\n'.join(getsource(fn) for fn in functions) try: from pygments.formatters import HtmlFormatter from pygments.lexers import PythonLexer from pygments import highlight display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True)))) except ImportError: print(source_code) # ______________________________________________________________________________ def show_iris(i=0, j=1, k=2): """Plots the iris dataset in a 3D plot. The three axes are given by i, j and k, which correspond to three of the four iris features.""" from mpl_toolkits.mplot3d import Axes3D plt.rcParams.update(plt.rcParamsDefault) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') iris = DataSet(name="iris") buckets = iris.split_values_by_classes() features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"] f1, f2, f3 = features[i], features[j], features[k] a_setosa = [v[i] for v in buckets["setosa"]] b_setosa = [v[j] for v in buckets["setosa"]] c_setosa = [v[k] for v in buckets["setosa"]] a_virginica = [v[i] for v in buckets["virginica"]] b_virginica = [v[j] for v in buckets["virginica"]] c_virginica = [v[k] for v in buckets["virginica"]] a_versicolor = [v[i] for v in buckets["versicolor"]] b_versicolor = [v[j] for v in buckets["versicolor"]] c_versicolor = [v[k] for v in buckets["versicolor"]] for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa), ('g', '^', a_virginica, b_virginica, c_virginica), ('r', 'o', a_versicolor, b_versicolor, c_versicolor)]: ax.scatter(sl, sw, pl, c=c, marker=m) ax.set_xlabel(f1) ax.set_ylabel(f2) ax.set_zlabel(f3) plt.show() # ______________________________________________________________________________ def load_MNIST(path="aima-data/MNIST"): import os, struct import array import numpy as np from collections import Counter plt.rcParams.update(plt.rcParamsDefault) plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb") train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb") test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb") test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb") magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16)) tr_img = array.array("B", train_img_file.read()) train_img_file.close() magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8)) tr_lbl = array.array("b", train_lbl_file.read()) train_lbl_file.close() magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16)) te_img = array.array("B", test_img_file.read()) test_img_file.close() magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8)) te_lbl = array.array("b", test_lbl_file.read()) test_lbl_file.close() #print(len(tr_img), len(tr_lbl), tr_size) #print(len(te_img), len(te_lbl), te_size) train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16) train_lbl = np.zeros((tr_size,), dtype=np.int8) for i in range(tr_size): train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols)) train_lbl[i] = tr_lbl[i] test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16) test_lbl = np.zeros((te_size,), dtype=np.int8) for i in range(te_size): test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols)) test_lbl[i] = te_lbl[i] return(train_img, train_lbl, test_img, test_lbl) def show_MNIST(labels, images, samples=8): classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] num_classes = len(classes) for y, cls in enumerate(classes): idxs = np.nonzero([i == y for i in labels]) idxs = np.random.choice(idxs[0], samples, replace=False) for i , idx in enumerate(idxs): plt_idx = i * num_classes + y + 1 plt.subplot(samples, num_classes, plt_idx) plt.imshow(images[idx].reshape((28, 28))) plt.axis("off") if i == 0: plt.title(cls) plt.show() def show_ave_MNIST(labels, images): classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] num_classes = len(classes) for y, cls in enumerate(classes): idxs = np.nonzero([i == y for i in labels]) print("Digit", y, ":", len(idxs[0]), "images.") ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0) #print(ave_img.shape) plt.subplot(1, num_classes, y+1) plt.imshow(ave_img.reshape((28, 28))) plt.axis("off") plt.title(cls) plt.show() # ______________________________________________________________________________ _canvas = """
""" # noqa class Canvas: """Inherit from this class to manage the HTML canvas element in jupyter notebooks. To create an object of this class any_name_xyz = Canvas("any_name_xyz") The first argument given must be the name of the object being created. IPython must be able to refernce the variable name that is being passed.""" def __init__(self, varname, width=800, height=600, cid=None): self.name = varname self.cid = cid or varname self.width = width self.height = height self.html = _canvas.format(self.cid, self.width, self.height, self.name) self.exec_list = [] display_html(self.html) def mouse_click(self, x, y): """Override this method to handle mouse click at position (x, y)""" raise NotImplementedError def mouse_move(self, x, y): raise NotImplementedError def execute(self, exec_str): """Stores the command to be exectued to a list which is used later during update()""" if not isinstance(exec_str, str): print("Invalid execution argument:", exec_str) self.alert("Recieved invalid execution command format") prefix = "{0}_canvas_object.".format(self.cid) self.exec_list.append(prefix + exec_str + ';') def fill(self, r, g, b): """Changes the fill color to a color in rgb format""" self.execute("fill({0}, {1}, {2})".format(r, g, b)) def stroke(self, r, g, b): """Changes the colors of line/strokes to rgb""" self.execute("stroke({0}, {1}, {2})".format(r, g, b)) def strokeWidth(self, w): """Changes the width of lines/strokes to 'w' pixels""" self.execute("strokeWidth({0})".format(w)) def rect(self, x, y, w, h): """Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner""" self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h)) def rect_n(self, xn, yn, wn, hn): """Similar to rect(), but the dimensions are normalized to fall between 0 and 1""" x = round(xn * self.width) y = round(yn * self.height) w = round(wn * self.width) h = round(hn * self.height) self.rect(x, y, w, h) def line(self, x1, y1, x2, y2): """Draw a line from (x1, y1) to (x2, y2)""" self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2)) def line_n(self, x1n, y1n, x2n, y2n): """Similar to line(), but the dimensions are normalized to fall between 0 and 1""" x1 = round(x1n * self.width) y1 = round(y1n * self.height) x2 = round(x2n * self.width) y2 = round(y2n * self.height) self.line(x1, y1, x2, y2) def arc(self, x, y, r, start, stop): """Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'""" self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop)) def arc_n(self, xn, yn, rn, start, stop): """Similar to arc(), but the dimensions are normalized to fall between 0 and 1 The normalizing factor for radius is selected between width and height by seeing which is smaller.""" x = round(xn * self.width) y = round(yn * self.height) r = round(rn * min(self.width, self.height)) self.arc(x, y, r, start, stop) def clear(self): """Clear the HTML canvas""" self.execute("clear()") def font(self, font): """Changes the font of text""" self.execute('font("{0}")'.format(font)) def text(self, txt, x, y, fill=True): """Display a text at (x, y)""" if fill: self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y)) else: self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y)) def text_n(self, txt, xn, yn, fill=True): """Similar to text(), but with normalized coordinates""" x = round(xn * self.width) y = round(yn * self.height) self.text(txt, x, y, fill) def alert(self, message): """Immediately display an alert""" display_html(''.format(message)) def update(self): """Execute the JS code to execute the commands queued by execute()""" exec_code = "" self.exec_list = [] display_html(exec_code) def display_html(html_string): display(HTML(html_string)) ################################################################################ class Canvas_TicTacToe(Canvas): """Play a 3x3 TicTacToe game on HTML canvas""" def __init__(self, varname, player_1='human', player_2='random', width=300, height=350, cid=None): valid_players = ('human', 'random', 'alphabeta') if player_1 not in valid_players or player_2 not in valid_players: raise TypeError("Players must be one of {}".format(valid_players)) Canvas.__init__(self, varname, width, height, cid) self.ttt = TicTacToe() self.state = self.ttt.initial self.turn = 0 self.strokeWidth(5) self.players = (player_1, player_2) self.font("20px Arial") self.draw_board() def mouse_click(self, x, y): player = self.players[self.turn] if self.ttt.terminal_test(self.state): if 0.55 <= x/self.width <= 0.95 and 6/7 <= y/self.height <= 6/7+1/8: self.state = self.ttt.initial self.turn = 0 self.draw_board() return if player == 'human': x, y = int(3*x/self.width) + 1, int(3*y/(self.height*6/7)) + 1 if (x, y) not in self.ttt.actions(self.state): # Invalid move return move = (x, y) elif player == 'alphabeta': move = alphabeta_player(self.ttt, self.state) else: move = random_player(self.ttt, self.state) self.state = self.ttt.result(self.state, move) self.turn ^= 1 self.draw_board() def draw_board(self): self.clear() self.stroke(0, 0, 0) offset = 1/20 self.line_n(0 + offset, (1/3)*6/7, 1 - offset, (1/3)*6/7) self.line_n(0 + offset, (2/3)*6/7, 1 - offset, (2/3)*6/7) self.line_n(1/3, (0 + offset)*6/7, 1/3, (1 - offset)*6/7) self.line_n(2/3, (0 + offset)*6/7, 2/3, (1 - offset)*6/7) board = self.state.board for mark in board: if board[mark] == 'X': self.draw_x(mark) elif board[mark] == 'O': self.draw_o(mark) if self.ttt.terminal_test(self.state): # End game message utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial)) if utility == 0: self.text_n('Game Draw!', offset, 6/7 + offset) else: self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6/7 + offset) # Find the 3 and draw a line self.stroke([255, 0][self.turn], [0, 255][self.turn], 0) for i in range(3): if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \ len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1: self.line_n(i/3 + 1/6, offset*6/7, i/3 + 1/6, (1 - offset)*6/7) if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \ len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1: self.line_n(offset, (i/3 + 1/6)*6/7, 1 - offset, (i/3 + 1/6)*6/7) if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \ len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1: self.line_n(offset, offset*6/7, 1 - offset, (1 - offset)*6/7) if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \ len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1: self.line_n(offset, (1 - offset)*6/7, 1 - offset, offset*6/7) # restart button self.fill(0, 0, 255) self.rect_n(0.5 + offset, 6/7, 0.4, 1/8) self.fill(0, 0, 0) self.text_n('Restart', 0.5 + 2*offset, 13/14) else: # Print which player's turn it is self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]), offset, 6/7 + offset) self.update() def draw_x(self, position): self.stroke(0, 255, 0) x, y = [i-1 for i in position] offset = 1/15 self.line_n(x/3 + offset, (y/3 + offset)*6/7, x/3 + 1/3 - offset, (y/3 + 1/3 - offset)*6/7) self.line_n(x/3 + 1/3 - offset, (y/3 + offset)*6/7, x/3 + offset, (y/3 + 1/3 - offset)*6/7) def draw_o(self, position): self.stroke(255, 0, 0) x, y = [i-1 for i in position] self.arc_n(x/3 + 1/6, (y/3 + 1/6)*6/7, 1/9, 0, 360) class Canvas_minimax(Canvas): """Minimax for Fig52Extended on HTML canvas""" def __init__(self, varname, util_list, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.utils = {node:util for node, util in zip(range(13, 40), util_list)} self.game = Fig52Extended() self.game.utils = self.utils self.nodes = list(range(40)) self.l = 1/40 self.node_pos = {} for i in range(4): base = len(self.node_pos) row_size = 3**i for node in [base + j for j in range(row_size)]: self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2, self.l/2 + (self.l + (1 - 5*self.l)/3)*i) self.font("12px Arial") self.node_stack = [] self.explored = {node for node in self.utils} self.thick_lines = set() self.change_list = [] self.draw_graph() self.stack_manager = self.stack_manager_gen() def minimax(self, node): game = self.game player = game.to_move(node) def max_value(node): if game.terminal_test(node): return game.utility(node, player) self.change_list.append(('a', node)) self.change_list.append(('h',)) max_a = argmax(game.actions(node), key=lambda x: min_value(game.result(node, x))) max_node = game.result(node, max_a) self.utils[node] = self.utils[max_node] x1, y1 = self.node_pos[node] x2, y2 = self.node_pos[max_node] self.change_list.append(('l', (node, max_node - 3*node - 1))) self.change_list.append(('e', node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return self.utils[node] def min_value(node): if game.terminal_test(node): return game.utility(node, player) self.change_list.append(('a', node)) self.change_list.append(('h',)) min_a = argmin(game.actions(node), key=lambda x: max_value(game.result(node, x))) min_node = game.result(node, min_a) self.utils[node] = self.utils[min_node] x1, y1 = self.node_pos[node] x2, y2 = self.node_pos[min_node] self.change_list.append(('l', (node, min_node - 3*node - 1))) self.change_list.append(('e', node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return self.utils[node] return max_value(node) def stack_manager_gen(self): self.minimax(0) for change in self.change_list: if change[0] == 'a': self.node_stack.append(change[1]) elif change[0] == 'e': self.explored.add(change[1]) elif change[0] == 'h': yield elif change[0] == 'l': self.thick_lines.add(change[1]) elif change[0] == 'p': self.node_stack.pop() def mouse_click(self, x, y): try: self.stack_manager.send(None) except StopIteration: pass self.draw_graph() def draw_graph(self): self.clear() # draw nodes self.stroke(0, 0, 0) self.strokeWidth(1) # highlight for nodes in stack for node in self.node_stack: x, y = self.node_pos[node] self.fill(200, 200, 0) self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5) for node in self.nodes: x, y = self.node_pos[node] if node in self.explored: self.fill(255, 255, 255) else: self.fill(200, 200, 200) self.rect_n(x, y, self.l, self.l) self.line_n(x, y, x + self.l, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.l, y + self.l, x + self.l, y) self.line_n(x + self.l, y + self.l, x, y + self.l) self.fill(0, 0, 0) if node in self.explored: self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10) # draw edges for i in range(13): x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l for j in range(3): x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1] if i in [1, 2, 3]: self.stroke(200, 0, 0) else: self.stroke(0, 200, 0) if (i, j) in self.thick_lines: self.strokeWidth(3) else: self.strokeWidth(1) self.line_n(x1, y1, x2, y2) self.update() class Canvas_alphabeta(Canvas): """Alpha-beta pruning for Fig52Extended on HTML canvas""" def __init__(self, varname, util_list, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.utils = {node:util for node, util in zip(range(13, 40), util_list)} self.game = Fig52Extended() self.game.utils = self.utils self.nodes = list(range(40)) self.l = 1/40 self.node_pos = {} for i in range(4): base = len(self.node_pos) row_size = 3**i for node in [base + j for j in range(row_size)]: self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2, 3*self.l/2 + (self.l + (1 - 6*self.l)/3)*i) self.font("12px Arial") self.node_stack = [] self.explored = {node for node in self.utils} self.pruned = set() self.ab = {} self.thick_lines = set() self.change_list = [] self.draw_graph() self.stack_manager = self.stack_manager_gen() def alphabeta_search(self, node): game = self.game player = game.to_move(node) # Functions used by alphabeta def max_value(node, alpha, beta): if game.terminal_test(node): self.change_list.append(('a', node)) self.change_list.append(('h',)) self.change_list.append(('p',)) return game.utility(node, player) v = -infinity self.change_list.append(('a', node)) self.change_list.append(('ab',node, v, beta)) self.change_list.append(('h',)) for a in game.actions(node): min_val = min_value(game.result(node, a), alpha, beta) if v < min_val: v = min_val max_node = game.result(node, a) self.change_list.append(('ab',node, v, beta)) if v >= beta: self.change_list.append(('h',)) self.pruned.add(node) break alpha = max(alpha, v) self.utils[node] = v if node not in self.pruned: self.change_list.append(('l', (node, max_node - 3*node - 1))) self.change_list.append(('e',node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return v def min_value(node, alpha, beta): if game.terminal_test(node): self.change_list.append(('a', node)) self.change_list.append(('h',)) self.change_list.append(('p',)) return game.utility(node, player) v = infinity self.change_list.append(('a', node)) self.change_list.append(('ab',node, alpha, v)) self.change_list.append(('h',)) for a in game.actions(node): max_val = max_value(game.result(node, a), alpha, beta) if v > max_val: v = max_val min_node = game.result(node, a) self.change_list.append(('ab',node, alpha, v)) if v <= alpha: self.change_list.append(('h',)) self.pruned.add(node) break beta = min(beta, v) self.utils[node] = v if node not in self.pruned: self.change_list.append(('l', (node, min_node - 3*node - 1))) self.change_list.append(('e',node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return v return max_value(node, -infinity, infinity) def stack_manager_gen(self): self.alphabeta_search(0) for change in self.change_list: if change[0] == 'a': self.node_stack.append(change[1]) elif change[0] == 'ab': self.ab[change[1]] = change[2:] elif change[0] == 'e': self.explored.add(change[1]) elif change[0] == 'h': yield elif change[0] == 'l': self.thick_lines.add(change[1]) elif change[0] == 'p': self.node_stack.pop() def mouse_click(self, x, y): try: self.stack_manager.send(None) except StopIteration: pass self.draw_graph() def draw_graph(self): self.clear() # draw nodes self.stroke(0, 0, 0) self.strokeWidth(1) # highlight for nodes in stack for node in self.node_stack: x, y = self.node_pos[node] # alpha > beta if node not in self.explored and self.ab[node][0] > self.ab[node][1]: self.fill(200, 100, 100) else: self.fill(200, 200, 0) self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5) for node in self.nodes: x, y = self.node_pos[node] if node in self.explored: if node in self.pruned: self.fill(50, 50, 50) else: self.fill(255, 255, 255) else: self.fill(200, 200, 200) self.rect_n(x, y, self.l, self.l) self.line_n(x, y, x + self.l, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.l, y + self.l, x + self.l, y) self.line_n(x + self.l, y + self.l, x, y + self.l) self.fill(0, 0, 0) if node in self.explored and node not in self.pruned: self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10) # draw edges for i in range(13): x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l for j in range(3): x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1] if i in [1, 2, 3]: self.stroke(200, 0, 0) else: self.stroke(0, 200, 0) if (i, j) in self.thick_lines: self.strokeWidth(3) else: self.strokeWidth(1) self.line_n(x1, y1, x2, y2) # display alpha and beta for node in self.node_stack: if node not in self.explored: x, y = self.node_pos[node] alpha, beta = self.ab[node] self.text_n(alpha, x - self.l/2, y - self.l/10) self.text_n(beta, x + self.l, y - self.l/10) self.update() class Canvas_fol_bc_ask(Canvas): """fol_bc_ask() on HTML canvas""" def __init__(self, varname, kb, query, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.kb = kb self.query = query self.l = 1/20 self.b = 3*self.l bc_out = list(self.fol_bc_ask()) if len(bc_out) is 0: self.valid = False else: self.valid = True graph = bc_out[0][0][0] s = bc_out[0][1] while True: new_graph = subst(s, graph) if graph == new_graph: break graph = new_graph self.make_table(graph) self.context = None self.draw_table() def fol_bc_ask(self): KB = self.kb query = self.query def fol_bc_or(KB, goal, theta): for rule in KB.fetch_rules_for_goal(goal): lhs, rhs = parse_definite_clause(standardize_variables(rule)) for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)): yield ([(goal, theta1[0])], theta1[1]) def fol_bc_and(KB, goals, theta): if theta is None: pass elif not goals: yield ([], theta) else: first, rest = goals[0], goals[1:] for theta1 in fol_bc_or(KB, subst(theta, first), theta): for theta2 in fol_bc_and(KB, rest, theta1[1]): yield (theta1[0] + theta2[0], theta2[1]) return fol_bc_or(KB, query, {}) def make_table(self, graph): table = [] pos = {} links = set() edges = set() def dfs(node, depth): if len(table) <= depth: table.append([]) pos = len(table[depth]) table[depth].append(node[0]) for child in node[1]: child_id = dfs(child, depth + 1) links.add(((depth, pos), child_id)) return (depth, pos) dfs(graph, 0) y_off = 0.85/len(table) for i, row in enumerate(table): x_off = 0.95/len(row) for j, node in enumerate(row): pos[(i, j)] = (0.025 + j*x_off + (x_off - self.b)/2, 0.025 + i*y_off + (y_off - self.l)/2) for p, c in links: x1, y1 = pos[p] x2, y2 = pos[c] edges.add((x1 + self.b/2, y1 + self.l, x2 + self.b/2, y2)) self.table = table self.pos = pos self.edges = edges def mouse_click(self, x, y): x, y = x/self.width, y/self.height for node in self.pos: xs, ys = self.pos[node] xe, ye = xs + self.b, ys + self.l if xs <= x <= xe and ys <= y <= ye: self.context = node break self.draw_table() def draw_table(self): self.clear() self.strokeWidth(3) self.stroke(0, 0, 0) self.font("12px Arial") if self.valid: # draw nodes for i, j in self.pos: x, y = self.pos[(i, j)] self.fill(200, 200, 200) self.rect_n(x, y, self.b, self.l) self.line_n(x, y, x + self.b, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.b, y, x + self.b, y + self.l) self.line_n(x, y + self.l, x + self.b, y + self.l) self.fill(0, 0, 0) self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01) #draw edges for x1, y1, x2, y2 in self.edges: self.line_n(x1, y1, x2, y2) else: self.fill(255, 0, 0) self.rect_n(0, 0, 1, 1) # text area self.fill(255, 255, 255) self.rect_n(0, 0.9, 1, 0.1) self.strokeWidth(5) self.stroke(0, 0, 0) self.line_n(0, 0.9, 1, 0.9) self.font("22px Arial") self.fill(0, 0, 0) self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975) self.update()