diff options
author | Reynold Xin <rxin@databricks.com> | 2015-06-13 17:10:13 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-06-13 17:10:13 -0700 |
commit | a138953391975886c88bfe81d4ce6b6dd189cd32 (patch) | |
tree | 7cc61e281f24c1d716ab2998c765068367a8e978 /sql | |
parent | ddec45279ed1061f4c05fd0760309a53581d03f5 (diff) | |
download | spark-a138953391975886c88bfe81d4ce6b6dd189cd32.tar.gz spark-a138953391975886c88bfe81d4ce6b6dd189cd32.tar.bz2 spark-a138953391975886c88bfe81d4ce6b6dd189cd32.zip |
[SPARK-8347][SQL] Add unit tests for abs.
Also addressed code review feedback from #6754
Author: Reynold Xin <rxin@databricks.com>
Closes #6803 from rxin/abs and squashes the following commits:
d07beba [Reynold Xin] [SPARK-8347] Add unit tests for abs.
Diffstat (limited to 'sql')
5 files changed, 31 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ed192360d..e7bf7cc1f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -672,13 +672,13 @@ trait HiveTypeCoercion { findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - i.makeCopy(Array(pred, newLeft, newRight)) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. // Convert If(null literal, _, _) into boolean type. // In the optimizer, we should short-circuit this directly into false value. - case i @ If(pred, left, right) if pred.dataType == NullType => - i.makeCopy(Array(Literal.create(null, BooleanType), left, right)) + case If(pred, left, right) if pred.dataType == NullType => + If(Literal.create(null, BooleanType), left, right) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e1afa81a7a..5ff1bca260 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -75,6 +75,21 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) } + test("Abs") { + def testAbs(convert: (Int) => Any): Unit = { + checkEvaluation(Abs(Literal(convert(0))), convert(0)) + checkEvaluation(Abs(Literal(convert(1))), convert(1)) + checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + } + testAbs(_.toByte) + testAbs(_.toShort) + testAbs(identity) + testAbs(_.toLong) + testAbs(_.toFloat) + testAbs(_.toDouble) + testAbs(Decimal(_)) + } + test("Divide") { checkEvaluation(Divide(Literal(2), Literal(1)), 2) checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index efcdae5bce..5a08578e7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -369,23 +369,6 @@ class ColumnExpressionSuite extends QueryTest { ) } - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key.asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key.desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(lit(null))), - (1 to 100).map(_ => Row(null)) - ) - } - test("upper") { checkAnswer( lowerCaseData.select(upper('l)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 6561c3b232..faa1d1193b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -236,6 +236,18 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("abs") { + val input = + Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) + checkAnswer( + input.toDF("key", "value").select(abs($"key").alias("a")).sort("a"), + input.map(pair => Row(pair._2))) + + checkAnswer( + input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), + input.map(pair => Row(pair._2))) + } + test("log2") { val df = Seq((1, 2)).toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6898d58441..d1520b757e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -178,18 +178,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } - test("SPARK-3176 Added Parser of SQL ABS()") { - checkAnswer( - sql("SELECT ABS(-1.3)"), - Row(1.3)) - checkAnswer( - sql("SELECT ABS(0.0)"), - Row(0.0)) - checkAnswer( - sql("SELECT ABS(2.5)"), - Row(2.5)) - } - test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") |