[競プロ][Python] LRUキャッシュを実装する

2024.05.20

はじめに

LRUCacheとは最近使用されていないデータを優先的に置き換えるキャッシュアルゴリズムです。

すごく素朴にキャッシュを実装すると、HashMapを使って

class SimpleCache:
    def __init__(self):
        self._map = {}

    def get(self, key: str) -> str:
        return self._map[key]

    def put(self, key: str, value: str) -> None:
        self._map[key] = value

# 使い方
cache = SimpleCache()
cache.put("apple", "リンゴ")
cache.put("banana", "バナナ")

のようにキャッシュを実装できるかと思います。

ただ、現実的には使えるメモリには限りがあるため、無制限にkeyとvalueを追加していくといつかその上限を超過してしまいます。

この点、LRUCacheを採用すると格納データが一定量を超えた時にLRU(Least Recently Used=最も最近使われていない)なデータを都度追い出してくれるため、限られたメモリにおいてキャッシュ機構を実現することが可能になります。

この仕組みを手を動かして理解するのに、LeetCodeの

LRU Cache - LeetCode

がとても良い題材だと思ったので解説的に紹介してみたいと思います。

ちなみに自分は海外エンジニアYoutuber NeetCode氏の

LRU Cache - Twitch Interview Question - Leetcode 146 - YouTube

を見て本問を学習しました。動画でわかりやすいので、英語に抵抗がなければとてもおすすめです。

問題

詳しくは元の問題を見て頂ければと思いますが、以下の雛形コード(Pythonを選択しています)が与えられるので、実装を埋めてLRU Cacheを実現してくださいという問題です。

class LRUCache:
    def __init__(self, capacity: int):

    def get(self, key: int) -> int:

    def put(self, key: int, value: int) -> None:

その他以下を考慮する必要があります。

  • capacity(格納できるデータの個数)を考慮すること
  • get() put() は平均してO(1)の操作とすること

解法

LinkedListのNodeクラスを定義する

データ構造としてLikedListとHashMapを用います。HashMapはPythonのdictを使用すれば良いとして、LinkedListはNodeという独自のクラスを定義して実現することにします。

class Node:
    def __init__(self, key: int, value: int):
        self.key = key
        self.value = value
        self.prev_node = None
        self.next_node = None

prev_node (前のNode)と next_node (後ろのNode)を持たせて、双方向連結のLinkedListとします。

LinkedListを用いることで、挿入については挿入位置の前後のNodeへの参照を持っていればO(1)で、削除については対象Nodeの参照を持っていればO(1)でそれぞれ操作を行うことができます。

LRUCacheクラスの初期化処理を実装する

問題より与えられたLRUCacheクラスの初期化処理を実装します。

class LRUCache:
    def __init__(self, capacity: int):
        self._capacity = capacity
        self._map = {}
        self._head_node = Node(0, 0)
        self._tail_node = Node(0, 0)
        self._head_node.next = self._tail_node
        self._tail_node.prev = self._head_node

まずコーディング作法的なところで、すべてのインスタンス変数についてクラスの外から利用する値ではない想定のため _ のprefixをつけてprivateであることを示すようにします。

_capacity は格納するキャッシュの最大件数、 _map はすべてのNodeへの参照を持ちます。

_head_node_tail_node は一番最初と一番最後のNodeを追加・取得しやすいように設置するダミーのNodeです。keyとvalueはなんでも良いため、それぞれ 0 にしておきます。

_head_node_tail_node を双方向に繋いでおき、この2つの間にNodeを追加していきます。

_removeNode()メソッドを定義する

下記のような _removeNode() privateメソッドを定義しておきます。

    # 渡されたNodeを削除する
    def _remove_node(self, node: Node) -> None:
        prev_node = node.prev_node
        next_node = node.next_node

        # 前後のNodeから対象Nodeへの参照を外す
        prev_node.next_node = next_node
        next_node.prev_node = prev_node

        # _mapから対象Nodeを削除する
        del self._map[node.key]

引数で渡されたNodeを削除する操作です。

まず対象の前後のNodeを取得します。

そして、その前後Node同士を繋ぎます。

_map からも対象Nodeを削除します。

こうすることで対象Nodeへの参照がなくなり、ガベージコレクションの対象となる(はずの)ため対象Nodeが削除される、という操作をO(1)で行えます。

_addLatestNode()メソッドを定義する

続いて下記のような _addLatestNode() privateメソッドを定義しておきます。

    # 渡されたNodeを最新として追加する
    def _add_latest_node(self, node: Node) -> None:
        prev_node = self._tail_node.prev_node
        next_node = self._tail_node

        # 前のNodeと対象Nodeを繋げる
        prev_node.next_node = node
        node.prev_node = prev_node

        # 後ろのNodeと対象Nodeを繋げる
        node.next_node = next_node
        next_node.prev_node = node

        # _mapへ対象Nodeを追加する
        self._map[node.key] = node

引数で渡されたNodeを末尾(正確には _tail_node の1つ前)に追加する操作です。

まず _tail_node_tail_node の1つ前のNodeを取得します。

