aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-18 18:06:38 -0700
committerYin Huai <yhuai@databricks.com>2016-05-18 18:06:38 -0700
commitebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5 (patch)
tree9b415f466b981db6e208801df9898944ea28dca0 /sql/catalyst
parent32be51fba45f5e07a2a3520293c12dc7765a364d (diff)
downloadspark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.tar.gz
spark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.tar.bz2
spark-ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5.zip
[SPARK-15192][SQL] null check for SparkSession.createDataFrame
## What changes were proposed in this pull request? This PR adds null check in `SparkSession.createDataFrame`, so that we can make sure the passed in rows matches the given schema. ## How was this patch tested? new tests in `DatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13008 from cloud-fan/row-encoder.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala4
4 files changed, 9 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index cb9a62dfd4..c0fa220d34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection {
* Returns true if the value of this data type is same between internal and external.
*/
def isNativeType(dt: DataType): Boolean = dt match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => true
+ case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => true
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index a5f39aaa23..71b39c54fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -70,8 +70,7 @@ object RowEncoder {
private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
- case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
+ case dt if ScalaReflection.isNativeType(dt) => inputObject
case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
@@ -151,7 +150,7 @@ object RowEncoder {
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val fieldValue = serializerFor(
- GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
+ GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)),
f.dataType
)
if (f.nullable) {
@@ -193,7 +192,6 @@ object RowEncoder {
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
- case CalendarIntervalType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -202,7 +200,6 @@ object RowEncoder {
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
- case _: NullType => ObjectType(classOf[java.lang.Object])
}
private def deserializerFor(schema: StructType): Expression = {
@@ -222,8 +219,7 @@ object RowEncoder {
}
private def deserializerFor(input: Expression): Expression = input.dataType match {
- case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType | CalendarIntervalType => input
+ case dt if ScalaReflection.isNativeType(dt) => input
case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 99f156a935..a38f1ec091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression {
- override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"
+ override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7df6e06805..fc38369f38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -693,6 +693,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
case class GetExternalRowField(
child: Expression,
index: Int,
+ fieldName: String,
dataType: DataType) extends UnaryExpression with NonSQLExpression {
override def nullable: Boolean = false
@@ -716,7 +717,8 @@ case class GetExternalRowField(
}
if (${row.value}.isNullAt($index)) {
- throw new RuntimeException("The ${index}th field of input row cannot be null.");
+ throw new RuntimeException("The ${index}th field '$fieldName' of input row " +
+ "cannot be null.");
}
final ${ctx.javaType(dataType)} ${ev.value} = $getField;