aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala11
3 files changed, 25 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 97f28fad62..d2003fd689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
// TODO: don't swallow original stack trace if it exists
@@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi
class AnalysisException protected[sql] (
val message: String,
val line: Option[Int] = None,
- val startPosition: Option[Int] = None)
+ val startPosition: Option[Int] = None,
+ val plan: Option[LogicalPlan] = None)
extends Exception with Serializable {
def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d8f755a39c..902644e735 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -50,7 +50,9 @@ object RowEncoder {
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => inputObject
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
+
+ case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -137,6 +139,7 @@ object RowEncoder {
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
+ case CalendarIntervalType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -150,19 +153,23 @@ object RowEncoder {
private def constructorFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
- val field = BoundReference(i, f.dataType, f.nullable)
+ val dt = f.dataType match {
+ case p: PythonUserDefinedType => p.sqlType
+ case other => other
+ }
+ val field = BoundReference(i, dt, f.nullable)
If(
IsNull(field),
- Literal.create(null, externalDataTypeFor(f.dataType)),
+ Literal.create(null, externalDataTypeFor(dt)),
constructorFor(field)
)
}
- CreateExternalRow(fields)
+ CreateExternalRow(fields, schema)
}
private def constructorFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => input
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => input
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -216,7 +223,7 @@ object RowEncoder {
"toScalaMap",
keyData :: valueData :: Nil)
- case StructType(fields) =>
+ case schema @ StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
@@ -225,6 +232,6 @@ object RowEncoder {
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
- CreateExternalRow(convertedFields))
+ CreateExternalRow(convertedFields, schema))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 75ecbaa453..b95c5dd892 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -388,6 +388,8 @@ case class MapObjects private(
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+ case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
+ case DateType => (i: String) => s".getInt($i)"
}
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
@@ -485,7 +487,9 @@ case class MapObjects private(
*
* @param children A list of expression to use as content of the external row.
*/
-case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
+case class CreateExternalRow(children: Seq[Expression], schema: StructType)
+ extends Expression with NonSQLExpression {
+
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
@@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
- val rowClass = classOf[GenericRow].getName
+ val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
+ val schemaField = ctx.addReferenceObj("schema", schema)
s"""
boolean ${ev.isNull} = false;
final Object[] $values = new Object[${children.size}];
@@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
}
"""
}.mkString("\n") +
- s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+ s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
}
}