aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorAaron Staple <aaron.staple@gmail.com>2014-09-25 16:11:00 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-25 16:11:00 -0700
commitff637c9380a6342fd0a4dde0710ec23856751dd4 (patch)
treef126756bb8adecd89014b5c18df60c4a6b44bf19 /mllib/src
parent9b56e249e09d8da20f703b9381c5c3c8a1a1d4a9 (diff)
downloadspark-ff637c9380a6342fd0a4dde0710ec23856751dd4.tar.gz
spark-ff637c9380a6342fd0a4dde0710ec23856751dd4.tar.bz2
spark-ff637c9380a6342fd0a4dde0710ec23856751dd4.zip
[SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data.
Add warnings to KMeans, GeneralizedLinearAlgorithm, and computeSVD when called with input data that is not cached. KMeans is implemented iteratively, and I believe that GeneralizedLinearAlgorithm’s current optimizers are iterative and its future optimizers are also likely to be iterative. RowMatrix’s computeSVD is iterative against an RDD when run in DistARPACK mode. ALS and DecisionTree are iterative as well, but they implement RDD caching internally so do not require a warning. I added a warning to GeneralizedLinearAlgorithm rather than inside its optimizers, where the iteration actually occurs, because internally GeneralizedLinearAlgorithm maps its input data to an uncached RDD before passing it to an optimizer. (In other words, the warning would be printed for every GeneralizedLinearAlgorithm run, regardless of whether its input is cached, if the warning were in GradientDescent or other optimizer.) I assume that use of an uncached RDD by GeneralizedLinearAlgorithm is intentional, and that the mapping there (adding label, intercepts and scaling) is a lightweight operation. Arguably a user calling an optimizer such as GradientDescent will be knowledgable enough to cache their data without needing a log warning, so lack of a warning in the optimizers may be ok. Some of the documentation examples making use of these iterative algorithms did not cache their training RDDs (while others did). I updated the examples to always cache. I also fixed some (unrelated) minor errors in the documentation examples. Author: Aaron Staple <aaron.staple@gmail.com> Closes #2347 from staple/SPARK-1484 and squashes the following commits: bd49701 [Aaron Staple] Address review comments. ab2d4a4 [Aaron Staple] Disable warnings on python code path. a7a0f99 [Aaron Staple] Change code comments per review comments. 7cca1dc [Aaron Staple] Change warning message text. c77e939 [Aaron Staple] [SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data. 3b6c511 [Aaron Staple] Minor doc example fixes.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala54
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala21
4 files changed, 83 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9164c294ac..e9f4175858 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -67,11 +67,13 @@ class PythonMLLibAPI extends Serializable {
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
private def trainRegressionModel(
- trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
+ learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
- val model = trainFunc(data.rdd, initialWeights)
+ // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
+ learner.disableUncachedWarning()
+ val model = learner.run(data.rdd, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(SerDe.dumps(model.weights))
ret.add(model.intercept: java.lang.Double)
@@ -106,8 +108,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- lrAlg.run(data, initialWeights),
+ lrAlg,
data,
initialWeightsBA)
}
@@ -122,15 +123,14 @@ class PythonMLLibAPI extends Serializable {
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ val lassoAlg = new LassoWithSGD()
+ lassoAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
trainRegressionModel(
- (data, initialWeights) =>
- LassoWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ lassoAlg,
data,
initialWeightsBA)
}
@@ -145,15 +145,14 @@ class PythonMLLibAPI extends Serializable {
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ val ridgeAlg = new RidgeRegressionWithSGD()
+ ridgeAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
trainRegressionModel(
- (data, initialWeights) =>
- RidgeRegressionWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ ridgeAlg,
data,
initialWeightsBA)
}
@@ -186,8 +185,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- SVMAlg.run(data, initialWeights),
+ SVMAlg,
data,
initialWeightsBA)
}
@@ -220,8 +218,7 @@ class PythonMLLibAPI extends Serializable {
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
- (data, initialWeights) =>
- LogRegAlg.run(data, initialWeights),
+ LogRegAlg,
data,
initialWeightsBA)
}
@@ -249,7 +246,14 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): KMeansModel = {
- KMeans.train(data.rdd, k, maxIterations, runs, initializationMode)
+ val kMeansAlg = new KMeans()
+ .setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
+ .disableUncachedWarning()
+ return kMeansAlg.run(data.rdd)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index fce8fe29f6..7443f232ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -27,6 +27,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -112,11 +113,26 @@ class KMeans private (
this
}
+ /** Whether a warning should be logged if the input RDD is uncached. */
+ private var warnOnUncachedInput = true
+
+ /** Disable warnings about uncached input. */
+ private[spark] def disableUncachedWarning(): this.type = {
+ warnOnUncachedInput = false
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
+
+ if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
// Compute squared norms and cache them.
val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
norms.persist()
@@ -125,6 +141,12 @@ class KMeans private (
}
val model = runBreeze(breezeData)
norms.unpersist()
+
+ // Warn at the end of the run as well, for increased visibility.
+ if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 2e414a73be..4174f45d23 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
+import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
@@ -231,6 +232,10 @@ class RowMatrix(
val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G)
(sigmaSquaresFull, uFull)
case SVDMode.DistARPACK =>
+ if (rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.")
EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter)
}
@@ -256,6 +261,12 @@ class RowMatrix(
logWarning(s"Requested $k singular values but only found $sk nonzeros.")
}
+ // Warn at the end of the run as well, for increased visibility.
+ if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk))
val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 20c1fdd226..d0fe417968 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLUtils._
+import org.apache.spark.storage.StorageLevel
/**
* :: DeveloperApi ::
@@ -133,6 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
+ /** Whether a warning should be logged if the input RDD is uncached. */
+ private var warnOnUncachedInput = true
+
+ /** Disable warnings about uncached input. */
+ private[spark] def disableUncachedWarning(): this.type = {
+ warnOnUncachedInput = false
+ this
+ }
+
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -149,6 +159,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+ if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
// Check the data properties before running the optimizer
if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
@@ -223,6 +238,12 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
weights = scaler.transform(weights)
}
+ // Warn at the end of the run as well, for increased visibility.
+ if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
createModel(weights, intercept)
}
}