From 7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Aug 2015 17:05:19 -0700 Subject: [SPARK-9598][SQL] do not expose generic getter in internal row Author: Wenchen Fan Closes #7932 from cloud-fan/generic-getter and squashes the following commits: c60de4c [Wenchen Fan] do not expose generic getter in internal row --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 5 -- .../apache/spark/sql/catalyst/InternalRow.scala | 37 ++++++++++--- .../expressions/GenericSpecializedGetters.scala | 61 ---------------------- .../sql/catalyst/expressions/Projection.scala | 4 +- .../catalyst/expressions/SpecificMutableRow.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 2 +- .../expressions/codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 12 ++--- .../apache/spark/sql/types/GenericArrayData.scala | 37 ++++++++----- .../execution/datasources/DataSourceStrategy.scala | 6 ++- .../spark/sql/columnar/ColumnStatsSuite.scala | 20 +++---- 11 files changed, 80 insertions(+), 108 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e6750fce4f..e3e1622de0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -253,11 +253,6 @@ public final class UnsafeRow extends MutableRow { } } - @Override - public Object genericGet(int ordinal) { - throw new UnsupportedOperationException(); - } - @Override public Object get(int ordinal, DataType dataType) { if (isNullAt(ordinal) || dataType instanceof NullType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7656d054dc..7d17cca808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -// todo: make InternalRow just extends SpecializedGetters, remove generic getter -abstract class InternalRow extends GenericSpecializedGetters with Serializable { +abstract class InternalRow extends SpecializedGetters with Serializable { def numFields: Int @@ -50,6 +50,31 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable { false } + // Subclasses of InternalRow should implement all special getters and equals/hashCode, + // or implement this genericGet. + protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( + "Concrete internal rows should implement genericGet, " + + "or implement all special getters and equals/hashCode") + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[InternalRow]) { return false @@ -159,15 +184,15 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable { object InternalRow { /** - * This method can be used to construct a [[Row]] with the given values. + * This method can be used to construct a [[InternalRow]] with the given values. */ def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray) /** - * This method can be used to construct a [[Row]] from a [[Seq]] of values. + * This method can be used to construct a [[InternalRow]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray) - /** Returns an empty row. */ + /** Returns an empty [[InternalRow]]. */ val empty = apply() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala deleted file mode 100644 index 6e957928e0..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - -trait GenericSpecializedGetters extends SpecializedGetters { - - def genericGet(ordinal: Int): Any - - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - - override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) - - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - - override def getByte(ordinal: Int): Byte = getAs(ordinal) - - override def getShort(ordinal: Int): Short = getAs(ordinal) - - override def getInt(ordinal: Int): Int = getAs(ordinal) - - override def getLong(ordinal: Int): Long = getAs(ordinal) - - override def getFloat(ordinal: Int): Float = getAs(ordinal) - - override def getDouble(ordinal: Int): Double = getAs(ordinal) - - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - - override def getMap(ordinal: Int): MapData = getAs(ordinal) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 7964974102..4296b4b123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -207,8 +207,8 @@ class JoinedRow extends InternalRow { override def numFields: Int = row1.numFields + row2.numFields - override def genericGet(i: Int): Any = - if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields) + override def get(i: Int, dt: DataType): AnyRef = + if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) override def isNullAt(i: Int): Boolean = if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index d149a5b179..b94df6bd66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -232,7 +232,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericInternalRow(newValues) } - override def genericGet(i: Int): Any = values(i).boxed + override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { if (value == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5d4b349b15..2cf8312ea5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -623,7 +623,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.genericGet(0)).reduceLeft( + casted.iterator.map(f => f.get(0, null)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 1572b2b99a..c04fe734d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -184,7 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + protected Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d04434b953..5e5de1d1dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. @@ -76,13 +76,13 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def genericGet(ordinal: Int): Any = values(ordinal) + override protected def genericGet(ordinal: Int) = values(ordinal) override def toSeq: Seq[Any] = values @@ -103,13 +103,13 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def genericGet(ordinal: Int): Any = values(ordinal) + override protected def genericGet(ordinal: Int) = values(ordinal) override def toSeq: Seq[Any] = values diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index b314acdfe3..459fcb6fc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -17,22 +17,33 @@ package org.apache.spark.sql.types -import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters - -class GenericArrayData(private[sql] val array: Array[Any]) - extends ArrayData with GenericSpecializedGetters { - - override def genericGet(ordinal: Int): Any = array(ordinal) +class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { override def copy(): ArrayData = new GenericArrayData(array.clone()) - // todo: Array is invariant in scala, maybe use toSeq instead? - override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T]) - override def numElements(): Int = array.length + private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def toString(): String = array.mkString("[", ",", "]") override def equals(o: Any): Boolean = { @@ -56,8 +67,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) return false } if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) + val o1 = array(i) + val o2 = other.array(i) o1 match { case b1: Array[Byte] => if (!o2.isInstanceOf[Array[Byte]] || @@ -91,7 +102,7 @@ class GenericArrayData(private[sql] val array: Array[Any]) if (isNullAt(i)) { 0 } else { - genericGet(i) match { + array(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6b91e51ca5..d9d7bc19bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -187,15 +187,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // To see whether the `index`-th column is a partition column... val i = partitionColumns.indexOf(name) if (i != -1) { + val dt = schema(partitionColumns(i)).dataType // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues.genericGet(i) + mutableRow(ordinal) = partitionValues.get(i, dt) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) + val dt = schema(nonPartitionColumns(i)).dataType (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow.genericGet(i) + mutableRow(ordinal) = dataRow.get(i, dt) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 66014ddca0..16e0187ed2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -61,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) - assertResult(10, "Wrong null count")(stats.genericGet(2)) - assertResult(20, "Wrong row count")(stats.genericGet(3)) - assertResult(stats.genericGet(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -96,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) - assertResult(10, "Wrong null count")(stats.genericGet(2)) - assertResult(20, "Wrong row count")(stats.genericGet(3)) - assertResult(stats.genericGet(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum -- cgit v1.2.3