2014-05-08

バンディットアルゴリズムによる最適化手法 5章

Epsilon-Greedyの次はSoftmax。腕の実装とテスト実行のコードは前に使った物と同じ。

In [22]:
%run './shared_functions.ipynb'

Softmaxアルゴリズムの実装

Softmaxは過去に得た腕毎の報酬の期待値を元に、期待値が高い腕を多く試行する。
期待値をどれだけ利用するかは温度パラメータ $tau$ で制御する
  • $tau \rightarrow \infty$ の時、完全にランダムに腕を選択する
  • $tau \rightarrow 0$ の時、過去の期待値に従う
In [24]:
class Softmax(object):
    def __init__(self, temperature):
        self.counts = None
        self.values = None
        self.temperature = temperature
        
    def initialize(self, n_arms):
        # 腕を何回引いたか
        self.counts = zeros(n_arms, dtype=int)
        # 引いた腕の報酬の平均値
        self.values = zeros(n_arms)
    
    @staticmethod
    def categorical_draw(probs):
        z = random.random()
        cum_prob = 0.0
        for i in range(len(probs)):
            prob = probs[i]
            cum_prob += prob
            if cum_prob > z:
                return i
        return len(probs) - 1
    
    def select_arm(self):
        z = sum([exp(v/self.temperature) for v in self.values])
        probs = [exp(v / self.temperature) / z for v in self.values]
        return self.categorical_draw(probs)
    
    def update(self, chosen_arm, reward):
        # 腕を選んだ回数をインクリメント
        self.counts[chosen_arm] += 1
        n = self.counts[chosen_arm]
        
        # 腕の平均報酬額を更新
        value = self.values[chosen_arm]
        new_value = ((n-1)/float(n)) * value + (1/float(n)) * reward
        self.values[chosen_arm] = new_value
In [23]:
# 結果プロット用の処理
def plot_results(simulate_num, horizon, best_arm, results):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    plot1, plot2, plot3 = axes
    
    x = range(horizon)
    
    for result in test_results:
        accuracy = zeros(horizon)
        reward_ave = zeros(horizon)
        cumulative_rewards = zeros(horizon)

        param, chosen_arms_m, rewards_m, cumulative_rewards_m = result

        for i in xrange(horizon):
            best_arm_selected_count = len(filter(lambda choice: choice == best_arm, chosen_arms_m[:,i]))
            accuracy[i] = best_arm_selected_count / float(simulate_num)
            reward_ave[i] = average(rewards_m[:,i])
            cumulative_rewards[i] = average(cumulative_rewards_m[:,i])
            
        plot1.plot(x, accuracy, label='%10.2f' % param)
        plot2.plot(x, reward_ave, label='%10.2f' % param)
        plot3.plot(x, cumulative_rewards, label='%10.2f' % param)

    plot1.legend(loc=4)
    plot1.set_xlabel('Time')
    plot1.set_ylabel('Probability of Selecting Best Arm')
    plot1.set_title('Accuracy of the \nSoftmax Algorithm')
    
    plot2.legend(loc=4)
    plot2.set_xlabel('Time')
    plot2.set_ylabel('Average Reward')
    plot2.set_title('Performance of the \nSoftmax Algorithm')
    
    plot3.legend(loc=4)
    plot3.set_xlabel('Time')
    plot3.set_ylabel('Cumulative Reward of Chosen Arm')
    plot3.set_title('Cumulative Reward of the \nSoftmax Algorithm')

実行

Epsilon-Greedyより、報酬の高い腕に集中するのが速い。
In [30]:
SIMULATE_NUM = 5000
HORIZON = 250

means = [0.1, 0.1, 0.1, 0.1, 0.9]
random.shuffle(means)
arms = map(lambda mu: BernoulliArm(mu), means)
best_arm = array(means).argmax()

test_results = []
for temperature in [0.1, 0.2, 0.4, 0.8, 1.6]:
    algo = Softmax(temperature)
    chosen_arms_mat, rewards_mat, cumulative_rewards_mat = test_algorithm(algo, arms, SIMULATE_NUM, HORIZON)
    test_results.append([temperature, chosen_arms_mat, rewards_mat, cumulative_rewards_mat])
plot_results(SIMULATE_NUM, HORIZON, best_arm, test_results)

腕の報酬にわずかな差しか無い場合

期待値の差が小さくなるため、腕を選択する頻度も近くなる。よって最も報酬の高い腕を選択する確率は上がりにくい。
In [26]:
SIMULATE_NUM = 5000
HORIZON = 250

means = [0.2, 0.2, 0.2, 0.2, 0.3]
random.shuffle(means)
arms = map(lambda mu: BernoulliArm(mu), means)
best_arm = array(means).argmax()

test_results = []
for temperature in [0.1, 0.2, 0.4, 0.8, 1.6]:
    algo = Softmax(temperature)
    chosen_arms_mat, rewards_mat, cumulative_rewards_mat = test_algorithm(algo, arms, SIMULATE_NUM, HORIZON)
    test_results.append([temperature, chosen_arms_mat, rewards_mat, cumulative_rewards_mat])
plot_results(SIMULATE_NUM, HORIZON, best_arm, test_results)

このエントリーをはてなブックマークに追加