aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-21 15:02:31 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-21 15:02:31 -0800
commitce95bd8e130b2c7688b94be40683bdd90d86012d (patch)
tree396d4e26517f3fc6e84a904ca6466ffb5da2f222 /mllib
parenta81918c5a66fc6040f9796fc1a9d4e0bfb8d0cbe (diff)
downloadspark-ce95bd8e130b2c7688b94be40683bdd90d86012d.tar.gz
spark-ce95bd8e130b2c7688b94be40683bdd90d86012d.tar.bz2
spark-ce95bd8e130b2c7688b94be40683bdd90d86012d.zip
[SPARK-4531] [MLlib] cache serialized java object
The Pyrolite is pretty slow (comparing to the adhoc serializer in 1.1), it cause much performance regression in 1.2, because we cache the serialized Python object in JVM, deserialize them into Java object in each step. This PR change to cache the deserialized JavaRDD instead of PythonRDD to avoid the deserialization of Pyrolite. It should have similar memory usage as before, but much faster. Author: Davies Liu <davies@databricks.com> Closes #3397 from davies/cache and squashes the following commits: 7f6e6ce [Davies Liu] Update -> Updater 4b52edd [Davies Liu] using named argument 63b984e [Davies Liu] fix 7da0332 [Davies Liu] add unpersist() dff33e1 [Davies Liu] address comments c2bdfc2 [Davies Liu] refactor d572f00 [Davies Liu] Merge branch 'master' into cache f1063e1 [Davies Liu] cache serialized java object
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala13
3 files changed, 56 insertions, 80 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index b6f7618171..f04df1c156 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable {
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeights: Vector): JList[Object] = {
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- learner.disableUncachedWarning()
- val model = learner.run(data.rdd, initialWeights)
- List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ try {
+ val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
+
+ /**
+ * Return the Updater from string
+ */
+ def getUpdaterFromString(regType: String): Updater = {
+ if (regType == "l2") {
+ new SquaredL2Updater
+ } else if (regType == "l1") {
+ new L1Updater
+ } else if (regType == null || regType == "none") {
+ new SimpleUpdater
+ } else {
+ throw new IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+ }
}
/**
@@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- lrAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- lrAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- lrAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
lrAlg,
data,
@@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- SVMAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- SVMAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
SVMAlg,
data,
@@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
@@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setNumCorrections(corrections)
.setConvergenceTol(tolerance)
- if (regType == "l2") {
- LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType == null) {
- LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
- } else {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: ['l1', 'l2', None].")
- }
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
@@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- .disableUncachedWarning()
- kMeansAlg.run(data.rdd)
+ try {
+ kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
}
/**
@@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable {
numPartitions: Int,
numIterations: Int,
seed: Long): Word2VecModelWrapper = {
- val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
- val model = word2vec.fit(data)
- data.unpersist()
- new Word2VecModelWrapper(model)
+ try {
+ val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
+ new Word2VecModelWrapper(model)
+ } finally {
+ dataJRDD.rdd.unpersist(blocking = false)
+ }
}
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
@@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
-
- DecisionTree.train(data.rdd, strategy)
+ try {
+ DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy)
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
}
/**
@@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable {
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
- if (algo == Algo.Classification) {
- RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
- } else {
- RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
+ val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+ try {
+ if (algo == Algo.Classification) {
+ RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ } else {
+ RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+ } finally {
+ cached.unpersist(blocking = false)
}
}
@@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable {
def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
if (obj == this) {
out.write(Opcodes.GLOBAL)
- out.write((module + "\n" + name + "\n").getBytes())
+ out.write((module + "\n" + name + "\n").getBytes)
} else {
pickler.save(this) // it will be memorized by Pickler
saveState(obj, out, pickler)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 7443f232ec..34ea0de706 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -113,22 +113,13 @@ class KMeans private (
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -143,7 +134,7 @@ class KMeans private (
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 00dfc86c9e..0287f04e2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}