aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2015-08-06 11:15:37 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-06 11:15:37 -0700
commit2eca46a17a3d46a605804ff89c010017da91e1bc (patch)
tree3ee8fa52d14bce8b62e152da4aa560eae780338b
parent6e009cb9c4d7a395991e10dab427f37019283758 (diff)
downloadspark-2eca46a17a3d46a605804ff89c010017da91e1bc.tar.gz
spark-2eca46a17a3d46a605804ff89c010017da91e1bc.tar.bz2
spark-2eca46a17a3d46a605804ff89c010017da91e1bc.zip
Revert "[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info"
This reverts commit 6e009cb9c4d7a395991e10dab427f37019283758.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala132
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala132
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala54
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala21
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala24
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala10
15 files changed, 217 insertions, 259 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 85b4bf3b6a..7d17cca808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -31,6 +32,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString
+ override def toString: String = mkString("[", ",", "]")
+
/**
* Make a copy of the current [[InternalRow]] object.
*/
@@ -47,25 +50,136 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
false
}
+ // Subclasses of InternalRow should implement all special getters and equals/hashCode,
+ // or implement this genericGet.
+ protected def genericGet(ordinal: Int): Any = throw new IllegalStateException(
+ "Concrete internal rows should implement genericGet, " +
+ "or implement all special getters and equals/hashCode")
+
+ // default implementation (slow)
+ private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+ override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[InternalRow]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[InternalRow]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numFields
+ if (len != other.numFields) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = genericGet(i)
+ val o2 = other.genericGet(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numFields
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ genericGet(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+
/* ---------------------- utility methods for Scala ---------------------- */
/**
* Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
- def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
- val len = numFields
- assert(len == fieldTypes.length)
-
- val values = new Array[Any](len)
+ // todo: remove this as it needs the generic getter
+ def toSeq: Seq[Any] = {
+ val n = numFields
+ val values = new Array[Any](n)
var i = 0
- while (i < len) {
- values(i) = get(i, fieldTypes(i))
+ while (i < n) {
+ values.update(i, genericGet(i))
i += 1
}
values
}
- def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType))
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
object InternalRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 59ce7fc4f2..4296b4b123 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -203,11 +203,7 @@ class JoinedRow extends InternalRow {
this
}
- override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
- assert(fieldTypes.length == row1.numFields + row2.numFields)
- val (left, right) = fieldTypes.splitAt(row1.numFields)
- row1.toSeq(left) ++ row2.toSeq(right)
- }
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
override def numFields: Int = row1.numFields + row2.numFields
@@ -280,11 +276,11 @@ class JoinedRow extends InternalRow {
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
- row2.toString
+ row2.mkString("[", ",", "]")
} else if (row2 eq null) {
- row1.toString
+ row1.mkString("[", ",", "]")
} else {
- s"{${row1.toString} + ${row2.toString}}"
+ mkString("[", ",", "]")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 4f56f94bd4..b94df6bd66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -192,8 +192,7 @@ final class MutableAny extends MutableValue {
* based on the dataTypes of each column. The intent is to decrease garbage when modifying the
* values of primitive columns.
*/
-final class SpecificMutableRow(val values: Array[MutableValue])
- extends MutableRow with BaseGenericInternalRow {
+final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow {
def this(dataTypes: Seq[DataType]) =
this(
@@ -214,6 +213,8 @@ final class SpecificMutableRow(val values: Array[MutableValue])
override def numFields: Int = values.length
+ override def toSeq: Seq[Any] = values.map(_.boxed)
+
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c744e84d82..c04fe734d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -26,8 +25,6 @@ import org.apache.spark.sql.types._
*/
abstract class BaseProjection extends Projection {}
-abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow
-
/**
* Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input
* [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]]
@@ -174,7 +171,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
return new SpecificRow((InternalRow) r);
}
- final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} {
+ final class SpecificRow extends ${classOf[MutableRow].getName} {
$columns
@@ -187,8 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
- @Override
- public Object genericGet(int i) {
+ protected Object genericGet(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 207e667792..7657fb535d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -22,130 +22,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
/**
- * An extended version of [[InternalRow]] that implements all special getters, toString
- * and equals/hashCode by `genericGet`.
- */
-trait BaseGenericInternalRow extends InternalRow {
-
- protected def genericGet(ordinal: Int): Any
-
- // default implementation (slow)
- private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
- override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
- override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
- override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
- override def getByte(ordinal: Int): Byte = getAs(ordinal)
- override def getShort(ordinal: Int): Short = getAs(ordinal)
- override def getInt(ordinal: Int): Int = getAs(ordinal)
- override def getLong(ordinal: Int): Long = getAs(ordinal)
- override def getFloat(ordinal: Int): Float = getAs(ordinal)
- override def getDouble(ordinal: Int): Double = getAs(ordinal)
- override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
- override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
- override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
- override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
- override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
- override def getMap(ordinal: Int): MapData = getAs(ordinal)
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
- override def toString(): String = {
- if (numFields == 0) {
- "[empty row]"
- } else {
- val sb = new StringBuilder
- sb.append("[")
- sb.append(genericGet(0))
- val len = numFields
- var i = 1
- while (i < len) {
- sb.append(",")
- sb.append(genericGet(i))
- i += 1
- }
- sb.append("]")
- sb.toString()
- }
- }
-
- override def equals(o: Any): Boolean = {
- if (!o.isInstanceOf[BaseGenericInternalRow]) {
- return false
- }
-
- val other = o.asInstanceOf[BaseGenericInternalRow]
- if (other eq null) {
- return false
- }
-
- val len = numFields
- if (len != other.numFields) {
- return false
- }
-
- var i = 0
- while (i < len) {
- if (isNullAt(i) != other.isNullAt(i)) {
- return false
- }
- if (!isNullAt(i)) {
- val o1 = genericGet(i)
- val o2 = other.genericGet(i)
- o1 match {
- case b1: Array[Byte] =>
- if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
- return false
- }
- case f1: Float if java.lang.Float.isNaN(f1) =>
- if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
- return false
- }
- case d1: Double if java.lang.Double.isNaN(d1) =>
- if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
- return false
- }
- case _ => if (o1 != o2) {
- return false
- }
- }
- }
- i += 1
- }
- true
- }
-
- // Custom hashCode function that matches the efficient code generated version.
- override def hashCode: Int = {
- var result: Int = 37
- var i = 0
- val len = numFields
- while (i < len) {
- val update: Int =
- if (isNullAt(i)) {
- 0
- } else {
- genericGet(i) match {
- case b: Boolean => if (b) 0 else 1
- case b: Byte => b.toInt
- case s: Short => s.toInt
- case i: Int => i
- case l: Long => (l ^ (l >>> 32)).toInt
- case f: Float => java.lang.Float.floatToIntBits(f)
- case d: Double =>
- val b = java.lang.Double.doubleToLongBits(d)
- (b ^ (b >>> 32)).toInt
- case a: Array[Byte] => java.util.Arrays.hashCode(a)
- case other => other.hashCode()
- }
- }
- result = 37 * result + update
- i += 1
- }
- result
- }
-}
-
-/**
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
* Setting a value through a primitive function implicitly marks that column as not null.
*/
@@ -206,7 +82,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
* Note that, while the array is not copied, and thus could technically be mutated after creation,
* this is not allowed.
*/
-class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow {
+class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
@@ -214,7 +90,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri
override protected def genericGet(ordinal: Int) = values(ordinal)
- override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values
+ override def toSeq: Seq[Any] = values
override def numFields: Int = values.length
@@ -233,7 +109,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow {
+class GenericMutableRow(values: Array[Any]) extends MutableRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
@@ -241,7 +117,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI
override protected def genericGet(ordinal: Int) = values(ordinal)
- override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values
+ override def toSeq: Seq[Any] = values
override def numFields: Int = values.length
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e323467af5..e310aee221 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
val plan = GenerateMutableProjection.generate(expressions)()
- val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
+ val actual = plan(new GenericMutableRow(length)).toSeq
val expected = Seq.fill(length)(true)
if (!checkResult(actual, expected)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 5cbd52bc05..af1a8ecca9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
* Column statistics represented as a single row, currently including closed lower bound, closed
* upper bound and null count.
*/
- def collectedStatistics: GenericInternalRow
+ def collectedStatistics: InternalRow
}
/**
@@ -75,8 +75,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
private[sql] class NoopColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
+ override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L)
}
private[sql] class BooleanColumnStats extends ColumnStats {
@@ -93,8 +92,8 @@ private[sql] class BooleanColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class ByteColumnStats extends ColumnStats {
@@ -111,8 +110,8 @@ private[sql] class ByteColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class ShortColumnStats extends ColumnStats {
@@ -129,8 +128,8 @@ private[sql] class ShortColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class IntColumnStats extends ColumnStats {
@@ -147,8 +146,8 @@ private[sql] class IntColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class LongColumnStats extends ColumnStats {
@@ -165,8 +164,8 @@ private[sql] class LongColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FloatColumnStats extends ColumnStats {
@@ -183,8 +182,8 @@ private[sql] class FloatColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DoubleColumnStats extends ColumnStats {
@@ -201,8 +200,8 @@ private[sql] class DoubleColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class StringColumnStats extends ColumnStats {
@@ -219,8 +218,8 @@ private[sql] class StringColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class BinaryColumnStats extends ColumnStats {
@@ -231,8 +230,8 @@ private[sql] class BinaryColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
@@ -249,8 +248,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
@@ -263,8 +262,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+ override def collectedStatistics: InternalRow =
+ InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class DateColumnStats extends IntColumnStats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index d553bb6169..5d5b0697d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation(
}
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
- .flatMap(_.values))
+ .flatMap(_.toSeq))
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -330,11 +330,10 @@ private[sql] case class InMemoryColumnarTableScan(
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
- def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map {
- case (a, i) =>
- val value = cachedBatch.stats.get(i, a.dataType)
- s"${a.name}: $value"
- }.mkString(", ")
+ def statsString: String = relation.partitionStatistics.schema
+ .zip(cachedBatch.stats.toSeq)
+ .map { case (a, s) => s"${a.name}: $s" }
+ .mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
false
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index dd3858ea2b..c37007f1ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -156,8 +156,8 @@ package object debug {
def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
case (null, _) =>
- case (row: InternalRow, s: StructType) =>
- row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
+ case (row: InternalRow, StructType(fields)) =>
+ row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (a: ArrayData, ArrayType(elemType, _)) =>
a.foreach(elemType, (_, e) => {
typeCheck(e, elemType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c04557e5a0..7126145ddc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
val spec = discoverPartitions()
val partitionColumnTypes = spec.partitionColumns.map(_.dataType)
val castedPartitions = spec.partitions.map { case p @ Partition(values, path) =>
- val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) =>
- Literal.create(values.get(i, dt), dt)
+ val literals = values.toSeq.zip(partitionColumnTypes).map {
+ case (value, dataType) => Literal.create(value, dataType)
}
val castedValues = partitionSchema.zip(literals).map { case (field, literal) =>
Cast(literal, field.dataType).eval()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index d0430d2a60..16e0187ed2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,36 +19,33 @@ package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
- testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
- testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
- testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
- testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0))
+ testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
- createRow(Long.MaxValue, Long.MinValue, 0))
- testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
+ InternalRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
- createRow(Double.MaxValue, Double.MinValue, 0))
- testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
- testDecimalColumnStats(createRow(null, null, 0))
-
- def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
+ InternalRow(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
+ testDecimalColumnStats(InternalRow(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: GenericInternalRow): Unit = {
+ initialStatistics: InternalRow): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+ columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -64,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
- assertResult(10, "Wrong null count")(stats.values(2))
- assertResult(20, "Wrong row count")(stats.values(3))
- assertResult(stats.values(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+ assertResult(10, "Wrong null count")(stats.get(2, null))
+ assertResult(20, "Wrong row count")(stats.get(3, null))
+ assertResult(stats.get(4, null), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
@@ -76,15 +73,14 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
- def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
- initialStatistics: GenericInternalRow): Unit = {
+ def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) {
val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
val columnType = FIXED_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new FixedDecimalColumnStats(15, 10)
- columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+ columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -100,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
- assertResult(10, "Wrong null count")(stats.values(2))
- assertResult(20, "Wrong row count")(stats.values(3))
- assertResult(stats.values(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+ assertResult(10, "Wrong null count")(stats.get(2, null))
+ assertResult(20, "Wrong row count")(stats.get(3, null))
+ assertResult(stats.get(4, null), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 9824dad239..39d798d072 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -390,10 +390,8 @@ private[hive] trait HiveInspectors {
(o: Any) => {
if (o != null) {
val struct = soi.create()
- val row = o.asInstanceOf[InternalRow]
- soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach {
- case ((field, wrapper), i) =>
- soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType)))
+ (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach {
+ (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
} else {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index ade27454b9..a6a343d395 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -88,7 +88,6 @@ case class ScriptTransformation(
// external process. That process's output will be read by this current thread.
val writerThread = new ScriptTransformationWriterThread(
inputIterator,
- input.map(_.dataType),
outputProjection,
inputSerde,
inputSoi,
@@ -202,7 +201,6 @@ case class ScriptTransformation(
private class ScriptTransformationWriterThread(
iter: Iterator[InternalRow],
- inputSchema: Seq[DataType],
outputProjection: Projection,
@Nullable inputSerde: AbstractSerDe,
@Nullable inputSoi: ObjectInspector,
@@ -228,25 +226,12 @@ private class ScriptTransformationWriterThread(
// We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
// let's use a variable to record whether the `finally` block was hit due to an exception
var threwException: Boolean = true
- val len = inputSchema.length
try {
iter.map(outputProjection).foreach { row =>
if (inputSerde == null) {
- val data = if (len == 0) {
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
- } else {
- val sb = new StringBuilder
- sb.append(row.get(0, inputSchema(0)))
- var i = 1
- while (i < len) {
- sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
- sb.append(row.get(i, inputSchema(i)))
- i += 1
- }
- sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
- sb.toString()
- }
- outputStream.write(data.getBytes("utf-8"))
+ val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+ outputStream.write(data)
} else {
val writable = inputSerde.serialize(
row.asInstanceOf[GenericInternalRow].values, inputSoi)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 8dc796b056..684ea1d137 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
}
}
- val nonDynamicPartLen = row.numFields - dynamicPartColNames.length
- val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) =>
- val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType)
- val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal)
- val colString =
- if (string == null || string.isEmpty) {
- defaultPartName
- } else {
- FileUtils.escapePathName(string, defaultPartName)
- }
- s"/$colName=$colString"
- }.mkString
+ val dynamicPartPath = dynamicPartColNames
+ .zip(row.toSeq.takeRight(dynamicPartColNames.length))
+ .map { case (col, rawVal) =>
+ val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal)
+ val colString =
+ if (string == null || string.isEmpty) {
+ defaultPartName
+ } else {
+ FileUtils.escapePathName(string, defaultPartName)
+ }
+ s"/$col=$colString"
+ }.mkString
def newWriter(): FileSinkOperator.RecordWriter = {
val newFileSinkDesc = new FileSinkDesc(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 81a70b8d42..99e95fb921 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
}
}
- def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = {
- row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) =>
+ def checkValues(row1: Seq[Any], row2: InternalRow): Unit = {
+ row1.zip(row2.toSeq).foreach { case (r1, r2) =>
checkValue(r1, r2)
}
}
@@ -211,10 +211,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
case (t, idx) => StructField(s"c_$idx", t)
})
val inspector = toInspector(dt)
- checkValues(
- row,
- unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow],
- dt)
+ checkValues(row,
+ unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow])
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
}