Implementing SARSA(λ) in Python
This post show how to implement the SARSA algorithm, using eligibility traces in Python. It is part of a serie of articles about reinforcement learning that I will be writing.
Please note that I will go in further details as soon as I can. This is the first version of this article and I simply published the code, but I will soon explain in depth the SARSA(lambda) algorithm along with eligibility traces and their benefits.
This post show how to implement the SARSA algorithm, using eligibility traces in Python. It is part of a serie of articles about reinforcement learning that I will be writing.
Please note that I will go in further details as soon as I can. This is the first version of this article and I simply published the code, but I will soon explain in depth the SARSA(lambda) algorithm along with eligibility traces and their benefits.
SARSA(λ)
The SARSA(λ) pseudocode is the following, as seen in Sutton & Barto’s book :
Python code
import gym
import itertools
from collections import defaultdict
import numpy as np
import sys
import time
from multiprocessing.pool import ThreadPool as Pool
if "./gym-botenv/" not in sys.path:
sys.path.append("./gym-botenv/")
from collections import defaultdict
from gym_botenv.envs.botenv_env import BotenvEnv
from utils import plotting
env = BotenvEnv(1000)
def make_epsilon_greedy_policy(Q, epsilon, nA):
def policy_fn(observation):
A = np.ones(nA, dtype=float) * epsilon / nA
best_action = np.argmax(Q[observation])
A[best_action] += (1.0 - epsilon)
return A
return policy_fn
def sarsa_lambda(env, num_episodes, discount=0.9, alpha=0.01, trace_decay=0.9, epsilon=0.1, type='accumulate'):
Q = defaultdict(lambda: np.zeros(env.nA))
E = defaultdict(lambda: np.zeros(env.nA))
policy = make_epsilon_greedy_policy(Q, epsilon, env.nA)
stats = plotting.EpisodeStats(
episode_lengths=np.zeros(num_episodes),
episode_rewards=np.zeros(num_episodes)
)
rewards = [0.]
for i_episode in range(num_episodes):
print("\rEpisode {}/{}. ({})".format(i_episode+1, num_episodes, rewards[-1]), end="")
sys.stdout.flush()
state = env.reset()
action_probs = policy(state)
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
for t in itertools.count():
next_state, reward, done, _ = env.step(action)
next_action_probs = policy(next_state)
next_action = np.random.choice(np.arange(len(next_action_probs)), p=next_action_probs)
delta = reward + discount*Q[next_state][next_action] - Q[state][action]
stats.episode_rewards[i_episode] += reward
E[state][action] += 1
for s, _ in Q.items():
Q[s][:] += alpha * delta * E[s][:]
if type == 'accumulate':
E[s][:] *= trace_decay * discount
elif type == 'replace':
if s == state:
E[s][:] = 1
else:
E[s][:] *= discount * trace_decay
if done:
break
state = next_state
action = next_action
return Q, stats
if __name__ == ' __main__':
start = time.time()
Q, stats = sarsa_lambda(env, 100)
print(stats)
end = time.time()
print("Algorithm took {} to execute".format(end-start))
plotting.plot_episode_stats(stats, title="sarsa_lambda")
Please note that the environment I used is a personalized environment I setup for the needs of my final year project, involving bots detections, that I will discuss in future articles.
You can find the code on github.