It's Really Easy To Play Flappy Bird Via Reinforcement Learning


stable Flappy Bird game played by reinforcement learning algorithm (DQN) after training 18 hours

Introduction

I think you must have heared about the famous game known as Flappy Bird. If not, please try it first. It’s not easy to get high score. However, for a computer, it’s maybe very simlpe!

In this post, I’ll show you how to play this game via reinforcement learning in only 180 lines of code!. Since the algorithm is introduced in this post: How To Play Cart Pole Game with Temporal Difference Learning By Pytorch?, I’ll mainly introduce how to train the computer to play the game instead of math.

The code is writen in python and the deep learning library is pytorch.

Flappy Bird

We choose pygame learning environment for game rendering environment. First you need to install it, please check official site.

After installing the environment, you need to include the path to the PLE using

import sys
sys.path.append("/path/to/PyGame-Learning-Environment")

For example, if you use the code below, you can see the flapy bird!

import sys
sys.path.append("/path/to/PyGame-Learning-Environment")
from ple.games.flappybird import FlappyBird
from ple import PLE

game = FlappyBird()
p = PLE(game, fps=30, display_screen=True, force_fps=False)
p.init()
actions = p.getActionSet()
action_dict = {0: actions[1], 1:actions[0]}

reward = 0.0

for i in range(10000):
   if p.game_over():
           p.reset_game()

   state = p.getScreenRGB()
   action = 1
   reward = p.act(action_dict[action])

The state in this game can be RGB image returned by getScreenRGB() or non-visual state representation of the game returned by getGameState(). In this post, I will show you how to use convolutional neural networks to train the computer, thus we choose getScreenRGB(). The returned image size is (512, 288, 3), since the ground is useless we will remove them, and resize the image to 80x80, then binarize it for faster convergence. Here is the image preprecess function.

def img_preprocess(img):
    img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    img = cv2.flip(img,1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    img = img[:400,:]
    img = cv2.resize(img, (80, 80))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    retval, img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY)
    #img = torch.FloatTensor(img)
    return img

With the image preprocessed, the image becomes:

stable Image preprocess

DQN

It’s a very simple convolutional neural networks writen in pytorch.

class DQN(torch.nn.Module):
    def __init__(self, action_space):
        super(DQN, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
        self.fc1 = nn.Sequential(nn.Linear(6 * 6 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, action_space)

    def forward(self, observation):
        output = self.conv1(observation)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)
        return output

Here is the algorithm from nature paper. The training code is respongding with the algorithm, please make sure to read the code and algorithm carefully. Deep Q-learning with experience replay

Show me the code

Here is th full script.

import sys
sys.path.append("/path/to/PyGame-Learning-Environment")
from ple.games.flappybird import FlappyBird
from ple import PLE
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import cv2
from torch.utils.tensorboard import SummaryWriter

GAMMA = 0.99
LEARNING_RATE = 3e-6
MEMORY_SIZE = 50000
BATCH_SIZE = 48
EXPLORATION_MAX = 5e-2
EXPLORATION_MIN = 1e-5
EXPLORATION_DECAY = 0.9999

Model_path = "./model.pth"
writer = SummaryWriter()

class DQN(torch.nn.Module):
    def __init__(self, action_space):
        super(DQN, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
        self.fc1 = nn.Sequential(nn.Linear(6 * 6 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, action_space)

    def forward(self, observation):
        output = self.conv1(observation)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)
        return output
class GameSolver:
    def __init__(self, action_space):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.Q_policy = DQN(action_space)
        self.Q_target = copy.deepcopy(self.Q_policy)
        if torch.cuda.is_available():
            self.Q_policy = self.Q_policy.cuda()
            self.Q_target = self.Q_target.cuda()
        self.optimizer = optim.Adam(self.Q_policy.parameters(), lr = LEARNING_RATE)
        self.lossFuc = torch.nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def predict(self, state):

        q_values = self.Q_policy(state)[0]

        # best action
        if np.random.rand() < self.exploration_rate:
            max_q_index = [random.randrange(self.action_space)]
            max_q_one_hot =  one_hot_embedding(max_q_index, self.action_space)
        else:
            max_q_index =  [torch.max(q_values, 0).indices.cpu().numpy().tolist()]
            max_q_one_hot =  one_hot_embedding(max_q_index, self.action_space)
       # print(torch.max(q_values).detach().numpy())
        return max_q_index, max_q_one_hot, q_values

    def experince_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.cat(tuple(action for action in action_batch))
        reward_batch = torch.cat(tuple(reward for reward in reward_batch))
        reward_batch = reward_batch.view(len(reward_batch), 1)
        next_state_batch = torch.cat(tuple(next_state for next_state in next_state_batch))
        current_prediction_batch = self.Q_policy(state_batch)
        next_prediction_batch = self.Q_target(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if terminal else reward + GAMMA * torch.max(prediction) for reward, terminal, prediction in
                  zip(reward_batch, terminal_batch, next_prediction_batch)))

        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        self.optimizer.zero_grad()
        # y_batch = y_batch.detach()
        loss = self.lossFuc(q_value, y_batch)
        loss.backward()
        self.optimizer.step()

        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate  = max(EXPLORATION_MIN, self.exploration_rate)

def Flappy_bird():
    game = FlappyBird()
    game.allowed_fps = None
    p = PLE(game, fps=200, display_screen=True, force_fps=False)
    p.init()
    actions = p.getActionSet()
    action_dict = {0: actions[1], 1:actions[0]}
    reward = 0.0
    gameSolver = GameSolver(len(actions))
    run = 0
    while True:
        run += 1
        p.reset_game()
        state = img_preprocess(p.getScreenRGB())
        step = 0
        if run % 5000 == 4999:
            torch.save(gameSolver.Q_policy.state_dict(), Model_path)
        while True:
            step += 1
            action, action_one_hot, p_values = gameSolver.predict(state)
            reward = p.act(action_dict[action[0]])
            state_next = img_preprocess(p.getScreenRGB())
            terminal = p.game_over()

            reward = torch.tensor([reward])
            if torch.cuda.is_available():
                reward = reward.cuda()
                action_one_hot = action_one_hot.cuda()
            gameSolver.remember(state, action_one_hot, reward, state_next,
                                terminal)
            state = state_next

            if terminal:
                print ("Run: " + str(run) + ", exploration: " + str(gameSolver.exploration_rate) + ", score: " + str(step) )
                writer.add_scalar('Q_value', torch.max(p_values), run)
                writer.add_scalar('Score', step, run)
                break
            gameSolver.experince_replay()
            if step % 50 == 0:
                gameSolver.Q_target.load_state_dict(copy.deepcopy(gameSolver.Q_policy.state_dict()))
def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.
    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes) 
    return y[labels] 

def img_preprocess(img):
    img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    img = cv2.flip(img,1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img[:400,:]
    img = cv2.resize(img, (80, 80))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    retval, img = cv2.threshold(img, 50, 255, cv2.THRESH_BINARY)
    img = torch.FloatTensor(img)
    img = img.unsqueeze(0)
    img = img.unsqueeze(0)
    if torch.cuda.is_available():
        img = img.cuda()
    return img

if __name__ == "__main__":
    Flappy_bird()

Here is the result after training 18 hours: Deep Q-learning with experience replay

Summary

In this post we introduced how to use neural networks to play Flappy Bird game in only 180 lines of code. However, as you can see, the system need a lot of time to learn even for such an easy game! Also, the DQN is not very stable since it will also make many mistakes. In the following posts, we will introduce other methods to improve DQN, and policy gradient methods that can be better than simple DQN.

References


Welcome to share or comment on this post:

Table of Contents