aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala22
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java180
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java34
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java46
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java98
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java94
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala44
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/Platform.java8
8 files changed, 484 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index cf5322125b..5dd661ee6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType {
}
}
+ /**
+ * Returns if dt is a DecimalType that fits inside a long
+ */
+ def is64BitDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision <= Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
+ /**
+ * Returns if dt is a DecimalType that doesn't fit inside a long
+ */
+ def isByteArrayDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision > Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a0bf8734b6..a5bc506a65 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -102,18 +105,36 @@ public abstract class ColumnVector {
DataType dt = data.dataType();
Object[] list = new Object[length];
- if (dt instanceof ByteType) {
+ if (dt instanceof BooleanType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getBoolean(offset + i);
+ }
+ }
+ } else if (dt instanceof ByteType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getByte(offset + i);
}
}
+ } else if (dt instanceof ShortType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getShort(offset + i);
+ }
+ }
} else if (dt instanceof IntegerType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getInt(offset + i);
}
}
+ } else if (dt instanceof FloatType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getFloat(offset + i);
+ }
+ }
} else if (dt instanceof DoubleType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
@@ -126,12 +147,25 @@ public abstract class ColumnVector {
list[i] = data.getLong(offset + i);
}
}
+ } else if (dt instanceof DecimalType) {
+ DecimalType decType = (DecimalType)dt;
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getDecimal(i, decType.precision(), decType.scale());
+ }
+ }
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
}
}
+ } else if (dt instanceof CalendarIntervalType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getInterval(i);
+ }
+ }
} else {
throw new NotImplementedException("Type " + dt);
}
@@ -170,7 +204,14 @@ public abstract class ColumnVector {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
@@ -181,17 +222,22 @@ public abstract class ColumnVector {
@Override
public byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = data.getByteArray(offset + ordinal);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ int month = data.getChildColumn(0).getInt(offset + ordinal);
+ long microseconds = data.getChildColumn(1).getLong(offset + ordinal);
+ return new CalendarInterval(month, microseconds);
}
@Override
public InternalRow getStruct(int ordinal, int numFields) {
- throw new NotImplementedException();
+ return data.getStruct(offset + ordinal);
}
@Override
@@ -282,6 +328,21 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putBoolean(int rowId, boolean value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putBooleans(int rowId, int count, boolean value);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract boolean getBoolean(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putByte(int rowId, byte value);
/**
@@ -302,6 +363,26 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putShort(int rowId, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putShorts(int rowId, int count, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract short getShort(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putInt(int rowId, int value);
/**
@@ -354,6 +435,33 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putFloat(int rowId, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putFloats(int rowId, int count, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ * src should contain `count` doubles written as ieee format.
+ */
+ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be ieee formatted floats.
+ */
+ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract float getFloat(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putDouble(int rowId, double value);
/**
@@ -369,7 +477,7 @@ public abstract class ColumnVector {
/**
* Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be ieee formated doubles.
+ * The data in src must be ieee formatted doubles.
*/
public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
@@ -469,6 +577,20 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendBoolean(boolean v) {
+ reserve(elementsAppended + 1);
+ putBoolean(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendBooleans(int count, boolean v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putBooleans(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendByte(byte v) {
reserve(elementsAppended + 1);
putByte(elementsAppended, v);
@@ -491,6 +613,28 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendShort(short v) {
+ reserve(elementsAppended + 1);
+ putShort(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendShorts(int count, short v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putShorts(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendShorts(int length, short[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putShorts(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
public final int appendInt(int v) {
reserve(elementsAppended + 1);
putInt(elementsAppended, v);
@@ -535,6 +679,20 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendFloat(float v) {
+ reserve(elementsAppended + 1);
+ putFloat(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendFloats(int count, float v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putFloats(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendDouble(double v) {
reserve(elementsAppended + 1);
putDouble(elementsAppended, v);
@@ -661,7 +819,8 @@ public abstract class ColumnVector {
this.capacity = capacity;
this.type = type;
- if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) {
+ if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
+ || DecimalType.isByteArrayDecimalType(type)) {
DataType childType;
int childCapacity = capacity;
if (type instanceof ArrayType) {
@@ -682,6 +841,13 @@ public abstract class ColumnVector {
}
this.resultArray = null;
this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+ } else if (type instanceof CalendarIntervalType) {
+ // Two columns. Months as int. Microseconds as Long.
+ this.childColumns = new ColumnVector[2];
+ this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
+ this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
+ this.resultArray = null;
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
} else {
this.childColumns = null;
this.resultArray = null;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 6c651a759d..453bc15e13 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -16,12 +16,15 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.commons.lang.NotImplementedException;
@@ -59,19 +62,44 @@ public class ColumnVectorUtils {
private static void appendValue(ColumnVector dst, DataType t, Object o) {
if (o == null) {
- dst.appendNull();
+ if (t instanceof CalendarIntervalType) {
+ dst.appendStruct(true);
+ } else {
+ dst.appendNull();
+ }
} else {
- if (t == DataTypes.ByteType) {
- dst.appendByte(((Byte)o).byteValue());
+ if (t == DataTypes.BooleanType) {
+ dst.appendBoolean(((Boolean)o).booleanValue());
+ } else if (t == DataTypes.ByteType) {
+ dst.appendByte(((Byte) o).byteValue());
+ } else if (t == DataTypes.ShortType) {
+ dst.appendShort(((Short)o).shortValue());
} else if (t == DataTypes.IntegerType) {
dst.appendInt(((Integer)o).intValue());
} else if (t == DataTypes.LongType) {
dst.appendLong(((Long)o).longValue());
+ } else if (t == DataTypes.FloatType) {
+ dst.appendFloat(((Float)o).floatValue());
} else if (t == DataTypes.DoubleType) {
dst.appendDouble(((Double)o).doubleValue());
} else if (t == DataTypes.StringType) {
byte[] b =((String)o).getBytes();
dst.appendByteArray(b, 0, b.length);
+ } else if (t instanceof DecimalType) {
+ DecimalType dt = (DecimalType)t;
+ Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
+ if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+ dst.appendLong(d.toUnscaledLong());
+ } else {
+ final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
+ byte[] bytes = integer.toByteArray();
+ dst.appendByteArray(bytes, 0, bytes.length);
+ }
+ } else if (t instanceof CalendarIntervalType) {
+ CalendarInterval c = (CalendarInterval)o;
+ dst.appendStruct(false);
+ dst.getChildColumn(0).appendInt(c.months);
+ dst.getChildColumn(1).appendLong(c.microseconds);
} else {
throw new NotImplementedException("Type " + t);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 5a575811fa..dbad5e070f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Arrays;
import java.util.Iterator;
@@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -150,44 +153,40 @@ public final class ColumnarBatch {
}
@Override
- public final boolean isNullAt(int ordinal) {
- return columns[ordinal].getIsNull(rowId);
- }
+ public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); }
@Override
- public final boolean getBoolean(int ordinal) {
- throw new NotImplementedException();
- }
+ public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
@Override
public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
@Override
- public final short getShort(int ordinal) {
- throw new NotImplementedException();
- }
+ public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
@Override
- public final int getInt(int ordinal) {
- return columns[ordinal].getInt(rowId);
- }
+ public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
@Override
public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
@Override
- public final float getFloat(int ordinal) {
- throw new NotImplementedException();
- }
+ public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
@Override
- public final double getDouble(int ordinal) {
- return columns[ordinal].getDouble(rowId);
- }
+ public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
@Override
public final Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
@@ -198,12 +197,17 @@ public final class ColumnarBatch {
@Override
public final byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public final CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+ final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+ return new CalendarInterval(months, microseconds);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 335124fd5a..22c5e5fc81 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -19,11 +19,15 @@ package org.apache.spark.sql.execution.vectorized;
import java.nio.ByteOrder;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.ShortType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
@@ -122,6 +126,26 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0));
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ Platform.putByte(null, data + rowId + i, v);
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
+
+ //
// APIs dealing with Bytes
//
@@ -149,6 +173,34 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ Platform.putShort(null, data + 2 * rowId, value);
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ long offset = data + 2 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putShort(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
+ null, data + 2 * rowId, count * 2);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return Platform.getShort(null, data + 2 * rowId);
+ }
+
+ //
// APIs dealing with ints
//
@@ -217,6 +269,41 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) {
+ Platform.putFloat(null, data + rowId * 4, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ long offset = data + 4 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putFloat(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
+ null, data + 4 * rowId, count * 4);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ null, data + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) {
+ return Platform.getFloat(null, data + rowId * 4);
+ }
+
+
+ //
// APIs dealing with doubles
//
@@ -241,7 +328,7 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex,
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
null, data + rowId * 8, count * 8);
}
@@ -300,11 +387,14 @@ public final class OffHeapColumnVector extends ColumnVector {
Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
this.offsetData =
Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof ByteType) {
+ } else if (type instanceof ByteType || type instanceof BooleanType) {
this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
- } else if (type instanceof IntegerType) {
+ } else if (type instanceof ShortType) {
+ this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
+ } else if (type instanceof IntegerType || type instanceof FloatType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof LongType || type instanceof DoubleType) {
+ } else if (type instanceof LongType || type instanceof DoubleType ||
+ DecimalType.is64BitDecimalType(type)) {
this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8);
} else if (resultStruct != null) {
// Nothing to store.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 8197fa11cd..32356334c0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector {
// Array for each type. Only 1 is populated for any type.
private byte[] byteData;
+ private short[] shortData;
private int[] intData;
private long[] longData;
+ private float[] floatData;
private double[] doubleData;
// Only set if type is Array.
@@ -105,6 +107,30 @@ public final class OnHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ byteData[rowId] = (byte)((value) ? 1 : 0);
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ byteData[i + rowId] = v;
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) {
+ return byteData[rowId] == 1;
+ }
+
+ //
+
+ //
// APIs dealing with Bytes
//
@@ -131,6 +157,33 @@ public final class OnHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ shortData[rowId] = value;
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ for (int i = 0; i < count; ++i) {
+ shortData[i + rowId] = value;
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, shortData, rowId, count);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return shortData[rowId];
+ }
+
+
+ //
// APIs dealing with Ints
//
@@ -202,6 +255,31 @@ public final class OnHeapColumnVector extends ColumnVector {
return longData[rowId];
}
+ //
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) { floatData[rowId] = value; }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ Arrays.fill(floatData, rowId, rowId + count, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, floatData, rowId, count);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) { return floatData[rowId]; }
//
// APIs dealing with doubles
@@ -277,7 +355,7 @@ public final class OnHeapColumnVector extends ColumnVector {
// Spilt this function out since it is the slow path.
private final void reserveInternal(int newCapacity) {
- if (this.resultArray != null) {
+ if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
if (this.arrayLengths != null) {
@@ -286,18 +364,30 @@ public final class OnHeapColumnVector extends ColumnVector {
}
arrayLengths = newLengths;
arrayOffsets = newOffsets;
+ } else if (type instanceof BooleanType) {
+ byte[] newData = new byte[newCapacity];
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ byteData = newData;
} else if (type instanceof ByteType) {
byte[] newData = new byte[newCapacity];
if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
byteData = newData;
+ } else if (type instanceof ShortType) {
+ short[] newData = new short[newCapacity];
+ if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+ shortData = newData;
} else if (type instanceof IntegerType) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
- } else if (type instanceof LongType) {
+ } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) {
long[] newData = new long[newCapacity];
if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
longData = newData;
+ } else if (type instanceof FloatType) {
+ float[] newData = new float[newCapacity];
+ if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+ floatData = newData;
} else if (type instanceof DoubleType) {
double[] newData = new double[newCapacity];
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 67cc08b6fc..445f311107 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.types.CalendarInterval
class ColumnarBatchSuite extends SparkFunSuite {
test("Null Apis") {
@@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
}}
}
-
private def doubleEquals(d1: Double, d2: Double): Boolean = {
if (d1.isNaN && d2.isNaN) {
true
@@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
if (!r1.isNullAt(v._2)) {
v._1.dataType match {
+ case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
+ case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
+ case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
+ "Seed = " + seed)
case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
"Seed = " + seed)
+ case t: DecimalType =>
+ val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
+ val d2 = r2.getDecimal(v._2)
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
case StringType =>
assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
+ case CalendarIntervalType =>
+ assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
case ArrayType(childType, n) =>
val a1 = r1.getArray(v._2).array
val a2 = r2.getList(v._2).toArray
@@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite {
i += 1
}
}
+ case FloatType => {
+ var i = 0
+ while (i < a1.length) {
+ assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]),
+ "Seed = " + seed)
+ i += 1
+ }
+ }
+
+ case t: DecimalType =>
+ var i = 0
+ while (i < a1.length) {
+ assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+ if (a1(i) != null) {
+ val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal
+ val d2 = a2(i).asInstanceOf[java.math.BigDecimal]
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
+ }
+ i += 1
+ }
+
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
@@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
* results.
*/
def testRandomRows(flatSchema: Boolean, numFields: Int) {
- // TODO: add remaining types. Figure out why StringType doesn't work on jenkins.
- val types = Array(ByteType, IntegerType, LongType, DoubleType)
+ // TODO: Figure out why StringType doesn't work on jenkins.
+ val types = Array(
+ BooleanType, ByteType, FloatType, DoubleType,
+ IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
+ CalendarIntervalType)
val seed = System.nanoTime()
- val NUM_ROWS = 500
+ val NUM_ROWS = 200
val NUM_ITERS = 1000
val random = new Random(seed)
var i = 0
@@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
test("Random flat schema") {
- testRandomRows(true, 10)
+ testRandomRows(true, 15)
}
test("Random nested schema") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index b29bf6a464..18761bfd22 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -27,10 +27,14 @@ public final class Platform {
public static final int BYTE_ARRAY_OFFSET;
+ public static final int SHORT_ARRAY_OFFSET;
+
public static final int INT_ARRAY_OFFSET;
public static final int LONG_ARRAY_OFFSET;
+ public static final int FLOAT_ARRAY_OFFSET;
+
public static final int DOUBLE_ARRAY_OFFSET;
public static int getInt(Object object, long offset) {
@@ -168,13 +172,17 @@ public final class Platform {
if (_UNSAFE != null) {
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+ SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+ FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
} else {
BYTE_ARRAY_OFFSET = 0;
+ SHORT_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
+ FLOAT_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
}
}