aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-05 15:59:52 -0700
committerCheng Lian <lian@databricks.com>2016-06-05 15:59:52 -0700
commit30c4774f33fed63b7d400d220d710fb432f599a8 (patch)
treeaf6af08486b4e3917277317e289a17e3a75b5058 /sql
parent8a9110510c9e4cbbcb0dede62cb4b9dd1c6bc8cc (diff)
downloadspark-30c4774f33fed63b7d400d220d710fb432f599a8.tar.gz
spark-30c4774f33fed63b7d400d220d710fb432f599a8.tar.bz2
spark-30c4774f33fed63b7d400d220d710fb432f599a8.zip
[SPARK-15657][SQL] RowEncoder should validate the data type of input object
## What changes were proposed in this pull request? This PR improves the error handling of `RowEncoder`. When we create a `RowEncoder` with a given schema, we should validate the data type of input object. e.g. we should throw an exception when a field is boolean but is declared as a string column. This PR also removes the support to use `Product` as a valid external type of struct type. This support is added at https://github.com/apache/spark/pull/9712, but is incomplete, e.g. nested product, product in array are both not working. However, we never officially support this feature and I think it's ok to ban it. ## How was this patch tested? new tests in `RowEncoderSuite`. Author: Wenchen Fan <wenchen@databricks.com> Closes #13401 from cloud-fan/bug.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala47
4 files changed, 95 insertions, 40 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index a257b831dd..391001de26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -304,15 +304,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
- def getStruct(i: Int): Row = {
- // Product and Row both are recognized as StructType in a Row
- val t = get(i)
- if (t.isInstanceOf[Product]) {
- Row.fromTuple(t.asInstanceOf[Product])
- } else {
- t.asInstanceOf[Row]
- }
- }
+ def getStruct(i: Int): Row = getAs[Row](i)
/**
* Returns the value at position i.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 6cd7b34ceb..67fca153b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
- * StructType -> org.apache.spark.sql.Row or Product
+ * StructType -> org.apache.spark.sql.Row
* }}}
*/
object RowEncoder {
@@ -121,11 +121,15 @@ object RowEncoder {
case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
+ // TODO: validate input type for primitive array.
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
- case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
+ case _ => MapObjects(
+ element => serializerFor(ValidateExternalType(element, et), et),
+ inputObject,
+ ObjectType(classOf[Object]))
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -151,8 +155,9 @@ object RowEncoder {
case StructType(fields) =>
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor(
- GetExternalRowField(
- inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
+ ValidateExternalType(
+ GetExternalRowField(inputObject, index, field.name),
+ field.dataType),
field.dataType)
val convertedField = if (field.nullable) {
If(
@@ -183,7 +188,7 @@ object RowEncoder {
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
* `org.apache.spark.sql.types.Decimal`.
*/
- private def externalDataTypeForInput(dt: DataType): DataType = dt match {
+ def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
@@ -192,7 +197,7 @@ object RowEncoder {
case _ => externalDataTypeFor(dt)
}
- private def externalDataTypeFor(dt: DataType): DataType = dt match {
+ def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d4c71bffe8..87c8a2e54a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -26,6 +26,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
case class GetExternalRowField(
child: Expression,
index: Int,
- fieldName: String,
- dataType: DataType) extends UnaryExpression with NonSQLExpression {
+ fieldName: String) extends UnaryExpression with NonSQLExpression {
override def nullable: Boolean = false
+ override def dataType: DataType = ObjectType(classOf[Object])
+
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val row = child.genCode(ctx)
-
- val getField = dataType match {
- case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
- case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
- }
-
val code = s"""
${row.code}
@@ -720,8 +716,55 @@ case class GetExternalRowField(
"cannot be null.");
}
- final ${ctx.javaType(dataType)} ${ev.value} = $getField;
+ final Object ${ev.value} = ${row.value}.get($index);
"""
ev.copy(code = code, isNull = "false")
}
}
+
+/**
+ * Validates the actual data type of input expression at runtime. If it doesn't match the
+ * expectation, throw an exception.
+ */
+case class ValidateExternalType(child: Expression, expected: DataType)
+ extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))
+
+ override def nullable: Boolean = child.nullable
+
+ override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val input = child.genCode(ctx)
+ val obj = input.value
+
+ val typeCheck = expected match {
+ case _: DecimalType =>
+ Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
+ .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+ case _: ArrayType =>
+ s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
+ case _ =>
+ s"$obj instanceof ${ctx.boxedType(dataType)}"
+ }
+
+ val code = s"""
+ ${input.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${input.isNull}) {
+ if ($typeCheck) {
+ ${ev.value} = (${ctx.boxedType(dataType)}) $obj;
+ } else {
+ throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
+ "external type for schema of ${expected.simpleString}");
+ }
+ }
+
+ """
+ ev.copy(code = code, isNull = input.isNull)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 16abde064f..2e513ea22c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))
- test(s"encode/decode: Product") {
- val schema = new StructType()
- .add("structAsProduct",
- new StructType()
- .add("int", IntegerType)
- .add("string", StringType)
- .add("double", DoubleType))
-
- val encoder = RowEncoder(schema).resolveAndBind()
-
- val input: Row = Row((100, "test", 0.123))
- val row = encoder.toRow(input)
- val convertedBack = encoder.fromRow(row)
- assert(input.getStruct(0) == convertedBack.getStruct(0))
- }
-
test("encode/decode decimal type") {
val schema = new StructType()
.add("int", IntegerType)
@@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite {
assert(e.getMessage.contains("top level row object"))
}
+ test("RowEncoder should validate external type") {
+ val e1 = intercept[RuntimeException] {
+ val schema = new StructType().add("a", IntegerType)
+ val encoder = RowEncoder(schema)
+ encoder.toRow(Row(1.toShort))
+ }
+ assert(e1.getMessage.contains("java.lang.Short is not a valid external type"))
+
+ val e2 = intercept[RuntimeException] {
+ val schema = new StructType().add("a", StringType)
+ val encoder = RowEncoder(schema)
+ encoder.toRow(Row(1))
+ }
+ assert(e2.getMessage.contains("java.lang.Integer is not a valid external type"))
+
+ val e3 = intercept[RuntimeException] {
+ val schema = new StructType().add("a",
+ new StructType().add("b", IntegerType).add("c", StringType))
+ val encoder = RowEncoder(schema)
+ encoder.toRow(Row(1 -> "a"))
+ }
+ assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))
+
+ val e4 = intercept[RuntimeException] {
+ val schema = new StructType().add("a", ArrayType(TimestampType))
+ val encoder = RowEncoder(schema)
+ encoder.toRow(Row(Array("a")))
+ }
+ assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
+ }
+
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema).resolveAndBind()