from inspect import getsource
from utils import argmax, argmin
from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity
from logic import parse_definite_clause, standardize_variables, unify, subst
from learning import DataSet
from IPython.display import HTML, display
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os, struct
import array
#______________________________________________________________________________
def pseudocode(algorithm):
"""Print the pseudocode for the given algorithm."""
from urllib.request import urlopen
from IPython.display import Markdown
url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
f = urlopen(url)
md = f.read().decode('utf-8')
md = md.split('\n', 1)[-1].strip()
md = '#' + md
return Markdown(md)
def psource(*functions):
"""Print the source code for the given function(s)."""
source_code = '\n\n'.join(getsource(fn) for fn in functions)
try:
from pygments.formatters import HtmlFormatter
from pygments.lexers import PythonLexer
from pygments import highlight
display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))
except ImportError:
print(source_code)
# ______________________________________________________________________________
def psource(*functions):
"Print the source code for the given function(s)."
source_code = '\n\n'.join(getsource(fn) for fn in functions)
try:
from pygments.formatters import HtmlFormatter
from pygments.lexers import PythonLexer
from pygments import highlight
display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True))))
except ImportError:
print(source_code)
# ______________________________________________________________________________
def show_iris(i=0, j=1, k=2):
"""Plots the iris dataset in a 3D plot.
The three axes are given by i, j and k,
which correspond to three of the four iris features."""
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams.update(plt.rcParamsDefault)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
iris = DataSet(name="iris")
buckets = iris.split_values_by_classes()
features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
f1, f2, f3 = features[i], features[j], features[k]
a_setosa = [v[i] for v in buckets["setosa"]]
b_setosa = [v[j] for v in buckets["setosa"]]
c_setosa = [v[k] for v in buckets["setosa"]]
a_virginica = [v[i] for v in buckets["virginica"]]
b_virginica = [v[j] for v in buckets["virginica"]]
c_virginica = [v[k] for v in buckets["virginica"]]
a_versicolor = [v[i] for v in buckets["versicolor"]]
b_versicolor = [v[j] for v in buckets["versicolor"]]
c_versicolor = [v[k] for v in buckets["versicolor"]]
for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
('g', '^', a_virginica, b_virginica, c_virginica),
('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
ax.scatter(sl, sw, pl, c=c, marker=m)
ax.set_xlabel(f1)
ax.set_ylabel(f2)
ax.set_zlabel(f3)
plt.show()
# ______________________________________________________________________________
def load_MNIST(path="aima-data/MNIST"):
import os, struct
import array
import numpy as np
from collections import Counter
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
tr_img = array.array("B", train_img_file.read())
train_img_file.close()
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
tr_lbl = array.array("b", train_lbl_file.read())
train_lbl_file.close()
magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
te_img = array.array("B", test_img_file.read())
test_img_file.close()
magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
te_lbl = array.array("b", test_lbl_file.read())
test_lbl_file.close()
#print(len(tr_img), len(tr_lbl), tr_size)
#print(len(te_img), len(te_lbl), te_size)
train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)
train_lbl = np.zeros((tr_size,), dtype=np.int8)
for i in range(tr_size):
train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols))
train_lbl[i] = tr_lbl[i]
test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16)
test_lbl = np.zeros((te_size,), dtype=np.int8)
for i in range(te_size):
test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols))
test_lbl[i] = te_lbl[i]
return(train_img, train_lbl, test_img, test_lbl)
def show_MNIST(labels, images, samples=8):
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
num_classes = len(classes)
for y, cls in enumerate(classes):
idxs = np.nonzero([i == y for i in labels])
idxs = np.random.choice(idxs[0], samples, replace=False)
for i , idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples, num_classes, plt_idx)
plt.imshow(images[idx].reshape((28, 28)))
plt.axis("off")
if i == 0:
plt.title(cls)
plt.show()
def show_ave_MNIST(labels, images):
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
num_classes = len(classes)
for y, cls in enumerate(classes):
idxs = np.nonzero([i == y for i in labels])
print("Digit", y, ":", len(idxs[0]), "images.")
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
#print(ave_img.shape)
plt.subplot(1, num_classes, y+1)
plt.imshow(ave_img.reshape((28, 28)))
plt.axis("off")
plt.title(cls)
plt.show()
# ______________________________________________________________________________
_canvas = """
""" # noqa
class Canvas:
"""Inherit from this class to manage the HTML canvas element in jupyter notebooks.
To create an object of this class any_name_xyz = Canvas("any_name_xyz")
The first argument given must be the name of the object being created.
IPython must be able to refernce the variable name that is being passed."""
def __init__(self, varname, width=800, height=600, cid=None):
self.name = varname
self.cid = cid or varname
self.width = width
self.height = height
self.html = _canvas.format(self.cid, self.width, self.height, self.name)
self.exec_list = []
display_html(self.html)
def mouse_click(self, x, y):
"""Override this method to handle mouse click at position (x, y)"""
raise NotImplementedError
def mouse_move(self, x, y):
raise NotImplementedError
def execute(self, exec_str):
"""Stores the command to be exectued to a list which is used later during update()"""
if not isinstance(exec_str, str):
print("Invalid execution argument:", exec_str)
self.alert("Recieved invalid execution command format")
prefix = "{0}_canvas_object.".format(self.cid)
self.exec_list.append(prefix + exec_str + ';')
def fill(self, r, g, b):
"""Changes the fill color to a color in rgb format"""
self.execute("fill({0}, {1}, {2})".format(r, g, b))
def stroke(self, r, g, b):
"""Changes the colors of line/strokes to rgb"""
self.execute("stroke({0}, {1}, {2})".format(r, g, b))
def strokeWidth(self, w):
"""Changes the width of lines/strokes to 'w' pixels"""
self.execute("strokeWidth({0})".format(w))
def rect(self, x, y, w, h):
"""Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner"""
self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h))
def rect_n(self, xn, yn, wn, hn):
"""Similar to rect(), but the dimensions are normalized to fall between 0 and 1"""
x = round(xn * self.width)
y = round(yn * self.height)
w = round(wn * self.width)
h = round(hn * self.height)
self.rect(x, y, w, h)
def line(self, x1, y1, x2, y2):
"""Draw a line from (x1, y1) to (x2, y2)"""
self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2))
def line_n(self, x1n, y1n, x2n, y2n):
"""Similar to line(), but the dimensions are normalized to fall between 0 and 1"""
x1 = round(x1n * self.width)
y1 = round(y1n * self.height)
x2 = round(x2n * self.width)
y2 = round(y2n * self.height)
self.line(x1, y1, x2, y2)
def arc(self, x, y, r, start, stop):
"""Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'"""
self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop))
def arc_n(self, xn, yn, rn, start, stop):
"""Similar to arc(), but the dimensions are normalized to fall between 0 and 1
The normalizing factor for radius is selected between width and height by
seeing which is smaller."""
x = round(xn * self.width)
y = round(yn * self.height)
r = round(rn * min(self.width, self.height))
self.arc(x, y, r, start, stop)
def clear(self):
"""Clear the HTML canvas"""
self.execute("clear()")
def font(self, font):
"""Changes the font of text"""
self.execute('font("{0}")'.format(font))
def text(self, txt, x, y, fill=True):
"""Display a text at (x, y)"""
if fill:
self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y))
else:
self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y))
def text_n(self, txt, xn, yn, fill=True):
"""Similar to text(), but with normalized coordinates"""
x = round(xn * self.width)
y = round(yn * self.height)
self.text(txt, x, y, fill)
def alert(self, message):
"""Immediately display an alert"""
display_html(''.format(message))
def update(self):
"""Execute the JS code to execute the commands queued by execute()"""
exec_code = ""
self.exec_list = []
display_html(exec_code)
def display_html(html_string):
display(HTML(html_string))
################################################################################
class Canvas_TicTacToe(Canvas):
"""Play a 3x3 TicTacToe game on HTML canvas"""
def __init__(self, varname, player_1='human', player_2='random',
width=300, height=350, cid=None):
valid_players = ('human', 'random', 'alphabeta')
if player_1 not in valid_players or player_2 not in valid_players:
raise TypeError("Players must be one of {}".format(valid_players))
Canvas.__init__(self, varname, width, height, cid)
self.ttt = TicTacToe()
self.state = self.ttt.initial
self.turn = 0
self.strokeWidth(5)
self.players = (player_1, player_2)
self.font("20px Arial")
self.draw_board()
def mouse_click(self, x, y):
player = self.players[self.turn]
if self.ttt.terminal_test(self.state):
if 0.55 <= x/self.width <= 0.95 and 6/7 <= y/self.height <= 6/7+1/8:
self.state = self.ttt.initial
self.turn = 0
self.draw_board()
return
if player == 'human':
x, y = int(3*x/self.width) + 1, int(3*y/(self.height*6/7)) + 1
if (x, y) not in self.ttt.actions(self.state):
# Invalid move
return
move = (x, y)
elif player == 'alphabeta':
move = alphabeta_player(self.ttt, self.state)
else:
move = random_player(self.ttt, self.state)
self.state = self.ttt.result(self.state, move)
self.turn ^= 1
self.draw_board()
def draw_board(self):
self.clear()
self.stroke(0, 0, 0)
offset = 1/20
self.line_n(0 + offset, (1/3)*6/7, 1 - offset, (1/3)*6/7)
self.line_n(0 + offset, (2/3)*6/7, 1 - offset, (2/3)*6/7)
self.line_n(1/3, (0 + offset)*6/7, 1/3, (1 - offset)*6/7)
self.line_n(2/3, (0 + offset)*6/7, 2/3, (1 - offset)*6/7)
board = self.state.board
for mark in board:
if board[mark] == 'X':
self.draw_x(mark)
elif board[mark] == 'O':
self.draw_o(mark)
if self.ttt.terminal_test(self.state):
# End game message
utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial))
if utility == 0:
self.text_n('Game Draw!', offset, 6/7 + offset)
else:
self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6/7 + offset)
# Find the 3 and draw a line
self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
for i in range(3):
if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
self.line_n(i/3 + 1/6, offset*6/7, i/3 + 1/6, (1 - offset)*6/7)
if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
self.line_n(offset, (i/3 + 1/6)*6/7, 1 - offset, (i/3 + 1/6)*6/7)
if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
self.line_n(offset, offset*6/7, 1 - offset, (1 - offset)*6/7)
if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
self.line_n(offset, (1 - offset)*6/7, 1 - offset, offset*6/7)
# restart button
self.fill(0, 0, 255)
self.rect_n(0.5 + offset, 6/7, 0.4, 1/8)
self.fill(0, 0, 0)
self.text_n('Restart', 0.5 + 2*offset, 13/14)
else: # Print which player's turn it is
self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
offset, 6/7 + offset)
self.update()
def draw_x(self, position):
self.stroke(0, 255, 0)
x, y = [i-1 for i in position]
offset = 1/15
self.line_n(x/3 + offset, (y/3 + offset)*6/7, x/3 + 1/3 - offset, (y/3 + 1/3 - offset)*6/7)
self.line_n(x/3 + 1/3 - offset, (y/3 + offset)*6/7, x/3 + offset, (y/3 + 1/3 - offset)*6/7)
def draw_o(self, position):
self.stroke(255, 0, 0)
x, y = [i-1 for i in position]
self.arc_n(x/3 + 1/6, (y/3 + 1/6)*6/7, 1/9, 0, 360)
class Canvas_minimax(Canvas):
"""Minimax for Fig52Extended on HTML canvas"""
def __init__(self, varname, util_list, width=800, height=600, cid=None):
Canvas.__init__(self, varname, width, height, cid)
self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
self.l = 1/40
self.node_pos = {}
for i in range(4):
base = len(self.node_pos)
row_size = 3**i
for node in [base + j for j in range(row_size)]:
self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
self.l/2 + (self.l + (1 - 5*self.l)/3)*i)
self.font("12px Arial")
self.node_stack = []
self.explored = {node for node in self.utils}
self.thick_lines = set()
self.change_list = []
self.draw_graph()
self.stack_manager = self.stack_manager_gen()
def minimax(self, node):
game = self.game
player = game.to_move(node)
def max_value(node):
if game.terminal_test(node):
return game.utility(node, player)
self.change_list.append(('a', node))
self.change_list.append(('h',))
max_a = argmax(game.actions(node), key=lambda x: min_value(game.result(node, x)))
max_node = game.result(node, max_a)
self.utils[node] = self.utils[max_node]
x1, y1 = self.node_pos[node]
x2, y2 = self.node_pos[max_node]
self.change_list.append(('l', (node, max_node - 3*node - 1)))
self.change_list.append(('e', node))
self.change_list.append(('p',))
self.change_list.append(('h',))
return self.utils[node]
def min_value(node):
if game.terminal_test(node):
return game.utility(node, player)
self.change_list.append(('a', node))
self.change_list.append(('h',))
min_a = argmin(game.actions(node), key=lambda x: max_value(game.result(node, x)))
min_node = game.result(node, min_a)
self.utils[node] = self.utils[min_node]
x1, y1 = self.node_pos[node]
x2, y2 = self.node_pos[min_node]
self.change_list.append(('l', (node, min_node - 3*node - 1)))
self.change_list.append(('e', node))
self.change_list.append(('p',))
self.change_list.append(('h',))
return self.utils[node]
return max_value(node)
def stack_manager_gen(self):
self.minimax(0)
for change in self.change_list:
if change[0] == 'a':
self.node_stack.append(change[1])
elif change[0] == 'e':
self.explored.add(change[1])
elif change[0] == 'h':
yield
elif change[0] == 'l':
self.thick_lines.add(change[1])
elif change[0] == 'p':
self.node_stack.pop()
def mouse_click(self, x, y):
try:
self.stack_manager.send(None)
except StopIteration:
pass
self.draw_graph()
def draw_graph(self):
self.clear()
# draw nodes
self.stroke(0, 0, 0)
self.strokeWidth(1)
# highlight for nodes in stack
for node in self.node_stack:
x, y = self.node_pos[node]
self.fill(200, 200, 0)
self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
for node in self.nodes:
x, y = self.node_pos[node]
if node in self.explored:
self.fill(255, 255, 255)
else:
self.fill(200, 200, 200)
self.rect_n(x, y, self.l, self.l)
self.line_n(x, y, x + self.l, y)
self.line_n(x, y, x, y + self.l)
self.line_n(x + self.l, y + self.l, x + self.l, y)
self.line_n(x + self.l, y + self.l, x, y + self.l)
self.fill(0, 0, 0)
if node in self.explored:
self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
# draw edges
for i in range(13):
x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
for j in range(3):
x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
if i in [1, 2, 3]:
self.stroke(200, 0, 0)
else:
self.stroke(0, 200, 0)
if (i, j) in self.thick_lines:
self.strokeWidth(3)
else:
self.strokeWidth(1)
self.line_n(x1, y1, x2, y2)
self.update()
class Canvas_alphabeta(Canvas):
"""Alpha-beta pruning for Fig52Extended on HTML canvas"""
def __init__(self, varname, util_list, width=800, height=600, cid=None):
Canvas.__init__(self, varname, width, height, cid)
self.utils = {node:util for node, util in zip(range(13, 40), util_list)}
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
self.l = 1/40
self.node_pos = {}
for i in range(4):
base = len(self.node_pos)
row_size = 3**i
for node in [base + j for j in range(row_size)]:
self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2,
3*self.l/2 + (self.l + (1 - 6*self.l)/3)*i)
self.font("12px Arial")
self.node_stack = []
self.explored = {node for node in self.utils}
self.pruned = set()
self.ab = {}
self.thick_lines = set()
self.change_list = []
self.draw_graph()
self.stack_manager = self.stack_manager_gen()
def alphabeta_search(self, node):
game = self.game
player = game.to_move(node)
# Functions used by alphabeta
def max_value(node, alpha, beta):
if game.terminal_test(node):
self.change_list.append(('a', node))
self.change_list.append(('h',))
self.change_list.append(('p',))
return game.utility(node, player)
v = -infinity
self.change_list.append(('a', node))
self.change_list.append(('ab',node, v, beta))
self.change_list.append(('h',))
for a in game.actions(node):
min_val = min_value(game.result(node, a), alpha, beta)
if v < min_val:
v = min_val
max_node = game.result(node, a)
self.change_list.append(('ab',node, v, beta))
if v >= beta:
self.change_list.append(('h',))
self.pruned.add(node)
break
alpha = max(alpha, v)
self.utils[node] = v
if node not in self.pruned:
self.change_list.append(('l', (node, max_node - 3*node - 1)))
self.change_list.append(('e',node))
self.change_list.append(('p',))
self.change_list.append(('h',))
return v
def min_value(node, alpha, beta):
if game.terminal_test(node):
self.change_list.append(('a', node))
self.change_list.append(('h',))
self.change_list.append(('p',))
return game.utility(node, player)
v = infinity
self.change_list.append(('a', node))
self.change_list.append(('ab',node, alpha, v))
self.change_list.append(('h',))
for a in game.actions(node):
max_val = max_value(game.result(node, a), alpha, beta)
if v > max_val:
v = max_val
min_node = game.result(node, a)
self.change_list.append(('ab',node, alpha, v))
if v <= alpha:
self.change_list.append(('h',))
self.pruned.add(node)
break
beta = min(beta, v)
self.utils[node] = v
if node not in self.pruned:
self.change_list.append(('l', (node, min_node - 3*node - 1)))
self.change_list.append(('e',node))
self.change_list.append(('p',))
self.change_list.append(('h',))
return v
return max_value(node, -infinity, infinity)
def stack_manager_gen(self):
self.alphabeta_search(0)
for change in self.change_list:
if change[0] == 'a':
self.node_stack.append(change[1])
elif change[0] == 'ab':
self.ab[change[1]] = change[2:]
elif change[0] == 'e':
self.explored.add(change[1])
elif change[0] == 'h':
yield
elif change[0] == 'l':
self.thick_lines.add(change[1])
elif change[0] == 'p':
self.node_stack.pop()
def mouse_click(self, x, y):
try:
self.stack_manager.send(None)
except StopIteration:
pass
self.draw_graph()
def draw_graph(self):
self.clear()
# draw nodes
self.stroke(0, 0, 0)
self.strokeWidth(1)
# highlight for nodes in stack
for node in self.node_stack:
x, y = self.node_pos[node]
# alpha > beta
if node not in self.explored and self.ab[node][0] > self.ab[node][1]:
self.fill(200, 100, 100)
else:
self.fill(200, 200, 0)
self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5)
for node in self.nodes:
x, y = self.node_pos[node]
if node in self.explored:
if node in self.pruned:
self.fill(50, 50, 50)
else:
self.fill(255, 255, 255)
else:
self.fill(200, 200, 200)
self.rect_n(x, y, self.l, self.l)
self.line_n(x, y, x + self.l, y)
self.line_n(x, y, x, y + self.l)
self.line_n(x + self.l, y + self.l, x + self.l, y)
self.line_n(x + self.l, y + self.l, x, y + self.l)
self.fill(0, 0, 0)
if node in self.explored and node not in self.pruned:
self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10)
# draw edges
for i in range(13):
x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l
for j in range(3):
x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1]
if i in [1, 2, 3]:
self.stroke(200, 0, 0)
else:
self.stroke(0, 200, 0)
if (i, j) in self.thick_lines:
self.strokeWidth(3)
else:
self.strokeWidth(1)
self.line_n(x1, y1, x2, y2)
# display alpha and beta
for node in self.node_stack:
if node not in self.explored:
x, y = self.node_pos[node]
alpha, beta = self.ab[node]
self.text_n(alpha, x - self.l/2, y - self.l/10)
self.text_n(beta, x + self.l, y - self.l/10)
self.update()
class Canvas_fol_bc_ask(Canvas):
"""fol_bc_ask() on HTML canvas"""
def __init__(self, varname, kb, query, width=800, height=600, cid=None):
Canvas.__init__(self, varname, width, height, cid)
self.kb = kb
self.query = query
self.l = 1/20
self.b = 3*self.l
bc_out = list(self.fol_bc_ask())
if len(bc_out) is 0:
self.valid = False
else:
self.valid = True
graph = bc_out[0][0][0]
s = bc_out[0][1]
while True:
new_graph = subst(s, graph)
if graph == new_graph:
break
graph = new_graph
self.make_table(graph)
self.context = None
self.draw_table()
def fol_bc_ask(self):
KB = self.kb
query = self.query
def fol_bc_or(KB, goal, theta):
for rule in KB.fetch_rules_for_goal(goal):
lhs, rhs = parse_definite_clause(standardize_variables(rule))
for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)):
yield ([(goal, theta1[0])], theta1[1])
def fol_bc_and(KB, goals, theta):
if theta is None:
pass
elif not goals:
yield ([], theta)
else:
first, rest = goals[0], goals[1:]
for theta1 in fol_bc_or(KB, subst(theta, first), theta):
for theta2 in fol_bc_and(KB, rest, theta1[1]):
yield (theta1[0] + theta2[0], theta2[1])
return fol_bc_or(KB, query, {})
def make_table(self, graph):
table = []
pos = {}
links = set()
edges = set()
def dfs(node, depth):
if len(table) <= depth:
table.append([])
pos = len(table[depth])
table[depth].append(node[0])
for child in node[1]:
child_id = dfs(child, depth + 1)
links.add(((depth, pos), child_id))
return (depth, pos)
dfs(graph, 0)
y_off = 0.85/len(table)
for i, row in enumerate(table):
x_off = 0.95/len(row)
for j, node in enumerate(row):
pos[(i, j)] = (0.025 + j*x_off + (x_off - self.b)/2, 0.025 + i*y_off + (y_off - self.l)/2)
for p, c in links:
x1, y1 = pos[p]
x2, y2 = pos[c]
edges.add((x1 + self.b/2, y1 + self.l, x2 + self.b/2, y2))
self.table = table
self.pos = pos
self.edges = edges
def mouse_click(self, x, y):
x, y = x/self.width, y/self.height
for node in self.pos:
xs, ys = self.pos[node]
xe, ye = xs + self.b, ys + self.l
if xs <= x <= xe and ys <= y <= ye:
self.context = node
break
self.draw_table()
def draw_table(self):
self.clear()
self.strokeWidth(3)
self.stroke(0, 0, 0)
self.font("12px Arial")
if self.valid:
# draw nodes
for i, j in self.pos:
x, y = self.pos[(i, j)]
self.fill(200, 200, 200)
self.rect_n(x, y, self.b, self.l)
self.line_n(x, y, x + self.b, y)
self.line_n(x, y, x, y + self.l)
self.line_n(x + self.b, y, x + self.b, y + self.l)
self.line_n(x, y + self.l, x + self.b, y + self.l)
self.fill(0, 0, 0)
self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01)
#draw edges
for x1, y1, x2, y2 in self.edges:
self.line_n(x1, y1, x2, y2)
else:
self.fill(255, 0, 0)
self.rect_n(0, 0, 1, 1)
# text area
self.fill(255, 255, 255)
self.rect_n(0, 0.9, 1, 0.1)
self.strokeWidth(5)
self.stroke(0, 0, 0)
self.line_n(0, 0.9, 1, 0.9)
self.font("22px Arial")
self.fill(0, 0, 0)
self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975)
self.update()