aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-24 09:37:36 -0700
committerReynold Xin <rxin@databricks.com>2015-07-24 09:37:36 -0700
commit431ca39be51352dfcdacc87de7e64c2af313558d (patch)
treed2e3bf64fb56b2bd4950d61d3a845db32ea1c847
parent3aec9f4e2d8fcce9ddf84ab4d0e10147c18afa16 (diff)
downloadspark-431ca39be51352dfcdacc87de7e64c2af313558d.tar.gz
spark-431ca39be51352dfcdacc87de7e64c2af313558d.tar.bz2
spark-431ca39be51352dfcdacc87de7e64c2af313558d.zip
[SPARK-9285][SQL] Remove InternalRow's inheritance from Row.
I also changed InternalRow's size/length function to numFields, to make it more obvious that it is not about bytes, but the number of fields. Author: Reynold Xin <rxin@databricks.com> Closes #7626 from rxin/internalRow and squashes the following commits: e124daf [Reynold Xin] Fixed test case. 805ceb7 [Reynold Xin] Commented out the failed test suite. f8a9ca5 [Reynold Xin] Fixed more bugs. Still at least one more remaining. 76d9081 [Reynold Xin] Fixed data sources. 7807f70 [Reynold Xin] Fixed DataFrameSuite. cb60cd2 [Reynold Xin] Code review & small bug fixes. 0a2948b [Reynold Xin] Fixed style. 3280d03 [Reynold Xin] [SPARK-9285][SQL] Remove InternalRow's inheritance from Row.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala153
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala168
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala57
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala47
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala139
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala57
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala166
41 files changed, 647 insertions, 433 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 55da0e094d..b6e2c30fbf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -174,8 +174,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
override def deserialize(datum: Any): Matrix = {
datum match {
case row: InternalRow =>
- require(row.length == 7,
- s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
+ require(row.numFields == 7,
+ s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 9067b3ba9a..c884aad088 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -203,8 +203,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def deserialize(datum: Any): Vector = {
datum match {
case row: InternalRow =>
- require(row.length == 4,
- s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ require(row.numFields == 4,
+ s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
val tpe = row.getByte(0)
tpe match {
case 0 =>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index fa1216b455..a898660885 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -64,7 +64,8 @@ public final class UnsafeRow extends MutableRow {
/** The size of this row's backing data, in bytes) */
private int sizeInBytes;
- public int length() { return numFields; }
+ @Override
+ public int numFields() { return numFields; }
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
@@ -218,12 +219,12 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public int size() {
- return numFields;
+ public Object get(int i) {
+ throw new UnsupportedOperationException();
}
@Override
- public Object get(int i) {
+ public <T> T getAs(int i) {
throw new UnsupportedOperationException();
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index bfaee04f33..5c3072a77a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -140,14 +140,14 @@ object CatalystTypeConverters {
private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] {
override def toCatalystImpl(scalaValue: Any): Any = scalaValue
override def toScala(catalystValue: Any): Any = catalystValue
- override def toScalaImpl(row: InternalRow, column: Int): Any = row(column)
+ override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column)
}
private case class UDTConverter(
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
- override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column))
+ override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column))
}
/** Converter for arrays, sequences, and Java iterables. */
@@ -184,7 +184,7 @@ object CatalystTypeConverters {
}
override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
- toScala(row(column).asInstanceOf[Seq[Any]])
+ toScala(row.get(column).asInstanceOf[Seq[Any]])
}
private case class MapConverter(
@@ -227,7 +227,7 @@ object CatalystTypeConverters {
}
override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] =
- toScala(row(column).asInstanceOf[Map[Any, Any]])
+ toScala(row.get(column).asInstanceOf[Map[Any, Any]])
}
private case class StructConverter(
@@ -260,9 +260,9 @@ object CatalystTypeConverters {
if (row == null) {
null
} else {
- val ar = new Array[Any](row.size)
+ val ar = new Array[Any](row.numFields)
var idx = 0
- while (idx < row.size) {
+ while (idx < row.numFields) {
ar(idx) = converters(idx).toScala(row, idx)
idx += 1
}
@@ -271,7 +271,7 @@ object CatalystTypeConverters {
}
override def toScalaImpl(row: InternalRow, column: Int): Row =
- toScala(row(column).asInstanceOf[InternalRow])
+ toScala(row.get(column).asInstanceOf[InternalRow])
}
private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index c7ec49b3d6..efc4faea56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -25,48 +25,139 @@ import org.apache.spark.unsafe.types.UTF8String
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
-abstract class InternalRow extends Row {
+abstract class InternalRow extends Serializable {
- def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i)
+ def numFields: Int
- def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i)
+ def get(i: Int): Any
- // This is only use for test
- override def getString(i: Int): String = getAs[UTF8String](i).toString
-
- // These expensive API should not be used internally.
- final override def getDecimal(i: Int): java.math.BigDecimal =
- throw new UnsupportedOperationException
- final override def getDate(i: Int): java.sql.Date =
- throw new UnsupportedOperationException
- final override def getTimestamp(i: Int): java.sql.Timestamp =
- throw new UnsupportedOperationException
- final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException
- final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException
- final override def getMap[K, V](i: Int): scala.collection.Map[K, V] =
- throw new UnsupportedOperationException
- final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] =
- throw new UnsupportedOperationException
- final override def getStruct(i: Int): Row = throw new UnsupportedOperationException
- final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException
- final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] =
- throw new UnsupportedOperationException
-
- // A default implementation to change the return type
- override def copy(): InternalRow = this
+ // TODO: Remove this.
+ def apply(i: Int): Any = get(i)
+
+ def getAs[T](i: Int): T = get(i).asInstanceOf[T]
+
+ def isNullAt(i: Int): Boolean = get(i) == null
+
+ def getBoolean(i: Int): Boolean = getAs[Boolean](i)
+
+ def getByte(i: Int): Byte = getAs[Byte](i)
+
+ def getShort(i: Int): Short = getAs[Short](i)
+
+ def getInt(i: Int): Int = getAs[Int](i)
+
+ def getLong(i: Int): Long = getAs[Long](i)
+
+ def getFloat(i: Int): Float = getAs[Float](i)
+
+ def getDouble(i: Int): Double = getAs[Double](i)
+
+ override def toString: String = s"[${this.mkString(",")}]"
+
+ /**
+ * Make a copy of the current [[InternalRow]] object.
+ */
+ def copy(): InternalRow = this
+
+ /** Returns true if there are any NULL values in this row. */
+ def anyNull: Boolean = {
+ val len = numFields
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i)) { return true }
+ i += 1
+ }
+ false
+ }
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[InternalRow]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[InternalRow]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numFields
+ if (len != other.numFields) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = get(i)
+ val o2 = other.get(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ /* ---------------------- utility methods for Scala ---------------------- */
/**
- * Returns true if we can check equality for these 2 rows.
- * Equality check between external row and internal row is not allowed.
- * Here we do this check to prevent call `equals` on internal row with external row.
+ * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
- protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow]
+ def toSeq: Seq[Any] = {
+ val n = numFields
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ values.update(i, get(i))
+ i += 1
+ }
+ values.toSeq
+ }
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
+
+ def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i)
+
+ def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i)
+
+ // This is only use for test
+ def getString(i: Int): String = getAs[UTF8String](i).toString
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
var i = 0
- while (i < length) {
+ val len = numFields
+ while (i < len) {
val update: Int =
if (isNullAt(i)) {
0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c66854d52c..47ad3e089e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -382,8 +382,8 @@ case class Cast(child: Expression, dataType: DataType)
val newRow = new GenericMutableRow(from.fields.length)
buildCast[InternalRow](_, row => {
var i = 0
- while (i < row.length) {
- newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i)))
+ while (i < row.numFields) {
+ newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i)))
i += 1
}
newRow.copy()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 04872fbc8b..dbda05a792 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -176,49 +176,49 @@ class JoinedRow extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
@@ -278,49 +278,49 @@ class JoinedRow2 extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
@@ -374,50 +374,50 @@ class JoinedRow3 extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
@@ -471,50 +471,50 @@ class JoinedRow4 extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
@@ -568,50 +568,50 @@ class JoinedRow5 extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
@@ -665,50 +665,50 @@ class JoinedRow6 extends InternalRow {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length: Int = row1.length + row2.length
+ override def numFields: Int = row1.numFields + row2.numFields
override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length)
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
}
override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length)
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
}
override def get(i: Int): Any =
- if (i < row1.length) row1(i) else row2(i - row1.length)
+ if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
- if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+ if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getInt(i: Int): Int =
- if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+ if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
- if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+ if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getDouble(i: Int): Double =
- if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+ if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
- if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getShort(i: Int): Short =
- if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getByte(i: Int): Byte =
- if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getFloat(i: Int): Float =
- if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def copy(): InternalRow = {
- val totalSize = row1.length + row2.length
+ val totalSize = row1.numFields + row2.numFields
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
- copiedValues(i) = apply(i)
+ copiedValues(i) = get(i)
i += 1
}
new GenericInternalRow(copiedValues)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 6f291d2c86..4b4833bd06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -211,7 +211,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
def this() = this(Seq.empty)
- override def length: Int = values.length
+ override def numFields: Int = values.length
override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
@@ -245,7 +245,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def setString(ordinal: Int, value: String): Unit =
update(ordinal, UTF8String.fromString(value))
- override def getString(ordinal: Int): String = apply(ordinal).toString
+ override def getString(ordinal: Int): String = get(ordinal).toString
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 73fde4e916..62b6cc834c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -675,7 +675,7 @@ case class CombineSetsAndSumFunction(
val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
val inputIterator = inputSetEval.iterator
while (inputIterator.hasNext) {
- seen.add(inputIterator.next)
+ seen.add(inputIterator.next())
}
}
@@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction(
null
} else {
Cast(Literal(
- casted.iterator.map(f => f.apply(0)).reduceLeft(
+ casted.iterator.map(f => f.get(0)).reduceLeft(
base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
base.dataType).eval(null)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 405d6b0e3b..f0efc4bff1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -178,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
$initColumns
}
- public int length() { return ${expressions.length};}
+ public int numFields() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 5504781edc..c91122cda2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any =
- input.asInstanceOf[InternalRow](ordinal)
+ input.asInstanceOf[InternalRow].get(ordinal)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
@@ -142,7 +142,7 @@ case class GetArrayStructFields(
protected override def nullSafeEval(input: Any): Any = {
input.asInstanceOf[Seq[InternalRow]].map { row =>
- if (row == null) null else row(ordinal)
+ if (row == null) null else row.get(ordinal)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index d78be5a595..53779dd404 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,9 +44,10 @@ abstract class MutableRow extends InternalRow {
}
override def copy(): InternalRow = {
- val arr = new Array[Any](length)
+ val n = numFields
+ val arr = new Array[Any](n)
var i = 0
- while (i < length) {
+ while (i < n) {
arr(i) = get(i)
i += 1
}
@@ -55,35 +56,22 @@ abstract class MutableRow extends InternalRow {
}
/**
- * A row implementation that uses an array of objects as the underlying storage.
- */
-trait ArrayBackedRow {
- self: Row =>
-
- protected val values: Array[Any]
-
- override def toSeq: Seq[Any] = values.toSeq
-
- def length: Int = values.length
-
- override def get(i: Int): Any = values(i)
-
- def setNullAt(i: Int): Unit = { values(i) = null}
-
- def update(i: Int, value: Any): Unit = { values(i) = value }
-}
-
-/**
* A row implementation that uses an array of objects as the underlying storage. Note that, while
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
-class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow {
+class GenericRow(protected[sql] val values: Array[Any]) extends Row {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
+ override def length: Int = values.length
+
+ override def get(i: Int): Any = values(i)
+
+ override def toSeq: Seq[Any] = values.toSeq
+
override def copy(): Row = this
}
@@ -101,34 +89,49 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
* Note that, while the array is not copied, and thus could technically be mutated after creation,
* this is not allowed.
*/
-class GenericInternalRow(protected[sql] val values: Array[Any])
- extends InternalRow with ArrayBackedRow {
+class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
+ override def toSeq: Seq[Any] = values.toSeq
+
+ override def numFields: Int = values.length
+
+ override def get(i: Int): Any = values(i)
+
override def copy(): InternalRow = this
}
/**
* This is used for serialization of Python DataFrame
*/
-class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType)
+class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
extends GenericInternalRow(values) {
/** No-arg constructor for serialization. */
protected def this() = this(null, null)
- override def fieldIndex(name: String): Int = schema.fieldIndex(name)
+ def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow {
+class GenericMutableRow(val values: Array[Any]) extends MutableRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
def this(size: Int) = this(new Array[Any](size))
+ override def toSeq: Seq[Any] = values.toSeq
+
+ override def numFields: Int = values.length
+
+ override def get(i: Int): Any = values(i)
+
+ override def setNullAt(i: Int): Unit = { values(i) = null}
+
+ override def update(i: Int, value: Any): Unit = { values(i) = value }
+
override def copy(): InternalRow = new GenericInternalRow(values.clone())
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
index 878a1bb9b7..01ff84cb56 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -83,15 +83,5 @@ class RowTest extends FunSpec with Matchers {
it("equality check for internal rows") {
internalRow shouldEqual internalRow2
}
-
- it("throws an exception when check equality between external and internal rows") {
- def assertError(f: => Unit): Unit = {
- val e = intercept[UnsupportedOperationException](f)
- e.getMessage.contains("cannot check equality between external and internal rows")
- }
-
- assertError(internalRow.equals(externalRow))
- assertError(externalRow.equals(internalRow))
- }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index facf65c155..408353cf70 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Test suite for data type casting expression [[Cast]].
@@ -580,14 +581,21 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from struct") {
val struct = Literal.create(
- InternalRow("123", "abc", "", null),
+ InternalRow(
+ UTF8String.fromString("123"),
+ UTF8String.fromString("abc"),
+ UTF8String.fromString(""),
+ null),
StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true),
StructField("c", StringType, nullable = true),
StructField("d", StringType, nullable = true))))
val struct_notNull = Literal.create(
- InternalRow("123", "abc", ""),
+ InternalRow(
+ UTF8String.fromString("123"),
+ UTF8String.fromString("abc"),
+ UTF8String.fromString("")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
@@ -676,8 +684,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
InternalRow(
- Seq("123", "abc", ""),
- Map("a" -> "123", "b" -> "abc", "c" -> ""),
+ Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")),
+ Map(
+ UTF8String.fromString("a") -> UTF8String.fromString("123"),
+ UTF8String.fromString("b") -> UTF8String.fromString("abc"),
+ UTF8String.fromString("c") -> UTF8String.fromString("")),
InternalRow(0)),
StructType(Seq(
StructField("a",
@@ -700,7 +711,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(
Seq(123, null, null),
- Map("a" -> true, "b" -> true, "c" -> false),
+ Map(
+ UTF8String.fromString("a") -> true,
+ UTF8String.fromString("b") -> true,
+ UTF8String.fromString("c") -> false),
InternalRow(0L)))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index a8aee8f634..fc842772f3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -150,12 +151,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("CreateNamedStruct with literal field") {
val row = InternalRow(1, 2, 3)
val c1 = 'a.int.at(0)
- checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row)
+ checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
+ InternalRow(1, UTF8String.fromString("y")), row)
}
test("CreateNamedStruct from all literal fields") {
checkEvaluation(
- CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty)
+ CreateNamedStruct(Seq("a", "x", "b", 2.0)),
+ InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty)
}
test("test dsl for complex type") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 9d8415f063..ac42bde07c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -309,7 +309,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
override def actualSize(row: InternalRow, ordinal: Int): Int = {
- row.getString(ordinal).getBytes("utf-8").length + 4
+ row.getUTF8String(ordinal).numBytes() + 4
}
override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 38720968c1..5d5b0697d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -134,13 +134,13 @@ private[sql] case class InMemoryRelation(
// may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat
// hard to decipher.
assert(
- row.size == columnBuilders.size,
- s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}.
- |Row content: $row
- """.stripMargin)
+ row.numFields == columnBuilders.size,
+ s"Row column number mismatch, expected ${output.size} columns, " +
+ s"but got ${row.numFields}." +
+ s"\nRow content: $row")
var i = 0
- while (i < row.length) {
+ while (i < row.numFields) {
columnBuilders(i).appendFrom(row, i)
i += 1
}
@@ -304,7 +304,7 @@ private[sql] case class InMemoryColumnarTableScan(
// Extract rows via column accessors
new Iterator[InternalRow] {
- private[this] val rowLen = nextRow.length
+ private[this] val rowLen = nextRow.numFields
override def next(): InternalRow = {
var i = 0
while (i < rowLen) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index c87e2064a8..83c4e8733f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -25,7 +25,6 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.serializer._
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.types._
@@ -53,7 +52,7 @@ private[sql] class Serializer2SerializationStream(
private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
override def writeObject[T: ClassTag](t: T): SerializationStream = {
- val kv = t.asInstanceOf[Product2[Row, Row]]
+ val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]]
writeKey(kv._1)
writeValue(kv._2)
@@ -66,7 +65,7 @@ private[sql] class Serializer2SerializationStream(
}
override def writeValue[T: ClassTag](t: T): SerializationStream = {
- writeRowFunc(t.asInstanceOf[Row])
+ writeRowFunc(t.asInstanceOf[InternalRow])
this
}
@@ -205,8 +204,9 @@ private[sql] object SparkSqlSerializer2 {
/**
* The util function to create the serialization function based on the given schema.
*/
- def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = {
- (row: Row) =>
+ def createSerializationFunction(schema: Array[DataType], out: DataOutputStream)
+ : InternalRow => Unit = {
+ (row: InternalRow) =>
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
var i = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2b40092617..7f452daef3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -206,7 +206,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val mutableRow = new SpecificMutableRow(dataTypes)
iterator.map { dataRow =>
var i = 0
- while (i < mutableRow.length) {
+ while (i < mutableRow.numFields) {
mergers(i)(mutableRow, dataRow, i)
i += 1
}
@@ -315,7 +315,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
if (relation.relation.needConversion) {
execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
} else {
- rdd.map(_.asInstanceOf[InternalRow])
+ rdd.asInstanceOf[RDD[InternalRow]]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala
index cd2aa7f743..d551f386ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala
@@ -174,14 +174,19 @@ private[sql] case class InsertIntoHadoopFsRelation(
try {
writerContainer.executorSideSetup(taskContext)
- val converter: InternalRow => Row = if (needsConversion) {
- CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
+ if (needsConversion) {
+ val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
+ .asInstanceOf[InternalRow => Row]
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ writerContainer.outputWriterForRow(internalRow).write(converter(internalRow))
+ }
} else {
- r: InternalRow => r.asInstanceOf[Row]
- }
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- writerContainer.outputWriterForRow(internalRow).write(converter(internalRow))
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ writerContainer.outputWriterForRow(internalRow)
+ .asInstanceOf[OutputWriterInternal].writeInternal(internalRow)
+ }
}
writerContainer.commitTask()
@@ -248,17 +253,23 @@ private[sql] case class InsertIntoHadoopFsRelation(
val partitionProj = newProjection(codegenEnabled, partitionCasts, output)
val dataProj = newProjection(codegenEnabled, dataOutput, output)
- val dataConverter: InternalRow => Row = if (needsConversion) {
- CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
+ if (needsConversion) {
+ val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
+ .asInstanceOf[InternalRow => Row]
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ val partitionPart = partitionProj(internalRow)
+ val dataPart = converter(dataProj(internalRow))
+ writerContainer.outputWriterForRow(partitionPart).write(dataPart)
+ }
} else {
- r: InternalRow => r.asInstanceOf[Row]
- }
-
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- val partitionPart = partitionProj(internalRow)
- val dataPart = dataConverter(dataProj(internalRow))
- writerContainer.outputWriterForRow(partitionPart).write(dataPart)
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ val partitionPart = partitionProj(internalRow)
+ val dataPart = dataProj(internalRow)
+ writerContainer.outputWriterForRow(partitionPart)
+ .asInstanceOf[OutputWriterInternal].writeInternal(dataPart)
+ }
}
writerContainer.commitTask()
@@ -530,8 +541,12 @@ private[sql] class DynamicPartitionWriterContainer(
while (i < partitionColumns.length) {
val col = partitionColumns(i)
val partitionValueString = {
- val string = row.getString(i)
- if (string.eq(null)) defaultPartitionName else PartitioningUtils.escapePathName(string)
+ val string = row.getUTF8String(i)
+ if (string.eq(null)) {
+ defaultPartitionName
+ } else {
+ PartitioningUtils.escapePathName(string.toString)
+ }
}
if (i > 0) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index c8033d3c04..1f2797ec55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -23,11 +23,11 @@ import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode}
+import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -415,12 +415,12 @@ private[sql] case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {
- def run(sqlContext: SQLContext): Seq[InternalRow] = {
+ def run(sqlContext: SQLContext): Seq[Row] = {
val resolved = ResolvedDataSource(
sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
- Seq.empty
+ Seq.empty[Row]
}
}
@@ -432,20 +432,20 @@ private[sql] case class CreateTempTableUsingAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[InternalRow] = {
+ override def run(sqlContext: SQLContext): Seq[Row] = {
val df = DataFrame(sqlContext, query)
val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
- Seq.empty
+ Seq.empty[Row]
}
}
private[sql] case class RefreshTable(databaseName: String, tableName: String)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[InternalRow] = {
+ override def run(sqlContext: SQLContext): Seq[Row] = {
// Refresh the given table's metadata first.
sqlContext.catalog.refreshTable(databaseName, tableName)
@@ -464,7 +464,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String)
sqlContext.cacheManager.cacheQuery(df, Some(tableName))
}
- Seq.empty[InternalRow]
+ Seq.empty[Row]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index e6e27a87c7..40bf03a3f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -126,9 +126,9 @@ object EvaluatePython {
case (null, _) => null
case (row: InternalRow, struct: StructType) =>
- val values = new Array[Any](row.size)
+ val values = new Array[Any](row.numFields)
var i = 0
- while (i < row.size) {
+ while (i < row.numFields) {
values(i) = toJava(row(i), struct.fields(i).dataType)
i += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
index 6c49a906c8..46f0fac861 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
@@ -148,7 +148,7 @@ class InputAggregationBuffer private[sql] (
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
- var underlyingInputBuffer: Row)
+ var underlyingInputBuffer: InternalRow)
extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
override def get(i: Int): Any = {
@@ -156,6 +156,7 @@ class InputAggregationBuffer private[sql] (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
+ // TODO: Use buffer schema to avoid using generic getter.
toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 4d3aac464c..41d0ecb4bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -128,6 +128,7 @@ private[sql] case class JDBCRelation(
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverRegistry.getDriverClassName(url)
+ // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
@@ -137,7 +138,7 @@ private[sql] case class JDBCRelation(
table,
requiredColumns,
filters,
- parts).map(_.asInstanceOf[Row])
+ parts).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index 922794ac9a..562b058414 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -154,17 +154,19 @@ private[sql] class JSONRelation(
}
override def buildScan(): RDD[Row] = {
+ // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JacksonParser(
baseRDD(),
schema,
- sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
+ sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]]
}
override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = {
+ // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JacksonParser(
baseRDD(),
StructType.fromAttributes(requiredColumns),
- sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
+ sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index 0c3d8fdab6..b5e4263008 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveCo
import org.apache.parquet.schema.Type.Repetition
import org.apache.parquet.schema.{GroupType, PrimitiveType, Type}
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -55,8 +55,8 @@ private[parquet] trait ParentContainerUpdater {
private[parquet] object NoopUpdater extends ParentContainerUpdater
/**
- * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since
- * any Parquet record is also a struct, this converter can also be used as root converter.
+ * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s.
+ * Since any Parquet record is also a struct, this converter can also be used as root converter.
*
* When used as a root converter, [[NoopUpdater]] should be used since root converters don't have
* any "parent" container.
@@ -108,7 +108,7 @@ private[parquet] class CatalystRowConverter(
override def start(): Unit = {
var i = 0
- while (i < currentRow.length) {
+ while (i < currentRow.numFields) {
currentRow.setNullAt(i)
i += 1
}
@@ -178,7 +178,7 @@ private[parquet] class CatalystRowConverter(
case t: StructType =>
new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater {
- override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy())
+ override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy())
})
case t: UserDefinedType[_] =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 28cba5e54d..8cab27d6e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -178,7 +178,7 @@ private[sql] case class ParquetTableScan(
val row = iter.next()._2.asInstanceOf[InternalRow]
var i = 0
- while (i < row.size) {
+ while (i < row.numFields) {
mutableRow(i) = row(i)
i += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index d1040bf556..c7c58e69d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -208,9 +208,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
override def write(record: InternalRow): Unit = {
val attributesSize = attributes.size
- if (attributesSize > record.size) {
- throw new IndexOutOfBoundsException(
- s"Trying to write more fields than contained in row ($attributesSize > ${record.size})")
+ if (attributesSize > record.numFields) {
+ throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " +
+ s"($attributesSize > ${record.numFields})")
}
var index = 0
@@ -378,9 +378,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
override def write(record: InternalRow): Unit = {
val attributesSize = attributes.size
- if (attributesSize > record.size) {
- throw new IndexOutOfBoundsException(
- s"Trying to write more fields than contained in row ($attributesSize > ${record.size})")
+ if (attributesSize > record.numFields) {
+ throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " +
+ s"($attributesSize > ${record.numFields})")
}
var index = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index c384697c0e..8ec228c2b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -61,7 +61,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
- extends OutputWriter {
+ extends OutputWriterInternal {
private val recordWriter: RecordWriter[Void, InternalRow] = {
val outputFormat = {
@@ -86,7 +86,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
outputFormat.getRecordWriter(context)
}
- override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow])
+ override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
override def close(): Unit = recordWriter.close(context)
}
@@ -324,7 +324,7 @@ private[sql] class ParquetRelation2(
new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
}
- }.values.map(_.asInstanceOf[Row])
+ }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 7cd005b959..119bac786d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -345,6 +345,18 @@ abstract class OutputWriter {
}
/**
+ * This is an internal, private version of [[OutputWriter]] with an writeInternal method that
+ * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have
+ * the conversion flag set to false.
+ */
+private[sql] abstract class OutputWriterInternal extends OutputWriter {
+
+ override def write(row: Row): Unit = throw new UnsupportedOperationException
+
+ def writeInternal(row: InternalRow): Unit
+}
+
+/**
* ::Experimental::
* A [[BaseRelation]] that provides much of the common code required for formats that store their
* data to an HDFS compatible filesystem.
@@ -592,12 +604,12 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable)
}.toSeq
- val rdd = buildScan(inputFiles)
- val converted =
+ val rdd: RDD[Row] = buildScan(inputFiles)
+ val converted: RDD[InternalRow] =
if (needConversion) {
RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType))
} else {
- rdd.map(_.asInstanceOf[InternalRow])
+ rdd.asInstanceOf[RDD[InternalRow]]
}
converted.mapPartitions { rows =>
val buildProjection = if (codegenEnabled) {
@@ -606,8 +618,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
() => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes)
}
val mutableProjection = buildProjection()
- rows.map(r => mutableProjection(r).asInstanceOf[Row])
- }
+ rows.map(r => mutableProjection(r))
+ }.asInstanceOf[RDD[Row]]
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 7cc6ffd754..0e5c5abff8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -35,14 +35,14 @@ class RowSuite extends SparkFunSuite {
expected.update(2, false)
expected.update(3, null)
val actual1 = Row(2147483647, "this is a string", false, null)
- assert(expected.size === actual1.size)
+ assert(expected.numFields === actual1.size)
assert(expected.getInt(0) === actual1.getInt(0))
assert(expected.getString(1) === actual1.getString(1))
assert(expected.getBoolean(2) === actual1.getBoolean(2))
assert(expected(3) === actual1(3))
val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
- assert(expected.size === actual2.size)
+ assert(expected.numFields === actual2.size)
assert(expected.getInt(0) === actual2.getInt(0))
assert(expected.getString(1) === actual2.getString(1))
assert(expected.getBoolean(2) === actual2.getBoolean(2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index da53ec16b5..84855ce45e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -61,9 +61,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
override def needConversion: Boolean = false
override def buildScan(): RDD[Row] = {
+ // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
sqlContext.sparkContext.parallelize(from to to).map { e =>
- InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row
- }
+ InternalRow(UTF8String.fromString(s"people$e"), e * 2)
+ }.asInstanceOf[RDD[Row]]
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 257526feab..0d5183444a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -131,7 +131,7 @@ class PrunedScanSuite extends DataSourceTest {
queryExecution)
}
- if (rawOutput.size != expectedColumns.size) {
+ if (rawOutput.numFields != expectedColumns.size) {
fail(s"Wrong output row. Got $rawOutput\n$queryExecution")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 143aadc08b..5e189c3563 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -93,7 +93,7 @@ case class AllDataTypesScan(
InternalRow(i, UTF8String.fromString(i.toString)),
InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
- }
+ }.asInstanceOf[RDD[Row]]
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 8202e553af..34b629403e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -122,7 +122,7 @@ case class InsertIntoHiveTable(
*
* Note: this is run once and then kept to avoid double insertions.
*/
- protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
+ protected[sql] lazy val sideEffectResult: Seq[Row] = {
// Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
// instances within the closure, since Serializer is not serializable while TableDesc is.
val tableDesc = table.tableDesc
@@ -252,13 +252,12 @@ case class InsertIntoHiveTable(
// however for now we return an empty list to simplify compatibility checks with hive, which
// does not return anything for insert operations.
// TODO: implement hive compatibility as rules.
- Seq.empty[InternalRow]
+ Seq.empty[Row]
}
- override def executeCollect(): Array[Row] =
- sideEffectResult.toArray
+ override def executeCollect(): Array[Row] = sideEffectResult.toArray
protected override def doExecute(): RDD[InternalRow] = {
- sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+ sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index ecc78a5f8d..8850e060d2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -34,6 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.Row
import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.types._
@@ -94,7 +95,9 @@ private[hive] class SparkHiveWriterContainer(
"part-" + numberFormat.format(splitID) + extension
}
- def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer
+ def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = {
+ writer
+ }
def close() {
// Seems the boolean value passed into close does not matter.
@@ -197,7 +200,8 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker)
}
- override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = {
+ override def getLocalFileWriter(row: InternalRow, schema: StructType)
+ : FileSinkOperator.RecordWriter = {
def convertToHiveRawString(col: String, value: Any): String = {
val raw = String.valueOf(value)
schema(col).dataType match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index de63ee56dd..10623dc820 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
- extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors {
+ extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors {
private val serializer = {
val table = new Properties()
@@ -119,9 +119,9 @@ private[orc] class OrcOutputWriter(
).asInstanceOf[RecordWriter[NullWritable, Writable]]
}
- override def write(row: Row): Unit = {
+ override def writeInternal(row: InternalRow): Unit = {
var i = 0
- while (i < row.length) {
+ while (i < row.numFields) {
reusableOutputBuffer(i) = wrappers(i)(row(i))
i += 1
}
@@ -192,7 +192,7 @@ private[sql] class OrcRelation(
filters: Array[Filter],
inputPaths: Array[FileStatus]): RDD[Row] = {
val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes
- OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row])
+ OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]]
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
new file mode 100644
index 0000000000..e976125b37
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.hadoop.fs.Path
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.test.SQLTestUtils
+
+
+class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
+ override val sqlContext = TestHive
+
+ // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
+ val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
+
+ test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
+ withTempPath { file =>
+ // Here we coalesce partition number to 1 to ensure that only a single task is issued. This
+ // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
+ // directory while committing/aborting the job. See SPARK-8513 for more details.
+ val df = sqlContext.range(0, 10).coalesce(1)
+ intercept[SparkException] {
+ df.write.format(dataSourceName).save(file.getCanonicalPath)
+ }
+
+ val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+ assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
new file mode 100644
index 0000000000..d280543a07
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import com.google.common.io.Files
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.{AnalysisException, SaveMode, parquet}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+
+class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
+ override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName
+
+ import sqlContext._
+ import sqlContext.implicits._
+
+ test("save()/load() - partitioned table - simple queries - partition columns in data") {
+ withTempDir { file =>
+ val basePath = new Path(file.getCanonicalPath)
+ val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
+ val qualifiedBasePath = fs.makeQualified(basePath)
+
+ for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
+ val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ sparkContext
+ .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
+ .toDF("a", "b", "p1")
+ .write.parquet(partitionDir.toString)
+ }
+
+ val dataSchemaWithPartition =
+ StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
+
+ checkQueries(
+ read.format(dataSourceName)
+ .option("dataSchema", dataSchemaWithPartition.json)
+ .load(file.getCanonicalPath))
+ }
+ }
+
+ test("SPARK-7868: _temporary directories should be ignored") {
+ withTempPath { dir =>
+ val df = Seq("a", "b", "c").zipWithIndex.toDF()
+
+ df.write
+ .format("parquet")
+ .save(dir.getCanonicalPath)
+
+ df.write
+ .format("parquet")
+ .save(s"${dir.getCanonicalPath}/_temporary")
+
+ checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
+ }
+ }
+
+ test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") {
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ val df = Seq(1 -> "a").toDF()
+
+ // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw
+ // since it's not a valid Parquet file.
+ val emptyFile = new File(path, "empty")
+ Files.createParentDirs(emptyFile)
+ Files.touch(emptyFile)
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Ignore).save(path)
+
+ // This should only complain that the destination directory already exists, rather than file
+ // "empty" is not a Parquet file.
+ assert {
+ intercept[AnalysisException] {
+ df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path)
+ }.getMessage.contains("already exists")
+ }
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
+ checkAnswer(read.format("parquet").load(path), df)
+ }
+ }
+
+ test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") {
+ withTempPath { dir =>
+ intercept[AnalysisException] {
+ // Parquet doesn't allow field names with spaces. Here we are intentionally making an
+ // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger
+ // the bug. Please refer to spark-8079 for more details.
+ range(1, 10)
+ .withColumnRenamed("id", "a b")
+ .write
+ .format("parquet")
+ .save(dir.getCanonicalPath)
+ }
+ }
+ }
+
+ test("SPARK-8604: Parquet data source should write summary file while doing appending") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val df = sqlContext.range(0, 5)
+ df.write.mode(SaveMode.Overwrite).parquet(path)
+
+ val summaryPath = new Path(path, "_metadata")
+ val commonSummaryPath = new Path(path, "_common_metadata")
+
+ val fs = summaryPath.getFileSystem(configuration)
+ fs.delete(summaryPath, true)
+ fs.delete(commonSummaryPath, true)
+
+ df.write.mode(SaveMode.Append).parquet(path)
+ checkAnswer(sqlContext.read.parquet(path), df.unionAll(df))
+
+ assert(fs.exists(summaryPath))
+ assert(fs.exists(commonSummaryPath))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
new file mode 100644
index 0000000000..d761909d60
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+/*
+This is commented out due a bug in the data source API (SPARK-9291).
+
+
+class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
+ override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName
+
+ import sqlContext._
+
+ test("save()/load() - partitioned table - simple queries - partition columns in data") {
+ withTempDir { file =>
+ val basePath = new Path(file.getCanonicalPath)
+ val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
+ val qualifiedBasePath = fs.makeQualified(basePath)
+
+ for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
+ val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ sparkContext
+ .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1")
+ .saveAsTextFile(partitionDir.toString)
+ }
+
+ val dataSchemaWithPartition =
+ StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
+
+ checkQueries(
+ read.format(dataSourceName)
+ .option("dataSchema", dataSchemaWithPartition.json)
+ .load(file.getCanonicalPath))
+ }
+ }
+}
+*/
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 2a8748d913..dd274023a1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -17,18 +17,14 @@
package org.apache.spark.sql.sources
-import java.io.File
-
import scala.collection.JavaConversions._
-import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.parquet.hadoop.ParquetOutputCommitter
-import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -581,165 +577,3 @@ class AlwaysFailParquetOutputCommitter(
sys.error("Intentional job commitment failure for testing purpose.")
}
}
-
-class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
- override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName
-
- import sqlContext._
-
- test("save()/load() - partitioned table - simple queries - partition columns in data") {
- withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
- for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
- sparkContext
- .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1")
- .saveAsTextFile(partitionDir.toString)
- }
-
- val dataSchemaWithPartition =
- StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
-
- checkQueries(
- read.format(dataSourceName)
- .option("dataSchema", dataSchemaWithPartition.json)
- .load(file.getCanonicalPath))
- }
- }
-}
-
-class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
- override val sqlContext = TestHive
-
- // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
- val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
-
- test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
- withTempPath { file =>
- // Here we coalesce partition number to 1 to ensure that only a single task is issued. This
- // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
- // directory while committing/aborting the job. See SPARK-8513 for more details.
- val df = sqlContext.range(0, 10).coalesce(1)
- intercept[SparkException] {
- df.write.format(dataSourceName).save(file.getCanonicalPath)
- }
-
- val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
- assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
- }
- }
-}
-
-class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
- override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName
-
- import sqlContext._
- import sqlContext.implicits._
-
- test("save()/load() - partitioned table - simple queries - partition columns in data") {
- withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
- for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
- sparkContext
- .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
- .toDF("a", "b", "p1")
- .write.parquet(partitionDir.toString)
- }
-
- val dataSchemaWithPartition =
- StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
-
- checkQueries(
- read.format(dataSourceName)
- .option("dataSchema", dataSchemaWithPartition.json)
- .load(file.getCanonicalPath))
- }
- }
-
- test("SPARK-7868: _temporary directories should be ignored") {
- withTempPath { dir =>
- val df = Seq("a", "b", "c").zipWithIndex.toDF()
-
- df.write
- .format("parquet")
- .save(dir.getCanonicalPath)
-
- df.write
- .format("parquet")
- .save(s"${dir.getCanonicalPath}/_temporary")
-
- checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
- }
- }
-
- test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") {
- withTempDir { dir =>
- val path = dir.getCanonicalPath
- val df = Seq(1 -> "a").toDF()
-
- // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw
- // since it's not a valid Parquet file.
- val emptyFile = new File(path, "empty")
- Files.createParentDirs(emptyFile)
- Files.touch(emptyFile)
-
- // This shouldn't throw anything.
- df.write.format("parquet").mode(SaveMode.Ignore).save(path)
-
- // This should only complain that the destination directory already exists, rather than file
- // "empty" is not a Parquet file.
- assert {
- intercept[AnalysisException] {
- df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path)
- }.getMessage.contains("already exists")
- }
-
- // This shouldn't throw anything.
- df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
- checkAnswer(read.format("parquet").load(path), df)
- }
- }
-
- test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") {
- withTempPath { dir =>
- intercept[AnalysisException] {
- // Parquet doesn't allow field names with spaces. Here we are intentionally making an
- // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger
- // the bug. Please refer to spark-8079 for more details.
- range(1, 10)
- .withColumnRenamed("id", "a b")
- .write
- .format("parquet")
- .save(dir.getCanonicalPath)
- }
- }
- }
-
- test("SPARK-8604: Parquet data source should write summary file while doing appending") {
- withTempPath { dir =>
- val path = dir.getCanonicalPath
- val df = sqlContext.range(0, 5)
- df.write.mode(SaveMode.Overwrite).parquet(path)
-
- val summaryPath = new Path(path, "_metadata")
- val commonSummaryPath = new Path(path, "_common_metadata")
-
- val fs = summaryPath.getFileSystem(configuration)
- fs.delete(summaryPath, true)
- fs.delete(commonSummaryPath, true)
-
- df.write.mode(SaveMode.Append).parquet(path)
- checkAnswer(sqlContext.read.parquet(path), df.unionAll(df))
-
- assert(fs.exists(summaryPath))
- assert(fs.exists(commonSummaryPath))
- }
- }
-}