diff options
author | Davies Liu <davies@databricks.com> | 2015-10-12 21:12:59 -0700 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2015-10-12 21:12:59 -0700 |
commit | c4da5345a0ef643a7518756caaa18ff3f3ea9acc (patch) | |
tree | 330ed74a4ebe7e98b8983df84d0d91f556b7199e /sql/catalyst/src | |
parent | f97e9323b526b3d0b0fee0ca03f4276f37bb5750 (diff) | |
download | spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.gz spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.bz2 spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.zip |
[SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types
This PR improve the unrolling and read of complex types in columnar cache:
1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize)
2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[]
3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy.
Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less).
```
df = sqlContext.read.parquet(path)
t = time.time()
df.cache()
df.count()
print 'unrolling', time.time() - t
for i in range(10):
t = time.time()
print df.select("*")._jdf.queryExecution().toRdd().count()
print time.time() - t
```
The schema is
```
root
|-- a: struct (nullable = true)
| |-- b: long (nullable = true)
| |-- c: string (nullable = true)
|-- d: array (nullable = true)
| |-- element: long (containsNull = true)
|-- e: map (nullable = true)
| |-- key: long
| |-- value: string (valueContainsNull = true)
```
Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that.
Author: Davies Liu <davies@databricks.com>
Closes #9016 from davies/complex2.
Diffstat (limited to 'sql/catalyst/src')
5 files changed, 50 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index fdd9125613..796f8abec9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -145,6 +146,8 @@ public class UnsafeArrayData extends ArrayData { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -306,6 +309,15 @@ public class UnsafeArrayData extends ArrayData { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5af7ed5d6e..36859fbab9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions; import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -326,6 +327,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -602,6 +605,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert (buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public void writeExternal(ObjectOutput out) throws IOException { byte[] bytes = getBytes(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a0fe5bd77e..7544d27e3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -129,6 +129,7 @@ class CodeGenContext { case _: ArrayType => s"$input.getArray($ordinal)" case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) case _ => s"($jt)$input.get($ordinal, null)" } } @@ -143,6 +144,7 @@ class CodeGenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } } @@ -177,6 +179,7 @@ class CodeGenContext { case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName case _ => "Object" @@ -222,6 +225,7 @@ class CodeGenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case other => s"$c1.equals($c2)" } @@ -255,6 +259,7 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 9873630937..ee50587ed0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => GeneratedExpressionCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3e0e81733f..1b957a508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case dt: OpenHashSetUDT => false // it's not a standard UDT + case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dt), index) => + case ((input, dataType), index) => + val dt = dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { @@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val index = ctx.freshName("index") val element = ctx.freshName("element") - val jt = ctx.javaType(elementType) + val et = elementType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } + + val jt = ctx.javaType(et) - val fixedElementSize = elementType match { + val fixedElementSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize + case _ if ctx.isPrimitiveType(jt) => et.defaultSize case _ => 0 } - val writeElement = elementType match { + val writeElement = et match { case t: StructType => s""" $arrayWriter.setOffset($index); @@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(elementType) => + case _ if ctx.isPrimitiveType(et) => // Should we do word align? - val dataSize = elementType.defaultSize + val dataSize = et.defaultSize s""" $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, elementType, + ${writePrimitiveType(ctx, element, et, s"$bufferHolder.buffer", s"$bufferHolder.cursor")} $bufferHolder.cursor += $dataSize; """ @@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($input.isNullAt($index)) { $arrayWriter.setNullAt($index); } else { - final $jt $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } |