aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala33
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala82
7 files changed, 203 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 31c6e5def1..7bcaea7ea2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -441,6 +441,22 @@ object ScalaReflection extends ScalaReflection {
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, dt)
+ case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType) =>
+ val cls = input.dataType.asInstanceOf[ObjectType].cls
+ if (cls.isArray && cls.getComponentType.isPrimitive) {
+ StaticInvoke(
+ classOf[UnsafeArrayData],
+ ArrayType(dt, false),
+ "fromPrimitiveArray",
+ input :: Nil)
+ } else {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(dt, schemaFor(elementType).nullable))
+ }
+
case dt =>
NewInstance(
classOf[GenericArrayData],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 2a6fcd03a2..e95e97b9dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -119,18 +119,19 @@ object RowEncoder {
"fromString",
inputObject :: Nil)
- case t @ ArrayType(et, _) => et match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
- // TODO: validate input type for primitive array.
- NewInstance(
- classOf[GenericArrayData],
- inputObject :: Nil,
- dataType = t)
- case _ => MapObjects(
- element => serializerFor(ValidateExternalType(element, et), et),
- inputObject,
- ObjectType(classOf[Object]))
- }
+ case t @ ArrayType(et, cn) =>
+ et match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
+ StaticInvoke(
+ classOf[ArrayData],
+ t,
+ "toArrayData",
+ inputObject :: Nil)
+ case _ => MapObjects(
+ element => serializerFor(ValidateExternalType(element, et), et),
+ inputObject,
+ ObjectType(classOf[Object]))
+ }
case t @ MapType(kt, vt, valueNullable) =>
val keys =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index cad4a08b0d..140e86d670 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
import org.apache.spark.sql.types.DataType
+object ArrayData {
+ def toArrayData(input: Any): ArrayData = input match {
+ case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
+ case other => new GenericArrayData(other)
+ }
+}
+
abstract class ArrayData extends SpecializedGetters with Serializable {
def numElements(): Int
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 03bb102c67..f3702ec92b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
class CatalystTypeConvertersSuite extends SparkFunSuite {
@@ -61,4 +63,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
test("option handling in createToCatalystConverter") {
assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
}
+
+ test("primitive array handling") {
+ val intArray = Array(1, 100, 10000)
+ val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
+ val intArrayType = ArrayType(IntegerType, false)
+ assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)
+
+ val doubleArray = Array(1.1, 111.1, 11111.1)
+ val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
+ val doubleArrayType = ArrayType(DoubleType, false)
+ assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray)
+ === doubleArray)
+ }
+
+ test("An array with null handling") {
+ val intArray = Array(1, null, 100, null, 10000)
+ val intGenericArray = new GenericArrayData(intArray)
+ val intArrayType = ArrayType(IntegerType, true)
+ assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray)
+ === intArray)
+ assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray)
+ == intGenericArray)
+
+ val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
+ val doubleGenericArray = new GenericArrayData(doubleArray)
+ val doubleArrayType = ArrayType(DoubleType, true)
+ assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray)
+ === doubleArray)
+ assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
+ == doubleGenericArray)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 2e513ea22c..1a5569a77d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}
+ test("RowEncoder should support primitive arrays") {
+ val schema = new StructType()
+ .add("booleanPrimitiveArray", ArrayType(BooleanType, false))
+ .add("bytePrimitiveArray", ArrayType(ByteType, false))
+ .add("shortPrimitiveArray", ArrayType(ShortType, false))
+ .add("intPrimitiveArray", ArrayType(IntegerType, false))
+ .add("longPrimitiveArray", ArrayType(LongType, false))
+ .add("floatPrimitiveArray", ArrayType(FloatType, false))
+ .add("doublePrimitiveArray", ArrayType(DoubleType, false))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val input = Seq(
+ Array(true, false),
+ Array(1.toByte, 64.toByte, Byte.MaxValue),
+ Array(1.toShort, 255.toShort, Short.MaxValue),
+ Array(1, 10000, Int.MaxValue),
+ Array(1.toLong, 1000000.toLong, Long.MaxValue),
+ Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
+ Array(11.1111, 123456.7890123, Double.MaxValue)
+ )
+ val row = encoder.toRow(Row.fromSeq(input))
+ val convertedBack = encoder.fromRow(row)
+ input.zipWithIndex.map { case (array, index) =>
+ assert(convertedBack.getSeq(index) === array)
+ }
+ }
+
test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType()
.add("array", ArrayType(IntegerType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a8dd422aa0..81fa8cbf22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1033,6 +1033,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
}
}
+
+ test("identity map for primitive arrays") {
+ val arrayByte = Array(1.toByte, 2.toByte, 3.toByte)
+ val arrayInt = Array(1, 2, 3)
+ val arrayLong = Array(1.toLong, 2.toLong, 3.toLong)
+ val arrayDouble = Array(1.1, 2.2, 3.3)
+ val arrayString = Array("a", "b", "c")
+ val dsByte = sparkContext.parallelize(Seq(arrayByte), 1).toDS.map(e => e)
+ val dsInt = sparkContext.parallelize(Seq(arrayInt), 1).toDS.map(e => e)
+ val dsLong = sparkContext.parallelize(Seq(arrayLong), 1).toDS.map(e => e)
+ val dsDouble = sparkContext.parallelize(Seq(arrayDouble), 1).toDS.map(e => e)
+ val dsString = sparkContext.parallelize(Seq(arrayString), 1).toDS.map(e => e)
+ checkDataset(dsByte, arrayByte)
+ checkDataset(dsInt, arrayInt)
+ checkDataset(dsLong, arrayLong)
+ checkDataset(dsDouble, arrayDouble)
+ checkDataset(dsString, arrayString)
+ }
}
case class Generic[T](id: T, value: Double)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
new file mode 100644
index 0000000000..e7c8f2717f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array
+ * To run this:
+ * 1. replace ignore(...) with test(...)
+ * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark"
+ *
+ * Benchmarks in this file are skipped in normal builds.
+ */
+class PrimitiveArrayBenchmark extends BenchmarkBase {
+
+ def writeDatasetArray(iters: Int): Unit = {
+ import sparkSession.implicits._
+
+ val count = 1024 * 1024 * 2
+
+ val sc = sparkSession.sparkContext
+ val primitiveIntArray = Array.fill[Int](count)(65535)
+ val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS
+ dsInt.count // force to build dataset
+ val intArray = { i: Int =>
+ var n = 0
+ var len = 0
+ while (n < iters) {
+ len += dsInt.map(e => e).queryExecution.toRdd.collect.length
+ n += 1
+ }
+ }
+ val primitiveDoubleArray = Array.fill[Double](count)(65535.0)
+ val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS
+ dsDouble.count // force to build dataset
+ val doubleArray = { i: Int =>
+ var n = 0
+ var len = 0
+ while (n < iters) {
+ len += dsDouble.map(e => e).queryExecution.toRdd.collect.length
+ n += 1
+ }
+ }
+
+ val benchmark = new Benchmark("Write an array in Dataset", count * iters)
+ benchmark.addCase("Int ")(intArray)
+ benchmark.addCase("Double")(doubleArray)
+ benchmark.run
+ /*
+ OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
+ Intel Xeon E3-12xx v2 (Ivy Bridge)
+ Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ Int 352 / 401 23.8 42.0 1.0X
+ Double 821 / 885 10.2 97.9 0.4X
+ */
+ }
+
+ ignore("Write an array in Dataset") {
+ writeDatasetArray(4)
+ }
+}