aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-04 11:55:20 -0700
committerReynold Xin <rxin@databricks.com>2015-07-04 11:55:20 -0700
commitc991ef5abbb501933b2a68eea1987cf8d88794a5 (patch)
treeb615e43f1b77351d124fdd5210965f17f314cb50 /sql
parent347cab85cd924ffd326f3d1367b3b156ee08052d (diff)
downloadspark-c991ef5abbb501933b2a68eea1987cf8d88794a5.tar.gz
spark-c991ef5abbb501933b2a68eea1987cf8d88794a5.tar.bz2
spark-c991ef5abbb501933b2a68eea1987cf8d88794a5.zip
[SPARK-8822][SQL] clean up type checking in math.scala.
Author: Reynold Xin <rxin@databricks.com> Closes #7220 from rxin/SPARK-8822 and squashes the following commits: 0cda076 [Reynold Xin] Test cases. 22d0463 [Reynold Xin] Fixed type precedence. beb2a97 [Reynold Xin] [SPARK-8822][SQL] clean up type checking in math.scala.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala260
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala31
2 files changed, 123 insertions, 168 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 0fc320fb08..45b7e4d340 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
-import java.lang.{Long => JLong}
-import java.util.Arrays
+import java.{lang => jl}
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
if (evalE == null) {
null
} else {
- val input = evalE.asInstanceOf[Integer]
+ val input = evalE.asInstanceOf[jl.Integer]
if (input > 20 || input < 0) {
null
} else {
@@ -290,7 +288,7 @@ case class Bin(child: Expression)
if (evalE == null) {
null
} else {
- UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long]))
+ UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long]))
}
}
@@ -300,27 +298,18 @@ case class Bin(child: Expression)
}
}
-
/**
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
-case class Hex(child: Expression) extends UnaryExpression with Serializable {
+case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ // TODO: Create code-gen version.
- override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(LongType, StringType, BinaryType))
- override def checkInputDataTypes(): TypeCheckResult = {
- if (child.dataType.isInstanceOf[StringType]
- || child.dataType.isInstanceOf[IntegerType]
- || child.dataType.isInstanceOf[LongType]
- || child.dataType.isInstanceOf[BinaryType]
- || child.dataType == NullType) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type")
- }
- }
+ override def dataType: DataType = StringType
override def eval(input: InternalRow): Any = {
val num = child.eval(input)
@@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
} else {
child.dataType match {
case LongType => hex(num.asInstanceOf[Long])
- case IntegerType => hex(num.asInstanceOf[Integer].toLong)
case BinaryType => hex(num.asInstanceOf[Array[Byte]])
case StringType => hex(num.asInstanceOf[UTF8String])
}
@@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte
numBuf >>>= 4
} while (numBuf != 0)
- UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length))
+ UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
+ }
+}
+
+
+/**
+ * Performs the inverse operation of HEX.
+ * Resulting characters are returned as a byte array.
+ */
+case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ // TODO: Create code-gen version.
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def dataType: DataType = BinaryType
+
+ override def eval(input: InternalRow): Any = {
+ val num = child.eval(input)
+ if (num == null) {
+ null
+ } else {
+ unhex(num.asInstanceOf[UTF8String].getBytes)
+ }
+ }
+
+ private val unhexDigits = {
+ val array = Array.fill[Byte](128)(-1)
+ (0 to 9).foreach(i => array('0' + i) = i.toByte)
+ (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
+ (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
+ array
+ }
+
+ private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
+ var bytes = inputBytes
+ if ((bytes.length & 0x01) != 0) {
+ bytes = '0'.toByte +: bytes
+ }
+ val out = new Array[Byte](bytes.length >> 1)
+ // two characters form the hex value.
+ var i = 0
+ while (i < bytes.length) {
+ val first = unhexDigits(bytes(i))
+ val second = unhexDigits(bytes(i + 1))
+ if (first == -1 || second == -1) { return null}
+ out(i / 2) = (((first << 4) | second) & 0xFF).toByte
+ i += 2
+ }
+ out
}
}
@@ -423,22 +459,19 @@ case class Pow(left: Expression, right: Expression)
}
}
-case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
- override def checkInputDataTypes(): TypeCheckResult = {
- (left.dataType, right.dataType) match {
- case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
- case (_, IntegerType) => left.dataType match {
- case LongType | IntegerType | ShortType | ByteType =>
- return TypeCheckResult.TypeCheckSuccess
- case _ => // failed
- }
- case _ => // failed
- }
- TypeCheckResult.TypeCheckFailure(
- s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
- s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
- }
+/**
+ * Bitwise unsigned left shift.
+ * @param left the base number to shift.
+ * @param right number of bits to left shift.
+ */
+case class ShiftLeft(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+ override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
@@ -446,10 +479,8 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
- case l: Long => l << valueRight.asInstanceOf[Integer]
- case i: Integer => i << valueRight.asInstanceOf[Integer]
- case s: Short => s << valueRight.asInstanceOf[Integer]
- case b: Byte => b << valueRight.asInstanceOf[Integer]
+ case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer]
}
} else {
null
@@ -459,35 +490,24 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
}
}
- override def dataType: DataType = {
- left.dataType match {
- case LongType => LongType
- case IntegerType | ShortType | ByteType => IntegerType
- case _ => NullType
- }
- }
-
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
}
}
-case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
- override def checkInputDataTypes(): TypeCheckResult = {
- (left.dataType, right.dataType) match {
- case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
- case (_, IntegerType) => left.dataType match {
- case LongType | IntegerType | ShortType | ByteType =>
- return TypeCheckResult.TypeCheckSuccess
- case _ => // failed
- }
- case _ => // failed
- }
- TypeCheckResult.TypeCheckFailure(
- s"ShiftRight expects long, integer, short or byte value as first argument and an " +
- s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
- }
+/**
+ * Bitwise unsigned left shift.
+ * @param left the base number to shift.
+ * @param right number of bits to left shift.
+ */
+case class ShiftRight(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+ override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
@@ -495,10 +515,8 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
- case l: Long => l >> valueRight.asInstanceOf[Integer]
- case i: Integer => i >> valueRight.asInstanceOf[Integer]
- case s: Short => s >> valueRight.asInstanceOf[Integer]
- case b: Byte => b >> valueRight.asInstanceOf[Integer]
+ case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer]
}
} else {
null
@@ -508,35 +526,24 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}
- override def dataType: DataType = {
- left.dataType match {
- case LongType => LongType
- case IntegerType | ShortType | ByteType => IntegerType
- case _ => NullType
- }
- }
-
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
}
}
-case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
- override def checkInputDataTypes(): TypeCheckResult = {
- (left.dataType, right.dataType) match {
- case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
- case (_, IntegerType) => left.dataType match {
- case LongType | IntegerType | ShortType | ByteType =>
- return TypeCheckResult.TypeCheckSuccess
- case _ => // failed
- }
- case _ => // failed
- }
- TypeCheckResult.TypeCheckFailure(
- s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
- s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
- }
+/**
+ * Bitwise unsigned right shift, for integer and long data type.
+ * @param left the base number.
+ * @param right the number of bits to right shift.
+ */
+case class ShiftRightUnsigned(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+ override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
@@ -544,10 +551,8 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
- case l: Long => l >>> valueRight.asInstanceOf[Integer]
- case i: Integer => i >>> valueRight.asInstanceOf[Integer]
- case s: Short => s >>> valueRight.asInstanceOf[Integer]
- case b: Byte => b >>> valueRight.asInstanceOf[Integer]
+ case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer]
}
} else {
null
@@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
}
}
- override def dataType: DataType = {
- left.dataType match {
- case LongType => LongType
- case IntegerType | ShortType | ByteType => IntegerType
- case _ => NullType
- }
- }
-
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
}
}
-/**
- * Performs the inverse operation of HEX.
- * Resulting characters are returned as a byte array.
- */
-case class UnHex(child: Expression) extends UnaryExpression with Serializable {
-
- override def dataType: DataType = BinaryType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}")
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val num = child.eval(input)
- if (num == null) {
- null
- } else {
- unhex(num.asInstanceOf[UTF8String].getBytes)
- }
- }
-
- private val unhexDigits = {
- val array = Array.fill[Byte](128)(-1)
- (0 to 9).foreach(i => array('0' + i) = i.toByte)
- (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
- (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
- array
- }
-
- private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
- var bytes = inputBytes
- if ((bytes.length & 0x01) != 0) {
- bytes = '0'.toByte +: bytes
- }
- val out = new Array[Byte](bytes.length >> 1)
- // two characters form the hex value.
- var i = 0
- while (i < bytes.length) {
- val first = unhexDigits(bytes(i))
- val second = unhexDigits(bytes(i + 1))
- if (first == -1 || second == -1) { return null}
- out(i / 2) = (((first << 4) | second) & 0xFF).toByte
- i += 2
- }
- out
- }
-}
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
+
+/**
+ * Computes the logarithm of a number.
+ * @param left the logarithm base, default to e.
+ * @param right the number to compute the logarithm of.
+ */
case class Logarithm(left: Expression, right: Expression)
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
@@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression)
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
}
logCode + s"""
- if (Double.valueOf(${ev.primitive}).isNaN()) {
+ if (Double.isNaN(${ev.primitive})) {
${ev.isNull} = true;
}
"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 20839c83d4..03d8400cf3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -161,11 +161,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("factorial") {
- val dataLong = (0 to 20)
- dataLong.foreach { value =>
+ (0 to 20).foreach { value =>
checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
}
- checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null))
+ checkEvaluation(Literal.create(null, IntegerType), null, create_row(null))
checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
}
@@ -244,10 +243,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
- checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
- checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
- checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
+ checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
}
@@ -257,10 +254,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
- checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
- checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
- checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
+ checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}
@@ -270,16 +265,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
- checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
- checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
- checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
+ checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
}
test("hex") {
- checkEvaluation(Hex(Literal(28)), "1C")
- checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
checkEvaluation(Hex(Literal(100800200404L)), "177828FED4")
checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C")
checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578")
@@ -313,6 +304,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
}
+
+ // null input should yield null output
checkEvaluation(
Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
null,
@@ -321,5 +314,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
null,
create_row(null))
+
+ // negative input should yield null output
+ checkEvaluation(
+ Logarithm(Literal(-1.0), Literal(1.0)),
+ null,
+ create_row(null))
+ checkEvaluation(
+ Logarithm(Literal(1.0), Literal(-1.0)),
+ null,
+ create_row(null))
}
}