aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala45
1 files changed, 30 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 8ebc7e27ed..84d192db53 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -111,11 +111,17 @@ class ALS private (
*/
def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
+ /** If true, do alternating nonnegative least squares. */
+ private var nonnegative = false
+
+ /** storage level for user/product in/out links */
+ private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+
/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
*/
- def setBlocks(numBlocks: Int): ALS = {
+ def setBlocks(numBlocks: Int): this.type = {
this.numUserBlocks = numBlocks
this.numProductBlocks = numBlocks
this
@@ -124,7 +130,7 @@ class ALS private (
/**
* Set the number of user blocks to parallelize the computation.
*/
- def setUserBlocks(numUserBlocks: Int): ALS = {
+ def setUserBlocks(numUserBlocks: Int): this.type = {
this.numUserBlocks = numUserBlocks
this
}
@@ -132,31 +138,31 @@ class ALS private (
/**
* Set the number of product blocks to parallelize the computation.
*/
- def setProductBlocks(numProductBlocks: Int): ALS = {
+ def setProductBlocks(numProductBlocks: Int): this.type = {
this.numProductBlocks = numProductBlocks
this
}
/** Set the rank of the feature matrices computed (number of features). Default: 10. */
- def setRank(rank: Int): ALS = {
+ def setRank(rank: Int): this.type = {
this.rank = rank
this
}
/** Set the number of iterations to run. Default: 10. */
- def setIterations(iterations: Int): ALS = {
+ def setIterations(iterations: Int): this.type = {
this.iterations = iterations
this
}
/** Set the regularization parameter, lambda. Default: 0.01. */
- def setLambda(lambda: Double): ALS = {
+ def setLambda(lambda: Double): this.type = {
this.lambda = lambda
this
}
/** Sets whether to use implicit preference. Default: false. */
- def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
+ def setImplicitPrefs(implicitPrefs: Boolean): this.type = {
this.implicitPrefs = implicitPrefs
this
}
@@ -166,30 +172,39 @@ class ALS private (
* Sets the constant used in computing confidence in implicit ALS. Default: 1.0.
*/
@Experimental
- def setAlpha(alpha: Double): ALS = {
+ def setAlpha(alpha: Double): this.type = {
this.alpha = alpha
this
}
/** Sets a random seed to have deterministic results. */
- def setSeed(seed: Long): ALS = {
+ def setSeed(seed: Long): this.type = {
this.seed = seed
this
}
- /** If true, do alternating nonnegative least squares. */
- private var nonnegative = false
-
/**
* Set whether the least-squares problems solved at each iteration should have
* nonnegativity constraints.
*/
- def setNonnegative(b: Boolean): ALS = {
+ def setNonnegative(b: Boolean): this.type = {
this.nonnegative = b
this
}
/**
+ * :: DeveloperApi ::
+ * Sets storage level for intermediate RDDs (user/product in/out links). The default value is
+ * `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and
+ * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed.
+ */
+ @DeveloperApi
+ def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = {
+ this.intermediateRDDStorageLevel = storageLevel
+ this
+ }
+
+ /**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
*/
@@ -441,8 +456,8 @@ class ALS private (
}, preservesPartitioning = true)
val inLinks = links.mapValues(_._1)
val outLinks = links.mapValues(_._2)
- inLinks.persist(StorageLevel.MEMORY_AND_DISK)
- outLinks.persist(StorageLevel.MEMORY_AND_DISK)
+ inLinks.persist(intermediateRDDStorageLevel)
+ outLinks.persist(intermediateRDDStorageLevel)
(inLinks, outLinks)
}