PySparkで特定のカラムが全体の最大値であるレコードを取得する
概要
実現はできてはいたものの、もっと良いやり方ないかな?と聞いたら教えてもらったのでメモ。
うまく説明できないのでデータを記載します。
処理前
+----+------+ |name| date| +----+------+ | a|201906| | a|201907| | b|201906| | b|201907| | c|201907| +----+------+
処理後
+----+------+ |name| date| +----+------+ | a|201907| | b|201907| | c|201907| +----+------+
教えてもらった方法
from pyspark.sql import functions as f from pyspark.sql.window import Window as w df = spark.createDataFrame( [['a', '201906'], ['a', '201907'], ['b', '201906'], ['b', '201907'], ['c', '201907']], ['name', 'date'] ) df.show() result_df = ( df .withColumn('max_date', f.max('date').over(w.partitionBy())) .filter(f.col('date') == f.col('max_date')) .drop('max_date') ) result_df.show()
試行錯誤の内容もメモ
当初書いたコードや、途中のコードもメモ
from pyspark.sql import functions as f from pyspark.sql.window import Window as w df = spark.createDataFrame( [['a', '201906'], ['a', '201907'], ['b', '201906'], ['b', '201907'], ['c', '201907']], ['name', 'date'] ) df.show() # 1行ではあるが、一旦actionが走るので遅い気がする。 # また、filterの中でdfを使っているので、dfが定義されている必要がある。(読み込みからメソッドチェーンで繋げない) result_df_2 = df.filter(f.col('date') == df.agg(f.max('date')).first()[0]) result_df_2.show() # 同じwindow関数だったら最大を取るという意図からしてmax使った方がわかりやすい result_df_3 = ( df .withColumn('rank', f.rank().over(w.partitionBy().orderBy(f.col('date').desc()))) .filter(f.col('rank') == 1) .drop('rank') )