aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-04-09 02:21:15 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-09 02:21:15 -0700
commitbde9cc11fee42a0a41ec52d5dc7fa0502ce94f77 (patch)
tree2999353e9a277154728775a9cddb045816d51ef3 /mllib
parent87bd1f9ef7d547ee54a8a83214b45462e0751efb (diff)
downloadspark-bde9cc11fee42a0a41ec52d5dc7fa0502ce94f77.tar.gz
spark-bde9cc11fee42a0a41ec52d5dc7fa0502ce94f77.tar.bz2
spark-bde9cc11fee42a0a41ec52d5dc7fa0502ce94f77.zip
[SPARK-1357] [MLLIB] Annotate developer and experimental APIs
Annotate developer and experimental APIs in MLlib. Author: Xiangrui Meng <meng@databricks.com> Closes #298 from mengxr/api and squashes the following commits: 13390e8 [Xiangrui Meng] Merge branch 'master' into api dc4cbb3 [Xiangrui Meng] mark distribute matrices experimental 6b9f8e2 [Xiangrui Meng] add Experimental annotation 8773d0d [Xiangrui Meng] add DeveloperApi annotation da31733 [Xiangrui Meng] update developer and experimental tags 555e0fe [Xiangrui Meng] Merge branch 'master' into api ef1a717 [Xiangrui Meng] mark some constructors private add default parameters to JavaDoc 00ffbcc [Xiangrui Meng] update tree API annotation 0b674fa [Xiangrui Meng] mark decision tree APIs 86b9e34 [Xiangrui Meng] one pass over APIs of GLMs, NaiveBayes, and ALS f21d862 [Xiangrui Meng] Merge branch 'master' into api 2b133d6 [Xiangrui Meng] intial annotation of developer and experimental apis
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala21
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala4
42 files changed, 355 insertions, 122 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2df5b0d02b..ae27c57799 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.api.python
import java.nio.{ByteBuffer, ByteOrder}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
@@ -28,8 +29,11 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.rdd.RDD
/**
+ * :: DeveloperApi ::
+ *
* The Java stubs necessary for the Python mllib bindings.
*/
+@DeveloperApi
class PythonMLLibAPI extends Serializable {
private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
val packetLength = bytes.length
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 798f3a5c94..4f9eaacf67 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -55,7 +55,7 @@ class LogisticRegressionModel(
this
}
- override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
+ override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
val score = 1.0/ (1.0 + math.exp(-margin))
@@ -71,27 +71,27 @@ class LogisticRegressionModel(
* NOTE: Labels used in Logistic Regression should be {0, 1}
*/
class LogisticRegressionWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var regParam: Double,
- var miniBatchFraction: Double)
+ private var stepSize: Double,
+ private var numIterations: Int,
+ private var regParam: Double,
+ private var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
- val gradient = new LogisticGradient()
- val updater = new SimpleUpdater()
+ private val gradient = new LogisticGradient()
+ private val updater = new SimpleUpdater()
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- override val validators = List(DataValidators.classificationLabels)
+ override protected val validators = List(DataValidators.binaryLabelValidator)
/**
* Construct a LogisticRegression object with default parameters
*/
def this() = this(1.0, 100, 0.0, 1.0)
- def createModel(weights: Vector, intercept: Double) = {
+ override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index e956185319..5a45f12f1a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
@@ -27,11 +28,16 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
/**
+ * :: Experimental ::
+ *
* Model for Naive Bayes Classifiers.
*
- * @param pi Log of class priors, whose dimension is C.
- * @param theta Log of class conditional probabilities, whose dimension is CxD.
+ * @param labels list of labels
+ * @param pi log of class priors, whose dimension is C, number of labels
+ * @param theta log of class conditional probabilities, whose dimension is C-by-D,
+ * where D is number of features
*/
+@Experimental
class NaiveBayesModel(
val labels: Array[Double],
val pi: Array[Double],
@@ -40,14 +46,17 @@ class NaiveBayesModel(
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
- var i = 0
- while (i < theta.length) {
- var j = 0
- while (j < theta(i).length) {
- brzTheta(i, j) = theta(i)(j)
- j += 1
+ {
+ // Need to put an extra pair of braces to prevent Scala treating `i` as a member.
+ var i = 0
+ while (i < theta.length) {
+ var j = 0
+ while (j < theta(i).length) {
+ brzTheta(i, j) = theta(i)(j)
+ j += 1
+ }
+ i += 1
}
- i += 1
}
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
@@ -65,7 +74,7 @@ class NaiveBayesModel(
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*/
-class NaiveBayes private (var lambda: Double) extends Serializable with Logging {
+class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
def this() = this(1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e31a08899f..956654b1fe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -55,7 +55,9 @@ class SVMModel(
this
}
- override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
@@ -70,28 +72,27 @@ class SVMModel(
* NOTE: Labels used in SVM should be {0, 1}.
*/
class SVMWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var regParam: Double,
- var miniBatchFraction: Double)
+ private var stepSize: Double,
+ private var numIterations: Int,
+ private var regParam: Double,
+ private var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
- val gradient = new HingeGradient()
- val updater = new SquaredL2Updater()
+ private val gradient = new HingeGradient()
+ private val updater = new SquaredL2Updater()
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
-
- override val validators = List(DataValidators.classificationLabels)
+ override protected val validators = List(DataValidators.binaryLabelValidator)
/**
* Construct a SVM object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Vector, intercept: Double) = {
+ override protected def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index a78503df31..8f565eb60a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -37,12 +38,17 @@ import org.apache.spark.util.random.XORShiftRandom
* to it should be cached by the user.
*/
class KMeans private (
- var k: Int,
- var maxIterations: Int,
- var runs: Int,
- var initializationMode: String,
- var initializationSteps: Int,
- var epsilon: Double) extends Serializable with Logging {
+ private var k: Int,
+ private var maxIterations: Int,
+ private var runs: Int,
+ private var initializationMode: String,
+ private var initializationSteps: Int,
+ private var epsilon: Double) extends Serializable with Logging {
+
+ /**
+ * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
+ * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
+ */
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
/** Set the number of clusters to create (k). Default: 2. */
@@ -71,6 +77,8 @@ class KMeans private (
}
/**
+ * :: Experimental ::
+ *
* Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm
* this many times with random starting conditions (configured by the initialization mode), then
* return the best clustering found over any run. Default: 1.
@@ -316,8 +324,8 @@ object KMeans {
data: RDD[Vector],
k: Int,
maxIterations: Int,
- runs: Int = 1,
- initializationMode: String = K_MEANS_PARALLEL): KMeansModel = {
+ runs: Int,
+ initializationMode: String): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
@@ -326,6 +334,27 @@ object KMeans {
}
/**
+ * Trains a k-means model using specified parameters and the default values for unspecified.
+ */
+ def train(
+ data: RDD[Vector],
+ k: Int,
+ maxIterations: Int): KMeansModel = {
+ train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
+ }
+
+ /**
+ * Trains a k-means model using specified parameters and the default values for unspecified.
+ */
+ def train(
+ data: RDD[Vector],
+ k: Int,
+ maxIterations: Int,
+ runs: Int): KMeansModel = {
+ train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
+ }
+
+ /**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib] def findClosest(
@@ -369,6 +398,10 @@ object KMeans {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}
+ /**
+ * :: Experimental ::
+ */
+ @Experimental
def main(args: Array[String]) {
if (args.length < 4) {
println("Usage: KMeans <master> <input_file> <k> <max_iterations> [<runs>]")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 2cea58cd3f..99a849f1c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -64,11 +64,13 @@ trait Vector extends Serializable {
/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
+ * We don't use the name `Vector` because Scala imports
+ * [[scala.collection.immutable.Vector]] by default.
*/
object Vectors {
/**
- * Creates a dense vector.
+ * Creates a dense vector from its values.
*/
@varargs
def dense(firstValue: Double, otherValues: Double*): Vector =
@@ -158,20 +160,21 @@ class DenseVector(val values: Array[Double]) extends Vector {
/**
* A sparse vector represented by an index array and an value array.
*
- * @param n size of the vector.
+ * @param size size of the vector.
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
-class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double]) extends Vector {
-
- override def size: Int = n
+class SparseVector(
+ override val size: Int,
+ val indices: Array[Int],
+ val values: Array[Double]) extends Vector {
override def toString: String = {
- "(" + n + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
+ "(" + size + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
}
override def toArray: Array[Double] = {
- val data = new Array[Double](n)
+ val data = new Array[Double](size)
var i = 0
val nnz = indices.length
while (i < nnz) {
@@ -181,5 +184,5 @@ class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double
data
}
- private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, n)
+ private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
index 9194f65749..89d5c03d76 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
import breeze.linalg.{DenseMatrix => BDM}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
@@ -32,6 +33,8 @@ import org.apache.spark.mllib.linalg.Vectors
case class MatrixEntry(i: Long, j: Long, value: Double)
/**
+ * :: Experimental ::
+ *
* Represents a matrix in coordinate format.
*
* @param entries matrix entries
@@ -40,6 +43,7 @@ case class MatrixEntry(i: Long, j: Long, value: Double)
* @param nCols number of columns. A non-positive value means unknown, and then the number of
* columns will be determined by the max column index plus one.
*/
+@Experimental
class CoordinateMatrix(
val entries: RDD[MatrixEntry],
private var nRows: Long,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
index 13f72a3c72..a0e26ce3bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
@@ -19,8 +19,6 @@ package org.apache.spark.mllib.linalg.distributed
import breeze.linalg.{DenseMatrix => BDM}
-import org.apache.spark.mllib.linalg.Matrix
-
/**
* Represents a distributively stored matrix backed by one or more RDDs.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index e110f070bd..24c123ab7e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -19,14 +19,22 @@ package org.apache.spark.mllib.linalg.distributed
import breeze.linalg.{DenseMatrix => BDM}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.SingularValueDecomposition
-/** Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */
+/**
+ * :: Experimental ::
+ *
+ * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]].
+ */
+@Experimental
case class IndexedRow(index: Long, vector: Vector)
/**
+ * :: Experimental ::
+ *
* Represents a row-oriented [[org.apache.spark.mllib.linalg.distributed.DistributedMatrix]] with
* indexed rows.
*
@@ -36,6 +44,7 @@ case class IndexedRow(index: Long, vector: Vector)
* @param nCols number of columns. A non-positive value means unknown, and then the number of
* columns will be determined by the size of the first row.
*/
+@Experimental
class IndexedRowMatrix(
val rows: RDD[IndexedRow],
private var nRows: Long,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index f59811f18a..8d32c1a6db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -23,11 +23,14 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
/**
+ * :: Experimental ::
+ *
* Represents a row-oriented distributed Matrix with no meaningful row indices.
*
* @param rows rows stored as an RDD[Vector]
@@ -36,6 +39,7 @@ import org.apache.spark.Logging
* @param nCols number of columns. A non-positive value means unknown, and then the number of
* columns will be determined by the size of the first row.
*/
+@Experimental
class RowMatrix(
val rows: RDD[Vector],
private var nRows: Long,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 2065428496..1176dc9dbc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -19,11 +19,15 @@ package org.apache.spark.mllib.optimization
import breeze.linalg.{axpy => brzAxpy}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
+ * :: DeveloperApi ::
+ *
* Class used to compute the gradient for a loss function, given a single data point.
*/
+@DeveloperApi
abstract class Gradient extends Serializable {
/**
* Compute the gradient and loss given the features of a single data point.
@@ -51,9 +55,12 @@ abstract class Gradient extends Serializable {
}
/**
+ * :: DeveloperApi ::
+ *
* Compute gradient and loss for a logistic loss function, as used in binary classification.
* See also the documentation for the precise formulation.
*/
+@DeveloperApi
class LogisticGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val brzData = data.toBreeze
@@ -92,11 +99,14 @@ class LogisticGradient extends Gradient {
}
/**
+ * :: DeveloperApi ::
+ *
* Compute gradient and loss for a Least-squared loss function, as used in linear regression.
* This is correct for the averaged least squares loss function (mean squared error)
* L = 1/n ||A weights-y||^2
* See also the documentation for the precise formulation.
*/
+@DeveloperApi
class LeastSquaresGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val brzData = data.toBreeze
@@ -124,10 +134,13 @@ class LeastSquaresGradient extends Gradient {
}
/**
+ * :: DeveloperApi ::
+ *
* Compute gradient and loss for a Hinge loss function, as used in SVM binary classification.
* See also the documentation for the precise formulation.
* NOTE: This assumes that the labels are {0,1}
*/
+@DeveloperApi
class HingeGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val brzData = data.toBreeze
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index d0777ffd63..04267d967d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -19,18 +19,22 @@ package org.apache.spark.mllib.optimization
import scala.collection.mutable.ArrayBuffer
-import breeze.linalg.{Vector => BV, DenseVector => BDV}
+import breeze.linalg.{DenseVector => BDV}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
+ * :: DeveloperApi ::
+ *
* Class used to solve an optimization problem using Gradient Descent.
* @param gradient Gradient function to be used.
* @param updater Updater to be used to update weights after every iteration.
*/
-class GradientDescent(var gradient: Gradient, var updater: Updater)
+@DeveloperApi
+class GradientDescent(private var gradient: Gradient, private var updater: Updater)
extends Optimizer with Logging
{
private var stepSize: Double = 1.0
@@ -107,7 +111,12 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
-// Top-level method to run gradient descent.
+/**
+ * :: DeveloperApi ::
+ *
+ * Top-level method to run gradient descent.
+ */
+@DeveloperApi
object GradientDescent extends Logging {
/**
* Run stochastic gradient descent (SGD) in parallel using mini batches.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
index f9ce908a5f..0a313f3104 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
@@ -19,8 +19,15 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for optimization problem solvers.
+ */
+@DeveloperApi
trait Optimizer extends Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 3b7754cd7a..e67816796c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -21,9 +21,12 @@ import scala.math._
import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
+ * :: DeveloperApi ::
+ *
* Class used to perform steps (weight update) using Gradient Descent methods.
*
* For general minimization problems, or for regularized problems of the form
@@ -35,6 +38,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
* The updater is responsible to also perform the update coming from the
* regularization term R(w) (if any regularization is used).
*/
+@DeveloperApi
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize, iteration number and
@@ -59,9 +63,12 @@ abstract class Updater extends Serializable {
}
/**
+ * :: DeveloperApi ::
+ *
* A simple updater for gradient descent *without* any regularization.
* Uses a step-size decreasing with the square root of the number of iterations.
*/
+@DeveloperApi
class SimpleUpdater extends Updater {
override def compute(
weightsOld: Vector,
@@ -78,6 +85,8 @@ class SimpleUpdater extends Updater {
}
/**
+ * :: DeveloperApi ::
+ *
* Updater for L1 regularized problems.
* R(w) = ||w||_1
* Uses a step-size decreasing with the square root of the number of iterations.
@@ -95,6 +104,7 @@ class SimpleUpdater extends Updater {
*
* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
*/
+@DeveloperApi
class L1Updater extends Updater {
override def compute(
weightsOld: Vector,
@@ -120,10 +130,13 @@ class L1Updater extends Updater {
}
/**
+ * :: DeveloperApi ::
+ *
* Updater for L2 regularized problems.
* R(w) = 1/2 ||w||^2
* Uses a step-size decreasing with the square root of the number of iterations.
*/
+@DeveloperApi
class SquaredL2Updater extends Updater {
override def compute(
weightsOld: Vector,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 3124fac326..60cbb1c1e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -22,6 +22,10 @@ import scala.math.{abs, sqrt}
import scala.util.Random
import scala.util.Sorting
+import com.esotericsoftware.kryo.Kryo
+import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
+
+import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf}
import org.apache.spark.storage.StorageLevel
@@ -29,10 +33,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.SparkContext._
-import com.esotericsoftware.kryo.Kryo
-import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
-
-
/**
* Out-link information for a user or product block. This includes the original user/product IDs
* of the elements within this block, and the list of destination blocks that each user or
@@ -90,14 +90,19 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* preferences rather than explicit ratings given to items.
*/
class ALS private (
- var numBlocks: Int,
- var rank: Int,
- var iterations: Int,
- var lambda: Double,
- var implicitPrefs: Boolean,
- var alpha: Double,
- var seed: Long = System.nanoTime()
+ private var numBlocks: Int,
+ private var rank: Int,
+ private var iterations: Int,
+ private var lambda: Double,
+ private var implicitPrefs: Boolean,
+ private var alpha: Double,
+ private var seed: Long = System.nanoTime()
) extends Serializable with Logging {
+
+ /**
+ * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10,
+ * lambda: 0.01, implicitPrefs: false, alpha: 1.0}.
+ */
def this() = this(-1, 10, 10, 0.01, false, 1.0)
/**
@@ -127,11 +132,18 @@ class ALS private (
this
}
+ /** Sets whether to use implicit preference. Default: false. */
def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
this.implicitPrefs = implicitPrefs
this
}
+ /**
+ * :: Experimental ::
+ *
+ * Sets the constant used in computing confidence in implicit ALS. Default: 1.0.
+ */
+ @Experimental
def setAlpha(alpha: Double): ALS = {
this.alpha = alpha
this
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 443fc5de5b..e05224fc7c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.recommendation
+import org.jblas._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.PythonMLLibAPI
-import org.jblas._
-import org.apache.spark.api.java.JavaRDD
-
/**
* Model representing the result of matrix factorization.
@@ -68,6 +69,8 @@ class MatrixFactorizationModel(
}
/**
+ * :: DeveloperApi ::
+ *
* Predict the rating of many users for many products.
* This is a Java stub for python predictAll()
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 80dc0f12ff..c24f5afb99 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
@@ -79,7 +80,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List()
- val optimizer: Optimizer
+ /** The optimizer to solve the problem. */
+ def optimizer: Optimizer
/** Whether to add intercept (default: true). */
protected var addIntercept: Boolean = true
@@ -100,8 +102,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
/**
+ * :: Experimental ::
+ *
* Set if the algorithm should validate data before training. Default true.
*/
+ @Experimental
def setValidateData(validateData: Boolean): this.type = {
this.validateData = validateData
this
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 25920d0dc9..5f0812fd2e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -52,15 +52,16 @@ class LassoModel(
* See also the documentation for the precise formulation.
*/
class LassoWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var regParam: Double,
- var miniBatchFraction: Double)
+ private var stepSize: Double,
+ private var numIterations: Int,
+ private var regParam: Double,
+ private var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel] with Serializable {
- val gradient = new LeastSquaresGradient()
- val updater = new L1Updater()
- @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ private val gradient = new LeastSquaresGradient()
+ private val updater = new L1Updater()
+ override val optimizer = new GradientDescent(gradient, updater)
+ .setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
@@ -69,7 +70,8 @@ class LassoWithSGD private (
super.setIntercept(false)
/**
- * Construct a Lasso object with default parameters
+ * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
+ * regParam: 1.0, miniBatchFraction: 1.0}.
*/
def this() = this(1.0, 100, 1.0, 1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 9ed927994e..228fa8db3e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -52,19 +52,21 @@ class LinearRegressionModel(
* See also the documentation for the precise formulation.
*/
class LinearRegressionWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var miniBatchFraction: Double)
+ private var stepSize: Double,
+ private var numIterations: Int,
+ private var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
- val gradient = new LeastSquaresGradient()
- val updater = new SimpleUpdater()
- val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ private val gradient = new LeastSquaresGradient()
+ private val updater = new SimpleUpdater()
+ override val optimizer = new GradientDescent(gradient, updater)
+ .setStepSize(stepSize)
.setNumIterations(numIterations)
.setMiniBatchFraction(miniBatchFraction)
/**
- * Construct a LinearRegression object with default parameters
+ * Construct a LinearRegression object with default parameters: {stepSize: 1.0,
+ * numIterations: 100, miniBatchFraction: 1.0}.
*/
def this() = this(1.0, 100, 1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 1f17d2107f..e702027c7c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -52,16 +52,17 @@ class RidgeRegressionModel(
* See also the documentation for the precise formulation.
*/
class RidgeRegressionWithSGD private (
- var stepSize: Double,
- var numIterations: Int,
- var regParam: Double,
- var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable {
+ private var stepSize: Double,
+ private var numIterations: Int,
+ private var regParam: Double,
+ private var miniBatchFraction: Double)
+ extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable {
- val gradient = new LeastSquaresGradient()
- val updater = new SquaredL2Updater()
+ private val gradient = new LeastSquaresGradient()
+ private val updater = new SquaredL2Updater()
- @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ override val optimizer = new GradientDescent(gradient, updater)
+ .setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
@@ -70,7 +71,8 @@ class RidgeRegressionWithSGD private (
super.setIntercept(false)
/**
- * Construct a RidgeRegression object with default parameters
+ * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
+ * regParam: 1.0, miniBatchFraction: 1.0}.
*/
def this() = this(1.0, 100, 1.0, 1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index dee9594a9d..c8a966cd5f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree
import scala.util.control.Breaks._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.regression.LabeledPoint
@@ -33,13 +34,16 @@ import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
+ * :: Experimental ::
+ *
* A class that implements a decision tree algorithm for classification and regression. It
* supports both continuous and categorical features.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
-class DecisionTree private(val strategy: Strategy) extends Serializable with Logging {
+@Experimental
+class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
/**
* Method to train a decision tree model over an RDD
@@ -1024,7 +1028,7 @@ object DecisionTree extends Serializable with Logging {
}
}
- val usage = """
+ private val usage = """
Usage: DecisionTreeRunner <master>[slices] --algo <Classification,
Regression> --trainDataDir path --testDataDir path --maxDepth num [--impurity <Gini,Entropy,
Variance>] [--maxBins num]
@@ -1113,7 +1117,7 @@ object DecisionTree extends Serializable with Logging {
* @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
* the label, and the second element represents the feature values (an array of Double).
*/
- def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
+ private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
sc.textFile(dir).map { line =>
val parts = line.trim().split(",")
val label = parts(0).toDouble
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
index 2dd1f0f27b..017f84f3b9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -17,9 +17,14 @@
package org.apache.spark.mllib.tree.configuration
+import org.apache.spark.annotation.Experimental
+
/**
+ * :: Experimental ::
+ *
* Enum to select the algorithm for the decision tree
*/
+@Experimental
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
index 09ee0586c5..c0254c32c2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
@@ -17,9 +17,14 @@
package org.apache.spark.mllib.tree.configuration
+import org.apache.spark.annotation.Experimental
+
/**
+ * :: Experimental ::
+ *
* Enum to describe whether a feature is "continuous" or "categorical"
*/
+@Experimental
object FeatureType extends Enumeration {
type FeatureType = Value
val Continuous, Categorical = Value
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
index 2457a480c2..b3e8b224be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
@@ -17,9 +17,14 @@
package org.apache.spark.mllib.tree.configuration
+import org.apache.spark.annotation.Experimental
+
/**
+ * :: Experimental ::
+ *
* Enum for selecting the quantile calculation strategy
*/
+@Experimental
object QuantileStrategy extends Enumeration {
type QuantileStrategy = Value
val Sort, MinMax, ApproxHist = Value
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index df565f3eb8..482faaa9e7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.tree.configuration
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
/**
+ * :: Experimental ::
+ *
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
@@ -34,10 +37,11 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
+@Experimental
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
+ val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index b93995fcf9..55c43f2fcf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -17,31 +17,39 @@
package org.apache.spark.mllib.tree.impurity
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+
/**
+ * :: Experimental ::
+ *
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
*/
+@Experimental
object Entropy extends Impurity {
- def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
+ private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
/**
+ * :: DeveloperApi ::
+ *
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
*/
- def calculate(c0: Double, c1: Double): Double = {
- if (c0 == 0 || c1 == 0) {
- 0
- } else {
- val total = c0 + c1
- val f0 = c0 / total
- val f1 = c1 / total
- -(f0 * log2(f0)) - (f1 * log2(f1))
- }
- }
+ @DeveloperApi
+ override def calculate(c0: Double, c1: Double): Double = {
+ if (c0 == 0 || c1 == 0) {
+ 0
+ } else {
+ val total = c0 + c1
+ val f0 = c0 / total
+ val f1 = c1 / total
+ -(f0 * log2(f0)) - (f1 * log2(f1))
+ }
+ }
- def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c0407554a9..c923b8e8f4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -17,19 +17,27 @@
package org.apache.spark.mllib.tree.impurity
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+
/**
+ * :: Experimental ::
+ *
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
*/
+@Experimental
object Gini extends Impurity {
/**
+ * :: DeveloperApi ::
+ *
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return Gini coefficient value
*/
+ @DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
@@ -41,6 +49,6 @@ object Gini extends Impurity {
}
}
- def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a4069063af..f407796596 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -17,26 +17,36 @@
package org.apache.spark.mllib.tree.impurity
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+
/**
+ * :: Experimental ::
+ *
* Trait for calculating information gain.
*/
+@Experimental
trait Impurity extends Serializable {
/**
+ * :: DeveloperApi ::
+ *
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return information value
*/
+ @DeveloperApi
def calculate(c0 : Double, c1 : Double): Double
/**
+ * :: DeveloperApi ::
+ *
* information calculation for regression
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return information value
*/
+ @DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index b74577dcec..2c64644f4e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -17,19 +17,27 @@
package org.apache.spark.mllib.tree.impurity
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+
/**
+ * :: Experimental ::
+ *
* Class for calculating variance during regression
*/
+@Experimental
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")
/**
+ * :: DeveloperApi ::
+ *
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
+ @DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index a57faa1374..2d71e1e366 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -30,4 +30,5 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
+private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a6dca84a2c..0f76f4a049 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,15 +17,19 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
/**
+ * :: Experimental ::
+ *
* Model to store the decision tree parameters
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
+@Experimental
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
index ebc9595eaf..2deaf4ae8d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
@@ -22,7 +22,7 @@ package org.apache.spark.mllib.tree.model
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
-case class Filter(split: Split, comparison: Int) {
+private[tree] case class Filter(split: Split, comparison: Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 99bf79cf12..d36b58e92c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -17,7 +17,11 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
+ *
* Information gain statistics for each split
* @param gain information gain value
* @param impurity current node impurity
@@ -25,6 +29,7 @@ package org.apache.spark.mllib.tree.model
* @param rightImpurity right node impurity
* @param predict predicted value
*/
+@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index aac3f9ce30..3399721414 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vector
/**
+ * :: DeveloperApi ::
+ *
* Node in a decision tree
* @param id integer node id
* @param predict predicted value at the node
@@ -31,6 +34,7 @@ import org.apache.spark.mllib.linalg.Vector
* @param rightNode right child
* @param stats information gain stats
*/
+@DeveloperApi
class Node (
val id: Int,
val predict: Double,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index 4e64a81dda..8bbb343079 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -17,20 +17,24 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
/**
+ * :: DeveloperApi ::
+ *
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
*/
+@DeveloperApi
case class Split(
feature: Int,
threshold: Double,
featureType: FeatureType,
- categories: List[Double]){
+ categories: List[Double]) {
override def toString =
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
@@ -42,7 +46,7 @@ case class Split(
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
-class DummyLowSplit(feature: Int, featureType: FeatureType)
+private[tree] class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())
/**
@@ -50,7 +54,7 @@ class DummyLowSplit(feature: Int, featureType: FeatureType)
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
-class DummyHighSplit(feature: Int, featureType: FeatureType)
+private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
/**
@@ -59,6 +63,6 @@ class DummyHighSplit(feature: Int, featureType: FeatureType)
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
-class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
+private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
index 8b55bce7c4..230c409e1b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
@@ -17,23 +17,25 @@
package org.apache.spark.mllib.util
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
/**
+ * :: DeveloperApi ::
+ *
* A collection of methods used to validate data before applying ML algorithms.
*/
+@DeveloperApi
object DataValidators extends Logging {
/**
* Function to check if labels used for classification are either zero or one.
*
- * @param data - input data set that needs to be checked
- *
* @return True if labels are all zero or one, false otherwise.
*/
- val classificationLabels: RDD[LabeledPoint] => Boolean = { data =>
+ val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
if (numInvalid != 0) {
logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
index 9109189dff..e693d13703 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
@@ -19,15 +19,18 @@ package org.apache.spark.mllib.util
import scala.util.Random
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
/**
+ * :: DeveloperApi ::
+ *
* Generate test data for KMeans. This class first chooses k cluster centers
* from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian
* cluster with scale 1 around each center.
*/
-
+@DeveloperApi
object KMeansDataGenerator {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index 81e4eda2a6..140ff92869 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -22,16 +22,20 @@ import scala.util.Random
import org.jblas.DoubleMatrix
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
/**
+ * :: DeveloperApi ::
+ *
* Generate sample data used for Linear Data. This class generates
* uniformly random values for every feature and adds Gaussian noise with mean `eps` to the
* response variable `Y`.
*/
+@DeveloperApi
object LinearDataGenerator {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
index 61498dcc2b..ca06b9ad58 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
@@ -19,16 +19,19 @@ package org.apache.spark.mllib.util
import scala.util.Random
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
/**
+ * :: DeveloperApi ::
+ *
* Generate test data for LogisticRegression. This class chooses positive labels
* with probability `probOne` and scales features for positive examples by `eps`.
*/
-
+@DeveloperApi
object LogisticRegressionDataGenerator {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
index 348aba1dea..3bd86d6813 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
@@ -21,10 +21,13 @@ import scala.util.Random
import org.jblas.DoubleMatrix
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
/**
+ * :: DeveloperApi ::
+ *
* Generate RDD(s) containing data for Matrix Factorization.
*
* This method samples training entries according to the oversampling factor
@@ -47,9 +50,8 @@ import org.apache.spark.rdd.RDD
* test (Boolean) Whether to create testing RDD.
* testSampFact (Double) Percentage of training data to use as test data.
*/
-
-object MFDataGenerator{
-
+@DeveloperApi
+object MFDataGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: MFDataGenerator " +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 83d1bd3fd5..7f9804deaf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.util
import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
@@ -122,6 +123,8 @@ object MLUtils {
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
/**
+ * :: Experimental ::
+ *
* Load labeled data from a file. The data format used here is
* <L>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
@@ -131,6 +134,7 @@ object MLUtils {
* @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
* the label, and the second element represents the feature values (an array of Double).
*/
+ @Experimental
def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
sc.textFile(dir).map { line =>
val parts = line.split(',')
@@ -141,6 +145,8 @@ object MLUtils {
}
/**
+ * :: Experimental ::
+ *
* Save labeled data to a file. The data format used here is
* <L>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
@@ -148,6 +154,7 @@ object MLUtils {
* @param data An RDD of LabeledPoints containing data to be saved.
* @param dir Directory to save the data.
*/
+ @Experimental
def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" "))
dataStr.saveAsTextFile(dir)
@@ -165,7 +172,7 @@ object MLUtils {
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
- def computeStats(
+ private[mllib] def computeStats(
data: RDD[LabeledPoint],
numFeatures: Int,
numExamples: Long): (Double, Vector, Vector) = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index e300c3dbe1..87a6f2a0c3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -21,15 +21,19 @@ import scala.util.Random
import org.jblas.DoubleMatrix
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
/**
+ * :: DeveloperApi ::
+ *
* Generate sample data used for SVM. This class generates uniform random values
* for the features and adds Gaussian noise with weight 0.1 to generate labels.
*/
+@DeveloperApi
object SVMDataGenerator {
def main(args: Array[String]) {