aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java68
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java197
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala149
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala6
14 files changed, 166 insertions, 444 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
deleted file mode 100644
index acec2bf452..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql;
-
-import org.apache.spark.sql.catalyst.expressions.MutableRow;
-
-public abstract class BaseMutableRow extends BaseRow implements MutableRow {
-
- @Override
- public void update(int ordinal, Object value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setInt(int ordinal, int value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setLong(int ordinal, long value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setDouble(int ordinal, double value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setBoolean(int ordinal, boolean value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setShort(int ordinal, short value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setByte(int ordinal, byte value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setFloat(int ordinal, float value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void setString(int ordinal, String value) {
- throw new UnsupportedOperationException();
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
deleted file mode 100644
index 6a2356f1f9..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql;
-
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.sql.Timestamp;
-import java.util.List;
-
-import scala.collection.Seq;
-import scala.collection.mutable.ArraySeq;
-
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.StructType;
-
-public abstract class BaseRow extends InternalRow {
-
- @Override
- final public int length() {
- return size();
- }
-
- @Override
- public boolean anyNull() {
- final int n = size();
- for (int i=0; i < n; i++) {
- if (isNullAt(i)) {
- return true;
- }
- }
- return false;
- }
-
- @Override
- public StructType schema() { throw new UnsupportedOperationException(); }
-
- @Override
- final public Object apply(int i) {
- return get(i);
- }
-
- @Override
- public int getInt(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public long getLong(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public float getFloat(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public double getDouble(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public byte getByte(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public short getShort(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public boolean getBoolean(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public String getString(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public BigDecimal getDecimal(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Date getDate(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Timestamp getTimestamp(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> Seq<T> getSeq(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> List<T> getList(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> scala.collection.Map<K, V> getMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> java.util.Map<K, V> getJavaMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Row getStruct(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(String fieldName) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int fieldIndex(String name) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public InternalRow copy() {
- final int n = size();
- Object[] arr = new Object[n];
- for (int i = 0; i < n; i++) {
- arr[i] = get(i);
- }
- return new GenericRow(arr);
- }
-
- @Override
- public Seq<Object> toSeq() {
- final int n = size();
- final ArraySeq<Object> values = new ArraySeq<Object>(n);
- for (int i = 0; i < n; i++) {
- values.update(i, get(i));
- }
- return values;
- }
-
- @Override
- public String toString() {
- return mkString("[", ",", "]");
- }
-
- @Override
- public String mkString() {
- return toSeq().mkString();
- }
-
- @Override
- public String mkString(String sep) {
- return toSeq().mkString(sep);
- }
-
- @Override
- public String mkString(String start, String sep, String end) {
- return toSeq().mkString(start, sep, end);
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index bb2f2079b4..11d51d90f1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -23,16 +23,12 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
-import scala.collection.Seq;
-import scala.collection.mutable.ArraySeq;
-
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
@@ -52,7 +48,7 @@ import static org.apache.spark.sql.types.DataTypes.*;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow extends BaseMutableRow {
+public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
@@ -63,6 +59,8 @@ public final class UnsafeRow extends BaseMutableRow {
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
+ public int length() { return numFields; }
+
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
/**
@@ -344,13 +342,4 @@ public final class UnsafeRow extends BaseMutableRow {
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
}
-
- @Override
- public Seq<Object> toSeq() {
- final ArraySeq<Object> values = new ArraySeq<Object>(numFields);
- for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) {
- values.update(fieldNumber, get(fieldNumber));
- }
- return values;
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index e99d5c87a4..0f2fd6a86d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -179,7 +179,7 @@ trait Row extends Serializable {
def get(i: Int): Any = apply(i)
/** Checks whether the value at position i is null. */
- def isNullAt(i: Int): Boolean
+ def isNullAt(i: Int): Boolean = apply(i) == null
/**
* Returns the value at position i as a primitive boolean.
@@ -187,7 +187,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getBoolean(i: Int): Boolean
+ def getBoolean(i: Int): Boolean = getAs[Boolean](i)
/**
* Returns the value at position i as a primitive byte.
@@ -195,7 +195,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getByte(i: Int): Byte
+ def getByte(i: Int): Byte = getAs[Byte](i)
/**
* Returns the value at position i as a primitive short.
@@ -203,7 +203,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getShort(i: Int): Short
+ def getShort(i: Int): Short = getAs[Short](i)
/**
* Returns the value at position i as a primitive int.
@@ -211,7 +211,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getInt(i: Int): Int
+ def getInt(i: Int): Int = getAs[Int](i)
/**
* Returns the value at position i as a primitive long.
@@ -219,7 +219,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getLong(i: Int): Long
+ def getLong(i: Int): Long = getAs[Long](i)
/**
* Returns the value at position i as a primitive float.
@@ -228,7 +228,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getFloat(i: Int): Float
+ def getFloat(i: Int): Float = getAs[Float](i)
/**
* Returns the value at position i as a primitive double.
@@ -236,7 +236,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getDouble(i: Int): Double
+ def getDouble(i: Int): Double = getAs[Double](i)
/**
* Returns the value at position i as a String object.
@@ -244,35 +244,35 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
- def getString(i: Int): String
+ def getString(i: Int): String = getAs[String](i)
/**
* Returns the value at position i of decimal type as java.math.BigDecimal.
*
* @throws ClassCastException when data type does not match.
*/
- def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal]
+ def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
/**
* Returns the value at position i of date type as java.sql.Date.
*
* @throws ClassCastException when data type does not match.
*/
- def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
+ def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i)
/**
* Returns the value at position i of date type as java.sql.Timestamp.
*
* @throws ClassCastException when data type does not match.
*/
- def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp]
+ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i)
/**
* Returns the value at position i of array type as a Scala Seq.
*
* @throws ClassCastException when data type does not match.
*/
- def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]]
+ def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i)
/**
* Returns the value at position i of array type as [[java.util.List]].
@@ -288,7 +288,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
- def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]]
+ def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i)
/**
* Returns the value at position i of array type as a [[java.util.Map]].
@@ -366,9 +366,18 @@ trait Row extends Serializable {
/* ---------------------- utility methods for Scala ---------------------- */
/**
- * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
- def toSeq: Seq[Any]
+ def toSeq: Seq[Any] = {
+ val n = length
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ values.update(i, get(i))
+ i += 1
+ }
+ values.toSeq
+ }
/** Displays all elements of this sequence in a string (without a separator). */
def mkString: String = toSeq.mkString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 012f8bbecb..8f63d2120a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -242,7 +242,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(row(idx))
idx += 1
}
- new GenericRowWithSchema(ar, structType)
+ new GenericInternalRow(ar)
case p: Product =>
val ar = new Array[Any](structType.size)
@@ -252,7 +252,7 @@ object CatalystTypeConverters {
ar(idx) = converters(idx).toCatalyst(iter.next())
idx += 1
}
- new GenericRowWithSchema(ar, structType)
+ new GenericInternalRow(ar)
}
override def toScala(row: InternalRow): Row = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index d7b537a9fe..61a29c89d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -19,14 +19,38 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.unsafe.types.UTF8String
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
abstract class InternalRow extends Row {
+
+ // This is only use for test
+ override def getString(i: Int): String = getAs[UTF8String](i).toString
+
+ // These expensive API should not be used internally.
+ final override def getDecimal(i: Int): java.math.BigDecimal =
+ throw new UnsupportedOperationException
+ final override def getDate(i: Int): java.sql.Date =
+ throw new UnsupportedOperationException
+ final override def getTimestamp(i: Int): java.sql.Timestamp =
+ throw new UnsupportedOperationException
+ final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException
+ final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException
+ final override def getMap[K, V](i: Int): scala.collection.Map[K, V] =
+ throw new UnsupportedOperationException
+ final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] =
+ throw new UnsupportedOperationException
+ final override def getStruct(i: Int): Row = throw new UnsupportedOperationException
+ final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException
+ final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] =
+ throw new UnsupportedOperationException
+
// A default implementation to change the return type
override def copy(): InternalRow = this
+ override def apply(i: Int): Any = get(i)
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[Row]) {
@@ -93,27 +117,15 @@ abstract class InternalRow extends Row {
}
object InternalRow {
- def unapplySeq(row: InternalRow): Some[Seq[Any]] = Some(row.toSeq)
-
/**
* This method can be used to construct a [[Row]] with the given values.
*/
- def apply(values: Any*): InternalRow = new GenericRow(values.toArray)
+ def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray)
/**
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
- def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray)
-
- def fromTuple(tuple: Product): InternalRow = fromSeq(tuple.productIterator.toSeq)
-
- /**
- * Merge multiple rows into a single row, one after another.
- */
- def merge(rows: InternalRow*): InternalRow = {
- // TODO: Improve the performance of this if used in performance critical part.
- new GenericRow(rows.flatMap(_.toSeq).toArray)
- }
+ def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray)
/** Returns an empty row. */
val empty = apply()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index d5967438cc..fcfe83ceb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -36,7 +36,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
outputArray(i) = exprArray(i).eval(input)
i += 1
}
- new GenericRow(outputArray)
+ new GenericInternalRow(outputArray)
}
override def toString: String = s"Row => [${exprArray.mkString(",")}]"
@@ -135,12 +135,6 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -149,7 +143,7 @@ class JoinedRow extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -235,12 +229,6 @@ class JoinedRow2 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -249,7 +237,7 @@ class JoinedRow2 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -329,12 +317,6 @@ class JoinedRow3 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -343,7 +325,7 @@ class JoinedRow3 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -423,12 +405,6 @@ class JoinedRow4 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -437,7 +413,7 @@ class JoinedRow4 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -517,12 +493,6 @@ class JoinedRow5 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -531,7 +501,7 @@ class JoinedRow5 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
@@ -611,12 +581,6 @@ class JoinedRow6 extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- override def getString(i: Int): String =
- if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
-
- override def getAs[T](i: Int): T =
- if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
-
override def copy(): InternalRow = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
@@ -625,7 +589,7 @@ class JoinedRow6 extends InternalRow {
copiedValues(i) = apply(i)
i += 1
}
- new GenericRow(copiedValues)
+ new GenericInternalRow(copiedValues)
}
override def toString: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 05aab34559..53fedb531c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
i += 1
}
- new GenericRow(newValues)
+ new GenericInternalRow(newValues)
}
override def update(ordinal: Int, value: Any) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e75e82d380..64ef357a4f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
// MutableProjection is not accessible in Java
-abstract class BaseMutableProjection extends MutableProjection {}
+abstract class BaseMutableProjection extends MutableProjection
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 624e1cf4e2..39d32b78cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import org.apache.spark.sql.BaseMutableRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -149,6 +148,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
"""
}.mkString("\n")
+ val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
+ s"""arr[$i] = c$i;"""
+ }.mkString("\n ")
+
val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
@@ -167,7 +170,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
}
- final class SpecificRow extends ${typeOf[BaseMutableRow]} {
+ final class SpecificRow extends ${typeOf[MutableRow]} {
$columns
@@ -175,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
$initColumns
}
- public int size() { return ${expressions.length};}
+ public int length() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
@@ -216,6 +219,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
return super.equals(other);
}
+
+ @Override
+ public InternalRow copy() {
+ Object[] arr = new Object[${expressions.length}];
+ ${copyColumns}
+ return new ${typeOf[GenericInternalRow]}(arr);
+ }
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 356560e54c..7a42a1d310 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@@ -68,19 +69,19 @@ abstract class Generator extends Expression {
*/
case class UserDefinedGenerator(
elementTypes: Seq[(DataType, Boolean)],
- function: InternalRow => TraversableOnce[InternalRow],
+ function: Row => TraversableOnce[InternalRow],
children: Seq[Expression])
extends Generator {
@transient private[this] var inputRow: InterpretedProjection = _
- @transient private[this] var convertToScala: (InternalRow) => InternalRow = _
+ @transient private[this] var convertToScala: (InternalRow) => Row = _
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
CatalystTypeConverters.createToScalaConverter(inputSchema)
- }.asInstanceOf[(InternalRow => InternalRow)]
+ }.asInstanceOf[InternalRow => Row]
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -118,10 +119,11 @@ case class Explode(child: Expression)
child.dataType match {
case ArrayType(_, _) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
- if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
+ if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
- if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) }
+ if (inputMap == null) Nil
+ else inputMap.map { case (k, v) => InternalRow(k, v) }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0d4c9ace5e..dd5f2ed2d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
@@ -24,19 +25,32 @@ import org.apache.spark.unsafe.types.UTF8String
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
* Setting a value through a primitive function implicitly marks that column as not null.
*/
-trait MutableRow extends InternalRow {
+abstract class MutableRow extends InternalRow {
def setNullAt(i: Int): Unit
- def update(ordinal: Int, value: Any)
+ def update(i: Int, value: Any)
+
+ // default implementation (slow)
+ def setInt(i: Int, value: Int): Unit = { update(i, value) }
+ def setLong(i: Int, value: Long): Unit = { update(i, value) }
+ def setDouble(i: Int, value: Double): Unit = { update(i, value) }
+ def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
+ def setShort(i: Int, value: Short): Unit = { update(i, value) }
+ def setByte(i: Int, value: Byte): Unit = { update(i, value) }
+ def setFloat(i: Int, value: Float): Unit = { update(i, value) }
+ def setString(i: Int, value: String): Unit = {
+ update(i, UTF8String.fromString(value))
+ }
- def setInt(ordinal: Int, value: Int)
- def setLong(ordinal: Int, value: Long)
- def setDouble(ordinal: Int, value: Double)
- def setBoolean(ordinal: Int, value: Boolean)
- def setShort(ordinal: Int, value: Short)
- def setByte(ordinal: Int, value: Byte)
- def setFloat(ordinal: Int, value: Float)
- def setString(ordinal: Int, value: String)
+ override def copy(): InternalRow = {
+ val arr = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ arr(i) = get(i)
+ i += 1
+ }
+ new GenericInternalRow(arr)
+ }
}
/**
@@ -60,68 +74,57 @@ object EmptyRow extends InternalRow {
}
/**
- * A row implementation that uses an array of objects as the underlying storage. Note that, while
- * the array is not copied, and thus could technically be mutated after creation, this is not
- * allowed.
+ * A row implementation that uses an array of objects as the underlying storage.
*/
-class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
- /** No-arg constructor for serialization. */
- protected def this() = this(null)
+trait ArrayBackedRow {
+ self: Row =>
- def this(size: Int) = this(new Array[Any](size))
+ protected val values: Array[Any]
override def toSeq: Seq[Any] = values.toSeq
- override def length: Int = values.length
+ def length: Int = values.length
override def apply(i: Int): Any = values(i)
- override def isNullAt(i: Int): Boolean = values(i) == null
-
- override def getInt(i: Int): Int = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
- values(i).asInstanceOf[Int]
- }
-
- override def getLong(i: Int): Long = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
- values(i).asInstanceOf[Long]
- }
-
- override def getDouble(i: Int): Double = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
- values(i).asInstanceOf[Double]
- }
-
- override def getFloat(i: Int): Float = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
- values(i).asInstanceOf[Float]
- }
+ def setNullAt(i: Int): Unit = { values(i) = null}
- override def getBoolean(i: Int): Boolean = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
- values(i).asInstanceOf[Boolean]
- }
+ def update(i: Int, value: Any): Unit = { values(i) = value }
+}
- override def getShort(i: Int): Short = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
- values(i).asInstanceOf[Short]
- }
+/**
+ * A row implementation that uses an array of objects as the underlying storage. Note that, while
+ * the array is not copied, and thus could technically be mutated after creation, this is not
+ * allowed.
+ */
+class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow {
+ /** No-arg constructor for serialization. */
+ protected def this() = this(null)
- override def getByte(i: Int): Byte = {
- if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
- values(i).asInstanceOf[Byte]
- }
+ def this(size: Int) = this(new Array[Any](size))
- override def getString(i: Int): String = {
- values(i) match {
- case null => null
- case s: String => s
- case utf8: UTF8String => utf8.toString
- }
+ // This is used by test or outside
+ override def equals(o: Any): Boolean = o match {
+ case other: Row if other.length == length =>
+ var i = 0
+ while (i < length) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ val equal = (apply(i), other.apply(i)) match {
+ case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b)
+ case (a, b) => a == b
+ }
+ if (!equal) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
}
- override def copy(): InternalRow = this
+ override def copy(): Row = this
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
@@ -133,32 +136,30 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
+/**
+ * A internal row implementation that uses an array of objects as the underlying storage.
+ * Note that, while the array is not copied, and thus could technically be mutated after creation,
+ * this is not allowed.
+ */
+class GenericInternalRow(protected[sql] val values: Array[Any])
+ extends InternalRow with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
- override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
- override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
- override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
- override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
- override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
- override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = {
- values(ordinal) = UTF8String.fromString(value)
- }
-
- override def setNullAt(i: Int): Unit = { values(i) = null }
+ override def copy(): InternalRow = this
+}
- override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
+class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow {
+ /** No-arg constructor for serialization. */
+ protected def this() = this(null)
- override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
+ def this(size: Int) = this(new Array[Any](size))
- override def copy(): InternalRow = new GenericRow(values.clone())
+ override def copy(): InternalRow = new GenericInternalRow(values.clone())
}
-
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 158f54af13..7d95ef7f71 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -33,7 +33,7 @@ trait ExpressionEvalHelper {
self: SparkFunSuite =>
protected def create_row(values: Any*): InternalRow = {
- new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
protected def checkEvaluation(
@@ -122,7 +122,7 @@ trait ExpressionEvalHelper {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = InternalRow(expected)
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 7aae2bbd8a..3095ccb777 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -37,7 +37,7 @@ class UnsafeFixedWidthAggregationMapSuite
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
- private def emptyAggregationBuffer: InternalRow = new GenericRow(Array[Any](0))
+ private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private var memoryManager: TaskMemoryManager = null
@@ -84,7 +84,7 @@ class UnsafeFixedWidthAggregationMapSuite
1024, // initial capacity
false // disable perf metrics
)
- val groupKey = new GenericRow(Array[Any](UTF8String.fromString("cats")))
+ val groupKey = InternalRow(UTF8String.fromString("cats"))
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
map.getAggregationBuffer(groupKey)
@@ -113,7 +113,7 @@ class UnsafeFixedWidthAggregationMapSuite
val rand = new Random(42)
val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
groupKeys.foreach { keyString =>
- map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String.fromString(keyString))))
+ map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
}
val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
entry.key.getString(0)