aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-04-29 22:01:41 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 22:01:41 -0700
commit90fa2c6e7f4893af51e0cfb6dc162b828ea55995 (patch)
treecad8481dab7030e32f8001caab2f67fd99b3c49e /mllib
parentd7755cfd07554c132b7271730102b8b68eb56b28 (diff)
downloadspark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.tar.gz
spark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.tar.bz2
spark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.zip
[SPARK-14412][ML][PYSPARK] Add StorageLevel params to ALS
`mllib` `ALS` supports `setIntermediateRDDStorageLevel` and `setFinalRDDStorageLevel`. This PR adds these as Params in `ml` `ALS`. They are put in group **expertParam** since few users will need them. ## How was this patch tested? New test cases in `ALSSuite` and `tests.py`. cc yanboliang jkbradley sethah rishabhbhardwaj Author: Nick Pentreath <nickp@za.ibm.com> Closes #12660 from MLnick/SPARK-14412-als-storage-params.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala54
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala81
2 files changed, 129 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index cbcbfe8249..55cea800d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -22,7 +22,7 @@ import java.io.IOException
import scala.collection.mutable
import scala.reflect.ClassTag
-import scala.util.Sorting
+import scala.util.{Sorting, Try}
import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -153,12 +153,42 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group getParam */
def getNonnegative: Boolean = $(nonnegative)
+ /**
+ * Param for StorageLevel for intermediate RDDs. Pass in a string representation of
+ * [[StorageLevel]]. Cannot be "NONE".
+ * Default: "MEMORY_AND_DISK".
+ *
+ * @group expertParam
+ */
+ val intermediateRDDStorageLevel = new Param[String](this, "intermediateRDDStorageLevel",
+ "StorageLevel for intermediate RDDs. Cannot be 'NONE'. Default: 'MEMORY_AND_DISK'.",
+ (s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE")
+
+ /** @group expertGetParam */
+ def getIntermediateRDDStorageLevel: String = $(intermediateRDDStorageLevel)
+
+ /**
+ * Param for StorageLevel for ALS model factor RDDs. Pass in a string representation of
+ * [[StorageLevel]].
+ * Default: "MEMORY_AND_DISK".
+ *
+ * @group expertParam
+ */
+ val finalRDDStorageLevel = new Param[String](this, "finalRDDStorageLevel",
+ "StorageLevel for ALS model factor RDDs. Default: 'MEMORY_AND_DISK'.",
+ (s: String) => Try(StorageLevel.fromString(s)).isSuccess)
+
+ /** @group expertGetParam */
+ def getFinalRDDStorageLevel: String = $(finalRDDStorageLevel)
+
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
- ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
+ ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
+ intermediateRDDStorageLevel -> "MEMORY_AND_DISK", finalRDDStorageLevel -> "MEMORY_AND_DISK")
/**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -374,8 +404,21 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("1.3.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setIntermediateRDDStorageLevel(value: String): this.type = {
+ set(intermediateRDDStorageLevel, value)
+ }
+
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setFinalRDDStorageLevel(value: String): this.type = {
+ set(finalRDDStorageLevel, value)
+ }
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
+ *
* @group setParam
*/
@Since("1.3.0")
@@ -403,6 +446,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
+ intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateRDDStorageLevel)),
+ finalRDDStorageLevel = StorageLevel.fromString($(finalRDDStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed))
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
@@ -754,7 +799,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* ratings are associated with srcIds(i).
* @param dstEncodedIndices encoded dst indices
* @param ratings ratings
- *
* @see [[LocalIndexEncoder]]
*/
private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
@@ -850,7 +894,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param ratings raw ratings
* @param srcPart partitioner for src IDs
* @param dstPart partitioner for dst IDs
- *
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
*/
private def partitionRatings[ID: ClassTag](
@@ -899,6 +942,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/**
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
+ *
* @param encoder encoder for dst indices
*/
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
@@ -1099,6 +1143,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/**
* Creates in-blocks and out-blocks from rating blocks.
+ *
* @param prefix prefix for in/out-block names
* @param ratingBlocks rating blocks
* @param srcPart partitioner for src IDs
@@ -1187,7 +1232,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param implicitPrefs whether to use implicit preference
* @param alpha the alpha constant in the implicit preference formulation
* @param solver solver for least squares problems
- *
* @return dst factors
*/
private def computeFactors[ID](
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index dac76aa7a1..2e5c6a4f20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -33,7 +33,9 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.storage.StorageLevel
class ALSSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
@@ -198,6 +200,7 @@ class ALSSuite
/**
* Generates an explicit feedback dataset for testing ALS.
+ *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -238,6 +241,7 @@ class ALSSuite
/**
* Generates an implicit feedback dataset for testing ALS.
+ *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -286,6 +290,7 @@ class ALSSuite
/**
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
+ *
* @param size number of users/items
* @param rank number of features
* @param random random number generator
@@ -311,6 +316,7 @@ class ALSSuite
/**
* Test ALS using the given training/test splits and parameters.
+ *
* @param training training dataset
* @param test test dataset
* @param rank rank of the matrix factorization
@@ -514,6 +520,77 @@ class ALSSuite
}
}
+class ALSStorageSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
+
+ test("invalid storage params") {
+ intercept[IllegalArgumentException] {
+ new ALS().setIntermediateRDDStorageLevel("foo")
+ }
+ intercept[IllegalArgumentException] {
+ new ALS().setIntermediateRDDStorageLevel("NONE")
+ }
+ intercept[IllegalArgumentException] {
+ new ALS().setFinalRDDStorageLevel("foo")
+ }
+ }
+
+ test("default and non-default storage params set correct RDD StorageLevels") {
+ val sqlContext = this.sqlContext
+ import sqlContext.implicits._
+ val data = Seq(
+ (0, 0, 1.0),
+ (0, 1, 2.0),
+ (1, 2, 3.0),
+ (1, 0, 2.0)
+ ).toDF("user", "item", "rating")
+ val als = new ALS().setMaxIter(1).setRank(1)
+ // add listener to check intermediate RDD default storage levels
+ val defaultListener = new IntermediateRDDStorageListener
+ sc.addSparkListener(defaultListener)
+ val model = als.fit(data)
+ // check final factor RDD default storage levels
+ val defaultFactorRDDs = sc.getPersistentRDDs.collect {
+ case (id, rdd) if rdd.name == "userFactors" || rdd.name == "itemFactors" =>
+ rdd.name -> (id, rdd.getStorageLevel)
+ }.toMap
+ defaultFactorRDDs.foreach { case (_, (id, level)) =>
+ assert(level == StorageLevel.MEMORY_AND_DISK)
+ }
+ defaultListener.storageLevels.foreach(level => assert(level == StorageLevel.MEMORY_AND_DISK))
+
+ // add listener to check intermediate RDD non-default storage levels
+ val nonDefaultListener = new IntermediateRDDStorageListener
+ sc.addSparkListener(nonDefaultListener)
+ val nonDefaultModel = als
+ .setFinalRDDStorageLevel("MEMORY_ONLY")
+ .setIntermediateRDDStorageLevel("DISK_ONLY")
+ .fit(data)
+ // check final factor RDD non-default storage levels
+ val levels = sc.getPersistentRDDs.collect {
+ case (id, rdd) if rdd.name == "userFactors" && rdd.id != defaultFactorRDDs("userFactors")._1
+ || rdd.name == "itemFactors" && rdd.id != defaultFactorRDDs("itemFactors")._1 =>
+ rdd.getStorageLevel
+ }
+ levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
+ nonDefaultListener.storageLevels.foreach(level => assert(level == StorageLevel.DISK_ONLY))
+ }
+}
+
+private class IntermediateRDDStorageListener extends SparkListener {
+
+ val storageLevels: mutable.ArrayBuffer[StorageLevel] = mutable.ArrayBuffer()
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ val stageLevels = stageCompleted.stageInfo.rddInfos.collect {
+ case info if info.name.contains("Blocks") || info.name.contains("Factors-") =>
+ info.storageLevel
+ }
+ storageLevels ++= stageLevels
+ }
+
+}
+
object ALSSuite {
/**
@@ -539,6 +616,8 @@ object ALSSuite {
"implicitPrefs" -> true,
"alpha" -> 0.9,
"nonnegative" -> true,
- "checkpointInterval" -> 20
+ "checkpointInterval" -> 20,
+ "intermediateRDDStorageLevel" -> "MEMORY_ONLY",
+ "finalRDDStorageLevel" -> "MEMORY_AND_DISK_SER"
)
}