aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-29 23:04:51 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 23:04:51 -0700
commit43b149fb885a27f9467aab28e5195f6f03aadcf0 (patch)
treec8620d5d0f42e9f3238020e3bce8f8ea527182eb /sql/catalyst/src
parent4bac703eb9dcc286d6b89630cf433f95b63a4a1f (diff)
downloadspark-43b149fb885a27f9467aab28e5195f6f03aadcf0.tar.gz
spark-43b149fb885a27f9467aab28e5195f6f03aadcf0.tar.bz2
spark-43b149fb885a27f9467aab28e5195f6f03aadcf0.zip
[SPARK-14850][ML] convert primitive array from/to unsafe array directly in VectorUDT/MatrixUDT
## What changes were proposed in this pull request? This PR adds `fromPrimitiveArray` and `toPrimitiveArray` in `UnsafeArrayData`, so that we can do the conversion much faster in VectorUDT/MatrixUDT. ## How was this patch tested? existing tests and new test suite `UnsafeArraySuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #12640 from cloud-fan/ml.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java64
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala44
3 files changed, 107 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 648625b2cc..02a863b2bb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -47,7 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String;
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
-public class UnsafeArrayData extends ArrayData {
+public final class UnsafeArrayData extends ArrayData {
private Object baseObject;
private long baseOffset;
@@ -81,7 +81,7 @@ public class UnsafeArrayData extends ArrayData {
}
public Object[] array() {
- throw new UnsupportedOperationException("Only supported on GenericArrayData.");
+ throw new UnsupportedOperationException("Not supported on UnsafeArrayData.");
}
/**
@@ -336,4 +336,64 @@ public class UnsafeArrayData extends ArrayData {
arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return arrayCopy;
}
+
+ public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
+ if (arr.length > (Integer.MAX_VALUE - 4) / 8) {
+ throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
+ "it's too big.");
+ }
+
+ final int offsetRegionSize = 4 * arr.length;
+ final int valueRegionSize = 4 * arr.length;
+ final int totalSize = 4 + offsetRegionSize + valueRegionSize;
+ final byte[] data = new byte[totalSize];
+
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
+
+ int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
+ int valueOffset = 4 + offsetRegionSize;
+ for (int i = 0; i < arr.length; i++) {
+ Platform.putInt(data, offsetPosition, valueOffset);
+ offsetPosition += 4;
+ valueOffset += 4;
+ }
+
+ Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
+ Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+
+ UnsafeArrayData result = new UnsafeArrayData();
+ result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
+ return result;
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
+ if (arr.length > (Integer.MAX_VALUE - 4) / 12) {
+ throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
+ "it's too big.");
+ }
+
+ final int offsetRegionSize = 4 * arr.length;
+ final int valueRegionSize = 8 * arr.length;
+ final int totalSize = 4 + offsetRegionSize + valueRegionSize;
+ final byte[] data = new byte[totalSize];
+
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
+
+ int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
+ int valueOffset = 4 + offsetRegionSize;
+ for (int i = 0; i < arr.length; i++) {
+ Platform.putInt(data, offsetPosition, valueOffset);
+ offsetPosition += 4;
+ valueOffset += 8;
+ }
+
+ Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
+ Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+
+ UnsafeArrayData result = new UnsafeArrayData();
+ result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
+ return result;
+ }
+
+ // TODO: add more specialized methods.
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 651eb1ff0c..0700148bec 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -30,7 +30,7 @@ import org.apache.spark.unsafe.Platform;
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*/
// TODO: Use a more efficient format which doesn't depend on unsafe array.
-public class UnsafeMapData extends MapData {
+public final class UnsafeMapData extends MapData {
private Object baseObject;
private long baseOffset;
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
new file mode 100644
index 0000000000..1685276ff1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+
+class UnsafeArraySuite extends SparkFunSuite {
+
+ test("from primitive int array") {
+ val array = Array(1, 10, 100)
+ val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
+ assert(unsafe.numElements == 3)
+ assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
+ assert(unsafe.getInt(0) == 1)
+ assert(unsafe.getInt(1) == 10)
+ assert(unsafe.getInt(2) == 100)
+ }
+
+ test("from primitive double array") {
+ val array = Array(1.1, 2.2, 3.3)
+ val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
+ assert(unsafe.numElements == 3)
+ assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
+ assert(unsafe.getDouble(0) == 1.1)
+ assert(unsafe.getDouble(1) == 2.2)
+ assert(unsafe.getDouble(2) == 3.3)
+ }
+}