aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala39
2 files changed, 30 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
index ecd3b16598..534edac56b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
@@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization
predict(SerDe.asTupleRDD(userAndProducts.rdd))
def getUserFeatures: RDD[Array[Any]] = {
- SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ SerDe.fromTuple2RDD(userFeatures.map {
+ case (user, feature) => (user, Vectors.dense(feature))
+ }.asInstanceOf[RDD[(Any, Any)]])
}
def getProductFeatures: RDD[Array[Any]] = {
- SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ SerDe.fromTuple2RDD(productFeatures.map {
+ case (product, feature) => (product, Vectors.dense(feature))
+ }.asInstanceOf[RDD[(Any, Any)]])
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index ab15f0f36a..f976d2f97b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -28,7 +28,6 @@ import scala.reflect.ClassTag
import net.razorvine.pickle._
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
@@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.ChiSqTestResult
-import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
-import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
+import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
+import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.loss.Losses
-import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
+import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable {
data: JavaRDD[LabeledPoint],
lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)).
map(_.asInstanceOf[Object]).asJava
}
@@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable {
mu += model.gaussians(i).mu
sigma += model.gaussians(i).sigma
}
- List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
} finally {
data.rdd.unpersist(blocking = false)
}
@@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable {
*/
def predictSoftGMM(
data: JavaRDD[Vector],
- wt: Object,
+ wt: Vector,
mu: Array[Object],
- si: Array[Object]): RDD[Array[Double]] = {
+ si: Array[Object]): RDD[Vector] = {
- val weight = wt.asInstanceOf[Array[Double]]
+ val weight = wt.toArray
val mean = mu.map(_.asInstanceOf[DenseVector])
val sigma = si.map(_.asInstanceOf[DenseMatrix])
val gaussians = Array.tabulate(weight.length){
i => new MultivariateGaussian(mean(i), sigma(i))
}
val model = new GaussianMixtureModel(weight, gaussians)
- model.predictSoft(data)
+ model.predictSoft(data).map(Vectors.dense)
}
-
+
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable {
out.write(code)
}
+ protected def getBytes(obj: Object): Array[Byte] = {
+ if (obj.getClass.isArray) {
+ obj.asInstanceOf[Array[Byte]]
+ } else {
+ obj.asInstanceOf[String].getBytes(LATIN1)
+ }
+ }
+
private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
}
@@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable {
if (args.length != 1) {
throw new PickleException("should be 1")
}
- val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
+ val bytes = getBytes(args(0))
val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
bb.order(ByteOrder.nativeOrder())
val db = bb.asDoubleBuffer()
@@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val bytes = getBytes(args(2))
val n = bytes.length / 8
val values = new Array[Double](n)
val order = ByteOrder.nativeOrder()
@@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable {
throw new PickleException("should be 3")
}
val size = args(0).asInstanceOf[Int]
- val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
- val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val indiceBytes = getBytes(args(1))
+ val valueBytes = getBytes(args(2))
val n = indiceBytes.length / 4
val indices = new Array[Int](n)
val values = new Array[Double](n)