From 2eca46a17a3d46a605804ff89c010017da91e1bc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 11:15:37 -0700 Subject: Revert "[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info" This reverts commit 6e009cb9c4d7a395991e10dab427f37019283758. --- .../apache/spark/sql/catalyst/InternalRow.scala | 132 +++++++++++++++++++-- .../sql/catalyst/expressions/Projection.scala | 12 +- .../catalyst/expressions/SpecificMutableRow.scala | 5 +- .../expressions/codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +-------------------- .../catalyst/expressions/CodeGenerationSuite.scala | 2 +- .../apache/spark/sql/columnar/ColumnStats.scala | 51 ++++---- .../sql/columnar/InMemoryColumnarTableScan.scala | 11 +- .../apache/spark/sql/execution/debug/package.scala | 4 +- .../org/apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 ++++----- .../org/apache/spark/sql/hive/HiveInspectors.scala | 6 +- .../sql/hive/execution/ScriptTransformation.scala | 21 +--- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../apache/spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 217 insertions(+), 259 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 85b4bf3b6a..7d17cca808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -31,6 +32,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString + override def toString: String = mkString("[", ",", "]") + /** * Make a copy of the current [[InternalRow]] object. */ @@ -47,25 +50,136 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } + // Subclasses of InternalRow should implement all special getters and equals/hashCode, + // or implement this genericGet. + protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( + "Concrete internal rows should implement genericGet, " + + "or implement all special getters and equals/hashCode") + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - val len = numFields - assert(len == fieldTypes.length) - - val values = new Array[Any](len) + // todo: remove this as it needs the generic getter + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) var i = 0 - while (i < len) { - values(i) = get(i, fieldTypes(i)) + while (i < n) { + values.update(i, genericGet(i)) i += 1 } values } - def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 59ce7fc4f2..4296b4b123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,11 +203,7 @@ class JoinedRow extends InternalRow { this } - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - assert(fieldTypes.length == row1.numFields + row2.numFields) - val (left, right) = fieldTypes.splitAt(row1.numFields) - row1.toSeq(left) ++ row2.toSeq(right) - } + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq override def numFields: Int = row1.numFields + row2.numFields @@ -280,11 +276,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.toString + row2.mkString("[", ",", "]") } else if (row2 eq null) { - row1.toString + row1.mkString("[", ",", "]") } else { - s"{${row1.toString} + ${row2.toString}}" + mkString("[", ",", "]") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4f56f94bd4..b94df6bd66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,8 +192,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) - extends MutableRow with BaseGenericInternalRow { +final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { def this(dataTypes: Seq[DataType]) = this( @@ -214,6 +213,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) override def numFields: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed) + override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d82..c04fe734d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -26,8 +25,6 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -174,7 +171,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -187,8 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - @Override - public Object genericGet(int i) { + protected Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 207e667792..7657fb535d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,130 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** - * An extended version of [[InternalRow]] that implements all special getters, toString - * and equals/hashCode by `genericGet`. - */ -trait BaseGenericInternalRow extends InternalRow { - - protected def genericGet(ordinal: Int): Any - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def toString(): String = { - if (numFields == 0) { - "[empty row]" - } else { - val sb = new StringBuilder - sb.append("[") - sb.append(genericGet(0)) - val len = numFields - var i = 1 - while (i < len) { - sb.append(",") - sb.append(genericGet(i)) - i += 1 - } - sb.append("]") - sb.toString() - } - } - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[BaseGenericInternalRow]) { - return false - } - - val other = o.asInstanceOf[BaseGenericInternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } -} - /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -206,7 +82,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -214,7 +90,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length @@ -233,7 +109,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -241,7 +117,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5..e310aee221 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericMutableRow(length)).toSeq val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 5cbd52bc05..af1a8ecca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: GenericInternalRow + def collectedStatistics: InternalRow } /** @@ -75,8 +75,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) + override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,8 +92,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -111,8 +110,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -129,8 +128,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { @@ -147,8 +146,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -165,8 +164,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -183,8 +182,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -201,8 +200,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { @@ -219,8 +218,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -231,8 +230,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -249,8 +248,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -263,8 +262,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d553bb6169..5d5b0697d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,11 +330,10 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { - case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") + def statsString: String = relation.partitionStatistics.schema + .zip(cachedBatch.stats.toSeq) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dd3858ea2b..c37007f1ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, s: StructType) => - row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, StructType(fields)) => + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c04557e5a0..7126145ddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => - Literal.create(values.get(i, dt), dt) + val literals = values.toSeq.zip(partitionColumnTypes).map { + case (value, dataType) => Literal.create(value, dataType) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index d0430d2a60..16e0187ed2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,36 +19,33 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) + InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: InternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -64,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -76,15 +73,14 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -100,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9824dad239..39d798d072 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,10 +390,8 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { - case ((field, wrapper), i) => - soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ade27454b9..a6a343d395 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,7 +88,6 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, - input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -202,7 +201,6 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], - inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -228,25 +226,12 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true - val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = if (len == 0) { - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") - } else { - val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) - var i = 1 - while (i < len) { - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) - i += 1 - } - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) - sb.toString() - } - outputStream.write(data.getBytes("utf-8")) + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8dc796b056..684ea1d137 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + val dynamicPartPath = dynamicPartColNames + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$col=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 81a70b8d42..99e95fb921 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { - row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { + row1.zip(row2.toSeq).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,10 +211,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues( - row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], - dt) + checkValues(row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } -- cgit v1.2.3