aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-28 08:03:58 -0700
committerDavies Liu <davies@databricks.com>2015-06-28 08:03:58 -0700
commit77da5be6f11a7e9cb1d44f7fb97b93481505afe8 (patch)
tree95badc0ee5149fb7ce126a7a3590877415e10981 /sql/catalyst/src
parent52d128180166280af443fae84ac61386f3d6c500 (diff)
downloadspark-77da5be6f11a7e9cb1d44f7fb97b93481505afe8.tar.gz
spark-77da5be6f11a7e9cb1d44f7fb97b93481505afe8.tar.bz2
spark-77da5be6f11a7e9cb1d44f7fb97b93481505afe8.zip
[SPARK-8610] [SQL] Separate Row and InternalRow (part 2)
Currently, we use GenericRow both for Row and InternalRow, which is confusing because it could contain Scala type also Catalyst types. This PR changes to use GenericInternalRow for InternalRow (contains catalyst types), GenericRow for Row (contains Scala types). Also fixes some incorrect use of InternalRow or Row. Author: Davies Liu <davies@databricks.com> Closes #7003 from davies/internalrow and squashes the following commits: d05866c [Davies Liu] fix test: rollback changes for pyspark 72878dd [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow efd0b25 [Davies Liu] fix copy of MutableRow 87b13cf [Davies Liu] fix test d2ebd72 [Davies Liu] fix style eb4b473 [Davies Liu] mark expensive API as final bd4e99c [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow bdfb78f [Davies Liu] remove BaseMutableRow 6f99a97 [Davies Liu] fix catalyst test defe931 [Davies Liu] remove BaseRow 288b31f [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow 9d24350 [Davies Liu] separate Row and InternalRow (part 2)
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java68
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java197
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala149
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala6
14 files changed, 166 insertions, 444 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
deleted file mode 100644
index acec2bf452..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql;
-
-import org.apache.spark.sql.catalyst.expressions.MutableRow;
-
-public abstract class BaseMutableRow extends BaseRow implements MutableRow {
-
- @Override
- public void update(int ordinal, Object value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setInt(int ordinal, int value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setLong(int ordinal, long value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setDouble(int ordinal, double value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setBoolean(int ordinal, boolean value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setShort(int ordinal, short value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setByte(int ordinal, byte value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setFloat(int ordinal, float value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setString(int ordinal, String value) {
- throw new UnsupportedOperationException();
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
deleted file mode 100644
index 6a2356f1f9..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql;
-
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.sql.Timestamp;
-import java.util.List;
-
-import scala.collection.Seq;
-import scala.collection.mutable.ArraySeq;
-
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.StructType;
-
-public abstract class BaseRow extends InternalRow {
-
- @Override
- final public int length() {
- return size();
- }
-
- @Override
- public boolean anyNull() {
- final int n = size();
- for (int i=0; i < n; i++) {
- if (isNullAt(i)) {
- return true;
- }
- }
- return false;
- }
-
- @Override
- public StructType schema() { throw new UnsupportedOperationException(); }
-
- @Override
- final public Object apply(int i) {
- return get(i);
- }
-
- @Override
- public int getInt(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public long getLong(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public float getFloat(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public double getDouble(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public byte getByte(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public short getShort(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public boolean getBoolean(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public String getString(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public BigDecimal getDecimal(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Date getDate(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Timestamp getTimestamp(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> Seq<T> getSeq(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> List<T> getList(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> scala.collection.Map<K, V> getMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> java.util.Map<K, V> getJavaMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Row getStruct(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(String fieldName) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int fieldIndex(String name) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public InternalRow copy() {
- final int n = size();
- Object[] arr = new Object[n];
- for (int i = 0; i < n; i++) {
- arr[i] = get(i);
- }
- return new GenericRow(arr);
- }
-
- @Override
- public Seq<Object> toSeq() {
- final int n = size();
- final ArraySeq<Object> values = new ArraySeq<Object>(n);
- for (int i = 0; i < n; i++) {
- values.update(i, get(i));
- }
- return values;
- }
-
- @Override
- public String toString() {
- return mkString("[", ",", "]");
- }
-
- @Override
- public String mkString() {
- return toSeq().mkString();
- }
-
- @Override
- public String mkString(String sep) {
- return toSeq().mkString(sep);
- }
-
- @Override
- public String mkString(String start, String sep, String end) {
- return toSeq().mkString(start, sep, end);
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index bb2f2079b4..11d51d90f1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -23,16 +23,12 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
-import scala.collection.Seq;
-import scala.collection.mutable.ArraySeq;
-
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
@@ -52,7 +48,7 @@ import static org.apache.spark.sql.types.DataTypes.*;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow extends BaseMutableRow {
+public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
@@ -63,6 +59,8 @@ public final class UnsafeRow extends BaseMutableRow {
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
+ public int length() { return numFields; }
+
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
/**
@@ -344,13 +342,4 @@ public final class UnsafeRow extends BaseMutableRow {
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
}
-
- @Override
- public Seq<Object> toSeq() {
- final ArraySeq<Object> values = new ArraySeq<Object>(numFields);
- for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) {
- values.update(fieldNumber, get(fieldNumber));
- }
- return values;
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index e99d5c87a4..0f2fd6a86d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -179,7 +179,7 @@ trait Row extends Serializable {
def get(i: Int): Any = apply(i)
/** Checks whether the value at position i is null. */
- def isNullAt(i: Int): Boolean
+ def isNullAt(i: Int): Boolean = apply(i) == null
/**
* Returns the value at position i as a primitive boolean.
@@ -187,7 +187,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getBoolean(i: Int): Boolean
+ def getBoolean(i: Int): Boolean = getAs[Boolean](i)
/**
* Returns the value at position i as a primitive byte.
@@ -195,7 +195,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getByte(i: Int): Byte
+ def getByte(i: Int): Byte = getAs[Byte](i)
/**
* Returns the value at position i as a primitive short.
@@ -203,7 +203,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getShort(i: Int): Short
+ def getShort(i: Int): Short = getAs[Short](i)
/**
* Returns the value at position i as a primitive int.
@@ -211,7 +211,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getInt(i: Int): Int
+ def getInt(i: Int): Int = getAs[Int](i)
/**
* Returns the value at position i as a primitive long.
@@ -219,7 +219,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getLong(i: Int): Long
+ def getLong(i: Int): Long = getAs[Long](i)
/**
* Returns the value at position i as a primitive float.
@@ -228,7 +228,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getFloat(i: Int): Float
+ def getFloat(i: Int): Float = getAs[Float](i)
/**
* Returns the value at position i as a primitive double.
@@ -236,7 +236,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getDouble(i: Int): Double
+ def getDouble(i: Int): Double = getAs[Double](i)
/**
* Returns the value at position i as a String object.
@@ -244,35 +244,35 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getString(i: Int): String
+ def getString(i: Int): String = getAs[String](i)
/**
* Returns the value at position i of decimal type as java.math.BigDecimal.
*
* @throws ClassCastException when data type does not match.
*/
- def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal]
+ def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
/**
* Returns the value at position i of date type as java.sql.Date.
*
* @throws ClassCastException when data type does not match.
*/
- def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
+ def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i)
/**
* Returns the value at position i of date type as java.sql.Timestamp.
*
* @throws ClassCastException when data type does not match.
*/
- def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp]
+ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i)
/**
* Returns the value at position i of array type as a Scala Seq.
*
* @throws ClassCastException when data type does not match.
*/
- def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]]
+ def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i)
/**
* Returns the value at position i of array type as [[java.util.List]].
@@ -288,7 +288,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
- def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]]
+ def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i)
/**
* Returns the value at position i of array type as a [[java.util.Map]].
@@ -366,9 +366,18 @@ trait Row extends Serializable {
/* ---------------------- utility methods for Scala ---------------------- */
/**
- * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
- def toSeq: Seq[Any]
+ def toSeq: Seq[Any] = {
+ val n = length
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ values.update(i, get(i))
+ i += 1
+ }
+ values.toSeq
+ }
/** Displays all elements of this sequence in a string (without a separator). */
def mkString: String = toSeq.mkString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 012f8bbecb..8f63d2120a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -242,7 +242,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(row(idx))
idx += 1
}
- new GenericRowWithSchema(ar, structType)
+ new GenericInternalRow(ar)
case p: Product =>
val ar = new Array[Any](structType.size)
@@ -252,7 +252,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(iter.next())
idx += 1
}
- new GenericRowWithSchema(ar, structType)
+ new GenericInternalRow(ar)
}
override def toScala(row: InternalRow): Row = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index d7b537a9fe..61a29c89d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -19,14 +19,38 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.unsafe.types.UTF8String
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
abstract class InternalRow extends Row {
+
+ // This is only use for test
+ override def getString(i: Int): String = getAs[UTF8String](i).toString
+
+ // These expensive API should not be used internally.
+ final override def getDecimal(i: Int): java.math.BigDecimal =
+ throw new UnsupportedOperationException
+ final override def getDate(i: Int): java.sql.Date =
+ throw new UnsupportedOperationException
+ final override def getTimestamp(i: Int): java.sql.Timestamp =
+ throw new UnsupportedOperationException
+ final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException
+ final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException
+ final override def getMap[K, V](i: Int): scala.collection.Map[K, V] =
+ throw new UnsupportedOperationException
+ final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] =
+ throw new UnsupportedOperationException
+ final override def getStruct(i: Int): Row = throw new UnsupportedOperationException
+ final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException
+ final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] =
+ throw new UnsupportedOperationException
+
// A default implementation to change the return type
override def copy(): InternalRow = this
+ override def apply(i: Int): Any = get(i)
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[Row]) {
@@ -93,27 +117,15 @@ abstract class InternalRow extends Row {
}
object InternalRow {
- def unapplySeq(row: InternalRow): Some[Seq[Any]] = Some(row.toSeq)
-
/**
* This method can be used to construct a [[Row]] with the given values.
*/
- def apply(values: Any*): InternalRow = new GenericRow(values.toArray)
+ def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray)
/**
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
- def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray)
-
- def fromTuple(tuple: Product): InternalRow = fromSeq(tuple.productIterator.toSeq)
-
- /**
- * Merge multiple rows into a single row, one after another.
- */
- def merge(rows: InternalRow*): InternalRow = {
- // TODO: Improve the performance of this if used in performance critical part.
- new GenericRow(rows.flatMap(_.toSeq).toArray)
- }
+ def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray)
/** Returns an empty row. */
val empty = apply()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index d5967438cc..fcfe83ceb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -36,7 +36,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
outputArray(i) = exprArray(i).eval(input)
i += 1
}
- new GenericRow(outputArray)
+ new GenericInternalRow(outputArray)
}
override def toString: String = s"Row => [${exprArray.mkString(",")}]"
@@ -135,12 +135,6 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -149,7 +143,7 @@ class JoinedRow extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -235,12 +229,6 @@ class JoinedRow2 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -249,7 +237,7 @@ class JoinedRow2 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -329,12 +317,6 @@ class JoinedRow3 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -343,7 +325,7 @@ class JoinedRow3 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -423,12 +405,6 @@ class JoinedRow4 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -437,7 +413,7 @@ class JoinedRow4 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -517,12 +493,6 @@ class JoinedRow5 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -531,7 +501,7 @@ class JoinedRow5 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -611,12 +581,6 @@ class JoinedRow6 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -625,7 +589,7 @@ class JoinedRow6 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 05aab34559..53fedb531c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
i += 1
}
- new GenericRow(newValues)
+ new GenericInternalRow(newValues)
}
override def update(ordinal: Int, value: Any) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e75e82d380..64ef357a4f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
// MutableProjection is not accessible in Java
-abstract class BaseMutableProjection extends MutableProjection {}
+abstract class BaseMutableProjection extends MutableProjection
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 624e1cf4e2..39d32b78cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import org.apache.spark.sql.BaseMutableRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -149,6 +148,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
"""
}.mkString("\n")
+ val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
+ s"""arr[$i] = c$i;"""
+ }.mkString("\n ")
+
val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
@@ -167,7 +170,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
}
- final class SpecificRow extends ${typeOf[BaseMutableRow]} {
+ final class SpecificRow extends ${typeOf[MutableRow]} {
$columns
@@ -175,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
$initColumns
}
- public int size() { return ${expressions.length};}
+ public int length() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
@@ -216,6 +219,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
return super.equals(other);
}
+
+ @Override
+ public InternalRow copy() {
+ Object[] arr = new Object[${expressions.length}];
+ ${copyColumns}
+ return new ${typeOf[GenericInternalRow]}(arr);
+ }
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 356560e54c..7a42a1d310 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@@ -68,19 +69,19 @@ abstract class Generator extends Expression {
*/
case class UserDefinedGenerator(
elementTypes: Seq[(DataType, Boolean)],
- function: InternalRow => TraversableOnce[InternalRow],
+ function: Row => TraversableOnce[InternalRow],
children: Seq[Expression])
extends Generator {
@transient private[this] var inputRow: InterpretedProjection = _
- @transient private[this] var convertToScala: (InternalRow) => InternalRow = _
+ @transient private[this] var convertToScala: (InternalRow) => Row = _
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
CatalystTypeConverters.createToScalaConverter(inputSchema)
- }.asInstanceOf[(InternalRow => InternalRow)]
+ }.asInstanceOf[InternalRow => Row]
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -118,10 +119,11 @@ case class Explode(child: Expression)
child.dataType match {
case ArrayType(_, _) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
- if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
+ if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
- if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) }
+ if (inputMap == null) Nil
+ else inputMap.map { case (k, v) => InternalRow(k, v) }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0d4c9ace5e..dd5f2ed2d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
@@ -24,19 +25,32 @@ import org.apache.spark.unsafe.types.UTF8String
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
* Setting a value through a primitive function implicitly marks that column as not null.
*/
-trait MutableRow extends InternalRow {
+abstract class MutableRow extends InternalRow {
def setNullAt(i: Int): Unit
- def update(ordinal: Int, value: Any)
+ def update(i: Int, value: Any)
+
+ // default implementation (slow)
+ def setInt(i: Int, value: Int): Unit = { update(i, value) }
+ def setLong(i: Int, value: Long): Unit = { update(i, value) }
+ def setDouble(i: Int, value: Double): Unit = { update(i, value) }
+ def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
+ def setShort(i: Int, value: Short): Unit = { update(i, value) }
+ def setByte(i: Int, value: Byte): Unit = { update(i, value) }
+ def setFloat(i: Int, value: Float): Unit = { update(i, value) }
+ def setString(i: Int, value: String): Unit = {
+ update(i, UTF8String.fromString(value))
+ }
- def setInt(ordinal: Int, value: Int)
- def setLong(ordinal: Int, value: Long)
- def setDouble(ordinal: Int, value: Double)
- def setBoolean(ordinal: Int, value: Boolean)
- def setShort(ordinal: Int, value: Short)
- def setByte(ordinal: Int, value: Byte)
- def setFloat(ordinal: Int, value: Float)
- def setString(ordinal: Int, value: String)
+ override def copy(): InternalRow = {
+ val arr = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ arr(i) = get(i)
+ i += 1
+ }
+ new GenericInternalRow(arr)
+ }
}
/**
@@ -60,68 +74,57 @@ object EmptyRow extends InternalRow {
}
/**
- * A row implementation that uses an array of objects as the underlying storage. Note that, while
- * the array is not copied, and thus could technically be mutated after creation, this is not
- * allowed.
+ * A row implementation that uses an array of objects as the underlying storage.
*/
-class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
- /** No-arg constructor for serialization. */
- protected def this() = this(null)
+trait ArrayBackedRow {
+ self: Row =>
- def this(size: Int) = this(new Array[Any](size))
+ protected val values: Array[Any]
override def toSeq: Seq[Any] = values.toSeq
- override def length: Int = values.length
+ def length: Int = values.length
override def apply(i: Int): Any = values(i)
- override def isNullAt(i: Int): Boolean = values(i) == null
-
- override def getInt(i: Int): Int = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
- values(i).asInstanceOf[Int]
- }
-
- override def getLong(i: Int): Long = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
- values(i).asInstanceOf[Long]
- }
-
- override def getDouble(i: Int): Double = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
- values(i).asInstanceOf[Double]
- }
-
- override def getFloat(i: Int): Float = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
- values(i).asInstanceOf[Float]
- }
+ def setNullAt(i: Int): Unit = { values(i) = null}
- override def getBoolean(i: Int): Boolean = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
- values(i).asInstanceOf[Boolean]
- }
+ def update(i: Int, value: Any): Unit = { values(i) = value }
+}
- override def getShort(i: Int): Short = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
- values(i).asInstanceOf[Short]
- }
+/**
+ * A row implementation that uses an array of objects as the underlying storage. Note that, while
+ * the array is not copied, and thus could technically be mutated after creation, this is not
+ * allowed.
+ */
+class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow {
+ /** No-arg constructor for serialization. */
+ protected def this() = this(null)
- override def getByte(i: Int): Byte = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
- values(i).asInstanceOf[Byte]
- }
+ def this(size: Int) = this(new Array[Any](size))
- override def getString(i: Int): String = {
- values(i) match {
- case null => null
- case s: String => s
- case utf8: UTF8String => utf8.toString
- }
+ // This is used by test or outside
+ override def equals(o: Any): Boolean = o match {
+ case other: Row if other.length == length =>
+ var i = 0
+ while (i < length) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ val equal = (apply(i), other.apply(i)) match {
+ case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b)
+ case (a, b) => a == b
+ }
+ if (!equal) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
}
- override def copy(): InternalRow = this
+ override def copy(): Row = this
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
@@ -133,32 +136,30 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
+/**
+ * A internal row implementation that uses an array of objects as the underlying storage.
+ * Note that, while the array is not copied, and thus could technically be mutated after creation,
+ * this is not allowed.
+ */
+class GenericInternalRow(protected[sql] val values: Array[Any])
+ extends InternalRow with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
- override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
- override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
- override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
- override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
- override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
- override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = {
- values(ordinal) = UTF8String.fromString(value)
- }
-
- override def setNullAt(i: Int): Unit = { values(i) = null }
+ override def copy(): InternalRow = this
+}
- override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
+class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow {
+ /** No-arg constructor for serialization. */
+ protected def this() = this(null)
- override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
+ def this(size: Int) = this(new Array[Any](size))
- override def copy(): InternalRow = new GenericRow(values.clone())
+ override def copy(): InternalRow = new GenericInternalRow(values.clone())
}
-
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 158f54af13..7d95ef7f71 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -33,7 +33,7 @@ trait ExpressionEvalHelper {
self: SparkFunSuite =>
protected def create_row(values: Any*): InternalRow = {
- new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
protected def checkEvaluation(
@@ -122,7 +122,7 @@ trait ExpressionEvalHelper {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = InternalRow(expected)
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 7aae2bbd8a..3095ccb777 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -37,7 +37,7 @@ class UnsafeFixedWidthAggregationMapSuite
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
- private def emptyAggregationBuffer: InternalRow = new GenericRow(Array[Any](0))
+ private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private var memoryManager: TaskMemoryManager = null
@@ -84,7 +84,7 @@ class UnsafeFixedWidthAggregationMapSuite
1024, // initial capacity
false // disable perf metrics
)
- val groupKey = new GenericRow(Array[Any](UTF8String.fromString("cats")))
+ val groupKey = InternalRow(UTF8String.fromString("cats"))
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
map.getAggregationBuffer(groupKey)
@@ -113,7 +113,7 @@ class UnsafeFixedWidthAggregationMapSuite
val rand = new Random(42)
val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
groupKeys.foreach { keyString =>
- map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String.fromString(keyString))))
+ map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
}
val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
entry.key.getString(0)