
AmundsenからAmazon Athenaの基本統計量を自動取得する
この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。
どうも!DA部の春田です。
Lyft社製OSSデータカタログAmundsenでは、テーブルの各カラムごとに基本統計量等を表示できる枠が用意されています。

出典: Amundsen Frontend Service - User Interface
しかし現状、枠は用意はされているのですが、中身自体は各自で統計情報を算出するためのスクリプトを書く必要があります。公式でもサンプルがなかったので、Amazon Athenaのテーブルを対象に実装してみました。
環境構築
ローカル(macOS)の環境構築については下記事をご参照ください。
EC2インスタンス上での環境構築については下記事の通りです。
今回はEC2インスタンスに構築したAmundsenを使用し、Amazon Athenaにロードしたテーブルを参照します。
基本統計量を計算してAmundsenに表示させる
以前の記事でAthenaのメタデータを取得するスクリプトを作成したので、今回はこれに基本統計量を計算する関数を追加していきます。解説が長くなってしまうので、先にコードと実行結果を載せておきます。
ジョブは以下のathena_sample_dag.pyに定義しています。
amundsen-sample/athena_sample_dag.py at main · TakumiHaruta/amundsen-sample
上記で使用しているAthenaStatsExtractorクラスは、今回のために新規作成した以下のathena_stats_extractor.pyが元となっています。
amundsen-sample/athena_stats_extractor.py at main · TakumiHaruta/amundsen-sample
amundsenのリポジトリ上で両スクリプトを以下のパスで配置し、amundsen/amundsendatabuilder/配下でvenvの仮想環境に対してpython3 setup.py installを実行します。
- amundsen/amundsendatabuilder/example/dags/athena_sample_dag.py
- amundsen/amundsendatabuilder/databuilder/extractor/athena_stats_extractor.py
今回のサンプルデータは、AWS公式で提供されているcloudfront_logsとelb_logs_raw_nativeの2つを使用します。
お持ちのAWSアカウントで、手順通りDDLを流してテーブルを作成してください。一手間かかりますが、CSVのままでは集計関数が使いづらいので、列志向のParquetに変換したものをAmundsenから呼びたいと思います。フォーマットをParquetに指定した以下のDDLを流し、INSERT INTO SELECTでデータを移してください。
amundsen-sample/ddl at main · TakumiHaruta/amundsen-sample
INSERT INTO cloudfront_logs_parquet SELECT * fROM cloudfront_logs; INSERT INTO elb_logs_raw_native_parquet SELECT * fROM elb_logs_raw_native limit 10000; -- 件数が多いので限定
テーブルの準備ができたら、athena_sample_dag.pyを実行します。
python3 example/dags/athena_sample_dag.py --region 'ap-northeast-1' --s3output 's3://my-s3-bucket/athena/' --target_schema 'cm-haruta'
実行が完了すると、Amundsenのカラムメタ情報に基本統計量が追加されているのを確認できましたでしょうか?


以下、スクリプトを解説していきます。
Pythonコード解説
コードの解説にあたり、3節に分けます。
- create_table_stats_job関数の大枠のジョブ構成
- 独自作成したAthenaStatsExtractorクラス
- テーブル情報元のColumn_2.csvを出力している直前の関数create_table_extract_job
create_table_stats_job関数の大枠のジョブ構成
大枠のジョブ構成は以下の通りです。流れとしては、①Athenaから統計情報を取得し、②CSV形式でNeo4jにロードするというものです。AthenaStatsExtractorは今回独自に作成したクラスで、後ほど触れます。
job = DefaultJob(
    conf=job_config,
    task=DefaultTask(
        extractor=AthenaStatsExtractor(),
        loader=FsNeo4jCSVLoader(),
        transformer=NoopTransformer()
    ),
    publisher=Neo4jCsvPublisher()
)
job.launch()
基本的に全てのパラメータはjob_configの中でまとめて渡されます。その中で言及しておきたいのは以下の5つです。
job_config = ConfigFactory.from_dict({
    f'extractor.athena_metadata.{AthenaStatsExtractor.CATALOG_KEY}': catalog_source,
    f'extractor.athena_metadata.{AthenaStatsExtractor.TARGET_SCHEMA}': target_schema,
    f'extractor.athena_metadata.{AthenaStatsExtractor.TARGET_TABLE}': target_table,
    f'extractor.athena_metadata.{AthenaStatsExtractor.COLUMN_LIST}': column_list,
    ...,
    f'loader.filesystem_csv_neo4j.{FsNeo4jCSVLoader.FORCE_CREATE_DIR}': True,
    ...,
})
AthenaStatsExtractorで使用するパラメータは、テーブル名や文字列に変換されたカラム名のリストなど、すなわち1テーブルに関する情報です。複数テーブルにも対応できるよう、スクリプトでは①パラメータを定義、②ジョブを定義、③ジョブを起動というフローをforループで回しています。
FsNeo4jCSVLoader.FORCE_CREATE_DIRがTrueに指定されているのは、直前の別のジョブでFsNeo4jCSVLoaderで使用したCSVファイルを残しているためです。この残してあるCSVファイル(Column_2.csv)を使って、パラメータに渡すテーブル情報を作成しているのが以下です。今回は先日の記事で使用したAthenaのメタデータを取得する関数を使いまわすためのコードになっていますが、例えば既にNeo4Jに登録してあるテーブル情報から作成する、ということも可能でしょう。
target_cols = f'{tmp_folder}/nodes/Column_2.csv'
with open(target_cols, 'r') as r:
    reader = csv.DictReader(r)
    column_data = dict()
    for line in reader:
        key = line['KEY'].split('/')
        catalog_source = key[-3].split('.')[0]
        target_table = key[-3].split('.')[1] + '.' + key[-2]
        target_column = key[-1]
        if target_table in column_data.keys():
            column_data[target_table] += [target_column]
        else:
            column_data[target_table] = [target_column]
