diff options
Diffstat (limited to 'sql/catalyst')
3 files changed, 25 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 97f28fad62..d2003fd689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan // TODO: don't swallow original stack trace if it exists @@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, - val startPosition: Option[Int] = None) + val startPosition: Option[Int] = None, + val plan: Option[LogicalPlan] = None) extends Exception with Serializable { def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d8f755a39c..902644e735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,7 +50,9 @@ object RowEncoder { inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => inputObject + FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + + case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -137,6 +139,7 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt + case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -150,19 +153,23 @@ object RowEncoder { private def constructorFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val field = BoundReference(i, f.dataType, f.nullable) + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + val field = BoundReference(i, dt, f.nullable) If( IsNull(field), - Literal.create(null, externalDataTypeFor(f.dataType)), + Literal.create(null, externalDataTypeFor(dt)), constructorFor(field) ) } - CreateExternalRow(fields) + CreateExternalRow(fields, schema) } private def constructorFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => input + FloatType | DoubleType | BinaryType | CalendarIntervalType => input case udt: UserDefinedType[_] => val obj = NewInstance( @@ -216,7 +223,7 @@ object RowEncoder { "toScalaMap", keyData :: valueData :: Nil) - case StructType(fields) => + case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), @@ -225,6 +232,6 @@ object RowEncoder { } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), - CreateExternalRow(convertedFields)) + CreateExternalRow(convertedFields, schema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 75ecbaa453..b95c5dd892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -388,6 +388,8 @@ case class MapObjects private( case a: ArrayType => (i: String) => s".getArray($i)" case _: MapType => (i: String) => s".getMap($i)" case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)" + case DateType => (i: String) => s".getInt($i)" } private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { @@ -485,7 +487,9 @@ case class MapObjects private( * * @param children A list of expression to use as content of the external row. */ -case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression { +case class CreateExternalRow(children: Seq[Expression], schema: StructType) + extends Expression with NonSQLExpression { + override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false @@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val rowClass = classOf[GenericRow].getName + val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") + val schemaField = ctx.addReferenceObj("schema", schema) s""" boolean ${ev.isNull} = false; final Object[] $values = new Object[${children.size}]; @@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with } """ }.mkString("\n") + - s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" + s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);" } } |