diff options
Diffstat (limited to 'mllib')
4 files changed, 83 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9164c294ac..e9f4175858 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -67,11 +67,13 @@ class PythonMLLibAPI extends Serializable { MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) private def trainRegressionModel( - trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, + learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] - val model = trainFunc(data.rdd, initialWeights) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + learner.disableUncachedWarning() + val model = learner.run(data.rdd, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(SerDe.dumps(model.weights)) ret.add(model.intercept: java.lang.Double) @@ -106,8 +108,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - lrAlg.run(data, initialWeights), + lrAlg, data, initialWeightsBA) } @@ -122,15 +123,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val lassoAlg = new LassoWithSGD() + lassoAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - LassoWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + lassoAlg, data, initialWeightsBA) } @@ -145,15 +145,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - RidgeRegressionWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + ridgeAlg, data, initialWeightsBA) } @@ -186,8 +185,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - SVMAlg.run(data, initialWeights), + SVMAlg, data, initialWeightsBA) } @@ -220,8 +218,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - LogRegAlg.run(data, initialWeights), + LogRegAlg, data, initialWeightsBA) } @@ -249,7 +246,14 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): KMeansModel = { - KMeans.train(data.rdd, k, maxIterations, runs, initializationMode) + val kMeansAlg = new KMeans() + .setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + .disableUncachedWarning() + return kMeansAlg.run(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fce8fe29f6..7443f232ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -112,11 +113,26 @@ class KMeans private ( this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { + + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Compute squared norms and cache them. val norms = data.map(v => breezeNorm(v.toBreeze, 2.0)) norms.persist() @@ -125,6 +141,12 @@ class KMeans private ( } val model = runBreeze(breezeData) norms.unpersist() + + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } model } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2e414a73be..4174f45d23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: @@ -231,6 +232,10 @@ class RowMatrix( val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => + if (rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter) } @@ -256,6 +261,12 @@ class RowMatrix( logWarning(s"Requested $k singular values but only found $sk nonzeros.") } + // Warn at the end of the run as well, for increased visibility. + if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk)) val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 20c1fdd226..d0fe417968 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -133,6 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -149,6 +159,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { throw new SparkException("Input validation failed.") @@ -223,6 +238,12 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] weights = scaler.transform(weights) } + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + createModel(weights, intercept) } } |