aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala13
3 files changed, 56 insertions, 80 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index b6f7618171..f04df1c156 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable {
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeights: Vector): JList[Object] = {
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- learner.disableUncachedWarning()
- val model = learner.run(data.rdd, initialWeights)
- List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ try {
+ val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
+
+ /**
+ * Return the Updater from string
+ */
+ def getUpdaterFromString(regType: String): Updater = {
+ if (regType == "l2") {
+ new SquaredL2Updater
+ } else if (regType == "l1") {
+ new L1Updater
+ } else if (regType == null || regType == "none") {
+ new SimpleUpdater
+ } else {
+ throw new IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+ }
}
/**
@@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- lrAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- lrAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- lrAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
lrAlg,
data,
@@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- SVMAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- SVMAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
SVMAlg,
data,
@@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
@@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setNumCorrections(corrections)
.setConvergenceTol(tolerance)
- if (regType == "l2") {
- LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
@@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- .disableUncachedWarning()
- kMeansAlg.run(data.rdd)
+ try {
+ kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
}
/**
@@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable {
numPartitions: Int,
numIterations: Int,
seed: Long): Word2VecModelWrapper = {
- val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
- val model = word2vec.fit(data)
- data.unpersist()
- new Word2VecModelWrapper(model)
+ try {
+ val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
+ new Word2VecModelWrapper(model)
+ } finally {
+ dataJRDD.rdd.unpersist(blocking = false)
+ }
}
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
@@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
-
- DecisionTree.train(data.rdd, strategy)
+ try {
+ DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy)
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
}
/**
@@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable {
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
- if (algo == Algo.Classification) {
- RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
- } else {
- RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
+ val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+ try {
+ if (algo == Algo.Classification) {
+ RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ } else {
+ RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+ } finally {
+ cached.unpersist(blocking = false)
}
}
@@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable {
def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
if (obj == this) {
out.write(Opcodes.GLOBAL)
- out.write((module + "\n" + name + "\n").getBytes())
+ out.write((module + "\n" + name + "\n").getBytes)
} else {
pickler.save(this) // it will be memorized by Pickler
saveState(obj, out, pickler)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 7443f232ec..34ea0de706 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -113,22 +113,13 @@ class KMeans private (
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -143,7 +134,7 @@ class KMeans private (
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 00dfc86c9e..0287f04e2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}