aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-07 16:26:49 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-07 16:26:49 -0700
commit35f0173b8f67e2e506fc4575be6430cfb66e2238 (patch)
tree94c74039c393804752b762de0351c5dba1c9a4d7 /sql/core
parent937ba798c56770ec54276b9259e47ae65ee93967 (diff)
downloadspark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.gz
spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.bz2
spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.zip
[SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"
Avoid translating to CaseWhen and evaluate the key expression many times. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #5979 from cloud-fan/condition and squashes the following commits: 3ce54e1 [Wenchen Fan] add CaseKeyWhen
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala9
1 files changed, 5 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 481ed49248..4a54120ba8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
- val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
- df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
- lit(target).cast(col.dataType).expr :: Nil
+ val keyExpr = df.col(col.name).expr
+ def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
+ val branches = replacementMap.flatMap { case (source, target) =>
+ Seq(buildExpr(source), buildExpr(target))
}.toSeq
- new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
+ new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}
private def convertToDouble(v: Any): Double = v match {