Hide code cell source
# UNCOMMENT FOR INTERACTIVE PLOTTING
# %matplotlib notebook
%matplotlib widget
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import animation, rc, cm
import time
from snakelib import FastSnake, show_gui, NeuralAgent
import snakelib as snklib
import utils
from scipy.ndimage import label

rc("animation", html="html5")

Reinforcement learning on snake with a genetic neural network#

Required files

In order to work properly, this notebook requires the following modules in its folder:

Put it in your working directory along with this notebook.

This notebook is an example of supervised learning applied to video games. You will use the legendary game Snake rewritten in Python for the occasion and will try to develop an automatic game strategy. In a first step, by hand and in a second step using a genetic algorithm to evolve a neural network. Graphical examples will allow to see the evolution of the game performances.

Part 1: Try the game#

In this first part, you are asked to try the game and check that you understand the rules. Try to imagine what you need to know to win a game.

np.random.seed(0)  # Fixing the seed
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
snake = FastSnake(Nrow=10, Ncol=10)
display(show_gui(snake, ax))
plt.close()

Some explications#


1. Game Grid#

The game grid is composed of a number of columns noted Ncol and a number of lines noted Nrow.

print(
    f"Grid dimensions: {snake.Ncol} columns, {snake.Nrow} rows, a total of {snake.Ncell} elements"
)
Grid dimensions: 10 columns, 10 rows, a total of 100 elements

Each element of the grid is identified by an unique index.

Hide code cell source
utils.plot_grid(snake)

2. Snake directions#

From the current position of its head, the snake can only take three directions at each turn, it can go to its right, forward and left.

Here, an example of commands that allow you to move the snake :

# Moving the snake
np.random.seed(0)  # Fixing the seed
snake.reset()  # Reset snake position
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(1)  # Go Left

# Access to neighbors element positions
print(f"Current neighbors index {snake.get_neighbors_pos()}")
Current neighbors index [44 35 24]

Note

You notice that the value -1 turn the snake to the right, 0 in front and 1 to the left relatively to the current snake direction.

Here are the results of the actions taken by these various successive orders:

Hide code cell source
utils.plot_snake(snake)

Question 1#

What commands must be given for the snake to eat the first fruit ?

# Initial snake position
np.random.seed(0)  # Fixing the seed
snake.reset()  # Reset snake position
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(1)  # Go Left

# Your answer

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
display(show_gui(snake, ax))

3. Snake sensors#

Chris Roberts (video game developer)

It turns out that in their infinite wisdom, the game designers gave the snake some very weird sensors to help the players. Here’s how to display the sensor data:

Default sensors:#
# Moving the snake
np.random.seed(0)  # Fixing the seed
snake.reset()  # Reset snake position
snake.display_sensor_method = "default"
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
display(show_gui(snake, ax))

The first 3 elements indicate respectively the nature of the boxes directly to the right, in front of and to the left of the snake’s head (relative to its direction).

The values of the elements can be:

  • 1.0 = fruit present

  • -1.0 = forbidden position (lava or snake tail)

  • 0.0 = nothing special present

Hide code cell source
utils.plot_snake(snake)

Example

The head of the snake is in position 11. To its left is lava, a value of -1 is read on the first element of the sensors output. To its right, there is nothing, a value of 0 is read on the third element of the sensors output.

The last two data give the relative directions of the fruit relative to the head of the snake. Indeed, there is an angle \(\theta\) between the direction of the snake and the position of the fruit on the grid.

These last two values are respectively the sign of \(\cos(\theta)\) and \(\sin(\theta)\).

# Moving the snake
np.random.seed(0)  # Fixing the seed
snake.reset()  # Reset snake position
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(-1)  # Go Right
snake.turn(0)  # Go Front
snake.turn(1)  # Go Left

# Snake current direction
direction = snake.get_current_direction()
print(f"Current direction: {direction}")
Current direction: 0
The current direction of the snake based on the positions of its head and neck.
direction (int): An integer representing the current direction of the snake, where:
    0 = right
    1 = up
    2 = left
    3 = down
    -1 = if the direction could not be determined
Hide code cell source
utils.plot_snake_theta(snake)
Advanced sensors: label sensors#

The label sensors have the same shape as the “default” sensors. The main difference is that they identify the number of free cases that are reachable with a given choice. If the number of reachable cells is the maximum, it returns 1. Otherwise, it returns -1. If there is no reachable cell, it returns -1.

# Moving the snake
np.random.seed(0)  # Fixing the seed
snake.reset()  # Reset snake position
snake.display_sensor_method = "label"
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
display(show_gui(snake, ax))

Automatic playing#

Question 2#

At first, you are asked to build an agent that will play automatically. It will have to make decisions according to the snake’s sensor values. These decisions are to choose the best direction to take.

Here an example of a very dummy agent:

def my_agent(sensors):
    """
    A dummy agent that moves in a random direction.
    """
    return np.random.randint(3) - 1


snake = FastSnake(Nrow=10, Ncol=10)
my_choice = my_agent(snake.sensors())
# Your answer

Automatic play with graphic output#

Hide code cell source
snake5 = FastSnake(Nrow=15, Ncol=15)


