aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-12 21:12:59 -0700
committerCheng Lian <lian@databricks.com>2015-10-12 21:12:59 -0700
commitc4da5345a0ef643a7518756caaa18ff3f3ea9acc (patch)
tree330ed74a4ebe7e98b8983df84d0d91f556b7199e /sql/catalyst
parentf97e9323b526b3d0b0fee0ca03f4276f37bb5750 (diff)
downloadspark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.gz
spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.bz2
spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.zip
[SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types
This PR improve the unrolling and read of complex types in columnar cache: 1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize) 2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[] 3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy. Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less). ``` df = sqlContext.read.parquet(path) t = time.time() df.cache() df.count() print 'unrolling', time.time() - t for i in range(10): t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` The schema is ``` root |-- a: struct (nullable = true) | |-- b: long (nullable = true) | |-- c: string (nullable = true) |-- d: array (nullable = true) | |-- element: long (containsNull = true) |-- e: map (nullable = true) | |-- key: long | |-- value: string (valueContainsNull = true) ``` Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that. Author: Davies Liu <davies@databricks.com> Closes #9016 from davies/complex2.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java12
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala29
5 files changed, 50 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index fdd9125613..796f8abec9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
@@ -145,6 +146,8 @@ public class UnsafeArrayData extends ArrayData {
return getArray(ordinal);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof UserDefinedType) {
+ return get(ordinal, ((UserDefinedType)dataType).sqlType());
} else {
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
@@ -306,6 +309,15 @@ public class UnsafeArrayData extends ArrayData {
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
+ public void writeTo(ByteBuffer buffer) {
+ assert(buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public UnsafeArrayData copy() {
UnsafeArrayData arrayCopy = new UnsafeArrayData();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 5af7ed5d6e..36859fbab9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import java.io.*;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
@@ -326,6 +327,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
return getArray(ordinal);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof UserDefinedType) {
+ return get(ordinal, ((UserDefinedType)dataType).sqlType());
} else {
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
@@ -602,6 +605,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
+ public void writeTo(ByteBuffer buffer) {
+ assert (buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a0fe5bd77e..7544d27e3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -129,6 +129,7 @@ class CodeGenContext {
case _: ArrayType => s"$input.getArray($ordinal)"
case _: MapType => s"$input.getMap($ordinal)"
case NullType => "null"
+ case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
case _ => s"($jt)$input.get($ordinal, null)"
}
}
@@ -143,6 +144,7 @@ class CodeGenContext {
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
+ case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
case _ => s"$row.update($ordinal, $value)"
}
}
@@ -177,6 +179,7 @@ class CodeGenContext {
case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
+ case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
case _ => "Object"
@@ -222,6 +225,7 @@ class CodeGenContext {
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+ case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case other => s"$c1.equals($c2)"
}
@@ -255,6 +259,7 @@ class CodeGenContext {
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
+ case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 9873630937..ee50587ed0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
+ case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => GeneratedExpressionCode("", "false", input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 3e0e81733f..1b957a508d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
+ case dt: OpenHashSetUDT => false // it's not a standard UDT
+ case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
@@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
- case ((input, dt), index) =>
+ case ((input, dataType), index) =>
+ val dt = dataType match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }
val tmpCursor = ctx.freshName("tmpCursor")
val setNull = dt match {
@@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val index = ctx.freshName("index")
val element = ctx.freshName("element")
- val jt = ctx.javaType(elementType)
+ val et = elementType match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }
+
+ val jt = ctx.javaType(et)
- val fixedElementSize = elementType match {
+ val fixedElementSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
- case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
+ case _ if ctx.isPrimitiveType(jt) => et.defaultSize
case _ => 0
}
- val writeElement = elementType match {
+ val writeElement = et match {
case t: StructType =>
s"""
$arrayWriter.setOffset($index);
@@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
"""
- case _ if ctx.isPrimitiveType(elementType) =>
+ case _ if ctx.isPrimitiveType(et) =>
// Should we do word align?
- val dataSize = elementType.defaultSize
+ val dataSize = et.defaultSize
s"""
$arrayWriter.setOffset($index);
- ${writePrimitiveType(ctx, element, elementType,
+ ${writePrimitiveType(ctx, element, et,
s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
$bufferHolder.cursor += $dataSize;
"""
@@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
if ($input.isNullAt($index)) {
$arrayWriter.setNullAt($index);
} else {
- final $jt $element = ${ctx.getValue(input, elementType, index)};
+ final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
}