Amazon Neptune ServerlessでグラフDBを構築し、最短経路問題を解いてみた

2023.08.04

データアナリティクス事業本部のueharaです。

今回は、Neptune ServerlessでグラフDBを構築し、最短経路問題を解いてみたいと思います。

Neptuneとは

Amazon NeptuneはフルマネージドなグラフDBサービスです。

Neptuneについては以下の記事で紹介されていますので、より詳しく知りたい方は以下をご参考ください。

Neptune Serverlessの構築

NeptuneはServerlessでも構築可能なので、今回はそちらを利用したいと思います。

AWSのマネージドコンソールからNeptuneを検索し、「Amazon Neptune を起動」を選択します。

以下のような画面が表示されますので、まずエンジンのタイプを Serverless で設定します。

バージョンは今回は Neptune 1.2.0.1.R2 を利用します。

DBクラスター識別子は任意の値になりますので、 db-test-uehara にしました。

キャパシティの設定もお好みですが、今回はお試し利用なので最小・最大NCU共に 設定可能な最も低い値 を設定しました。

テンプレートは 開発とテスト を選択肢、可用性と耐久性はマルチAZを なし にします。

接続に関しては、適当なVPCを選択頂ければ問題ありません。

サブネットグループやセキュリティグループは今回新規作成することにしました。

Neptuneにアクセスするための、Sagemakerを通じてホストされるJupyter Notebookもこの段階で作成できるので、今回はそちらを作成します。

私の手元では以下のように設定しました。

ここまで設定できたら、「データベースを作成」ボタンからDBを作成します。

以上で環境の構築は終了です。

Neptune Serverlessへの接続

先程Notebookを作成したので、AWSマネージメントコンソールからSagemakerを検索し、作成したNotebookインスタンスが作成されていることを確認します。

確認ができたら、Neptune Serverlessにアタッチしているセキュリティグループについて、Notebookインスタンスにアタッチしているセキュリティグループからのポート番号8182に対するインバウンドトラフィックのアクセスを許可しておきます。

これでNotebookインスタンスからNeptune Serverlessへの接続ができるようになります。

Notebookを開くと、デフォルトでいくつかのサンプルファイルが入ったディレクトリを確認することができます。

前提:構築するグラフ

今回はこちらのサイトで定義されている、以下のグラフを利用したいと思います。

引用元

上記グラフの定義は以下の通りです。

  • ノードはアルファベットに丸で、エッジはノードをつなぐ線で表現されている
  • エッジの近くの数字は、エッジの重み(ノード間の距離)を表す

やってみた

グラフの定義

NeptuneではGremlinやopenCypherなどを使ってクエリを実行することができます。

今回はGremlinを使いたいと思います。

先に示したグラフを構築するため、Notebookで以下を実行します。

%%gremlin

g.addV('node').property('id', 'a').as('a').
  addV('node').property('id', 'b').as('b').
  addV('node').property('id', 'c').as('c').
  addV('node').property('id', 'd').as('d').
  addV('node').property('id', 'e').as('e').
  addV('node').property('id', 'f').as('f').
  addV('node').property('id', 'g').as('g').
  addV('node').property('id', 'h').as('h').
  addE('edge').from('a').to('b').property('cost', 1).
  addE('edge').from('a').to('c').property('cost', 7).
  addE('edge').from('a').to('d').property('cost', 2).
  addE('edge').from('b').to('a').property('cost', 1).
  addE('edge').from('b').to('e').property('cost', 2).
  addE('edge').from('b').to('f').property('cost', 4).
  addE('edge').from('c').to('a').property('cost', 7).
  addE('edge').from('c').to('f').property('cost', 2).
  addE('edge').from('c').to('g').property('cost', 3).
  addE('edge').from('d').to('a').property('cost', 2).
  addE('edge').from('d').to('g').property('cost', 5).
  addE('edge').from('e').to('b').property('cost', 2).
  addE('edge').from('e').to('f').property('cost', 1).
  addE('edge').from('f').to('b').property('cost', 4).
  addE('edge').from('f').to('c').property('cost', 2).
  addE('edge').from('f').to('e').property('cost', 1).
  addE('edge').from('f').to('h').property('cost', 6).
  addE('edge').from('g').to('c').property('cost', 3).
  addE('edge').from('g').to('d').property('cost', 5).
  addE('edge').from('g').to('h').property('cost', 2).
  addE('edge').from('h').to('f').property('cost', 6).
  addE('edge').from('h').to('g').property('cost', 2).iterate()

Jupyter NotebookでGremlinによるクエリの実行は、先頭に%%gremlinを記載することで可能です。

