mdp.py 7,07 ko
Newer Older
"""Markov Decision Processes (Chapter 17)

First we define an MDP, and the special case of a GridMDP, in which
states are laid out in a 2-dimensional grid.  We also represent a policy
as a dictionary of {state:action} pairs, and a Utility function as a
withal's avatar
withal a validé
dictionary of {state:number} pairs.  We then define the value_iteration
and policy_iteration algorithms."""

MircoT's avatar
MircoT a validé

from . utils import *
MircoT's avatar
MircoT a validé

MircoT's avatar
MircoT a validé

    """A Markov Decision Process, defined by an initial state, transition model,
withal's avatar
withal a validé
    and reward function. We also keep track of a gamma value, for use by
    algorithms. The transition model is represented somewhat differently from
withal's avatar
withal a validé
    the text.  Instead of P(s' | s, a) being a probability number for each
    state/state/action triplet, we instead have T(s, a) return a list of (p, s')
withal's avatar
withal a validé
    pairs.  We also keep track of the possible states, terminal states, and
withal's avatar
withal a validé
    actions for each state. [page 646]"""
withal's avatar
withal a validé
    def __init__(self, init, actlist, terminals, gamma=.9):
MircoT's avatar
MircoT a validé
        self.init = init
        self.actlist = actlist
        self.terminals = terminals
        if not (0 <= gamma < 1):
            raise ValueError("An MDP must have 0 <= gamma < 1")
MircoT's avatar
MircoT a validé
        self.gamma = gamma
        self.states = set()
        self.reward = {}
withal's avatar
withal a validé
    def R(self, state):
        "Return a numeric reward for this state."
        return self.reward[state]

withal's avatar
withal a validé
    def T(self, state, action):
        """Transition model.  From a state and an action, return a list
        of (probability, result-state) pairs."""
MircoT's avatar
MircoT a validé
        raise NotImplementedError
withal's avatar
withal a validé
        """Set of actions that can be performed in this state.  By default, a
        fixed list of actions, except for terminal states. Override this
        method if you need to specialize by state."""
        if state in self.terminals:
            return [None]
        else:
            return self.actlist

MircoT's avatar
MircoT a validé

withal's avatar
withal a validé
class GridMDP(MDP):
MircoT's avatar
MircoT a validé

withal's avatar
withal a validé
    """A two-dimensional grid MDP, as in [Figure 17.1].  All you have to do is
    specify the grid as a list of lists of rewards; use None for an obstacle
    (unreachable state).  Also, you should specify the terminal states.
    An action is an (x, y) unit vector; e.g. (1, 0) means move east."""
MircoT's avatar
MircoT a validé

    def __init__(self, grid, terminals, init=(0, 0), gamma=.9):
MircoT's avatar
MircoT a validé
        grid.reverse()  # because we want row 0 on bottom, not on top
withal's avatar
withal a validé
        MDP.__init__(self, init, actlist=orientations,
                     terminals=terminals, gamma=gamma)
MircoT's avatar
MircoT a validé
        self.grid = grid
        self.rows = len(grid)
        self.cols = len(grid[0])
        for x in range(self.cols):
            for y in range(self.rows):
                self.reward[x, y] = grid[y][x]
                if grid[y][x] is not None:
                    self.states.add((x, y))

    def T(self, state, action):
withal's avatar
withal a validé
        if action is None:
            return [(0.0, state)]
        else:
            return [(0.8, self.go(state, action)),
                    (0.1, self.go(state, turn_right(action))),
                    (0.1, self.go(state, turn_left(action)))]

    def go(self, state, direction):
        "Return the state that results from going in this direction."
        state1 = vector_add(state, direction)
        return (state1 if state1 in self.states else state)

    def to_grid(self, mapping):
        """Convert a mapping from (x, y) to v into a [[..., v, ...]] grid."""
MircoT's avatar
MircoT a validé
        return list(reversed([[mapping.get((x, y), None)
withal's avatar
withal a validé
                               for x in range(self.cols)]
                              for y in range(self.rows)]))

    def to_arrows(self, policy):
MircoT's avatar
MircoT a validé
        chars = {
            (1, 0): '>', (0, 1): '^', (-1, 0): '<', (0, -1): 'v', None: '.'}
MircoT's avatar
MircoT a validé
        return self.to_grid(dict([(s, chars[a]) for (s, a) in list(policy.items())]))

#______________________________________________________________________________

MircoT's avatar
MircoT a validé
Fig[17, 1] = GridMDP([[-0.04, -0.04, -0.04, +1],
                      [-0.04, None,  -0.04, -1],
                      [-0.04, -0.04, -0.04, -0.04]],
                     terminals=[(3, 2), (3, 1)])

