aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-18 18:34:01 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-18 18:34:01 -0800
commite99d3392068bc929c900a4cc7b50e9e2b437a23a (patch)
treeecc75c1b1d75173742d95c4156cce653c12418b2 /mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
parent59a501359a267fbdb7689058693aa788703e54b1 (diff)
downloadspark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.gz
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.bz2
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.zip
[SPARK-11839][ML] refactor save/write traits
* add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9827 from mengxr/SPARK-11839.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala27
1 files changed, 11 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index d92514d2e2..795b73c4c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -185,7 +185,7 @@ class ALSModel private[ml] (
val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
- extends Model[ALSModel] with ALSModelParams with Writable {
+ extends Model[ALSModel] with ALSModelParams with MLWritable {
/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
@@ -225,19 +225,19 @@ class ALSModel private[ml] (
}
@Since("1.6.0")
- override def write: Writer = new ALSModel.ALSModelWriter(this)
+ override def write: MLWriter = new ALSModel.ALSModelWriter(this)
}
@Since("1.6.0")
-object ALSModel extends Readable[ALSModel] {
+object ALSModel extends MLReadable[ALSModel] {
@Since("1.6.0")
- override def read: Reader[ALSModel] = new ALSModelReader
+ override def read: MLReader[ALSModel] = new ALSModelReader
@Since("1.6.0")
- override def load(path: String): ALSModel = read.load(path)
+ override def load(path: String): ALSModel = super.load(path)
- private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer {
+ private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
@@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] {
}
}
- private[recommendation] class ALSModelReader extends Reader[ALSModel] {
+ private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.recommendation.ALSModel"
@@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] {
* preferences rather than explicit ratings given to items.
*/
@Experimental
-class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable {
+class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams
+ with DefaultParamsWritable {
import org.apache.spark.ml.recommendation.ALS.Rating
@@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w
}
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
-
- @Since("1.6.0")
- override def write: Writer = new DefaultParamsWriter(this)
}
@@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w
* than 2 billion.
*/
@DeveloperApi
-object ALS extends Readable[ALS] with Logging {
+object ALS extends DefaultParamsReadable[ALS] with Logging {
/**
* :: DeveloperApi ::
@@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging {
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
@Since("1.6.0")
- override def read: Reader[ALS] = new DefaultParamsReader[ALS]
-
- @Since("1.6.0")
- override def load(path: String): ALS = read.load(path)
+ override def load(path: String): ALS = super.load(path)
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {