Implementing SARSA(λ) in Python

This post show how to implement the SARSA algorithm, using eligibility traces in Python. It is part of a serie of articles about reinforcement learning that I will be writing.
Please note that I will go in further details as soon as I can. This is the first version of this article and I simply published the code, but I will soon explain in depth the SARSA(lambda) algorithm along with eligibility traces and their benefits.

This post show how to implement the SARSA algorithm, using eligibility traces in Python. It is part of a serie of articles about reinforcement learning that I will be writing.
Please note that I will go in further details as soon as I can. This is the first version of this article and I simply published the code, but I will soon explain in depth the SARSA(lambda) algorithm along with eligibility traces and their benefits.

SARSA(λ)

The SARSA(λ) pseudocode is the following, as seen in Sutton & Barto’s book :

Python code

    import gym
    import itertools
    from collections import defaultdict
    import numpy as np
    import sys
    import time
    from multiprocessing.pool import ThreadPool as Pool
    
    if "./gym-botenv/" not in sys.path:
        sys.path.append("./gym-botenv/")
    
    from collections import defaultdict
    from gym_botenv.envs.botenv_env import BotenvEnv
    from utils import plotting
    
    
    env = BotenvEnv(1000)
    
    def make_epsilon_greedy_policy(Q, epsilon, nA):
    
        def policy_fn(observation):
            A = np.ones(nA, dtype=float) * epsilon / nA
            best_action = np.argmax(Q[observation])
            A[best_action] += (1.0 - epsilon)
            return A
        return policy_fn
    
    def sarsa_lambda(env, num_episodes, discount=0.9, alpha=0.01, trace_decay=0.9, epsilon=0.1, type='accumulate'):
    
        Q = defaultdict(lambda: np.zeros(env.nA))
        E = defaultdict(lambda: np.zeros(env.nA))
    
        policy = make_epsilon_greedy_policy(Q, epsilon, env.nA)
    
        stats = plotting.EpisodeStats(
            episode_lengths=np.zeros(num_episodes),
            episode_rewards=np.zeros(num_episodes)
        )
        rewards = [0.]
    
        for i_episode in range(num_episodes):
    
            print("\rEpisode {}/{}. ({})".format(i_episode+1, num_episodes, rewards[-1]), end="")
            sys.stdout.flush()
    
            state = env.reset()
            action_probs = policy(state)
            action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
    
            for t in itertools.count():
    
                next_state, reward, done, _ = env.step(action)
    
                next_action_probs = policy(next_state)
                next_action = np.random.choice(np.arange(len(next_action_probs)), p=next_action_probs)
    
                delta = reward + discount*Q[next_state][next_action] - Q[state][action]
    
                stats.episode_rewards[i_episode] += reward
    
                E[state][action] += 1
    
                for s, _ in Q.items():
                    Q[s][:] += alpha * delta * E[s][:]
                    if type == 'accumulate':
                        E[s][:] *= trace_decay * discount
                    elif type == 'replace':
                        if s == state:
                            E[s][:] = 1
                        else:
                            E[s][:] *= discount * trace_decay
    
                if done:
                    break
    
                state = next_state
                action = next_action
    
        return Q, stats
    
    
    if __name__ == ' __main__':
        start = time.time()
        Q, stats = sarsa_lambda(env, 100)
        print(stats)
        end = time.time()
        print("Algorithm took {} to execute".format(end-start))
        plotting.plot_episode_stats(stats, title="sarsa_lambda")

Please note that the environment I used is a personalized environment I setup for the needs of my final year project, involving bots detections, that I will discuss in future articles.
You can find the code on github.