[Rust] バンディッドアルゴリズム(Epsilon-greedy)の実装

[Rust] バンディッドアルゴリズム(Epsilon-greedy)の実装

Clock Icon2021.07.16

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

Intoroduction

例えば、「複数のスロットマシンがあり、それぞれ期待値が異なるがその値はわからない場合、
どのスロットマシンを選ぶのが一番よいのか」

この問題を解決する手法の1つが、今回紹介するバンディッドアルゴリズムです。
本稿ではバンディッドアルゴリズムの簡単な説明と、
そのアルゴリズムを使用したコードをRustで実装してみます。
  

Bandit Algorithm?

バンディッドアルゴリズムは、システムが自分でいろいろと試行錯誤しながら
最適な結果を実現する強化学習(Reinforcement Learning)の中で、
代表的な手法のひとつです。

Multi-Armed Bandit Problem(多腕バンディット問題)とよばれる問題を解くためのアルゴリズムで、
これは先程もいったように、報酬の確率分布が未知の複数台のスロットマシンを繰り返しプレイするとき、
どういった方策をとればベストなのかを考える問題です。

このアルゴリズムを使用すると、A/Bテストと同じくデータに基づいた意思決定が可能です。

terminology

バンディッドアルゴリズムでよく使う用語は以下です。

  • アーム
    任意の時点で選択可能な選択肢。
    A/Bテストをバンディッドアルゴリズムの対象にするならアームは2つ。

  • 報酬(リターン)
    アームを選択したことにより得られる価値。
    それぞれのアームが持つ確率密度分布に従って報酬が得られると仮定する。

  • 方策
    事前に決めたアルゴリズムに従ってアームを選択する手法。
    ※Policyというときもある
    この方策の設計如何でバンディットアルゴリズムの性能がきまる。

  • 試行
    あるアームを選択し、報酬を得る行動。

Implement the Epsilon-greedy algorithm

バンディッドアルゴリズムでもいくつか種類があります。
今回はバンディッドアルゴリズムの中でも最もシンプルなEpsilon-greedyアルゴリズムを実装します。
(他にSoftmax AlgorithmやUpper Confidence Boundといったアルゴリズムもあります)

Epsilon-greedy spec

Epsilon-greedyはかなり単純です。

  • 設定した任意の割合(0 < ε < 1)でランダムに選択肢からアームを選択
  • 1-εの割合で、「いままでの結果が最も良い選択肢」を選択

例えばε=0.3にしたら、

30%の確率でアームをランダムに選択、70%の確率で、その時点で最も結果が良いアームを選択。
そして、選択したアームで試行します。
試行の結果、報酬が得られたらそれを積み上げます。

とてもシンプルです。

Environment

  • OS : MacOS 10.15.7
  • Rust : 1.52.1

Make Program with Rust

ではRustで実装してみます。
cargoをつかってプロジェクトの作成。

% cargo new bandit-sample

Cargo.tomlでライブラリの追加をします。
二項分布、ランダム値を使うのでstatrsとrandのcrateを追加します。

[dependencies]
statrs = "0.15.0"
rand = "0.8.0"

そしてmain.rsで実装します。
※今回は面倒なので1つのファイルにすべて記述

//! bandit algorithm sample project.

use rand::prelude::*;
use rand::rngs::StdRng;
use rand::{thread_rng, Rng};
use statrs::distribution::Binomial;

/// Arm struct
///
/// # Parameters
/// * `p` - probability
/// * `success` - Number of successes
/// * `fail` - Number of failures
/// * `binomial` - Binomial
struct Arm {
    _p: f64,
    success: i32,
    fail: i32,
    binomial: Binomial,
}

impl Arm {
    /// create new Arm.
    ///
    /// # Parameters
    /// * `p` - probability
    pub fn new(p: f64) -> Self {
        Arm {
            _p: p,
            success: 0,
            fail: 0,
            binomial: Binomial::new(p, 1).unwrap(),
        }
    }

    /// trial Arm.
    pub fn play(&mut self) -> i32 {
        let mut rng = rand::rngs::StdRng::from_entropy();
        let result = self.binomial.sample::<StdRng>(&mut rng) as i32;
        if result == 1 {
            self.success += 1;
        } else {
            self.fail += 1;
        }
        result
    }

