· 

Spiking Neural UnitをChainerで実装してみた

(著)山拓

Spiking Neural Unit (SNU)をChainer-v5で実装しました。

Paper: https://arxiv.org/abs/1812.07040

 

コードはhttps://github.com/takyamamoto/SNU_Chainerにあります。

 

SNUは通常のNeural NetworkのフレームワークでSpiking Neural Networkが実装できる(ついでにBack-propagationを用いて学習もできる)という、面白いユニットです。

 

ユニットをpaper通り実装するのはすぐできたのですが、学習が全く進まなかったので、数カ月放置しました。もう一度チャレンジし、色々いじくってたら学習が進んだので結果をまとめました。

 

Spiking Neural Unitについて

SNUはstateを持つfeed-forwardのユニットです。他にstateを持つユニットとしてはLSTMやGRUなどがあります。

(Wozniak, et al., 2018; Fig.1)

 

構造はLeaky-integrate and fire(LIF)ニューロンの式を離散化し、 スパイクが生じる($y=1$)と、膜電位を想定した$s$が0にリセットされるというものです。式としては \begin{align*} s_t&=g\left(Wx_t+l(\tau)\odot s_{t-1}\odot (1-y_{t-1})\right)\\ y_t&=h(s_t + b) \end{align*} と2つでかけます。ここで$g(\cdot)$はReLU, $h(\cdot)$はStep関数です。$l(\tau)=(1-\frac{\Delta t}{\tau})$は1step前の膜電位をどれだけ引き継ぐかの割合です。学習可能なパラメータとして実装も可能ですが、paperでは定数が用いられています。なお、実装では$\Delta t=1 \text{ms}, \tau=5 \text{ms}$としました。

上の式にはバイアスがありませんが、これは静止膜電位を0mVと設定しているためです。下の式はバイアスを足して(実際には閾値との差を計算して)、0を超えればスパイクが生じます($y=1$)。

学習させずに挙動を試してみると下図のようになります。上から入力(ポアソンスパイク)、SNUの膜電位$s_t$、SNUの出力$y_t$です。なお、ここのパラメータは学習させるものと異なり、Paper中のBrianとの比較における設定と同じです。

コード:check_snu_layer.py

LIFニューロンと同じ挙動をしていることが確認できます。

 

ここで問題となるのは、Step関数があるということです。Step関数は微分するとDiracのデルタ関数となり、誤差逆伝搬できません。そこでpaperでは疑似勾配(pseudo-derivative)としてtanhの微分を用いたそうです。ただし、学習できなかったので、これに手を加えました(後述)。

 

また、Step関数の代わりにSigmoidを用いたsoft-SNUも提案されていました。

 

実装上の変更点

toy-problemとして、後述するJittered MNIST(スパイク列に変換したMNIST)の分類を学習させました。ただし、うまくいかなかったので、以下の4点を変更しました。

変更点1

ReLUだとdying ReLUが起こっているようで学習がうまく進みませんでした。そこで、活性化関数としてExponential Linear Unit (ELU)を代わりに用いました。この変更は発火特性に影響を与えません。

 

変更点2

Step関数の疑似勾配(pseudo-derivative)について。tanhの微分では学習が進まなかったので、 $$ f'(x) = \begin{cases} 1 & (-0.5<x<0.5) \\ 0 & (\text{otherwise}) \end{cases} $$ としました。気持ちとしてはhard sigmoidのような関数の微分です。この考えは自分で思いついたものでないはずですが、どこで見たのか思い出せません…。

変更点3

Loss functionについて。MSEだと学習が進まなかったので、出力ユニットの全スパイク数を加算し、softmaxをかけて、labelとのcross entropyを取りました。

 

