Goでゼロから作るDeep LearningフレームワークDeZeroを実装してみた
はじめまして。
サイバーエージェントゲーム・エンターテイメント事業部(SGE)に所属する子会社QualiArtsでサーバーサイドエンジニアをしている高嶋です。本記事はQualiArtsの定期ブログ「QualiArts Tech Note」第2弾の記事となります。QualiArtsでは会社で使われている様々な技術の知見をブログとして配信しています。
はじめに
SGEおよび、QualiArtsでは活発に勉強会が開催されており、私はその中でもディープラーニングを含むAIに関する勉強会によく参加しています。この勉強会では、毎回、それぞれ本を読み進め、わからない場所について議論したり、実際に手を動かして実装方法や動きを確認したりしています。現在、勉強会では「ゼロから作るDeep Learning ❸」を読み進めており、この本ではPythonを用いた、ディープラーニングのフレームワーク(通称 DeZero)の作り方が紹介されています。
一方で、私は一昨年から業務でGoを使うようになり、今や一番好きな言語になりつつあります。そのような背景から、この本の実装をPythonではなく、Goで実装することに挑戦しました。このGoでの実装を行うにあたり、そのまま実装することができなかった点について、どのように実装したかをお話しようと思います。なお、今回の実装はあくまで、「なるべくそのまま」載っている実装をGoで再現し、実装することを目的としており、より良い実装方法がある(であろう)ことはご承知ください。
Goとは
GoはGoogleで開発されたオープンソースのプログラミング言語(https://go.dev/)です。
私は一昨年までJavaのプロジェクトに従事していましたが、Goの学習コストはそれほど高いとは感じませんでした。上記の特徴はQ&Aなどに記載されている内容ですが、上記以外に私がGoの特徴として感じている所は以下のとおりです。
- 依存モジュール管理が簡単
- コード整形の機能が標準で付いている
- コードの自動生成が簡単
- エラーは基本的に戻り値として返す
- 継承・総称型がない
継承がないというのはJavaを使っていた私としてはだいぶ戸惑いました。今回のPythonのコードをGoにするにあたっても問題となりました。このあたりについては、後ほどお話します。
DeZeroとは
DeZeroはこの「ゼロから作るDeep Learning ❸(https://www.oreilly.co.jp/books/9784873119069/)」で出てくるオリジナルのフレームワークです。この本では、DeZeroを段階的に作りながらディープラーニングのフレームワークの作成方法について学んでいきます。コードのサンプルはPythonで記述されており、Pythonならではの効率的な書き方などについても触れられているので、大変勉強になりました。しかし、私はGoでこれを作るとどうなるのかについて興味が湧いたので、それを実践してみました。
実装
では早速ですが、実際に「ゼロから作るDeep Learning ❸」のステップ43「ニューラルネットワーク」までの内容を実装したものがこちら(https://github.com/qua-tkmax/dezerogo)です。なお、計算グラフの可視化についてはスキップしています。
Gonum
Goでの行列の計算には、Gonum(https://www.gonum.org/)を使用しました。
Gonumは数値計算のアルゴリズムを効率的に書くためのパッケージで、ここには行列計算のためのライブラリを始め、統計や確率計算のためのライブラリなどが含まれています。
Goで行列の計算をしたいときにはよく使われている印象があります。例えば、Goの機械学習のためのライブラリ「Gorgonia(https://gorgonia.org/)」でも使用されています。使い方としては、やや直感的ではないところがあり、例えばXとYの和を計算したいときは以下のようになります。
func add(x ,y *mat.Dense) *mat.Dense {
var z mat.Dense
z.Add(x, y)
return &z
}
このように、Gonumでは基本的に、計算結果を戻り値として返すのではなく、計算結果を入れるDenseを作り、そのDenseのメソッドを使って計算します。ここは初めて使うときには少し躓く点かもしれません。
実装したネットワーク
ステップ43では以下のような2層のニューラルネットワークを作ります。
システムの主な処理の流れは以下のとおりです。
- y=sin(2πx)+randの関数を使って、xとyのデータセットを100個作る
(randは[0,1)の一様乱数) - 最適化する重みとバイアス(W1,W2,b1,b2)を初期化する
- 2層ニューラルネットワークを使用して、各xに対するyPredを計算し、
yとの誤差を計算する - 傾きを計算し、それをもとに重みとバイアスを調整する
- 3に戻る
この3と4の繰り返しを10000回行い、最終的な結果としました。
実行の結果を以下に記します。
まず、ループ回数10000回までの、1000回ごとの誤差の変化は以下のとおりです。
ループ回数: 0 誤差:0.713802
ループ回数: 1000 誤差:0.244415
ループ回数: 2000 誤差:0.240605
ループ回数: 3000 誤差:0.232764
ループ回数: 4000 誤差:0.208652
ループ回数: 5000 誤差:0.153773
ループ回数: 6000 誤差:0.099183
ループ回数: 7000 誤差:0.084764
ループ回数: 8000 誤差:0.081310
ループ回数: 9000 誤差:0.079427
ループ回数:10000 誤差:0.078173
このように、ループごとに誤差が減っていっているのが確認できます。
続いて、最適化された重みとバイアスを使い、xを0から1までを100分割し計算した結果をプロットしたものが以下のグラフになります(なお、期待される結果である y=sin(2πx)+0.5 をプロットしています)。
ちゃんと動いていそうですね。
ここまで作るにあたって、本に書いてある通りの方法で実装できなかった箇所と、私が試したその解決方法を説明します。
1. 継承
DeZeroでは、各演算はFunctionを継承する形で実装しています。しかし、Goには継承がありません。これは、継承がないほうが記述量が減り、また多重継承などの複雑さを回避できるという考えに基づきます。しかし、これまで継承を活用していた人からすると、共通の処理を継承させたいというときもあります。そういうとき、Goではメソッドのポインタが役に立ちます。GoではC言語などのように、メソッドのポインタも変数として保持することができます。今回のケースで言うと、各演算の共通処理に関してはFunction側に実装し、各演算特有の処理に関しては、それぞれで実装しFunction生成時に引数としてメソッドのポインタを渡す実装にしました。
以下がFunctionの実装です。
type Function struct {
forwardFunc func(values []*Variable, args []interface{})
[]*Variable
backwardFunc func(values []*Variable, gradYs []*Variable, args
[]interface{}) ([]*Variable, error)
stringFunc func(values []*Variable, args []interface{}) string
Inputs []*Variable
Outputs []*Variable
generation int
args []interface{}
}
func CreateFunction(
forwardFunc func(values []*Variable, args []interface{})
[]*Variable,
backwardFunc func(values []*Variable, gradYs []*Variable, args
[]interface{}) ([]*Variable, error),
stringFunc func(values []*Variable, args []interface{}) string,
args ...interface{},
) *Function {
return &Function{
forwardFunc: forwardFunc,
backwardFunc: backwardFunc,
stringFunc: stringFunc,
args: args,
}
}
このようにCreateFunctionというメソ ッドを作っておき、引数としてforwardとbackwardのメソッドのポインタを受け取るようにしています。このCreateFunctionは各Functionの生成ごとに呼ばれます。
例えば、AddのFunctionであれば以下のように実装されています。
func Add(variable1 *model.Variable, variable2 *model.Variable)
*model.Variable {
return model.CreateFunction(forwardAdd, backwardAdd,
toStringAdd).ForwardOne(variable1, variable2)
}
func forwardAdd(values []*model.Variable, _ []interface{})
[]*model.Variable {
value1 := values[0].Data
value2 := values[1].Data
value1, value2, err := util.BroadcastDenses(value1, value2)
if err != nil {
panic(err)
}
var result mat.Dense
result.Add(value1, value2)
return []*model.Variable{{Data: &result}}
}
func backwardAdd(values []*model.Variable, gradYs []*model.Variable,
_ []interface{}) ([]*model.Variable, error) {
rows1, cols1 := values[0].Data.Dims()
rows2, cols2 := values[1].Data.Dims()
var result1, result2 *model.Variable
result1 = gradYs[0]
result2 = gradYs[0]
if rows1 == rows2 && cols1 == cols2 {
return []*model.Variable{result1, result2}, nil
}
rRows1, rCols1 := result1.Data.Dims()
if rows1 == 1 && cols1 == 1 {
result1 = Sum(result1, 0)
} else if cols1 == 1 && rows1 == rRows1 {
result1 = Sum(result1, 1)
} else if rows1 == 1 && cols1 == rCols1 {
result1 = Sum(result1, 2)
}
rRows2, rCols2 := result2.Data.Dims()
if rows2 == 2 && cols2 == 2 {
result2 = Sum(result2, 0)
} else if cols2 == 2 && rows2 == rRows2 {
result2 = Sum(result2, 1)
} else if rows2 == 2 && cols2 == rCols2 {
result2 = Sum(result2, 2)
}
return []*model.Variable{result1, result2}, nil
}
func toStringAdd(values []*model.Variable, _ []interface{}) string {
stringValues := make([]string, 0, len(values))
for _, value := range values {
stringValues = append(stringValues, value.ToString())
}
return strings.Join(stringValues, " + ")
}
このように実装することで、継承のようなことができます。
2.オーバーロード
DeZeroでは、計算の表記を簡潔にするために演算子のオーバーロードを使用しています。また、引数に型もないので、一つのメソッドで、スカラ値でも行列でも受け取れるようになっています。一方Goでは演算子・メソッドのオーバーロードはありません。これは、オーバーロードは便利であるが、実際には混乱しやすく壊れやすいという考えに基づいています。オーバーロードに関しては同じように実装するのは諦めました。引数に関しては、interface{}を使うことも考えましたが各所に型の判定を入れることも避けたかったので、代わりに、1×1を簡単に作ることができるメソッドを用意することで対応しました。
func CreateScalarVariable(value float64) *Variable {
return &Variable{
Data: mat.NewDense(1, 1, []float64{value}),
}
}
おわりに
このブログでは、DeZeroフレームワークをGoで実装し、その結果とGoで組むにあた って工夫した点について書きました。このように、そのままとはいきませんが、Goでも十分DeZeroフレームワークを作ることはできそうです。別の言語で作ってみることで、ただ内容を追うよりも理解できる気がするので、ぜひ皆さんもやってみてください。