ノードのプロパティは先に記載したグラフのようにa〜hまでのidを持たせ、エッジのプロパティにはcostとしてノード間の重みを定義しました。

以下のクエリで、作成したグラフを可視化してみます。

%%gremlin -p v,oute,inv

g.V().outE().inV().path().by('id').by('cost')

以下のように、グラフが定義できていれば成功です。

Gremlinのクエリで最短経路を確認してみる

今回はノードaからノードhまでの最短経路を求めることを要件にしたいと思います。

以下のクエリで、ノードaからノードhまでの各経路のトータルコストを計算し、昇順にならべることで最短経路を確認することができます。

%%gremlin

g.withSack(0).
  V().
  has('id','a').
  repeat(
    outE().sack(sum).by('cost').
    inV().simplePath()).
  until(has('id','h')).
  limit(5).
  order().
    by(sack()).
  local(
    union(
      path().
        by('id').
        by('cost').
        unfold(),
      sack()).
      fold())

結果は次のようになっており、a→d→g→hの順がトータルコスト9となり最短であることが分かります。

Pythonを用いたダイクストラ法で最短経路を確認してみる

グラフ上のノード間の最短経路を求めるアルゴリズムにダイクストラ法があります。

ここではダイクストラ法のアルゴリズムをPythonで実装し、再度ノードa〜ノードh間の最短経路を確認したいと思います。

PythonからNeptuneにアクセスするためにはgremlinpythonモジュールを使用します。

インストールは以下のコマンドをJupyter Notebookで実行することで可能です。

!pip install gremlinpython

インストールが完了したら、ダイクストラ法で最短経路を求めるPythonスクリプトを記載します。

全体のコードは以下の通りです。

from gremlin_python import statics
from gremlin_python.structure.graph import Graph
from gremlin_python.process.graph_traversal import __
from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection


# Neptuneへの接続を確立
graph = Graph()
connection = DriverRemoteConnection('wss://<your-neptune-endpoint>:8182/gremlin','g')
g = Graph().traversal().withRemote(connection)

# ノードとエッジを取得
nodes = g.V().valueMap(True).toList()  # 各ノードのIDとプロパティを取得
edges = g.E().toList()

# Neptuneから取得したノードとエッジをPythonのグラフ表現に変換
graph = {node['id'][0]: {} for node in nodes} 
for edge in edges:
    node_out_id = g.V(edge.outV.id).values('id').next()  # outVのidを取得
    node_in_id = g.V(edge.inV.id).values('id').next()  # inVのidを取得
    # 属性 'cost' をコストとして使用
    cost = g.E(edge.id).values('cost').next()
    # グラフデータ構造を更新
    graph[node_out_id][node_in_id] = cost

# 接続を閉じる
connection.close()


# ダイクストラ法
def dijkstra(graph, start, end):
    shortest_paths = {start: (None, 0)}
    current_node = start
    visited = set()

    while current_node != end:
        visited.add(current_node)
        destinations = graph[current_node]
        weight_to_current_node = shortest_paths[current_node][1]

        for next_node, weight in destinations.items():
            weight = weight + weight_to_current_node
            if next_node not in shortest_paths:
                shortest_paths[next_node] = (current_node, weight)
            else:
                current_shortest_weight = shortest_paths[next_node][1]
                if current_shortest_weight > weight:
                    shortest_paths[next_node] = (current_node, weight)

        next_destinations = {node: shortest_paths[node] for node in shortest_paths if node not in visited}
        if not next_destinations:
            return ("Route Not Found", None)
        current_node = min(next_destinations, key=lambda k: next_destinations[k][1])

    path = []
    while current_node is not None:
        path.append(current_node)
        next_node = shortest_paths[current_node][0]
        current_node = next_node
    path = path[::-1]
    
    return (path, shortest_paths[end][1])


# ノード'a'からノード'h'までの最短経路を求める
path, cost = dijkstra(graph, "a", "h")
print(f"最短経路: {path}")
print(f"トータルコスト: {cost}")

Neptuneから取得したノードとエッジの関係をPythonのdict形式に落とし込み、ダイクストラ法により最短経路を求めるスクリプトになっています。

your-neptune-endpointの部分はご自身のNeptune Serverlessのエンドポイントに置き換えてください。

実行結果は次の通りで、やはりa→d→g→hの順がトータルコスト9となり最短であることが分かります。

最後に

今回は、Neptune ServerlessでグラフDBを構築し、最短経路問題を解いてみました。

参考になりましたら幸いです。

参考文献