aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-01 13:07:04 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-01 13:07:04 -0800
commitc27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0 (patch)
treef529168194ef53ded5cda96b0353f62fcd9bcad7 /sql
parentc37bbb3a1cbd93c749aaaeca1345817e0c20094f (diff)
downloadspark-c27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0.tar.gz
spark-c27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0.tar.bz2
spark-c27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0.zip
[SPARK-13582] [SQL] defer dictionary decoding in parquet reader
## What changes were proposed in this pull request? This PR defer the resolution from a id of dictionary to value until the column is actually accessed (inside getInt/getLong), this is very useful for those columns and rows that are filtered out. It's also useful for binary type, we will not need to copy all the byte arrays. This PR also change the underlying type for small decimal that could be fit within a Int, in order to use getInt() to lookup the value from IntDictionary. ## How was this patch tested? Manually test TPCDS Q7 with scale factor 10, saw about 30% improvements (after PR #11274). Author: Davies Liu <davies@databricks.com> Closes #11437 from davies/decode_dict.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala11
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java101
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java33
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java105
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java8
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java24
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java58
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala6
15 files changed, 221 insertions, 203 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 38ce1604b1..6a59e9728a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -340,6 +340,9 @@ object Decimal {
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
+ /** Maximum number of decimal digits a Int can represent */
+ val MAX_INT_DIGITS = 9
+
/** Maximum number of decimal digits a Long can represent */
val MAX_LONG_DIGITS = 18
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 2e03ddae76..9c1319c1c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -151,6 +151,17 @@ object DecimalType extends AbstractDataType {
}
/**
+ * Returns if dt is a DecimalType that fits inside a int
+ */
+ def is32BitDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision <= Decimal.MAX_INT_DIGITS
+ case _ => false
+ }
+ }
+
+ /**
* Returns if dt is a DecimalType that fits inside a long
*/
def is64BitDecimalType(dt: DataType): Boolean = {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index e7f0ec2e77..57dbd7c2ff 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -257,8 +257,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
throw new IOException("Unsupported type: " + t);
}
if (originalTypes[i] == OriginalType.DECIMAL &&
- primitiveType.getDecimalMetadata().getPrecision() >
- CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
+ primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
throw new IOException("Decimal with high precision is not supported.");
}
if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
@@ -439,7 +438,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
int precision = type.getDecimalMetadata().getPrecision();
int scale = type.getDecimalMetadata().getScale();
- Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
+ Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
"Unsupported precision.");
for (int n = 0; n < num; ++n) {
@@ -481,11 +480,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private boolean useDictionary;
/**
- * If useDictionary is true, the staging vector used to decode the ids.
- */
- private ColumnVector dictionaryIds;
-
- /**
* Maximum definition level for this column.
*/
private final int maxDefLevel;
@@ -620,18 +614,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
int num = Math.min(total, leftInPage);
if (useDictionary) {
- // Data is dictionary encoded. We will vector decode the ids and then resolve the values.
- if (dictionaryIds == null) {
- dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
- } else {
- dictionaryIds.reset();
- dictionaryIds.reserve(total);
- }
// Read and decode dictionary ids.
+ ColumnVector dictionaryIds = column.reserveDictionaryIds(total);;
defColumn.readIntegers(
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
- decodeDictionaryIds(rowId, num, column);
+ decodeDictionaryIds(rowId, num, column, dictionaryIds);
} else {
+ column.setDictionary(null);
switch (descriptor.getType()) {
case BOOLEAN:
readBooleanBatch(rowId, num, column);
@@ -668,55 +657,25 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
- private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
+ private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
+ ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
- if (column.dataType() == DataTypes.IntegerType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (column.dataType() == DataTypes.ByteType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (column.dataType() == DataTypes.ShortType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else {
- throw new NotImplementedException("Unimplemented type: " + column.dataType());
- }
- break;
-
case INT64:
- if (column.dataType() == DataTypes.LongType ||
- DecimalType.is64BitDecimalType(column.dataType())) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
- }
- } else {
- throw new NotImplementedException("Unimplemented type: " + column.dataType());
- }
- break;
-
case FLOAT:
- for (int i = rowId; i < rowId + num; ++i) {
- column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
- }
- break;
-
case DOUBLE:
- for (int i = rowId; i < rowId + num; ++i) {
- column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
- }
+ case BINARY:
+ column.setDictionary(dictionary);
break;
case FIXED_LEN_BYTE_ARRAY:
- if (DecimalType.is64BitDecimalType(column.dataType())) {
+ // DecimalType written in the legacy mode
+ if (DecimalType.is32BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
@@ -726,17 +685,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
break;
- case BINARY:
- // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
- // need to do this better. We should probably add the dictionary data to the ColumnVector
- // and reuse it across batches. This should mean adding a ByteArray would just update
- // the length and offset.
- for (int i = rowId; i < rowId + num; ++i) {
- Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
- column.putByteArray(i, v.getBytes());
- }
- break;
-
default:
throw new NotImplementedException("Unsupported type: " + descriptor.getType());
}
@@ -756,15 +704,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
+ if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
+ DecimalType.is32BitDecimalType(column.dataType())) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
- defColumn.readIntsAsLongs(
- num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ShortType) {
defColumn.readShorts(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -822,7 +768,16 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (DecimalType.is64BitDecimalType(column.dataType())) {
+ if (DecimalType.is32BitDecimalType(column.dataType())) {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ column.putInt(rowId + i,
+ (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+ } else {
+ column.putNull(rowId + i);
+ }
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 8613fcae0b..6215738901 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -25,7 +25,6 @@ import org.apache.parquet.column.values.bitpacking.Packer;
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;
-import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
/**
@@ -239,38 +238,6 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
}
- public void readIntsAsLongs(int total, ColumnVector c,
- int rowId, int level, VectorizedValuesReader data) {
- int left = total;
- while (left > 0) {
- if (this.currentCount == 0) this.readNextGroup();
- int n = Math.min(left, this.currentCount);
- switch (mode) {
- case RLE:
- if (currentValue == level) {
- for (int i = 0; i < n; i++) {
- c.putLong(rowId + i, data.readInteger());
- }
- } else {
- c.putNulls(rowId, n);
- }
- break;
- case PACKED:
- for (int i = 0; i < n; ++i) {
- if (currentBuffer[currentBufferIdx++] == level) {
- c.putLong(rowId + i, data.readInteger());
- } else {
- c.putNull(rowId + i);
- }
- }
- break;
- }
- rowId += n;
- left -= n;
- currentCount -= n;
- }
- }
-
public void readBytes(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 0514252a8e..bb0247c2fb 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -19,6 +19,10 @@ package org.apache.spark.sql.execution.vectorized;
import java.math.BigDecimal;
import java.math.BigInteger;
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.parquet.column.Dictionary;
+import org.apache.parquet.io.api.Binary;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -27,8 +31,6 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* This class represents a column of values and provides the main APIs to access the data
* values. It supports all the types and contains get/put APIs as well as their batched versions.
@@ -157,7 +159,7 @@ public abstract class ColumnVector {
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
- list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
+ list[i] = getUTF8String(i).toString();
}
}
} else if (dt instanceof CalendarIntervalType) {
@@ -204,28 +206,17 @@ public abstract class ColumnVector {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
- return Decimal.apply(getLong(ordinal), precision, scale);
- } else {
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(javaDecimal, precision, scale);
- }
+ return data.getDecimal(offset + ordinal, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
- Array child = data.getByteArray(offset + ordinal);
- return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length);
+ return data.getUTF8String(offset + ordinal);
}
@Override
public byte[] getBinary(int ordinal) {
- ColumnVector.Array array = data.getByteArray(offset + ordinal);
- byte[] bytes = new byte[array.length];
- System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
- return bytes;
+ return data.getBinary(offset + ordinal);
}
@Override
@@ -534,13 +525,58 @@ public abstract class ColumnVector {
/**
* Returns the value for rowId.
*/
- public final Array getByteArray(int rowId) {
+ private Array getByteArray(int rowId) {
Array array = getArray(rowId);
array.data.loadBytes(array);
return array;
}
/**
+ * Returns the decimal for rowId.
+ */
+ public final Decimal getDecimal(int rowId, int precision, int scale) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ return Decimal.apply(getInt(rowId), precision, scale);
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(rowId), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(rowId);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
+ }
+
+ /**
+ * Returns the UTF8String for rowId.
+ */
+ public final UTF8String getUTF8String(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array a = getByteArray(rowId);
+ return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
+ } else {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
+ return UTF8String.fromBytes(v.getBytes());
+ }
+ }
+
+ /**
+ * Returns the byte array for rowId.
+ */
+ public final byte[] getBinary(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array array = getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
+ } else {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
+ return v.getBytes();
+ }
+ }
+
+ /**
* Append APIs. These APIs all behave similarly and will append data to the current vector. It
* is not valid to mix the put and append APIs. The append APIs are slower and should only be
* used if the sizes are not known up front.
@@ -817,6 +853,39 @@ public abstract class ColumnVector {
protected final ColumnarBatch.Row resultStruct;
/**
+ * The Dictionary for this column.
+ *
+ * If it's not null, will be used to decode the value in getXXX().
+ */
+ protected Dictionary dictionary;
+
+ /**
+ * Reusable column for ids of dictionary.
+ */
+ protected ColumnVector dictionaryIds;
+
+ /**
+ * Update the dictionary.
+ */
+ public void setDictionary(Dictionary dictionary) {
+ this.dictionary = dictionary;
+ }
+
+ /**
+ * Reserve a integer column for ids of dictionary.
+ */
+ public ColumnVector reserveDictionaryIds(int capacity) {
+ if (dictionaryIds == null) {
+ dictionaryIds = allocate(capacity, DataTypes.IntegerType,
+ this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP);
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(capacity);
+ }
+ return dictionaryIds;
+ }
+
+ /**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 2aeef7f2f9..681ace3387 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -22,24 +22,20 @@ import java.sql.Date;
import java.util.Iterator;
import java.util.List;
+import org.apache.commons.lang.NotImplementedException;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
* for debugging or other non-performance critical paths.
* These utilities are mostly used to convert ColumnVectors into other formats.
*/
public class ColumnVectorUtils {
- public static String toString(ColumnVector.Array a) {
- return new String(a.byteArray, a.byteArrayOffset, a.length);
- }
-
/**
* Returns the array data as the java primitive array.
* For example, an array of IntegerType will return an int[].
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 070d897a71..8a0d7f8b12 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,11 +16,11 @@
*/
package org.apache.spark.sql.execution.vectorized;
-import java.math.BigDecimal;
-import java.math.BigInteger;
import java.util.Arrays;
import java.util.Iterator;
+import org.apache.commons.lang.NotImplementedException;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
@@ -31,8 +31,6 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* This class is the in memory representation of rows as they are streamed through operators. It
* is designed to maximize CPU efficiency and not storage footprint. Since it is expected that
@@ -193,29 +191,17 @@ public final class ColumnarBatch {
@Override
public final Decimal getDecimal(int ordinal, int precision, int scale) {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
- return Decimal.apply(getLong(ordinal), precision, scale);
- } else {
- // TODO: best perf?
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(javaDecimal, precision, scale);
- }
+ return columns[ordinal].getDecimal(rowId, precision, scale);
}
@Override
public final UTF8String getUTF8String(int ordinal) {
- ColumnVector.Array a = columns[ordinal].getByteArray(rowId);
- return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
+ return columns[ordinal].getUTF8String(rowId);
}
@Override
public final byte[] getBinary(int ordinal) {
- ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
- byte[] bytes = new byte[array.length];
- System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
- return bytes;
+ return columns[ordinal].getBinary(rowId);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index e38ed05121..b06b7f2457 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -18,25 +18,11 @@ package org.apache.spark.sql.execution.vectorized;
import java.nio.ByteOrder;
-import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.execution.vectorized.ColumnVector.Array;
-import org.apache.spark.sql.types.BooleanType;
-import org.apache.spark.sql.types.ByteType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DateType;
-import org.apache.spark.sql.types.DecimalType;
-import org.apache.spark.sql.types.DoubleType;
-import org.apache.spark.sql.types.FloatType;
-import org.apache.spark.sql.types.IntegerType;
-import org.apache.spark.sql.types.LongType;
-import org.apache.spark.sql.types.ShortType;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.types.UTF8String;
-
-
import org.apache.commons.lang.NotImplementedException;
-import org.apache.commons.lang.NotImplementedException;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
/**
* Column data backed using offheap memory.
@@ -171,7 +157,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final byte getByte(int rowId) {
- return Platform.getByte(null, data + rowId);
+ if (dictionary == null) {
+ return Platform.getByte(null, data + rowId);
+ } else {
+ return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -199,7 +189,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final short getShort(int rowId) {
- return Platform.getShort(null, data + 2 * rowId);
+ if (dictionary == null) {
+ return Platform.getShort(null, data + 2 * rowId);
+ } else {
+ return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -233,7 +227,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final int getInt(int rowId) {
- return Platform.getInt(null, data + 4 * rowId);
+ if (dictionary == null) {
+ return Platform.getInt(null, data + 4 * rowId);
+ } else {
+ return dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -267,7 +265,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final long getLong(int rowId) {
- return Platform.getLong(null, data + 8 * rowId);
+ if (dictionary == null) {
+ return Platform.getLong(null, data + 8 * rowId);
+ } else {
+ return dictionary.decodeToLong(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -301,7 +303,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final float getFloat(int rowId) {
- return Platform.getFloat(null, data + rowId * 4);
+ if (dictionary == null) {
+ return Platform.getFloat(null, data + rowId * 4);
+ } else {
+ return dictionary.decodeToFloat(dictionaryIds.getInt(rowId));
+ }
}
@@ -336,7 +342,11 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final double getDouble(int rowId) {
- return Platform.getDouble(null, data + rowId * 8);
+ if (dictionary == null) {
+ return Platform.getDouble(null, data + rowId * 8);
+ } else {
+ return dictionary.decodeToDouble(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -394,7 +404,7 @@ public final class OffHeapColumnVector extends ColumnVector {
} else if (type instanceof ShortType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
} else if (type instanceof IntegerType || type instanceof FloatType ||
- type instanceof DateType) {
+ type instanceof DateType || DecimalType.is32BitDecimalType(type)) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type)) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 3502d31bd1..305e84a86b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -16,13 +16,12 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.util.Arrays;
+
import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.execution.vectorized.ColumnVector.Array;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
-import java.util.Arrays;
-
/**
* A column backed by an in memory JVM array. This stores the NULLs as a byte per value
* and a java array for the values.
@@ -68,7 +67,6 @@ public final class OnHeapColumnVector extends ColumnVector {
doubleData = null;
}
-
//
// APIs dealing with nulls
//
@@ -154,7 +152,11 @@ public final class OnHeapColumnVector extends ColumnVector {
@Override
public final byte getByte(int rowId) {
- return byteData[rowId];
+ if (dictionary == null) {
+ return byteData[rowId];
+ } else {
+ return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -180,7 +182,11 @@ public final class OnHeapColumnVector extends ColumnVector {
@Override
public final short getShort(int rowId) {
- return shortData[rowId];
+ if (dictionary == null) {
+ return shortData[rowId];
+ } else {
+ return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
@@ -217,7 +223,11 @@ public final class OnHeapColumnVector extends ColumnVector {
@Override
public final int getInt(int rowId) {
- return intData[rowId];
+ if (dictionary == null) {
+ return intData[rowId];
+ } else {
+ return dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -253,7 +263,11 @@ public final class OnHeapColumnVector extends ColumnVector {
@Override
public final long getLong(int rowId) {
- return longData[rowId];
+ if (dictionary == null) {
+ return longData[rowId];
+ } else {
+ return dictionary.decodeToLong(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -280,7 +294,13 @@ public final class OnHeapColumnVector extends ColumnVector {
}
@Override
- public final float getFloat(int rowId) { return floatData[rowId]; }
+ public final float getFloat(int rowId) {
+ if (dictionary == null) {
+ return floatData[rowId];
+ } else {
+ return dictionary.decodeToFloat(dictionaryIds.getInt(rowId));
+ }
+ }
//
// APIs dealing with doubles
@@ -309,7 +329,11 @@ public final class OnHeapColumnVector extends ColumnVector {
@Override
public final double getDouble(int rowId) {
- return doubleData[rowId];
+ if (dictionary == null) {
+ return doubleData[rowId];
+ } else {
+ return dictionary.decodeToDouble(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -377,7 +401,8 @@ public final class OnHeapColumnVector extends ColumnVector {
short[] newData = new short[newCapacity];
if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
shortData = newData;
- } else if (type instanceof IntegerType || type instanceof DateType) {
+ } else if (type instanceof IntegerType || type instanceof DateType ||
+ DecimalType.is32BitDecimalType(type)) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 42d89f4bf8..8a128b4b61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter(
}
protected def decimalFromBinary(value: Binary): Decimal = {
- if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
val unscaled = CatalystRowConverter.binaryToUnscaledLong(value)
Decimal(unscaled, precision, scale)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index ab4250d0ad..6f6340f541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition._
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64}
+import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter(
case INT_16 => ShortType
case INT_32 | null => IntegerType
case DATE => DateType
- case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32)
+ case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS)
case UINT_8 => typeNotSupported()
case UINT_16 => typeNotSupported()
case UINT_32 => typeNotSupported()
@@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter(
case INT64 =>
originalType match {
case INT_64 | null => LongType
- case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64)
+ case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS)
case UINT_64 => typeNotSupported()
case TIMESTAMP_MILLIS => typeNotImplemented()
case _ => illegalType()
@@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT32 for 1 <= precision <= 9
case DecimalType.Fixed(precision, scale)
- if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat =>
+ if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat =>
Types
.primitive(INT32, repetition)
.as(DECIMAL)
@@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT64 for 1 <= precision <= 18
case DecimalType.Fixed(precision, scale)
- if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat =>
+ if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat =>
Types
.primitive(INT64, repetition)
.as(DECIMAL)
@@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter {
// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)
- val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */
-
- val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */
-
// Max precision of a decimal value stored in `numBytes` bytes
def maxPrecisionForBytes(numBytes: Int): Int = {
Math.round( // convert double to long
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
index 3508220c95..0252c79d8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
@@ -33,7 +33,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64}
+import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi
writeLegacyParquetFormat match {
// Standard mode, 1 <= precision <= 9, writes as INT32
- case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer
+ case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer
// Standard mode, 10 <= precision <= 18, writes as INT64
- case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer
+ case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer
// Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY
- case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong
+ case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong
// Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY
case _ => binaryWriterUsingUnscaledBytes
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
index cef6b79a09..281a2cffa8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex
assert(batch.column(0).getByte(i) == 1)
assert(batch.column(1).getInt(i) == 2)
assert(batch.column(2).getLong(i) == 3)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc")
+ assert(batch.column(3).getUTF8String(i).toString == "abc")
i += 1
}
reader.close()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
index 8efdf8adb0..97638a66ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
@@ -370,7 +370,7 @@ object ColumnarBatchBenchmark {
}
i = 0
while (i < count) {
- sum += column.getByteArray(i).length
+ sum += column.getUTF8String(i).numBytes()
i += 1
}
column.reset()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 445f311107..b3c3e66fbc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
reference.zipWithIndex.foreach { v =>
assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode)
- assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)),
+ assert(v._1 == column.getUTF8String(v._2).toString,
"MemoryMode" + memMode)
}
@@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(batch.column(1).getDouble(0) == 1.1)
assert(batch.column(1).getIsNull(0) == false)
assert(batch.column(2).getIsNull(0) == true)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello")
+ assert(batch.column(3).getUTF8String(0).toString == "Hello")
// Verify the iterator works correctly.
val it = batch.rowIterator()
@@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(row.getDouble(1) == 1.1)
assert(row.isNullAt(1) == false)
assert(row.isNullAt(2) == true)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello")
+ assert(batch.column(3).getUTF8String(0).toString == "Hello")
assert(it.hasNext == false)
assert(it.hasNext == false)