PyTorch provide a simple DQN implementation to solve the cartpole game. However, the code is incorrect, it diverges after training (It has been discussed here).

The official code’s training data is below, it’s high score is about 50 and finally diverges.

There are many reason that lead to divergence.

First it use the difference of two frame as input in the tutorial, not only it loss the cart’s absolute information(This information is useful, as game will terminate if cart moves too far from centre), but also confused the agent when difference is the same but the state is varied.

Second, small replay memory. If the memory is too small, the agent will forget the strategy it has token in some state. I’m not sure whether 10000 memory is big enough, but I suggest using a higher value.

Third, the parameters. learning_rate, target_update_interval may cause fluctuation. Here is a example on stackoverflow. I also met this problem when training cartpole agent. The reward stops growing after 1000 episode.

After doing some research on the cartpole DNQ code, I managed to made a model to play the flappy bird. Here are the changes from the original cartpole code. Most of the technology can be found in this papers: Playing Atari with Deep Reinforcement Learning

  1. Dueling DQN

    The vanilla DQN has the overestimate problem. As the max function will accumulate the noise when training. This leads to converging at suboptimal point. Two following architectures are submitted to solve this problem.

    Double DQN was published two year later DQN. It has two value function, one is used to choose the action with max Q value, another one is used to calculate the Q value of this action.

    \(a^{max}(S’_j, w) = \arg\max_{a’}Q(\phi(S’_j),a,w)\)

    \(y_j = R_j + \gamma Q’(\phi(S’_j),a^{max}(S’_j, w),w’)\)

    Dueling DQN is another solution. It has two estimator, one estimates the score of current state, another estimates the action score.

    \(Q(S,A,w,\alpha, \beta) = V(S,w,\alpha) + A(S,A,w,\beta)\)

    In order to distinguish the score of the actions, the return the Q-value will minus the mean action score:

    x=val+adv-adv.mean(1,keepdim=True)

    In this project, I use dueling DQN.

  2. Image processing

    I grayscale the image, then remove the background color.

  3. Stack frames

    I use the last 4 frame as the input. This should help the agent to know the change of environment.

  4. Extra FC before last layer

    I add a FC between the image features and the FC for calculate Q-Value.

  5. Frame Skipping

    Frame-skipping means agent sees and selects actions on every k frame instead of every frame, the last action is repeated on skipped frames. This method will accelerate the training procedure. I have tried different frame skipping values. When frame=1, I got max reward, about 100. When frame=2, the max reward drop to 50. When k=4, the agent failed to play the game, and max reward stays at 0. Although frame skipping is not working with flappy bird, someone has prove that it does work in many cases. More details can be found in this post.

  6. Prioritized Experience Replay

    This idea was published here. It’s a very simple idea: replay high TD error experience more frequently. My code implementation is not efficient. But in cartpole game, this technology help the agent converge faster. Here is the result on cartpole. The formoer one is uniform replay, the later is prioritized replay.

    ddqn_cartpole_normal.png ddqn_cartpole_prioritized.png

  7. Colab and Kaggle Kernel

    My MacBook doesn’t support CUDA, so I use these two website to train the model. Here are the comparison of them. During training, Kaggle seems more stable, Colab usually disconnected after 1h.

    Colab Kaggle Kernel
    GPU Tesla T4 Tesla P100
    Max training time 12h 9h
    Export trained model Google Drive -

The lesson I learnt from this project is patience. It takes a long time(maybe hundreds of thousand steps) to see whether this model works, and there are so many parameters can effect the final performance. It takes me about 3 weeks to build the final model. So if you want to build your own model, be patient and good luck. Here are two articles talking about the debugging and hyperparameter tuning in DQN:

Here are something may help with this task.

I’m still trying to find better solution. I’ve uploaded code to this repo.

Ref:

  1. PyTorch REINFORCEMENT LEARNING (DQN) TUTORIAL
  2. 强化学习 (A series of Chinese post about reinforcement learning)
  3. Deep Reinforcement Learning for Flappy Bird
  4. Flappy-Bird-Double-DQN-Pytorch
  5. DeepRL-Tutorials