aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala64
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala8
3 files changed, 54 insertions, 69 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c4af284f73..1c7720afe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection {
}
}
- val array = Invoke(
- MapObjects(mapFunction, getPath, dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- val wrappedArray = StaticInvoke(
- scala.collection.mutable.WrappedArray.getClass,
- ObjectType(classOf[Seq[_]]),
- "make",
- array :: Nil)
-
- if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
- wrappedArray
- } else {
- // Convert to another type using `to`
- val cls = mirror.runtimeClass(t.typeSymbol.asClass)
- import scala.collection.generic.CanBuildFrom
- import scala.reflect.ClassTag
-
- // Some canBuildFrom methods take an implicit ClassTag parameter
- val cbfParams = try {
- cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
- StaticInvoke(
- ClassTag.getClass,
- ObjectType(classOf[ClassTag[_]]),
- "apply",
- StaticInvoke(
- cls,
- ObjectType(classOf[Class[_]]),
- "getClass"
- ) :: Nil
- ) :: Nil
- } catch {
- case _: NoSuchMethodException => Nil
- }
-
- Invoke(
- wrappedArray,
- "to",
- ObjectType(cls),
- StaticInvoke(
- cls,
- ObjectType(classOf[CanBuildFrom[_, _, _]]),
- "canBuildFrom",
- cbfParams
- ) :: Nil
- )
+ val cls = t.dealias.companion.decl(TermName("newBuilder")) match {
+ case NoSymbol => classOf[Seq[_]]
+ case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
+ MapObjects(mapFunction, getPath, dataType, Some(cls))
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 771ac28e51..bb584f7d08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.Modifier
+import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
@@ -429,24 +430,34 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+ * or None (returning ArrayType)
*/
def apply(
function: Expression => Expression,
inputData: Expression,
- elementType: DataType): MapObjects = {
- val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
- val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+ elementType: DataType,
+ customCollectionCls: Option[Class[_]] = None): MapObjects = {
+ val id = curId.getAndIncrement()
+ val loopValue = s"MapObjects_loopValue$id"
+ val loopIsNull = s"MapObjects_loopIsNull$id"
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
- MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
+ val builderValue = s"MapObjects_builderValue$id"
+ MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
+ customCollectionCls, builderValue)
}
}
/**
* Applies the given expression to every element of a collection of items, returning the result
- * as an ArrayType. This is similar to a typical map operation, but where the lambda function
- * is expressed using catalyst expressions.
+ * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
+ * function is expressed using catalyst expressions.
+ *
+ * The type of the result is determined as follows:
+ * - ArrayType - when customCollectionCls is None
+ * - ObjectType(collection) - when customCollectionCls contains a collection class
*
- * The following collection ObjectTypes are currently supported:
+ * The following collection ObjectTypes are currently supported on input:
* Seq, Array, ArrayData, java.util.List
*
* @param loopValue the name of the loop variable that used when iterate the collection, and used
@@ -458,13 +469,19 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+ * or None (returning ArrayType)
+ * @param builderValue The name of the builder variable used to construct the resulting collection
+ * (used only when returning ObjectType)
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
- inputData: Expression) extends Expression with NonSQLExpression {
+ inputData: Expression,
+ customCollectionCls: Option[Class[_]],
+ builderValue: String) extends Expression with NonSQLExpression {
override def nullable: Boolean = inputData.nullable
@@ -474,7 +491,8 @@ case class MapObjects private(
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def dataType: DataType =
- ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
+ customCollectionCls.map(ObjectType.apply).getOrElse(
+ ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
@@ -557,15 +575,33 @@ case class MapObjects private(
case _ => s"$loopIsNull = $loopValue == null;"
}
+ val (initCollection, addElement, getResult): (String, String => String, String) =
+ customCollectionCls match {
+ case Some(cls) =>
+ // collection
+ val collObjectName = s"${cls.getName}$$.MODULE$$"
+ val getBuilderVar = s"$collObjectName.newBuilder()"
+
+ (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
+ $builderValue.sizeHint($dataLength);""",
+ genValue => s"$builderValue.$$plus$$eq($genValue);",
+ s"(${cls.getName}) $builderValue.result();")
+ case None =>
+ // array
+ (s"""$convertedType[] $convertedArray = null;
+ $convertedArray = $arrayConstructor;""",
+ genValue => s"$convertedArray[$loopIndex] = $genValue;",
+ s"new ${classOf[GenericArrayData].getName}($convertedArray);")
+ }
+
val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${genInputData.isNull}) {
$determineCollectionType
- $convertedType[] $convertedArray = null;
int $dataLength = $getLength;
- $convertedArray = $arrayConstructor;
+ $initCollection
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
@@ -574,15 +610,15 @@ case class MapObjects private(
${genFunction.code}
if (${genFunction.isNull}) {
- $convertedArray[$loopIndex] = null;
+ ${addElement("null")}
} else {
- $convertedArray[$loopIndex] = $genFunctionValue;
+ ${addElement(genFunctionValue)}
}
$loopIndex += 1;
}
- ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
+ ${ev.value} = $getResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 650a35398f..70ad064f93 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
-
- // Check whether conversion is skipped when using WrappedArray[_] supertype
- // (would otherwise needlessly add overhead)
- import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
- val seqDeserializer = deserializerFor[Seq[Int]]
- assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
- scala.collection.mutable.WrappedArray.getClass)
- assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}
private val dataTypeForComplexData = dataTypeFor[ComplexData]