Spark SQLとDataFrame API入門 | Hadoop Advent Calendar 2016 #16

こんにちは、小澤です。 この記事はHadoop Advent Calendar 16日目のものとなります。

前回はSparkでWord Countの実装して動かす方法を紹介しました。
今回はSpark SQLとDataFrame APIについて書かせていただきます。

Spark SQLとDataFrame API

SparkはRDDに対して何かしらの処理を行った新しいRDDの生成を繰り返していくことで全体の処理フローを定義するものでした。
これに対してDataFrameというものはデータをテーブル構造で定義して、それに対する操作を記述していくものになります。
これはRやPythonなどデータ分析によく使われる言語におけるDataFrameと同じような概念と考えていいでしょう。 また、テーブル構造であるため、SQLでの処理フローの記述もかけるようになっており、これがSpark SQLとなります。

この段階ではまだ、DataFrameのイメージがつかめない方もいるかと思いますので、RDDとDataFrameの関係はMapReduceとHiveのような感じだと思っていただければいいかと思います。

また、Spark SQLに関してはもちろんテーブル構造になっているDataFrameを扱うだけでなく、JDBCを利用してRDBからデータを取得するなどもできます。

DataFrameを使ってみる

今回は2.0から導入された、SparkSessionというのを使います。1系をお使いの方はSQLContextをご利用かと思いますで適宜読み替えていただければと思います。
DataFrameはScala, Java, Python, Rから利用でします。今回はPythonを利用します。

SparkSessionの作成

まずは、DataFrameの生成に利用する、SparkSessionを生成します。

from pyspark.sql import SparkSession

spark = SparkSession.builder \
        .appName('Spark SQL and DataFrame') \
        .getOrCreate()

spark-shellやpysparkのシェルを利用している場合はSparkContextと同様に、sparkという変数名ですでに用意されていますでの、改めて作成する必要はありません。

DataFrameの作成

まずはデータを読み込んでDataFrameを作成します。 今回は、最近追加されたCSVファイルからの読み込みを行ってみます。

from pyspark.sql.types import *

# スキーマ定義
struct = StructType([
        StructField('sepal_length', DoubleType(), False),
        StructField('sepal_width', DoubleType(), False),
        StructField('petal_length', DoubleType(), False),
        StructField('petal_width', DoubleType(), False),
        StructField('species', StringType(), False)
     ])

# DataFrameの作成
df = spark.read.csv('../resources/iris/', schema=struct)

# 内容確認
df.show(5)
+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|    species|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
+------------+-----------+------------+-----------+-----------+

最初にスキーマの定義を行っています。 これは必須ではありませんが、指定しておくことをお勧めします。 指定せずに読み込みを行うと、以下のようにカラム名が分かりづらいものになりますし、すべてのカラムの型が文字列になってしまいます。

df = spark.read.csv('../resources/iris/')
df.printSchema()
root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)
 |-- _c2: string (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)

特に複雑な操作を行うようになると、このような状態ではデータを処理する過程でどのような処理を行っているのかの見通しが悪くなってしまいます。

また、CSVではなく、TSVなどを利用したい場合は

spark.read.csv('../resources/iris_tsv/', schema=struct, sep='\t')

というふうにsepで区切り文字を与えてやります。

データの操作

データの表示

まずはshow()について説明しておきます。 すでにDataFrameの作成時にも利用していましたが、データフレームの内容を表示するための関数になります。 列と行の絞り込みを行います。
show()はデフォルトで20件の表示を行いますが、引数で指定することで表示件数の変更ができます。

# 3件のみ表示
df.show(3)

行や列の絞り込み

列を絞り込むにはselectに絞り込みたい列名を渡してやります。

# 列名を指定して絞り込み
df.select('species').show(3)
+-----------+
|    species|
+-----------+
|Iris-setosa|
|Iris-setosa|
|Iris-setosa|
+-----------+
only showing top 3 rows

# 複数の列を選択する場合はlistで指定する
df.select(['species', 'sepal_length']).show(3)
+-----------+------------+
|    species|sepal_length|
+-----------+------------+
|Iris-setosa|         5.1|
|Iris-setosa|         4.9|
|Iris-setosa|         4.7|
+-----------+------------+
only showing top 3 rows

