スーパーマリオの強化学習を動かす(Stable Baselines 3)

2024-08-07

PyTorchのチュートリアルにファミコンのスーパーマリオを使った強化学習があってずっと昔に試した時にはGoogle Colab上で動かせたはずなんだけど、 今ではライブラリのインストールで競合してバージョンの不整合でエラーが出たりして動かせない。 なんとか動かしたかったのでローカルで起動できるようやってみた。

スーパーマリオの強化学習

チュートリアルと修正方法

モジュール構成

件のチュートリアルがどういう構成になっているのか:

  • gym-super-mario-bros:スーパーマリオをGymのAPIに載せたもの
  • nes-py:ファミコンのエミュレータと、Gym用の環境や行動
  • gym:強化学習プラットフォーム

上記をモジュールとしてインストールした上で、強化学習のコードをColab上で動かしている。

gym

強化学習に必要な要素は環境とエージェントで、エージェントの行動によって得られる報酬をいかに増やすかという学習を行う。 Gymは環境のAPIを用意することで強化学習のプラットフォームの体制を整えている。 Gymの環境gym.Envクラスを継承して必要なメソッドを実装することで自作の環境を用意できる。 Envは重ねることができる。

nes-py

C言語で書かれたファミコンエミュレータをPythonから呼び出せるようにして、その画面をGymの環境として強化学習に利用できるようにしている。 gym.Envを継承したNESEnvが用意されている:

  • 観測空間:240x256x3
  • 行動空間:256(パッド入力各ビット:上下左右ABスタートセレクト)
    • ラッパーのJoypadSpaceで、ボタン名で簡単に指定できるようになっている
  • step:ファミコンエミュレータを1フレーム進める
  • renderhumanの場合、pygletで描画

gym-super-mario-bros

スーパーマリオの強化学習用環境:SuperMarioBrosEnv

問題点

チュートリアルが公開されてからだいぶ年月が経過していることもあって、色々問題がある:

  • OpenAI GymのBreaking Change:env.stepからの戻り値が4タプルから5に変更されたことでエラーが出る
  • OpenAI Gymがdeprecated:後継gymnasiumに載せ替えたい
    • gym-super-mario-brosとnes-pyどちらも
  • np.uint8の演算:np.uint8に対してintの加減算でも自動的に拡張してくれず0~255の範囲にクランプされる

修正方法

pipでモジュールをインストールしてしまってはコードが修正できないので、Github上にフォークして修正することにした。

でメインで動かすコードもGoogle ColabやJupyter Notebookではなく、単なるPythonスクリプトにした。 pipでのGithubのアドレスを指定してのインストールもできるっぽいが、 nes-pyはエミュレータがC言語で書かれていてコンパイルが必要なので、 gitのサブモジュールとして取り込んでビルドするようにした。

強化学習にStable Baselines 3を使ってみる

チュートリアルでは強化学習のアルゴリズムとしてDDQNをPyTorchを使って実際に実装している。 実装することでアルゴリズムを理解したりいじったりできるのは理想ではあるが実際のところ自分には難しいので、ありものを利用したい。 Stable Baselines 3(以下SB3)というライブラリがあるそうなのでそれを使うことにした。

Stable Baselines 3

SB3では強化学習の様々なアルゴリズムが実装されていて、簡単に利用できるようになっている。 エージェントの行動が離散的か連続的かによって使えるアルゴリズムが制限される (を参照のこと)

  • SB3はgymじゃなくてgymnasiumを使用している。

方策ネットワークの指定方法

SB3で方策ネットワークを指定するには、各アルゴリズムのコンストラクタのpolicy_kwargsfeatures_extractor_classとして指定することでできる (Custom Feature Extractor)。

BaseFeaturesExtractorを継承したクラスのforwardメソッドで任意の特徴量を返せるよう、 コンストラクタでネットワークを構築しておく。

学習結果

各アルゴリズムを使って学習させてみた。

  • 学習を簡単にするため、取れる行動を右ダッシュ・右ダッシュ+ジャンプ・左の3つに限定してみた
    • 1-1ならクリア可能なはず…
  • NESの画素(256x240xRGB)をそのまま使うんじゃなく、縮小や畳み込みを行う:元のチュートリアルと同様に
    • スキップフレーム:4フレームごとに行動させる
    • グレイスケール化
    • 84x84に画像縮小(CNNではなく単なる画像縮小)
    • フレームスタック:4フレーム重ねて時間経過がわかるようにする
    • CNN:84x84x4 → 20x20x32 → 9x9x64 → 7x7x64
    • フラット化:特徴数3,136
    • 全結合:3,136 → 512 → 行動数
  • モデルのtrainメソッドのtotal_timesteps5,000,000(エピソード数とは違う)

これをSB3の各種アルゴリズムで動かしてみた:PPO, TRPO, A2C, QRDQN, DQN。 学習には自分のマシンで各1日程度かかった。

PPO
PPO
TRPO
TRPO
A2C
A2C
QRDQN
QRDQN
DQN
DQN

考察

  • 1-1かつ行動が3通りとそんなに難しくなさそうに思うけど、全然学習が安定しない
  • PPOだけ性能がよくて他は全然、ハイパーパラメータが悪いとか?
  • 学習結果をリプレイさせるためにmodel.predictで行動を選んでもランダムに左右される
  • モデルに渡す画像は84x84のグレイスケール化しているが、人間的にはこんな画像で学習できる気がしない

クッパ面に挑戦

試した中ではPPOが一番よかったので、より難しそうなクッパ面で試してみた。

  • 行動:complex
  • グレイスケール化をやめてRGBのままに
  • CNNの特徴数:64-64-64
  • スキップフレーム:4から2に

1,000万ステップほど学習させたところ運がよければクリアできるようになった:https://youtu.be/mlSjsejrrZY

  • 行動をcomplexにしたのにちゃんとBダッシュメインで、穴やファイアを飛び越え、ファイアの前では減速、長ジャンプするの偉い

ソース

参考・リンク

付録

Stable Baselines 3の各種アルゴリズム

SB3に実装されているアルゴリズムがどんなものかググってみた:

  • PPO:A2CとTRPOのアイディアを組み合わせ、更新後のポリシーが前のポリシーから乖離しすぎないようクリップする
  • Maskable PPO:無効な行動をマスク
  • TRPO:方策更新を大きくしすぎないようにして学習を安定化
  • A2C :A3Cの「同期」版で、A3Cは非同期に複数のエージェントを用いて学習する
  • QRDQN:カテゴリカルDQNの分布を分位点から求めることで改良
  • DQN:Q学習をDNNで、リプレイ、ターゲットネットワークの固定化、勾配クリッピング

ハイパーパラメータ:

  • DQN: learning_rate: 1e-4, gamma: 0.99
  • QRDQN: learning_rate: 5e-4, gamma: 0.99
  • A2C: learning_rate: 7e-4, gamma: 0.99
  • TRPO: learning_rate: 1e-3, gamma: 0.99
  • PPO: learning_rate: 3e-4, gamma: 0.99

動かせなかったアルゴリズム:

  • ARS:ValueError: Policy CnnPolicy unknown
  • HER:ImportError: Since Stable Baselines 2.1.0, HER is now a replay buffer class HerReplayBuffer 単独で使えなくなり、他のアルゴリズムのリプレイバッファとして指定する
  • RecurrentPPO:CnnPolicyを受け付けない