シナプスのシミュレーション 2
引き続きBrain2というライブラリを試してみる.
今回は前回のコードをさらに拡張したものになっている.
具体的には 出力ニューロン増加やヒートマップ表示等々を行なっている.
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
# シミュレーション設定
start_scope()
prefs.codegen.target = 'numpy'
# 定数
tau = 20*ms # 時定数(変更可能)
Vt = -50*mV
Vr = -60*mV
El = -49*mV
# ニューロンモデル(LIF)
eqs = '''
dv/dt = (El - v) / tau : volt (unless refractory)
'
# ニューロングループ(入力2、出力3)
G1 = NeuronGroup(2, eqs, threshold='v > Vt', reset='v = Vr', refractory=5*ms, method='exact')
G2 = NeuronGroup(3, eqs, threshold='v > Vt', reset='v = Vr', refractory=5*ms, method='exact')
G1.v = Vr
G2.v = Vr
# スパイク入力
input_times = [10, 50, 90, 130, 170, 210]*ms
indices = [0, 1, 0, 1, 0, 1]
P = SpikeGeneratorGroup(2, indices, input_times)
Sinput = Synapses(P, G1, on_pre='v_post += 2*mV')
Sinput.connect(j='i')
# **STDPモデル**
stdp_model = '''
w : 1
dApre/dt = -Apre / (tau) : 1 (event-driven)
dApost/dt = -Apost / (tau) : 1 (event-driven)
'
# **通常STDP**
stdp_on_pre = '''
v_post += w * mV
Apre += 0.01
w = clip(w + Apost, 0, 1)
'
stdp_on_post = '''
Apost += -0.012
w = clip(w + Apre, 0, 1)
'
# **Hebbian学習(発火したら強化)**
hebb_on_pre = '''
v_post += w * mV
w = clip(w + 0.005, 0, 1)
'
# **anti-STDP(通常STDPと逆)**
anti_on_pre = '''
v_post += w * mV
Apre += 0.01
w = clip(w - Apost, 0, 1)
'
anti_on_post = '''
Apost += -0.012
w = clip(w - Apre, 0, 1)
'
# **シナプス結合**
S = Synapses(G1, G2, model=stdp_model, on_pre=stdp_on_pre, on_post=stdp_on_post)
S.connect()
S.w = '0.5 + 0.1*rand()' # 初期値ランダム
# **モニタリング**
mon_v = StateMonitor(G2, 'v', record=True)
mon_w = StateMonitor(S, 'w', record=True)
spike_mon = SpikeMonitor(G2)
# **シミュレーション**
run(250*ms)
# **可視化**
fig, axes = plt.subplots(1, 2, figsize=(12,5))
# 出力ニューロンの膜電位
for i in range(3):
axes[0].plot(mon_v.t/ms, mon_v.v[i]/mV, label=f'Output Neuron {i}')
axes[0].set_xlabel('Time (ms)')
axes[0].set_ylabel('Voltage (mV)')
axes[0].set_title('Output Neuron Membrane Potential')
axes[0].legend()
# シナプス重みの時間変化
for i in range(6): # 2入力 x 3出力 = 6接続
axes[1].plot(mon_w.t/ms, mon_w.w[i], label=f'Synapse {i}')
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Synaptic Weight')
axes[1].set_title('STDP Synaptic Weight Changes')
axes[1].legend()
plt.show()
# **学習前後のシナプス重みヒートマップ**
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
# 初期値
initial_weights = S.w[:].reshape(2, 3)
ax[0].imshow(initial_weights, cmap='viridis', aspect='auto')
ax[0].set_title("Initial Synaptic Weights")
ax[0].set_xlabel("Output Neurons")
ax[0].set_ylabel("Input Neurons")
ax[0].set_xticks(range(3))
ax[0].set_yticks(range(2))
# 学習後
final_weights = mon_w.w[:, -1].reshape(2, 3)
ax[1].imshow(final_weights, cmap='viridis', aspect='auto')
ax[1].set_title("Final Synaptic Weights")
ax[1].set_xlabel("Output Neurons")
ax[1].set_xticks(range(3))
ax[1].set_yticks(range(2))
plt.show()

