csp.py 54,1 ko
Newer Older
    """Returns a function that is True when x is equal to val, False otherwise"""

    def isv(x):
        return val == x

    isv.__name__ = str(val) + "=="
    return isv


def ne_constraint(val):
    """Returns a function that is True when x is not equal to val, False otherwise"""

    def nev(x):
        return val != x

    nev.__name__ = str(val) + "!="
    return nev


def no_heuristic(to_do):
    return to_do


def sat_up(to_do):
    return SortedSet(to_do, key=lambda t: 1 / len([var for var in t[1].scope]))


class ACSolver:
    """Solves a CSP with arc consistency and domain splitting"""

    def __init__(self, csp):
        """a CSP solver that uses arc consistency
        * csp is the CSP to be solved
        """
        self.csp = csp

    def GAC(self, orig_domains=None, to_do=None, arc_heuristic=sat_up):
        """
        Makes this CSP arc-consistent using Generalized Arc Consistency
        orig_domains: is the original domains
        to_do       : is a set of (variable,constraint) pairs
        returns the reduced domains (an arc-consistent variable:domain dictionary)
        """
        if orig_domains is None:
            orig_domains = self.csp.domains
        if to_do is None:
            to_do = {(var, const) for const in self.csp.constraints for var in const.scope}
        else:
            to_do = to_do.copy()
        domains = orig_domains.copy()
        to_do = arc_heuristic(to_do)
        while to_do:
            var, const = to_do.pop()
            other_vars = [ov for ov in const.scope if ov != var]
                for val in domains[var]:
                    if const.holds({var: val}):
                        new_domain.add(val)
                    checks += 1
                # new_domain = {val for val in domains[var]
                #               if const.holds({var: val})}
            elif len(other_vars) == 1:
                other = other_vars[0]
                for val in domains[var]:
                    for other_val in domains[other]:
                        checks += 1
                        if const.holds({var: val, other: other_val}):
                            new_domain.add(val)
                            break
                # new_domain = {val for val in domains[var]
                #               if any(const.holds({var: val, other: other_val})
                #                      for other_val in domains[other])}
            else:  # general case
                for val in domains[var]:
                    holds, checks = self.any_holds(domains, const, {var: val}, other_vars, checks=checks)
                    if holds:
                        new_domain.add(val)
                # new_domain = {val for val in domains[var]
                #               if self.any_holds(domains, const, {var: val}, other_vars)}
            if new_domain != domains[var]:
                domains[var] = new_domain
                if not new_domain:
                    return False, domains, checks
                add_to_do = self.new_to_do(var, const).difference(to_do)
                to_do |= add_to_do
        return True, domains, checks
        """
        Returns new elements to be added to to_do after assigning
        variable var in constraint const.
        """
        return {(nvar, nconst) for nconst in self.csp.var_to_const[var]
                if nconst != const
                for nvar in nconst.scope
                if nvar != var}

    def any_holds(self, domains, const, env, other_vars, ind=0, checks=0):
        """
        Returns True if Constraint const holds for an assignment
        that extends env with the variables in other_vars[ind:]
        env is a dictionary
        Warning: this has side effects and changes the elements of env
        """
        if ind == len(other_vars):
            return const.holds(env), checks + 1
        else:
            var = other_vars[ind]
            for val in domains[var]:
                # env = dict_union(env, {var:val})  # no side effects
                holds, checks = self.any_holds(domains, const, env, other_vars, ind + 1, checks)
                    return True, checks
            return False, checks

    def domain_splitting(self, domains=None, to_do=None, arc_heuristic=sat_up):
        """
        Return a solution to the current CSP or False if there are no solutions
        to_do is the list of arcs to check
        """
        if domains is None:
            domains = self.csp.domains
        consistency, new_domains, _ = self.GAC(domains, to_do, arc_heuristic)
        if not consistency:
            return False
        elif all(len(new_domains[var]) == 1 for var in domains):
            return {var: first(new_domains[var]) for var in domains}
        else:
            var = first(x for x in self.csp.variables if len(new_domains[x]) > 1)
            if var:
                dom1, dom2 = partition_domain(new_domains[var])
                new_doms1 = extend(new_domains, var, dom1)
                new_doms2 = extend(new_domains, var, dom2)
                to_do = self.new_to_do(var, None)
                return self.domain_splitting(new_doms1, to_do, arc_heuristic) or \
                       self.domain_splitting(new_doms2, to_do, arc_heuristic)