また出力ユニットの発火数を抑えるため、代謝コスト(metabolic cost)を損失に加えました。これはANNを脳のモデルとして捉える研究ではよく用いられるコストです。正則化の効果もあります。出力層の $i$ 番目のユニットの出力を $y_t^{(i)}$とすると、代謝コスト $C_{\text{met}}$は $$ C_{\text{met}}=\frac{10^{-2}}{N_t \cdot N_{\text{out}}}\sum_{t=1}^{N_t}\sum_{i=1}^{N_{\text{out}}} \left(y_t^{(i)}\right)^2 $$ となります。ただし、$N_t$はシミュレーションの総タイムステップ数、$N_{\text{out}}$は出力ユニットの数(今回だと10個)です。あまり大きくすると、分類誤差よりも代謝コストの方が大きくなってしまうので低めに設定します。

変更点4

optimizerについて。paperにあるようにSGD(lr=0.2)では学習が進まなかったので、Adam(alpha=1e-4)としました。

 

Jittered MNIST

学習はMNISTですが、スパイク列として入力する必要があります。手法が詳しく記されていなかったので、勝手に作りました。

 

まず、MNISTの画像を2値化し、1の画素値を持つユニットに100Hzのポアソンスパイクを入力しました。下の画像は左から「2値化した4」,「1 time stepに入力されるスパイク」、「入力される全てのスパイクを加算したもの」です。 

コード:check_jittered_mnist.py

 

入力がかなりnoisyなので、入力の周波数を大きくすればAccuracyも上がると思います。

 

結果

10 simulation time step (10ms) のシミュレーションで、100 epoch学習させました。

LossとAccuracy

左からLoss, Accuracyの変化です。train lossとaccuracyがとても変動しているので、バッチサイズを増やせば改善されるかもしれません。MNISTにしてはAccuracy低いですが、ずっとchance accuracy(~0.1)ぐらいをさまよっていたので、上手く学習できたといえます。なお、Paperとは入出力の条件が異なるので比較はできません。

コード:train.py

 

出力の挙動

ネットワーク内のユニットの挙動を見てみましょう。とはいえ、中間層は見てもよく分からないので、出力層だけ見てみます。学習は10msだけでしたが、外挿して100msのシミュレーションをテストします。下図の左は入力のスパイクの合計を正規化して画像化したものです。7と読めると思います。右はこのときの出力層のユニットのスパイクです。7番目のユニットがよく発火していることが分かります。

 

コード:analysis.py

 

なお、代謝コストが無い場合は、 頻繁に発火が見られ、biologicalニューロンの挙動とかけ離れたものになってしまいました。

 

このように、ちょっとした工夫でSpiking NNを普通のANNのライブラリで実装することができました。これを応用してSpiking ConvolutionもChainerやPytorch等で実装できるかもしれませんね!(誰かPaper書きそうですが)

  

Firing rate NN  vs Spiking NN

最後に普通のANNとSpiking NNの違い(特にRNNの学習)に触れておきます。

 

今のANNはニューロンの発火率モデル(Firing rate model)と捉えられます。要は、ANNの出力は発火率の期待値ということです。これに対してスパイクを出力するNNがSpiking NNです。今回のSNUはrate modelの実装に用いるフレームワークでspiking modelを実装したということが面白い点です。

 

実際のニューロンはスパイクしているので、spiking NNの方が生理学的にもっともらしいです。ただし、rate modelでもニューロンの挙動は(場合によりますが)近似できます。

 

また、Spiking NNの中でもrecurrent connectionを持つSpiking RNNは学習が困難ですが、今回用いたSNUではback-propagation through time (BPTT)を用いて学習可能です(feed-forwardだけでなくrecurrnet connectionも付けられるため)。他にはrate modelをspiking modelに変換することでSpiking RNNを実装する研究がいくつかあります。

(例えば一番直近だとhttps://www.biorxiv.org/content/10.1101/579706v1

 

Spiking modelに限らず、RNNの学習が難しいのは、Temporal credit assignment (TCA; 時間的信用割当)が必要となる点です。TCAの手法として、現時点で生理学的に妥当な学習方法は提案されていません。BPTTはTCAを可能にしますが、発火を全て記憶しておくのは生体内では困難です。ただ、脳内でやっていないことはないと考えられるので、今後の研究で分かればいいなと思います。

(参考:T.P. Lillicrap, A.Santoro. "Backpropagation through time and the brain". Curr. Opin. Neurobiol. (2019). (sciencedirect))