aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala66
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala9
7 files changed, 141 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index da37eb00dc..206ae2f0e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
- private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
+ private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
@@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection {
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
- * This method help us "remember" the required data type by adding a `UpCast`. Note that we
- * don't need to cast struct type because there must be `UnresolvedExtractValue` or
- * `GetStructField` wrapping it, thus we only need to handle leaf type.
+ * This method help us "remember" the required data type by adding a `UpCast`. Note that we
+ * only need to do this for leaf nodes.
*/
def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
+ case _: ArrayType => expr
+ // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
+ // it's not trivial to support by-name resolution for StructType inside MapType.
case _ => UpCast(expr, expected, walkedTypePath)
}
@@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(_, elementNullable) = schemaFor(elementType)
+ val className = getClassNameFromType(elementType)
+ val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
- // TODO: add runtime null check for primitive array
- val primitiveMethod = elementType match {
- case t if t <:< definitions.IntTpe => Some("toIntArray")
- case t if t <:< definitions.LongTpe => Some("toLongArray")
- case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
- case t if t <:< definitions.FloatTpe => Some("toFloatArray")
- case t if t <:< definitions.ShortTpe => Some("toShortArray")
- case t if t <:< definitions.ByteTpe => Some("toByteArray")
- case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
- case _ => None
+ val mapFunction: Expression => Expression = p => {
+ val converter = deserializerFor(elementType, Some(p), newTypePath)
+ if (elementNullable) {
+ converter
+ } else {
+ AssertNotNull(converter, newTypePath)
+ }
}
- primitiveMethod.map { method =>
- Invoke(getPath, method, arrayClassFor(elementType))
- }.getOrElse {
- val className = getClassNameFromType(elementType)
- val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
- Invoke(
- MapObjects(
- p => deserializerFor(elementType, Some(p), newTypePath),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- arrayClassFor(elementType))
+ val arrayData = UnresolvedMapObjects(mapFunction, getPath)
+ val arrayCls = arrayClassFor(elementType)
+
+ if (elementNullable) {
+ Invoke(arrayData, "array", arrayCls)
+ } else {
+ val primitiveMethod = elementType match {
+ case t if t <:< definitions.IntTpe => "toIntArray"
+ case t if t <:< definitions.LongTpe => "toLongArray"
+ case t if t <:< definitions.DoubleTpe => "toDoubleArray"
+ case t if t <:< definitions.FloatTpe => "toFloatArray"
+ case t if t <:< definitions.ShortTpe => "toShortArray"
+ case t if t <:< definitions.ByteTpe => "toByteArray"
+ case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+ case other => throw new IllegalStateException("expect primitive array element type " +
+ "but got " + other)
+ }
+ Invoke(arrayData, primitiveMethod, arrayCls)
}
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
+ val Schema(_, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
- if (nullable) {
+ if (elementNullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
@@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection {
case NoSymbol => classOf[Seq[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
- MapObjects(mapFunction, getPath, dataType, Some(cls))
+ UnresolvedMapObjects(mapFunction, getPath, Some(cls))
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2d53d2424a..c698ca6a83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
@@ -2227,8 +2227,21 @@ class Analyzer(
validateTopLevelTupleFields(deserializer, inputs)
val resolved = resolveExpression(
deserializer, LocalRelation(inputs), throws = true)
- validateNestedTupleFields(resolved)
- resolved
+ val result = resolved transformDown {
+ case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
+ inputData.dataType match {
+ case ArrayType(et, _) =>
+ val expr = MapObjects(func, inputData, et, cls) transformUp {
+ case UnresolvedExtractValue(child, fieldName) if child.resolved =>
+ ExtractValue(child, fieldName, resolver)
+ }
+ expr
+ case other =>
+ throw new AnalysisException("need an array field but got " + other.simpleString)
+ }
+ }
+ validateNestedTupleFields(result)
+ result
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index de1594d119..ef88cfb543 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -68,7 +68,7 @@ object ExtractValue {
case StructType(_) =>
s"Field name should be String Literal, but it's $extraction"
case other =>
- s"Can't extract value from $child"
+ s"Can't extract value from $child: need struct type but got ${other.simpleString}"
}
throw new AnalysisException(errorMsg)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index bb584f7d08..00e2ac91e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -448,6 +448,17 @@ object MapObjects {
}
}
+case class UnresolvedMapObjects(
+ function: Expression => Expression,
+ child: Expression,
+ customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
+ override lazy val resolved = false
+
+ override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
+ throw new UnsupportedOperationException("not resolved")
+ }
+}
+
/**
* Applies the given expression to every element of a collection of items, returning the result
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
@@ -581,17 +592,24 @@ case class MapObjects private(
// collection
val collObjectName = s"${cls.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"
-
- (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
- $builderValue.sizeHint($dataLength);""",
+ (
+ s"""
+ ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
+ $builderValue.sizeHint($dataLength);
+ """,
genValue => s"$builderValue.$$plus$$eq($genValue);",
- s"(${cls.getName}) $builderValue.result();")
+ s"(${cls.getName}) $builderValue.result();"
+ )
case None =>
// array
- (s"""$convertedType[] $convertedArray = null;
- $convertedArray = $arrayConstructor;""",
+ (
+ s"""
+ $convertedType[] $convertedArray = null;
+ $convertedArray = $arrayConstructor;
+ """,
genValue => s"$convertedArray[$loopIndex] = $genValue;",
- s"new ${classOf[GenericArrayData].getName}($convertedArray);")
+ s"new ${classOf[GenericArrayData].getName}($convertedArray);"
+ )
}
val code = s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 802397d50e..e5a3e1fd37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)
+case class ArrayClass(arr: Seq[StringIntClass])
+
+case class NestedArrayClass(nestedArr: Array[ArrayClass])
+
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")
@@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest {
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}
+ test("real type doesn't match encoder schema but they are compatible: array") {
+ val encoder = ExpressionEncoder[ArrayClass]
+ val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
+ val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
+ encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
+ }
+
+ test("real type doesn't match encoder schema but they are compatible: nested array") {
+ val encoder = ExpressionEncoder[NestedArrayClass]
+ val et = new StructType().add("arr", ArrayType(
+ new StructType().add("a", "int").add("b", "int").add("c", "int")))
+ val attrs = Seq('nestedArr.array(et))
+ val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
+ val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
+ encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
+ }
+
+ test("the real type is not compatible with encoder schema: non-array field") {
+ val encoder = ExpressionEncoder[ArrayClass]
+ val attrs = Seq('arr.int)
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+ "need an array field but got int")
+ }
+
+ test("the real type is not compatible with encoder schema: array element type") {
+ val encoder = ExpressionEncoder[ArrayClass]
+ val attrs = Seq('arr.array(new StructType().add("c", "int")))
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+ "No such struct field a in c")
+ }
+
+ test("the real type is not compatible with encoder schema: nested array element type") {
+ val encoder = ExpressionEncoder[NestedArrayClass]
+
+ withClue("inner element is not array") {
+ val attrs = Seq('nestedArr.array(new StructType().add("arr", "int")))
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+ "need an array field but got int")
+ }
+
+ withClue("nested array element type is not compatible") {
+ val attrs = Seq('nestedArr.array(new StructType()
+ .add("arr", ArrayType(new StructType().add("c", "int")))))
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+ "No such struct field a in c")
+ }
+ }
+
test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index 174378304d..e266ae55cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
extends Aggregator[T, (Boolean, T), T] {
- private val encoder = implicitly[Encoder[T]]
+ @transient private val encoder = implicitly[Encoder[T]]
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 68e071a1a6..5b5cd28ad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}
+ test("as seq of case class - reorder fields by name") {
+ val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a"))))
+ val ds = df.as[Seq[ClassData]]
+ assert(ds.collect() === Array(
+ Seq(ClassData("a", 0)),
+ Seq(ClassData("a", 1)),
+ Seq(ClassData("a", 2))))
+ }
+
test("map") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(