aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala241
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala147
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala114
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala3
9 files changed, 544 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e7088edced..24dc80b1a7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -85,6 +85,14 @@ public final class UnsafeRow extends MutableRow {
})));
}
+ public static boolean isFixedLength(DataType dt) {
+ if (dt instanceof DecimalType) {
+ return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS();
+ } else {
+ return settableFieldTypes.contains(dt);
+ }
+ }
+
//////////////////////////////////////////////////////////////////////////////
// Private fields and methods
//////////////////////////////////////////////////////////////////////////////
@@ -144,6 +152,17 @@ public final class UnsafeRow extends MutableRow {
this.sizeInBytes = sizeInBytes;
}
+ /**
+ * Update this UnsafeRow to point to the underlying byte array.
+ *
+ * @param buf byte array to point to
+ * @param numFields the number of fields in this row
+ * @param sizeInBytes the number of bytes valid in the byte array
+ */
+ public void pointTo(byte[] buf, int numFields, int sizeInBytes) {
+ pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
+ }
+
private void assertIndexIsValid(int index) {
assert index >= 0 : "index (" + index + ") should >= 0";
assert index < numFields : "index (" + index + ") should < " + numFields;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e50ec27fc2..36f4e9c6be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -27,6 +27,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.types._
@@ -293,6 +294,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
evaluator.setDefaultImports(Array(
+ classOf[PlatformDependent].getName,
classOf[InternalRow].getName,
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1d223986d9..6c99086046 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -266,16 +266,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val code = s"""
public Object generate($exprType[] exprs) {
- return new SpecificProjection(exprs);
+ return new SpecificUnsafeProjection(exprs);
}
- class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+ class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {
private $exprType[] expressions;
${declareMutableStates(ctx)}
- public SpecificProjection($exprType[] expressions) {
+ public SpecificUnsafeProjection($exprType[] expressions) {
this.expressions = expressions;
${initMutableStates(ctx)}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
new file mode 100644
index 0000000000..645eb48d5a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.PlatformDependent
+
+
+abstract class UnsafeRowJoiner {
+ def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow
+}
+
+
+/**
+ * A code generator for concatenating two [[UnsafeRow]]s into a single [[UnsafeRow]].
+ *
+ * The high level algorithm is:
+ *
+ * 1. Concatenate the two bitsets together into a single one, taking padding into account.
+ * 2. Move fixed-length data.
+ * 3. Move variable-length data.
+ * 4. Update the offset position (i.e. the upper 32 bits in the fixed length part) for all
+ * variable-length data.
+ */
+object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] {
+
+ def dump(word: Long): String = {
+ Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString
+ }
+
+ override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = {
+ create(in._1, in._2)
+ }
+
+ override protected def canonicalize(in: (StructType, StructType)): (StructType, StructType) = in
+
+ override protected def bind(in: (StructType, StructType), inputSchema: Seq[Attribute])
+ : (StructType, StructType) = {
+ in
+ }
+
+ def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
+ val ctx = newCodeGenContext()
+ val offset = PlatformDependent.BYTE_ARRAY_OFFSET
+
+ val bitset1Words = (schema1.size + 63) / 64
+ val bitset2Words = (schema2.size + 63) / 64
+ val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
+ val bitset1Remainder = schema1.size % 64
+ val bitset2Remainder = schema2.size % 64
+
+ // The number of words we can reduce when we concat two rows together.
+ // The only reduction comes from merging the bitset portion of the two rows, saving 1 word.
+ val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
+
+ // --------------------- copy bitset from row 1 ----------------------- //
+ val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
+ | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
+ """.stripMargin
+ }.mkString
+
+
+ // --------------------- copy bitset from row 2 ----------------------- //
+ var copyBitset2 = ""
+ if (bitset1Remainder == 0) {
+ copyBitset2 += Seq.tabulate(bitset2Words) { i =>
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8},
+ | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
+ """.stripMargin
+ }.mkString
+ } else {
+ copyBitset2 = Seq.tabulate(bitset2Words) { i =>
+ s"""
+ |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8});
+ |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1);
+ |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
+ """.stripMargin
+ }.mkString
+
+ copyBitset2 += Seq.tabulate(bitset2Words) { i =>
+ val currentOffset = offset + (bitset1Words + i - 1) * 8
+ if (i == 0) {
+ if (bitset1Words > 0) {
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
+ | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset));
+ """.stripMargin
+ } else {
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1);
+ """.stripMargin
+ }
+ } else {
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2);
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ if (bitset2Words > 0 &&
+ (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) {
+ val lastWord = bitset2Words - 1
+ copyBitset2 +=
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8},
+ | bs2w${lastWord}p2);
+ """.stripMargin
+ }
+ }
+
+ // --------------------- copy fixed length portion from row 1 ----------------------- //
+ var cursor = offset + outputBitsetWords * 8
+ val copyFixedLengthRow1 = s"""
+ |// Copy fixed length data for row1
+ |PlatformDependent.copyMemory(
+ | obj1, offset1 + ${bitset1Words * 8},
+ | buf, $cursor,
+ | ${schema1.size * 8});
+ """.stripMargin
+ cursor += schema1.size * 8
+
+ // --------------------- copy fixed length portion from row 2 ----------------------- //
+ val copyFixedLengthRow2 = s"""
+ |// Copy fixed length data for row2
+ |PlatformDependent.copyMemory(
+ | obj2, offset2 + ${bitset2Words * 8},
+ | buf, $cursor,
+ | ${schema2.size * 8});
+ """.stripMargin
+ cursor += schema2.size * 8
+
+ // --------------------- copy variable length portion from row 1 ----------------------- //
+ val copyVariableLengthRow1 = s"""
+ |// Copy variable length data for row1
+ |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
+ |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1;
+ |PlatformDependent.copyMemory(
+ | obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
+ | buf, $cursor,
+ | numBytesVariableRow1);
+ """.stripMargin
+
+ // --------------------- copy variable length portion from row 2 ----------------------- //
+ val copyVariableLengthRow2 = s"""
+ |// Copy variable length data for row2
+ |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
+ |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2;
+ |PlatformDependent.copyMemory(
+ | obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
+ | buf, $cursor + numBytesVariableRow1,
+ | numBytesVariableRow2);
+ """.stripMargin
+
+ // ------------- update fixed length data for variable length data type --------------- //
+ val updateOffset = (schema1 ++ schema2).zipWithIndex.map { case (field, i) =>
+ // Skip fixed length data types, and only generate code for variable length data
+ if (UnsafeRow.isFixedLength(field.dataType)) {
+ ""
+ } else {
+ // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the
+ // offset in the upper 32 bit of the words, we can just shift the offset to the left by
+ // 32 and increment that amount in place.
+ val shift =
+ if (i < schema1.size) {
+ s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
+ } else {
+ s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1"
+ }
+ val cursor = offset + outputBitsetWords * 8 + i * 8
+ s"""
+ |PlatformDependent.UNSAFE.putLong(buf, $cursor,
+ | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
+ """.stripMargin
+ }
+ }.mkString
+
+ // ------------------------ Finally, put everything together --------------------------- //
+ val code = s"""
+ |public Object generate($exprType[] exprs) {
+ | return new SpecificUnsafeRowJoiner();
+ |}
+ |
+ |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} {
+ | private byte[] buf = new byte[64];
+ | private UnsafeRow out = new UnsafeRow();
+ |
+ | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) {
+ | // row1: ${schema1.size} fields, $bitset1Words words in bitset
+ | // row2: ${schema2.size}, $bitset2Words words in bitset
+ | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset
+ | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes();
+ | if (sizeInBytes > buf.length) {
+ | buf = new byte[sizeInBytes];
+ | }
+ |
+ | final Object obj1 = row1.getBaseObject();
+ | final long offset1 = row1.getBaseOffset();
+ | final Object obj2 = row2.getBaseObject();
+ | final long offset2 = row2.getBaseOffset();
+ |
+ | $copyBitset1
+ | $copyBitset2
+ | $copyFixedLengthRow1
+ | $copyFixedLengthRow2
+ | $copyVariableLengthRow1
+ | $copyVariableLengthRow2
+ | $updateOffset
+ |
+ | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction);
+ |
+ | return out;
+ | }
+ |}
+ """.stripMargin
+
+ logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
+ // println(CodeFormatter.format(code))
+
+ val c = compile(code)
+ c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 75ae29d690..81267dc915 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -66,6 +66,19 @@ object RandomDataGenerator {
}
/**
+ * Returns a randomly generated schema, based on the given accepted types.
+ *
+ * @param numFields the number of fields in this schema
+ * @param acceptedTypes types to draw from.
+ */
+ def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = {
+ StructType(Seq.tabulate(numFields) { i =>
+ val dt = acceptedTypes(Random.nextInt(acceptedTypes.size))
+ StructField("col_" + i, dt, nullable = true)
+ })
+ }
+
+ /**
* Returns a function which generates random values for the given [[DataType]], or `None` if no
* random data generator is defined for that data type. The generated values will use an external
* representation of the data type; for example, the random generator for [[DateType]] will return
@@ -94,7 +107,7 @@ object RandomDataGenerator {
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case DecimalType.Fixed(precision, scale) => Some(
- () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision)))
+ () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new MathContext(precision)))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
new file mode 100644
index 0000000000..76d9d991ed
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite for the bitset portion of the row concatenation.
+ */
+class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
+
+ test("bitset concat: boundary size 0, 0") {
+ testBitsets(0, 0)
+ }
+
+ test("bitset concat: boundary size 0, 64") {
+ testBitsets(0, 64)
+ }
+
+ test("bitset concat: boundary size 64, 0") {
+ testBitsets(64, 0)
+ }
+
+ test("bitset concat: boundary size 64, 64") {
+ testBitsets(64, 64)
+ }
+
+ test("bitset concat: boundary size 0, 128") {
+ testBitsets(0, 128)
+ }
+
+ test("bitset concat: boundary size 128, 0") {
+ testBitsets(128, 0)
+ }
+
+ test("bitset concat: boundary size 128, 128") {
+ testBitsets(128, 128)
+ }
+
+ test("bitset concat: single word bitsets") {
+ testBitsets(10, 5)
+ }
+
+ test("bitset concat: first bitset larger than a word") {
+ testBitsets(67, 5)
+ }
+
+ test("bitset concat: second bitset larger than a word") {
+ testBitsets(6, 67)
+ }
+
+ test("bitset concat: no reduction in bitset size") {
+ testBitsets(33, 34)
+ }
+
+ test("bitset concat: two words") {
+ testBitsets(120, 95)
+ }
+
+ test("bitset concat: bitset 65, 128") {
+ testBitsets(65, 128)
+ }
+
+ test("bitset concat: randomized tests") {
+ for (i <- 1 until 20) {
+ val numFields1 = Random.nextInt(1000)
+ val numFields2 = Random.nextInt(1000)
+ testBitsetsOnce(numFields1, numFields2)
+ }
+ }
+
+ private def createUnsafeRow(numFields: Int): UnsafeRow = {
+ val row = new UnsafeRow
+ val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
+ val buf = new Array[Byte](sizeInBytes)
+ row.pointTo(buf, numFields, sizeInBytes)
+ row
+ }
+
+ private def testBitsets(numFields1: Int, numFields2: Int): Unit = {
+ for (i <- 0 until 5) {
+ testBitsetsOnce(numFields1, numFields2)
+ }
+ }
+
+ private def testBitsetsOnce(numFields1: Int, numFields2: Int): Unit = {
+ info(s"num fields: $numFields1 and $numFields2")
+ val schema1 = StructType(Seq.tabulate(numFields1) { i => StructField(s"a_$i", IntegerType) })
+ val schema2 = StructType(Seq.tabulate(numFields2) { i => StructField(s"b_$i", IntegerType) })
+
+ val row1 = createUnsafeRow(numFields1)
+ val row2 = createUnsafeRow(numFields2)
+
+ if (numFields1 > 0) {
+ for (i <- 0 until Random.nextInt(numFields1)) {
+ row1.setNullAt(Random.nextInt(numFields1))
+ }
+ }
+ if (numFields2 > 0) {
+ for (i <- 0 until Random.nextInt(numFields2)) {
+ row2.setNullAt(Random.nextInt(numFields2))
+ }
+ }
+
+ val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+ val output = concater.join(row1, row2)
+
+ def dumpDebug(): String = {
+ val set1 = Seq.tabulate(numFields1) { i => if (row1.isNullAt(i)) "1" else "0" }
+ val set2 = Seq.tabulate(numFields2) { i => if (row2.isNullAt(i)) "1" else "0" }
+ val out = Seq.tabulate(numFields1 + numFields2) { i => if (output.isNullAt(i)) "1" else "0" }
+
+ s"""
+ |input1: ${set1.mkString}
+ |input2: ${set2.mkString}
+ |output: ${out.mkString}
+ """.stripMargin
+ }
+
+ for (i <- 0 until (numFields1 + numFields2)) {
+ if (i < numFields1) {
+ assert(output.isNullAt(i) === row1.isNullAt(i), dumpDebug())
+ } else {
+ assert(output.isNullAt(i) === row2.isNullAt(i - numFields1), dumpDebug())
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
new file mode 100644
index 0000000000..59729e7646
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for [[GenerateUnsafeRowJoiner]].
+ *
+ * There is also a separate [[GenerateUnsafeRowJoinerBitsetSuite]] that tests specifically
+ * concatenation for the bitset portion, since that is the hardest one to get right.
+ */
+class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
+
+ private val fixed = Seq(IntegerType)
+ private val variable = Seq(IntegerType, StringType)
+
+ test("simple fixed width types") {
+ testConcat(0, 0, fixed)
+ testConcat(0, 1, fixed)
+ testConcat(1, 0, fixed)
+ testConcat(64, 0, fixed)
+ testConcat(0, 64, fixed)
+ testConcat(64, 64, fixed)
+ }
+
+ test("randomized fix width types") {
+ for (i <- 0 until 20) {
+ testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
+ }
+ }
+
+ test("simple variable width types") {
+ testConcat(0, 0, variable)
+ testConcat(0, 1, variable)
+ testConcat(1, 0, variable)
+ testConcat(64, 0, variable)
+ testConcat(0, 64, variable)
+ testConcat(64, 64, variable)
+ }
+
+ test("randomized variable width types") {
+ for (i <- 0 until 10) {
+ testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable)
+ }
+ }
+
+ private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
+ for (i <- 0 until 10) {
+ testConcatOnce(numFields1, numFields2, candidateTypes)
+ }
+ }
+
+ private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) {
+ info(s"schema size $numFields1, $numFields2")
+ val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
+ val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)
+
+ // Create the converters needed to convert from external row to internal row and to UnsafeRows.
+ val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1)
+ val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2)
+ val converter1 = UnsafeProjection.create(schema1)
+ val converter2 = UnsafeProjection.create(schema2)
+
+ // Create the input rows, convert them into UnsafeRows.
+ val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply()
+ val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
+ val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
+ val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+
+ // Run the joiner.
+ val mergedSchema = StructType(schema1 ++ schema2)
+ val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+ val output = concater.join(row1, row2)
+
+ // Test everything equals ...
+ for (i <- mergedSchema.indices) {
+ if (i < schema1.size) {
+ assert(output.isNullAt(i) === row1.isNullAt(i))
+ if (!output.isNullAt(i)) {
+ assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
+ }
+ } else {
+ assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
+ if (!output.isNullAt(i)) {
+ assert(output.get(i, mergedSchema(i).dataType) ===
+ row2.get(i - schema1.size, mergedSchema(i).dataType))
+ }
+ }
+ }
+ }
+
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 66012e3c94..08a98cdd94 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -73,12 +73,7 @@ public final class UnsafeFixedWidthAggregationMap {
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
- if (field.dataType() instanceof DecimalType) {
- DecimalType dt = (DecimalType) field.dataType();
- if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
- return false;
- }
- } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ if (!UnsafeRow.isFixedLength(field.dataType())) {
return false;
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 450963547c..b3f821e0cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
+/**
+ * A test suite that generates randomized data to test the [[TungstenSort]] operator.
+ */
class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {