aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala15
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala121
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala59
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala30
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala28
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala2
34 files changed, 430 insertions, 181 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index d82ba2456d..88914fa875 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
- row.update(3, sm.colPtrs.toSeq)
- row.update(4, sm.rowIndices.toSeq)
- row.update(5, sm.values.toSeq)
+ row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
+ row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
+ row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)
case dm: DenseMatrix =>
@@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
- row.update(5, dm.values.toSeq)
+ row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
@@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
- val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
- val colPtrs =
- row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray
- val rowIndices =
- row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray
+ val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
+ val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 23c2c16d68..89a1818db0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
- row.update(2, indices.toSeq)
- row.update(3, values.toSeq)
+ row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, values.toSeq)
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
}
}
@@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
- val indices =
- row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new SparseVector(size, indices, values)
case 1 =>
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new DenseVector(values)
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index bc345dcd00..f7cea13688 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.ArrayData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -50,4 +51,5 @@ public interface SpecializedGetters {
InternalRow getStruct(int ordinal, int numFields);
+ ArrayData getArray(int ordinal);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d1d89a1f48..22452c0f20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -55,7 +55,6 @@ object CatalystTypeConverters {
private def isWholePrimitive(dt: DataType): Boolean = dt match {
case dt if isPrimitive(dt) => true
- case ArrayType(elementType, _) => isWholePrimitive(elementType)
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
case _ => false
}
@@ -154,39 +153,41 @@ object CatalystTypeConverters {
/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
- elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
private[this] val elementConverter = getConverterForType(elementType)
private[this] val isNoChange = isWholePrimitive(elementType)
- override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
+ override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
- case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
- case s: Seq[_] => s.map(elementConverter.toCatalyst)
+ case a: Array[_] =>
+ new GenericArrayData(a.map(elementConverter.toCatalyst))
+ case s: Seq[_] =>
+ new GenericArrayData(s.map(elementConverter.toCatalyst).toArray)
case i: JavaIterable[_] =>
val iter = i.iterator
- var convertedIterable: List[Any] = List()
+ val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any]
while (iter.hasNext) {
val item = iter.next()
- convertedIterable :+= elementConverter.toCatalyst(item)
+ convertedIterable += elementConverter.toCatalyst(item)
}
- convertedIterable
+ new GenericArrayData(convertedIterable.toArray)
}
}
- override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
+ override def toScala(catalystValue: ArrayData): Seq[Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
- catalystValue
+ catalystValue.toArray()
} else {
- catalystValue.map(elementConverter.toScala)
+ catalystValue.toArray().map(elementConverter.toScala)
}
}
override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
- toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]])
+ toScala(row.getArray(column))
}
private case class MapConverter(
@@ -402,9 +403,9 @@ object CatalystTypeConverters {
case t: Timestamp => TimestampConverter.toCatalyst(t)
case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
- case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
- case arr: Array[Any] => arr.map(convertToCatalyst)
+ case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
case m: Map[_, _] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index a5999e64ec..486ba03654 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
getAs[InternalRow](ordinal, null)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)
+
override def toString: String = s"[${this.mkString(",")}]"
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 371681b5d4..45709c1c8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
- val value = ctx.getColumn("i", dataType, ordinal)
+ val value = ctx.getValue("i", dataType, ordinal.toString)
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
$javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8c01c13c9c..43be11c48a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
val elementCast = cast(from.elementType, to.elementType)
- buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
+ // TODO: Could be faster?
+ buildCast[ArrayData](_, array => {
+ val length = array.numElements()
+ val values = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = elementCast(array.get(i))
+ }
+ i += 1
+ }
+ new GenericArrayData(values)
+ })
}
private[this] def castMap(from: MapType, to: MapType): Any => Any = {
@@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArrayCode(
from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
-
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
val fromElementNull = ctx.freshName("feNull")
val fromElementPrim = ctx.freshName("fePrim")
val toElementNull = ctx.freshName("teNull")
val toElementPrim = ctx.freshName("tePrim")
val size = ctx.freshName("n")
val j = ctx.freshName("j")
- val result = ctx.freshName("result")
+ val values = ctx.freshName("values")
(c, evPrim, evNull) =>
s"""
- final int $size = $c.size();
- final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
+ final int $size = $c.numElements();
+ final Object[] $values = new Object[$size];
for (int $j = 0; $j < $size; $j ++) {
- if ($c.apply($j) == null) {
- $result.update($j, null);
+ if ($c.isNullAt($j)) {
+ $values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(from.elementType)} $fromElementPrim =
- (${ctx.boxedType(from.elementType)}) $c.apply($j);
+ ${ctx.getValue(c, from.elementType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
if ($toElementNull) {
- $result.update($j, null);
+ $values[$j] = null;
} else {
- $result.update($j, $toElementPrim);
+ $values[$j] = $toElementPrim;
}
}
}
- $evPrim = $result;
+ $evPrim = new $arrayClass($values);
"""
}
@@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType)
$result.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
+ ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 092f4c9fb0..c39e0df6fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -100,17 +100,18 @@ class CodeGenContext {
}
/**
- * Returns the code to access a column in Row for a given DataType.
+ * Returns the code to access a value in `SpecializedGetters` for a given DataType.
*/
- def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
+ def getValue(getter: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
- case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
- case StringType => s"$row.getUTF8String($ordinal)"
- case BinaryType => s"$row.getBinary($ordinal)"
- case CalendarIntervalType => s"$row.getInterval($ordinal)"
- case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
- case _ => s"($jt)$row.get($ordinal)"
+ case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
+ case StringType => s"$getter.getUTF8String($ordinal)"
+ case BinaryType => s"$getter.getBinary($ordinal)"
+ case CalendarIntervalType => s"$getter.getInterval($ordinal)"
+ case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
+ case a: ArrayType => s"$getter.getArray($ordinal)"
+ case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
}
}
@@ -152,8 +153,8 @@ class CodeGenContext {
case StringType => "UTF8String"
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
- case _: ArrayType => s"scala.collection.Seq"
- case _: MapType => s"scala.collection.Map"
+ case _: ArrayType => "ArrayData"
+ case _: MapType => "scala.collection.Map"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
@@ -214,7 +215,9 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
- case other => s"$c1.compare($c2)"
+ case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
+ case _ => throw new IllegalArgumentException(
+ "cannot generate compare code for un-comparable type")
}
/**
@@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
- classOf[CalendarInterval].getName
+ classOf[CalendarInterval].getName,
+ classOf[ArrayData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7be60114ce..a662357fb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val nestedStructEv = GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
createCodeForStruct(ctx, nestedStructEv, st)
case _ =>
GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2d92dcf23a..1a00dbc254 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullSafeEval(value: Any): Int = child.dataType match {
- case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
- case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
+ case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
+ case _: MapType => value.asInstanceOf[Map[Any, Any]].size
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
+ val sizeCall = child.dataType match {
+ case _: ArrayType => "numElements()"
+ case _: MapType => "size()"
+ }
+ nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 0517050a45..a145dfb4bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -18,12 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.unsafe.types.UTF8String
-
-import scala.collection.mutable
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -46,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
- children.map(_.eval(input))
+ new GenericArrayData(children.map(_.eval(input)).toArray)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
s"""
- boolean ${ev.isNull} = false;
- $arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size});
+ final boolean ${ev.isNull} = false;
+ final Object[] values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
+ values[$i] = null;
} else {
- ${ev.primitive}.update($i, ${eval.primitive});
+ values[$i] = ${eval.primitive};
}
"""
- }.mkString("\n")
+ }.mkString("\n") +
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
}
override def prettyName: String = "array"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 6331a9eb60..99393c9c76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -57,7 +57,8 @@ object ExtractValue {
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
- GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
+ ordinal, fields.length, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
@@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
+ ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)};
}
"""
})
@@ -134,6 +135,7 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
+ numFields: Int,
containsNull: Boolean) extends UnaryExpression {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
@@ -141,26 +143,45 @@ case class GetArrayStructFields(
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any = {
- input.asInstanceOf[Seq[InternalRow]].map { row =>
- if (row == null) null else row.get(ordinal, field.dataType)
+ val array = input.asInstanceOf[ArrayData]
+ val length = array.numElements()
+ val result = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ result(i) = null
+ } else {
+ val row = array.getStruct(i, numFields)
+ if (row.isNullAt(ordinal)) {
+ result(i) = null
+ } else {
+ result(i) = row.get(ordinal, field.dataType)
+ }
+ }
+ i += 1
}
+ new GenericArrayData(result)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = "scala.collection.mutable.ArraySeq"
- // TODO: consider using Array[_] for ArrayType child to avoid
- // boxing of primitives
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
s"""
- final int n = $eval.size();
- final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
+ final int n = $eval.numElements();
+ final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
- InternalRow row = (InternalRow) $eval.apply(j);
- if (row != null && !row.isNullAt($ordinal)) {
- values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
+ if ($eval.isNullAt(j)) {
+ values[j] = null;
+ } else {
+ final InternalRow row = $eval.getStruct(j, $numFields);
+ if (row.isNullAt($ordinal)) {
+ values[j] = null;
+ } else {
+ values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
+ }
}
}
- ${ev.primitive} = (${ctx.javaType(dataType)}) values;
+ ${ev.primitive} = new $arrayClass(values);
"""
})
}
@@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
- val baseValue = value.asInstanceOf[Seq[_]]
+ val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
- if (index >= baseValue.size || index < 0) {
+ if (index >= baseValue.numElements() || index < 0) {
null
} else {
- baseValue(index)
+ baseValue.get(index)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- final int index = (int)$eval2;
- if (index >= $eval1.size() || index < 0) {
+ final int index = (int) $eval2;
+ if (index >= $eval1.numElements() || index < 0) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
+ ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")};
}
"""
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 2dbcf2830f..8064235c64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
child.dataType match {
case ArrayType(_, _) =>
- val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
- if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
+ val inputArray = child.eval(input).asInstanceOf[ArrayData]
+ if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
if (inputMap == null) Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5b3a64a096..79c0ca56a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression])
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
- case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
+ case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String])
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
@@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression])
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+ s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
@@ -665,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def nullSafeEval(string: Any, regex: Any): Any = {
- string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq
+ val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
+ new GenericArrayData(strings.asInstanceOf[Array[Any]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
- s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer(
- java.util.Arrays.asList($str.split($pattern, -1)));""")
+ // Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
+ s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""")
}
override def prettyName: String = "split"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 813c620096..29d706dcb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
+ case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
+ Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
new file mode 100644
index 0000000000..14a7285877
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+
+abstract class ArrayData extends SpecializedGetters with Serializable {
+ // todo: remove this after we handle all types.(map type need special getter)
+ def get(ordinal: Int): Any
+
+ def numElements(): Int
+
+ // todo: need a more efficient way to iterate array type.
+ def toArray(): Array[Any] = {
+ val n = numElements()
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ if (isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = get(i)
+ }
+ i += 1
+ }
+ values
+ }
+
+ override def toString(): String = toArray.mkString("[", ",", "]")
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[ArrayData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[ArrayData]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numElements()
+ if (len != other.numElements()) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = get(i)
+ val o2 = other.get(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numElements()
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ get(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
new file mode 100644
index 0000000000..7992ba947c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval}
+
+class GenericArrayData(array: Array[Any]) extends ArrayData {
+ private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T]
+
+ override def toArray(): Array[Any] = array
+
+ override def get(ordinal: Int): Any = array(ordinal)
+
+ override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
+
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+
+ override def getDecimal(ordinal: Int): Decimal = getAs(ordinal)
+
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+
+ override def numElements(): Int = array.length
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a517da9872..4f35b653d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date}
import java.util.{TimeZone, Calendar}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -730,13 +731,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
- InternalRow(
- Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")),
- Map(
- UTF8String.fromString("a") -> UTF8String.fromString("123"),
- UTF8String.fromString("b") -> UTF8String.fromString("abc"),
- UTF8String.fromString("c") -> UTF8String.fromString("")),
- InternalRow(0)),
+ Row(
+ Seq("123", "abc", ""),
+ Map("a" ->"123", "b" -> "abc", "c" -> ""),
+ Row(0)),
StructType(Seq(
StructField("a",
ArrayType(StringType, containsNull = false), nullable = true),
@@ -756,13 +754,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("l", LongType, nullable = true)))))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(
+ checkEvaluation(ret, Row(
Seq(123, null, null),
- Map(
- UTF8String.fromString("a") -> true,
- UTF8String.fromString("b") -> true,
- UTF8String.fromString("c") -> false),
- InternalRow(0L)))
+ Map("a" -> true, "b" -> true, "c" -> false),
+ Row(0L)))
}
test("case between string and interval") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5de5ddce97..3fa246b69d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
expr.dataType match {
case ArrayType(StructType(fields), containsNull) =>
val field = fields.find(_.name == fieldName).get
- GetArrayStructFields(expr, field, fields.indexOf(field), containsNull)
+ GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index aeeb0e4527..f26f41fb75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -158,8 +158,8 @@ package object debug {
case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
- case (s: Seq[_], ArrayType(elemType, _)) =>
- s.foreach(typeCheck(_, elemType))
+ case (a: ArrayData, ArrayType(elemType, _)) =>
+ a.toArray().foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
m.keys.foreach(typeCheck(_, keyType))
m.values.foreach(typeCheck(_, valueType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 3c38916fd7..ef1c6e57dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -134,8 +134,19 @@ object EvaluatePython {
}
new GenericInternalRowWithSchema(values, struct)
- case (seq: Seq[Any], array: ArrayType) =>
- seq.map(x => toJava(x, array.elementType)).asJava
+ case (a: ArrayData, array: ArrayType) =>
+ val length = a.numElements()
+ val values = new java.util.ArrayList[Any](length)
+ var i = 0
+ while (i < length) {
+ if (a.isNullAt(i)) {
+ values.add(null)
+ } else {
+ values.add(toJava(a.get(i), array.elementType))
+ }
+ i += 1
+ }
+ values
case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
@@ -190,10 +201,10 @@ object EvaluatePython {
case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => fromJava(e, elementType)}.toSeq
+ new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray)
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq
+ new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 78da2840da..9329148aa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame}
private[sql] object FrequentItems extends Logging {
@@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
- val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
+ val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
val resultRow = InternalRow(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index 0eb3b04007..04ab5e2217 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -125,7 +125,7 @@ private[sql] object InferSchema {
* Convert NullType to StringType and remove StructTypes with no fields
*/
private def canonicalizeType: DataType => Option[DataType] = {
- case at@ArrayType(elementType, _) =>
+ case at @ ArrayType(elementType, _) =>
for {
canonicalType <- canonicalizeType(elementType)
} yield {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 381e7ed544..1c309f8794 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -110,8 +110,13 @@ private[sql] object JacksonParser {
case (START_OBJECT, st: StructType) =>
convertObject(factory, parser, st)
+ case (START_ARRAY, st: StructType) =>
+ // SPARK-3308: support reading top level JSON arrays and take every element
+ // in such an array as a row
+ convertArray(factory, parser, st)
+
case (START_ARRAY, ArrayType(st, _)) =>
- convertList(factory, parser, st)
+ convertArray(factory, parser, st)
case (START_OBJECT, ArrayType(st, _)) =>
// the business end of SPARK-3308:
@@ -165,16 +170,16 @@ private[sql] object JacksonParser {
builder.result()
}
- private def convertList(
+ private def convertArray(
factory: JsonFactory,
parser: JsonParser,
- schema: DataType): Seq[Any] = {
- val builder = Seq.newBuilder[Any]
+ elementType: DataType): ArrayData = {
+ val values = scala.collection.mutable.ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
- builder += convertField(factory, parser, schema)
+ values += convertField(factory, parser, elementType)
}
- builder.result()
+ new GenericArrayData(values.toArray)
}
private def parseJson(
@@ -201,12 +206,15 @@ private[sql] object JacksonParser {
val parser = factory.createParser(record)
parser.nextToken()
- // to support both object and arrays (see SPARK-3308) we'll start
- // by converting the StructType schema to an ArrayType and let
- // convertField wrap an object into a single value array when necessary.
- convertField(factory, parser, ArrayType(schema)) match {
+ convertField(factory, parser, schema) match {
case null => failedRecord(record)
- case list: Seq[InternalRow @unchecked] => list
+ case row: InternalRow => row :: Nil
+ case array: ArrayData =>
+ if (array.numElements() == 0) {
+ Nil
+ } else {
+ array.toArray().map(_.asInstanceOf[InternalRow])
+ }
case _ =>
sys.error(
s"Failed to parse record $record. Please make sure that each line of the file " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index e00bd90edb..172db8362a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = elementConverter
- override def end(): Unit = updater.set(currentArray)
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
// NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the
// next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index ea51650fe9..2332a36468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.parquet
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.ArrayData
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -32,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = Seq[T]
+ type ArrayScalaType[T] = ArrayData
type StructScalaType[T] = InternalRow
type MapScalaType[K, V] = Map[K, V]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 78ecfad1d5..79dd16b7b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
array: CatalystConverter.ArrayScalaType[_]): Unit = {
val elementType = schema.elementType
writer.startGroup()
- if (array.size > 0) {
+ if (array.numElements() > 0) {
if (schema.containsNull) {
writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
+ while (i < array.numElements()) {
writer.startGroup()
- if (array(i) != null) {
+ if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array(i))
+ writeValue(elementType, array.get(i))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
} else {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
- writeValue(elementType, array(i))
+ while (i < array.numElements()) {
+ writeValue(elementType, array.get(i))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 72c42f4fe3..9e61d06f40 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -30,7 +30,6 @@ import org.junit.*;
import scala.collection.JavaConversions;
import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
import java.io.Serializable;
import java.util.Arrays;
@@ -168,10 +167,10 @@ public class JavaDataFrameSuite {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
- Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
+ Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
- Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer)));
+ Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer)));
Seq<String> d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 45c9f06941..77ed4a9c0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
- override def serialize(obj: Any): Seq[Double] = {
+ override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyDenseVector =>
- features.data.toSeq
+ new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
}
override def deserialize(datum: Any): MyDenseVector = {
datum match {
- case data: Seq[_] =>
- new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
+ case data: ArrayData =>
+ new MyDenseVector(data.toArray.map(_.asInstanceOf[Double]))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 5e189c3563..cfb03ff485 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -67,12 +67,12 @@ case class AllDataTypesScan(
override def schema: StructType = userSpecifiedSchema
- override def needConversion: Boolean = false
+ override def needConversion: Boolean = true
override def buildScan(): RDD[Row] = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
- InternalRow(
- UTF8String.fromString(s"str_$i"),
+ Row(
+ s"str_$i",
s"str_$i".getBytes(),
i % 2 == 0,
i.toByte,
@@ -81,19 +81,19 @@ case class AllDataTypesScan(
i.toLong,
i.toFloat,
i.toDouble,
- Decimal(new java.math.BigDecimal(i)),
- Decimal(new java.math.BigDecimal(i)),
- DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)),
- DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
- UTF8String.fromString(s"varchar_$i"),
+ new java.math.BigDecimal(i),
+ new java.math.BigDecimal(i),
+ new Date(1970, 1, 1),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
Seq(i, i + 1),
- Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
- Map(i -> UTF8String.fromString(i.toString)),
- Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
- InternalRow(i, UTF8String.fromString(i.toString)),
- InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
- InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
- }.asInstanceOf[RDD[Row]]
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"),
+ Row(Seq(new Date(1970, 1, i + 1)))))
+ }
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index f467500259..5926ef9aa3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -52,9 +52,8 @@ import scala.collection.JavaConversions._
* java.sql.Timestamp
* Complex Types =>
* Map: scala.collection.immutable.Map
- * List: scala.collection.immutable.Seq
- * Struct:
- * [[org.apache.spark.sql.catalyst.InternalRow]]
+ * List: [[org.apache.spark.sql.types.ArrayData]]
+ * Struct: [[org.apache.spark.sql.catalyst.InternalRow]]
* Union: NOT SUPPORTED YET
* The Complex types plays as a container, which can hold arbitrary data types.
*
@@ -297,7 +296,10 @@ private[hive] trait HiveInspectors {
}.toMap
case li: StandardConstantListObjectInspector =>
// take the value from the list inspector object, rather than the input data
- li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq
+ val values = li.getWritableConstantValue
+ .map(unwrap(_, li.getListElementObjectInspector))
+ .toArray
+ new GenericArrayData(values)
// if the value is null, we don't care about the object inspector type
case _ if data == null => null
case poi: VoidObjectInspector => null // always be null for void object inspector
@@ -339,7 +341,10 @@ private[hive] trait HiveInspectors {
}
case li: ListObjectInspector =>
Option(li.getList(data))
- .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq)
+ .map { l =>
+ val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray
+ new GenericArrayData(values)
+ }
.orNull
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
@@ -391,7 +396,13 @@ private[hive] trait HiveInspectors {
case loi: ListObjectInspector =>
val wrapper = wrapperFor(loi.getListElementObjectInspector)
- (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null
+ (o: Any) => {
+ if (o != null) {
+ seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper))
+ } else {
+ null
+ }
+ }
case moi: MapObjectInspector =>
// The Predef.Map is scala.collection.immutable.Map.
@@ -520,7 +531,7 @@ private[hive] trait HiveInspectors {
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
val tpe = dataType.asInstanceOf[ArrayType].elementType
- a.asInstanceOf[Seq[_]].foreach {
+ a.asInstanceOf[ArrayData].toArray().foreach {
v => list.add(wrap(v, x.getListElementObjectInspector, tpe))
}
list
@@ -634,7 +645,8 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
} else {
val list = new java.util.ArrayList[Object]()
- value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt)))
+ value.asInstanceOf[ArrayData].toArray()
+ .foreach(v => list.add(wrap(v, listObjectInspector, dt)))
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
}
case Literal(value, MapType(keyType, valueType, _)) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 741c705e2a..7e3342cc84 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -176,13 +176,13 @@ case class ScriptTransformation(
val prevLine = curLine
curLine = reader.readLine()
if (!ioschema.schemaLess) {
- new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")))
- .asInstanceOf[Array[Any]])
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ .map(CatalystTypeConverters.convertToCatalyst))
} else {
- new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2))
- .asInstanceOf[Array[Any]])
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
+ .map(CatalystTypeConverters.convertToCatalyst))
}
} else {
val ret = deserialize()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 8732e9abf8..4a13022edd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction(
// if pivotResult is true, we will get a Seq having the same size with the size
// of the window frame. At here, we will return the result at the position of
// index in the output buffer.
- outputBuffer.asInstanceOf[Seq[Any]].get(index)
+ outputBuffer.asInstanceOf[ArrayData].get(index)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 0330013f53..f719f2e06a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
test("wrap / unwrap Array Type") {
val dt = ArrayType(dataTypes(0))
- val d = row(0) :: row(0) :: Nil
+ val d = new GenericArrayData(Array(row(0), row(0)))
checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
checkValue(d,