aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-20 15:02:57 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-24 11:32:18 -0700
commitbc92a2e405241542770a64adfef39dcb02e96461 (patch)
tree5901ed47608a09a82651e8e1a1cc66ce4b098662 /mllib/src/main
parent4ff577160235c0ca82de8330702ed07293024de1 (diff)
downloadspark-bc92a2e405241542770a64adfef39dcb02e96461.tar.gz
spark-bc92a2e405241542770a64adfef39dcb02e96461.tar.bz2
spark-bc92a2e405241542770a64adfef39dcb02e96461.zip
[SPARK-5955][MLLIB] add checkpointInterval to ALS
Add checkpiontInterval to ALS to prevent: 1. StackOverflow exceptions caused by long lineage, 2. large shuffle files generated during iterations, 3. slow recovery when some node fail. srowen coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #5076 from mengxr/SPARK-5955 and squashes the following commits: df56791 [Xiangrui Meng] update impl to reuse code 29affcb [Xiangrui Meng] do not materialize factors in implicit 20d3f7f [Xiangrui Meng] add checkpointInterval to ALS (cherry picked from commit 6b36470c66bd6140c45e45d3f1d51b0082c3fd97) Signed-off-by: Xiangrui Meng <meng@databricks.com> Conflicts: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala17
3 files changed, 65 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 1a70322b4c..5d660d1e15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
/** @group getParam */
def getOutputCol: String = get(outputCol)
}
+
+private[ml] trait HasCheckpointInterval extends Params {
+ /**
+ * param for checkpoint interval
+ * @group param
+ */
+ val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+
+ /** @group getParam */
+ def getCheckpointInterval: Int = get(checkpointInterval)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 7bb69df653..058076d309 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.recommendation
import java.{util => ju}
+import java.io.IOException
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.jblas.DoubleMatrix
import org.netlib.util.intW
@@ -47,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol {
+ with HasPredictionCol with HasCheckpointInterval {
/**
* Param for rank of the matrix factorization.
@@ -165,6 +167,7 @@ class ALSModel private[ml] (
itemFactors: RDD[(Int, Array[Float])])
extends Model[ALSModel] with ALSParams {
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
@@ -263,6 +266,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -275,6 +281,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
+ setCheckpointInterval(10)
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
@@ -286,7 +293,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha), nonnegative = map(nonnegative))
+ alpha = map(alpha), nonnegative = map(nonnegative),
+ checkpointInterval = map(checkpointInterval))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
@@ -496,6 +504,7 @@ object ALS extends Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -523,6 +532,18 @@ object ALS extends Logging {
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+ var previousCheckpointFile: Option[String] = None
+ val shouldCheckpoint: Int => Boolean = (iter) =>
+ sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+ val deletePreviousCheckpointFile: () => Unit = () =>
+ previousCheckpointFile.foreach { file =>
+ try {
+ FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
+ } catch {
+ case e: IOException =>
+ logWarning(s"Cannot delete checkpoint file $file:", e)
+ }
+ }
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
@@ -530,19 +551,30 @@ object ALS extends Logging {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
- if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
- itemFactors.checkpoint()
- }
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+ // TODO: Generalize PeriodicGraphCheckpointer and use it here.
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
+ }
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
+ if (shouldCheckpoint(iter)) {
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint()
+ itemFactors.count() // checkpoint item factors and cut lineage
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index caacab9430..dddefe1944 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -82,6 +82,9 @@ class ALS private (
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+ /** checkpoint interval */
+ private var checkpointInterval: Int = 10
+
/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
@@ -183,6 +186,19 @@ class ALS private (
}
/**
+ * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with
+ * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps
+ * with eliminating temporary shuffle files on disk, which can be important when there are many
+ * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]],
+ * this setting is ignored.
+ */
+ @DeveloperApi
+ def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
+ /**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
*/
@@ -212,6 +228,7 @@ class ALS private (
nonnegative = nonnegative,
intermediateRDDStorageLevel = intermediateRDDStorageLevel,
finalRDDStorageLevel = StorageLevel.NONE,
+ checkpointInterval = checkpointInterval,
seed = seed)
val userFactors = floatUserFactors