NNablaでDQN書いてる(途中)
Sonyが公開したDLフレームワーク「NNabla」で何か書いてみようかと思い、Mnih et al., 2015のDQN(Deep Q-Network)を実装してみた。
ソースは以下。 ただし、現在学習中で結果を見れていないため、ソースにはバグがあるかも。
NNablaは初めて使うフレームワークのため、精度がでない原因がNNablaの使い方が間違いか、ハイパーパラメータの調整が足りないのか切り分けられないと困ると思った。 そこで、まずは他のフレームワークの実装をみて、ハイパーパラメータ等を参考にすることにした。 今回は以下を参考にしたが、非常にわかりやすい解説&コードでした。
DQNをKerasとTensorFlowとOpenAI Gymで実装する
この記事では、現状書いたところまでで気がついたことなどを書く。
今後は、まず普通のDQNで動くことを確認後、BinaryなDQNで実験してみたい。
実装
特にNNablaで実装する上でしていて気になったことなどを記載する
Target Networkのパラメータの更新
# update target network if step % Config.TARGET_NET_UPDATE_INTERVAL == 0: # copy parameter from dqn to target with nn.parameter_scope("dqn"): src = nn.get_parameters() with nn.parameter_scope("target"): dst = nn.get_parameters() for (s_key, s_val), (d_key, d_val) in zip(src.items(), dst.items()): # Variable#d method is reference d_val.d = s_val.d.copy()
- NNablaでは、NNのパラメータはすべてグローバルに管理。区別するために名前空間を使う仕組み。
with nn.parameter_scope('<name>')
で、このコンテキストではnameの階層に属するパラメータだけ扱えるようになる- Variable#dで、Variableがもつnumpyのndarrayに直接アクセスできる
順伝播処理
# inference image.d = state q.forward() action = np.argmax(q.d)
プレースホルダーであるVariableへの値のセットと、ネットワークの順伝播実行が命令として別れているのが、少し慣れなかった。 (たしか、プレースホルダーに値をセットせずにforwardした場合、ときたまSEGVしたような…)
その他
- Replay Memoryのサイズを大きくしすぎるとメモリ確保できずに強制修了するので注意
- システムモニターをみてすぐに気がつけてよかった
- 学習環境はDockerのコンテナで作ってる。そのため、gymのGUIが表示できなかった (できる?)
- ドキュメントで、Functionの一覧が把握しづらいなと感じた
まだ慣れていないことと、学習状況の可視化やデバッグがTensorFlowとかに比べてまだプアなので、積極的にこのライブラリを使うモチベーションがなければ、TensorFlowとかのほうが使うのが良いかも?
私の場合は、面白半分、BNNを動かしたい、組み込み用途への期待、でこのライブラリを今後もWatchしていきたいと思ってる。
環境構築
環境構築していて気になったことを記載する
NVIDIA-Docker
今回、初めてNVIDIA-Dockerを使ったところ非常に便利だった。
GPU環境構築の問題として、DLフレームワークによって対応するCUDA ToolkitやcuDNNのバージョンがまちまちなので環境が共存できなかったりする。
NVIDIA-Dockerを使えば、ホストにはGPUドライバだけインストールし、CUDAとcuDNNは各コンテナにインストールできるので、コンテナ毎に異なるバージョンが使える。
NVIDIA-Dockerのベースイメージは、様々なCUDAとcuDNNのバージョンの組み合わせがあり、以下で確認できる。
今回、NNabla用には、 nvidia/cuda:8.0-cudnn6-runtime-ubuntu16.04
を利用した*1
https://hub.docker.com/r/nvidia/cuda/
Hydrogen
JupyterなどとAtomをつないで、Atomをリッチな開発環境にできる。
変数のWatchなどもできてよかったので、今後も使ってこう。
複数のDockerコンテナを登録して簡単に切り替えられるので便利。設定はURL https://t.co/r4FtKxeiS5 pic.twitter.com/usjfY0U8bg
— tkato (@_tkato_) 2017年7月1日