diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-05-07 16:26:49 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-05-07 16:26:49 -0700 |
commit | 35f0173b8f67e2e506fc4575be6430cfb66e2238 (patch) | |
tree | 94c74039c393804752b762de0351c5dba1c9a4d7 /sql/core | |
parent | 937ba798c56770ec54276b9259e47ae65ee93967 (diff) | |
download | spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.gz spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.bz2 spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.zip |
[SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"
Avoid translating to CaseWhen and evaluate the key expression many times.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #5979 from cloud-fan/condition and squashes the following commits:
3ce54e1 [Wenchen Fan] add CaseKeyWhen
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 481ed49248..4a54120ba8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { - val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) => - df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr :: - lit(target).cast(col.dataType).expr :: Nil + val keyExpr = df.col(col.name).expr + def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) + val branches = replacementMap.flatMap { case (source, target) => + Seq(buildExpr(source), buildExpr(target)) }.toSeq - new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name) + new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } private def convertToDouble(v: Any): Double = v match { |