PyTorchのチュートリアルにファミコンのスーパーマリオを使った強化学習があってずっと昔に試した時にはGoogle Colab上で動かせたはずなんだけど、 今ではライブラリのインストールで競合してバージョンの不整合でエラーが出たりして動かせない。 なんとか動かしたかったのでローカルで起動できるようやってみた。
チュートリアルと修正方法
モジュール構成
件のチュートリアルがどういう構成になっているのか:
- gym-super-mario-bros:スーパーマリオをGymのAPIに載せたもの
- nes-py:ファミコンのエミュレータと、Gym用の環境や行動
- gym:強化学習プラットフォーム
上記をモジュールとしてインストールした上で、強化学習のコードをColab上で動かしている。
gym
強化学習に必要な要素は環境とエージェントで、エージェントの行動によって得られる報酬をいかに増やすかという学習を行う。
Gymは環境のAPIを用意することで強化学習のプラットフォームの体制を整えている。
Gymの環境gym.Env
クラスを継承して必要なメソッドを実装することで自作の環境を用意できる。
Env
は重ねることができる。
nes-py
C言語で書かれたファミコンエミュレータをPythonから呼び出せるようにして、その画面をGymの環境として強化学習に利用できるようにしている。
gym.Env
を継承したNESEnv
が用意されている:
- 観測空間:240x256x3
- 行動空間:256(パッド入力各ビット:上下左右ABスタートセレクト)
- ラッパーの
JoypadSpace
で、ボタン名で簡単に指定できるようになっている
- ラッパーの
step
:ファミコンエミュレータを1フレーム進めるrender
:human
の場合、pygletで描画
gym-super-mario-bros
スーパーマリオの強化学習用環境:SuperMarioBrosEnv
- 観測空間:
NESEnv
のまま - 行動:
- 報酬:
問題点
チュートリアルが公開されてからだいぶ年月が経過していることもあって、色々問題がある:
- OpenAI GymのBreaking Change:
env.step
からの戻り値が4タプルから5に変更されたことでエラーが出る - OpenAI Gymがdeprecated:後継gymnasiumに載せ替えたい
- gym-super-mario-brosとnes-pyどちらも
np.uint8
の演算:np.uint8
に対してint
の加減算でも自動的に拡張してくれず0~255
の範囲にクランプされる
修正方法
pipでモジュールをインストールしてしまってはコードが修正できないので、Github上にフォークして修正することにした。
でメインで動かすコードもGoogle ColabやJupyter Notebookではなく、単なるPythonスクリプトにした。
pipでのGithubのアドレスを指定してのインストールもできるっぽいが、
nes-py
はエミュレータがC言語で書かれていてコンパイルが必要なので、
gitのサブモジュールとして取り込んでビルドするようにした。
強化学習にStable Baselines 3を使ってみる
チュートリアルでは強化学習のアルゴリズムとしてDDQNをPyTorchを使って実際に実装している。 実装することでアルゴリズムを理解したりいじったりできるのは理想ではあるが実際のところ自分には難しいので、ありものを利用したい。 Stable Baselines 3(以下SB3)というライブラリがあるそうなのでそれを使うことにした。
Stable Baselines 3
SB3では強化学習の様々なアルゴリズムが実装されていて、簡単に利用できるようになっている。 エージェントの行動が離散的か連続的かによって使えるアルゴリズムが制限される (表を参照のこと)
- SB3は
gym
じゃなくてgymnasium
を使用している。
方策ネットワークの指定方法
SB3で方策ネットワークを指定するには、各アルゴリズムのコンストラクタのpolicy_kwargs
でfeatures_extractor_class
として指定することでできる
(Custom Feature Extractor)。
BaseFeaturesExtractor
を継承したクラスのforward
メソッドで任意の特徴量を返せるよう、
コンストラクタでネットワークを構築しておく。
学習結果
各アルゴリズムを使って学習させてみた。
- 学習を簡単にするため、取れる行動を右ダッシュ・右ダッシュ+ジャンプ・左の3つに限定してみた
- 1-1ならクリア可能なはず…
- NESの画素(256x240xRGB)をそのまま使うんじゃなく、縮小や畳み込みを行う:元のチュートリアルと同様に
- スキップフレーム:4フレームごとに行動させる
- グレイスケール化
- 84x84に画像縮小(CNNではなく単なる画像縮小)
- フレームスタック:4フレーム重ねて時間経過がわかるようにする
- CNN:84x84x4 → 20x20x32 → 9x9x64 → 7x7x64
- フラット化:特徴数3,136
- 全結合:3,136 → 512 → 行動数
- モデルの
train
メソッドのtotal_timesteps
で5,000,000
(エピソード数とは違う)
これをSB3の各種アルゴリズムで動かしてみた:PPO, TRPO, A2C, QRDQN, DQN。 学習には自分のマシンで各1日程度かかった。
考察
- 1-1かつ行動が3通りとそんなに難しくなさそうに思うけど、全然学習が安定しない
- PPOだけ性能がよくて他は全然、ハイパーパラメータが悪いとか?
- 学習結果をリプレイさせるために
model.predict
で行動を選んでもランダムに左右される - モデルに渡す画像は84x84のグレイスケール化しているが、人間的にはこんな画像で学習できる気がしない
クッパ面に挑戦
試した中ではPPOが一番よかったので、より難しそうなクッパ面で試してみた。
- 行動:
complex
- グレイスケール化をやめてRGBのままに
- CNNの特徴数:64-64-64
- スキップフレーム:4から2に
1,000万ステップほど学習させたところ運がよければクリアできるようになった:https://youtu.be/mlSjsejrrZY
- 行動を
complex
にしたのにちゃんとBダッシュメインで、穴やファイアを飛び越え、ファイアの前では減速、長ジャンプするの偉い
ソース
参考・リンク
- Stable Baselines 3 メインドキュメント
- PyTorchチュートリアル(日本語翻訳版) 4. 深層強化学習に「強化学習を用いたマリオの訓練」のノートがある
- Stable Baselinesを使ってスーパーマリオブラザーズ1-1をクリアするまで #Python - Qiita
- Super Mario Bros. with Stable-Baseline3 PPO SB3/PPOとGymnasiumを使ってマリオ
- 結構ちゃんと学習している、画像を
Rectangle
に、行動をSIMPLE
より絞ってるから?
- 結構ちゃんと学習している、画像を
- uvipen/Super-mario-bros-A3C-pytorch: Asynchronous Advantage Actor-Critic (A3C) algorithm for Super Mario Bros
- gifを見ると完璧と言っていいくらいちゃんと学習できてるんだけどそんなに上手くいくのか?
- 動画:Stable Baselines3 Tutorial: Beginner’s Guide to Choosing Reinforcement Learning Algorithms - YouTube SB3とその使い方の説明
付録
Stable Baselines 3の各種アルゴリズム
SB3に実装されているアルゴリズムがどんなものかググってみた:
- PPO:A2CとTRPOのアイディアを組み合わせ、更新後のポリシーが前のポリシーから乖離しすぎないようクリップする
- Maskable PPO:無効な行動をマスク
- TRPO:方策更新を大きくしすぎないようにして学習を安定化
- A2C :A3Cの「同期」版で、A3Cは非同期に複数のエージェントを用いて学習する
- QRDQN:カテゴリカルDQNの分布を分位点から求めることで改良
- DQN:Q学習をDNNで、リプレイ、ターゲットネットワークの固定化、勾配クリッピング
ハイパーパラメータ:
- DQN: learning_rate: 1e-4, gamma: 0.99
- QRDQN: learning_rate: 5e-4, gamma: 0.99
- A2C: learning_rate: 7e-4, gamma: 0.99
- TRPO: learning_rate: 1e-3, gamma: 0.99
- PPO: learning_rate: 3e-4, gamma: 0.99
動かせなかったアルゴリズム:
- ARS:
ValueError: Policy CnnPolicy unknown
- HER:
ImportError: Since Stable Baselines 2.1.0, HER is now a replay buffer class HerReplayBuffer
単独で使えなくなり、他のアルゴリズムのリプレイバッファとして指定する - RecurrentPPO:
CnnPolicy
を受け付けない