aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-06 13:11:59 -0700
committerReynold Xin <rxin@databricks.com>2015-08-06 13:11:59 -0700
commit1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21 (patch)
treef04d74dffd581fa1eeb8e7a1f929f2aa843cf0a0 /sql/core
parenta1bbf1bc5c51cd796015ac159799cf024de6fa07 (diff)
downloadspark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.tar.gz
spark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.tar.bz2
spark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.zip
[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info
This re-applies #7955, which was reverted due to a race condition to fix build breaking. Author: Wenchen Fan <cloud0fan@outlook.com> Author: Reynold Xin <rxin@databricks.com> Closes #8002 from rxin/InternalRow-toSeq and squashes the following commits: 332416a [Reynold Xin] Merge pull request #7955 from cloud-fan/toSeq 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala54
5 files changed, 65 insertions, 59 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index af1a8ecca9..5cbd52bc05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
* Column statistics represented as a single row, currently including closed lower bound, closed
* upper bound and null count.
*/
- def collectedStatistics: InternalRow
+ def collectedStatistics: GenericInternalRow
}
/**
@@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable {
private[sql] class NoopColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
- override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
}
private[sql] class BooleanColumnStats extends ColumnStats {
@@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class ByteColumnStats extends ColumnStats {
@@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class ShortColumnStats extends ColumnStats {
@@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class IntColumnStats extends ColumnStats {
@@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class LongColumnStats extends ColumnStats {
@@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class FloatColumnStats extends ColumnStats {
@@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class DoubleColumnStats extends ColumnStats {
@@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class StringColumnStats extends ColumnStats {
@@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class BinaryColumnStats extends ColumnStats {
@@ -230,8 +231,8 @@ private[sql] class BinaryColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
@@ -248,8 +249,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
@@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
private[sql] class DateColumnStats extends IntColumnStats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 5d5b0697d7..d553bb6169 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation(
}
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
- .flatMap(_.toSeq))
+ .flatMap(_.values))
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan(
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
- def statsString: String = relation.partitionStatistics.schema
- .zip(cachedBatch.stats.toSeq)
- .map { case (a, s) => s"${a.name}: $s" }
- .mkString(", ")
+ def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map {
+ case (a, i) =>
+ val value = cachedBatch.stats.get(i, a.dataType)
+ s"${a.name}: $value"
+ }.mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
false
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index c37007f1ee..dd3858ea2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -156,8 +156,8 @@ package object debug {
def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
case (null, _) =>
- case (row: InternalRow, StructType(fields)) =>
- row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
+ case (row: InternalRow, s: StructType) =>
+ row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (a: ArrayData, ArrayType(elemType, _)) =>
a.foreach(elemType, (_, e) => {
typeCheck(e, elemType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 7126145ddc..c04557e5a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
val spec = discoverPartitions()
val partitionColumnTypes = spec.partitionColumns.map(_.dataType)
val castedPartitions = spec.partitions.map { case p @ Partition(values, path) =>
- val literals = values.toSeq.zip(partitionColumnTypes).map {
- case (value, dataType) => Literal.create(value, dataType)
+ val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) =>
+ Literal.create(values.get(i, dt), dt)
}
val castedValues = partitionSchema.zip(literals).map { case (field, literal) =>
Cast(literal, field.dataType).eval()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 16e0187ed2..d0430d2a60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,33 +19,36 @@ package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
- testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0))
- testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0))
- testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0))
- testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
+ testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
- InternalRow(Long.MaxValue, Long.MinValue, 0))
- testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0))
+ createRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
- InternalRow(Double.MaxValue, Double.MinValue, 0))
- testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
- testDecimalColumnStats(InternalRow(null, null, 0))
+ createRow(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
+ testDecimalColumnStats(createRow(null, null, 0))
+
+ def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: InternalRow): Unit = {
+ initialStatistics: GenericInternalRow): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+ columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -61,11 +64,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
- assertResult(10, "Wrong null count")(stats.get(2, null))
- assertResult(20, "Wrong row count")(stats.get(3, null))
- assertResult(stats.get(4, null), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
+ assertResult(10, "Wrong null count")(stats.values(2))
+ assertResult(20, "Wrong row count")(stats.values(3))
+ assertResult(stats.values(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
@@ -73,14 +76,15 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
- def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) {
+ def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
+ initialStatistics: GenericInternalRow): Unit = {
val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
val columnType = FIXED_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new FixedDecimalColumnStats(15, 10)
- columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+ columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -96,11 +100,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
- assertResult(10, "Wrong null count")(stats.get(2, null))
- assertResult(20, "Wrong row count")(stats.get(3, null))
- assertResult(stats.get(4, null), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
+ assertResult(10, "Wrong null count")(stats.values(2))
+ assertResult(20, "Wrong row count")(stats.values(3))
+ assertResult(stats.values(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum