diff options
author | Michal Senkyr <mike.senkyr@gmail.com> | 2017-03-28 10:09:49 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-03-28 10:09:49 +0800 |
commit | 6c70a38c2e60e1b69a310aee1a92ee0b3815c02d (patch) | |
tree | 54fada64529e2fc978c7e433c47c8bcbadb5e1de | |
parent | ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 (diff) | |
download | spark-6c70a38c2e60e1b69a310aee1a92ee0b3815c02d.tar.gz spark-6c70a38c2e60e1b69a310aee1a92ee0b3815c02d.tar.bz2 spark-6c70a38c2e60e1b69a310aee1a92ee0b3815c02d.zip |
[SPARK-19088][SQL] Optimize sequence type deserialization codegen
## What changes were proposed in this pull request?
Optimization of arbitrary Scala sequence deserialization introduced by #16240.
The previous implementation constructed an array which was then converted by `to`. This required two passes in most cases.
This implementation attempts to remedy that by using `Builder`s provided by the `newBuilder` method on every Scala collection's companion object to build the resulting collection directly.
Example codegen for simple `List` (obtained using `Seq(List(1)).toDS().map(identity).queryExecution.debug.codegen`):
Before:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean deserializetoobject_resultIsNull;
/* 010 */ private java.lang.Object[] deserializetoobject_argValue;
/* 011 */ private boolean MapObjects_loopIsNull1;
/* 012 */ private int MapObjects_loopValue0;
/* 013 */ private boolean deserializetoobject_resultIsNull1;
/* 014 */ private scala.collection.generic.CanBuildFrom deserializetoobject_argValue1;
/* 015 */ private UnsafeRow deserializetoobject_result;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 018 */ private scala.collection.immutable.List mapelements_argValue;
/* 019 */ private UnsafeRow mapelements_result;
/* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 022 */ private scala.collection.immutable.List serializefromobject_argValue;
/* 023 */ private UnsafeRow serializefromobject_result;
/* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 027 */
/* 028 */ public GeneratedIterator(Object[] references) {
/* 029 */ this.references = references;
/* 030 */ }
/* 031 */
/* 032 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 033 */ partitionIndex = index;
/* 034 */ this.inputs = inputs;
/* 035 */ inputadapter_input = inputs[0];
/* 036 */
/* 037 */ deserializetoobject_result = new UnsafeRow(1);
/* 038 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 039 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 040 */
/* 041 */ mapelements_result = new UnsafeRow(1);
/* 042 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 043 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 044 */
/* 045 */ serializefromobject_result = new UnsafeRow(1);
/* 046 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 047 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 048 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 049 */
/* 050 */ }
/* 051 */
/* 052 */ protected void processNext() throws java.io.IOException {
/* 053 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 054 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 055 */ ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 056 */
/* 057 */ deserializetoobject_resultIsNull = false;
/* 058 */
/* 059 */ if (!deserializetoobject_resultIsNull) {
/* 060 */ ArrayData deserializetoobject_value3 = null;
/* 061 */
/* 062 */ if (!false) {
/* 063 */ Integer[] deserializetoobject_convertedArray = null;
/* 064 */ int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 065 */ deserializetoobject_convertedArray = new Integer[deserializetoobject_dataLength];
/* 066 */
/* 067 */ int deserializetoobject_loopIndex = 0;
/* 068 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 069 */ MapObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 070 */ MapObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 071 */
/* 072 */ if (MapObjects_loopIsNull1) {
/* 073 */ throw new RuntimeException(((java.lang.String) references[0]));
/* 074 */ }
/* 075 */ if (false) {
/* 076 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 077 */ } else {
/* 078 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue0;
/* 079 */ }
/* 080 */
/* 081 */ deserializetoobject_loopIndex += 1;
/* 082 */ }
/* 083 */
/* 084 */ deserializetoobject_value3 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray);
/* 085 */ }
/* 086 */ boolean deserializetoobject_isNull2 = true;
/* 087 */ java.lang.Object[] deserializetoobject_value2 = null;
/* 088 */ if (!false) {
/* 089 */ deserializetoobject_isNull2 = false;
/* 090 */ if (!deserializetoobject_isNull2) {
/* 091 */ Object deserializetoobject_funcResult = null;
/* 092 */ deserializetoobject_funcResult = deserializetoobject_value3.array();
/* 093 */ if (deserializetoobject_funcResult == null) {
/* 094 */ deserializetoobject_isNull2 = true;
/* 095 */ } else {
/* 096 */ deserializetoobject_value2 = (java.lang.Object[]) deserializetoobject_funcResult;
/* 097 */ }
/* 098 */
/* 099 */ }
/* 100 */ deserializetoobject_isNull2 = deserializetoobject_value2 == null;
/* 101 */ }
/* 102 */ deserializetoobject_resultIsNull = deserializetoobject_isNull2;
/* 103 */ deserializetoobject_argValue = deserializetoobject_value2;
/* 104 */ }
/* 105 */
/* 106 */ boolean deserializetoobject_isNull1 = deserializetoobject_resultIsNull;
/* 107 */ final scala.collection.Seq deserializetoobject_value1 = deserializetoobject_resultIsNull ? null : scala.collection.mutable.WrappedArray.make(deserializetoobject_argValue);
/* 108 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 109 */ boolean deserializetoobject_isNull = true;
/* 110 */ scala.collection.immutable.List deserializetoobject_value = null;
/* 111 */ if (!deserializetoobject_isNull1) {
/* 112 */ deserializetoobject_resultIsNull1 = false;
/* 113 */
/* 114 */ if (!deserializetoobject_resultIsNull1) {
/* 115 */ boolean deserializetoobject_isNull6 = false;
/* 116 */ final scala.collection.generic.CanBuildFrom deserializetoobject_value6 = false ? null : scala.collection.immutable.List.canBuildFrom();
/* 117 */ deserializetoobject_isNull6 = deserializetoobject_value6 == null;
/* 118 */ deserializetoobject_resultIsNull1 = deserializetoobject_isNull6;
/* 119 */ deserializetoobject_argValue1 = deserializetoobject_value6;
/* 120 */ }
/* 121 */
/* 122 */ deserializetoobject_isNull = deserializetoobject_resultIsNull1;
/* 123 */ if (!deserializetoobject_isNull) {
/* 124 */ Object deserializetoobject_funcResult1 = null;
/* 125 */ deserializetoobject_funcResult1 = deserializetoobject_value1.to(deserializetoobject_argValue1);
/* 126 */ if (deserializetoobject_funcResult1 == null) {
/* 127 */ deserializetoobject_isNull = true;
/* 128 */ } else {
/* 129 */ deserializetoobject_value = (scala.collection.immutable.List) deserializetoobject_funcResult1;
/* 130 */ }
/* 131 */
/* 132 */ }
/* 133 */ deserializetoobject_isNull = deserializetoobject_value == null;
/* 134 */ }
/* 135 */
/* 136 */ boolean mapelements_isNull = true;
/* 137 */ scala.collection.immutable.List mapelements_value = null;
/* 138 */ if (!false) {
/* 139 */ mapelements_argValue = deserializetoobject_value;
/* 140 */
/* 141 */ mapelements_isNull = false;
/* 142 */ if (!mapelements_isNull) {
/* 143 */ Object mapelements_funcResult = null;
/* 144 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 145 */ if (mapelements_funcResult == null) {
/* 146 */ mapelements_isNull = true;
/* 147 */ } else {
/* 148 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 149 */ }
/* 150 */
/* 151 */ }
/* 152 */ mapelements_isNull = mapelements_value == null;
/* 153 */ }
/* 154 */
/* 155 */ if (mapelements_isNull) {
/* 156 */ throw new RuntimeException(((java.lang.String) references[2]));
/* 157 */ }
/* 158 */ serializefromobject_argValue = mapelements_value;
/* 159 */
/* 160 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 161 */ serializefromobject_holder.reset();
/* 162 */
/* 163 */ // Remember the current cursor so that we can calculate how many bytes are
/* 164 */ // written later.
/* 165 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 166 */
/* 167 */ if (serializefromobject_value instanceof UnsafeArrayData) {
/* 168 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 169 */ // grow the global buffer before writing data.
/* 170 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 171 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 172 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 173 */
/* 174 */ } else {
/* 175 */ final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 176 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 177 */
/* 178 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 179 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 180 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 181 */ } else {
/* 182 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 183 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 184 */ }
/* 185 */ }
/* 186 */ }
/* 187 */
/* 188 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 189 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 190 */ append(serializefromobject_result);
/* 191 */ if (shouldStop()) return;
/* 192 */ }
/* 193 */ }
/* 194 */ }
```
After:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean CollectObjects_loopIsNull1;
/* 010 */ private int CollectObjects_loopValue0;
/* 011 */ private UnsafeRow deserializetoobject_result;
/* 012 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 013 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 014 */ private scala.collection.immutable.List mapelements_argValue;
/* 015 */ private UnsafeRow mapelements_result;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 018 */ private scala.collection.immutable.List serializefromobject_argValue;
/* 019 */ private UnsafeRow serializefromobject_result;
/* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 023 */
/* 024 */ public GeneratedIterator(Object[] references) {
/* 025 */ this.references = references;
/* 026 */ }
/* 027 */
/* 028 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 029 */ partitionIndex = index;
/* 030 */ this.inputs = inputs;
/* 031 */ inputadapter_input = inputs[0];
/* 032 */
/* 033 */ deserializetoobject_result = new UnsafeRow(1);
/* 034 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 035 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 036 */
/* 037 */ mapelements_result = new UnsafeRow(1);
/* 038 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 039 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 040 */
/* 041 */ serializefromobject_result = new UnsafeRow(1);
/* 042 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 043 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 044 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 045 */
/* 046 */ }
/* 047 */
/* 048 */ protected void processNext() throws java.io.IOException {
/* 049 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 051 */ ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 052 */
/* 053 */ scala.collection.immutable.List deserializetoobject_value = null;
/* 054 */
/* 055 */ if (!false) {
/* 056 */ int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 057 */ scala.collection.mutable.Builder CollectObjects_builderValue2 = scala.collection.immutable.List$.MODULE$.newBuilder();
/* 058 */ CollectObjects_builderValue2.sizeHint(deserializetoobject_dataLength);
/* 059 */
/* 060 */ int deserializetoobject_loopIndex = 0;
/* 061 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 062 */ CollectObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 063 */ CollectObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 064 */
/* 065 */ if (CollectObjects_loopIsNull1) {
/* 066 */ throw new RuntimeException(((java.lang.String) references[0]));
/* 067 */ }
/* 068 */ if (false) {
/* 069 */ CollectObjects_builderValue2.$plus$eq(null);
/* 070 */ } else {
/* 071 */ CollectObjects_builderValue2.$plus$eq(CollectObjects_loopValue0);
/* 072 */ }
/* 073 */
/* 074 */ deserializetoobject_loopIndex += 1;
/* 075 */ }
/* 076 */
/* 077 */ deserializetoobject_value = (scala.collection.immutable.List) CollectObjects_builderValue2.result();
/* 078 */ }
/* 079 */
/* 080 */ boolean mapelements_isNull = true;
/* 081 */ scala.collection.immutable.List mapelements_value = null;
/* 082 */ if (!false) {
/* 083 */ mapelements_argValue = deserializetoobject_value;
/* 084 */
/* 085 */ mapelements_isNull = false;
/* 086 */ if (!mapelements_isNull) {
/* 087 */ Object mapelements_funcResult = null;
/* 088 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 089 */ if (mapelements_funcResult == null) {
/* 090 */ mapelements_isNull = true;
/* 091 */ } else {
/* 092 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 093 */ }
/* 094 */
/* 095 */ }
/* 096 */ mapelements_isNull = mapelements_value == null;
/* 097 */ }
/* 098 */
/* 099 */ if (mapelements_isNull) {
/* 100 */ throw new RuntimeException(((java.lang.String) references[2]));
/* 101 */ }
/* 102 */ serializefromobject_argValue = mapelements_value;
/* 103 */
/* 104 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 105 */ serializefromobject_holder.reset();
/* 106 */
/* 107 */ // Remember the current cursor so that we can calculate how many bytes are
/* 108 */ // written later.
/* 109 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 110 */
/* 111 */ if (serializefromobject_value instanceof UnsafeArrayData) {
/* 112 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 113 */ // grow the global buffer before writing data.
/* 114 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 115 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 116 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 117 */
/* 118 */ } else {
/* 119 */ final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 120 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 121 */
/* 122 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 123 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 124 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 125 */ } else {
/* 126 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 127 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 128 */ }
/* 129 */ }
/* 130 */ }
/* 131 */
/* 132 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 133 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 134 */ append(serializefromobject_result);
/* 135 */ if (shouldStop()) return;
/* 136 */ }
/* 137 */ }
/* 138 */ }
```
Benchmark results before:
```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Seq 269 / 370 0.0 269125.8 1.0X
List 154 / 176 0.0 154453.5 1.7X
mutable.Queue 210 / 233 0.0 209691.6 1.3X
```
Benchmark results after:
```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Seq 255 / 316 0.0 254697.3 1.0X
List 152 / 177 0.0 152410.0 1.7X
mutable.Queue 213 / 235 0.0 213470.0 1.2X
```
## How was this patch tested?
```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```
Additionally in Spark Shell:
```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])
spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect
```
Author: Michal Senkyr <mike.senkyr@gmail.com>
Closes #16541 from michalsenkyr/dataset-seq-builder.
3 files changed, 54 insertions, 69 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c4af284f73..1c7720afe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection { } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) + val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } + MapObjects(mapFunction, getPath, dataType, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 771ac28e51..bb584f7d08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -429,24 +430,34 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + elementType: DataType, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + val builderValue = s"MapObjects_builderValue$id" + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + customCollectionCls, builderValue) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class * - * The following collection ObjectTypes are currently supported: + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -458,13 +469,19 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + * @param builderValue The name of the builder variable used to construct the resulting collection + * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + customCollectionCls: Option[Class[_]], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -474,7 +491,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -557,15 +575,33 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val collObjectName = s"${cls.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength);""", + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${cls.getName}) $builderValue.result();") + case None => + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; - $convertedArray = $arrayConstructor; + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -574,15 +610,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; + ${addElement("null")} } else { - $convertedArray[$loopIndex] = $genFunctionValue; + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f..70ad064f93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData] |