def partition_domain(dom):
    """Partitions domain dom into two"""
    split = len(dom) // 2
    dom1 = set(list(dom)[:split])
    dom2 = dom - dom1
    return dom1, dom2


class ACSearchSolver(search.Problem):
    """A search problem with arc consistency and domain splitting

    def __init__(self, csp, arc_heuristic=sat_up):
        self.cons = ACSolver(csp)
        consistency, self.domains, _ = self.cons.GAC(arc_heuristic=arc_heuristic)
        if not consistency:
            raise Exception('CSP is inconsistent')
        self.heuristic = arc_heuristic
        super().__init__(self.domains)

    def goal_test(self, node):
        """Node is a goal if all domains have 1 element"""
        return all(len(node[var]) == 1 for var in node)

    def actions(self, state):
        var = first(x for x in state if len(state[x]) > 1)
        neighs = []
        if var:
            dom1, dom2 = partition_domain(state[var])
            to_do = self.cons.new_to_do(var, None)
            for dom in [dom1, dom2]:
                new_domains = extend(state, var, dom)
                consistency, cons_doms, _ = self.cons.GAC(new_domains, to_do, self.heuristic)
                if consistency:
                    neighs.append(cons_doms)
        return neighs

    def result(self, state, action):
        return action


def ac_solver(csp, arc_heuristic=sat_up):
    """Arc consistency (domain splitting interface)"""
    return ACSolver(csp).domain_splitting(arc_heuristic=arc_heuristic)


def ac_search_solver(csp, arc_heuristic=sat_up):
    """Arc consistency (search interface)"""
    from search import depth_first_tree_search
    solution = None
    try:
        solution = depth_first_tree_search(ACSearchSolver(csp, arc_heuristic=arc_heuristic)).state
    except:
        return solution
    if solution:
        return {var: first(solution[var]) for var in solution}


# ______________________________________________________________________________
# Crossword Problem


csp_crossword = NaryCSP({'one_across': {'ant', 'big', 'bus', 'car', 'has'},
                         'one_down': {'book', 'buys', 'hold', 'lane', 'year'},
                         'two_down': {'ginger', 'search', 'symbol', 'syntax'},
                         'three_across': {'book', 'buys', 'hold', 'land', 'year'},
                         'four_across': {'ant', 'big', 'bus', 'car', 'has'}},
                        [Constraint(('one_across', 'one_down'), meet_at_constraint(0, 0)),
                         Constraint(('one_across', 'two_down'), meet_at_constraint(2, 0)),
                         Constraint(('three_across', 'two_down'), meet_at_constraint(2, 2)),
                         Constraint(('three_across', 'one_down'), meet_at_constraint(0, 2)),
                         Constraint(('four_across', 'two_down'), meet_at_constraint(0, 4))])

crossword1 = [['_', '_', '_', '*', '*'],
              ['_', '*', '_', '*', '*'],
              ['_', '_', '_', '_', '*'],
              ['_', '*', '_', '*', '*'],
              ['*', '*', '_', '_', '_'],
              ['*', '*', '_', '*', '*']]

words1 = {'ant', 'big', 'bus', 'car', 'has', 'book', 'buys', 'hold',
          'lane', 'year', 'ginger', 'search', 'symbol', 'syntax'}


class Crossword(NaryCSP):

    def __init__(self, puzzle, words):
        domains = {}
        constraints = []
        for i, line in enumerate(puzzle):
            scope = []
            for j, element in enumerate(line):
                if element == '_':
                    var = "p" + str(j) + str(i)
                    domains[var] = list(string.ascii_lowercase)
                    scope.append(var)
                else:
                    if len(scope) > 1:
                        constraints.append(Constraint(tuple(scope), is_word_constraint(words)))
                constraints.append(Constraint(tuple(scope), is_word_constraint(words)))
        puzzle_t = list(map(list, zip(*puzzle)))
        for i, line in enumerate(puzzle_t):
            scope = []
            for j, element in enumerate(line):
                if element == '_':
                    scope.append("p" + str(i) + str(j))
                else:
                    if len(scope) > 1:
                        constraints.append(Constraint(tuple(scope), is_word_constraint(words)))
                constraints.append(Constraint(tuple(scope), is_word_constraint(words)))
        super().__init__(domains, constraints)
        self.puzzle = puzzle

    def display(self, assignment=None):
        for i, line in enumerate(self.puzzle):
            puzzle = ""
            for j, element in enumerate(line):
                if element == '*':
                    puzzle += "[*] "
                else:
                    var = "p" + str(j) + str(i)
                    if assignment is not None:
                        if isinstance(assignment[var], set) and len(assignment[var]) is 1:
                            puzzle += "[" + str(first(assignment[var])).upper() + "] "
                        elif isinstance(assignment[var], str):
                            puzzle += "[" + str(assignment[var]).upper() + "] "
                        else:
                            puzzle += "[_] "
                    else:
                        puzzle += "[_] "
            print(puzzle)


# ______________________________________________________________________________
           ['*', [4, ''], [3, 3], '_', '_'],
           [['', 10], '_', '_', '_', '_'],
           [['', 3], '_', '_', '*', '*']]

# difficulty 0
    ['*', [10, ''], [13, ''], '*'],
    [['', 3], '_', '_', [13, '']],
    [['', 12], '_', '_', '_'],
    [['', 21], '_', '_', '_']]

# difficulty 1
    ['*', [17, ''], [28, ''], '*', [42, ''], [22, '']],
    [['', 9], '_', '_', [31, 14], '_', '_'],
    [['', 20], '_', '_', '_', '_', '_'],
    ['*', ['', 30], '_', '_', '_', '_'],
    ['*', [22, 24], '_', '_', '_', '*'],
    [['', 25], '_', '_', '_', '_', [11, '']],
    [['', 20], '_', '_', '_', '_', '_'],
    [['', 14], '_', '_', ['', 17], '_', '_']]

# difficulty 2
    ['*', '*', '*', '*', '*', [4, ''], [24, ''], [11, ''], '*', '*', '*', [11, ''], [17, ''], '*', '*'],
    ['*', '*', '*', [17, ''], [11, 12], '_', '_', '_', '*', '*', [24, 10], '_', '_', [11, ''], '*'],
    ['*', [4, ''], [16, 26], '_', '_', '_', '_', '_', '*', ['', 20], '_', '_', '_', '_', [16, '']],
    [['', 20], '_', '_', '_', '_', [24, 13], '_', '_', [16, ''], ['', 12], '_', '_', [23, 10], '_', '_'],
    [['', 10], '_', '_', [24, 12], '_', '_', [16, 5], '_', '_', [16, 30], '_', '_', '_', '_', '_'],
    ['*', '*', [3, 26], '_', '_', '_', '_', ['', 12], '_', '_', [4, ''], [16, 14], '_', '_', '*'],
    ['*', ['', 8], '_', '_', ['', 15], '_', '_', [34, 26], '_', '_', '_', '_', '_', '*', '*'],
    ['*', ['', 11], '_', '_', [3, ''], [17, ''], ['', 14], '_', '_', ['', 8], '_', '_', [7, ''], [17, ''], '*'],
    ['*', '*', '*', [23, 10], '_', '_', [3, 9], '_', '_', [4, ''], [23, ''], ['', 13], '_', '_', '*'],
    ['*', '*', [10, 26], '_', '_', '_', '_', '_', ['', 7], '_', '_', [30, 9], '_', '_', '*'],
    ['*', [17, 11], '_', '_', [11, ''], [24, 8], '_', '_', [11, 21], '_', '_', '_', '_', [16, ''], [17, '']],
    [['', 29], '_', '_', '_', '_', '_', ['', 7], '_', '_', [23, 14], '_', '_', [3, 17], '_', '_'],
    [['', 10], '_', '_', [3, 10], '_', '_', '*', ['', 8], '_', '_', [4, 25], '_', '_', '_', '_'],
    ['*', ['', 16], '_', '_', '_', '_', '*', ['', 23], '_', '_', '_', '_', '_', '*', '*'],
    ['*', '*', ['', 6], '_', '_', '*', '*', ['', 15], '_', '_', '_', '*', '*', '*', '*']]



    def __init__(self, puzzle):
        variables = []
        for i, line in enumerate(puzzle):
            # print line
            for j, element in enumerate(line):
                if element == '_':
                    var1 = str(i)
                    if len(var1) == 1:
                        var1 = "0" + var1
                    var2 = str(j)
                    if len(var2) == 1:
                        var2 = "0" + var2
                    variables.append("X" + var1 + var2)
        domains = {}
        for var in variables:
            domains[var] = set(range(1, 10))
        constraints = []
        for i, line in enumerate(puzzle):
            for j, element in enumerate(line):
                if element != '_' and element != '*':
                    # down - column
                    if element[0] != '':
                        x = []
                        for k in range(i + 1, len(puzzle)):
                            if puzzle[k][j] != '_':
                                break
                            var1 = str(k)
                            if len(var1) == 1:
                                var1 = "0" + var1
                            var2 = str(j)
                            if len(var2) == 1:
                                var2 = "0" + var2
                            x.append("X" + var1 + var2)
                        constraints.append(Constraint(x, sum_constraint(element[0])))
                        constraints.append(Constraint(x, all_diff_constraint))
                    # right - line
                    if element[1] != '':
                        x = []
                        for k in range(j + 1, len(puzzle[i])):
                            if puzzle[i][k] != '_':
                                break
                            var1 = str(i)
                            if len(var1) == 1:
                                var1 = "0" + var1
                            var2 = str(k)
                            if len(var2) == 1:
                                var2 = "0" + var2
                            x.append("X" + var1 + var2)
                        constraints.append(Constraint(x, sum_constraint(element[1])))
                        constraints.append(Constraint(x, all_diff_constraint))
        super().__init__(domains, constraints)
        self.puzzle = puzzle

    def display(self, assignment=None):
        for i, line in enumerate(self.puzzle):
            puzzle = ""
            for j, element in enumerate(line):
                if element == '*':
                    puzzle += "[*]\t"
                elif element == '_':
                    var1 = str(i)
                    if len(var1) == 1:
                        var1 = "0" + var1
                    var2 = str(j)
                    if len(var2) == 1:
                        var2 = "0" + var2
                    var = "X" + var1 + var2
                    if assignment is not None:
                        if isinstance(assignment[var], set) and len(assignment[var]) is 1:
                            puzzle += "[" + str(first(assignment[var])) + "]\t"
                        elif isinstance(assignment[var], int):
                            puzzle += "[" + str(assignment[var]) + "]\t"
                        else:
                            puzzle += "[_]\t"
                    else:
                        puzzle += "[_]\t"
                else:
                    puzzle += str(element[0]) + "\\" + str(element[1]) + "\t"
            print(puzzle)


# ______________________________________________________________________________
# Cryptarithmetic Problem

# [Figure 6.2]
# T W O + T W O = F O U R
two_two_four = NaryCSP({'T': set(range(1, 10)), 'F': set(range(1, 10)),
                        'W': set(range(0, 10)), 'O': set(range(0, 10)), 'U': set(range(0, 10)), 'R': set(range(0, 10)),
                        'C1': set(range(0, 2)), 'C2': set(range(0, 2)), 'C3': set(range(0, 2))},
                       [Constraint(('T', 'F', 'W', 'O', 'U', 'R'), all_diff_constraint),
                        Constraint(('O', 'R', 'C1'), lambda o, r, c1: o + o == r + 10 * c1),
                        Constraint(('W', 'U', 'C1', 'C2'), lambda w, u, c1, c2: c1 + w + w == u + 10 * c2),
                        Constraint(('T', 'O', 'C2', 'C3'), lambda t, o, c2, c3: c2 + t + t == o + 10 * c3),
                        Constraint(('F', 'C3'), eq)])

# S E N D + M O R E = M O N E Y
send_more_money = NaryCSP({'S': set(range(1, 10)), 'M': set(range(1, 10)),
                           'E': set(range(0, 10)), 'N': set(range(0, 10)), 'D': set(range(0, 10)),
                           'O': set(range(0, 10)), 'R': set(range(0, 10)), 'Y': set(range(0, 10)),
                           'C1': set(range(0, 2)), 'C2': set(range(0, 2)), 'C3': set(range(0, 2)),
                           'C4': set(range(0, 2))},
                          [Constraint(('S', 'E', 'N', 'D', 'M', 'O', 'R', 'Y'), all_diff_constraint),
                           Constraint(('D', 'E', 'Y', 'C1'), lambda d, e, y, c1: d + e == y + 10 * c1),
                           Constraint(('N', 'R', 'E', 'C1', 'C2'), lambda n, r, e, c1, c2: c1 + n + r == e + 10 * c2),
                           Constraint(('E', 'O', 'N', 'C2', 'C3'), lambda e, o, n, c2, c3: c2 + e + o == n + 10 * c3),
                           Constraint(('S', 'M', 'O', 'C3', 'C4'), lambda s, m, o, c3, c4: c3 + s + m == o + 10 * c4),
                           Constraint(('M', 'C4'), eq)])