def updatefig(*args):
    sensors = snake5.sensors(method="label")
    my_choice = my_agent(sensors)
    snake5.turn(my_choice)
    im2.set_array(snake5.grid)
    score = snake5.score
    iteration = snake5.iteration
    title2.set_text(f"Score = {score} Iteration = {iteration}")
    if snake5.status != 0:
        snake5.reset()
    return (im2,)


fig2, ax2 = plt.subplots()
ax2.axis("off")
im2 = plt.imshow(snake5.grid, interpolation="nearest", animated=True)
score = snake5.score
title2 = ax2.set_title(f"Score = {score}", animated=True)
anim = animation.FuncAnimation(fig2, updatefig, frames=40, interval=50, blit=True)
# plt.show()  # UNCOMMENT TO PLAY
plt.close()  # COMMENT TO PLAY
anim  # COMMENT TO PLAY
Hide code cell source
anim.pause()
plt.close(fig2)

Benchmark#

Hide code cell source
Nagent_ids = 200
max_turns = 1000
snake3 = FastSnake(Nrow=10, Ncol=10)
scores = np.zeros(Nagent_ids)
turns = np.zeros(Nagent_ids)
# for agent_id in tqdm.trange(Nagent_ids):
for agent_id in range(Nagent_ids):
    snake3.reset()
    turn = 0
    while snake3.status == 0:
        sensors = snake3.sensors()
        my_choice = my_agent(sensors)
        snake3.turn(my_choice)
        turn += 1
        if turn >= max_turns:
            break
    scores[agent_id] = snake3.score
    turns[agent_id] = turn
Hide code cell source
data = pd.DataFrame({"score": scores, "turns": turns})
data.describe().loc[["mean", "std", "max", "min", "count"]].T
mean std max min count
score 0.03 0.198233 2.0 0.0 200.0
turns 3.32 5.791833 37.0 1.0 200.0

Genetic Neural Network agent#

# NEURAL FUNCTIONS
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))


def ReLu(x):
    return np.where(x > 0.0, x, 0.0)


def identity(x):
    return x


# ARG MAX
def argMax(x):
    return int(np.where(x == x.max())[0][0])
# GENETIC ALGORITHM SETUP

Npop = 100  # NUMBER OF INDIVIDUALS IN THE POPULATION
Ngen = 5  # NUMBER OF GENERATIONS OF EVOLUTION
Ntries = 1  # NUMBER OF TRIES PER INDIVIDUAL PER GENERATION
Net_struct = 5, 3  # NETWORK STRUCTURE
keep_ratio = 0.2  # # GENETIC ALGORITHM KEEP RATIO
mutation_ratio = 0.1  # MUTATION RATIO
mutation_sigma = 1.0  # MUTATION GAUSSIAN SIGMA
max_turns = 600  # MAX PLAY TURNS PER TRIAL
neural_functions = [identity]  # NEURAL FUNCTION VECTOR
sensor_method = "default"

# PRE-PROCESSING
keep_individuals = int(keep_ratio * Npop)
Nw = 0
for i in range(len(Net_struct) - 1):
    nin = Net_struct[i]
    nout = Net_struct[i + 1]
    Nw += (nin + 1) * nout
# all_weights = np.random.normal(loc=0.0, scale=1.0, size=(Npop, Nw))
all_weights = (np.random.rand(Npop, Nw) - 0.5) * 2.0

agents = []
for agent_id in range(Npop):
    weights = all_weights[agent_id]
    agent = NeuralAgent(
        weights=weights, structure=Net_struct, neural_functions=neural_functions
    )
    agents.append(agent)
agent_functions = [agent.get_caller() for agent in agents]
func = agent_functions[0]
total_generations = 1
generation_store = []
best_score_store = []
# DATA STORAGE
snake4 = FastSnake(
    Nrow=12, Ncol=12, record_turns=False, recorded_sensors_method=sensor_method
)
scores = np.zeros(Npop)
turns = np.zeros(Npop)
tries_scores = np.zeros(Ntries)
tries_turns = np.zeros(Ntries)
new_all_weights = np.zeros_like(all_weights)
turn_ids = np.array([-1.0, 0.0, 1.0])
for generation in range(Ngen):
    print(f"Generation: {total_generations}")
    generation_store.append(total_generations)
    scores[:] = 0.0
    turns[:] = 0.0
    for agent_id in range(Npop):
        # np.random.seed(0)
        tries_scores[:] = 0.0
        tries_turns[:] = 0.0
        agent_func = agent_functions[agent_id]
        for trial in range(Ntries):
            snake4.reset()
            Ncol = snake4.Ncol
            Nrow = snake4.Nrow
            snake4.fruit_position = (Nrow - 2) * Ncol + Ncol - 2
            turn = 0
            while snake4.status == 0:
                sensors = snake4.sensors(method=sensor_method)
                my_choice = turn_ids[argMax(agent_func(sensors))]
                snake4.turn(my_choice)
                turn += 1
                if turn >= max_turns:
                    break
            tries_scores[trial] = snake4.score
            tries_turns[trial] = turn
        scores[agent_id] = tries_scores.mean()
        turns[agent_id] = tries_turns.mean()
    perf = scores * 100 - turns
    order = np.argsort(perf)[::-1]
    new_all_weights[:] = 0.0
    # SELECTION
    new_all_weights[:keep_individuals] = all_weights[order][:keep_individuals]
    # HYBRIDATION
    keep_range = np.arange(keep_individuals)
    for indiv in range(keep_individuals, Npop):
        parents = np.random.choice(keep_range, 2)
        while parents[1] == parents[0]:
            parents = np.random.choice(keep_range, 2)
        pw = np.random.rand(Nw)
        new_all_weights[indiv] = (
            new_all_weights[parents[0]] * pw + (1.0 - pw) * new_all_weights[parents[1]]
        )

        # MUTATION:
        if np.random.rand() <= mutation_ratio:
            # mutation_loc = np.random.randint(Nw)
            new_all_weights[indiv] *= np.random.normal(
                loc=1.0, scale=mutation_sigma, size=Nw
            )
    total_generations += 1
    all_weights[:] = new_all_weights

    data = pd.DataFrame(
        {"score": scores[order], "turns": turns[order], "perf": perf[order]}
    )  # .sort_values( "perf", ascending=False    )
    print(data.head(5))
    print(f"=> best score = {scores.max()}")
    best_score_store.append(data.iloc[0].score)
