tachiken's blog

開発、プログラミング、その他

pythonで強化学習を試す

強化学習が巷でブームになっておりますが、私もpython 初心者ながらこれについて手探りながら勉強しています。

その中でとても参考になったのが以下のHironsanさんの記事。

qiita.com

強化学習についての説明はとてもわかりやすかったのですが、実際の記事に記載してあるコードや引用元のUC Berkeleyの中身をみてもメソッドの一部が欠けていたので、自分でコード補完して動かしてみました。 コードは以下、

import random
import operator, math, random, copy, sys,  os.path, bisect
import pandas as pd


def print_table(table, header=None, sep=' ', numfmt='%g'):

    justs = [if_(isnumber(x), 'rjust', 'ljust') for x in table[0]]
    if header:
        table = [header] + table
    table = [[if_(isnumber(x), lambda: numfmt % x, x)  for x in row]
             for row in table]
    maxlen = lambda seq: max(map(len, seq))
    sizes = map(maxlen, zip(*[map(str, row) for row in table]))

def argmin(seq, fn):

    best = seq[0]; best_score = fn(best)
    for x in seq:
        x_score = fn(x)
        if x_score < best_score:
            best, best_score = x, x_score
    return best

def argmax(seq, fn):

    return argmin(seq, lambda x: -fn(x))

def vector_add(a, b):

    return tuple(map(operator.add, a, b))

orientations = [(1,0), (0, 1), (-1, 0), (0, -1)]

def turn_right(orientation):
    return orientations[orientations.index(orientation)-1]

def turn_left(orientation):
    return orientations[(orientations.index(orientation)+1) % len(orientations)]

import random


class MDP:

    def __init__(self, init, actlist, terminals, gamma=.9):
        self.init = init
        self.actlist = actlist
        self.terminals = terminals
        if not (0 <= gamma < 1):
            raise ValueError("An MDP must have 0 <= gamma < 1")
        self.gamma = gamma
        self.states = set()
        self.reward = {}

    def R(self, state):

        return self.reward[state]

    def T(self, state, action):

        raise NotImplementedError

    def actions(self, state):

        if state in self.terminals:
            return [None]
        else:
            return self.actlist


class GridMDP(MDP):

    def __init__(self, grid, terminals, init=(0, 0), gamma=.9):
        grid.reverse()  # because we want row 0 on bottom, not on top
        MDP.__init__(self, init, actlist=orientations,
                     terminals=terminals, gamma=gamma)
        self.grid = grid
        self.rows = len(grid)
        self.cols = len(grid[0])
        for x in range(self.cols):
            for y in range(self.rows):
                self.reward[x, y] = grid[y][x]
                if grid[y][x] is not None:
                    self.states.add((x, y))

    def T(self, state, action):
        if action is None:
            return [(0.0, state)]
        else:
            return [(0.8, self.go(state, action)),
                    (0.1, self.go(state, turn_right(action))),
                    (0.1, self.go(state, turn_left(action)))]

    def go(self, state, direction):

        state1 = vector_add(state, direction)
        return state1 if state1 in self.states else state

    def to_grid(self, mapping):

        return list(reversed([[mapping.get((x, y), None)
                               for x in range(self.cols)]
                              for y in range(self.rows)]))

    def to_arrows(self, policy,U):
        arrows_result={}
        chars = {
            (1, 0): '>', (0, 1): '^', (-1, 0): '<', (0, -1): 'v', None: '.'}
        for(s,a)in policy.items():
            if policy[s] == None:
                arrows_result[s] = U[s]
            else:
                 arrows_result[s]=chars[a] 
        
        return self.to_grid(arrows_result)

# ______________________________________________________________________________


sequential_decision_environment = GridMDP([[-0.04, -0.04, -0.04, +1],
                                           [-0.04, None,  -0.04, -1],
                                           [-0.04, -0.04, -0.04, -0.04]],
                                          terminals=[(3, 2), (3, 1)])

# ______________________________________________________________________________

# ______________________________________________________________________________



sequential_decision_environment2 = GridMDP([[-1, -0.04, -0.04, -0.04,+1],
                                           [-0.04, -0.04,  None,-0.04, -0.04],
                                           [-0.04, None,  -0.04, -0.04,0.04],
                                           [-0.04, -0.04, -0.04,-0.04, -0.04]],
                                          terminals=[(4, 3), (0, 3)])

# ______________________________________________________________________________


def value_iteration(mdp, epsilon=0.001):

    U1 = {s: 0 for s in mdp.states}
    R, T, gamma = mdp.R, mdp.T, mdp.gamma
    while True:
        U = U1.copy()
        delta = 0
        for s in mdp.states:
            U1[s] = R(s) + gamma * max([sum([p * U[s1] 
                                             for (p, s1) in T(s, a)])
                                        for a in mdp.actions(s)])
            delta = max(delta, abs(U1[s] - U[s]))
        if delta < epsilon * (1 - gamma) / gamma:
            return U


def best_policy(mdp, U):

    pi = {}
    for s in mdp.states:
        pi[s] = argmax(mdp.actions(s), lambda a: expected_utility(a, s, U, mdp))
    return pi, U

def expected_utility(a, s, U, mdp):

    return sum([p * U[s1] for (p, s1) in mdp.T(s, a)])

# ______________________________________________________________________________


def policy_iteration(mdp):

    U = {s: 0 for s in mdp.states}
    pi = {s: random.choice(mdp.actions(s)) for s in mdp.states}
    while True:
        U = policy_evaluation(pi, U, mdp)
        unchanged = True
        for s in mdp.states:
            a = argmax(mdp.actions(s), key=lambda a: expected_utility(a, s, U, mdp))
            if a != pi[s]:
                pi[s] = a
                unchanged = False
        if unchanged:
            return pi[f:id:tachiken0210:20170529202347p:plain]


def policy_evaluation(pi, U, mdp, k=20):

    R, T, gamma = mdp.R, mdp.T, mdp.gamma
    for i in range(k):
        for s in mdp.states:
            U[s] = R(s) + gamma * sum([p * U[s1] for (p, s1) in T(s, pi[s])])
    return U


pi ,U= best_policy(sequential_decision_environment, value_iteration(sequential_decision_environment, .01))
pd.DataFrame(sequential_decision_environment.to_arrows(pi,U))

実行すると結果は以下のように

f:id:tachiken0210:20170529202334p:plain

行と列や中身を変更

f:id:tachiken0210:20170529202347p:plain