diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-05-18 18:06:38 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-05-18 18:06:38 -0700 |
commit | ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5 (patch) | |
tree | 9b415f466b981db6e208801df9898944ea28dca0 /sql/catalyst | |
parent | 32be51fba45f5e07a2a3520293c12dc7765a364d (diff) | |
download | spark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.tar.gz spark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.tar.bz2 spark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.zip |
[SPARK-15192][SQL] null check for SparkSession.createDataFrame
## What changes were proposed in this pull request?
This PR adds null check in `SparkSession.createDataFrame`, so that we can make sure the passed in rows matches the given schema.
## How was this patch tested?
new tests in `DatasetSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13008 from cloud-fan/row-encoder.
Diffstat (limited to 'sql/catalyst')
4 files changed, 9 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cb9a62dfd4..c0fa220d34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection { * Returns true if the value of this data type is same between internal and external. */ def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a5f39aaa23..71b39c54fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -70,8 +70,7 @@ object RowEncoder { private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + case dt if ScalaReflection.isNativeType(dt) => inputObject case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) @@ -151,7 +150,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => val fieldValue = serializerFor( - GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), + GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)), f.dataType ) if (f.nullable) { @@ -193,7 +192,6 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt - case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -202,7 +200,6 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) - case _: NullType => ObjectType(classOf[java.lang.Object]) } private def deserializerFor(schema: StructType): Expression = { @@ -222,8 +219,7 @@ object RowEncoder { } private def deserializerFor(input: Expression): Expression = input.dataType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => input + case dt if ScalaReflection.isNativeType(dt) => input case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 99f156a935..a38f1ec091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { - override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7df6e06805..fc38369f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -693,6 +693,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) case class GetExternalRowField( child: Expression, index: Int, + fieldName: String, dataType: DataType) extends UnaryExpression with NonSQLExpression { override def nullable: Boolean = false @@ -716,7 +717,8 @@ case class GetExternalRowField( } if (${row.value}.isNullAt($index)) { - throw new RuntimeException("The ${index}th field of input row cannot be null."); + throw new RuntimeException("The ${index}th field '$fieldName' of input row " + + "cannot be null."); } final ${ctx.javaType(dataType)} ${ev.value} = $getField; |