そして、 _tail_node と対象Nodeを繋ぎ、_tail_node の1つ前にあったNodeと対象Nodeも繋ぎます。

2つのNodeの間に入れ込むようなイメージですね。

_map へのNodeの追加も忘れずに行います。

これで対象Nodeを末尾へ追加する操作をO(1)で行うことができます。

get()メソッドを実装する

では問題で実装を要求されている get() メソッドを実装していきます。

    # keyに対応したvalueを取得する
    # 取得したNodeを最後尾に移動させる
    def get(self, key: int) -> int:
        # 対象keyがない場合は-1を返す
        if key not in self._map:
            return -1

        # 対象keyに紐づくNodeをLinkedListの最後尾に移動させる
        node = self._map[key]
        self._remove_node(node)
        self._add_latest_node(node)

        # valueを返す
        return node.value

渡されたkeyに対応したNodeを _map から探し、存在しなければ -1 を返して終了します。

存在する場合、一旦NodeをLinkedListから取り除いてしまってから最後尾に追加します。こうすることで参照されたNodeを最新として更新できます。

put()メソッドを実装する

続けて put() メソッドも実装していきます。

    # 新しいkeyとvalueを挿入する
    # すでにkeyが存在する場合は更新する
    def put(self, key: int, value: int) -> None:
        # すでに対象keyが存在する場合、一旦削除する
        if key in self._map:
            self._remove_node(self._map[key])

        # 新しくNodeを作成し、LinkedListの最後尾に追加する
        new_node = Node(key, value)
        self._add_latest_node(new_node)

        # Nodeの総数が_capacityを超えた場合、先頭(つまり最も最近参照されていない)Nodeを削除する
        current_capacity = len(self._map)
        if current_capacity > self._capacity:
            oldest_node = self._head_node.next_node
            self._remove_node(oldest_node)

すでに渡されたkeyに紐づくNodeが存在する場合、削除してから再び最新Nodeとして加えることでvalueとLinkedListの順序を更新します。存在しなかった場合は、単に最新Nodeとして追加します。

ここで、現在のNode総数がインスタンス初期化時に設定した _capacity を上回ってしまう場合、許容量オーバーということなのでいずれかのNodeを削除する必要があります。本問はLRUキャッシュの実装なので、LRU(最も最近参照されていない、言い換えると順序が最も古い)であるLinkedList先頭のNodeを削除することとなります。

コード全体

以上のコード全体が下記になります。

class Node:
    def __init__(self, key: int, value: int):
        self.key = key
        self.value = value
        self.prev_node = None
        self.next_node = None

class LRUCache:
    def __init__(self, capacity: int):
        self._capacity = capacity
        self._map = {}
        self._head_node = Node(0, 0)
        self._tail_node = Node(0, 0)
        self._head_node.next_node = self._tail_node
        self._tail_node.prev_node = self._head_node

    # 渡されたNodeを削除する
    def _remove_node(self, node: Node) -> None:
        prev_node = node.prev_node
        next_node = node.next_node

        # 前後のNodeから対象Nodeへの参照を外す
        prev_node.next_node = next_node
        next_node.prev_node = prev_node

        # _mapから対象Nodeを削除する
        del self._map[node.key]

    # 渡されたNodeを最新として追加する
    def _add_latest_node(self, node: Node) -> None:
        prev_node = self._tail_node.prev_node
        next_node = self._tail_node

        # 前のNodeと対象Nodeを繋げる
        prev_node.next_node = node
        node.prev_node = prev_node

        # 後ろのNodeと対象Nodeを繋げる
        node.next_node = next_node
        next_node.prev_node = node

        # _mapへ対象Nodeを追加する
        self._map[node.key] = node

    # keyに対応したvalueを取得する
    # 取得したNodeを最後尾に移動させる
    def get(self, key: int) -> int:
        # 対象keyがない場合は-1を返す
        if key not in self._map:
            return -1

        # 対象keyに紐づくNodeをLinkedListの最後尾に移動させる
        node = self._map[key]
        self._remove_node(node)
        self._add_latest_node(node)

        # valueを返す
        return node.value

    # 新しいkeyとvalueを挿入する
    # すでにkeyが存在する場合は更新する
    def put(self, key: int, value: int) -> None:
        # すでに対象keyが存在する場合、一旦削除する
        if key in self._map:
            self._remove_node(self._map[key])

        # 新しくNodeを作成し、LinkedListの最後尾に追加する
        new_node = Node(key, value)
        self._add_latest_node(new_node)

        # Nodeの総数が_capacityを超えた場合、先頭(つまり最も最近参照されていない)Nodeを削除する
        current_capacity = len(self._map)
        if current_capacity > self._capacity:
            oldest_node = self._head_node.next_node
            self._remove_node(oldest_node)

おわりに

以上、LRUキャッシュについて、LeetCode問題を題材に実際にコードを書いて学んでみるという内容でした。

次はLFU(Least Frequently Used)キャッシュの構造も下記問題などに取り組んで学習したいと思います。

LFU Cache - LeetCode

以上、参考になれば幸いです。

参考