print("FINISHED")
plt.figure("Snake at the gym !")
plt.clf()
plt.plot(generation_store, best_score_store, "or-")
plt.xlabel("Generation")
plt.ylabel("Best Score")
plt.grid()
plt.show()
Generation: 1
   score  turns   perf
0    6.0   65.0  535.0
1    1.0   19.0   81.0
2    1.0   20.0   80.0
3    1.0   22.0   78.0
4    0.0    1.0   -1.0
=> best score = 6.0
Generation: 2
   score  turns    perf
0   16.0  140.0  1460.0
1   15.0  118.0  1382.0
2   13.0  175.0  1125.0
3   11.0   93.0  1007.0
4   11.0  172.0   928.0
=> best score = 16.0
Generation: 3
   score  turns    perf
0   22.0  248.0  1952.0
1   19.0  142.0  1758.0
2   19.0  184.0  1716.0
3   17.0  134.0  1566.0
4   17.0  142.0  1558.0
=> best score = 22.0
Generation: 4
   score  turns    perf
0   27.0  200.0  2500.0
1   25.0  259.0  2241.0
2   21.0  169.0  1931.0
3   17.0  145.0  1555.0
4   17.0  149.0  1551.0
=> best score = 27.0
Generation: 5
   score  turns    perf
0   26.0  174.0  2426.0
1   23.0  221.0  2079.0
2   21.0  183.0  1917.0
3   21.0  193.0  1907.0
4   20.0  148.0  1852.0
=> best score = 26.0
FINISHED
snake5 = FastSnake(
    Nrow=12, Ncol=12, record_turns=True, recorded_sensors_method=sensor_method
)
# np.random.seed(0)
# weights = all_weights[0]  # BEST AGENT
best_agent_func = agent_functions[0]
turn_ids = np.array([-1.0, 0.0, 1.0])
np.random.seed(0)


def updatefig(*args):
    if snake.status == 0:
        sensors = snake5.sensors(method=sensor_method)
        my_choice = turn_ids[argMax(best_agent_func(sensors))]

        snake5.turn(my_choice)
        im2.set_array(snake5.grid)
        score = snake5.score
        iteration = snake5.iteration
        title2.set_text(f"Score = {score} Iteration = {iteration}")

    else:
        title2.set_text(f"Youd died with score = {score}")
    return im2


score = 0
fig2, ax2 = plt.subplots()
ax2.axis("off")
title2 = ax2.set_title(f"Score = {score}", animated=True)
im2 = plt.imshow(snake5.grid, interpolation="nearest", animated=True)
anim = animation.FuncAnimation(fig2, updatefig, frames=1000, interval=30, blit=True)
# plt.show()
plt.close()
anim
turns = snake5.recorded_turns
sensors = np.array(snake5.recorded_sensors)

out = {
    "turn": turns,
}
for i, s in enumerate(sensors.T):
    out[f"s{i}"] = s

data = pd.DataFrame(out)
data.iloc[-10:]
turn s0 s1 s2 s3 s4
105 -1.0 0.0 0.0 0.0 1.0 -1.0
106 1.0 -1.0 -1.0 0.0 1.0 1.0
107 -1.0 0.0 0.0 0.0 0.0 -1.0
108 0.0 0.0 0.0 0.0 1.0 0.0
109 0.0 0.0 1.0 0.0 1.0 0.0
110 -1.0 0.0 -1.0 0.0 -1.0 -1.0
111 -1.0 0.0 0.0 -1.0 1.0 -1.0
112 0.0 -1.0 0.0 0.0 1.0 1.0
113 1.0 -1.0 -1.0 0.0 1.0 1.0
114 -1.0 -1.0 0.0 0.0 1.0 -1.0

Question 4#

Play with the neural agent and find a performance formulation to get the highest score in less than 1000 turns.