# DataFrame変数は['<列名>']でColumnオブジェクトを取得
# selectの引数はColumnオブジェクトを指定してもよい
df.select(df['petal_width'], df['petal_length']).show(3)
+-----------+------------+
|petal_width|petal_length|
+-----------+------------+
|        0.2|         1.4|
|        0.2|         1.4|
|        0.2|         1.3|
+-----------+------------+
only showing top 3 rows

# Columnオブジェクトに対して演算を行うとその列のデータ全てに対して適用される
df.select(df['petal_width'] + 1, df['petal_length']).show(3)
+-----------------+------------+
|(petal_width + 1)|petal_length|
+-----------------+------------+
|              1.2|         1.4|
|              1.2|         1.4|
|              1.2|         1.3|
+-----------------+------------+
only showing top 3 rows

行を絞り込む場合はfilterやwhereを利用します。

df.filter(df['species'] == 'Iris-virginica').show(3)
+------------+-----------+------------+-----------+--------------+
|sepal_length|sepal_width|petal_length|petal_width|       species|
+------------+-----------+------------+-----------+--------------+
|         6.3|        3.3|         6.0|        2.5|Iris-virginica|
|         5.8|        2.7|         5.1|        1.9|Iris-virginica|
|         7.1|        3.0|         5.9|        2.1|Iris-virginica|
+------------+-----------+------------+-----------+--------------+
only showing top 3 rows

df.where(df['sepal_length'] > 7).show(3)
+------------+-----------+------------+-----------+--------------+
|sepal_length|sepal_width|petal_length|petal_width|       species|
+------------+-----------+------------+-----------+--------------+
|         7.1|        3.0|         5.9|        2.1|Iris-virginica|
|         7.6|        3.0|         6.6|        2.1|Iris-virginica|
|         7.3|        2.9|         6.3|        1.8|Iris-virginica|
+------------+-----------+------------+-----------+--------------+
only showing top 3 rows

どちらもColumnオブジェクトに対する演算で条件を指定しています。

集約関数

集約関数としてはaggを使います。また、groupByと併用することによって、SQLのgroup byと同様のことができます。

from pyspark.sql import functions as func

# 集約関数
df.agg(func.mean('petal_length')).show()
+------------------+
| avg(petal_length)|
+------------------+
|3.7586666666666693|
+------------------+

# group化した上で集約をする
# aggには複数の集約関数を指定可能
df.groupBy('species').agg(func.mean('petal_length'), func.max('sepal_width')).show()
+---------------+-----------------+----------------+
|        species|avg(petal_length)|max(sepal_width)|
+---------------+-----------------+----------------+
| Iris-virginica|            5.552|             3.8|
|    Iris-setosa|            1.464|             4.4|
|Iris-versicolor|             4.26|             3.4|
+---------------+-----------------+----------------+

# aliasを利用して列名を指定する
df.groupBy('species').agg(func.sum('petal_width').alias('sum')).show()
+---------------+------------------+
|        species|               sum|
+---------------+------------------+
| Iris-virginica|101.29999999999998|
|    Iris-setosa|12.199999999999996|
|Iris-versicolor|              66.3|
+---------------+------------------+

Join

次にJoinをしてみます。 ここまでで利用してきたDataFrameが1つのみなので、Join対象となるものがありません。 適当にデータの操作を行い作成しますが、動作を見るための操作を行っているだけなので、この処理自体に特に意味があるわけではありません。

df_agg = df.groupBy('species').agg(func.mean('sepal_length').alias('sepal'), func.mean('petal_length').alias('petal'))

df_a = df_agg.filter(df['species'] != 'Iris-virginica').select(['species', 'sepal'])
df_a.show()
+---------------+-----------------+
|        species|            sepal|
+---------------+-----------------+
|    Iris-setosa|5.005999999999999|
|Iris-versicolor|            5.936|
+---------------+-----------------+

df_b = df_agg.filter(df['species'] != 'Iris-versicolor').select(['species', 'petal'])
df_b.show()
+--------------+-----+
|       species|petal|
+--------------+-----+
|Iris-virginica|5.552|
|   Iris-setosa|1.464|
+--------------+-----+

この2つのdf_aとdf_bを使ってJoinの動作を見てみます