#______________________________________________________________________________

MircoT's avatar
MircoT a validé

def value_iteration(mdp, epsilon=0.001):
    "Solving an MDP by value iteration. [Fig. 17.4]"
    U1 = dict([(s, 0) for s in mdp.states])
    R, T, gamma = mdp.R, mdp.T, mdp.gamma
    while True:
withal's avatar
withal a validé
        U = U1.copy()
        delta = 0
        for s in mdp.states:
            U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)])
                                        for a in mdp.actions(s)])
            delta = max(delta, abs(U1[s] - U[s]))
        if delta < epsilon * (1 - gamma) / gamma:
MircoT's avatar
MircoT a validé
            return U


def best_policy(mdp, U):
    """Given an MDP and a utility function U, determine the best policy,
    as a mapping from state to action. (Equation 17.4)"""
    pi = {}
    for s in mdp.states:
MircoT's avatar
MircoT a validé
        pi[s] = argmax(
            mdp.actions(s), lambda a: expected_utility(a, s, U, mdp))
MircoT's avatar
MircoT a validé

def expected_utility(a, s, U, mdp):
    "The expected utility of doing a in state s, according to the MDP and U."
    return sum([p * U[s1] for (p, s1) in mdp.T(s, a)])

#______________________________________________________________________________

MircoT's avatar
MircoT a validé

def policy_iteration(mdp):
    "Solve an MDP by policy iteration [Fig. 17.7]"
    U = dict([(s, 0) for s in mdp.states])
    pi = dict([(s, random.choice(mdp.actions(s))) for s in mdp.states])
    while True:
        U = policy_evaluation(pi, U, mdp)
        unchanged = True
        for s in mdp.states:
MircoT's avatar
MircoT a validé
            a = argmax(
                mdp.actions(s), lambda a: expected_utility(a, s, U, mdp))
            if a != pi[s]:
                pi[s] = a
                unchanged = False
        if unchanged:
            return pi

MircoT's avatar
MircoT a validé

def policy_evaluation(pi, U, mdp, k=20):
withal's avatar
withal a validé
    """Return an updated utility mapping U from each state in the MDP to its
    utility, using an approximation (modified policy iteration)."""
    R, T, gamma = mdp.R, mdp.T, mdp.gamma
    for i in range(k):
        for s in mdp.states:
srburnet's avatar
srburnet a validé
            U[s] = R(s) + gamma * sum([p * U[s1] for (p, s1) in T(s, pi[s])])
peter.norvig's avatar
peter.norvig a validé
__doc__ += """
>>> pi = best_policy(Fig[17,1], value_iteration(Fig[17,1], .01))

>>> Fig[17,1].to_arrows(pi)
[['>', '>', '>', '.'], ['^', None, '^', '.'], ['^', '>', '^', '<']]

>>> print_table(Fig[17,1].to_arrows(pi))
>   >      >   .
^   None   ^   .
^   >      ^   <
peter.norvig's avatar
peter.norvig a validé

>>> print_table(Fig[17,1].to_arrows(policy_iteration(Fig[17,1])))
>   >      >   .
^   None   ^   .
^   >      ^   <
peter.norvig's avatar
peter.norvig a validé
"""

MircoT's avatar
MircoT a validé
__doc__ += """
Random tests:
peter.norvig's avatar
peter.norvig a validé
>>> pi
{(3, 2): None, (3, 1): None, (3, 0): (-1, 0), (2, 1): (0, 1), (0, 2): (1, 0), (1, 0): (1, 0), (0, 0): (0, 1), (1, 2): (1, 0), (2, 0): (0, 1), (0, 1): (0, 1), (2, 2): (1, 0)}

>>> value_iteration(Fig[17,1], .01)
{(3, 2): 1.0, (3, 1): -1.0, (3, 0): 0.12958868267972745, (0, 1): 0.39810203830605462, (0, 2): 0.50928545646220924, (1, 0): 0.25348746162470537, (0, 0): 0.29543540628363629, (1, 2): 0.64958064617168676, (2, 0): 0.34461306281476806, (2, 1): 0.48643676237737926, (2, 2): 0.79536093684710951}

>>> policy_iteration(Fig[17,1])
{(3, 2): None, (3, 1): None, (3, 0): (0, -1), (2, 1): (-1, 0), (0, 2): (1, 0), (1, 0): (1, 0), (0, 0): (1, 0), (1, 2): (1, 0), (2, 0): (1, 0), (0, 1): (1, 0), (2, 2): (1, 0)}
MircoT's avatar
MircoT a validé
"""