aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-04-24 09:57:28 -0700
committerReynold Xin <rxin@apache.org>2014-04-24 09:57:28 -0700
commit27b2821cf16948962c7a6f513621a1eba60b8cf3 (patch)
tree1d0b82bab718526cfe019eaddf3812d34e08176b /sql
parent1fdf659d2fdf23c5562e5dc646d05083062281ed (diff)
downloadspark-27b2821cf16948962c7a6f513621a1eba60b8cf3.tar.gz
spark-27b2821cf16948962c7a6f513621a1eba60b8cf3.tar.bz2
spark-27b2821cf16948962c7a6f513621a1eba60b8cf3.zip
[SPARK-1610] [SQL] Fix Cast to use exact type value when cast from BooleanType to NumericTy...
...pe. `Cast` from `BooleanType` to `NumericType` are all using `Int` value. But it causes `ClassCastException` when the casted value is used by the following evaluation like the code below: ``` scala scala> import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst._ scala> import types._ import types._ scala> import expressions._ import expressions._ scala> Add(Cast(Literal(true), ShortType), Literal(1.toShort)).eval() java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Short at scala.runtime.BoxesRunTime.unboxToShort(BoxesRunTime.java:102) at scala.math.Numeric$ShortIsIntegral$.plus(Numeric.scala:72) at org.apache.spark.sql.catalyst.expressions.Add$$anonfun$eval$2.apply(arithmetic.scala:58) at org.apache.spark.sql.catalyst.expressions.Add$$anonfun$eval$2.apply(arithmetic.scala:58) at org.apache.spark.sql.catalyst.expressions.Expression.n2(Expression.scala:114) at org.apache.spark.sql.catalyst.expressions.Add.eval(arithmetic.scala:58) at .<init>(<console>:17) at .<clinit>(<console>) at .<init>(<console>:7) at .<clinit>(<console>) at $print(<console>) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:483) at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734) at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983) at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573) at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604) at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:568) at scala.tools.nsc.interpreter.ILoop.reallyInterpret$1(ILoop.scala:760) at scala.tools.nsc.interpreter.ILoop.interpretStartingWith(ILoop.scala:805) at scala.tools.nsc.interpreter.ILoop.command(ILoop.scala:717) at scala.tools.nsc.interpreter.ILoop.processLine$1(ILoop.scala:581) at scala.tools.nsc.interpreter.ILoop.innerLoop$1(ILoop.scala:588) at scala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:591) at scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:882) at scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:837) at scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:837) at scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135) at scala.tools.nsc.interpreter.ILoop.process(ILoop.scala:837) at scala.tools.nsc.MainGenericRunner.runTarget$1(MainGenericRunner.scala:83) at scala.tools.nsc.MainGenericRunner.process(MainGenericRunner.scala:96) at scala.tools.nsc.MainGenericRunner$.main(MainGenericRunner.scala:105) at scala.tools.nsc.MainGenericRunner.main(MainGenericRunner.scala) ``` Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #533 from ueshin/issues/SPARK-1610 and squashes the following commits: 70f36e8 [Takuya UESHIN] Fix Cast to use exact type value when cast from BooleanType to NumericType.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala7
2 files changed, 12 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1f3fab09e9..8b79b0cd65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -111,7 +111,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType => nullOrCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
@@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType => nullOrCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
@@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType => nullOrCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
@@ -162,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
@@ -172,7 +172,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
- case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 2cd0d2b0e1..4ce0dff9e1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -237,6 +237,13 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)
+ checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+ checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
+ checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
+ checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
+ checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
+ checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
}