From b1f4b4abfd8d038c3684685b245b5fd31b927da0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 18:41:51 -0700 Subject: [SPARK-9348][SQL] Remove apply method on InternalRow. Author: Reynold Xin Closes #7665 from rxin/remove-row-apply and squashes the following commits: 0b43001 [Reynold Xin] support getString in UnsafeRow. 176d633 [Reynold Xin] apply -> get. 2941324 [Reynold Xin] [SPARK-9348][SQL] Remove apply method on InternalRow. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 88 +++++++++++----------- .../apache/spark/sql/catalyst/InternalRow.scala | 32 ++++---- .../expressions/codegen/CodeGenerator.scala | 2 +- .../catalyst/expressions/MathFunctionsSuite.scala | 2 +- .../apache/spark/sql/columnar/ColumnStats.scala | 4 +- .../org/apache/spark/sql/columnar/ColumnType.scala | 16 ++-- .../columnar/compression/compressionSchemes.scala | 2 +- .../spark/sql/execution/SparkSqlSerializer2.scala | 4 +- .../execution/datasources/DataSourceStrategy.scala | 6 +- .../apache/spark/sql/execution/debug/package.scala | 2 +- .../apache/spark/sql/execution/pythonUDFs.scala | 2 +- .../spark/sql/expressions/aggregate/udaf.scala | 4 +- .../spark/sql/parquet/ParquetTableOperations.scala | 6 +- .../spark/sql/parquet/ParquetTableSupport.scala | 22 +++--- .../test/scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 12 +-- .../sql/columnar/NullableColumnAccessorSuite.scala | 2 +- .../sql/columnar/NullableColumnBuilderSuite.scala | 2 +- .../columnar/compression/BooleanBitSetSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveInspectors.scala | 6 +- .../sql/hive/execution/InsertIntoHiveTable.scala | 2 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 2 +- 22 files changed, 113 insertions(+), 111 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 225f6e6553..9be9089493 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -231,84 +231,89 @@ public final class UnsafeRow extends MutableRow { } @Override - public Object get(int i) { + public Object get(int ordinal) { throw new UnsupportedOperationException(); } @Override - public T getAs(int i) { + public T getAs(int ordinal) { throw new UnsupportedOperationException(); } @Override - public boolean isNullAt(int i) { - assertIndexIsValid(i); - return BitSetMethods.isSet(baseObject, baseOffset, i); + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return BitSetMethods.isSet(baseObject, baseOffset, ordinal); } @Override - public boolean getBoolean(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override - public byte getByte(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); } @Override - public short getShort(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); } @Override - public int getInt(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); } @Override - public long getLong(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); } @Override - public float getFloat(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { return Float.NaN; } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } } @Override - public double getDouble(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { return Float.NaN; } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } } @Override - public UTF8String getUTF8String(int i) { - assertIndexIsValid(i); - return isNullAt(i) ? null : UTF8String.fromBytes(getBinary(i)); + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); } @Override - public byte[] getBinary(int i) { - if (isNullAt(i)) { + public String getString(int ordinal) { + return getUTF8String(ordinal).toString(); + } + + @Override + public byte[] getBinary(int ordinal) { + if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; @@ -324,17 +329,12 @@ public final class UnsafeRow extends MutableRow { } @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - - @Override - public UnsafeRow getStruct(int i, int numFields) { - if (isNullAt(i)) { + public UnsafeRow getStruct(int ordinal, int numFields) { + if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final UnsafeRow row = new UnsafeRow(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index f248b1f338..37f0f57e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String /** @@ -29,35 +30,34 @@ abstract class InternalRow extends Serializable { def numFields: Int - def get(i: Int): Any + def get(ordinal: Int): Any - // TODO: Remove this. - def apply(i: Int): Any = get(i) + def getAs[T](ordinal: Int): T = get(ordinal).asInstanceOf[T] - def getAs[T](i: Int): T = get(i).asInstanceOf[T] + def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def isNullAt(i: Int): Boolean = get(i) == null + def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal) - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getByte(ordinal: Int): Byte = getAs[Byte](ordinal) - def getByte(i: Int): Byte = getAs[Byte](i) + def getShort(ordinal: Int): Short = getAs[Short](ordinal) - def getShort(i: Int): Short = getAs[Short](i) + def getInt(ordinal: Int): Int = getAs[Int](ordinal) - def getInt(i: Int): Int = getAs[Int](i) + def getLong(ordinal: Int): Long = getAs[Long](ordinal) - def getLong(i: Int): Long = getAs[Long](i) + def getFloat(ordinal: Int): Float = getAs[Float](ordinal) - def getFloat(i: Int): Float = getAs[Float](i) + def getDouble(ordinal: Int): Double = getAs[Double](ordinal) - def getDouble(i: Int): Double = getAs[Double](i) + def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal) - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal) - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal) - // This is only use for test - def getString(i: Int): String = getAs[UTF8String](i).toString + // This is only use for test and will throw a null pointer exception if the position is null. + def getString(ordinal: Int): String = getAs[UTF8String](ordinal).toString /** * Returns a struct from ordinal position. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 508882acbe..2a1e288cb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -110,7 +110,7 @@ class CodeGenContext { case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.apply($ordinal)" + case _ => s"($jt)$row.get($ordinal)" } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index a2b0fad7b7..6caf8baf24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -158,7 +158,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).apply(0) + val actual = plan(inputRow).get(0) if (!actual.asInstanceOf[Double].isNaN) { fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 00374d1fa3..7c63179af6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -211,7 +211,7 @@ private[sql] class StringColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[UTF8String] + val value = row.getUTF8String(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) @@ -241,7 +241,7 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] + val value = row.getDecimal(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index ac42bde07c..c0ca52751b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -90,7 +90,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to(toOrdinal) = from(fromOrdinal) + to(toOrdinal) = from.get(fromOrdinal) } /** @@ -329,11 +329,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def getField(row: InternalRow, ordinal: Int): UTF8String = { - row(ordinal).asInstanceOf[UTF8String] + row.getUTF8String(ordinal) } override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.update(toOrdinal, from(fromOrdinal)) + to.update(toOrdinal, from.getUTF8String(fromOrdinal)) } } @@ -347,7 +347,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } override def getField(row: InternalRow, ordinal: Int): Int = { - row(ordinal).asInstanceOf[Int] + row.getInt(ordinal) } def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { @@ -365,7 +365,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { } override def getField(row: InternalRow, ordinal: Int): Long = { - row(ordinal).asInstanceOf[Long] + row.getLong(ordinal) } override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { @@ -388,7 +388,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row(ordinal).asInstanceOf[Decimal] + row.getDecimal(ordinal) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { @@ -427,7 +427,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row(ordinal).asInstanceOf[Array[Byte]] + row.getBinary(ordinal) } } @@ -440,7 +440,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row(ordinal)) + SparkSqlSerializer.serialize(row.get(ordinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 5abc1259a1..6150df6930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value(0) == currentValue(0)) { + if (value.get(0) == currentValue.get(0)) { currentRun += 1 } else { // Writes current run diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 83c4e8733f..6ee833c7b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -278,7 +278,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getUTF8String(i).getBytes out.writeInt(bytes.length) out.write(bytes) } @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] + val value = row.getAs[Decimal](i) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7f452daef3..cdbe42381a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -170,6 +170,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) } + // TODO: refactor this thing. It is very complicated because it does projection internally. + // We should just put a project on top of this. private def mergeWithPartitionValues( schema: StructType, requiredColumns: Array[String], @@ -187,13 +189,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (i != -1) { // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues(i) + mutableRow(ordinal) = partitionValues.get(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow(i) + mutableRow(ordinal) = dataRow.get(i) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e6081cb05b..1fdcc6a850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { tupleCount += 1 var i = 0 while (i < numColumns) { - val value = currentRow(i) + val value = currentRow.get(i) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 40bf03a3f1..970c40dc61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -129,7 +129,7 @@ object EvaluatePython { val values = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { - values(i) = toJava(row(i), struct.fields(i).dataType) + values(i) = toJava(row.get(i), struct.fields(i).dataType) i += 1 } new GenericInternalRowWithSchema(values, struct) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 46f0fac861..7a6e86779b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -121,7 +121,7 @@ class MutableAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer(offsets(i))) + toScalaConverters(i)(underlyingBuffer.get(offsets(i))) } def update(i: Int, value: Any): Unit = { @@ -157,7 +157,7 @@ class InputAggregationBuffer private[sql] ( s"Could not access ${i}th value in this buffer because it only has $length values.") } // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i))) } override def copy(): InputAggregationBuffer = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 8cab27d6e1..38bb1e3967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -159,7 +159,7 @@ private[sql] case class ParquetTableScan( // Parquet will leave partitioning columns empty, so we fill them in here. var i = 0 - while (i < requestedPartitionOrdinals.size) { + while (i < requestedPartitionOrdinals.length) { row(requestedPartitionOrdinals(i)._2) = partitionRowValues(requestedPartitionOrdinals(i)._1) i += 1 @@ -179,12 +179,12 @@ private[sql] case class ParquetTableScan( var i = 0 while (i < row.numFields) { - mutableRow(i) = row(i) + mutableRow(i) = row.get(i) i += 1 } // Parquet will leave partitioning columns empty, so we fill them in here. i = 0 - while (i < requestedPartitionOrdinals.size) { + while (i < requestedPartitionOrdinals.length) { mutableRow(requestedPartitionOrdinals(i)._2) = partitionRowValues(requestedPartitionOrdinals(i)._1) i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c7c58e69d4..2c23d4e8a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -217,9 +217,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null) { + if (!record.isNullAt(index)) { writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record(index)) + writeValue(attributes(index).dataType, record.get(index)) writer.endField(attributes(index).name, index) } index = index + 1 @@ -277,10 +277,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo val fields = schema.fields.toArray writer.startGroup() var i = 0 - while(i < fields.size) { - if (struct(i) != null) { + while(i < fields.length) { + if (!struct.isNullAt(i)) { writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct(i)) + writeValue(fields(i).dataType, struct.get(i)) writer.endField(fields(i).name, i) } i = i + 1 @@ -387,7 +387,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null && record(index) != Nil) { + if (!record.isNullAt(index) && !record.isNullAt(index)) { writer.startField(attributes(index).name, index) consumeType(attributes(index).dataType, record, index) writer.endField(attributes(index).name, index) @@ -410,15 +410,15 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case TimestampType => writeTimestamp(record.getLong(index)) case FloatType => writer.addFloat(record.getFloat(index)) case DoubleType => writer.addDouble(record.getDouble(index)) - case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) + case StringType => + writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) + case BinaryType => + writer.addBinary(Binary.fromByteArray(record.getBinary(index))) case d: DecimalType => if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(record(index).asInstanceOf[Decimal], d.precision) + writeDecimal(record.getDecimal(index), d.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 0e5c5abff8..c6804e8482 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -39,14 +39,14 @@ class RowSuite extends SparkFunSuite { assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected(3) === actual1(3)) + assert(expected.get(3) === actual1.get(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected(3) === actual2(3)) + assert(expected.get(3) === actual2.get(3)) } test("SpecificMutableRow.update with null") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 3333fee671..31e7b0e72e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -58,15 +58,15 @@ class ColumnStatsSuite extends SparkFunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1)) + assertResult(10, "Wrong null count")(stats.get(2)) + assertResult(20, "Wrong row count")(stats.get(3)) + assertResult(stats.get(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 9eaa769846..d421f4d8d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row(0) === randomRow(0)) + assert(row.get(0) === randomRow.get(0)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 17e9ae464b..cd8bf75ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -98,7 +98,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType.extract(buffer) } - assert(actual === randomRow(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index f606e2133b..33092c83a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_(0)) + val values = rows.map(_.get(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 592cfa0ee8..16977ce30c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -497,7 +497,7 @@ private[hive] trait HiveInspectors { x.setStructFieldData( result, fieldRefs.get(i), - wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) i += 1 } @@ -508,7 +508,7 @@ private[hive] trait HiveInspectors { val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) i += 1 } @@ -536,7 +536,7 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row.get(i), inspectors(i)) i += 1 } cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 34b629403e..f0e0ca05a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -102,7 +102,7 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i)) i += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 10623dc820..58445095ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -122,7 +122,7 @@ private[orc] class OrcOutputWriter( override def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row(i)) + reusableOutputBuffer(i) = wrappers(i)(row.get(i)) i += 1 } -- cgit v1.2.3