[Boto3 Adv-Cal DAY5]Athenaで集計した結果をローカルに落としてみる

boto3 で楽しむ AWS PythonLife 一人AdventCalendarです。

boto3のドキュメントを通して、サービス別にどういった事が出来るのかを理解したり、管理コンソールを通さずにTerminalだけで完結できるように検証していくことが目的になります。

5日目はAmazon Athenaで集計且つ結果をローカルにダウンロードしてみます。

boto3を通してAmazon Athenaで出来ること

ドキュメントは下記リンク先です。

ざっくりと以下のことができます。

  • 名称付きクエリの操作(取得・作成・削除)
  • クエリの実行ステータス取得
  • クエリの実行結果取得
  • クエリの実行
  • クエリの停止

実行可能なQuery文字数

最小1文字、最大262144文字です。

Length Constraints: Minimum length of 1. Maximum length of 262144.

名前付きクエリ - named_query

作成する際に名前を付けたクエリです。名前を付けておくことで再利用等が行いやすくなります。

今回の操作

Athena上にデータベースを作成して、SelectクエリをNamedQueryとして作成し、結果をS3からダウンロードしてみます。

  1. NamedQueryの作成
  2. クエリの実行
  3. S3からCSVのダウンロード

実行

ボリュームの関係で2つに分けました。

% python create_named_query.py
% python execute_query.py

create_named_query.py

import boto3
import os
import re
import argparse

class AthenaNamedQueryWizard:
    _client_name = 'athena'
    _session = None

    def __init__(self, profile_name):
        self._session = boto3.Session(profile_name=profile_name)

    @property
    def session(self):
        return self._session

    def get_client(self, client_name=None):
        if not client_name:
            client_name = self._client_name
        return self.session.client(client_name)

    @property
    def client_name(self):
        return self._client_name

    def get_aws_account_id(self):
        return self.get_client('sts').get_caller_identity().get('Account')

    def create_named_query(self, name, query, database):
        params = {
            'Name': name,
            'QueryString': query,
            'Database': database
        }
        return self.get_client().create_named_query(**params)

    @classmethod
    def prompt_query_name(cls):
        query_name = None
        while True:
            query_name = input('\nInput query_name >>')
            if query_name and len(query_name) != 0:
                break
        return query_name

    @classmethod
    def prompt_query_body(cls):
        query_name = None
        while True:
            query_body = input('\nInput query >>')
            if query_body and len(query_body) != 0:
                break
        return query_body

    @staticmethod
    def prompt(database):
        if not database:
            database = 'default'

        params = {}
        default_profile_name = 'default'
        profile_name = input('Input Profile name [{}]>> '.format(default_profile_name))
        if len(profile_name) == 0:
            profile_name = default_profile_name
        athena = AthenaNamedQueryWizard(profile_name)

        query_name = athena.prompt_query_name()
        query_body = athena.prompt_query_body()

        if query_name and query_body and database:
            print(athena.create_named_query(query_name, query_body, database))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--database')
    args = parser.parse_args()
    AthenaNamedQueryWizard.prompt(args.database)

execute_query.py

import boto3
import os
import time

class AthenaExecuteNamedQueryWizard:
    _client_name = 'athena'
    _session = None
    _bucket_name = "aws-athena-query-results-{account_id}-{region}"

    def __init__(self, profile_name):
        self._session = boto3.Session(profile_name=profile_name)

    @property
    def session(self):
        return self._session

    def get_client(self, client_name=None):
        if not client_name:
            client_name = self._client_name
        return self.session.client(client_name)

    @property
    def client_name(self):
        return self._client_name

    @property
    def bucket_name(self):
        params = {
            'account_id': self.get_aws_account_id(),
            'region': self.session.region_name
        }
        return self._bucket_name.format(**params)

    def get_aws_account_id(self):
        return self.get_client('sts').get_caller_identity().get('Account')

    def get_named_queries(self):
        client = self.get_client()
        list_named_queries = client.list_named_queries()
        return client.batch_get_named_query(NamedQueryIds=list_named_queries['NamedQueryIds'])

    def run_query(self, query_string=None, named_query_id=None, database=None):
        if (not query_string) and named_query_id:
            named_status = self.get_named_query_status(named_query_id)
            query_string = named_status['NamedQuery']['QueryString']
            database = named_status['NamedQuery']['Database']

        execution_status = {
            'QueryString': query_string,
            'ResultConfiguration':{'OutputLocation': "s3://{}/".format(self.bucket_name)}
        }
        if database:
            execution_status['QueryExecutionContext'] = {'Database': database}

        start_response = self.get_client().start_query_execution(**execution_status)
        return self.get_result(start_response['QueryExecutionId'])

    def get_result(self, query_execution_id):
        client = self.get_client()

        def __wait_proc():
            def __get_query_execution():
                wait_params = {
                    'QueryExecutionId': query_execution_id
                }
                result = client.get_query_execution(**wait_params)
                return result['QueryExecution']['Status']['State']

            def __in_process():
                return __get_query_execution() == 'RUNNING'

            def __is_failed():
                return __get_query_execution() == 'FAILED'

            while __in_process():
                time.sleep(3)

            if __is_failed():
                return False
            return True

        if not __wait_proc():
            raise Exception('Failed to complete query')
        return query_execution_id

    def get_named_query_status(self, named_query_id):
        return self.get_client().get_named_query(NamedQueryId=named_query_id)

    def download_csv(self, query_execution_id):
        download_filename = "{}.csv".format(query_execution_id)
        download_path = os.path.join(os.path.dirname(__file__), download_filename)
        message_tmpl = "Download from {from} to {to}"
        msssage_tmpl_params = {
            'from': '/'.join([self.bucket_name, download_filename]),
            'to': download_path
        }
        print(message_tmpl.format(**message_tmpl_params))
        self.get_client('s3').download_file(
            self.bucket_name,
            download_filename,
            download_path
        )

    @staticmethod
    def prompt():
        params = {}
        default_profile_name = 'default'
        profile_name = input('Input Profile name [{}]>> '.format(default_profile_name))
        if len(profile_name) == 0:
            profile_name = default_profile_name
        wizard = AthenaExecuteNamedQueryWizard(profile_name)

        queries = wizard.get_named_queries()
        named_query_ids = [row['NamedQueryId'] for row in queries['NamedQueries']]
        names = [row['Name'] for row in queries['NamedQueries']]
        query_id = None
        while True:
            print('Input QueryId')
            for query in queries['NamedQueries']:
                print('[{}] - {}'.format(query['NamedQueryId'], query['Name']))
            query_id = input('>> ')
            if query_id and query_id in named_query_ids:
                break

        query_execution_id = wizard.run_query(named_query_id=query_id)
        wizard.download_csv(query_execution_id)

if __name__ == '__main__':
    AthenaExecuteNamedQueryWizard.prompt()

batch_get_named_query()

NamedQueryの束を返します。batch処理が実行されるわけではありません。

NamedQueryの実行

NamedQueryIdを直接用いた実行関数は現時点で存在しないようです。NamedQueryIdを元にQueryそのものとDatabase名を取得し、それらを元に実行する形にしてみました。

まとめ

NamedQueryをIDから直接実行できない点に気が付くまで掛かったこと以外は、特に支障なく出来たと思います。

Athenaの実行後、S3からダウンロードするまで手間要らずの実行も実装可能で、頻繁に操作を行っている方にはオススメです。