aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala54
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala21
4 files changed, 83 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9164c294ac..e9f4175858 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -67,11 +67,13 @@ class PythonMLLibAPI extends Serializable {
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
private def trainRegressionModel(
- trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
+ learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
- val model = trainFunc(data.rdd, initialWeights)
+ // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
+ learner.disableUncachedWarning()
+ val model = learner.run(data.rdd, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(SerDe.dumps(model.weights))
ret.add(model.intercept: java.lang.Double)
@@ -106,8 +108,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- lrAlg.run(data, initialWeights),
+ lrAlg,
data,
initialWeightsBA)
}
@@ -122,15 +123,14 @@ class PythonMLLibAPI extends Serializable {
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ val lassoAlg = new LassoWithSGD()
+ lassoAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
trainRegressionModel(
- (data, initialWeights) =>
- LassoWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ lassoAlg,
data,
initialWeightsBA)
}
@@ -145,15 +145,14 @@ class PythonMLLibAPI extends Serializable {
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ val ridgeAlg = new RidgeRegressionWithSGD()
+ ridgeAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
trainRegressionModel(
- (data, initialWeights) =>
- RidgeRegressionWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ ridgeAlg,
data,
initialWeightsBA)
}
@@ -186,8 +185,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- SVMAlg.run(data, initialWeights),
+ SVMAlg,
data,
initialWeightsBA)
}
@@ -220,8 +218,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- LogRegAlg.run(data, initialWeights),
+ LogRegAlg,
data,
initialWeightsBA)
}
@@ -249,7 +246,14 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): KMeansModel = {
- KMeans.train(data.rdd, k, maxIterations, runs, initializationMode)
+ val kMeansAlg = new KMeans()
+ .setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
+ .disableUncachedWarning()
+ return kMeansAlg.run(data.rdd)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index fce8fe29f6..7443f232ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -27,6 +27,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -112,11 +113,26 @@ class KMeans private (
this
}
+ /** Whether a warning should be logged if the input RDD is uncached. */
+ private var warnOnUncachedInput = true
+
+ /** Disable warnings about uncached input. */
+ private[spark] def disableUncachedWarning(): this.type = {
+ warnOnUncachedInput = false
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
+
+ if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
// Compute squared norms and cache them.
val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
norms.persist()
@@ -125,6 +141,12 @@ class KMeans private (
}
val model = runBreeze(breezeData)
norms.unpersist()
+
+ // Warn at the end of the run as well, for increased visibility.
+ if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 2e414a73be..4174f45d23 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
+import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
@@ -231,6 +232,10 @@ class RowMatrix(
val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G)
(sigmaSquaresFull, uFull)
case SVDMode.DistARPACK =>
+ if (rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.")
EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter)
}
@@ -256,6 +261,12 @@ class RowMatrix(
logWarning(s"Requested $k singular values but only found $sk nonzeros.")
}
+ // Warn at the end of the run as well, for increased visibility.
+ if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk))
val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 20c1fdd226..d0fe417968 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLUtils._
+import org.apache.spark.storage.StorageLevel
/**
* :: DeveloperApi ::
@@ -133,6 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
+ /** Whether a warning should be logged if the input RDD is uncached. */
+ private var warnOnUncachedInput = true
+
+ /** Disable warnings about uncached input. */
+ private[spark] def disableUncachedWarning(): this.type = {
+ warnOnUncachedInput = false
+ this
+ }
+
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -149,6 +159,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+ if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
// Check the data properties before running the optimizer
if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
@@ -223,6 +238,12 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
weights = scaler.transform(weights)
}
+ // Warn at the end of the run as well, for increased visibility.
+ if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
createModel(weights, intercept)
}
}