From 03377d2522776267a07b7d6ae9bddf79a4e0f516 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 Jul 2015 21:09:00 -0700 Subject: [SPARK-9358][SQL] Code generation for UnsafeRow joiner. This patch creates a code generated unsafe row concatenator that can be used to concatenate/join two UnsafeRows into a single UnsafeRow. Since it is inherently hard to test these low level stuff, the test suites employ randomized testing heavily in order to guarantee correctness. Author: Reynold Xin Closes #7821 from rxin/rowconcat and squashes the following commits: 8717f35 [Reynold Xin] Rebase and code review. 72c5d8e [Reynold Xin] Fixed a bug. a84ed2e [Reynold Xin] Fixed offset. 40c3fb2 [Reynold Xin] Reset random data generator. f0913aa [Reynold Xin] Test fixes. 6687b6f [Reynold Xin] Updated documentation. 00354b9 [Reynold Xin] Support concat data as well. e9a4347 [Reynold Xin] Updated. 6269f96 [Reynold Xin] Fixed a bug . 0f89716 [Reynold Xin] [SPARK-9358][SQL][WIP] Code generation for UnsafeRow concat. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 19 ++ .../expressions/codegen/CodeGenerator.scala | 2 + .../codegen/GenerateUnsafeProjection.scala | 6 +- .../codegen/GenerateUnsafeRowJoiner.scala | 241 +++++++++++++++++++++ .../org/apache/spark/sql/RandomDataGenerator.scala | 15 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 147 +++++++++++++ .../codegen/GenerateUnsafeRowJoinerSuite.scala | 114 ++++++++++ .../execution/UnsafeFixedWidthAggregationMap.java | 7 +- .../spark/sql/execution/TungstenSortSuite.scala | 3 + 9 files changed, 544 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e7088edced..24dc80b1a7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -85,6 +85,14 @@ public final class UnsafeRow extends MutableRow { }))); } + public static boolean isFixedLength(DataType dt) { + if (dt instanceof DecimalType) { + return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS(); + } else { + return settableFieldTypes.contains(dt); + } + } + ////////////////////////////////////////////////////////////////////////////// // Private fields and methods ////////////////////////////////////////////////////////////////////////////// @@ -144,6 +152,17 @@ public final class UnsafeRow extends MutableRow { this.sizeInBytes = sizeInBytes; } + /** + * Update this UnsafeRow to point to the underlying byte array. + * + * @param buf byte array to point to + * @param numFields the number of fields in this row + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int numFields, int sizeInBytes) { + pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + } + private void assertIndexIsValid(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < numFields : "index (" + index + ") should < " + numFields; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e50ec27fc2..36f4e9c6be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,6 +27,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.types._ @@ -293,6 +294,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) evaluator.setDefaultImports(Array( + classOf[PlatformDependent].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1d223986d9..6c99086046 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -266,16 +266,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" public Object generate($exprType[] exprs) { - return new SpecificProjection(exprs); + return new SpecificUnsafeProjection(exprs); } - class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { private $exprType[] expressions; ${declareMutableStates(ctx)} - public SpecificProjection($exprType[] expressions) { + public SpecificUnsafeProjection($exprType[] expressions) { this.expressions = expressions; ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala new file mode 100644 index 0000000000..645eb48d5a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.PlatformDependent + + +abstract class UnsafeRowJoiner { + def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow +} + + +/** + * A code generator for concatenating two [[UnsafeRow]]s into a single [[UnsafeRow]]. + * + * The high level algorithm is: + * + * 1. Concatenate the two bitsets together into a single one, taking padding into account. + * 2. Move fixed-length data. + * 3. Move variable-length data. + * 4. Update the offset position (i.e. the upper 32 bits in the fixed length part) for all + * variable-length data. + */ +object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] { + + def dump(word: Long): String = { + Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString + } + + override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = { + create(in._1, in._2) + } + + override protected def canonicalize(in: (StructType, StructType)): (StructType, StructType) = in + + override protected def bind(in: (StructType, StructType), inputSchema: Seq[Attribute]) + : (StructType, StructType) = { + in + } + + def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { + val ctx = newCodeGenContext() + val offset = PlatformDependent.BYTE_ARRAY_OFFSET + + val bitset1Words = (schema1.size + 63) / 64 + val bitset2Words = (schema2.size + 63) / 64 + val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 + val bitset1Remainder = schema1.size % 64 + val bitset2Remainder = schema2.size % 64 + + // The number of words we can reduce when we concat two rows together. + // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. + val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords + + // --------------------- copy bitset from row 1 ----------------------- // + val copyBitset1 = Seq.tabulate(bitset1Words) { i => + s""" + |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8}, + | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8})); + """.stripMargin + }.mkString + + + // --------------------- copy bitset from row 2 ----------------------- // + var copyBitset2 = "" + if (bitset1Remainder == 0) { + copyBitset2 += Seq.tabulate(bitset2Words) { i => + s""" + |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8}, + | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8})); + """.stripMargin + }.mkString + } else { + copyBitset2 = Seq.tabulate(bitset2Words) { i => + s""" + |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}); + |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1); + |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder}); + """.stripMargin + }.mkString + + copyBitset2 += Seq.tabulate(bitset2Words) { i => + val currentOffset = offset + (bitset1Words + i - 1) * 8 + if (i == 0) { + if (bitset1Words > 0) { + s""" + |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, + | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset)); + """.stripMargin + } else { + s""" + |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1); + """.stripMargin + } + } else { + s""" + |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2); + """.stripMargin + } + }.mkString("\n") + + if (bitset2Words > 0 && + (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) { + val lastWord = bitset2Words - 1 + copyBitset2 += + s""" + |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8}, + | bs2w${lastWord}p2); + """.stripMargin + } + } + + // --------------------- copy fixed length portion from row 1 ----------------------- // + var cursor = offset + outputBitsetWords * 8 + val copyFixedLengthRow1 = s""" + |// Copy fixed length data for row1 + |PlatformDependent.copyMemory( + | obj1, offset1 + ${bitset1Words * 8}, + | buf, $cursor, + | ${schema1.size * 8}); + """.stripMargin + cursor += schema1.size * 8 + + // --------------------- copy fixed length portion from row 2 ----------------------- // + val copyFixedLengthRow2 = s""" + |// Copy fixed length data for row2 + |PlatformDependent.copyMemory( + | obj2, offset2 + ${bitset2Words * 8}, + | buf, $cursor, + | ${schema2.size * 8}); + """.stripMargin + cursor += schema2.size * 8 + + // --------------------- copy variable length portion from row 1 ----------------------- // + val copyVariableLengthRow1 = s""" + |// Copy variable length data for row1 + |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8}; + |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1; + |PlatformDependent.copyMemory( + | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, + | buf, $cursor, + | numBytesVariableRow1); + """.stripMargin + + // --------------------- copy variable length portion from row 2 ----------------------- // + val copyVariableLengthRow2 = s""" + |// Copy variable length data for row2 + |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8}; + |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2; + |PlatformDependent.copyMemory( + | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, + | buf, $cursor + numBytesVariableRow1, + | numBytesVariableRow2); + """.stripMargin + + // ------------- update fixed length data for variable length data type --------------- // + val updateOffset = (schema1 ++ schema2).zipWithIndex.map { case (field, i) => + // Skip fixed length data types, and only generate code for variable length data + if (UnsafeRow.isFixedLength(field.dataType)) { + "" + } else { + // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the + // offset in the upper 32 bit of the words, we can just shift the offset to the left by + // 32 and increment that amount in place. + val shift = + if (i < schema1.size) { + s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" + } else { + s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1" + } + val cursor = offset + outputBitsetWords * 8 + i * 8 + s""" + |PlatformDependent.UNSAFE.putLong(buf, $cursor, + | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32)); + """.stripMargin + } + }.mkString + + // ------------------------ Finally, put everything together --------------------------- // + val code = s""" + |public Object generate($exprType[] exprs) { + | return new SpecificUnsafeRowJoiner(); + |} + | + |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { + | private byte[] buf = new byte[64]; + | private UnsafeRow out = new UnsafeRow(); + | + | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { + | // row1: ${schema1.size} fields, $bitset1Words words in bitset + | // row2: ${schema2.size}, $bitset2Words words in bitset + | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset + | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes(); + | if (sizeInBytes > buf.length) { + | buf = new byte[sizeInBytes]; + | } + | + | final Object obj1 = row1.getBaseObject(); + | final long offset1 = row1.getBaseOffset(); + | final Object obj2 = row2.getBaseObject(); + | final long offset2 = row2.getBaseOffset(); + | + | $copyBitset1 + | $copyBitset2 + | $copyFixedLengthRow1 + | $copyFixedLengthRow2 + | $copyVariableLengthRow1 + | $copyVariableLengthRow2 + | $updateOffset + | + | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction); + | + | return out; + | } + |} + """.stripMargin + + logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") + // println(CodeFormatter.format(code)) + + val c = compile(code) + c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 75ae29d690..81267dc915 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -65,6 +65,19 @@ object RandomDataGenerator { Some(f) } + /** + * Returns a randomly generated schema, based on the given accepted types. + * + * @param numFields the number of fields in this schema + * @param acceptedTypes types to draw from. + */ + def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = { + StructType(Seq.tabulate(numFields) { i => + val dt = acceptedTypes(Random.nextInt(acceptedTypes.size)) + StructField("col_" + i, dt, nullable = true) + }) + } + /** * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external @@ -94,7 +107,7 @@ object RandomDataGenerator { case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) case DecimalType.Fixed(precision, scale) => Some( - () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision))) + () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new MathContext(precision))) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala new file mode 100644 index 0000000000..76d9d991ed --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ + +/** + * A test suite for the bitset portion of the row concatenation. + */ +class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { + + test("bitset concat: boundary size 0, 0") { + testBitsets(0, 0) + } + + test("bitset concat: boundary size 0, 64") { + testBitsets(0, 64) + } + + test("bitset concat: boundary size 64, 0") { + testBitsets(64, 0) + } + + test("bitset concat: boundary size 64, 64") { + testBitsets(64, 64) + } + + test("bitset concat: boundary size 0, 128") { + testBitsets(0, 128) + } + + test("bitset concat: boundary size 128, 0") { + testBitsets(128, 0) + } + + test("bitset concat: boundary size 128, 128") { + testBitsets(128, 128) + } + + test("bitset concat: single word bitsets") { + testBitsets(10, 5) + } + + test("bitset concat: first bitset larger than a word") { + testBitsets(67, 5) + } + + test("bitset concat: second bitset larger than a word") { + testBitsets(6, 67) + } + + test("bitset concat: no reduction in bitset size") { + testBitsets(33, 34) + } + + test("bitset concat: two words") { + testBitsets(120, 95) + } + + test("bitset concat: bitset 65, 128") { + testBitsets(65, 128) + } + + test("bitset concat: randomized tests") { + for (i <- 1 until 20) { + val numFields1 = Random.nextInt(1000) + val numFields2 = Random.nextInt(1000) + testBitsetsOnce(numFields1, numFields2) + } + } + + private def createUnsafeRow(numFields: Int): UnsafeRow = { + val row = new UnsafeRow + val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 + val buf = new Array[Byte](sizeInBytes) + row.pointTo(buf, numFields, sizeInBytes) + row + } + + private def testBitsets(numFields1: Int, numFields2: Int): Unit = { + for (i <- 0 until 5) { + testBitsetsOnce(numFields1, numFields2) + } + } + + private def testBitsetsOnce(numFields1: Int, numFields2: Int): Unit = { + info(s"num fields: $numFields1 and $numFields2") + val schema1 = StructType(Seq.tabulate(numFields1) { i => StructField(s"a_$i", IntegerType) }) + val schema2 = StructType(Seq.tabulate(numFields2) { i => StructField(s"b_$i", IntegerType) }) + + val row1 = createUnsafeRow(numFields1) + val row2 = createUnsafeRow(numFields2) + + if (numFields1 > 0) { + for (i <- 0 until Random.nextInt(numFields1)) { + row1.setNullAt(Random.nextInt(numFields1)) + } + } + if (numFields2 > 0) { + for (i <- 0 until Random.nextInt(numFields2)) { + row2.setNullAt(Random.nextInt(numFields2)) + } + } + + val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) + val output = concater.join(row1, row2) + + def dumpDebug(): String = { + val set1 = Seq.tabulate(numFields1) { i => if (row1.isNullAt(i)) "1" else "0" } + val set2 = Seq.tabulate(numFields2) { i => if (row2.isNullAt(i)) "1" else "0" } + val out = Seq.tabulate(numFields1 + numFields2) { i => if (output.isNullAt(i)) "1" else "0" } + + s""" + |input1: ${set1.mkString} + |input2: ${set2.mkString} + |output: ${out.mkString} + """.stripMargin + } + + for (i <- 0 until (numFields1 + numFields2)) { + if (i < numFields1) { + assert(output.isNullAt(i) === row1.isNullAt(i), dumpDebug()) + } else { + assert(output.isNullAt(i) === row2.isNullAt(i - numFields1), dumpDebug()) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala new file mode 100644 index 0000000000..59729e7646 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.types._ + +/** + * Test suite for [[GenerateUnsafeRowJoiner]]. + * + * There is also a separate [[GenerateUnsafeRowJoinerBitsetSuite]] that tests specifically + * concatenation for the bitset portion, since that is the hardest one to get right. + */ +class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { + + private val fixed = Seq(IntegerType) + private val variable = Seq(IntegerType, StringType) + + test("simple fixed width types") { + testConcat(0, 0, fixed) + testConcat(0, 1, fixed) + testConcat(1, 0, fixed) + testConcat(64, 0, fixed) + testConcat(0, 64, fixed) + testConcat(64, 64, fixed) + } + + test("randomized fix width types") { + for (i <- 0 until 20) { + testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) + } + } + + test("simple variable width types") { + testConcat(0, 0, variable) + testConcat(0, 1, variable) + testConcat(1, 0, variable) + testConcat(64, 0, variable) + testConcat(0, 64, variable) + testConcat(64, 64, variable) + } + + test("randomized variable width types") { + for (i <- 0 until 10) { + testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable) + } + } + + private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = { + for (i <- 0 until 10) { + testConcatOnce(numFields1, numFields2, candidateTypes) + } + } + + private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { + info(s"schema size $numFields1, $numFields2") + val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) + val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) + + // Create the converters needed to convert from external row to internal row and to UnsafeRows. + val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) + val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2) + val converter1 = UnsafeProjection.create(schema1) + val converter2 = UnsafeProjection.create(schema2) + + // Create the input rows, convert them into UnsafeRows. + val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply() + val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() + val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) + val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + + // Run the joiner. + val mergedSchema = StructType(schema1 ++ schema2) + val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) + val output = concater.join(row1, row2) + + // Test everything equals ... + for (i <- mergedSchema.indices) { + if (i < schema1.size) { + assert(output.isNullAt(i) === row1.isNullAt(i)) + if (!output.isNullAt(i)) { + assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + } + } else { + assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) + if (!output.isNullAt(i)) { + assert(output.get(i, mergedSchema(i).dataType) === + row2.get(i - schema1.size, mergedSchema(i).dataType)) + } + } + } + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 66012e3c94..08a98cdd94 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -73,12 +73,7 @@ public final class UnsafeFixedWidthAggregationMap { */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (field.dataType() instanceof DecimalType) { - DecimalType dt = (DecimalType) field.dataType(); - if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { - return false; - } - } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + if (!UnsafeRow.isFixedLength(field.dataType())) { return false; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 450963547c..b3f821e0cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ +/** + * A test suite that generates randomized data to test the [[TungstenSort]] operator. + */ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { -- cgit v1.2.3