aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala224
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala309
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala72
-rw-r--r--python/pyspark/java_gateway.py1
-rw-r--r--python/pyspark/ml/base.py2
-rw-r--r--python/pyspark/ml/classification.py2
-rw-r--r--python/pyspark/ml/clustering.py2
-rw-r--r--python/pyspark/ml/common.py137
-rw-r--r--python/pyspark/ml/evaluation.py2
-rwxr-xr-xpython/pyspark/ml/feature.py2
-rw-r--r--python/pyspark/ml/pipeline.py2
-rw-r--r--python/pyspark/ml/recommendation.py2
-rw-r--r--python/pyspark/ml/regression.py2
-rwxr-xr-xpython/pyspark/ml/tests.py10
-rw-r--r--python/pyspark/ml/tuning.py2
-rw-r--r--python/pyspark/ml/util.py2
-rw-r--r--python/pyspark/ml/wrapper.py2
17 files changed, 518 insertions, 257 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
new file mode 100644
index 0000000000..1279c901c5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util.{ArrayList => JArrayList}
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle._
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.ml.linalg._
+import org.apache.spark.mllib.api.python.SerDeBase
+import org.apache.spark.rdd.RDD
+
+/**
+ * SerDe utility functions for pyspark.ml.
+ */
+private[spark] object MLSerDe extends SerDeBase with Serializable {
+
+ override val PYSPARK_PACKAGE = "pyspark.ml"
+
+ // Pickler for DenseVector
+ private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val vector: DenseVector = obj.asInstanceOf[DenseVector]
+ val bytes = new Array[Byte](8 * vector.size)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE1)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ require(args.length == 1)
+ if (args.length != 1) {
+ throw new PickleException("should be 1")
+ }
+ val bytes = getBytes(args(0))
+ val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](bytes.length / 8)
+ db.get(ans)
+ Vectors.dense(ans)
+ }
+ }
+
+ // Pickler for DenseMatrix
+ private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+ val bytes = new Array[Byte](8 * m.values.length)
+ val order = ByteOrder.nativeOrder()
+ val isTransposed = if (m.isTransposed) 1 else 0
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+ out.write(Opcodes.MARK)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 4) {
+ throw new PickleException("should be 4")
+ }
+ val bytes = getBytes(args(2))
+ val n = bytes.length / 8
+ val values = new Array[Double](n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+ val isTransposed = args(3).asInstanceOf[Int] == 1
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
+ }
+ }
+
+ // Pickler for SparseMatrix
+ private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val s = obj.asInstanceOf[SparseMatrix]
+ val order = ByteOrder.nativeOrder()
+
+ val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+ val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+ val valuesBytes = new Array[Byte](8 * s.values.length)
+ val isTransposed = if (s.isTransposed) 1 else 0
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+ out.write(Opcodes.MARK)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+ out.write(colPtrsBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+ out.write(indicesBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+ out.write(valuesBytes)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 6) {
+ throw new PickleException("should be 6")
+ }
+ val order = ByteOrder.nativeOrder()
+ val colPtrsBytes = getBytes(args(2))
+ val indicesBytes = getBytes(args(3))
+ val valuesBytes = getBytes(args(4))
+ val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+ val rowIndices = new Array[Int](indicesBytes.length / 4)
+ val values = new Array[Double](valuesBytes.length / 8)
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+ val isTransposed = args(5).asInstanceOf[Int] == 1
+ new SparseMatrix(
+ args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
+ isTransposed)
+ }
+ }
+
+ // Pickler for SparseVector
+ private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val v: SparseVector = obj.asInstanceOf[SparseVector]
+ val n = v.indices.length
+ val indiceBytes = new Array[Byte](4 * n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+ val valueBytes = new Array[Byte](8 * n)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(v.size))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+ out.write(indiceBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+ out.write(valueBytes)
+ out.write(Opcodes.TUPLE3)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 3) {
+ throw new PickleException("should be 3")
+ }
+ val size = args(0).asInstanceOf[Int]
+ val indiceBytes = getBytes(args(1))
+ val valueBytes = getBytes(args(2))
+ val n = indiceBytes.length / 4
+ val indices = new Array[Int](n)
+ val values = new Array[Double](n)
+ if (n > 0) {
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+ }
+ new SparseVector(size, indices, values)
+ }
+ }
+
+ var initialized = false
+ // This should be called before trying to serialize any above classes
+ // In cluster mode, this should be put in the closure
+ override def initialize(): Unit = {
+ SerDeUtil.initialize()
+ synchronized {
+ if (!initialized) {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ initialized = true
+ }
+ }
+ }
+ // will not called in Executor automatically
+ initialize()
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e43469bf1c..7df61601fb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -30,7 +30,6 @@ import net.razorvine.pickle._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.ml.linalg.{DenseMatrix => NewDenseMatrix, DenseVector => NewDenseVector, SparseMatrix => NewSparseMatrix, SparseVector => NewSparseVector, Vectors => NewVectors}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.evaluation.RankingMetrics
@@ -1205,23 +1204,21 @@ private[python] class PythonMLLibAPI extends Serializable {
}
/**
- * SerDe utility functions for PythonMLLibAPI.
+ * Basic SerDe utility class.
*/
-private[spark] object SerDe extends Serializable {
+private[spark] abstract class SerDeBase {
- val PYSPARK_PACKAGE = "pyspark.mllib"
- val PYSPARK_ML_PACKAGE = "pyspark.ml"
+ val PYSPARK_PACKAGE: String
+ def initialize(): Unit
/**
* Base class used for pickle
*/
- private[python] abstract class BasePickler[T: ClassTag]
+ private[spark] abstract class BasePickler[T: ClassTag]
extends IObjectPickler with IObjectConstructor {
- protected def packageName: String = PYSPARK_PACKAGE
-
private val cls = implicitly[ClassTag[T]].runtimeClass
- private val module = packageName + "." + cls.getName.split('.')(4)
+ private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
private val name = cls.getSimpleName
// register this to Pickler and Unpickler
@@ -1268,45 +1265,73 @@ private[spark] object SerDe extends Serializable {
private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
}
- // Pickler for (mllib) DenseVector
- private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+ def dumps(obj: AnyRef): Array[Byte] = {
+ obj match {
+ // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
+ case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
+ case _ => new Pickler().dumps(obj)
+ }
+ }
- def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
- val vector: DenseVector = obj.asInstanceOf[DenseVector]
- val bytes = new Array[Byte](8 * vector.size)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val db = bb.asDoubleBuffer()
- db.put(vector.values)
+ def loads(bytes: Array[Byte]): AnyRef = {
+ new Unpickler().loads(bytes)
+ }
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(bytes.length))
- out.write(bytes)
- out.write(Opcodes.TUPLE1)
+ /* convert object into Tuple */
+ def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
+ rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
+ }
+
+ /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+ rdd.map(x => Array(x._1, x._2))
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ new SerDeUtil.AutoBatchedPickler(iter)
}
+ }
- def construct(args: Array[Object]): Object = {
- require(args.length == 1)
- if (args.length != 1) {
- throw new PickleException("should be 1")
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj match {
+ case list: JArrayList[_] => list.asScala
+ case arr: Array[_] => arr
+ }
+ } else {
+ Seq(obj)
+ }
}
- val bytes = getBytes(args(0))
- val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
- bb.order(ByteOrder.nativeOrder())
- val db = bb.asDoubleBuffer()
- val ans = new Array[Double](bytes.length / 8)
- db.get(ans)
- Vectors.dense(ans)
- }
+ }.toJavaRDD()
}
+}
- // Pickler for (new) DenseVector
- private[python] class NewDenseVectorPickler extends BasePickler[NewDenseVector] {
+/**
+ * SerDe utility functions for PythonMLLibAPI.
+ */
+private[spark] object SerDe extends SerDeBase with Serializable {
+
+ override val PYSPARK_PACKAGE = "pyspark.mllib"
- override protected def packageName = PYSPARK_ML_PACKAGE
+ // Pickler for DenseVector
+ private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
- val vector: NewDenseVector = obj.asInstanceOf[NewDenseVector]
+ val vector: DenseVector = obj.asInstanceOf[DenseVector]
val bytes = new Array[Byte](8 * vector.size)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
@@ -1330,11 +1355,11 @@ private[spark] object SerDe extends Serializable {
val db = bb.asDoubleBuffer()
val ans = new Array[Double](bytes.length / 8)
db.get(ans)
- NewVectors.dense(ans)
+ Vectors.dense(ans)
}
}
- // Pickler for (mllib) DenseMatrix
+ // Pickler for DenseMatrix
private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1371,46 +1396,7 @@ private[spark] object SerDe extends Serializable {
}
}
- // Pickler for (new) DenseMatrix
- private[python] class NewDenseMatrixPickler extends BasePickler[NewDenseMatrix] {
-
- override protected def packageName = PYSPARK_ML_PACKAGE
-
- def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
- val m: NewDenseMatrix = obj.asInstanceOf[NewDenseMatrix]
- val bytes = new Array[Byte](8 * m.values.length)
- val order = ByteOrder.nativeOrder()
- val isTransposed = if (m.isTransposed) 1 else 0
- ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
-
- out.write(Opcodes.MARK)
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(m.numRows))
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(m.numCols))
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(bytes.length))
- out.write(bytes)
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(isTransposed))
- out.write(Opcodes.TUPLE)
- }
-
- def construct(args: Array[Object]): Object = {
- if (args.length != 4) {
- throw new PickleException("should be 4")
- }
- val bytes = getBytes(args(2))
- val n = bytes.length / 8
- val values = new Array[Double](n)
- val order = ByteOrder.nativeOrder()
- ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
- val isTransposed = args(3).asInstanceOf[Int] == 1
- new NewDenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
- }
- }
-
- // Pickler for (mllib) SparseMatrix
+ // Pickler for SparseMatrix
private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1465,64 +1451,7 @@ private[spark] object SerDe extends Serializable {
}
}
- // Pickler for (new) SparseMatrix
- private[python] class NewSparseMatrixPickler extends BasePickler[NewSparseMatrix] {
-
- override protected def packageName = PYSPARK_ML_PACKAGE
-
- def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
- val s = obj.asInstanceOf[NewSparseMatrix]
- val order = ByteOrder.nativeOrder()
-
- val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
- val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
- val valuesBytes = new Array[Byte](8 * s.values.length)
- val isTransposed = if (s.isTransposed) 1 else 0
- ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
- ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
- ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
-
- out.write(Opcodes.MARK)
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(s.numRows))
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(s.numCols))
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
- out.write(colPtrsBytes)
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
- out.write(indicesBytes)
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
- out.write(valuesBytes)
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(isTransposed))
- out.write(Opcodes.TUPLE)
- }
-
- def construct(args: Array[Object]): Object = {
- if (args.length != 6) {
- throw new PickleException("should be 6")
- }
- val order = ByteOrder.nativeOrder()
- val colPtrsBytes = getBytes(args(2))
- val indicesBytes = getBytes(args(3))
- val valuesBytes = getBytes(args(4))
- val colPtrs = new Array[Int](colPtrsBytes.length / 4)
- val rowIndices = new Array[Int](indicesBytes.length / 4)
- val values = new Array[Double](valuesBytes.length / 8)
- ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
- ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
- ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
- val isTransposed = args(5).asInstanceOf[Int] == 1
- new NewSparseMatrix(
- args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
- isTransposed)
- }
- }
-
- // Pickler for (mllib) SparseVector
+ // Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1564,50 +1493,6 @@ private[spark] object SerDe extends Serializable {
}
}
- // Pickler for (new) SparseVector
- private[python] class NewSparseVectorPickler extends BasePickler[NewSparseVector] {
-
- override protected def packageName = PYSPARK_ML_PACKAGE
-
- def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
- val v: NewSparseVector = obj.asInstanceOf[NewSparseVector]
- val n = v.indices.length
- val indiceBytes = new Array[Byte](4 * n)
- val order = ByteOrder.nativeOrder()
- ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
- val valueBytes = new Array[Byte](8 * n)
- ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
-
- out.write(Opcodes.BININT)
- out.write(PickleUtils.integer_to_bytes(v.size))
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
- out.write(indiceBytes)
- out.write(Opcodes.BINSTRING)
- out.write(PickleUtils.integer_to_bytes(valueBytes.length))
- out.write(valueBytes)
- out.write(Opcodes.TUPLE3)
- }
-
- def construct(args: Array[Object]): Object = {
- if (args.length != 3) {
- throw new PickleException("should be 3")
- }
- val size = args(0).asInstanceOf[Int]
- val indiceBytes = getBytes(args(1))
- val valueBytes = getBytes(args(2))
- val n = indiceBytes.length / 4
- val indices = new Array[Int](n)
- val values = new Array[Double](n)
- if (n > 0) {
- val order = ByteOrder.nativeOrder()
- ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
- ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
- }
- new NewSparseVector(size, indices, values)
- }
- }
-
// Pickler for MLlib LabeledPoint
private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
@@ -1654,7 +1539,7 @@ private[spark] object SerDe extends Serializable {
var initialized = false
// This should be called before trying to serialize any above classes
// In cluster mode, this should be put in the closure
- def initialize(): Unit = {
+ override def initialize(): Unit = {
SerDeUtil.initialize()
synchronized {
if (!initialized) {
@@ -1662,10 +1547,6 @@ private[spark] object SerDe extends Serializable {
new DenseMatrixPickler().register()
new SparseMatrixPickler().register()
new SparseVectorPickler().register()
- new NewDenseVectorPickler().register()
- new NewDenseMatrixPickler().register()
- new NewSparseMatrixPickler().register()
- new NewSparseVectorPickler().register()
new LabeledPointPickler().register()
new RatingPickler().register()
initialized = true
@@ -1674,58 +1555,4 @@ private[spark] object SerDe extends Serializable {
}
// will not called in Executor automatically
initialize()
-
- def dumps(obj: AnyRef): Array[Byte] = {
- obj match {
- // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
- case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
- case _ => new Pickler().dumps(obj)
- }
- }
-
- def loads(bytes: Array[Byte]): AnyRef = {
- new Unpickler().loads(bytes)
- }
-
- /* convert object into Tuple */
- def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
- rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
- }
-
- /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
- def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
- rdd.map(x => Array(x._1, x._2))
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter =>
- initialize() // let it called in executor
- new SerDeUtil.AutoBatchedPickler(iter)
- }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- initialize() // let it called in executor
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj match {
- case list: JArrayList[_] => list.asScala
- case arr: Array[_] => arr
- }
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
new file mode 100644
index 0000000000..5eaef9aabd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors}
+
+class MLSerDeSuite extends SparkFunSuite {
+
+ MLSerDe.initialize()
+
+ test("pickle vector") {
+ val vectors = Seq(
+ Vectors.dense(Array.empty[Double]),
+ Vectors.dense(0.0),
+ Vectors.dense(0.0, -2.0),
+ Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
+ Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
+ Vectors.sparse(2, Array(1), Array(-2.0)))
+ vectors.foreach { v =>
+ val u = MLSerDe.loads(MLSerDe.dumps(v))
+ assert(u.getClass === v.getClass)
+ assert(u === v)
+ }
+ }
+
+ test("pickle double") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
+ val deser = MLSerDe.loads(MLSerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
+ // We use `equals` here for comparison because we cannot use `==` for NaN
+ assert(x.equals(deser))
+ }
+ }
+
+ test("pickle matrix") {
+ val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
+ val matrix = Matrices.dense(2, 3, values)
+ val nm = MLSerDe.loads(MLSerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
+ assert(matrix === nm)
+
+ // Test conversion for empty matrix
+ val empty = Array[Double]()
+ val emptyMatrix = Matrices.dense(0, 0, empty)
+ val ne = MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
+ assert(emptyMatrix == ne)
+
+ val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+ val nsm = MLSerDe.loads(MLSerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+ assert(sm.toArray === nsm.toArray)
+
+ val smt = new SparseMatrix(
+ 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+ isTransposed = true)
+ val nsmt = MLSerDe.loads(MLSerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+ assert(smt.toArray === nsmt.toArray)
+ }
+}
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index cd4c55f79f..527ca82d31 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -116,6 +116,7 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.ml.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a7a58e17a4..339e5d6af5 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod
from pyspark import since
from pyspark.ml.param import Params
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
@inherit_doc
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 77badebeb4..121b9262dd 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -26,7 +26,7 @@ from pyspark.ml.regression import (
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 92df19e804..75d9a0e8ca 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
'KMeans', 'KMeansModel',
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
new file mode 100644
index 0000000000..256e91e141
--- /dev/null
+++ b/python/pyspark/ml/common.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= '3':
+ long = int
+ unicode = str
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame, SQLContext
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = str(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+]
+
+
+# this will call the ML version of pythonToJava()
+def _to_java_object_rdd(rdd):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, DataFrame):
+ obj = obj._jdf
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, list):
+ obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
+ pass
+ else:
+ data = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.MLSerDe.loads(data)
+ return obj
+
+
+def _java2py(sc, r, encoding="bytes"):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.MLSerDe.javaToPython(r)
+ return RDD(jrdd, sc)
+
+ if clsName == 'Dataset':
+ return DataFrame(r, SQLContext.getOrCreate(sc))
+
+ if clsName in _picklable_classes:
+ r = sc._jvm.MLSerDe.dumps(r)
+ elif isinstance(r, (JavaArray, JavaList)):
+ try:
+ r = sc._jvm.MLSerDe.dumps(r)
+ except Py4JJavaError:
+ pass # not pickable
+
+ if isinstance(r, (bytearray, bytes)):
+ r = PickleSerializer().loads(bytes(r), encoding=encoding)
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def inherit_doc(cls):
+ """
+ A decorator that makes a class inherit documentation from its parents.
+ """
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index cd071f1b7c..1fe8772da7 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
'MulticlassClassificationEvaluator']
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ca77ac395d..a28764a752 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import _convert_to_vector
from pyspark.ml.param.shared import *
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['Binarizer',
'Bucketizer',
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 0777527134..a48f4bb2ad 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
@inherit_doc
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 1778bfe938..0a7096794d 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['ALS', 'ALSModel']
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 7c79ab73c7..db31993f0f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
from pyspark.sql import DataFrame
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4358175a57..981ed9dda0 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -61,7 +61,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
GeneralizedLinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import _java2py
+from pyspark.ml.common import _java2py
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
@@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
+ jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
- jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
+ jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)
def test_serialize(self):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index fe87b6cdb9..f857c5e8c8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -25,7 +25,7 @@ from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
-from pyspark.mllib.common import inherit_doc, _py2java
+from pyspark.ml.common import inherit_doc, _py2java
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9d28823196..4a31a29809 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -23,7 +23,7 @@ if sys.version > '3':
unicode = str
from pyspark import SparkContext, since
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
def _jvm():
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index fef0040faf..25c44b7533 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -22,7 +22,7 @@ from pyspark.sql import DataFrame
from pyspark.ml import Estimator, Transformer, Model
from pyspark.ml.param import Params
from pyspark.ml.util import _jvm
-from pyspark.mllib.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
class JavaWrapper(object):