for k, column_list in column_data.items():
    target_schema, target_table = k.split('.')
    column_list = json.dumps(column_list)
     job_config = ConfigFactory.from_dict({
        ...,
続いて、ジョブのExtractorで使用しているAthenaStatsExtractorクラスについて解説します。
独自作成したAthenaStatsExtractorクラス
AthenaStatsExtractorは今回の肝です。統計情報をAmundsen上で表示させるためのデータモデルとして用意されている、TableColumnStatsクラスを活用します。TableColumnStatはAmundsen公式のサンプルデータロードにも使われているクラスですね。
下のextractは、job.launch()を実行した時にジョブ内のタスクで実行されるメソッドです。extractが実行されると、SQL Alchemyで取得したAthenaのクエリ結果を、1行ずつTableColumnStatに当てめるイテレータが作成されます。
def extract(self) -> Union[TableColumnStats, None]:
    if not self._extract_iter:
        self._extract_iter = self._get_extract_iter()
    try:
        return next(self._extract_iter)
    except StopIteration:
        return None
def _get_extract_iter(self) -> Iterator[TableColumnStats]:
    """
    Provides iterator of result row from SQLAlchemy extractor
    :return:
    """
    row = self._alchemy_extractor.extract()
    while row:
        yield TableColumnStats(
            row['table_name'],
            row['col_name'],
            row['stat_name'],
            row['stat_val'],
            row['start_epoch'],
            row['end_epoch'],
            row['db'],
            row['cluster'],
            row['schema']
        )
        row = self._alchemy_extractor.extract()
実行されるSQLは、下記のinitメソッドの中で定義されています。job_configで渡されたパラメータを展開し、_create_sqlでSQLを作成した後、SQLAlchemyのExtractorを定義しています。
def init(self, conf: ConfigTree) -> None:
    conf = conf.with_fallback(AthenaStatsExtractor.DEFAULT_CONFIG)
    self._cluster = conf.get_string(AthenaStatsExtractor.CATALOG_KEY)
    self._target_schema = conf.get_string(AthenaStatsExtractor.TARGET_SCHEMA)
    self._target_table = conf.get_string(AthenaStatsExtractor.TARGET_TABLE)
    self._column_list = json.loads(conf.get_string(AthenaStatsExtractor.COLUMN_LIST))
    
    self.sql_stmt = self._create_sql(
        self._cluster,
        self._target_schema,
        self._target_table,
        self._column_list
    )
    
    LOGGER.info('SQL for Athena stats: %%s', self.sql_stmt)
    self._alchemy_extractor = SQLAlchemyExtractor()
    sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
        .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))
    self._alchemy_extractor.init(sql_alch_conf)  # execute_query
    self._extract_iter: Union[None, Iterator] = None
