aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala180
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala21
10 files changed, 335 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d133ad3f0d..9b6b5b8bd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -18,9 +18,8 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
- def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None)
+ def constructorFor[T : TypeTag]: Expression = {
+ val tpe = localTypeOf[T]
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+ constructorFor(tpe, None, walkedTypePath)
+ }
private def constructorFor(
tpe: `Type`,
- path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
+ path: Option[Expression],
+ walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
/** Returns the current path with a sub-field extracted. */
- def addToPath(part: String): Expression = path
- .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
+ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
+ val newPath = path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+ upCastToExpectedType(newPath, dataType, walkedTypePath)
+ }
/** Returns the current path with a field at ordinal extracted. */
- def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
- .map(p => GetStructField(p, ordinal))
- .getOrElse(BoundReference(ordinal, dataType, false))
+ def addToPathOrdinal(
+ ordinal: Int,
+ dataType: DataType,
+ walkedTypePath: Seq[String]): Expression = {
+ val newPath = path
+ .map(p => GetStructField(p, ordinal))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+ upCastToExpectedType(newPath, dataType, walkedTypePath)
+ }
/** Returns the current path or `BoundReference`. */
- def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
+ def getPath: Expression = {
+ val dataType = schemaFor(tpe).dataType
+ if (path.isDefined) {
+ path.get
+ } else {
+ upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
+ }
+ }
+
+ /**
+ * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
+ * and lost the required data type, which may lead to runtime error if the real type doesn't
+ * match the encoder's schema.
+ * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
+ * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
+ * `Data` with int and long, because we lost the information that `b` should be a string.
+ *
+ * This method help us "remember" the required data type by adding a `UpCast`. Note that we
+ * don't need to cast struct type because there must be `UnresolvedExtractValue` or
+ * `GetStructField` wrapping it, thus we only need to handle leaf type.
+ */
+ def upCastToExpectedType(
+ expr: Expression,
+ expected: DataType,
+ walkedTypePath: Seq[String]): Expression = expected match {
+ case _: StructType => expr
+ case _ => UpCast(expr, expected, walkedTypePath)
+ }
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
- WrapOption(constructorFor(optType, path))
+ val className = getClassNameFromType(optType)
+ val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
+ WrapOption(constructorFor(optType, path, newTypePath))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
@@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
primitiveMethod.map { method =>
Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
+ val className = getClassNameFromType(elementType)
+ val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
- p => constructorFor(elementType, Some(p)),
+ p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
@@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+ val className = getClassNameFromType(elementType)
+ val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val arrayData =
Invoke(
MapObjects(
- p => constructorFor(elementType, Some(p)),
+ p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
@@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
arrayData :: Nil)
case t if t <:< localTypeOf[Map[_, _]] =>
+ // TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p)),
+ p => constructorFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
@@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p)),
+ p => constructorFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
@@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val dataType = schemaFor(fieldType).dataType
-
+ val clsName = getClassNameFromType(fieldType)
+ val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
if (cls.getName startsWith "scala.Tuple") {
- constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+ constructorFor(
+ fieldType,
+ Some(addToPathOrdinal(i, dataType, newTypePath)),
+ newTypePath)
} else {
- constructorFor(fieldType, Some(addToPath(fieldName)))
+ constructorFor(
+ fieldType,
+ Some(addToPath(fieldName, dataType, newTypePath)),
+ newTypePath)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b8f212fca7..765327c474 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
+ ResolveUpCast ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
@@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
+ */
+object ResolveUpCast extends Rule[LogicalPlan] {
+ private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
+ throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
+ s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
+ "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+ "You can either add an explicit cast to the input data or choose a higher precision " +
+ "type of the field in the target object")
+ }
+
+ private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+ val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+ val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+ toPrecedence > 0 && fromPrecedence > toPrecedence
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan transformAllExpressions {
+ case u @ UpCast(child, _, _) if !child.resolved => u
+
+ case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
+ case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+ fail(child, to, walkedTypePath)
+ case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+ fail(child, to, walkedTypePath)
+ case (from, to) if illegalNumericPrecedence(from, to) =>
+ fail(child, to, walkedTypePath)
+ case (TimestampType, DateType) =>
+ fail(child, DateType, walkedTypePath)
+ case (StringType, to: NumericType) =>
+ fail(child, to, walkedTypePath)
+ case _ => Cast(child, dataType)
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index f90fc3cc12..29502a5991 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -53,7 +53,7 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
- private val numericPrecedence =
+ private[sql] val numericPrecedence =
IndexedSeq(
ByteType,
ShortType,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0c10a56c55..06ffe86455 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
@@ -235,12 +236,13 @@ case class ExpressionEncoder[T](
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
+ val optimizedPlan = SimplifyCasts(analyzedPlan)
// In order to construct instances of inner classes (for example those declared in a REPL cell),
// we need an instance of the outer scope. This rule substitues those outer objects into
// expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
// registry.
- copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
+ copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
if (outer == null) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a2c6c39fd8..cb60d5958d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -914,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
"""
}
}
+
+/**
+ * Cast the child expression to the target data type, but will throw error if the cast might
+ * truncate, e.g. long -> int, timestamp -> data.
+ */
+case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String])
+ extends UnaryExpression with Unevaluable {
+ override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1854dfaa7d..72cc89c8be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
/**
- * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
+ * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
* StructType.
*/
def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 0cd352d0fa..ce45245b9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -91,6 +91,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
}
/**
+ * Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
+ * can be casted into `other` safely without losing any precision or range.
+ */
+ private[sql] def isTighterThan(other: DataType): Boolean = other match {
+ case dt: DecimalType =>
+ (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
+ case dt: IntegralType =>
+ isTighterThan(DecimalType.forType(dt))
+ case _ => false
+ }
+
+ /**
* The default size of a value of the DecimalType is 4096 bytes.
*/
override def defaultSize: Int = 4096
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
new file mode 100644
index 0000000000..0289988342
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+
+case class StringLongClass(a: String, b: Long)
+
+case class StringIntClass(a: String, b: Int)
+
+case class ComplexClass(a: Long, b: StringLongClass)
+
+class EncoderResolutionSuite extends PlanTest {
+ test("real type doesn't match encoder schema but they are compatible: product") {
+ val encoder = ExpressionEncoder[StringLongClass]
+ val cls = classOf[StringLongClass]
+
+ {
+ val attrs = Seq('a.string, 'b.int)
+ val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+ val expected: Expression = NewInstance(
+ cls,
+ toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
+ false,
+ ObjectType(cls))
+ compareExpressions(fromRowExpr, expected)
+ }
+
+ {
+ val attrs = Seq('a.int, 'b.long)
+ val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
+ val expected = NewInstance(
+ cls,
+ toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
+ false,
+ ObjectType(cls))
+ compareExpressions(fromRowExpr, expected)
+ }
+ }
+
+ test("real type doesn't match encoder schema but they are compatible: nested product") {
+ val encoder = ExpressionEncoder[ComplexClass]
+ val innerCls = classOf[StringLongClass]
+ val cls = classOf[ComplexClass]
+
+ val structType = new StructType().add("a", IntegerType).add("b", LongType)
+ val attrs = Seq('a.int, 'b.struct(structType))
+ val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+ val expected: Expression = NewInstance(
+ cls,
+ Seq(
+ 'a.int.cast(LongType),
+ If(
+ 'b.struct(structType).isNull,
+ Literal.create(null, ObjectType(innerCls)),
+ NewInstance(
+ innerCls,
+ Seq(
+ toExternalString(
+ GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)),
+ GetStructField('b.struct(structType), 1, Some("b"))),
+ false,
+ ObjectType(innerCls))
+ )),
+ false,
+ ObjectType(cls))
+ compareExpressions(fromRowExpr, expected)
+ }
+
+ test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
+ val encoder = ExpressionEncoder.tuple(
+ ExpressionEncoder[StringLongClass],
+ ExpressionEncoder[Long])
+ val cls = classOf[StringLongClass]
+
+ val structType = new StructType().add("a", StringType).add("b", ByteType)
+ val attrs = Seq('a.struct(structType), 'b.int)
+ val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
+ val expected: Expression = NewInstance(
+ classOf[Tuple2[_, _]],
+ Seq(
+ NewInstance(
+ cls,
+ Seq(
+ toExternalString(GetStructField('a.struct(structType), 0, Some("a"))),
+ GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)),
+ false,
+ ObjectType(cls)),
+ 'b.int.cast(LongType)),
+ false,
+ ObjectType(classOf[Tuple2[_, _]]))
+ compareExpressions(fromRowExpr, expected)
+ }
+
+ private def toExternalString(e: Expression): Expression = {
+ Invoke(e, "toString", ObjectType(classOf[String]), Nil)
+ }
+
+ test("throw exception if real type is not compatible with encoder schema") {
+ val msg1 = intercept[AnalysisException] {
+ ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
+ }.message
+ assert(msg1 ==
+ s"""
+ |Cannot up cast `b` from bigint to int as it may truncate
+ |The type path of the target object is:
+ |- field (class: "scala.Int", name: "b")
+ |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass"
+ |You can either add an explicit cast to the input data or choose a higher precision type
+ """.stripMargin.trim + " of the field in the target object")
+
+ val msg2 = intercept[AnalysisException] {
+ val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
+ ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null)
+ }.message
+ assert(msg2 ==
+ s"""
+ |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
+ |The type path of the target object is:
+ |- field (class: "scala.Long", name: "b")
+ |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")
+ |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass"
+ |You can either add an explicit cast to the input data or choose a higher precision type
+ """.stripMargin.trim + " of the field in the target object")
+ }
+
+ // test for leaf types
+ castSuccess[Int, Long]
+ castSuccess[java.sql.Date, java.sql.Timestamp]
+ castSuccess[Long, String]
+ castSuccess[Int, java.math.BigDecimal]
+ castSuccess[Long, java.math.BigDecimal]
+
+ castFail[Long, Int]
+ castFail[java.sql.Timestamp, java.sql.Date]
+ castFail[java.math.BigDecimal, Double]
+ castFail[Double, java.math.BigDecimal]
+ castFail[java.math.BigDecimal, Int]
+ castFail[String, Long]
+
+
+ private def castSuccess[T: TypeTag, U: TypeTag]: Unit = {
+ val from = ExpressionEncoder[T]
+ val to = ExpressionEncoder[U]
+ val catalystType = from.schema.head.dataType.simpleString
+ test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") {
+ to.resolve(from.schema.toAttributes, null)
+ }
+ }
+
+ private def castFail[T: TypeTag, U: TypeTag]: Unit = {
+ val from = ExpressionEncoder[T]
+ val to = ExpressionEncoder[U]
+ val catalystType = from.schema.head.dataType.simpleString
+ test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") {
+ intercept[AnalysisException](to.resolve(from.schema.toAttributes, null))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 19dce5d1e2..c6d2bf07b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
checkAnswer(
ds.groupBy(_._1).agg(
sum(_._2),
- expr("sum(_2)").as[Int],
+ expr("sum(_2)").as[Long],
count("*")),
- ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
+ ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
}
test("typed aggregation: complex case") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a2c8d20156..542e4d6c43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Int]),
- ("a", 30), ("b", 3), ("c", 1))
+ ds.groupBy(_._1).agg(sum("_2").as[Long]),
+ ("a", 30L), ("b", 3L), ("c", 1L))
}
test("typed aggregation: expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
- ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
+ ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
+ ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
}
test("typed aggregation: expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]),
- ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
+ ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
+ ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
}
test("typed aggregation: expr, expr, expr, expr") {
@@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(
ds.groupBy(_._1).agg(
- sum("_2").as[Int],
+ sum("_2").as[Long],
sum($"_2" + 1).as[Long],
count("*").as[Long],
avg("_2").as[Double]),
- ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
+ ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0))
}
test("cogroup") {
@@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
}
+
+ test("change encoder with compatible schema") {
+ val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
+ assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
+ }
}