aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-20 12:04:42 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-20 12:04:42 -0800
commit3b9d2a347f9c796b90852173d84189834e499e25 (patch)
treedc88c8bb396fb031b28646a32040f46bb20088c9
parent60bfb113325c71491f8dcf98b6036b0caa2144fe (diff)
downloadspark-3b9d2a347f9c796b90852173d84189834e499e25.tar.gz
spark-3b9d2a347f9c796b90852173d84189834e499e25.tar.bz2
spark-3b9d2a347f9c796b90852173d84189834e499e25.zip
[SPARK-11819][SQL] nice error message for missing encoder
before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan <wenchen@databricks.com> Closes #9810 from cloud-fan/error-message.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala62
2 files changed, 129 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 33ae700706..918050b531 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case _ =>
- val className: String = tpe.erasure.typeSymbol.asClass.fullName
+ val className = getClassNameFromType(tpe)
className match {
case "scala.Array" =>
val TypeRef(_, _, Seq(elementType)) = tpe
@@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection {
}
}
- /** Returns expressions for extracting all the fields from the given type. */
+ /**
+ * Returns expressions for extracting all the fields from the given type.
+ *
+ * If the given type is not supported, i.e. there is no encoder can be built for this type,
+ * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
+ * the type path walked so far and which class we are not supporting.
+ * There are 4 kinds of type path:
+ * * the root type: `root class: "abc.xyz.MyClass"`
+ * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
+ * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
+ * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
+ */
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
- extractorFor(inputObject, localTypeOf[T]) match {
+ val tpe = localTypeOf[T]
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+ extractorFor(inputObject, tpe, walkedTypePath) match {
case s: CreateNamedStruct => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
@@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection {
/** Helper for extracting internal fields from a case class. */
private def extractorFor(
inputObject: Expression,
- tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
+ tpe: `Type`,
+ walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
+
+ def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
+ val externalDataType = dataTypeFor(elementType)
+ val Schema(catalystType, nullable) = silentSchemaFor(elementType)
+ if (isNativeType(catalystType)) {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(catalystType, nullable))
+ } else {
+ val clsName = getClassNameFromType(elementType)
+ val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
+ // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
+ // to trigger the type check.
+ extractorFor(inputObject, elementType, newPath)
+
+ MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
+ }
+ }
+
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
@@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection {
// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
- val className: String = optType.erasure.typeSymbol.asClass.fullName
+ val className = getClassNameFromType(optType)
val classObj = Utils.classForName(className)
val optionObjectType = ObjectType(classObj)
+ val newPath = s"""- option value class: "$className"""" +: walkedTypePath
val unwrapped = UnwrapOption(optionObjectType, inputObject)
expressions.If(
IsNull(unwrapped),
- expressions.Literal.create(null, schemaFor(optType).dataType),
- extractorFor(unwrapped, optType))
+ expressions.Literal.create(null, silentSchemaFor(optType).dataType),
+ extractorFor(unwrapped, optType, newPath))
}
case t if t <:< localTypeOf[Product] =>
@@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection {
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ val clsName = getClassNameFromType(fieldType)
+ val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+
+ expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})
case t if t <:< localTypeOf[Array[_]] =>
@@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection {
Invoke(inputObject, "booleanValue", BooleanType)
case other =>
- throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+ throw new UnsupportedOperationException(
+ s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
}
}
-
- private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
- val externalDataType = dataTypeFor(elementType)
- val Schema(catalystType, nullable) = schemaFor(elementType)
- if (isNativeType(catalystType)) {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(catalystType, nullable))
- } else {
- MapObjects(extractorFor(_, elementType), input, externalDataType)
- }
- }
}
/**
@@ -561,7 +588,7 @@ trait ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
- val className: String = tpe.erasure.typeSymbol.asClass.fullName
+ val className = getClassNameFromType(tpe)
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
@@ -638,6 +665,23 @@ trait ScalaReflection {
}
/**
+ * Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
+ *
+ * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
+ * `NullType` silently instead.
+ */
+ private def silentSchemaFor(tpe: `Type`): Schema = try {
+ schemaFor(tpe)
+ } catch {
+ case _: UnsupportedOperationException => Schema(NullType, nullable = true)
+ }
+
+ /** Returns the full class name for a type. */
+ private def getClassNameFromType(tpe: `Type`): String = {
+ tpe.erasure.typeSymbol.asClass.fullName
+ }
+
+ /**
* Returns classes of input parameters of scala function object.
*/
def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala
index 0b2a10bb04..8c766ef829 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala
@@ -17,9 +17,22 @@
package org.apache.spark.sql.catalyst.encoders
+import scala.reflect.ClassTag
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
+class NonEncodable(i: Int)
+
+case class ComplexNonEncodable1(name1: NonEncodable)
+
+case class ComplexNonEncodable2(name2: ComplexNonEncodable1)
+
+case class ComplexNonEncodable3(name3: Option[NonEncodable])
+
+case class ComplexNonEncodable4(name4: Array[NonEncodable])
+
+case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])
class EncoderErrorMessageSuite extends SparkFunSuite {
@@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] }
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
}
+
+ test("nice error message for missing encoder") {
+ val errorMsg1 =
+ intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
+ assert(errorMsg1.contains(
+ s"""root class: "${clsName[ComplexNonEncodable1]}""""))
+ assert(errorMsg1.contains(
+ s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
+
+ val errorMsg2 =
+ intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
+ assert(errorMsg2.contains(
+ s"""root class: "${clsName[ComplexNonEncodable2]}""""))
+ assert(errorMsg2.contains(
+ s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
+ assert(errorMsg1.contains(
+ s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
+
+ val errorMsg3 =
+ intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
+ assert(errorMsg3.contains(
+ s"""root class: "${clsName[ComplexNonEncodable3]}""""))
+ assert(errorMsg3.contains(
+ s"""field (class: "scala.Option", name: "name3")"""))
+ assert(errorMsg3.contains(
+ s"""option value class: "${clsName[NonEncodable]}""""))
+
+ val errorMsg4 =
+ intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
+ assert(errorMsg4.contains(
+ s"""root class: "${clsName[ComplexNonEncodable4]}""""))
+ assert(errorMsg4.contains(
+ s"""field (class: "scala.Array", name: "name4")"""))
+ assert(errorMsg4.contains(
+ s"""array element class: "${clsName[NonEncodable]}""""))
+
+ val errorMsg5 =
+ intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
+ assert(errorMsg5.contains(
+ s"""root class: "${clsName[ComplexNonEncodable5]}""""))
+ assert(errorMsg5.contains(
+ s"""field (class: "scala.Option", name: "name5")"""))
+ assert(errorMsg5.contains(
+ s"""option value class: "scala.Array""""))
+ assert(errorMsg5.contains(
+ s"""array element class: "${clsName[NonEncodable]}""""))
+ }
+
+ private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
}