    /// calc success
    pub fn calc_success(&self) -> f64 {
        if self.success + self.fail == 0 {
            0.0
        } else {
            (self.success as f64) / (self.success + self.fail) as f64
        }
    }
}

/// ramdom
///
/// # Parameters
/// * `arms` - Vector of Arm
/// * `count` - count of get reward
fn ramdom_greedy(arms: &mut Vec<Arm>, count: i32) -> i32 {
    let mut reward = 0;

    for _i in 0..count {
        let mut rng = thread_rng();
        let index = rng.gen_range(0..arms.len());
        reward += arms[index].play();
    }
    reward
}

/// search & utilization
///
/// # Parameters
/// * `arms` - Vector of Arm
/// * `count` - count of get reward
/// * `epsilon` - search rate parameter
fn epsilon_greedy(arms: &mut Vec<Arm>, count: i32, epsilon: f64) -> i32 {
    let mut reward = 0;

    for _i in 0..count {
        //epsiron(〜1.0)の確率で表がでるコインをふる
        let b = Binomial::new(epsilon, 1).unwrap();
        let mut rng = rand::rngs::StdRng::from_entropy();
        let index: usize;

        if b.sample(&mut rng) as i32 == 1 {
            // search: select random arm
            let mut rng = thread_rng();
            index = rng.gen_range(0..arms.len());
        } else {
            // utilization : Choose the arm with the highest probability of success ever
            index = select_highest_arm_index(&arms);
        }
        reward += arms[index].play();
    }
    reward
}

/// select highest Arm Index
///
/// # Parameters
/// * `arms` - Vector of Arm
fn select_highest_arm_index(arms: &Vec<Arm>) -> usize {
    if arms.len() == 1 {
        0
    } else {
        let mut highest_arm_index: usize = 0;
        for i in 0..arms.len() - 1 {
            if arms[i].calc_success() < arms[i + 1].calc_success() {
                highest_arm_index = i + 1;
            }
        }
        highest_arm_index
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_play() {
        let mut test_arm = Arm::new(0.5);
        let result = test_arm.play();
        assert!(result == 0 || result == 1);
        assert_ne!(test_arm.success, test_arm.fail);
    }
    #[test]
    fn test_calc_success() {
        let mut test_arm = Arm::new(0.5);
        test_arm.fail = 20;
        test_arm.success = 5;
        assert_eq!(test_arm.calc_success(), 0.2);
    }

    #[test]
    fn test_select_highest_arm_index() {
        let mut test_arm1 = Arm::new(0.1);
        test_arm1.fail = 9;
        test_arm1.success = 1;
        let mut test_arm2 = Arm::new(0.1);
        test_arm2.fail = 8;
        test_arm2.success = 2;
        let mut test_arm3 = Arm::new(0.1);
        test_arm3.fail = 7;
        test_arm3.success = 3;
        let mut test_arm4 = Arm::new(0.1);
        test_arm4.fail = 6;
        test_arm4.success = 4;

        let mut v = vec![test_arm1];
        assert_eq!(select_highest_arm_index(&v), 0);
        v.push(test_arm2);
        assert_eq!(select_highest_arm_index(&v), 1);
        v.push(test_arm3);
        assert_eq!(select_highest_arm_index(&v), 2);
        v.push(test_arm4);
        assert_eq!(select_highest_arm_index(&v), 3);
    }
}

/// main関数
fn main() {
    let count = 10000;
    let arm1 = Arm::new(0.5);
    let arm2 = Arm::new(0.2);
    let arm3 = Arm::new(0.7);
    let mut v = vec![arm1, arm2, arm3];

    println!("Ramdom : {}", ramdom_greedy(&mut v, count));

    let epsilon = 0.4;
    println!("ϵ-greedy : {}", epsilon_greedy(&mut v, count, epsilon));
}

main関数では完全ランダムで10000回試行した結果と
バンディッドアルゴリズム(ε-greedy)で試行した結果を表示します。

% cargo run
Finished dev [unoptimized + debuginfo] target(s) in 0.16s
Running `target/debug/bandit-sample`

Ramdom : 4670
ϵ-greedy : 5994

ϵ-greedyのほうが報酬が高いことがわかります。

Summary

今回はRustでバンディットアルゴリズム(ε-greedy)を実装してみました。
バンディットアルゴリズムには、これ以外にもいろいろな種類があるので、興味のあるかたは
このあたりこのあたりを参考にしてみてください。

References

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.