山pの楽しいお勉強生活

勉強の成果を垂れ流していきます

PySparkで特定のカラムが全体の最大値であるレコードを取得する

概要

実現はできてはいたものの、もっと良いやり方ないかな?と聞いたら教えてもらったのでメモ。

うまく説明できないのでデータを記載します。

処理前

+----+------+
|name|  date|
+----+------+
|   a|201906|
|   a|201907|
|   b|201906|
|   b|201907|
|   c|201907|
+----+------+

処理後

+----+------+
|name|  date|
+----+------+
|   a|201907|
|   b|201907|
|   c|201907|
+----+------+

教えてもらった方法

from pyspark.sql import functions as f
from pyspark.sql.window import Window as w

df = spark.createDataFrame(
  [['a', '201906'], ['a', '201907'], ['b', '201906'], ['b', '201907'], ['c', '201907']],
  ['name', 'date']
)
df.show()

result_df = (
  df
    .withColumn('max_date', f.max('date').over(w.partitionBy()))
    .filter(f.col('date') == f.col('max_date'))
    .drop('max_date')
)
result_df.show()

試行錯誤の内容もメモ

当初書いたコードや、途中のコードもメモ

from pyspark.sql import functions as f
from pyspark.sql.window import Window as w

df = spark.createDataFrame(
  [['a', '201906'], ['a', '201907'], ['b', '201906'], ['b', '201907'], ['c', '201907']],
  ['name', 'date']
)
df.show()

# 1行ではあるが、一旦actionが走るので遅い気がする。
# また、filterの中でdfを使っているので、dfが定義されている必要がある。(読み込みからメソッドチェーンで繋げない)
result_df_2 = df.filter(f.col('date') == df.agg(f.max('date')).first()[0])
result_df_2.show()

# 同じwindow関数だったら最大を取るという意図からしてmax使った方がわかりやすい
result_df_3 = (
  df
    .withColumn('rank', f.rank().over(w.partitionBy().orderBy(f.col('date').desc())))
    .filter(f.col('rank') == 1)
    .drop('rank')
)