_create_sqlメソッドでは、パラメータからカラム名を各関数に当てはめて、統計情報をTableColumnStatsに合う形で出力するSQLを作成します。長いので掲載は割愛します。
    def _create_sql(self, catalog_source, target_schema, target_table, column_list):
        col_name_sql = ', '.join(["'" + col + "'" for col in column_list])
        str_convert_sql = ', '.join([
            'cast("' + col + '" as varchar) as "' + col + '"' for col in column_list])
        max_col_sql = ','.join([f'''
            coalesce(
              cast(max(try_cast("{col}" as bigint)) as varchar),
              cast(max(try_cast("{col}" as double)) as varchar),
              cast(try(max("{col}")) as varchar)
            )''' for col in column_list])
        min_col_sql = ','.join([f'''
            coalesce(
              cast(min(try_cast("{col}" as bigint)) as varchar),
              cast(min(try_cast("{col}" as double)) as varchar),
              cast(try(min("{col}")) as varchar)
            )''' for col in column_list])
        avg_col_sql = ','.join([f'''
            coalesce(
              cast(avg(try_cast("{col}" as bigint)) as varchar),
              cast(avg(try_cast("{col}" as double)) as varchar),
              null
            )''' for col in column_list])
        stdev_col_sql = ','.join([f'''
            coalesce(
              cast(stddev(try_cast("{col}" as bigint)) as varchar),
              cast(stddev(try_cast("{col}" as double)) as varchar),
              null
            )''' for col in column_list])
        med_col_sql = ','.join([f'''
            coalesce(
              cast(approx_percentile(try_cast("{col}" as bigint), 0.5) as varchar),
              cast(approx_percentile(try_cast("{col}" as double), 0.5) as varchar),
              null
            )''' for col in column_list])
        cnt_col_sql = ','.join([
            f'cast(count("{col}") as varchar)' for col in column_list])
        uniq_col_sql = ','.join([
            f'cast(count(distinct "{col}") as varchar)' for col in column_list])
        nul_col_sql = ','.join([
            f'cast(sum(case when "{col}" is null then 1 else 0 end) as varchar)' for col in column_list])
        SQL_STATEMENT = f"""
        WITH str_convert AS (
          SELECT {str_convert_sql}
          FROM "{target_schema}"."{target_table}"
        ), max_col AS (
          SELECT
            'max' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{max_col_sql}] as stat_val
          FROM str_convert
        ), min_col AS (
          SELECT
            'min' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{min_col_sql}] as stat_val
          FROM str_convert
        ), avg_col AS (
          SELECT
            'avg' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{avg_col_sql}] as stat_val
          FROM str_convert
        ), stdev_col AS (
          SELECT
            'std dev' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{stdev_col_sql}] as stat_val
          FROM str_convert
        ), med_col AS (
          SELECT
            'median' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{med_col_sql}] as stat_val
          FROM str_convert
        ), cnt_col AS (
          SELECT
            'num rows' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{cnt_col_sql}] as stat_val
          FROM str_convert
        ), uniq_col AS (
          SELECT
            'num uniq' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{uniq_col_sql}] as stat_val
          FROM str_convert
        ), nul_col AS (
          SELECT
            'num nulls' as stat_name,
            array[{col_name_sql}] as col_name,
            array[{nul_col_sql}] as stat_val
          FROM str_convert
        ), union_table AS (
          SELECT t1.stat_name, t2.col_name, t2.stat_val FROM max_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM min_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM avg_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM stdev_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM med_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM cnt_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM uniq_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
          UNION SELECT t1.stat_name, t2.col_name, t2.stat_val FROM nul_col t1
          CROSS JOIN UNNEST (col_name, stat_val) AS t2(col_name, stat_val)
        )
        SELECT
          '{catalog_source}' as cluster,
          'athena' as db,
          '{target_schema}' as schema,
          '{target_table}' as table_name,
          col_name,
          stat_name,
          stat_val,
          to_unixtime(now()) as start_epoch,
          to_unixtime(now()) as end_epoch
        FROM union_table
        ORDER BY cluster, db, schema, table_name, col_name
        ;
        """
        return SQL_STATEMENT
これらのメソッド群を使用して、Athenaのテーブルで統計量を計算し、クエリ結果をNeo4Jのデータモデルに合うロードしています。
テーブル情報元のColumn_2.csvを出力している直前の関数create_table_extract_job
最後に、create_table_extract_jobのパラメータについてだけ触れます。デフォルトでは、ジョブ実行後に出力したデータは削除されてしまうのですが、以下のパラメータを渡してあげることで、自動削除を無効にすることができます。
job_config = ConfigFactory.from_dict({
    ...,
    f'loader.filesystem_csv_neo4j.{FsNeo4jCSVLoader.SHOULD_DELETE_CREATED_DIR}': False,
    f'loader.filesystem_csv_neo4j.{FsNeo4jCSVLoader.FORCE_CREATE_DIR}': True,
    ...,
})
その他は、先日の記事とほぼ同じかと思います。
まとめ
Athena以外のデータソースに対しても、以下のポイントを抑えればご自身で構築できるかと思います。
- 基本統計量の計算(SQL)の出力結果を、Neo4Jのデータモデルに合わせる
- 出力結果を当てはまるデータモデルとして、TableColumnStatsクラスを活用する
- 基本統計量を求めるテーブルのカラム情報の取得方法を考えておく
- 対象DBから取得するか?既存のNeo4Jのデータから取得するか?
 
ご参考になれば幸いです。








