aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala17
4 files changed, 82 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 1a70322b4c..5d660d1e15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
/** @group getParam */
def getOutputCol: String = get(outputCol)
}
+
+private[ml] trait HasCheckpointInterval extends Params {
+ /**
+ * param for checkpoint interval
+ * @group param
+ */
+ val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+
+ /** @group getParam */
+ def getCheckpointInterval: Int = get(checkpointInterval)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index e3515ee81a..514b4ef98d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.recommendation
import java.{util => ju}
+import java.io.IOException
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.netlib.util.intW
import org.apache.spark.{Logging, Partitioner}
@@ -46,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol {
+ with HasPredictionCol with HasCheckpointInterval {
/**
* Param for rank of the matrix factorization.
@@ -164,6 +166,7 @@ class ALSModel private[ml] (
itemFactors: RDD[(Int, Array[Float])])
extends Model[ALSModel] with ALSParams {
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
@@ -262,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -274,6 +280,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
+ setCheckpointInterval(10)
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
@@ -285,7 +292,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha), nonnegative = map(nonnegative))
+ alpha = map(alpha), nonnegative = map(nonnegative),
+ checkpointInterval = map(checkpointInterval))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
@@ -494,6 +502,7 @@ object ALS extends Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -521,6 +530,18 @@ object ALS extends Logging {
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+ var previousCheckpointFile: Option[String] = None
+ val shouldCheckpoint: Int => Boolean = (iter) =>
+ sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+ val deletePreviousCheckpointFile: () => Unit = () =>
+ previousCheckpointFile.foreach { file =>
+ try {
+ FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
+ } catch {
+ case e: IOException =>
+ logWarning(s"Cannot delete checkpoint file $file:", e)
+ }
+ }
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
@@ -528,19 +549,30 @@ object ALS extends Logging {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
- if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
- itemFactors.checkpoint()
- }
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+ // TODO: Generalize PeriodicGraphCheckpointer and use it here.
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
+ }
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
+ if (shouldCheckpoint(iter)) {
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint()
+ itemFactors.count() // checkpoint item factors and cut lineage
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index caacab9430..dddefe1944 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -82,6 +82,9 @@ class ALS private (
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+ /** checkpoint interval */
+ private var checkpointInterval: Int = 10
+
/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
@@ -183,6 +186,19 @@ class ALS private (
}
/**
+ * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with
+ * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps
+ * with eliminating temporary shuffle files on disk, which can be important when there are many
+ * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]],
+ * this setting is ignored.
+ */
+ @DeveloperApi
+ def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
+ /**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
*/
@@ -212,6 +228,7 @@ class ALS private (
nonnegative = nonnegative,
intermediateRDDStorageLevel = intermediateRDDStorageLevel,
finalRDDStorageLevel = StorageLevel.NONE,
+ checkpointInterval = checkpointInterval,
seed = seed)
val userFactors = floatUserFactors
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bb86bafc0e..0bb06e9e8a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.recommendation
+import java.io.File
import java.util.Random
import scala.collection.mutable
@@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.util.Utils
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
private var sqlContext: SQLContext = _
+ private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
+ tempDir = Utils.createTempDir()
+ sc.setCheckpointDir(tempDir.getAbsolutePath)
sqlContext = new SQLContext(sc)
}
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
test("LocalIndexEncoder") {
val random = new Random
for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
@@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}.count()
}
}
+
+ test("als with large number of iterations") {
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
+ ALS.train(
+ ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
+ }
}