aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/test')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala195
2 files changed, 191 insertions, 27 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 1265908182..90790dda75 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
assert(array.numElements == values.length)
- assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
+ assert(array.getSizeInBytes ==
+ 8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length))
values.zipWithIndex.foreach {
case (value, index) => assert(array.getInt(index) == value)
}
@@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
testArrayInt(map.keyArray, keys)
testArrayInt(map.valueArray, values)
- assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+ assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
test("basic conversion with array type") {
@@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedArray = unsafeArray2.getArray(0)
testArrayInt(nestedArray, Seq(3, 4))
- assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+ assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes)
val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
@@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedMap = valueArray.getMap(0)
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
- assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes))
}
- assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
@@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getLong(0) == 2L)
}
- assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getSizeInBytes == 8 + 8)
assert(innerStruct.getLong(0) == 4L)
- assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
}
- assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
+ assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes))
val field2 = unsafeRow.getMap(1)
assert(field2.numElements == 1)
@@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerArray = valueArray.getArray(0)
testArrayInt(innerArray, Seq(4))
- assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes)
}
- assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
index 1685276ff1..f0e247bf46 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
@@ -18,27 +18,190 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class UnsafeArraySuite extends SparkFunSuite {
- test("from primitive int array") {
- val array = Array(1, 10, 100)
- val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
- assert(unsafe.numElements == 3)
- assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
- assert(unsafe.getInt(0) == 1)
- assert(unsafe.getInt(1) == 10)
- assert(unsafe.getInt(2) == 100)
+ val booleanArray = Array(false, true)
+ val shortArray = Array(1.toShort, 10.toShort, 100.toShort)
+ val intArray = Array(1, 10, 100)
+ val longArray = Array(1.toLong, 10.toLong, 100.toLong)
+ val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat)
+ val doubleArray = Array(1.1, 2.2, 3.3)
+ val stringArray = Array("1", "10", "100")
+ val dateArray = Array(
+ DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get,
+ DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get)
+ val timestampArray = Array(
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get,
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get)
+ val decimalArray4_1 = Array(
+ BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR),
+ BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR))
+ val decimalArray20_20 = Array(
+ BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR),
+ BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR))
+
+ val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123))
+
+ val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300))
+ val doubleMultiDimArray = Array(
+ Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))
+
+ test("read array") {
+ val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
+ toRow(booleanArray).getArray(0)
+ assert(unsafeBoolean.isInstanceOf[UnsafeArrayData])
+ assert(unsafeBoolean.numElements == booleanArray.length)
+ booleanArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeBoolean.getBoolean(i) == e)
+ }
+
+ val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind().
+ toRow(shortArray).getArray(0)
+ assert(unsafeShort.isInstanceOf[UnsafeArrayData])
+ assert(unsafeShort.numElements == shortArray.length)
+ shortArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeShort.getShort(i) == e)
+ }
+
+ val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind().
+ toRow(intArray).getArray(0)
+ assert(unsafeInt.isInstanceOf[UnsafeArrayData])
+ assert(unsafeInt.numElements == intArray.length)
+ intArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeInt.getInt(i) == e)
+ }
+
+ val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind().
+ toRow(longArray).getArray(0)
+ assert(unsafeLong.isInstanceOf[UnsafeArrayData])
+ assert(unsafeLong.numElements == longArray.length)
+ longArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeLong.getLong(i) == e)
+ }
+
+ val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind().
+ toRow(floatArray).getArray(0)
+ assert(unsafeFloat.isInstanceOf[UnsafeArrayData])
+ assert(unsafeFloat.numElements == floatArray.length)
+ floatArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeFloat.getFloat(i) == e)
+ }
+
+ val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind().
+ toRow(doubleArray).getArray(0)
+ assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDouble.numElements == doubleArray.length)
+ doubleArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDouble.getDouble(i) == e)
+ }
+
+ val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind().
+ toRow(stringArray).getArray(0)
+ assert(unsafeString.isInstanceOf[UnsafeArrayData])
+ assert(unsafeString.numElements == stringArray.length)
+ stringArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeString.getUTF8String(i).toString().equals(e))
+ }
+
+ val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind().
+ toRow(dateArray).getArray(0)
+ assert(unsafeDate.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDate.numElements == dateArray.length)
+ dateArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDate.get(i, DateType) == e)
+ }
+
+ val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind().
+ toRow(timestampArray).getArray(0)
+ assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData])
+ assert(unsafeTimestamp.numElements == timestampArray.length)
+ timestampArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeTimestamp.get(i, TimestampType) == e)
+ }
+
+ Seq(decimalArray4_1, decimalArray20_20).map { decimalArray =>
+ val decimal = decimalArray(0)
+ val schema = new StructType().add(
+ "array", ArrayType(DecimalType(decimal.precision, decimal.scale)))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val externalRow = Row(decimalArray)
+ val ir = encoder.toRow(externalRow)
+
+ val unsafeDecimal = ir.getArray(0)
+ assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDecimal.numElements == decimalArray.length)
+ decimalArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e)
+ }
+ }
+
+ val schema = new StructType().add("array", ArrayType(CalendarIntervalType))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val externalRow = Row(calenderintervalArray)
+ val ir = encoder.toRow(externalRow)
+ val unsafeCalendar = ir.getArray(0)
+ assert(unsafeCalendar.isInstanceOf[UnsafeArrayData])
+ assert(unsafeCalendar.numElements == calenderintervalArray.length)
+ calenderintervalArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeCalendar.getInterval(i) == e)
+ }
+
+ val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind().
+ toRow(intMultiDimArray).getArray(0)
+ assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData])
+ assert(unsafeMultiDimInt.numElements == intMultiDimArray.length)
+ intMultiDimArray.zipWithIndex.map { case (a, j) =>
+ val u = unsafeMultiDimInt.getArray(j)
+ assert(u.isInstanceOf[UnsafeArrayData])
+ assert(u.numElements == a.length)
+ a.zipWithIndex.map { case (e, i) =>
+ assert(u.getInt(i) == e)
+ }
+ }
+
+ val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind().
+ toRow(doubleMultiDimArray).getArray(0)
+ assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
+ assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length)
+ doubleMultiDimArray.zipWithIndex.map { case (a, j) =>
+ val u = unsafeMultiDimDouble.getArray(j)
+ assert(u.isInstanceOf[UnsafeArrayData])
+ assert(u.numElements == a.length)
+ a.zipWithIndex.map { case (e, i) =>
+ assert(u.getDouble(i) == e)
+ }
+ }
}
- test("from primitive double array") {
- val array = Array(1.1, 2.2, 3.3)
- val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
- assert(unsafe.numElements == 3)
- assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
- assert(unsafe.getDouble(0) == 1.1)
- assert(unsafe.getDouble(1) == 2.2)
- assert(unsafe.getDouble(2) == 3.3)
+ test("from primitive array") {
+ val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray)
+ assert(unsafeInt.numElements == 3)
+ assert(unsafeInt.getSizeInBytes ==
+ ((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8)
+ intArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeInt.getInt(i) == e)
+ }
+
+ val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray)
+ assert(unsafeDouble.numElements == 3)
+ assert(unsafeDouble.getSizeInBytes ==
+ ((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8)
+ doubleArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDouble.getDouble(i) == e)
+ }
+ }
+
+ test("to primitive array") {
+ val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
+ assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray))
+
+ val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
+ assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
}
}