aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2016-02-01 13:56:14 -0800
committerReynold Xin <rxin@databricks.com>2016-02-01 13:56:14 -0800
commit064b029c6a15481fc4dfb147100c19a68cd1cc95 (patch)
treed6d10a8d0026556f873780483d010077f6c16ac4
parentc9b89a0a0921ce3d52864afd4feb7f37b90f7b46 (diff)
downloadspark-064b029c6a15481fc4dfb147100c19a68cd1cc95.tar.gz
spark-064b029c6a15481fc4dfb147100c19a68cd1cc95.tar.bz2
spark-064b029c6a15481fc4dfb147100c19a68cd1cc95.zip
[SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch.
This includes: float, boolean, short, decimal and calendar interval. Decimal is mapped to long or byte array depending on the size and calendar interval is mapped to a struct of int and long. The only remaining type is map. The schema mapping is straightforward but we might want to revisit how we deal with this in the rest of the execution engine. Author: Nong Li <nong@databricks.com> Closes #10961 from nongli/spark-13043.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala22
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java180
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java34
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java46
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java98
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java94
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala44
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/Platform.java8
8 files changed, 484 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index cf5322125b..5dd661ee6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType {
}
}
+ /**
+ * Returns if dt is a DecimalType that fits inside a long
+ */
+ def is64BitDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision <= Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
+ /**
+ * Returns if dt is a DecimalType that doesn't fit inside a long
+ */
+ def isByteArrayDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision > Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a0bf8734b6..a5bc506a65 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -102,18 +105,36 @@ public abstract class ColumnVector {
DataType dt = data.dataType();
Object[] list = new Object[length];
- if (dt instanceof ByteType) {
+ if (dt instanceof BooleanType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getBoolean(offset + i);
+ }
+ }
+ } else if (dt instanceof ByteType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getByte(offset + i);
}
}
+ } else if (dt instanceof ShortType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getShort(offset + i);
+ }
+ }
} else if (dt instanceof IntegerType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getInt(offset + i);
}
}
+ } else if (dt instanceof FloatType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getFloat(offset + i);
+ }
+ }
} else if (dt instanceof DoubleType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
@@ -126,12 +147,25 @@ public abstract class ColumnVector {
list[i] = data.getLong(offset + i);
}
}
+ } else if (dt instanceof DecimalType) {
+ DecimalType decType = (DecimalType)dt;
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getDecimal(i, decType.precision(), decType.scale());
+ }
+ }
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
}
}
+ } else if (dt instanceof CalendarIntervalType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getInterval(i);
+ }
+ }
} else {
throw new NotImplementedException("Type " + dt);
}
@@ -170,7 +204,14 @@ public abstract class ColumnVector {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
@@ -181,17 +222,22 @@ public abstract class ColumnVector {
@Override
public byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = data.getByteArray(offset + ordinal);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ int month = data.getChildColumn(0).getInt(offset + ordinal);
+ long microseconds = data.getChildColumn(1).getLong(offset + ordinal);
+ return new CalendarInterval(month, microseconds);
}
@Override
public InternalRow getStruct(int ordinal, int numFields) {
- throw new NotImplementedException();
+ return data.getStruct(offset + ordinal);
}
@Override
@@ -282,6 +328,21 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putBoolean(int rowId, boolean value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putBooleans(int rowId, int count, boolean value);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract boolean getBoolean(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putByte(int rowId, byte value);
/**
@@ -302,6 +363,26 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putShort(int rowId, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putShorts(int rowId, int count, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract short getShort(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putInt(int rowId, int value);
/**
@@ -354,6 +435,33 @@ public abstract class ColumnVector {
/**
* Sets the value at rowId to `value`.
*/
+ public abstract void putFloat(int rowId, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putFloats(int rowId, int count, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ * src should contain `count` doubles written as ieee format.
+ */
+ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be ieee formatted floats.
+ */
+ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract float getFloat(int rowId);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
public abstract void putDouble(int rowId, double value);
/**
@@ -369,7 +477,7 @@ public abstract class ColumnVector {
/**
* Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be ieee formated doubles.
+ * The data in src must be ieee formatted doubles.
*/
public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
@@ -469,6 +577,20 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendBoolean(boolean v) {
+ reserve(elementsAppended + 1);
+ putBoolean(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendBooleans(int count, boolean v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putBooleans(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendByte(byte v) {
reserve(elementsAppended + 1);
putByte(elementsAppended, v);
@@ -491,6 +613,28 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendShort(short v) {
+ reserve(elementsAppended + 1);
+ putShort(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendShorts(int count, short v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putShorts(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendShorts(int length, short[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putShorts(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
public final int appendInt(int v) {
reserve(elementsAppended + 1);
putInt(elementsAppended, v);
@@ -535,6 +679,20 @@ public abstract class ColumnVector {
return result;
}
+ public final int appendFloat(float v) {
+ reserve(elementsAppended + 1);
+ putFloat(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendFloats(int count, float v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putFloats(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendDouble(double v) {
reserve(elementsAppended + 1);
putDouble(elementsAppended, v);
@@ -661,7 +819,8 @@ public abstract class ColumnVector {
this.capacity = capacity;
this.type = type;
- if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) {
+ if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
+ || DecimalType.isByteArrayDecimalType(type)) {
DataType childType;
int childCapacity = capacity;
if (type instanceof ArrayType) {
@@ -682,6 +841,13 @@ public abstract class ColumnVector {
}
this.resultArray = null;
this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+ } else if (type instanceof CalendarIntervalType) {
+ // Two columns. Months as int. Microseconds as Long.
+ this.childColumns = new ColumnVector[2];
+ this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
+ this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
+ this.resultArray = null;
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
} else {
this.childColumns = null;
this.resultArray = null;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 6c651a759d..453bc15e13 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -16,12 +16,15 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.commons.lang.NotImplementedException;
@@ -59,19 +62,44 @@ public class ColumnVectorUtils {
private static void appendValue(ColumnVector dst, DataType t, Object o) {
if (o == null) {
- dst.appendNull();
+ if (t instanceof CalendarIntervalType) {
+ dst.appendStruct(true);
+ } else {
+ dst.appendNull();
+ }
} else {
- if (t == DataTypes.ByteType) {
- dst.appendByte(((Byte)o).byteValue());
+ if (t == DataTypes.BooleanType) {
+ dst.appendBoolean(((Boolean)o).booleanValue());
+ } else if (t == DataTypes.ByteType) {
+ dst.appendByte(((Byte) o).byteValue());
+ } else if (t == DataTypes.ShortType) {
+ dst.appendShort(((Short)o).shortValue());
} else if (t == DataTypes.IntegerType) {
dst.appendInt(((Integer)o).intValue());
} else if (t == DataTypes.LongType) {
dst.appendLong(((Long)o).longValue());
+ } else if (t == DataTypes.FloatType) {
+ dst.appendFloat(((Float)o).floatValue());
} else if (t == DataTypes.DoubleType) {
dst.appendDouble(((Double)o).doubleValue());
} else if (t == DataTypes.StringType) {
byte[] b =((String)o).getBytes();
dst.appendByteArray(b, 0, b.length);
+ } else if (t instanceof DecimalType) {
+ DecimalType dt = (DecimalType)t;
+ Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
+ if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+ dst.appendLong(d.toUnscaledLong());
+ } else {
+ final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
+ byte[] bytes = integer.toByteArray();
+ dst.appendByteArray(bytes, 0, bytes.length);
+ }
+ } else if (t instanceof CalendarIntervalType) {
+ CalendarInterval c = (CalendarInterval)o;
+ dst.appendStruct(false);
+ dst.getChildColumn(0).appendInt(c.months);
+ dst.getChildColumn(1).appendLong(c.microseconds);
} else {
throw new NotImplementedException("Type " + t);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 5a575811fa..dbad5e070f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Arrays;
import java.util.Iterator;
@@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -150,44 +153,40 @@ public final class ColumnarBatch {
}
@Override
- public final boolean isNullAt(int ordinal) {
- return columns[ordinal].getIsNull(rowId);
- }
+ public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); }
@Override
- public final boolean getBoolean(int ordinal) {
- throw new NotImplementedException();
- }
+ public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
@Override
public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
@Override
- public final short getShort(int ordinal) {
- throw new NotImplementedException();
- }
+ public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
@Override
- public final int getInt(int ordinal) {
- return columns[ordinal].getInt(rowId);
- }
+ public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
@Override
public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
@Override
- public final float getFloat(int ordinal) {
- throw new NotImplementedException();
- }
+ public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
@Override
- public final double getDouble(int ordinal) {
- return columns[ordinal].getDouble(rowId);
- }
+ public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
@Override
public final Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
@@ -198,12 +197,17 @@ public final class ColumnarBatch {
@Override
public final byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public final CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+ final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+ return new CalendarInterval(months, microseconds);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 335124fd5a..22c5e5fc81 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -19,11 +19,15 @@ package org.apache.spark.sql.execution.vectorized;
import java.nio.ByteOrder;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.ShortType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
@@ -122,6 +126,26 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0));
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ Platform.putByte(null, data + rowId + i, v);
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
+
+ //
// APIs dealing with Bytes
//
@@ -149,6 +173,34 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ Platform.putShort(null, data + 2 * rowId, value);
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ long offset = data + 2 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putShort(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
+ null, data + 2 * rowId, count * 2);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return Platform.getShort(null, data + 2 * rowId);
+ }
+
+ //
// APIs dealing with ints
//
@@ -217,6 +269,41 @@ public final class OffHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) {
+ Platform.putFloat(null, data + rowId * 4, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ long offset = data + 4 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putFloat(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
+ null, data + 4 * rowId, count * 4);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ null, data + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) {
+ return Platform.getFloat(null, data + rowId * 4);
+ }
+
+
+ //
// APIs dealing with doubles
//
@@ -241,7 +328,7 @@ public final class OffHeapColumnVector extends ColumnVector {
@Override
public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex,
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
null, data + rowId * 8, count * 8);
}
@@ -300,11 +387,14 @@ public final class OffHeapColumnVector extends ColumnVector {
Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
this.offsetData =
Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof ByteType) {
+ } else if (type instanceof ByteType || type instanceof BooleanType) {
this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
- } else if (type instanceof IntegerType) {
+ } else if (type instanceof ShortType) {
+ this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
+ } else if (type instanceof IntegerType || type instanceof FloatType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof LongType || type instanceof DoubleType) {
+ } else if (type instanceof LongType || type instanceof DoubleType ||
+ DecimalType.is64BitDecimalType(type)) {
this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8);
} else if (resultStruct != null) {
// Nothing to store.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 8197fa11cd..32356334c0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector {
// Array for each type. Only 1 is populated for any type.
private byte[] byteData;
+ private short[] shortData;
private int[] intData;
private long[] longData;
+ private float[] floatData;
private double[] doubleData;
// Only set if type is Array.
@@ -105,6 +107,30 @@ public final class OnHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ byteData[rowId] = (byte)((value) ? 1 : 0);
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ byteData[i + rowId] = v;
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) {
+ return byteData[rowId] == 1;
+ }
+
+ //
+
+ //
// APIs dealing with Bytes
//
@@ -131,6 +157,33 @@ public final class OnHeapColumnVector extends ColumnVector {
}
//
+ // APIs dealing with Shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ shortData[rowId] = value;
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ for (int i = 0; i < count; ++i) {
+ shortData[i + rowId] = value;
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, shortData, rowId, count);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return shortData[rowId];
+ }
+
+
+ //
// APIs dealing with Ints
//
@@ -202,6 +255,31 @@ public final class OnHeapColumnVector extends ColumnVector {
return longData[rowId];
}
+ //
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) { floatData[rowId] = value; }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ Arrays.fill(floatData, rowId, rowId + count, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, floatData, rowId, count);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) { return floatData[rowId]; }
//
// APIs dealing with doubles
@@ -277,7 +355,7 @@ public final class OnHeapColumnVector extends ColumnVector {
// Spilt this function out since it is the slow path.
private final void reserveInternal(int newCapacity) {
- if (this.resultArray != null) {
+ if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
if (this.arrayLengths != null) {
@@ -286,18 +364,30 @@ public final class OnHeapColumnVector extends ColumnVector {
}
arrayLengths = newLengths;
arrayOffsets = newOffsets;
+ } else if (type instanceof BooleanType) {
+ byte[] newData = new byte[newCapacity];
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ byteData = newData;
} else if (type instanceof ByteType) {
byte[] newData = new byte[newCapacity];
if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
byteData = newData;
+ } else if (type instanceof ShortType) {
+ short[] newData = new short[newCapacity];
+ if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+ shortData = newData;
} else if (type instanceof IntegerType) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
- } else if (type instanceof LongType) {
+ } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) {
long[] newData = new long[newCapacity];
if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
longData = newData;
+ } else if (type instanceof FloatType) {
+ float[] newData = new float[newCapacity];
+ if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+ floatData = newData;
} else if (type instanceof DoubleType) {
double[] newData = new double[newCapacity];
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 67cc08b6fc..445f311107 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.types.CalendarInterval
class ColumnarBatchSuite extends SparkFunSuite {
test("Null Apis") {
@@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
}}
}
-
private def doubleEquals(d1: Double, d2: Double): Boolean = {
if (d1.isNaN && d2.isNaN) {
true
@@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
if (!r1.isNullAt(v._2)) {
v._1.dataType match {
+ case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
+ case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
+ case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
+ "Seed = " + seed)
case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
"Seed = " + seed)
+ case t: DecimalType =>
+ val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
+ val d2 = r2.getDecimal(v._2)
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
case StringType =>
assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
+ case CalendarIntervalType =>
+ assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
case ArrayType(childType, n) =>
val a1 = r1.getArray(v._2).array
val a2 = r2.getList(v._2).toArray
@@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite {
i += 1
}
}
+ case FloatType => {
+ var i = 0
+ while (i < a1.length) {
+ assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]),
+ "Seed = " + seed)
+ i += 1
+ }
+ }
+
+ case t: DecimalType =>
+ var i = 0
+ while (i < a1.length) {
+ assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+ if (a1(i) != null) {
+ val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal
+ val d2 = a2(i).asInstanceOf[java.math.BigDecimal]
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
+ }
+ i += 1
+ }
+
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
@@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
* results.
*/
def testRandomRows(flatSchema: Boolean, numFields: Int) {
- // TODO: add remaining types. Figure out why StringType doesn't work on jenkins.
- val types = Array(ByteType, IntegerType, LongType, DoubleType)
+ // TODO: Figure out why StringType doesn't work on jenkins.
+ val types = Array(
+ BooleanType, ByteType, FloatType, DoubleType,
+ IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
+ CalendarIntervalType)
val seed = System.nanoTime()
- val NUM_ROWS = 500
+ val NUM_ROWS = 200
val NUM_ITERS = 1000
val random = new Random(seed)
var i = 0
@@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
test("Random flat schema") {
- testRandomRows(true, 10)
+ testRandomRows(true, 15)
}
test("Random nested schema") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index b29bf6a464..18761bfd22 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -27,10 +27,14 @@ public final class Platform {
public static final int BYTE_ARRAY_OFFSET;
+ public static final int SHORT_ARRAY_OFFSET;
+
public static final int INT_ARRAY_OFFSET;
public static final int LONG_ARRAY_OFFSET;
+ public static final int FLOAT_ARRAY_OFFSET;
+
public static final int DOUBLE_ARRAY_OFFSET;
public static int getInt(Object object, long offset) {
@@ -168,13 +172,17 @@ public final class Platform {
if (_UNSAFE != null) {
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+ SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+ FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
} else {
BYTE_ARRAY_OFFSET = 0;
+ SHORT_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
+ FLOAT_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
}
}