aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-04 17:05:19 -0700
committerReynold Xin <rxin@databricks.com>2015-08-04 17:05:19 -0700
commit7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4 (patch)
treea6ced68aa4833cedf76b110b2e50a059d3ba82af /sql
parentb77d3b9688d56d33737909375d1d0db07da5827b (diff)
downloadspark-7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4.tar.gz
spark-7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4.tar.bz2
spark-7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4.zip
[SPARK-9598][SQL] do not expose generic getter in internal row
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7932 from cloud-fan/generic-getter and squashes the following commits: c60de4c [Wenchen Fan] do not expose generic getter in internal row
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala20
11 files changed, 80 insertions, 108 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e6750fce4f..e3e1622de0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -254,11 +254,6 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public Object genericGet(int ordinal) {
- throw new UnsupportedOperationException();
- }
-
- @Override
public Object get(int ordinal, DataType dataType) {
if (isNullAt(ordinal) || dataType instanceof NullType) {
return null;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 7656d054dc..7d17cca808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql.catalyst
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
-// todo: make InternalRow just extends SpecializedGetters, remove generic getter
-abstract class InternalRow extends GenericSpecializedGetters with Serializable {
+abstract class InternalRow extends SpecializedGetters with Serializable {
def numFields: Int
@@ -50,6 +50,31 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable {
false
}
+ // Subclasses of InternalRow should implement all special getters and equals/hashCode,
+ // or implement this genericGet.
+ protected def genericGet(ordinal: Int): Any = throw new IllegalStateException(
+ "Concrete internal rows should implement genericGet, " +
+ "or implement all special getters and equals/hashCode")
+
+ // default implementation (slow)
+ private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+ override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[InternalRow]) {
return false
@@ -159,15 +184,15 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable {
object InternalRow {
/**
- * This method can be used to construct a [[Row]] with the given values.
+ * This method can be used to construct a [[InternalRow]] with the given values.
*/
def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray)
/**
- * This method can be used to construct a [[Row]] from a [[Seq]] of values.
+ * This method can be used to construct a [[InternalRow]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray)
- /** Returns an empty row. */
+ /** Returns an empty [[InternalRow]]. */
val empty = apply()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
deleted file mode 100644
index 6e957928e0..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
-trait GenericSpecializedGetters extends SpecializedGetters {
-
- def genericGet(ordinal: Int): Any
-
- private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
-
- override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
-
- override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
-
- override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
-
- override def getByte(ordinal: Int): Byte = getAs(ordinal)
-
- override def getShort(ordinal: Int): Short = getAs(ordinal)
-
- override def getInt(ordinal: Int): Int = getAs(ordinal)
-
- override def getLong(ordinal: Int): Long = getAs(ordinal)
-
- override def getFloat(ordinal: Int): Float = getAs(ordinal)
-
- override def getDouble(ordinal: Int): Double = getAs(ordinal)
-
- override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
-
- override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
-
- override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
-
- override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
-
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
- override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
-
- override def getMap(ordinal: Int): MapData = getAs(ordinal)
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 7964974102..4296b4b123 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -207,8 +207,8 @@ class JoinedRow extends InternalRow {
override def numFields: Int = row1.numFields + row2.numFields
- override def genericGet(i: Int): Any =
- if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields)
+ override def get(i: Int, dt: DataType): AnyRef =
+ if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
override def isNullAt(i: Int): Boolean =
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index d149a5b179..b94df6bd66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -232,7 +232,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericInternalRow(newValues)
}
- override def genericGet(i: Int): Any = values(i).boxed
+ override protected def genericGet(i: Int): Any = values(i).boxed
override def update(ordinal: Int, value: Any) {
if (value == null) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5d4b349b15..2cf8312ea5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -623,7 +623,7 @@ case class CombineSetsAndSumFunction(
null
} else {
Cast(Literal(
- casted.iterator.map(f => f.genericGet(0)).reduceLeft(
+ casted.iterator.map(f => f.get(0, null)).reduceLeft(
base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
base.dataType).eval(null)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 1572b2b99a..c04fe734d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -184,7 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
- public Object genericGet(int i) {
+ protected Object genericGet(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index d04434b953..5e5de1d1dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType}
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
@@ -76,13 +76,13 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
* Note that, while the array is not copied, and thus could technically be mutated after creation,
* this is not allowed.
*/
-class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow {
+class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
- override def genericGet(ordinal: Int): Any = values(ordinal)
+ override protected def genericGet(ordinal: Int) = values(ordinal)
override def toSeq: Seq[Any] = values
@@ -103,13 +103,13 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(val values: Array[Any]) extends MutableRow {
+class GenericMutableRow(values: Array[Any]) extends MutableRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
- override def genericGet(ordinal: Int): Any = values(ordinal)
+ override protected def genericGet(ordinal: Int) = values(ordinal)
override def toSeq: Seq[Any] = values
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index b314acdfe3..459fcb6fc0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -17,22 +17,33 @@
package org.apache.spark.sql.types
-import scala.reflect.ClassTag
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters
-
-class GenericArrayData(private[sql] val array: Array[Any])
- extends ArrayData with GenericSpecializedGetters {
-
- override def genericGet(ordinal: Int): Any = array(ordinal)
+class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
override def copy(): ArrayData = new GenericArrayData(array.clone())
- // todo: Array is invariant in scala, maybe use toSeq instead?
- override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T])
-
override def numElements(): Int = array.length
+ private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T]
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+ override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+
override def toString(): String = array.mkString("[", ",", "]")
override def equals(o: Any): Boolean = {
@@ -56,8 +67,8 @@ class GenericArrayData(private[sql] val array: Array[Any])
return false
}
if (!isNullAt(i)) {
- val o1 = genericGet(i)
- val o2 = other.genericGet(i)
+ val o1 = array(i)
+ val o2 = other.array(i)
o1 match {
case b1: Array[Byte] =>
if (!o2.isInstanceOf[Array[Byte]] ||
@@ -91,7 +102,7 @@ class GenericArrayData(private[sql] val array: Array[Any])
if (isNullAt(i)) {
0
} else {
- genericGet(i) match {
+ array(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 6b91e51ca5..d9d7bc19bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -187,15 +187,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// To see whether the `index`-th column is a partition column...
val i = partitionColumns.indexOf(name)
if (i != -1) {
+ val dt = schema(partitionColumns(i)).dataType
// If yes, gets column value from partition values.
(mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = partitionValues.genericGet(i)
+ mutableRow(ordinal) = partitionValues.get(i, dt)
}
} else {
// Otherwise, inherits the value from scanned data.
val i = nonPartitionColumns.indexOf(name)
+ val dt = schema(nonPartitionColumns(i)).dataType
(mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = dataRow.genericGet(i)
+ mutableRow(ordinal) = dataRow.get(i, dt)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 66014ddca0..16e0187ed2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -61,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
- assertResult(10, "Wrong null count")(stats.genericGet(2))
- assertResult(20, "Wrong row count")(stats.genericGet(3))
- assertResult(stats.genericGet(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+ assertResult(10, "Wrong null count")(stats.get(2, null))
+ assertResult(20, "Wrong row count")(stats.get(3, null))
+ assertResult(stats.get(4, null), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
@@ -96,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
- assertResult(10, "Wrong null count")(stats.genericGet(2))
- assertResult(20, "Wrong row count")(stats.genericGet(3))
- assertResult(stats.genericGet(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+ assertResult(10, "Wrong null count")(stats.get(2, null))
+ assertResult(20, "Wrong row count")(stats.get(3, null))
+ assertResult(stats.get(4, null), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum