diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-30 10:04:30 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-30 10:04:30 -0700 |
commit | c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch) | |
tree | 582bad5631cde3bac3b5c69e1f22b3c4098de684 /sql | |
parent | 7492a33fdd074446c30c657d771a69932a00246d (diff) | |
download | spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2 spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip |
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7724 from cloud-fan/array-data and squashes the following commits:
d0408a1 [Wenchen Fan] fix python
661e608 [Wenchen Fan] rebase
f39256c [Wenchen Fan] fix hive...
6dbfa6f [Wenchen Fan] fix hive again...
8cb8842 [Wenchen Fan] remove element type parameter from getArray
43e9816 [Wenchen Fan] fix mllib
e719afc [Wenchen Fan] fix hive
4346290 [Wenchen Fan] address comment
d4a38da [Wenchen Fan] remove sizeInBytes and add license
7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'sql')
32 files changed, 417 insertions, 163 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index bc345dcd00..f7cea13688 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -50,4 +51,5 @@ public interface SpecializedGetters { InternalRow getStruct(int ordinal, int numFields); + ArrayData getArray(int ordinal); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d1d89a1f48..22452c0f20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -55,7 +55,6 @@ object CatalystTypeConverters { private def isWholePrimitive(dt: DataType): Boolean = dt match { case dt if isPrimitive(dt) => true - case ArrayType(elementType, _) => isWholePrimitive(elementType) case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) case _ => false } @@ -154,39 +153,41 @@ object CatalystTypeConverters { /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { private[this] val elementConverter = getConverterForType(elementType) private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) - case s: Seq[_] => s.map(elementConverter.toCatalyst) + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) case i: JavaIterable[_] => val iter = i.iterator - var convertedIterable: List[Any] = List() + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] while (iter.hasNext) { val item = iter.next() - convertedIterable :+= elementConverter.toCatalyst(item) + convertedIterable += elementConverter.toCatalyst(item) } - convertedIterable + new GenericArrayData(convertedIterable.toArray) } } - override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null } else if (isNoChange) { - catalystValue + catalystValue.toArray() } else { - catalystValue.map(elementConverter.toScala) + catalystValue.toArray().map(elementConverter.toScala) } } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) + toScala(row.getArray(column)) } private case class MapConverter( @@ -402,9 +403,9 @@ object CatalystTypeConverters { case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.map(convertToCatalyst) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index a5999e64ec..486ba03654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) + override def toString: String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 371681b5d4..45709c1c8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val value = ctx.getColumn("i", dataType, ordinal) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8c01c13c9c..43be11c48a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { val elementCast = cast(from.elementType, to.elementType) - buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + // TODO: Could be faster? + buildCast[ArrayData](_, array => { + val length = array.numElements() + val values = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + values(i) = null + } else { + values(i) = elementCast(array.get(i)) + } + i += 1 + } + new GenericArrayData(values) + }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { @@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArrayCode( from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) - - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") val toElementNull = ctx.freshName("teNull") val toElementPrim = ctx.freshName("tePrim") val size = ctx.freshName("n") val j = ctx.freshName("j") - val result = ctx.freshName("result") + val values = ctx.freshName("values") (c, evPrim, evNull) => s""" - final int $size = $c.size(); - final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size); + final int $size = $c.numElements(); + final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { - if ($c.apply($j) == null) { - $result.update($j, null); + if ($c.isNullAt($j)) { + $values[$j] = null; } else { boolean $fromElementNull = false; ${ctx.javaType(from.elementType)} $fromElementPrim = - (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${ctx.getValue(c, from.elementType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} if ($toElementNull) { - $result.update($j, null); + $values[$j] = null; } else { - $result.update($j, $toElementPrim); + $values[$j] = $toElementPrim; } } } - $evPrim = $result; + $evPrim = new $arrayClass($values); """ } @@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType) $result.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 092f4c9fb0..c39e0df6fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -100,17 +100,18 @@ class CodeGenContext { } /** - * Returns the code to access a column in Row for a given DataType. + * Returns the code to access a value in `SpecializedGetters` for a given DataType. */ - def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + def getValue(getter: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" - case StringType => s"$row.getUTF8String($ordinal)" - case BinaryType => s"$row.getBinary($ordinal)" - case CalendarIntervalType => s"$row.getInterval($ordinal)" - case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.get($ordinal)" + case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case StringType => s"$getter.getUTF8String($ordinal)" + case BinaryType => s"$getter.getBinary($ordinal)" + case CalendarIntervalType => s"$getter.getInterval($ordinal)" + case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" + case a: ArrayType => s"$getter.getArray($ordinal)" + case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. } } @@ -152,8 +153,8 @@ class CodeGenContext { case StringType => "UTF8String" case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" - case _: ArrayType => s"scala.collection.Seq" - case _: MapType => s"scala.collection.Map" + case _: ArrayType => "ArrayData" + case _: MapType => "scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -214,7 +215,9 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case other => s"$c1.compare($c2)" + case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case _ => throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type") } /** @@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeRow].getName, classOf[UTF8String].getName, classOf[Decimal].getName, - classOf[CalendarInterval].getName + classOf[CalendarInterval].getName, + classOf[ArrayData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7be60114ce..a662357fb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val nestedStructEv = GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) createCodeForStruct(ctx, nestedStructEv, st) case _ => GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d92dcf23a..1a00dbc254 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) override def nullSafeEval(value: Any): Int = child.dataType match { - case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size - case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[Map[Any, Any]].size } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + val sizeCall = child.dataType match { + case _: ArrayType => "numElements()" + case _: MapType => "size()" + } + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0517050a45..a145dfb4bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,12 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.unsafe.types.UTF8String - -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -46,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) + new GenericArrayData(children.map(_.eval(input)).toArray) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName s""" - boolean ${ev.isNull} = false; - $arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size}); + final boolean ${ev.isNull} = false; + final Object[] values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" } override def prettyName: String = "array" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6331a9eb60..99393c9c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,8 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) @@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -134,6 +135,7 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, + numFields: Int, containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) @@ -141,26 +143,45 @@ case class GetArrayStructFields( override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { - input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal, field.dataType) + val array = input.asInstanceOf[ArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else { + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } + i += 1 } + new GenericArrayData(result) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = "scala.collection.mutable.ArraySeq" - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" - final int n = $eval.size(); - final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n); + final int n = $eval.numElements(); + final Object[] values = new Object[n]; for (int j = 0; j < n; j++) { - InternalRow row = (InternalRow) $eval.apply(j); - if (row != null && !row.isNullAt($ordinal)) { - values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + if ($eval.isNullAt(j)) { + values[j] = null; + } else { + final InternalRow row = $eval.getStruct(j, $numFields); + if (row.isNullAt($ordinal)) { + values[j] = null; + } else { + values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + } } } - ${ev.primitive} = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = new $arrayClass(values); """ }) } @@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] + val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.size || index < 0) { + if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue(index) + baseValue.get(index) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - final int index = (int)$eval2; - if (index >= $eval1.size() || index < 0) { + final int index = (int) $eval2; + if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2dbcf2830f..8064235c64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => - val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] if (inputMap == null) Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5b3a64a096..79c0ca56a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" @@ -665,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => - s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( - java.util.Arrays.asList($str.split($pattern, -1)));""") + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c620096..29d706dcb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => + Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala new file mode 100644 index 0000000000..14a7285877 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters + +abstract class ArrayData extends SpecializedGetters with Serializable { + // todo: remove this after we handle all types.(map type need special getter) + def get(ordinal: Int): Any + + def numElements(): Int + + // todo: need a more efficient way to iterate array type. + def toArray(): Array[Any] = { + val n = numElements() + val values = new Array[Any](n) + var i = 0 + while (i < n) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i) + } + i += 1 + } + values + } + + override def toString(): String = toArray.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayData]) { + return false + } + + val other = o.asInstanceOf[ArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + get(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala new file mode 100644 index 0000000000..7992ba947c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} + +class GenericArrayData(array: Array[Any]) extends ArrayData { + private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] + + override def toArray(): Array[Any] = array + + override def get(ordinal: Int): Any = array(ordinal) + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def numElements(): Int = array.length +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a517da9872..4f35b653d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -730,13 +731,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - InternalRow( - Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), - Map( - UTF8String.fromString("a") -> UTF8String.fromString("123"), - UTF8String.fromString("b") -> UTF8String.fromString("abc"), - UTF8String.fromString("c") -> UTF8String.fromString("")), - InternalRow(0)), + Row( + Seq("123", "abc", ""), + Map("a" ->"123", "b" -> "abc", "c" -> ""), + Row(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -756,13 +754,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow( + checkEvaluation(ret, Row( Seq(123, null, null), - Map( - UTF8String.fromString("a") -> true, - UTF8String.fromString("b") -> true, - UTF8String.fromString("c") -> false), - InternalRow(0L))) + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) } test("case between string and interval") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5de5ddce97..3fa246b69d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { expr.dataType match { case ArrayType(StructType(fields), containsNull) => val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index aeeb0e4527..f26f41fb75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -158,8 +158,8 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (s: Seq[_], ArrayType(elemType, _)) => - s.foreach(typeCheck(_, elemType)) + case (a: ArrayData, ArrayType(elemType, _)) => + a.toArray().foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => m.keys.foreach(typeCheck(_, keyType)) m.values.foreach(typeCheck(_, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 3c38916fd7..ef1c6e57dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -134,8 +134,19 @@ object EvaluatePython { } new GenericInternalRowWithSchema(values, struct) - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava + case (a: ArrayData, array: ArrayType) => + val length = a.numElements() + val values = new java.util.ArrayList[Any](length) + var i = 0 + while (i < length) { + if (a.isNullAt(i)) { + values.add(null) + } else { + values.add(toJava(a.get(i), array.elementType)) + } + i += 1 + } + values case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) @@ -190,10 +201,10 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}.toSeq + new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 78da2840da..9329148aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 0eb3b04007..04ab5e2217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -125,7 +125,7 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType: DataType => Option[DataType] = { - case at@ArrayType(elementType, _) => + case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) } yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 381e7ed544..1c309f8794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -110,8 +110,13 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + case (START_ARRAY, ArrayType(st, _)) => - convertList(factory, parser, st) + convertArray(factory, parser, st) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: @@ -165,16 +170,16 @@ private[sql] object JacksonParser { builder.result() } - private def convertList( + private def convertArray( factory: JsonFactory, parser: JsonParser, - schema: DataType): Seq[Any] = { - val builder = Seq.newBuilder[Any] + elementType: DataType): ArrayData = { + val values = scala.collection.mutable.ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - builder += convertField(factory, parser, schema) + values += convertField(factory, parser, elementType) } - builder.result() + new GenericArrayData(values.toArray) } private def parseJson( @@ -201,12 +206,15 @@ private[sql] object JacksonParser { val parser = factory.createParser(record) parser.nextToken() - // to support both object and arrays (see SPARK-3308) we'll start - // by converting the StructType schema to an ArrayType and let - // convertField wrap an object into a single value array when necessary. - convertField(factory, parser, ArrayType(schema)) match { + convertField(factory, parser, schema) match { case null => failedRecord(record) - case list: Seq[InternalRow @unchecked] => list + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray().map(_.asInstanceOf[InternalRow]) + } case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index e00bd90edb..172db8362a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = elementConverter - override def end(): Unit = updater.set(currentArray) + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index ea51650fe9..2332a36468 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.ArrayData // TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { @@ -32,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] + type ArrayScalaType[T] = ArrayData type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 78ecfad1d5..79dd16b7b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() - if (array.size > 0) { + if (array.numElements() > 0) { if (schema.containsNull) { writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { + while (i < array.numElements()) { writer.startGroup() - if (array(i) != null) { + if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) + writeValue(elementType, array.get(i)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } else { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) + while (i < array.numElements()) { + writeValue(elementType, array.get(i)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe3..9e61d06f40 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -30,7 +30,6 @@ import org.junit.*; import scala.collection.JavaConversions; import scala.collection.Seq; -import scala.collection.mutable.Buffer; import java.io.Serializable; import java.util.Arrays; @@ -168,10 +167,10 @@ public class JavaDataFrameSuite { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello"); + Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); Seq<String> d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941..77ed4a9c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toArray.map(_.asInstanceOf[Double])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5e189c3563..cfb03ff485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -67,12 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), + Row( + s"str_$i", s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -81,19 +81,19 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), - DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + new Date(1970, 1, 1), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> UTF8String.fromString(i.toString)), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - InternalRow(i, UTF8String.fromString(i.toString)), - InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - }.asInstanceOf[RDD[Row]] + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(new Date(1970, 1, i + 1))))) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index f467500259..5926ef9aa3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -52,9 +52,8 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * [[org.apache.spark.sql.catalyst.InternalRow]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -520,7 +531,7 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[Seq[_]].foreach { + a.asInstanceOf[ArrayData].toArray().foreach { v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list @@ -634,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 741c705e2a..7e3342cc84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -176,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8732e9abf8..4a13022edd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 0330013f53..f719f2e06a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil + val d = new GenericArrayData(Array(row(0), row(0))) checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, |