# inner
df_a.join(df_b, df_a['species'] == df_b['species'], 'inner').show()
+-----------+-----------------+-----------+-----+
|    species|            sepal|    species|petal|
+-----------+-----------------+-----------+-----+
|Iris-setosa|5.005999999999999|Iris-setosa|1.464|
+-----------+-----------------+-----------+-----+

# 'left' or 'leftouter'
df_a.join(df_b, df_a['species'] == df_b['species'], 'left').show()
+---------------+-----------------+-----------+-----+
|        species|            sepal|    species|petal|
+---------------+-----------------+-----------+-----+
|    Iris-setosa|5.005999999999999|Iris-setosa|1.464|
|Iris-versicolor|            5.936|       null| null|
+---------------+-----------------+-----------+-----+

# 'right' or 'rigthouter'
df_a.join(df_b, df_a['species'] == df_b['species'], 'right').show()
+-----------+-----------------+--------------+-----+
|    species|            sepal|       species|petal|
+-----------+-----------------+--------------+-----+
|       null|             null|Iris-virginica|5.552|
|Iris-setosa|5.005999999999999|   Iris-setosa|1.464|
+-----------+-----------------+--------------+-----+

# 'full' or 'fullouter'
df_a.join(df_b, df_a['species'] == df_b['species'], 'full').show()
+---------------+-----------------+--------------+-----+
|        species|            sepal|       species|petal|
+---------------+-----------------+--------------+-----+
|           null|             null|Iris-virginica|5.552|
|    Iris-setosa|5.005999999999999|   Iris-setosa|1.464|
|Iris-versicolor|            5.936|          null| null|
+---------------+-----------------+--------------+-----+

# default
df_a.join(df_b, df_a['species'] == df_b['species']).show()
+-----------+-----------------+-----------+-----+
|    species|            sepal|    species|petal|
+-----------+-----------------+-----------+-----+
|Iris-setosa|5.005999999999999|Iris-setosa|1.464|
+-----------+-----------------+-----------+-----+

# joinのキーとなる値を指定しなかった場合
df_a.join(df_b).show()
+---------------+-----------------+--------------+-----+
|        species|            sepal|       species|petal|
+---------------+-----------------+--------------+-----+
|    Iris-setosa|5.005999999999999|Iris-virginica|5.552|
|    Iris-setosa|5.005999999999999|   Iris-setosa|1.464|
|Iris-versicolor|            5.936|Iris-virginica|5.552|
|Iris-versicolor|            5.936|   Iris-setosa|1.464|
+---------------+-----------------+--------------+-----+

最後がCross Joinになっている以外は想像どうりかと思います。実際のデータでのうっかり結合キーを指定し忘れに注意してください。

データの保存

最後に最終的な結果となるDataFrameを保存します。

df.write.csv('test', sep='\t')

のようにするとTSVファイルとして出力されます。他のHadoop系と同様、引数で指定したディレクトリが作成されその中に各ノードごとに別ファイルで出力を行います。
csvの部分を他のフォーマット名(jsonやjdbc, orc, parquetなど様々あります)に変更することでそのフォーマットでの出力が可能です。

SQLで操作してみる

これまでは、DataFrameの各関数を利用して処理を行う方法を見てきました。 これと同じ操作をSQLで記述することが可能です。
SQLでそうするにはDataFrameのテーブル名を指定します。これだけでSQL文による操作が可能になります。

# このデータフレームに対応するテーブル名を設定
df.createOrReplaceTempView('iris')

# SparkSessionのsqlメソッドから実行
# 戻り値はDataFrame
df2 = spark.sql("select * from iris where species = 'Iris-setosa' limit 3")
df2.show()
+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|    species|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
+------------+-----------+------------+-----------+-----------+

難しい点は特にないかと思います。 SQLの文法さえ知っていいれば、あとは前項に行ったDataFrameに対する各種処理と同様のことがSQLでも実現できることが容易に想像できるかと思います。

終わりに

今回はSpark SQLとDataFrame APIの基本的な機能の紹介をしました。 DataFrameを使っての処理の記述は現在の主流となっていますのでぜひいろいろ試してみてください。

明日は、DataFrameに類似した要素として、DataSet APIというものについて書かせていただく予定です。
ぜひ、お楽しみに!