aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
commitee4090b60e8b6a350913d1d5049f0770c251cd4a (patch)
tree7e082fa815430c23e0387461be0726cc3e4d04b5 /mllib/src/main/scala/org
parent2407f5b14edcdcf750113766d82e78732f9852d6 (diff)
parentd7e124edfe2578ecdf8e816a4dda3ce430a09172 (diff)
downloadspark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.gz
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.bz2
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.zip
Merge remote-tracking branch 'origin/master' into build-for-2.12
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala662
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala149
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala54
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala89
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala179
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala311
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala50
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala135
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala57
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala85
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala167
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala128
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala44
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala77
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala121
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala)5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala)7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala100
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala114
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala266
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala147
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala169
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala135
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala117
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala117
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala98
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala96
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala181
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala195
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala150
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala73
109 files changed, 3661 insertions, 2160 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 57e416591d..1247882d6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param.{ParamMap, ParamPair}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
/**
* :: DeveloperApi ::
@@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* Estimator's embedded ParamMap.
* @return fitted model
*/
+ @Since("2.0.0")
@varargs
- def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+ def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
@@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
- def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}
/**
* Fits a model to the input data.
*/
- def fit(dataset: DataFrame): M
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_]): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 3a99979a88..82066726a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
/**
@@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") (
* @param dataset input dataset
* @return fitted pipeline
*/
- @Since("1.2.0")
- override def fit(dataset: DataFrame): PipelineModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PipelineModel = {
transformSchema(dataset.schema, logging = true)
val theStages = $(stages)
// Search for the last estimator.
@@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") (
t
case _ =>
throw new IllegalArgumentException(
- s"Do not support stage $stage of type ${stage.getClass}")
+ s"Does not support stage $stage of type ${stage.getClass}")
}
if (index < indexOfLastEstimator) {
curDataset = transformer.transform(curDataset)
@@ -291,10 +291,10 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}
- @Since("1.2.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
+ stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
}
@Since("1.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ebe48700f8..81140d1f7b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params
/**
* Validates and transforms the input schema with the provided param map.
+ *
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
@@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
- // TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
@@ -83,7 +83,7 @@ abstract class Predictor[
/** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
- override def fit(dataset: DataFrame): M = {
+ override def fit(dataset: Dataset[_]): M = {
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
@@ -100,7 +100,7 @@ abstract class Predictor[
* @param dataset Training dataset
* @return Fitted model
*/
- protected def train(dataset: DataFrame): M
+ protected def train(dataset: Dataset[_]): M
/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
@@ -120,10 +120,9 @@ abstract class Predictor[
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
- protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
- dataset.select($(labelCol), $(featuresCol)).rdd.map {
- case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
+ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}
@@ -172,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
* @param dataset input dataset
* @return transformed dataset with [[predictionCol]] of type [[Double]]
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
- dataset
+ dataset.toDF
}
}
- protected def transformImpl(dataset: DataFrame): DataFrame = {
+ protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 2538c0f477..a3a2b55adc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage {
* @param otherParamPairs other param pairs, overwrite embedded params
* @return transformed dataset
*/
+ @Since("2.0.0")
@varargs
def transform(
- dataset: DataFrame,
+ dataset: Dataset[_],
firstParamPair: ParamPair[_],
otherParamPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
@@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ @Since("2.0.0")
+ def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
this.copy(paramMap).transform(dataset)
}
/**
* Transforms the input dataset.
*/
- def transform(dataset: DataFrame): DataFrame
+ @Since("2.0.0")
+ def transform(dataset: Dataset[_]): DataFrame
override def copy(extra: ParamMap): Transformer
}
@@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 2cd94fa8f5..a5b84116e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.ann
-import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV,
- Vector => BV}
-import breeze.numerics.{log => Blog, sigmoid => Bsigmoid}
+import java.util.Random
+
+import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.optimization._
@@ -32,20 +32,46 @@ import org.apache.spark.util.random.XORShiftRandom
*
*/
private[ann] trait Layer extends Serializable {
+
/**
- * Returns the instance of the layer based on weights provided
- * @param weights vector with layer weights
- * @param position position of weights in the vector
- * @return the layer model
+ * Number of weights that is used to allocate memory for the weights vector
+ */
+ val weightSize: Int
+
+ /**
+ * Returns the output size given the input size (not counting the stack size).
+ * Output size is used to allocate memory for the output.
+ *
+ * @param inputSize input size
+ * @return output size
*/
- def getInstance(weights: Vector, position: Int): LayerModel
+ def getOutputSize(inputSize: Int): Int
/**
+ * If true, the memory is not allocated for the output of this layer.
+ * The memory allocated to the previous layer is used to write the output of this layer.
+ * Developer can set this to true if computing delta of a previous layer
+ * does not involve its output, so the current layer can write there.
+ * This also mean that both layers have the same number of outputs.
+ */
+ val inPlace: Boolean
+
+ /**
+ * Returns the instance of the layer based on weights provided.
+ * Size of weights must be equal to weightSize
+ *
+ * @param initialWeights vector with layer weights
+ * @return the layer model
+ */
+ def createModel(initialWeights: BDV[Double]): LayerModel
+ /**
* Returns the instance of the layer with random generated weights
- * @param seed seed
+ *
+ * @param weights vector for weights initialization, must be equal to weightSize
+ * @param random random number generator
* @return the layer model
*/
- def getInstance(seed: Long): LayerModel
+ def initModel(weights: BDV[Double], random: Random): LayerModel
}
/**
@@ -54,92 +80,102 @@ private[ann] trait Layer extends Serializable {
* Can return weights in Vector format.
*/
private[ann] trait LayerModel extends Serializable {
- /**
- * number of weights
- */
- val size: Int
+ val weights: BDV[Double]
/**
* Evaluates the data (process the data through the layer)
+ * Output is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of output
+ * when writing to it
+ *
* @param data data
- * @return processed data
+ * @param output output (modified in place)
*/
- def eval(data: BDM[Double]): BDM[Double]
+ def eval(data: BDM[Double], output: BDM[Double]): Unit
/**
* Computes the delta for back propagation
- * @param nextDelta delta of the next layer
- * @param input input data
- * @return delta
+ * Delta is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of
+ * prevDelta when writing to it
+ *
+ * @param delta delta of this layer
+ * @param output output of this layer
+ * @param prevDelta the previous delta (modified in place)
*/
- def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double]
+ def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit
/**
* Computes the gradient
+ * cumGrad is a wrapper on the part of the weight vector
+ * size of cumGrad is based on weightSize provided by
+ * implementation of LayerModel
+ *
* @param delta delta for this layer
* @param input input data
- * @return gradient
+ * @param cumGrad cumulative gradient (modified in place)
*/
- def grad(delta: BDM[Double], input: BDM[Double]): Array[Double]
-
- /**
- * Returns weights for the layer in a single vector
- * @return layer weights
- */
- def weights(): Vector
+ def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit
}
/**
* Layer properties of affine transformations, that is y=A*x+b
+ *
* @param numIn number of inputs
* @param numOut number of outputs
*/
private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = {
- AffineLayerModel(this, weights, position)
- }
+ override val weightSize = numIn * numOut + numOut
- override def getInstance(seed: Long = 11L): LayerModel = {
- AffineLayerModel(this, seed)
- }
+ override def getOutputSize(inputSize: Int): Int = numOut
+
+ override val inPlace = false
+
+ override def createModel(weights: BDV[Double]): LayerModel = new AffineLayerModel(weights, this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ AffineLayerModel(this, weights, random)
}
/**
- * Model of Affine layer y=A*x+b
- * @param w weights (matrix A)
- * @param b bias (vector b)
+ * Model of Affine layer
+ *
+ * @param weights weights
+ * @param layer layer properties
*/
-private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel {
- val size = w.size + b.length
- val gwb = new Array[Double](size)
- private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb)
- private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size)
- private var z: BDM[Double] = null
- private var d: BDM[Double] = null
+private[ann] class AffineLayerModel private[ann] (
+ val weights: BDV[Double],
+ val layer: AffineLayer) extends LayerModel {
+ val w = new BDM[Double](layer.numOut, layer.numIn, weights.data, weights.offset)
+ val b =
+ new BDV[Double](weights.data, weights.offset + (layer.numOut * layer.numIn), 1, layer.numOut)
+
private var ones: BDV[Double] = null
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols)
- z(::, *) := b
- BreezeUtil.dgemm(1.0, w, data, 1.0, z)
- z
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ output(::, *) := b
+ BreezeUtil.dgemm(1.0, w, data, 1.0, output)
}
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols)
- BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d)
- d
+ override def computePrevDelta(
+ delta: BDM[Double],
+ output: BDM[Double],
+ prevDelta: BDM[Double]): Unit = {
+ BreezeUtil.dgemm(1.0, w.t, delta, 0.0, prevDelta)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = {
- BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ // compute gradient of weights
+ val cumGradientOfWeights = new BDM[Double](w.rows, w.cols, cumGrad.data, cumGrad.offset)
+ BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 1.0, cumGradientOfWeights)
if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols)
- BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb)
- gwb
+ // compute gradient of bias
+ val cumGradientOfBias = new BDV[Double](cumGrad.data, cumGrad.offset + w.size, 1, b.length)
+ BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 1.0, cumGradientOfBias)
}
-
- override def weights(): Vector = AffineLayerModel.roll(w, b)
}
/**
@@ -149,73 +185,40 @@ private[ann] object AffineLayerModel {
/**
* Creates a model of Affine layer
+ *
* @param layer layer properties
- * @param weights vector with weights
- * @param position position of weights in the vector
- * @return model of Affine layer
- */
- def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = {
- val (w, b) = unroll(weights, position, layer.numIn, layer.numOut)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Creates a model of Affine layer
- * @param layer layer properties
- * @param seed seed
+ * @param weights vector for weights initialization
+ * @param random random number generator
* @return model of Affine layer
*/
- def apply(layer: AffineLayer, seed: Long): AffineLayerModel = {
- val (w, b) = randomWeights(layer.numIn, layer.numOut, seed)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Unrolls the weights from the vector
- * @param weights vector with weights
- * @param position position of weights for this layer
- * @param numIn number of layer inputs
- * @param numOut number of layer outputs
- * @return matrix A and vector b
- */
- def unroll(
- weights: Vector,
- position: Int,
- numIn: Int,
- numOut: Int): (BDM[Double], BDV[Double]) = {
- val weightsCopy = weights.toArray
- // TODO: the array is not copied to BDMs, make sure this is OK!
- val a = new BDM[Double](numOut, numIn, weightsCopy, position)
- val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut)
- (a, b)
- }
-
- /**
- * Roll the layer weights into a vector
- * @param a matrix A
- * @param b vector b
- * @return vector of weights
- */
- def roll(a: BDM[Double], b: BDV[Double]): Vector = {
- val result = new Array[Double](a.size + b.length)
- // TODO: make sure that we need to copy!
- System.arraycopy(a.toArray, 0, result, 0, a.size)
- System.arraycopy(b.toArray, 0, result, a.size, b.length)
- Vectors.dense(result)
+ def apply(layer: AffineLayer, weights: BDV[Double], random: Random): AffineLayerModel = {
+ randomWeights(layer.numIn, layer.numOut, weights, random)
+ new AffineLayerModel(weights, layer)
}
/**
- * Generate random weights for the layer
- * @param numIn number of inputs
+ * Initialize weights randomly in the interval
+ * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)]
+ * where a is chosen in a such way that the weight variance corresponds
+ * to the points to the maximal curvature of the activation function
+ * (which is approximately 2.38 for a standard sigmoid)
+ *
+ * @param numIn number of inputs
* @param numOut number of outputs
- * @param seed seed
- * @return (matrix A, vector b)
+ * @param weights vector for weights initialization
+ * @param random random number generator
*/
- def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = {
- val rand: XORShiftRandom = new XORShiftRandom(seed)
- val weights = BDM.fill[Double](numOut, numIn) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- val bias = BDV.fill[Double](numOut) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- (weights, bias)
+ def randomWeights(
+ numIn: Int,
+ numOut: Int,
+ weights: BDV[Double],
+ random: Random): Unit = {
+ var i = 0
+ val sqrtIn = math.sqrt(numIn)
+ while (i < weights.length) {
+ weights(i) = (random.nextDouble * 4.8 - 2.4) / sqrtIn
+ i += 1
+ }
}
}
@@ -226,44 +229,21 @@ private[ann] trait ActivationFunction extends Serializable {
/**
* Implements a function
- * @param x input data
- * @param y output data
*/
- def eval(x: BDM[Double], y: BDM[Double]): Unit
+ def eval: Double => Double
/**
* Implements a derivative of a function (needed for the back propagation)
- * @param x input data
- * @param y output data
*/
- def derivative(x: BDM[Double], y: BDM[Double]): Unit
-
- /**
- * Implements a cross entropy error of a function.
- * Needed if the functional layer that contains this function is the output layer
- * of the network.
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return cross-entropy
- */
- def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
-
- /**
- * Implements a mean squared error of a function
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return mean squared error
- */
- def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
+ def derivative: Double => Double
}
/**
- * Implements in-place application of functions
+ * Implements in-place application of functions in the arrays
*/
-private[ann] object ActivationFunction {
+private[ann] object ApplyInPlace {
+ // TODO: use Breeze UFunc
def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = {
var i = 0
while (i < x.rows) {
@@ -276,6 +256,7 @@ private[ann] object ActivationFunction {
}
}
+ // TODO: use Breeze UFunc
def apply(
x1: BDM[Double],
x2: BDM[Double],
@@ -294,179 +275,86 @@ private[ann] object ActivationFunction {
}
/**
- * Implements SoftMax activation function
- */
-private[ann] class SoftmaxFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- var j = 0
- // find max value to make sure later that exponent is computable
- while (j < x.cols) {
- var i = 0
- var max = Double.MinValue
- while (i < x.rows) {
- if (x(i, j) > max) {
- max = x(i, j)
- }
- i += 1
- }
- var sum = 0.0
- i = 0
- while (i < x.rows) {
- val res = Math.exp(x(i, j) - max)
- y(i, j) = res
- sum += res
- i += 1
- }
- i = 0
- while (i < x.rows) {
- y(i, j) /= sum
- i += 1
- }
- j += 1
- }
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum( target :* Blog(output)) / output.cols
- }
-
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
-
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.")
- }
-}
-
-/**
* Implements Sigmoid activation function
*/
private[ann] class SigmoidFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- def s(z: Double): Double = Bsigmoid(z)
- ActivationFunction(x, y, s)
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum(target :* Blog(output)) / output.cols
- }
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
+ override def eval: (Double) => Double = x => 1.0 / (1 + math.exp(-x))
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- // TODO: make it readable
- def m(o: Double, t: Double): Double = (o - t)
- ActivationFunction(output, target, result, m)
- val e = Bsum(result :* result) / 2 / output.cols
- def m2(x: Double, o: Double) = x * (o - o * o)
- ActivationFunction(result, output, result, m2)
- e
- }
+ override def derivative: (Double) => Double = z => (1 - z) * z
}
/**
* Functional layer properties, y = f(x)
+ *
* @param activationFunction activation function
*/
private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L)
- override def getInstance(seed: Long): LayerModel =
- FunctionalLayerModel(this)
+ override val weightSize = 0
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+
+ override val inPlace = true
+
+ override def createModel(weights: BDV[Double]): LayerModel = new FunctionalLayerModel(this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ createModel(weights)
}
/**
* Functional layer model. Holds no weights.
- * @param activationFunction activation function
+ *
+ * @param layer functiona layer
*/
-private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction)
+private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer)
extends LayerModel {
- val size = 0
- // matrices for in-place computations
- // outputs
- private var f: BDM[Double] = null
- // delta
- private var d: BDM[Double] = null
- // matrix for error computation
- private var e: BDM[Double] = null
- // delta gradient
- private lazy val dg = new Array[Double](0)
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols)
- activationFunction.eval(data, f)
- f
- }
+ // empty weights
+ val weights = new BDV[Double](0)
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols)
- activationFunction.derivative(input, d)
- d :*= nextDelta
- d
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ ApplyInPlace(data, output, layer.activationFunction.eval)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg
-
- override def weights(): Vector = Vectors.dense(new Array[Double](0))
-
- def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.crossEntropy(output, target, e)
- (e, error)
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ ApplyInPlace(input, delta, layer.activationFunction.derivative)
+ delta :*= nextDelta
}
- def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.squared(output, target, e)
- (e, error)
- }
-
- def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- // TODO: allow user pick error
- activationFunction match {
- case sigmoid: SigmoidFunction => squared(output, target)
- case softmax: SoftmaxFunction => crossEntropy(output, target)
- }
- }
-}
-
-/**
- * Fabric of functional layer models
- */
-private[ann] object FunctionalLayerModel {
- def apply(layer: FunctionalLayer): FunctionalLayerModel =
- new FunctionalLayerModel(layer.activationFunction)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {}
}
/**
* Trait for the artificial neural network (ANN) topology properties
*/
-private[ann] trait Topology extends Serializable{
- def getInstance(weights: Vector): TopologyModel
- def getInstance(seed: Long): TopologyModel
+private[ann] trait Topology extends Serializable {
+ def model(weights: Vector): TopologyModel
+ def model(seed: Long): TopologyModel
}
/**
* Trait for ANN topology model
*/
-private[ann] trait TopologyModel extends Serializable{
+private[ann] trait TopologyModel extends Serializable {
+
+ val weights: Vector
+ /**
+ * Array of layers
+ */
+ val layers: Array[Layer]
+
+ /**
+ * Array of layer models
+ */
+ val layerModels: Array[LayerModel]
/**
* Forward propagation
+ *
* @param data input data
* @return array of outputs for each of the layers
*/
@@ -474,6 +362,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Prediction of the model
+ *
* @param data input data
* @return prediction
*/
@@ -481,6 +370,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Computes gradient for the network
+ *
* @param data input data
* @param target target output
* @param cumGradient cumulative gradient
@@ -489,22 +379,17 @@ private[ann] trait TopologyModel extends Serializable{
*/
def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector,
blockSize: Int): Double
-
- /**
- * Returns the weights of the ANN
- * @return weights
- */
- def weights(): Vector
}
/**
* Feed forward ANN
+ *
* @param layers
*/
private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology {
- override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
+ override def model(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
- override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed)
+ override def model(seed: Long): TopologyModel = FeedForwardModel(this, seed)
}
/**
@@ -513,6 +398,7 @@ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends
private[ml] object FeedForwardTopology {
/**
* Creates a feed forward topology from the array of layers
+ *
* @param layers array of layers
* @return feed forward topology
*/
@@ -522,18 +408,26 @@ private[ml] object FeedForwardTopology {
/**
* Creates a multi-layer perceptron
+ *
* @param layerSizes sizes of layers including input and output size
- * @param softmax whether to use SoftMax or Sigmoid function for an output layer.
+ * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer.
* Softmax is default
* @return multilayer perceptron topology
*/
- def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = {
+ def multiLayerPerceptron(
+ layerSizes: Array[Int],
+ softmaxOnTop: Boolean = true): FeedForwardTopology = {
val layers = new Array[Layer]((layerSizes.length - 1) * 2)
- for(i <- 0 until layerSizes.length - 1) {
+ for (i <- 0 until layerSizes.length - 1) {
layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1))
layers(i * 2 + 1) =
- if (softmax && i == layerSizes.length - 2) {
- new FunctionalLayer(new SoftmaxFunction())
+ if (i == layerSizes.length - 2) {
+ if (softmaxOnTop) {
+ new SoftmaxLayerWithCrossEntropyLoss()
+ } else {
+ // TODO: squared error is more natural but converges slower
+ new SigmoidLayerWithSquaredError()
+ }
} else {
new FunctionalLayer(new SigmoidFunction())
}
@@ -545,17 +439,45 @@ private[ml] object FeedForwardTopology {
/**
* Model of Feed Forward Neural Network.
* Implements forward, gradient computation and can return weights in vector format.
- * @param layerModels models of layers
- * @param topology topology of the network
+ *
+ * @param weights network weights
+ * @param topology network topology
*/
private[ml] class FeedForwardModel private(
- val layerModels: Array[LayerModel],
+ val weights: Vector,
val topology: FeedForwardTopology) extends TopologyModel {
+
+ val layers = topology.layers
+ val layerModels = new Array[LayerModel](layers.length)
+ private var offset = 0
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).createModel(
+ new BDV[Double](weights.toArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
+ }
+ private var outputs: Array[BDM[Double]] = null
+ private var deltas: Array[BDM[Double]] = null
+
override def forward(data: BDM[Double]): Array[BDM[Double]] = {
- val outputs = new Array[BDM[Double]](layerModels.length)
- outputs(0) = layerModels(0).eval(data)
+ // Initialize output arrays for all layers. Special treatment for InPlace
+ val currentBatchSize = data.cols
+ // TODO: allocate outputs as one big array and then create BDMs from it
+ if (outputs == null || outputs(0).cols != currentBatchSize) {
+ outputs = new Array[BDM[Double]](layers.length)
+ var inputSize = data.rows
+ for (i <- 0 until layers.length) {
+ if (layers(i).inPlace) {
+ outputs(i) = outputs(i - 1)
+ } else {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ outputs(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
+ }
+ layerModels(0).eval(data, outputs(0))
for (i <- 1 until layerModels.length) {
- outputs(i) = layerModels(i).eval(outputs(i-1))
+ layerModels(i).eval(outputs(i - 1), outputs(i))
}
outputs
}
@@ -566,54 +488,36 @@ private[ml] class FeedForwardModel private(
cumGradient: Vector,
realBatchSize: Int): Double = {
val outputs = forward(data)
- val deltas = new Array[BDM[Double]](layerModels.length)
+ val currentBatchSize = data.cols
+ // TODO: allocate deltas as one big array and then create BDMs from it
+ if (deltas == null || deltas(0).cols != currentBatchSize) {
+ deltas = new Array[BDM[Double]](layerModels.length)
+ var inputSize = data.rows
+ for (i <- 0 until layerModels.length - 1) {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ deltas(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
val L = layerModels.length - 1
- val (newE, newError) = layerModels.last match {
- case flm: FunctionalLayerModel => flm.error(outputs.last, target)
+ // TODO: explain why delta of top layer is null (because it might contain loss+layer)
+ val loss = layerModels.last match {
+ case levelWithError: LossFunction => levelWithError.loss(outputs.last, target, deltas(L - 1))
case _ =>
- throw new UnsupportedOperationException("Non-functional layer not supported at the top")
+ throw new UnsupportedOperationException("Top layer is required to have objective.")
}
- deltas(L) = new BDM[Double](0, 0)
- deltas(L - 1) = newE
for (i <- (L - 2) to (0, -1)) {
- deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1))
- }
- val grads = new Array[Array[Double]](layerModels.length)
- for (i <- 0 until layerModels.length) {
- val input = if (i==0) data else outputs(i - 1)
- grads(i) = layerModels(i).grad(deltas(i), input)
+ layerModels(i + 1).computePrevDelta(deltas(i + 1), outputs(i + 1), deltas(i))
}
- // update cumGradient
val cumGradientArray = cumGradient.toArray
var offset = 0
- // TODO: extract roll
- for (i <- 0 until grads.length) {
- val gradArray = grads(i)
- var k = 0
- while (k < gradArray.length) {
- cumGradientArray(offset + k) += gradArray(k)
- k += 1
- }
- offset += gradArray.length
- }
- newError
- }
-
- // TODO: do we really need to copy the weights? they should be read-only
- override def weights(): Vector = {
- // TODO: extract roll
- var size = 0
- for (i <- 0 until layerModels.length) {
- size += layerModels(i).size
- }
- val array = new Array[Double](size)
- var offset = 0
for (i <- 0 until layerModels.length) {
- val layerWeights = layerModels(i).weights().toArray
- System.arraycopy(layerWeights, 0, array, offset, layerWeights.length)
- offset += layerWeights.length
+ val input = if (i == 0) data else outputs(i - 1)
+ layerModels(i).grad(deltas(i), input,
+ new BDV[Double](cumGradientArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
}
- Vectors.dense(array)
+ loss
}
override def predict(data: Vector): Vector = {
@@ -630,23 +534,19 @@ private[ann] object FeedForwardModel {
/**
* Creates a model from a topology and weights
+ *
* @param topology topology
* @param weights weights
* @return model
*/
def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
- val layers = topology.layers
- val layerModels = new Array[LayerModel](layers.length)
- var offset = 0
- for (i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(weights, offset)
- offset += layerModels(i).size
- }
- new FeedForwardModel(layerModels, topology)
+ // TODO: check that weights size is equal to sum of layers sizes
+ new FeedForwardModel(weights, topology)
}
/**
* Creates a model given a topology and seed
+ *
* @param topology topology
* @param seed seed for generating the weights
* @return model
@@ -654,17 +554,25 @@ private[ann] object FeedForwardModel {
def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
val layers = topology.layers
val layerModels = new Array[LayerModel](layers.length)
+ var totalSize = 0
+ for (i <- 0 until topology.layers.length) {
+ totalSize += topology.layers(i).weightSize
+ }
+ val weights = BDV.zeros[Double](totalSize)
var offset = 0
- for(i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(seed)
- offset += layerModels(i).size
+ val random = new XORShiftRandom(seed)
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).
+ initModel(new BDV[Double](weights.data, offset, 1, layers(i).weightSize), random)
+ offset += layers(i).weightSize
}
- new FeedForwardModel(layerModels, topology)
+ new FeedForwardModel(Vectors.fromBreeze(weights), topology)
}
}
/**
* Neural network gradient. Does nothing but calling Model's gradient
+ *
* @param topology topology
* @param dataStacker data stacker
*/
@@ -682,7 +590,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
weights: Vector,
cumGradient: Vector): Double = {
val (input, target, realBatchSize) = dataStacker.unstack(data)
- val model = topology.getInstance(weights)
+ val model = topology.model(weights)
model.computeGradient(input, target, cumGradient, realBatchSize)
}
}
@@ -692,6 +600,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
* through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks
* or matrices of inputs and outputs and then stack them in one vector.
* This can be used for further batch computations after unstacking.
+ *
* @param stackSize stack size
* @param inputSize size of the input vectors
* @param outputSize size of the output vectors
@@ -701,6 +610,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Stacks the data
+ *
* @param data RDD of vector pairs
* @return RDD of double (always zero) and vector that contains the stacked vectors
*/
@@ -733,6 +643,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Unstack the stacked vectors into matrices for batch operations
+ *
* @param data stacked vector
* @return pair of matrices holding input and output data and the real stack size
*/
@@ -765,6 +676,7 @@ private[ann] class ANNUpdater extends Updater {
/**
* MLlib-style trainer class that trains a network given the data and topology
+ *
* @param topology topology of ANN
* @param inputSize input size
* @param outputSize output size
@@ -774,8 +686,8 @@ private[ml] class FeedForwardTrainer(
val inputSize: Int,
val outputSize: Int) extends Serializable {
- // TODO: what if we need to pass random seed?
- private var _weights = topology.getInstance(11L).weights()
+ private var _seed = this.getClass.getName.hashCode.toLong
+ private var _weights: Vector = null
private var _stackSize = 128
private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize)
private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
@@ -783,27 +695,41 @@ private[ml] class FeedForwardTrainer(
private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100)
/**
+ * Returns seed
+ */
+ def getSeed: Long = _seed
+
+ /**
+ * Sets seed
+ */
+ def setSeed(value: Long): this.type = {
+ _seed = value
+ this
+ }
+
+ /**
* Returns weights
- * @return weights
*/
def getWeights: Vector = _weights
/**
* Sets weights
+ *
* @param value weights
* @return trainer
*/
- def setWeights(value: Vector): FeedForwardTrainer = {
+ def setWeights(value: Vector): this.type = {
_weights = value
this
}
/**
* Sets the stack size
+ *
* @param value stack size
* @return trainer
*/
- def setStackSize(value: Int): FeedForwardTrainer = {
+ def setStackSize(value: Int): this.type = {
_stackSize = value
dataStacker = new DataStacker(value, inputSize, outputSize)
this
@@ -811,6 +737,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the SGD optimizer
+ *
* @return SGD optimizer
*/
def SGDOptimizer: GradientDescent = {
@@ -821,6 +748,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the LBFGS optimizer
+ *
* @return LBGS optimizer
*/
def LBFGSOptimizer: LBFGS = {
@@ -831,10 +759,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the updater
+ *
* @param value updater
* @return trainer
*/
- def setUpdater(value: Updater): FeedForwardTrainer = {
+ def setUpdater(value: Updater): this.type = {
_updater = value
updateUpdater(value)
this
@@ -842,10 +771,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the gradient
+ *
* @param value gradient
* @return trainer
*/
- def setGradient(value: Gradient): FeedForwardTrainer = {
+ def setGradient(value: Gradient): this.type = {
_gradient = value
updateGradient(value)
this
@@ -871,12 +801,20 @@ private[ml] class FeedForwardTrainer(
/**
* Trains the ANN
+ *
* @param data RDD of input and output vector pairs
* @return model
*/
def train(data: RDD[(Vector, Vector)]): TopologyModel = {
- val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
- topology.getInstance(newWeights)
+ val w = if (getWeights == null) {
+ // TODO: will make a copy if vector is a subvector of BDV (see Vectors code)
+ topology.model(_seed).weights
+ } else {
+ getWeights
+ }
+ // TODO: deprecate standard optimizer because it needs Vector
+ val newWeights = optimizer.optimize(dataStacker.stack(data), w)
+ topology.model(newWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
new file mode 100644
index 0000000000..32d78e9b22
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import java.util.Random
+
+import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV}
+import breeze.numerics.{log => brzlog}
+
+/**
+ * Trait for loss function
+ */
+private[ann] trait LossFunction {
+ /**
+ * Returns the value of loss function.
+ * Computes loss based on target and output.
+ * Writes delta (error) to delta in place.
+ * Delta is allocated based on the outputSize
+ * of model implementation.
+ *
+ * @param output actual output
+ * @param target target output
+ * @param delta delta (updated in place)
+ * @return loss
+ */
+ def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double
+}
+
+private[ann] class SigmoidLayerWithSquaredError extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+}
+
+private[ann] class SigmoidLayerModelWithSquaredError
+ extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction {
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ val error = Bsum(delta :* delta) / 2 / output.cols
+ ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o))
+ error
+ }
+}
+
+private[ann] class SoftmaxLayerWithCrossEntropyLoss extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+}
+
+private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with LossFunction {
+
+ // loss layer models do not have weights
+ val weights = new BDV[Double](0)
+
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ var j = 0
+ // find max value to make sure later that exponent is computable
+ while (j < data.cols) {
+ var i = 0
+ var max = Double.MinValue
+ while (i < data.rows) {
+ if (data(i, j) > max) {
+ max = data(i, j)
+ }
+ i += 1
+ }
+ var sum = 0.0
+ i = 0
+ while (i < data.rows) {
+ val res = math.exp(data(i, j) - max)
+ output(i, j) = res
+ sum += res
+ i += 1
+ }
+ i = 0
+ while (i < data.rows) {
+ output(i, j) /= sum
+ i += 1
+ }
+ j += 1
+ }
+ }
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ /* loss layer model computes delta in loss function */
+ }
+
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ /* loss layer model does not have weights */
+ }
+
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ -Bsum( target :* brzlog(output)) / output.cols
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 8186afc17a..473e801794 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @param dataset input dataset
* @return transformed dataset
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Output selected columns only.
@@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
- outputData
+ outputData.toDF
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21bff6..300ae4339c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/**
@@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)
- override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -203,9 +203,9 @@ final class DecisionTreeClassificationModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3aa18..39a698af15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -18,23 +18,24 @@
package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
- TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -43,13 +44,23 @@ import org.apache.spark.sql.functions._
* learning algorithm for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
*/
@Since("1.4.0")
@Experimental
final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
- with GBTParams with TreeClassifierParams with Logging {
+ with GBTClassifierParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc"))
@@ -106,41 +117,13 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTClassifier:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "logistic"
- * (default = logistic)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "logistic")
+ // Parameters from GBTClassifierParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "logistic" => OldLogLoss
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
- }
- }
-
- override protected def train(dataset: DataFrame): GBTClassificationModel = {
+ override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -166,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object GBTClassifier {
- // The losses below should be lowercase.
+object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
+
/** Accessor for supported loss settings: logistic */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassifier = super.load(path)
}
/**
@@ -190,9 +176,10 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -206,12 +193,12 @@ final class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -227,6 +214,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
@@ -238,16 +228,79 @@ final class GBTClassificationModel private[ml](
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
-private[ml] object GBTClassificationModel {
+@Since("2.0.0")
+object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassificationModel = super.load(path)
+
+ private[GBTClassificationModel]
+ class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+ val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
@@ -259,6 +312,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 861b1d4b66..c2b440059b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -36,8 +36,9 @@ import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -256,22 +257,26 @@ class LogisticRegression @Since("1.2.0") (
this
}
- override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}
- protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
+ protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean):
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(regParam, elasticNetParam, standardization, threshold,
+ maxIter, tol, fitIntercept)
+
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
@@ -290,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
val (coefficients, intercept, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
@@ -361,7 +369,7 @@ class LogisticRegression @Since("1.2.0") (
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
val vec = optInitialModel.get.coefficients
logWarning(
- s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
+ s"Initial coefficients provided $vec did not match the expected size $numFeatures")
}
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
@@ -443,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(logRegSummary)
+ val m = model.setSummary(logRegSummary)
+ instr.logSuccess(m)
+ m
}
@Since("1.4.0")
@@ -522,7 +532,7 @@ class LogisticRegressionModel private[spark] (
(LogisticRegressionModel, String) = {
$(probabilityCol) match {
case "" =>
- val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+ val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
@@ -539,13 +549,15 @@ class LogisticRegressionModel private[spark] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
- new BinaryLogisticRegressionSummary(
- this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
+ new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
+ probabilityColName, $(labelCol), $(featuresCol))
}
/**
@@ -771,13 +783,13 @@ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary
*/
sealed trait LogisticRegressionSummary extends Serializable {
- /** Dataframe outputted by the model's `transform` method. */
+ /** Dataframe output by the model's `transform` method. */
def predictions: DataFrame
- /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
+ /** Field in "predictions" which gives the probability of each class as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance (if available). */
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
@@ -789,9 +801,9 @@ sealed trait LogisticRegressionSummary extends Serializable {
* :: Experimental ::
* Logistic regression training results.
*
- * @param predictions dataframe outputted by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance as a vector.
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -813,9 +825,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* :: Experimental ::
* Binary Logistic regression results for a given model.
*
- * @param predictions dataframe outputted by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance.
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index f6de5f2df4..9ff5252e4f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -24,27 +24,27 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
-import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/** Params for Multilayer Perceptron. */
private[ml] trait MultilayerPerceptronParams extends PredictorParams
- with HasSeed with HasMaxIter with HasTol {
+ with HasSeed with HasMaxIter with HasTol with HasStepSize {
/**
* Layer sizes including input size and output size.
* Default: Array(1, 1)
- * @group param
+ *
+ * @group param
*/
final val layers: IntArrayParam = new IntArrayParam(this, "layers",
"Sizes of layers from input layer to output layer" +
" E.g., Array(780, 100, 10) means 780 inputs, " +
"one hidden layer with 100 neurons and output layer of 10 neurons.",
- // TODO: how to check ALSO that all elements are greater than 0?
- ParamValidators.arrayLengthGt(1)
+ (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
)
/** @group getParam */
@@ -56,7 +56,8 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
* a partition then it is adjusted to the size of this data.
* Recommended size is between 10 and 1000.
* Default: 128
- * @group expertParam
+ *
+ * @group expertParam
*/
final val blockSize: IntParam = new IntParam(this, "blockSize",
"Block size for stacking input data in matrices. Data is stacked within partitions." +
@@ -67,7 +68,33 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
/** @group getParam */
final def getBlockSize: Int = $(blockSize)
- setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
+ /**
+ * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs.
+ * l-bfgs is the default one.
+ *
+ * @group expertParam
+ */
+ final val solver: Param[String] = new Param[String](this, "solver",
+ " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " +
+ " l-bfgs is the default one.",
+ ParamValidators.inArray[String](Array("gd", "l-bfgs")))
+
+ /** @group getParam */
+ final def getOptimizer: String = $(solver)
+
+ /**
+ * Model weights. Can be returned either after training or after explicit setting
+ *
+ * @group expertParam
+ */
+ final val weights: Param[Vector] = new Param[Vector](this, "weights",
+ " Sets the weights of the model ")
+
+ /** @group getParam */
+ final def getWeights: Vector = $(weights)
+
+
+ setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03)
}
/** Label to vector converter. */
@@ -106,6 +133,7 @@ private object LabelConverter {
* Each layer has sigmoid activation function, output layer has softmax.
* Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels.
+ *
*/
@Since("1.5.0")
@Experimental
@@ -128,7 +156,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
/**
* Set the maximum number of iterations.
* Default is 100.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
@@ -137,18 +166,28 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-4.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setTol(value: Double): this.type = set(tol, value)
/**
- * Set the seed for weights initialization.
- * @group setParam
+ * Set the seed for weights initialization if weights are not set
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the model weights.
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ def setWeights(value: Vector): this.type = set(weights, value)
+
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
@@ -160,17 +199,24 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
* @param dataset Training dataset
* @return Fitted model
*/
- override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
+ override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = {
val myLayers = $(layers)
val labels = myLayers.last
val lpData = extractLabeledPoints(dataset)
val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels))
val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true)
- val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
- FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
- FeedForwardTrainer.setStackSize($(blockSize))
- val mlpModel = FeedForwardTrainer.train(data)
- new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
+ val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
+ if (isDefined(weights)) {
+ trainer.setWeights($(weights))
+ } else {
+ trainer.setSeed($(seed))
+ }
+ trainer.LBFGSOptimizer
+ .setConvergenceTol($(tol))
+ .setNumIterations($(maxIter))
+ trainer.setStackSize($(blockSize))
+ val mlpModel = trainer.train(data)
+ new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
}
}
@@ -186,7 +232,8 @@ object MultilayerPerceptronClassifier
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
* Each layer has sigmoid activation function, output layer has softmax.
- * @param uid uid
+ *
+ * @param uid uid
* @param layers array of layer sizes including input and output layers
* @param weights vector of initial weights for the model that consists of the weights of layers
* @return prediction model
@@ -203,7 +250,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.6.0")
override val numFeatures: Int = layers.head
- private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
+ private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights)
/**
* Returns layers in a Java List.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 483ef0d88c..267d63b51e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/**
* Params for Naive Bayes Classifiers.
@@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") (
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)
- override protected def train(dataset: DataFrame): NaiveBayesModel = {
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c41a611f1c..4de1b877b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,22 +21,24 @@ import java.util.UUID
import scala.language.existentials
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject, _}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
+private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
@@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams {
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
+}
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}
+private[ml] object OneVsRestParams extends ClassifierTypeTrait {
+
+ def validateParams(instance: OneVsRestParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("OneVsRest write will fail " +
+ s" because it contains $name which does not implement MLWritable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+
+ instance match {
+ case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
+ case _ => // no need to check OneVsRest here
+ }
+
+ checkElement(instance.getClassifier, "classifier")
+ }
+
+ def saveImpl(
+ path: String,
+ instance: OneVsRestParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => p.name != "classifier" }
+ .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
+ .toList)
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val classifierPath = new Path(path, "classifier").toString
+ instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
+ }
+
+ def loadImpl(
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+ val classifierPath = new Path(path, "classifier").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
+ (metadata, estimator)
+ }
+}
+
/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
@@ -73,18 +130,18 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
- @Since("1.4.0") override val uid: String,
- @Since("1.4.0") labelMetadata: Metadata,
+ @Since("1.4.0") override val uid: String,
+ private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
- extends Model[OneVsRestModel] with OneVsRestParams {
+ extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
@@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRestModel extends MLReadable[OneVsRestModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRestModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRestModel]] */
+ private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
+ ("numClasses" -> instance.models.length)
+ OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
+ instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ model.save(modelPath)
+ }
+ }
+ }
+
+ private class OneVsRestModelReader extends MLReader[OneVsRestModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRestModel].getName
+
+ override def load(path: String): OneVsRestModel = {
+ implicit val format = DefaultFormats
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val models = Range(0, numClasses).toArray.map { idx =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
+ }
+ val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
+ DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+ ovrModel.set("classifier", classifier)
+ ovrModel
+ }
+ }
}
/**
@@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Estimator[OneVsRestModel] with OneVsRestParams {
+ extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -186,12 +293,14 @@ final class OneVsRest @Since("1.4.0") (
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def fit(dataset: DataFrame): OneVsRestModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): OneVsRestModel = {
+ transformSchema(dataset.schema)
+
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
- val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
@@ -243,4 +352,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRest extends MLReadable[OneVsRest] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRest] = new OneVsRestReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRest = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRest]] */
+ private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ OneVsRestParams.saveImpl(path, instance, sc)
+ }
+ }
+
+ private class OneVsRestReader extends MLReader[OneVsRest] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRest].getName
+
+ override def load(path: String): OneVsRest = {
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val ovr = new OneVsRest(metadata.uid)
+ DefaultParamsReader.getAndSetParams(ovr, metadata)
+ ovr.setClassifier(classifier)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 865614aa5c..d00fee12b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[
* @param dataset input dataset
* @return transformed dataset
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
- outputData
+ outputData.toDF
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d341d..dfa711b243 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,17 +17,21 @@
package org.apache.spark.ml.classification
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
- with RandomForestParams with TreeClassifierParams {
+ with RandomForestClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
@@ -94,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") (
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object RandomForestClassifier {
+object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -129,6 +133,9 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassifier = super.load(path)
}
/**
@@ -136,8 +143,9 @@ object RandomForestClassifier {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
+ *
* @param _trees Decision trees in the ensemble.
- * Warning: These have null parents.
+ * Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
@@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
/**
* Construct a random forest classification model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(
@@ -162,15 +172,15 @@ final class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
@@ -216,36 +235,89 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
+ s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}
-private[ml] object RandomForestClassificationModel {
+@Since("2.0.0")
+object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestClassificationModel] =
+ new RandomForestClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassificationModel = super.load(path)
+
+ private[RandomForestClassificationModel]
+ class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestClassificationModelReader
+ extends MLReader[RandomForestClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): RandomForestClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f014a1d572..6cc9117da3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -17,15 +17,17 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params
/** @group expertParam */
@Since("2.0.0")
- final val minDivisibleClusterSize = new Param[Double](
+ final val minDivisibleClusterSize = new DoubleParam(
this,
"minDivisibleClusterSize",
"the minimum number of points (if >= 1.0) or the minimum proportion",
@@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params
class BisectingKMeansModel private[ml] (
@Since("2.0.0") override val uid: String,
private val parentModel: MLlibBisectingKMeansModel
- ) extends Model[BisectingKMeansModel] with BisectingKMeansParams {
+ ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
@@ -90,7 +92,7 @@ class BisectingKMeansModel private[ml] (
}
@Since("2.0.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -110,11 +112,49 @@ class BisectingKMeansModel private[ml] (
* centers.
*/
@Since("2.0.0")
- def computeCost(dataset: DataFrame): Double = {
+ def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+}
+
+object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
+ @Since("2.0.0")
+ override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeansModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[BisectingKMeansModel]] */
+ private[BisectingKMeansModel]
+ class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val dataPath = new Path(path, "data").toString
+ instance.parentModel.save(sc, dataPath)
+ }
+ }
+
+ private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[BisectingKMeansModel].getName
+
+ override def load(path: String): BisectingKMeansModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
+ val model = new BisectingKMeansModel(metadata.uid, mllibModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
@@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] (
@Experimental
class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0") override val uid: String)
- extends Estimator[BisectingKMeansModel] with BisectingKMeansParams {
+ extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable {
setDefault(
k -> 4,
@@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") (
override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra)
@Since("2.0.0")
- def this() = this(Identifiable.randomUID("bisecting k-means"))
+ def this() = this(Identifiable.randomUID("bisecting-kmeans"))
/** @group setParam */
@Since("2.0.0")
@@ -175,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") (
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
@Since("2.0.0")
- override def fit(dataset: DataFrame): BisectingKMeansModel = {
+ override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
val bkm = new MLlibBisectingKMeans()
@@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") (
}
}
+
+@Since("2.0.0")
+object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeans = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
new file mode 100644
index 0000000000..ead8ad7806
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+
+/**
+ * Common params for GaussianMixture and GaussianMixtureModel
+ */
+private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol
+ with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol {
+
+ /**
+ * Set the number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ @Since("2.0.0")
+ final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getK: Int = $(k)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by GaussianMixture.
+ * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureModel private[ml] (
+ @Since("2.0.0") override val uid: String,
+ private val parentModel: MLlibGMModel)
+ extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GaussianMixtureModel = {
+ val copied = new GaussianMixtureModel(uid, parentModel)
+ copyValues(copied, extra).setParent(this.parent)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val predUDF = udf((vector: Vector) => predict(vector))
+ val probUDF = udf((vector: Vector) => predictProbability(vector))
+ dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
+ .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+ private[clustering] def predictProbability(features: Vector): Vector = {
+ Vectors.dense(parentModel.predictSoft(features))
+ }
+
+ @Since("2.0.0")
+ def weights: Array[Double] = parentModel.weights
+
+ @Since("2.0.0")
+ def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
+
+ private var trainingSummary: Option[GaussianMixtureSummary] = None
+
+ private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.0.0")
+ def summary: GaussianMixtureSummary = trainingSummary.getOrElse {
+ throw new RuntimeException(
+ s"No training summary available for the ${this.getClass.getSimpleName}")
+ }
+}
+
+@Since("2.0.0")
+object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GaussianMixtureModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[GaussianMixtureModel]] */
+ private[GaussianMixtureModel] class GaussianMixtureModelWriter(
+ instance: GaussianMixtureModel) extends MLWriter {
+
+ private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: weights and gaussians
+ val weights = instance.weights
+ val gaussians = instance.gaussians
+ val mus = gaussians.map(_.mu)
+ val sigmas = gaussians.map(_.sigma)
+ val data = Data(weights, mus, sigmas)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GaussianMixtureModel].getName
+
+ override def load(path: String): GaussianMixtureModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
+ val weights = row.getSeq[Double](0).toArray
+ val mus = row.getSeq[Vector](1).toArray
+ val sigmas = row.getSeq[Matrix](2).toArray
+ require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
+ require(mus.length == weights.length, "Length of weight and Gaussian array must match")
+
+ val gaussians = (mus zip sigmas).map {
+ case (mu, sigma) =>
+ new MultivariateGaussian(mu, sigma)
+ }
+ val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
+
+/**
+ * :: Experimental ::
+ * GaussianMixture clustering.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixture @Since("2.0.0") (
+ @Since("2.0.0") override val uid: String)
+ extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable {
+
+ setDefault(
+ k -> 2,
+ maxIter -> 100,
+ tol -> 0.01)
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra)
+
+ @Since("2.0.0")
+ def this() = this(Identifiable.randomUID("GaussianMixture"))
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
+ val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+
+ val algo = new MLlibGM()
+ .setK($(k))
+ .setMaxIterations($(maxIter))
+ .setSeed($(seed))
+ .setConvergenceTol($(tol))
+ val parentModel = algo.run(rdd)
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this))
+ val summary = new GaussianMixtureSummary(model.transform(dataset),
+ $(predictionCol), $(probabilityCol), $(featuresCol), $(k))
+ model.setSummary(summary)
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
+@Since("2.0.0")
+object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
+
+ @Since("2.0.0")
+ override def load(path: String): GaussianMixture = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Summary of GaussianMixture.
+ *
+ * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param probabilityCol Name for column of predicted probability of each cluster in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureSummary private[clustering] (
+ @Since("2.0.0") @transient val predictions: DataFrame,
+ @Since("2.0.0") val predictionCol: String,
+ @Since("2.0.0") val probabilityCol: String,
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.0.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Probability of each cluster.
+ */
+ @Since("2.0.0")
+ @transient lazy val probability: DataFrame = predictions.select(probabilityCol)
+
+ /**
+ * Size of (number of data points in) each cluster.
+ */
+ @Since("2.0.0")
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 38428826a8..b324196842 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -105,8 +105,8 @@ class KMeansModel private[ml] (
copyValues(copied, extra)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -126,8 +126,8 @@ class KMeansModel private[ml] (
* model on the given data.
*/
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
- @Since("1.6.0")
- def computeCost(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
@@ -144,6 +144,12 @@ class KMeansModel private[ml] (
}
/**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@@ -254,8 +260,8 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): KMeansModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): KMeansModel = {
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
val algo = new MLlibKMeans()
@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
- val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ val summary = new KMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
}
@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+/**
+ * :: Experimental ::
+ * Summary of KMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
- @Since("2.0.0") val featuresCol: String) extends Serializable {
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
/**
* Cluster centers of the transformed data.
@@ -296,10 +315,15 @@ class KMeansSummary private[clustering] (
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
- * Size of each cluster.
+ * Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
- lazy val size: Array[Int] = cluster.rdd.map {
- case Row(clusterIdx: Int) => (clusterIdx, 1)
- }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index fe6a37fd6d..c57ceba4a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,21 +17,22 @@
package org.apache.spark.ml.clustering
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
- EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
- LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
- OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+ EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
+ LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
+ OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
import org.apache.spark.sql.types.StructType
@@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
/**
* Param for the number of topics (clusters) to infer. Must be > 1. Default: 10.
+ *
* @group param
*/
@Since("1.6.0")
@@ -173,10 +175,11 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* This uses a variational approximation following Hoffman et al. (2010), where the approximate
* distribution is called "gamma." Technically, this method returns this approximation "gamma"
* for each document.
+ *
* @group param
*/
@Since("1.6.0")
- final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" +
+ final val topicDistributionCol = new Param[String](this, "topicDistributionCol", "Output column" +
" with estimates of the topic mixture distribution for each document (often called \"theta\"" +
" in the literature). Returns a vector of zeros for an empty document.")
@@ -187,15 +190,19 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getTopicDistributionCol: String = $(topicDistributionCol)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
* This is called "tau0" in the Online LDA paper (Hoffman et al., 2010)
* Default: 1024, following Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
- final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" +
- " parameter that downweights early iterations. Larger values make early iterations count less.",
+ final val learningOffset = new DoubleParam(this, "learningOffset", "(For online optimizer)" +
+ " A (positive) learning parameter that downweights early iterations. Larger values make early" +
+ " iterations count less.",
ParamValidators.gt(0))
/** @group expertGetParam */
@@ -203,22 +210,27 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getLearningOffset: Double = $(learningOffset)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* Learning rate, set as an exponential decay rate.
* This should be between (0.5, 1.0] to guarantee asymptotic convergence.
* This is called "kappa" in the Online LDA paper (Hoffman et al., 2010).
* Default: 0.51, based on Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
- final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" +
- " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" +
- " convergence.", ParamValidators.gt(0))
+ final val learningDecay = new DoubleParam(this, "learningDecay", "(For online optimizer)" +
+ " Learning rate, set as an exponential decay rate. This should be between (0.5, 1.0] to" +
+ " guarantee asymptotic convergence.", ParamValidators.gt(0))
/** @group expertGetParam */
@Since("1.6.0")
def getLearningDecay: Double = $(learningDecay)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent,
* in range (0, 1].
*
@@ -230,11 +242,13 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]].
*
* Default: 0.05, i.e., 5% of total documents.
+ *
* @group param
*/
@Since("1.6.0")
- final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" +
- " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].",
+ final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "(For online optimizer)" +
+ " Fraction of the corpus to be sampled and used in each iteration of mini-batch" +
+ " gradient descent, in range (0, 1].",
ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true))
/** @group getParam */
@@ -242,23 +256,52 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getSubsamplingRate: Double = $(subsamplingRate)
/**
+ * For Online optimizer only (currently): [[optimizer]] = "online".
+ *
* Indicates whether the docConcentration (Dirichlet parameter for
* document-topic distribution) will be optimized during training.
* Setting this to true will make the model more expressive and fit the training data better.
* Default: false
+ *
* @group expertParam
*/
@Since("1.6.0")
final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration",
- "Indicates whether the docConcentration (Dirichlet parameter for document-topic" +
- " distribution) will be optimized during training.")
+ "(For online optimizer only, currently) Indicates whether the docConcentration" +
+ " (Dirichlet parameter for document-topic distribution) will be optimized during training.")
/** @group expertGetParam */
@Since("1.6.0")
def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)
/**
+ * For EM optimizer only: [[optimizer]] = "em".
+ *
+ * If using checkpointing, this indicates whether to keep the last
+ * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can
+ * cause failures if a data partition is lost, so set this bit with care.
+ * Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and
+ * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints.
+ *
+ * Default: true
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint",
+ "(For EM optimizer) If using checkpointing, this indicates whether to keep the last" +
+ " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" +
+ " cause failures if a data partition is lost, so set this bit with care.")
+
+ /** @group expertGetParam */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint)
+
+ /**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -303,6 +346,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
+ .setKeepLastCheckpoint($(keepLastCheckpoint))
}
}
@@ -341,6 +385,7 @@ sealed abstract class LDAModel private[ml] (
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -357,15 +402,15 @@ sealed abstract class LDAModel private[ml] (
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
- dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
+ dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
- dataset
+ dataset.toDF
}
}
@@ -410,8 +455,8 @@ sealed abstract class LDAModel private[ml] (
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
*/
- @Since("1.6.0")
- def logLikelihood(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def logLikelihood(dataset: Dataset[_]): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logLikelihood(oldDataset)
}
@@ -427,8 +472,8 @@ sealed abstract class LDAModel private[ml] (
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
- @Since("1.6.0")
- def logPerplexity(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def logPerplexity(dataset: Dataset[_]): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logPerplexity(oldDataset)
}
@@ -619,6 +664,35 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior
+ private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles
+
+ /**
+ * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be
+ * saved checkpoint files. This method is provided so that users can manage those files.
+ *
+ * Note that removing the checkpoints can cause failures if a partition is lost and is needed
+ * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints
+ * when this model and derivative data go out of scope.
+ *
+ * @return Checkpoint files from training
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def getCheckpointFiles: Array[String] = _checkpointFiles
+
+ /**
+ * Remove any remaining checkpoint files from training.
+ *
+ * @see [[getCheckpointFiles]]
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def deleteCheckpointFiles(): Unit = {
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
+ _checkpointFiles = Array.empty[String]
+ }
+
@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
}
@@ -696,11 +770,12 @@ class LDA @Since("1.6.0") (
setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10,
learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05,
- optimizeDocConcentration -> true)
+ optimizeDocConcentration -> true, keepLastCheckpoint -> true)
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -758,11 +833,15 @@ class LDA @Since("1.6.0") (
@Since("1.6.0")
def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value)
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value)
+
@Since("1.6.0")
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
- @Since("1.6.0")
- override def fit(dataset: DataFrame): LDAModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): LDAModel = {
transformSchema(dataset.schema, logging = true)
val oldLDA = new OldLDA()
.setK($(k))
@@ -794,7 +873,7 @@ class LDA @Since("1.6.0") (
private[clustering] object LDA extends DefaultParamsReadable[LDA] {
/** Get dataset for spark.mllib LDA */
- def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
+ def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = {
dataset
.withColumn("docId", monotonicallyIncreasingId())
.select("docId", featuresCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 337ffbe90f..bde8c275fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
setDefault(metricName -> "areaUnderROC")
- @Since("1.2.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index 0f22cca3a7..5f765c071b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param.{ParamMap, Params}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
/**
* :: DeveloperApi ::
@@ -36,8 +36,8 @@ abstract class Evaluator extends Params {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- @Since("1.5.0")
- def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = {
this.copy(paramMap).evaluate(dataset)
}
@@ -46,8 +46,8 @@ abstract class Evaluator extends Params {
* @param dataset a dataset that contains labels/observations and predictions.
* @return metric
*/
- @Since("1.5.0")
- def evaluate(dataset: DataFrame): Double
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): Double
/**
* Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 55ff44323a..3acfc221c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
setDefault(metricName -> "f1")
- @Since("1.5.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 9976d7ed43..ed04b67bcc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType}
@@ -39,11 +39,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
def this() = this(Identifiable.randomUID("regEval"))
/**
- * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ * Param for metric name in evaluation. Supports:
+ * - `"rmse"` (default): root mean squared error
+ * - `"mse"`: mean squared error
+ * - `"r2"`: R^2^ metric
+ * - `"mae"`: mean absolute error
*
- * Because we will maximize evaluation value (ref: `CrossValidator`),
- * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
- * we take and output the negative of this metric.
* @group param
*/
@Since("1.4.0")
@@ -70,8 +71,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
setDefault(metricName -> "rmse")
- @Since("1.4.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 2f8e3a0371..898ac2cc89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -64,7 +64,8 @@ final class Binarizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val schema = dataset.schema
val inputType = schema($(inputCol)).dataType
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 33abc7c99d..10e622ace6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val bucketizer = udf { feature: Double =>
Bucketizer.binarySearchForBuckets($(splits), feature)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index b9e9d56853..cfecae7e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String)
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
case Row(label: Double, features: Vector) =>
@@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] (
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val newField = transformedSchema.last
val selector = udf { chiSqSelector.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 5694b3890f..922670a41b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -100,6 +100,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** @group getParam */
def getMinTF: Double = $(minTF)
+
+ /**
+ * Binary toggle to control the output vector values.
+ * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
+ * discrete probabilistic models that model binary events rather than integer counts.
+ * Default: false
+ * @group param
+ */
+ val binary: BooleanParam =
+ new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.")
+
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ setDefault(binary -> false)
}
/**
@@ -127,9 +142,13 @@ class CountVectorizer(override val uid: String)
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
setDefault(vocabSize -> (1 << 18), minDF -> 1)
- override def fit(dataset: DataFrame): CountVectorizerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CountVectorizerModel = {
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
@@ -152,16 +171,10 @@ class CountVectorizer(override val uid: String)
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
- val vocab: Array[String] = {
- val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocSize)
- }
- tmpSortedWC.map(_._1)
- }
+
+ val vocab = wordCounts
+ .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+ .map(_._1)
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
@@ -206,30 +219,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
- /**
- * Binary toggle to control the output vector values.
- * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
- * discrete probabilistic models that model binary events rather than integer counts.
- * Default: false
- * @group param
- */
- val binary: BooleanParam =
- new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
- "This is useful for discrete probabilistic models that model binary events rather " +
- "than integer counts")
-
- /** @group getParam */
- def getBinary: Boolean = $(binary)
-
/** @group setParam */
def setBinary(value: Boolean): this.type = set(binary, value)
- setDefault(binary -> false)
-
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 2c7ffdb7ba..1b0a9a12e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -38,9 +38,9 @@ class ElementwiseProduct(override val uid: String)
def this() = this(Identifiable.randomUID("elemProd"))
/**
- * the vector to multiply with input vectors
- * @group param
- */
+ * the vector to multiply with input vectors
+ * @group param
+ */
val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d73c4..467ad73074 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,11 +20,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}
@@ -52,7 +52,18 @@ class HashingTF(override val uid: String)
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))
- setDefault(numFeatures -> (1 << 18))
+ /**
+ * Binary toggle to control term frequency counts.
+ * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
+ * models that model binary events rather than integer counts.
+ * (default = false)
+ * @group param
+ */
+ val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " +
+ "This is useful for discrete probabilistic models that model binary events rather " +
+ "than integer counts")
+
+ setDefault(numFeatures -> (1 << 18), binary -> false)
/** @group getParam */
def getNumFeatures: Int = $(numFeatures)
@@ -60,9 +71,16 @@ class HashingTF(override val uid: String)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
- val hashingTF = new feature.HashingTF($(numFeatures))
+ val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index f36cf503a0..5075b78c98 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
- override def fit(dataset: DataFrame): IDFModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val idf = new feature.IDF($(minDocFreq)).fit(input)
@@ -115,7 +116,8 @@ class IDFModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val idf = udf { vec: Vector => idfModel.transform(vec) }
dataset.withColumn($(outputCol), idf(col($(inputCol))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index d3fe6e528f..9ca34e9ae2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 7de5a4d5d3..e9df600c8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): MaxAbsScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val summary = Statistics.colStats(input)
@@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// TODO: this looks hack, we may have to handle sparse and dense vectors separately.
val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index b13684a1cb..125becbb8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String)
/** @group setParam */
def setMax(value: Double): this.type = set(max, value)
- override def fit(dataset: DataFrame): MinMaxScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val summary = Statistics.colStats(input)
@@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] (
/** @group setParam */
def setMax(value: Double): this.type = set(max, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
val minArray = originalMin.toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 4f67042629..99357793db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
@@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// schema transformation
val inputColName = $(inputCol)
val outputColName = $(outputCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 305c3d187f..9cf722e121 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
/**
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
*/
- override def fit(dataset: DataFrame): PCAModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PCAModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
@@ -124,7 +125,8 @@ class PCAModel private[ml] (
* NOTE: Vectors to be transformed must be the same length
* as the source vectors given to [[PCA.fit()]].
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
val pcaOp = udf { pcaModel.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e486e92c12..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.param.{IntParam, _}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.util.random.XORShiftRandom
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -87,12 +104,13 @@ final class QuantileDiscretizer(override val uid: String)
StructType(outputFields)
}
- override def fit(dataset: DataFrame): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Bucketizer = {
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -103,90 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 12a76dbbfb..3ac6c77669 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorUDT
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
/**
@@ -103,7 +103,8 @@ class RFormula(override val uid: String)
RFormulaParser.parse($(formula)).hasIntercept
}
- override def fit(dataset: DataFrame): RFormulaModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): RFormulaModel = {
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
@@ -204,7 +205,8 @@ class RFormulaModel private[feature](
private[ml] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
checkCanTransform(dataset.schema)
transformLabel(pipelineModel.transform(dataset))
}
@@ -232,10 +234,10 @@ class RFormulaModel private[feature](
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
- private def transformLabel(dataset: DataFrame): DataFrame = {
+ private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
- dataset
+ dataset.toDF
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
@@ -246,7 +248,7 @@ class RFormulaModel private[feature](
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
- dataset
+ dataset.toDF
}
}
@@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str
def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
dataset.select(columnsToKeep.map(dataset.col): _*)
}
@@ -396,7 +398,7 @@ private class VectorAttributeRewriter(
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = group.attributes.get.map { attr =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index e0ca45b9a6..2002d15745 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.types.StructType
/**
@@ -63,13 +63,12 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
private val tableIdentifier: String = "__THIS__"
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- val outputDF = dataset.sqlContext.sql(realStatement)
- outputDF
+ dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
- dummyDF.registerTempTable(tableIdentifier)
- val outputSchema = sqlContext.sql($(statement)).schema
+ val tableName = Identifiable.randomUID(uid)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ dummyDF.registerTempTable(tableName)
+ val outputSchema = sqlContext.sql(realStatement).schema
+ sqlContext.dropTempTable(tableName)
outputSchema
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 26ee8e1bf1..118a6e3e6a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
- override def fit(dataset: DataFrame): StandardScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
@@ -135,7 +136,8 @@ class StandardScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0a0e0b0960..b96bc48566 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
@@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String)
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val t = if ($(caseSensitive)) {
val stopWordsSet = $(stopWords).toSet
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index faa0f6f407..7e0d374f02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): StringIndexerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
@@ -144,11 +145,12 @@ class StringIndexerModel (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
- return dataset
+ return dataset.toDF
}
validateAndTransformSchema(dataset.schema)
@@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String)
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if ($(labels).isEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 957e8e7a59..4d3e46e488 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.first()
+ lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index bf4aef2a74..68b699d569 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
@@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): VectorIndexerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): VectorIndexerModel = {
transformSchema(dataset.schema, logging = true)
val firstRow = dataset.select($(inputCol)).take(1)
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
@@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index b60e82de00..7a9468b87b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
@@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 95bae1c8a3..a72692960f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
- override def fit(dataset: DataFrame): Word2VecModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Word2VecModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
val wordVectors = new feature.Word2Vec()
@@ -219,7 +220,8 @@ class Word2VecModel private[ml] (
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d7837b6730..c368aadd23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.param
import java.lang.reflect.Modifier
+import java.util.{List => JList}
import java.util.NoSuchElementException
import scala.annotation.varargs
@@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
this
}
+ /** Put param pairs with a [[java.util.List]] of values for Python. */
+ private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = {
+ put(paramPairs.asScala: _*)
+ }
+
/**
* Optionally returns the value associated with a param.
*/
@@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
}
}
+ /** Java-friendly method for Python API */
+ private[ml] def toList: java.util.List[ParamPair[_]] = {
+ this.toSeq.asJava
+ }
+
/**
* Number of param pairs in this map.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 3ce129b12c..1d03a5b4f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -62,7 +62,7 @@ private[shared] object SharedParamsCodeGen {
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
- "will filter out rows with bad values), or error (which will throw an errror). More " +
+ "will filter out rows with bad values), or error (which will throw an error). More " +
"options may be added later",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 96263c5baf..64d6af2766 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -270,10 +270,10 @@ private[ml] trait HasFitIntercept extends Params {
private[ml] trait HasHandleInvalid extends Params {
/**
- * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
+ * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.
* @group param
*/
- final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
+ final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
/** @group getParam */
final def getHandleInvalid: String = $(handleInvalid)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 40590e71c4..7835468626 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class AFTSurvivalRegressionWrapper private (
pipeline: PipelineModel,
@@ -43,8 +43,8 @@ private[r] class AFTSurvivalRegressionWrapper private (
features ++ Array("Log(scale)")
}
- def transform(dataset: DataFrame): DataFrame = {
- pipeline.transform(dataset)
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(aftModel.getFeaturesCol)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
new file mode 100644
index 0000000000..475a308385
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression._
+import org.apache.spark.sql._
+
+private[r] class GeneralizedLinearRegressionWrapper private (
+ pipeline: PipelineModel,
+ val features: Array[String]) {
+
+ private val glm: GeneralizedLinearRegressionModel =
+ pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
+
+ lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
+ Array(glm.intercept) ++ glm.coefficients.toArray
+ } else {
+ glm.coefficients.toArray
+ }
+
+ lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
+ Array("(Intercept)") ++ features
+ } else {
+ features
+ }
+
+ def transform(dataset: DataFrame): DataFrame = {
+ pipeline.transform(dataset).drop(glm.getFeaturesCol)
+ }
+}
+
+private[r] object GeneralizedLinearRegressionWrapper {
+
+ def fit(
+ formula: String,
+ data: DataFrame,
+ family: String,
+ link: String,
+ epsilon: Double,
+ maxit: Int): GeneralizedLinearRegressionWrapper = {
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ val rFormulaModel = rFormula.fit(data)
+ // get labels and feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+ // assemble and fit the pipeline
+ val glm = new GeneralizedLinearRegression()
+ .setFamily(family)
+ .setLink(link)
+ .setFitIntercept(rFormula.hasIntercept)
+ .setTol(epsilon)
+ .setMaxIter(maxit)
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, glm))
+ .fit(data)
+ new GeneralizedLinearRegressionWrapper(pipeline, features)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
new file mode 100644
index 0000000000..9e2b81ee20
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class KMeansWrapper private (
+ pipeline: PipelineModel) {
+
+ private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
+
+ lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
+
+ private lazy val attrs = AttributeGroup.fromStructField(
+ kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
+
+ lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
+
+ lazy val k: Int = kMeansModel.getK
+
+ lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
+
+ lazy val cluster: DataFrame = kMeansModel.summary.cluster
+
+ def fitted(method: String): DataFrame = {
+ if (method == "centers") {
+ kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
+ } else if (method == "classes") {
+ kMeansModel.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ }
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
+ }
+
+}
+
+private[r] object KMeansWrapper {
+
+ def fit(
+ data: DataFrame,
+ k: Double,
+ maxIter: Double,
+ initMode: String,
+ columns: Array[String]): KMeansWrapper = {
+
+ val assembler = new VectorAssembler()
+ .setInputCols(columns)
+ .setOutputCol("features")
+
+ val kMeans = new KMeans()
+ .setK(k.toInt)
+ .setMaxIter(maxIter.toInt)
+ .setInitMode(initMode)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(assembler, kMeans))
+ .fit(data)
+
+ new KMeansWrapper(pipeline)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 07383d393d..b17207e99b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class NaiveBayesWrapper private (
pipeline: PipelineModel,
@@ -36,8 +36,10 @@ private[r] class NaiveBayesWrapper private (
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
- def transform(dataset: DataFrame): DataFrame = {
- pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(naiveBayesModel.getFeaturesCol)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
deleted file mode 100644
index d23e4fc9d1..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ /dev/null
@@ -1,167 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.api.r
-
-import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
-import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
-import org.apache.spark.sql.DataFrame
-
-private[r] object SparkRWrappers {
- def fitRModelFormula(
- value: String,
- df: DataFrame,
- family: String,
- lambda: Double,
- alpha: Double,
- standardize: Boolean,
- solver: String): PipelineModel = {
- val formula = new RFormula().setFormula(value)
- val estimator = family match {
- case "gaussian" => new LinearRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- .setSolver(solver)
- case "binomial" => new LogisticRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- }
- val pipeline = new Pipeline().setStages(Array(formula, estimator))
- pipeline.fit(df)
- }
-
- def fitKMeans(
- df: DataFrame,
- initMode: String,
- maxIter: Double,
- k: Double,
- columns: Array[String]): PipelineModel = {
- val assembler = new VectorAssembler().setInputCols(columns)
- val kMeans = new KMeans()
- .setInitMode(initMode)
- .setMaxIter(maxIter.toInt)
- .setK(k.toInt)
- .setFeaturesCol(assembler.getOutputCol)
- val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
- pipeline.fit(df)
- }
-
- def getModelCoefficients(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel => {
- val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++
- m.summary.coefficientStandardErrors.dropRight(1)
- val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
- val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
- tValuesR ++ pValuesR
- } else {
- m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
- }
- }
- case m: LogisticRegressionModel => {
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray
- } else {
- m.coefficients.toArray
- }
- }
- case m: KMeansModel =>
- m.clusterCenters.flatMap(_.toArray)
- }
- }
-
- def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- m.summary.devianceResiduals
- case m: LogisticRegressionModel =>
- throw new UnsupportedOperationException(
- "No deviance residuals available for LogisticRegressionModel")
- }
- }
-
- def getKMeansModelSize(model: PipelineModel): Array[Int] = {
- model.stages.last match {
- case m: KMeansModel => Array(m.getK) ++ m.summary.size
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
- def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
- model.stages.last match {
- case m: KMeansModel =>
- if (method == "centers") {
- // Drop the assembled vector for easy-print to R side.
- m.summary.predictions.drop(m.summary.featuresCol)
- } else if (method == "classes") {
- m.summary.cluster
- } else {
- throw new UnsupportedOperationException(
- s"Method (centers or classes) required but $method found.")
- }
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
- def getModelFeatures(model: PipelineModel): Array[String] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- case m: LogisticRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- case m: KMeansModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- attrs.attributes.get.map(_.name.get)
- }
- }
-
- def getModelName(model: PipelineModel): String = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- "LinearRegressionModel"
- case m: LogisticRegressionModel =>
- "LogisticRegressionModel"
- case m: KMeansModel =>
- "KMeansModel"
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4a3ad662a0..36dce01590 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -200,8 +200,8 @@ class ALSModel private[ml] (
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- @Since("1.3.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
this
}
- @Since("1.3.0")
- override def fit(dataset: DataFrame): ALSModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): ALSModel = {
import dataset.sqlContext.implicits._
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708ab8d..89ba6ab5d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -31,8 +31,9 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -103,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -183,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
*/
- protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
- dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
- case Row(features: Vector, label: Double, censor: Double) =>
- AFTPoint(features, label, censor)
- }
+ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+ .rdd.map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
}
- @Since("1.6.0")
- override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val costFun = new AFTCostFun(instances, $(fitIntercept))
+ val featuresSummarizer = {
+ val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
+ val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
+ c1.merge(c2)
+ }
+ instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+ }
+
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
- val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+ val numFeatures = featuresStd.size
/*
The parameters vector has three parts:
the first element: Double, log(sigma), the log of scale parameter
@@ -229,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
+ val rawCoefficients = parameters.slice(2, parameters.length)
+ var i = 0
+ while (i < numFeatures) {
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ i += 1
+ }
+ val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
@@ -298,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] (
math.exp(BLAS.dot(coefficients, features) + intercept)
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
@@ -433,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
+ * @param featuresStd The standard deviation values of the features.
*/
-private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
- extends Serializable {
+private class AFTAggregator(
+ parameters: BDV[Double],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends Serializable {
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
- private val intercept = parameters.valueAt(1)
+ private val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
- private var gradientInterceptSum = 0.0
- private var gradientLogSigmaSum = 0.0
+ // Here we optimize loss function over log(sigma), intercept and coefficients
+ private val gradientSumArray = Array.ofDim[Double](parameters.length)
def count: Long = totalCnt
+ def loss: Double = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ lossSum / totalCnt
+ }
+ def gradient: BDV[Double] = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
+ }
- def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
-
- // Here we optimize loss function over coefficients, intercept and log(sigma)
- def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
- BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -465,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def add(data: AFTPoint): this.type = {
-
- val interceptFlag = if (fitIntercept) 1.0 else 0.0
-
- val xi = data.features.toBreeze
+ val xi = data.features
val ti = data.label
val delta = data.censor
- val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
- lossSum += math.log(sigma) * delta
- lossSum += (math.exp(epsilon) - delta * epsilon)
+ val margin = {
+ var sum = 0.0
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ sum += coefficients(index) * (value / featuresStd(index))
+ }
+ }
+ sum + intercept
+ }
+ val epsilon = (math.log(ti) - margin) / sigma
+
+ lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
- // Sanity check (should never occur):
- assert(!lossSum.isInfinity,
- s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+ val multiplier = (delta - math.exp(epsilon)) / sigma
- val deltaMinusExpEps = delta - math.exp(epsilon)
- gradientCoefficientSum += xi * deltaMinusExpEps / sigma
- gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
- gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
+ gradientSumArray(0) += delta + multiplier * sigma * epsilon
+ gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+ }
+ }
totalCnt += 1
this
@@ -502,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
- gradientCoefficientSum += other.gradientCoefficientSum
- gradientInterceptSum += other.gradientInterceptSum
- gradientLogSigmaSum += other.gradientLogSigmaSum
+ var i = 0
+ val len = this.gradientSumArray.length
+ while (i < len) {
+ this.gradientSumArray(i) += other.gradientSumArray(i)
+ i += 1
+ }
}
this
}
@@ -515,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* It returns the loss and gradient at a particular point (parameters).
* It's used in Breeze's convex optimization routines.
*/
-private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
- extends DiffFunction[BDV[Double]] {
+private class AFTCostFun(
+ data: RDD[AFTPoint],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
- val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
+ val aftAggregator = data.treeAggregate(
+ new AFTAggregator(parameters, fitIntercept, featuresStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..c04c416aaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
- override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).impurityStats.calculate()
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
transformImpl(dataset)
}
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
- var output = dataset
+ var output = dataset.toDF
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..741724d7a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,23 +18,23 @@
package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
- TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -42,12 +42,24 @@ import org.apache.spark.sql.functions._
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for regression.
* It supports both continuous and categorical features.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
*/
@Since("1.4.0")
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
- with GBTParams with TreeRegressorParams with Logging {
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtr"))
@@ -101,42 +113,13 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTRegressor:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "squared" (L2) and "absolute" (L1)
- * (default = squared)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "squared")
+ // Parameters from GBTRegressorParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "squared" => OldSquaredError
- case "absolute" => OldAbsoluteError
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
- }
- }
-
- override protected def train(dataset: DataFrame): GBTRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -153,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
@Experimental
-object GBTRegressor {
- // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
}
/**
@@ -177,9 +163,10 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -193,12 +180,12 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -213,6 +200,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -224,16 +214,81 @@ final class GBTRegressionModel private[ml](
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
categoricalFeatures: Map[Int, Int],
@@ -245,6 +300,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d8e1..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -31,9 +31,9 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
+ *
* @group param
*/
@Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ *
* @group param
*/
@Since("2.0.0")
@@ -163,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
setDefault(tol -> 1E-6)
/**
- * Sets the regularization parameter.
+ * Sets the regularization parameter for L2 regularization.
+ * The regularization term is
+ * {{{
+ * 0.5 * regParam * L2norm(coefficients)^2
+ * }}}
* Default is 0.0.
* @group setParam
*/
@@ -190,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
- override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
Link.fromName($(link))
@@ -210,9 +216,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
if (familyObj == Gaussian && linkObj == Identity) {
@@ -230,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -250,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -698,7 +707,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -769,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* :: Experimental ::
* Summarizing Generalized Linear regression Fits.
*
- * @param predictions predictions outputted by the model's `transform` method
+ * @param predictions predictions output by the model's `transform` method
* @param predictionCol field in "predictions" which gives the prediction value of each instance
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -782,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
@@ -930,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val coefficientStandardErrors: Array[Double] = {
@@ -938,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val tValues: Array[Double] = {
@@ -951,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val pValues: Array[Double] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9a34..7a78ecbdf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
* Extracts (label, feature, weight) from input dataset.
*/
protected[ml] def extractWeightedLabeledPoints(
- dataset: DataFrame): RDD[(Double, Double, Double)] = {
+ dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
val idx = $(featureIndex)
val extract = udf { v: Vector => v(idx) }
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
- dataset.select(col($(labelCol)), f, w).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
@@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
@@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b81c588e44..71e02730c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -38,8 +38,9 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -57,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* The specific squared error loss function used is:
* L = 1/2n ||A coefficients - y||^2^
*
- * This support multiple types of regularization:
+ * This supports multiple types of regularization:
* - none (a.k.a. ordinary least squares)
* - L2 (ridge regression)
* - L1 (Lasso)
@@ -157,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "auto")
- override protected def train(dataset: DataFrame): LinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
case Row(features: Vector) => features.size
@@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
@@ -189,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
- $(featuresCol),
Array(0D))
return lrModel.setSummary(trainingSummary)
@@ -248,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
@@ -355,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -412,15 +413,15 @@ class LinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), this, Array(0D))
+ $(labelCol), $(featuresCol), summaryModel, Array(0D))
}
/**
@@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -510,9 +511,9 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
/**
* :: Experimental ::
* Linear regression training results. Currently, the training summary ignores the
- * training coefficients except for the objective trace.
+ * training weights except for the objective trace.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Since("1.5.0")
@@ -521,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
- val featuresCol: String,
val objectiveHistory: Array[Double])
- extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) {
+ extends LinearRegressionSummary(
+ predictions,
+ predictionCol,
+ labelCol,
+ featuresCol,
+ model,
+ diagInvAtWA) {
- /** Number of training iterations until termination */
+ /**
+ * Number of training iterations until termination
+ *
+ * This value is only available when using the "l-bfgs" solver.
+ * @see [[LinearRegression.solver]]
+ */
@Since("1.5.0")
val totalIterations = objectiveHistory.length
@@ -537,7 +549,11 @@ class LinearRegressionTrainingSummary private[regression] (
* :: Experimental ::
* Linear regression results evaluated on a dataset.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
+ * @param predictionCol Field in "predictions" which gives the predicted value of the label at
+ * each instance.
+ * @param labelCol Field in "predictions" which gives the true label of each instance.
+ * @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
@@ -545,12 +561,13 @@ class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String,
+ val featuresCol: String,
val model: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
@transient private val metrics = new RegressionMetrics(
predictions
- .select(predictionCol, labelCol)
+ .select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!model.getFitIntercept)
@@ -638,6 +655,12 @@ class LinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -653,12 +676,18 @@ class LinearRegressionSummary private[regression] (
col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
}
val sigma2 = rss / degreesOfFreedom
- diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
+ diagInvAtWA.map(_ * sigma2).map(math.sqrt)
}
}
/**
* T-statistic of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -676,6 +705,12 @@ class LinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -826,7 +861,7 @@ private class LeastSquaresAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..4c4ff278d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,18 +17,22 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 13a13f0a7e..2f1f2523fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,23 +19,25 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
-import org.apache.spark.util.collection.BitSet
private[libsvm] class LibSVMOutputWriter(
path: String,
@@ -110,13 +112,16 @@ class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
+ override def toString: String = "LibSVM"
+
private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
- throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
+ throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
+
override def inferSchema(
sqlContext: SQLContext,
options: Map[String, String],
@@ -127,6 +132,32 @@ class DefaultSource extends FileFormat with DataSourceRegister {
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}
+ override def prepareRead(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Map[String, String] = {
+ def computeNumFeatures(): Int = {
+ val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
+ val path = if (dataFiles.length == 1) {
+ dataFiles.head.getPath.toUri.toString
+ } else if (dataFiles.isEmpty) {
+ throw new IOException("No input path specified for libsvm data")
+ } else {
+ throw new IOException("Multiple input paths are not supported for libsvm data.")
+ }
+
+ val sc = sqlContext.sparkContext
+ val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
+ MLUtils.computeNumFeatures(parsed)
+ }
+
+ val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
+ computeNumFeatures()
+ }
+
+ new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+ }
+
override def prepareWrite(
sqlContext: SQLContext,
job: Job,
@@ -144,36 +175,51 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- // TODO: This does not handle cases where column pruning has been performed.
-
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
- val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
-
- val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
- else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
- else throw new IOException("Multiple input paths are not supported for libsvm data.")
-
- val numFeatures = options.getOrElse("numFeatures", "-1").toInt
- val vectorType = options.getOrElse("vectorType", "sparse")
-
- val sc = sqlContext.sparkContext
- val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
- val sparse = vectorType == "sparse"
- baseRdd.map { pt =>
- val features = if (sparse) pt.features.toSparse else pt.features.toDense
- Row(pt.label, features)
- }.mapPartitions { externalRows =>
- val converter = RowEncoder(dataSchema)
- externalRows.map(converter.toRow)
+ val numFeatures = options("numFeatures").toInt
+ assert(numFeatures > 0)
+
+ val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
+
+ val broadcastedConf = sqlContext.sparkContext.broadcast(
+ new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
+ )
+
+ (file: PartitionedFile) => {
+ val points =
+ new HadoopFileLinesReader(file, broadcastedConf.value.value)
+ .map(_.toString.trim)
+ .filterNot(line => line.isEmpty || line.startsWith("#"))
+ .map { line =>
+ val (label, indices, values) = MLUtils.parseLibSVMRecord(line)
+ LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
+ }
+
+ val converter = RowEncoder(requiredSchema)
+
+ val unsafeRowIterator = points.map { pt =>
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
+ converter.toRow(Row(pt.label, features))
+ }
+
+ def toAttribute(f: StructField): AttributeReference =
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+
+ // Appends partition values
+ val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 572815df0b..4e372702f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.commons.math3.distribution.PoissonDistribution
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index c745e9f8db..61091bb803 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
@@ -86,6 +86,7 @@ private[spark] class DTStatsAggregator(
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ *
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
*/
@@ -118,6 +119,7 @@ private[spark] class DTStatsAggregator(
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
*/
@@ -138,6 +140,7 @@ private[spark] class DTStatsAggregator(
/**
* For a given feature, merge the stats for two bins.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* @param binIndex The other bin is merged into this bin.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 4f27dc44ef..c7cde1563f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}
+
+ val isIntRegex = "^([1-9]\\d*)$".r
+ val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
+ case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4dfe..b6334762c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -20,16 +20,17 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
-import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-private[ml] object GradientBoostedTrees extends Logging {
+
+private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
@@ -106,7 +107,7 @@ private[ml] object GradientBoostedTrees extends Logging {
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
- val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
+ val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
@@ -132,7 +133,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
- val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
+ val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
@@ -141,6 +142,97 @@ private[ml] object GradientBoostedTrees extends Logging {
}
/**
+ * Add prediction from a new boosting iteration to an existing prediction.
+ *
+ * @param features Vector of features representing a single data point.
+ * @param prediction The existing prediction.
+ * @param tree New Decision Tree model.
+ * @param weight Tree weight.
+ * @return Updated prediction.
+ */
+ def updatePrediction(
+ features: Vector,
+ prediction: Double,
+ tree: DecisionTreeRegressionModel,
+ weight: Double): Double = {
+ prediction + tree.rootNode.predictImpl(features).prediction * weight
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @return Measure of model error on data
+ */
+ def computeError(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss): Double = {
+ data.map { lp =>
+ val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
+ updatePrediction(lp.features, acc, model, weight)
+ }
+ loss.computeError(predicted, lp.label)
+ }.mean()
+ }
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @param algo algorithm for the ensemble, either Classification or Regression
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss,
+ algo: OldAlgo.Value): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+ val localTreeWeights = treeWeights
+
+ var predictionAndError = computeInitialPredictionAndError(
+ remappedData, localTreeWeights(0), trees(0), loss)
+
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ val broadcastTrees = sc.broadcast(trees)
+ (1 until numIterations).foreach { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = localTreeWeights(nTree)
+ iter.map { case (point, (pred, error)) =>
+ val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight)
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ evaluationArray
+ }
+
+ /**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index 2c8286766f..9d697a36b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.{LearningNode, Split}
-import org.apache.spark.mllib.tree.impl.BaggedPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..7b1fd089f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,16 +26,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
- TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -332,7 +328,7 @@ private[spark] object RandomForest extends Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+ * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
@@ -1105,112 +1101,4 @@ private[spark] object RandomForest extends Logging {
}
}
- /**
- * Given a Random Forest model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
- *
- * @param trees Unweighted forest of trees
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
- val totalImportances = new OpenHashMap[Int, Double]()
- trees.foreach { tree =>
- // Aggregate feature importance vector for this tree
- val importances = new OpenHashMap[Int, Double]()
- computeFeatureImportance(tree.rootNode, importances)
- // Normalize importance vector for this tree, and add it to total.
- // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
- val treeNorm = importances.map(_._2).sum
- if (treeNorm != 0) {
- importances.foreach { case (idx, impt) =>
- val normImpt = impt / treeNorm
- totalImportances.changeValue(idx, normImpt, _ + normImpt)
- }
- }
- }
- // Normalize importances
- normalizeMapValues(totalImportances)
- // Construct vector
- val d = if (numFeatures != -1) {
- numFeatures
- } else {
- // Find max feature index used in trees
- val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
- maxFeatureIndex + 1
- }
- if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing feature" +
- s" importance: No splits found, but some non-zero importances.")
- }
- val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
- Vectors.sparse(d, indices.toArray, values.toArray)
- }
-
- /**
- * Given a Decision Tree model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- *
- * @param tree Decision tree to compute importances for.
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
- featureImportances(Array(tree), numFeatures)
- }
-
- /**
- * Recursive method for computing feature importances for one tree.
- * This walks down the tree, adding to the importance of 1 feature at each node.
- * @param node Current node in recursion
- * @param importances Aggregate feature importances, modified by this method
- */
- private[impl] def computeFeatureImportance(
- node: Node,
- importances: OpenHashMap[Int, Double]): Unit = {
- node match {
- case n: InternalNode =>
- val feature = n.split.featureIndex
- val scaledGain = n.gain * n.impurityStats.count
- importances.changeValue(feature, scaledGain, _ + scaledGain)
- computeFeatureImportance(n.leftChild, importances)
- computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
- // do nothing
- }
- }
-
- /**
- * Normalize the values of this map to sum to 1, in place.
- * If all values are 0, this method does nothing.
- * @param map Map with non-negative values.
- */
- private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
- val total = map.map(_._2).sum
- if (total != 0) {
- val keys = map.iterator.map(_._1).toArray
- keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
index 70afaa162b..4cc250aa46 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable.{HashMap => MutableHashMap}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index 9fa27e5e1f..3a2bf3c725 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c9068f..f38e1ec7c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,16 +17,22 @@
package org.apache.spark.ml.tree
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.DefaultParamsReader
+import org.apache.spark.ml.param.{Param, Params}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, SQLContext}
+import org.apache.spark.util.collection.OpenHashMap
/**
* Abstraction for Decision Tree models.
@@ -70,7 +76,7 @@ private[spark] trait DecisionTreeModel {
*/
private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
private[spark] def toOld: OldDecisionTreeModel
}
@@ -78,14 +84,21 @@ private[spark] trait DecisionTreeModel {
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ *
+ * @tparam M Type of tree model in this ensemble
*/
-private[ml] trait TreeEnsembleModel {
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
// DecisionTreeModel.
/** Trees in this ensemble. Warning: These have null parent Estimators. */
- def trees: Array[DecisionTreeModel]
+ def trees: Array[M]
+
+ /**
+ * Number of trees in ensemble
+ */
+ val getNumTrees: Int = trees.length
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
@@ -97,7 +110,7 @@ private[ml] trait TreeEnsembleModel {
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
- s"TreeEnsembleModel with $numTrees trees"
+ s"TreeEnsembleModel with ${trees.length} trees"
}
/** Full description of model */
@@ -108,13 +121,129 @@ private[ml] trait TreeEnsembleModel {
}.fold("")(_ + _)
}
- /** Number of trees in ensemble */
- val numTrees: Int = trees.length
-
/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
+private[ml] object TreeEnsembleModel {
+
+ /**
+ * Given a tree ensemble model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * References:
+ * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+ *
+ * @param trees Unweighted collection of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ *
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ *
+ * @param map Map with non-negative values.
+ */
+ def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+}
+
/** Helper classes for tree model persistence */
private[ml] object DecisionTreeModelReadWrite {
@@ -196,6 +325,10 @@ private[ml] object DecisionTreeModelReadWrite {
}
}
+ /**
+ * Load a decision tree from a file.
+ * @return Root node of reconstructed tree
+ */
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
@@ -211,9 +344,18 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
+ buildTreeFromNodes(data.collect(), impurityType)
+ }
+ /**
+ * Given all data for all nodes in a tree, rebuild the tree.
+ * @param data Unsorted node data
+ * @param impurityType Impurity type for this tree
+ * @return Root node of reconstructed tree
+ */
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
// Load all nodes, sorted by ID.
- val nodes: Array[NodeData] = data.collect().sortBy(_.id)
+ val nodes = data.sortBy(_.id)
// Sanity checks; could remove
assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
s" but found ${nodes.head.id}")
@@ -238,3 +380,105 @@ private[ml] object DecisionTreeModelReadWrite {
finalNodes.head
}
}
+
+private[ml] object EnsembleModelReadWrite {
+
+ /**
+ * Helper method for saving a tree ensemble to disk.
+ *
+ * @param instance Tree ensemble model
+ * @param path Path to which to save the ensemble model.
+ * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
+ */
+ def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
+ instance: M,
+ path: String,
+ sql: SQLContext,
+ extraMetadata: JObject): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
+ val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
+ case (tree, treeID) =>
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
+ instance.treeWeights(treeID))
+ }
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
+ .write.parquet(treesMetadataPath)
+ val dataPath = new Path(path, "data").toString
+ val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
+ case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
+ }
+ sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ }
+
+ /**
+ * Helper method for loading a tree ensemble from disk.
+ * This reconstructs all trees, returning the root nodes.
+ * @param path Path given to [[saveImpl()]]
+ * @param className Class name for ensemble model type
+ * @param treeClassName Class name for tree model type in the ensemble
+ * @return (ensemble metadata, array over trees of (tree metadata, root node)),
+ * where the root node is linked with all descendents
+ * @see [[saveImpl()]] for how the model was saved
+ */
+ def loadImpl(
+ path: String,
+ sql: SQLContext,
+ className: String,
+ treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
+ import sql.implicits._
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
+ case (treeID: Int, json: String, weights: Double) =>
+ treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights)
+ }
+
+ val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+ val treesMetadata = treesMetadataWeights.map(_._1)
+ val treesWeights = treesMetadataWeights.map(_._2)
+
+ val dataPath = new Path(path, "data").toString
+ val nodeData: Dataset[EnsembleNodeData] =
+ sql.read.parquet(dataPath).as[EnsembleNodeData]
+ val rootNodesRDD: RDD[(Int, Node)] =
+ nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
+ case (treeID: Int, nodeData: Iterable[NodeData]) =>
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ }
+ val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
+ (metadata, treesMetadata.zip(rootNodes), treesWeights)
+ }
+
+ /**
+ * Info for one [[Node]] in a tree ensemble
+ *
+ * @param treeID Tree index
+ * @param nodeData Data for this node
+ */
+ case class EnsembleNodeData(
+ treeID: Int,
+ nodeData: NodeData)
+
+ object EnsembleNodeData {
+ /**
+ * Create [[EnsembleNodeData]] instances for the given tree.
+ *
+ * @return Sequence of nodes for this tree
+ */
+ def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = {
+ val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0)
+ nodeData.map(nd => EnsembleNodeData(treeID, nd))
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 4fbd957677..b6783911ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/**
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
+/** Used for [[RandomForestParams]] */
+private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* The number of features to consider for splits at each tree node.
@@ -343,6 +329,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -360,29 +348,71 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+ setDefault(featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getNumTrees: Int = $(numTrees)
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+/**
+ * Used for [[RandomForestParams]].
+ * This is separated out from [[RandomForestParams]] because of an issue with the
+ * `numTrees` method conflicting with this Param in the Estimator.
+ */
+private[ml] trait HasNumTrees extends Params {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+ final def getNumTrees: Int = $(numTrees)
}
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with HasNumTrees
+
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeClassifierParams
+
+private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeRegressorParams
+
+private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeRegressorParams
+
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
@@ -432,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81cb3e..de563d4fad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,27 +17,25 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
-
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -45,6 +43,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
+ *
* @group param
*/
val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -91,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.4.0")
- override def fit(dataset: DataFrame): CrossValidatorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
@@ -101,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
+ val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
@@ -163,10 +162,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit =
- SharedReadWrite.saveImpl(path, instance, sc)
+ ValidatorParams.saveImpl(path, instance, sc)
}
private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +174,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
@@ -184,123 +186,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
.setNumFolds(numFolds)
}
}
-
- private object CrossValidatorReader {
- /**
- * Examine the given estimator (which may be a compound estimator) and extract a mapping
- * from UIDs to corresponding [[Params]] instances.
- */
- def getUidMap(instance: Params): Map[String, Params] = {
- val uidList = getUidMapImpl(instance)
- val uidMap = uidList.toMap
- if (uidList.size != uidMap.size) {
- throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
- s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
- }
- uidMap
- }
-
- def getUidMapImpl(instance: Params): List[(String, Params)] = {
- val subStages: Array[Params] = instance match {
- case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
- case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
- case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
- case ovr: OneVsRestParams =>
- // TODO: SPARK-11892: This case may require special handling.
- throw new UnsupportedOperationException("CrossValidator write will fail because it" +
- " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
- case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
- case _: Params => Array()
- }
- val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
- List((instance.uid, instance)) ++ subStageMaps
- }
- }
-
- private[tuning] object SharedReadWrite {
-
- /**
- * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
- * This does not check [[CrossValidator.estimatorParamMaps]].
- */
- def validateParams(instance: ValidatorParams): Unit = {
- def checkElement(elem: Params, name: String): Unit = elem match {
- case stage: MLWritable => // good
- case other =>
- throw new UnsupportedOperationException("CrossValidator write will fail " +
- s" because it contains $name which does not implement Writable." +
- s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
- }
- checkElement(instance.getEvaluator, "evaluator")
- checkElement(instance.getEstimator, "estimator")
- // Check to make sure all Params apply to this estimator. Throw an error if any do not.
- // Extraneous Params would cause problems when loading the estimatorParamMaps.
- val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
- instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
- pMap.toSeq.foreach { case ParamPair(p, v) =>
- require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
- s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
- s" Evaluator. An extraneous Param was found: $p")
- }
- }
- }
-
- private[tuning] def saveImpl(
- path: String,
- instance: CrossValidatorParams,
- sc: SparkContext,
- extraMetadata: Option[JObject] = None): Unit = {
- import org.json4s.JsonDSL._
-
- val estimatorParamMapsJson = compact(render(
- instance.getEstimatorParamMaps.map { case paramMap =>
- paramMap.toSeq.map { case ParamPair(p, v) =>
- Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
- }
- }.toSeq
- ))
- val jsonParams = List(
- "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
- "estimatorParamMaps" -> parse(estimatorParamMapsJson)
- )
- DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
-
- val evaluatorPath = new Path(path, "evaluator").toString
- instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
- val estimatorPath = new Path(path, "estimator").toString
- instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
- }
-
- private[tuning] def load[M <: Model[M]](
- path: String,
- sc: SparkContext,
- expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
-
- val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
-
- implicit val format = DefaultFormats
- val evaluatorPath = new Path(path, "evaluator").toString
- val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
- val estimatorPath = new Path(path, "estimator").toString
- val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
- val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
-
- val numFolds = (metadata.params \ "numFolds").extract[Int]
- val estimatorParamMaps: Array[ParamMap] =
- (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
- pMap =>
- val paramPairs = pMap.map { case pInfo: Map[String, String] =>
- val est = uidToParams(pInfo("parent"))
- val param = est.getParam(pInfo("name"))
- val value = param.jsonDecode(pInfo("value"))
- param -> value
- }
- ParamMap(paramPairs: _*)
- }.toArray
- (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
- }
- }
}
/**
@@ -319,8 +204,13 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
+ this(uid, bestModel, avgMetrics.asScala.toArray)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
@@ -346,8 +236,6 @@ class CrossValidatorModel private[ml] (
@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
- import CrossValidator.SharedReadWrite
-
@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
@@ -357,12 +245,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
private[CrossValidatorModel]
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
- SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
}
@@ -376,8 +264,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 70fa5f0234..12d6905510 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,22 +17,32 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
*/
-private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
+ *
* @group param
*/
val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +65,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
@Experimental
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[TrainValidationSplitModel]
- with TrainValidationSplitParams with Logging {
+ with TrainValidationSplitParams with MLWritable with Logging {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("tvs"))
@@ -76,8 +86,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): TrainValidationSplitModel = {
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
@@ -87,10 +101,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val Array(training, validation) =
- dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
- val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
- val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+ val Array(trainingDataset, validationDataset) =
+ dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
+ trainingDataset.cache()
+ validationDataset.cache()
// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
@@ -130,6 +144,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplit = super.load(path)
+
+ private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit)
+ extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit =
+ ValidatorParams.saveImpl(path, instance, sc)
+ }
+
+ private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplit].getName
+
+ override def load(path: String): TrainValidationSplit = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ new TrainValidationSplit(metadata.uid)
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimatorParamMaps)
+ .setTrainRatio(trainRatio)
+ }
+ }
}
/**
@@ -146,10 +201,15 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val bestModel: Model[_],
@Since("1.5.0") val validationMetrics: Array[Double])
- extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+ extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = {
+ this(uid, bestModel, validationMetrics.asScala.toArray)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
@@ -167,4 +227,53 @@ class TrainValidationSplitModel private[ml] (
validationMetrics.clone())
copyValues(copied, extra)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplitModel = super.load(path)
+
+ private[TrainValidationSplitModel]
+ class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ import org.json4s.JsonDSL._
+ val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplitModel].getName
+
+ override def load(path: String): TrainValidationSplitModel = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
+ val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
+ tvs.set(tvs.estimator, estimator)
+ .set(tvs.evaluator, evaluator)
+ .set(tvs.estimatorParamMaps, estimatorParamMaps)
+ .set(tvs.trainRatio, trainRatio)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 953456e8f0..7a4e106aeb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,9 +17,17 @@
package org.apache.spark.ml.tuning
-import org.apache.spark.ml.Estimator
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, _}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite,
+ MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
/**
@@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params {
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
}
+
+private[ml] object ValidatorParams {
+ /**
+ * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable.
+ * This does not check [[ValidatorParams.estimatorParamMaps]].
+ */
+ def validateParams(instance: ValidatorParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " +
+ s" because it contains $name which does not implement Writable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+ checkElement(instance.getEvaluator, "evaluator")
+ checkElement(instance.getEstimator, "estimator")
+ // Check to make sure all Params apply to this estimator. Throw an error if any do not.
+ // Extraneous Params would cause problems when loading the estimatorParamMaps.
+ val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance)
+ instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" +
+ s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" +
+ s" Evaluator. An extraneous Param was found: $p")
+ }
+ }
+ }
+
+ /**
+ * Generic implementation of save for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing
+ * class needs to handle model data.
+ */
+ def saveImpl(
+ path: String,
+ instance: ValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+ import org.json4s.JsonDSL._
+
+ val estimatorParamMapsJson = compact(render(
+ instance.getEstimatorParamMaps.map { case paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ }
+ }.toSeq
+ ))
+
+ val validatorSpecificParams = instance match {
+ case cv: CrossValidatorParams =>
+ List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
+ case tvs: TrainValidationSplitParams =>
+ List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
+ case _ =>
+ // This should not happen.
+ throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
+ instance.getClass.getCanonicalName)
+ }
+
+ val jsonParams = validatorSpecificParams ++ List(
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ /**
+ * Generic implementation of load for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields, but the implementing
+ * class needs to handle model data and special [[Param]] values.
+ */
+ def loadImpl[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)
+
+ val estimatorParamMaps: Array[ParamMap] =
+ (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+
+ (metadata, estimator, evaluator, estimatorParamMaps)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
new file mode 100644
index 0000000000..7e57cefc44
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+
+/**
+ * A small wrapper that defines a training session for an estimator, and some methods to log
+ * useful information during this session.
+ *
+ * A new instance is expected to be created within fit().
+ *
+ * @param estimator the estimator that is being fit
+ * @param dataset the training dataset
+ * @tparam E the type of the estimator
+ */
+private[ml] class Instrumentation[E <: Estimator[_]] private (
+ estimator: E, dataset: RDD[_]) extends Logging {
+
+ private val id = Instrumentation.counter.incrementAndGet()
+ private val prefix = {
+ val className = estimator.getClass.getSimpleName
+ s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
+ }
+
+ init()
+
+ private def init(): Unit = {
+ log(s"training: numPartitions=${dataset.partitions.length}" +
+ s" storageLevel=${dataset.getStorageLevel}")
+ }
+
+ /**
+ * Logs a message with a prefix that uniquely identifies the training session.
+ */
+ def log(msg: String): Unit = {
+ logInfo(prefix + msg)
+ }
+
+ /**
+ * Logs the value of the given parameters for the estimator being used in this session.
+ */
+ def logParams(params: Param[_]*): Unit = {
+ val pairs: Seq[(String, JValue)] = for {
+ p <- params
+ value <- estimator.get(p)
+ } yield {
+ val cast = p.asInstanceOf[Param[Any]]
+ p.name -> parse(cast.jsonEncode(value))
+ }
+ log(compact(render(map2jvalue(pairs.toMap))))
+ }
+
+ def logNumFeatures(num: Long): Unit = {
+ log(compact(render("numFeatures" -> num)))
+ }
+
+ def logNumClasses(num: Long): Unit = {
+ log(compact(render("numClasses" -> num)))
+ }
+
+ /**
+ * Logs the successful completion of the training session and the value of the learned model.
+ */
+ def logSuccess(model: Model[_]): Unit = {
+ log(s"training finished")
+ }
+}
+
+/**
+ * Some common methods for logging information about a training session.
+ */
+private[ml] object Instrumentation {
+ private val counter = new AtomicLong(0)
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
+ create[E](estimator, dataset.rdd)
+ }
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: RDD[_]): Instrumentation[E] = {
+ new Instrumentation[E](estimator, dataset)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536abd..7dec07ea14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
+import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
@@ -139,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
/**
* Abstract class for utility classes that can load ML instances.
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -157,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite {
/**
* Trait for objects that provide [[MLReader]].
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -187,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
@@ -206,6 +214,7 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap
* - (optionally, extra metadata)
+ *
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
@@ -217,6 +226,22 @@ private[ml] object DefaultParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+
+ /**
+ * Helper for [[saveMetadata()]] which extracts the JSON to save.
+ * This is useful for ensemble models which need to save metadata for many sub-models.
+ *
+ * @see [[saveMetadata()]] for details on what this includes.
+ */
+ def getMetadataToSave(
+ instance: Params,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -234,9 +259,8 @@ private[ml] object DefaultParamsWriter {
case None =>
basicMetadata
}
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ val metadataJson: String = compact(render(metadata))
+ metadataJson
}
}
@@ -244,6 +268,7 @@ private[ml] object DefaultParamsWriter {
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
@@ -263,6 +288,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
+ *
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
@@ -299,13 +325,26 @@ private[ml] object DefaultParamsReader {
}
/**
- * Load metadata from file.
+ * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
+ *
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
+ parseMetadata(metadataStr, expectedClassName)
+ }
+
+ /**
+ * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
+ * This is a helper function for [[loadMetadata()]].
+ *
+ * @param metadataStr JSON string of metadata
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
@@ -352,3 +391,36 @@ private[ml] object DefaultParamsReader {
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
+ s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
+ }
+ uidMap
+ }
+
+ private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRest => Array(ovr.getClassifier)
+ case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
+ case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76021ad8f4..334410c962 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.util
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
/**
@@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
}
/**
- * Check whether the given schema contains a column of one of the require data types.
- * @param colName column name
- * @param dataTypes required column data types
- */
+ * Check whether the given schema contains a column of one of the require data types.
+ * @param colName column name
+ * @param dataTypes required column data types
+ */
def checkColumnTypes(
schema: StructType,
colName: String,
@@ -61,6 +61,20 @@ private[spark] object SchemaUtils {
}
/**
+ * Check whether the given schema contains a column of the numeric data type.
+ * @param colName column name
+ */
+ def checkNumericType(
+ schema: StructType,
+ colName: String,
+ msg: String = ""): Unit = {
+ val actualDataType = schema(colName).dataType
+ val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+ require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
+ s"NumericType but was actually of type $actualDataType.$message")
+ }
+
+ /**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param colName new column name. If this column name is an empty string "", this method returns
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index a689b09341..364d5eea08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -24,15 +24,15 @@ import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
- * Wrapper around GaussianMixtureModel to provide helper methods in Python
- */
+ * Wrapper around GaussianMixtureModel to provide helper methods in Python
+ */
private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
val weights: Vector = Vectors.dense(model.weights)
val k: Int = weights.size
/**
- * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
- */
+ * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
+ */
val gaussians: Array[Byte] = {
val modelGaussians = model.gaussians.map { gaussian =>
Array[Any](gaussian.mu, gaussian.sigma)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
index 073f03e16f..05273c3434 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -27,8 +27,8 @@ import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
- * Wrapper around Word2VecModel to provide helper methods in Python
- */
+ * Wrapper around Word2VecModel to provide helper methods in Python
+ */
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
def transform(word: String): Vector = {
model.transform(word)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index c0404be019..f10570e662 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -418,7 +418,7 @@ class LogisticRegressionWithLBFGS
private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean):
LogisticRegressionModel = {
- // ml's Logisitic regression only supports binary classifcation currently.
+ // ml's Logistic regression only supports binary classification currently.
if (numOfLinearPredictor == 1) {
def runWithMlLogisitcRegression(elasticNetParam: Double) = {
// Prepare the ml LogisticRegression based on our settings
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 64b838a1db..e4bd0dc25e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable {
private[clustering] class ClusteringTreeNode private[clustering] (
val index: Int,
val size: Long,
- private val centerWithNorm: VectorWithNorm,
+ private[clustering] val centerWithNorm: VectorWithNorm,
val cost: Double,
val height: Double,
val children: Array[ClusteringTreeNode]) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 01a0d31f14..c3b5b8b790 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -17,11 +17,19 @@
package org.apache.spark.mllib.clustering
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
/**
* Clustering model produced by [[BisectingKMeans]].
@@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD
@Experimental
class BisectingKMeansModel private[clustering] (
private[clustering] val root: ClusteringTreeNode
- ) extends Serializable with Logging {
+ ) extends Serializable with Saveable with Logging {
/**
* Leaf cluster centers.
@@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd)
+
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+ val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val rootId = (metadata \ "rootId").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, formatVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path, rootId)
+ model
+ case _ => throw new Exception(
+ s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $formatVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double,
+ height: Double, children: Seq[Int])
+
+ private object Data {
+ def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3),
+ r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
+ }
+
+ private[clustering] object SaveLoadV1_0 {
+ private val thisFormatVersion = "1.0"
+
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+ def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+ ~ ("rootId" -> model.root.index)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val data = getNodes(model.root).map(node => Data(node.index, node.size,
+ node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+ node.children.map(_.index)))
+ val dataRDD = sc.parallelize(data).toDF()
+ dataRDD.write.parquet(Loader.dataPath(path))
+ }
+
+ private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+ if (node.children.isEmpty) {
+ Array(node)
+ } else {
+ node.children.flatMap(getNodes(_)) ++ Array(node)
+ }
+ }
+
+ def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val rows = sqlContext.read.parquet(Loader.dataPath(path))
+ Loader.checkSchema[Data](rows.schema)
+ val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+ val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+ val rootNode = buildTree(rootId, nodes)
+ new BisectingKMeansModel(rootNode)
+ }
+
+ private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+ val root = nodes.get(rootId).get
+ if (root.children.isEmpty) {
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, new Array[ClusteringTreeNode](0))
+ } else {
+ val children = root.children.map(c => buildTree(c, nodes))
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, children.toArray)
+ }
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 03eb903bb8..f04c87259c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -181,13 +181,12 @@ class GaussianMixture private (
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weights, gmm.gaussians)
- case None => {
+ case None =>
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
- }
}
var llh = Double.MinValue // current log-likelihood
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 02417b1124..f87613cc72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -183,7 +183,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
val k = (metadata \ "k").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
(loadedClassName, version) match {
- case (classNameV1_0, "1.0") => {
+ case (classNameV1_0, "1.0") =>
val model = SaveLoadV1_0.load(sc, path)
require(model.weights.length == k,
s"GaussianMixtureModel requires weights of length $k " +
@@ -192,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
s"GaussianMixtureModel requires gaussians of length $k" +
s"got gaussians of length ${model.gaussians.length}")
model
- }
case _ => throw new Exception(
s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index a7beb81980..8ff0b83e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -253,16 +253,14 @@ class KMeans private (
}
val centers = initialModel match {
- case Some(kMeansCenters) => {
+ case Some(kMeansCenters) =>
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
- }
- case None => {
+ case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
- }
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
@@ -390,6 +388,8 @@ class KMeans private (
// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
+ // Could be empty if data is empty; fail with a better message early:
+ require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
/** Merges new centers to centers. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 12813fd412..d999b9be8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -130,7 +130,8 @@ class LDA private (
*/
@Since("1.5.0")
def setDocConcentration(docConcentration: Vector): this.type = {
- require(docConcentration.size > 0, "docConcentration must have > 0 elements")
+ require(docConcentration.size == 1 || docConcentration.size == k,
+ s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}")
this.docConcentration = docConcentration
this
}
@@ -260,15 +261,18 @@ class LDA private (
def getCheckpointInterval: Int = checkpointInterval
/**
- * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
+ * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
+ * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery
* (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
* important when LDA is run for many iterations. If the checkpoint directory is not set in
- * [[org.apache.spark.SparkContext]], this setting is ignored.
+ * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10)
*
* @see [[org.apache.spark.SparkContext#setCheckpointDir]]
*/
@Since("1.3.0")
def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ require(checkpointInterval == -1 || checkpointInterval > 0,
+ s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}")
this.checkpointInterval = checkpointInterval
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 25d67a3756..27b4004927 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
private[spark] val iterationTimes: Array[Double],
- override protected[clustering] val gammaShape: Double = 100)
+ override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape,
+ private[spark] val checkpointFiles: Array[String] = Array.empty[String])
extends LDAModel {
import LDA._
@@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] (
override protected def formatVersion = "1.0"
- /**
- * Java-friendly version of [[topicDistributions]]
- */
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
+ // Note: This intentionally does not save checkpointFiles.
DistributedLDAModel.SaveLoadV1_0.save(
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
iterationTimes, gammaShape)
@@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0")
object DistributedLDAModel extends Loader[DistributedLDAModel] {
+ /**
+ * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100
+ * to ensure equivalence in LDAModel.toLocal conversion.
+ */
+ private[clustering] val defaultGammaShape: Double = 100
+
private object SaveLoadV1_0 {
val thisFormatVersion = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7491ab0d51..6418f0d3b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer {
import LDA._
+ // Adjustable parameters
+ private var keepLastCheckpoint: Boolean = true
+
/**
- * The following fields will only be initialized through the initialize() method
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint
+
+ /**
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with
+ * care. Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * Default: true
*/
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = {
+ this.keepLastCheckpoint = keepLastCheckpoint
+ this
+ }
+
+ // The following fields will only be initialized through the initialize() method
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
private[clustering] var k: Int = 0
private[clustering] var vocabSize: Int = 0
@@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
- this.graphCheckpointer.deleteAllCheckpoints()
+ val checkpointFiles: Array[String] = if (keepLastCheckpoint) {
+ this.graphCheckpointer.deleteAllCheckpointsButLast()
+ this.graphCheckpointer.getAllCheckpointFiles
+ } else {
+ this.graphCheckpointer.deleteAllCheckpoints()
+ Array.empty[String]
+ }
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
- // LDAModel.toLocal conversion
+ // LDAModel.toLocal conversion.
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
- iterationTimes)
+ iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles)
}
}
@@ -451,10 +477,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
Iterator((stat, gammaPart))
}
- val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
+ _ += _, _ += _)
expElogbetaBc.unpersist()
val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
- stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
+ stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 4eb8fc049e..24e1cff0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ require(centers.size == weights.size,
+ "Number of initial centers must be equal to number of weights")
+ require(centers.size == k,
+ s"Number of initial centers must be ${k} but got ${centers.size}")
+ require(weights.forall(_ >= 0),
+ s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]")
model = new StreamingKMeansModel(centers, weights)
this
}
@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ require(dim > 0,
+ s"Number of dimensions must be positive but got ${dim}")
+ require(weight >= 0,
+ s"Weight for each center must be nonnegative but got ${weight}")
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index c93ed64183..47c9e850a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -36,12 +36,24 @@ import org.apache.spark.util.Utils
@Since("1.1.0")
class HashingTF(val numFeatures: Int) extends Serializable {
+ private var binary = false
+
/**
*/
@Since("1.1.0")
def this() = this(1 << 20)
/**
+ * If true, term frequency vector will be binary such that non-zero term counts will be set to 1
+ * (default: false)
+ */
+ @Since("2.0.0")
+ def setBinary(value: Boolean): this.type = {
+ binary = value
+ this
+ }
+
+ /**
* Returns the index of the input term.
*/
@Since("1.1.0")
@@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable {
@Since("1.1.0")
def transform(document: Iterable[_]): Vector = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
+ val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
document.foreach { term =>
val i = indexOf(term)
- termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
+ termFrequencies.put(i, setTF(i))
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 4455681e50..4344ab1bad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
/**
@@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
@Since("1.5.0")
class PrefixSpanModel[Item] @Since("1.5.0") (
@Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
- extends Serializable
+ extends Saveable with Serializable {
+
+ /**
+ * Save this model to the given path.
+ * It only works for Item datatypes supported by DataFrames.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[PrefixSpanModel.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ */
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ PrefixSpanModel.SaveLoadV1_0.save(this, path)
+ }
+
+ override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ PrefixSpanModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[fpm] object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"
+
+ def save(model: PrefixSpanModel[_], path: String): Unit = {
+ val sc = model.freqSequences.sparkContext
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Get the type of item class
+ val sample = model.freqSequences.first().sequence(0)(0)
+ val className = sample.getClass.getCanonicalName
+ val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+ val tpe = classSymbol.selfType
+
+ val itemType = ScalaReflection.schemaFor(tpe).dataType
+ val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))),
+ StructField("freq", LongType))
+ val schema = StructType(fields)
+ val rowDataRDD = model.freqSequences.map { x =>
+ Row(x.sequence, x.freq)
+ }
+ sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ implicit val formats = DefaultFormats
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
+ val sample = freqSequences.select("sequence").head().get(0)
+ loadImpl(freqSequences, sample)
+ }
+
+ def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = {
+ val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x =>
+ val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
+ val freq = x.getLong(1)
+ new PrefixSpan.FreqSequence(sequence, freq)
+ }
+ new PrefixSpanModel(freqSequencesRDD)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
index 391f89aa14..5c12c9305b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel
* - This class removes checkpoint files once later Datasets have been checkpointed.
* However, references to the older Datasets will still return isCheckpointed = true.
*
- * @param checkpointInterval Datasets will be checkpointed at this interval
+ * @param checkpointInterval Datasets will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
@@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T](
updateCount += 1
// Handle checkpointing (after persisting)
- if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0
+ && sc.getCheckpointDir.nonEmpty) {
// Add new checkpoint before removing old checkpoints.
checkpoint(newData)
checkpointQueue.enqueue(newData)
@@ -134,6 +136,24 @@ private[mllib] abstract class PeriodicCheckpointer[T](
}
/**
+ * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
+ * Note that there may not be any checkpoints at all.
+ */
+ def deleteAllCheckpointsButLast(): Unit = {
+ while (checkpointQueue.size > 1) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Get all current checkpoint files.
+ * This is useful in combination with [[deleteAllCheckpointsButLast()]].
+ */
+ def getAllCheckpointFiles: Array[String] = {
+ checkpointQueue.flatMap(getCheckpointFiles).toArray
+ }
+
+ /**
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
@@ -141,15 +161,20 @@ private[mllib] abstract class PeriodicCheckpointer[T](
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
- getCheckpointFiles(old).foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
}
+}
+
+private[spark] object PeriodicCheckpointer extends Logging {
+ /** Delete a checkpoint file, and log a warning if deletion fails. */
+ def removeCheckpointFile(path: String, fs: FileSystem): Unit = {
+ try {
+ fs.delete(new Path(path), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ path)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 11a059536c..20db6084d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -69,7 +69,8 @@ import org.apache.spark.storage.StorageLevel
* // checkpointed: graph4
* }}}
*
- * @param checkpointInterval Graphs will be checkpointed at this interval
+ * @param checkpointInterval Graphs will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index c6de7751f5..8c09b69b3c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -123,14 +123,18 @@ sealed trait Matrix extends Serializable {
@Since("1.4.0")
def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth)
- /** Map the values of this matrix using a function. Generates a new matrix. Performs the
- * function on only the backing array. For example, an operation such as addition or
- * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
+ /**
+ * Map the values of this matrix using a function. Generates a new matrix. Performs the
+ * function on only the backing array. For example, an operation such as addition or
+ * subtraction will only be performed on the non-zero values in a `SparseMatrix`.
+ */
private[spark] def map(f: Double => Double): Matrix
- /** Update all the values of this matrix using the function f. Performed in-place on the
- * backing array. For example, an operation such as addition or subtraction will only be
- * performed on the non-zero values in a `SparseMatrix`. */
+ /**
+ * Update all the values of this matrix using the function f. Performed in-place on the
+ * backing array. For example, an operation such as addition or subtraction will only be
+ * performed on the non-zero values in a `SparseMatrix`.
+ */
private[mllib] def update(f: Double => Double): Matrix
/**
@@ -613,7 +617,7 @@ class SparseMatrix @Since("1.3.0") (
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
val ind = index(i, j)
- if (ind == -1) {
+ if (ind < 0) {
throw new NoSuchElementException("The given row and column indices correspond to a zero " +
"value. Only non-zero elements in Sparse Matrices can be updated.")
} else {
@@ -940,8 +944,16 @@ object Matrices {
case dm: BDM[Double] =>
new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose)
case sm: BSM[Double] =>
+ // Spark-11507. work around breeze issue 479.
+ val mat = if (sm.colPtrs.last != sm.data.length) {
+ val matCopy = sm.copy
+ matCopy.compact()
+ matCopy
+ } else {
+ sm
+ }
// There is no isTranspose flag for sparse matrices in Breeze
- new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
+ new SparseMatrix(mat.rows, mat.cols, mat.colPtrs, mat.rowIndices, mat.data)
case _ =>
throw new UnsupportedOperationException(
s"Do not support conversion from type ${breeze.getClass.getName}.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 0f0c3a2df5..5812cdde2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -186,7 +186,7 @@ sealed trait Vector extends Serializable {
* :: AlphaComponent ::
*
* User-defined type for [[Vector]] which allows easy interaction with SQL
- * via [[org.apache.spark.sql.DataFrame]].
+ * via [[org.apache.spark.sql.Dataset]].
*/
@AlphaComponent
class VectorUDT extends UserDefinedType[Vector] {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index e8f4422fd4..84764963b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -81,8 +81,8 @@ class StreamingLinearRegressionWithSGD private[mllib] (
}
/**
- * Set the number of iterations of gradient descent to run per update. Default: 50.
- */
+ * Set the number of iterations of gradient descent to run per update. Default: 50.
+ */
@Since("1.1.0")
def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index 052b5b1d65..6c6e9fb7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -61,15 +61,17 @@ class MultivariateGaussian @Since("1.3.0") (
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
- /** Returns density of this multivariate Gaussian at given point, x
- */
+ /**
+ * Returns density of this multivariate Gaussian at given point, x
+ */
@Since("1.3.0")
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
- /** Returns the log-density of this multivariate Gaussian at given point, x
- */
+ /**
+ * Returns the log-density of this multivariate Gaussian at given point, x
+ */
@Since("1.3.0")
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
index baf9e5e7d1..9748fbf2c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
@@ -166,7 +166,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
: KolmogorovSmirnovTestResult = {
val distObj =
distName match {
- case "norm" => {
+ case "norm" =>
if (params.nonEmpty) {
// parameters are passed, then can only be 2
require(params.length == 2, "Normal distribution requires mean and standard " +
@@ -178,7 +178,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
"initialized to standard normal (i.e. N(0, 1))")
new NormalDistribution(0, 1)
}
- }
case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
s" convenience method. Current options are:['norm'].")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index d166dc7905..7fe60e2d99 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,15 +20,11 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.impurity.Variance
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
/**
* A class that implements
@@ -70,17 +66,8 @@ class GradientBoostedTrees private[spark] (
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
- seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -107,20 +94,9 @@ class GradientBoostedTrees private[spark] (
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- val remappedValidationInput = validationInput.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true, seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy,
+ seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -162,147 +138,4 @@ object GradientBoostedTrees extends Logging {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
train(input.rdd, boostingStrategy)
}
-
- /**
- * Internal method for performing regression using trees as base learners.
- * @param input Training dataset.
- * @param validationInput Validation dataset, ignored if validate is set to false.
- * @param boostingStrategy Boosting parameters.
- * @param validate Whether or not to use the validation dataset.
- * @param seed Random seed.
- * @return GradientBoostedTreesModel that can be used for prediction.
- */
- private def boost(
- input: RDD[LabeledPoint],
- validationInput: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy,
- validate: Boolean,
- seed: Int): GradientBoostedTreesModel = {
- val timer = new TimeTracker()
- timer.start("total")
- timer.start("init")
-
- boostingStrategy.assertValid()
-
- // Initialize gradient boosting parameters
- val numIterations = boostingStrategy.numIterations
- val baseLearners = new Array[DecisionTreeModel](numIterations)
- val baseLearnerWeights = new Array[Double](numIterations)
- val loss = boostingStrategy.loss
- val learningRate = boostingStrategy.learningRate
- // Prepare strategy for individual trees, which use regression with variance impurity.
- val treeStrategy = boostingStrategy.treeStrategy.copy
- val validationTol = boostingStrategy.validationTol
- treeStrategy.algo = Regression
- treeStrategy.impurity = Variance
- treeStrategy.assertValid()
-
- // Cache input
- val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
- input.persist(StorageLevel.MEMORY_AND_DISK)
- true
- } else {
- false
- }
-
- // Prepare periodic checkpointers
- val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
- val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
-
- timer.stop("init")
-
- logDebug("##########")
- logDebug("Building tree 0")
- logDebug("##########")
-
- // Initialize tree
- timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
- val firstTreeWeight = 1.0
- baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = firstTreeWeight
-
- var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- // Note: A model of type regression is used since we require raw prediction
- timer.stop("building tree 0")
-
- var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
- if (validate) validatePredErrorCheckpointer.update(validatePredError)
- var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
- var bestM = 1
-
- var m = 1
- var doneLearning = false
- while (m < numIterations && !doneLearning) {
- // Update data with pseudo-residuals
- val data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
- timer.start(s"building tree $m")
- logDebug("###################################################")
- logDebug("Gradient boosting tree iteration " + m)
- logDebug("###################################################")
- val model = new DecisionTree(treeStrategy, seed + m).run(data)
- timer.stop(s"building tree $m")
- // Update partial model
- baseLearners(m) = model
- // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
- // Technically, the weight should be optimized for the particular loss.
- // However, the behavior should be reasonable, though not optimal.
- baseLearnerWeights(m) = learningRate
-
- predError = GradientBoostedTreesModel.updatePredictionError(
- input, predError, baseLearnerWeights(m), baseLearners(m), loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- if (validate) {
- // Stop training early if
- // 1. Reduction in error is less than the validationTol or
- // 2. If the error increases, that is if the model is overfit.
- // We want the model returned corresponding to the best validation error.
-
- validatePredError = GradientBoostedTreesModel.updatePredictionError(
- validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
- validatePredErrorCheckpointer.update(validatePredError)
- val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol * Math.max(
- currentValidateError, 0.01)) {
- doneLearning = true
- } else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
- }
- }
- m += 1
- }
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- predErrorCheckpointer.deleteAllCheckpoints()
- validatePredErrorCheckpointer.deleteAllCheckpoints()
- if (persistedInput) input.unpersist()
-
- if (validate) {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
- } else {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4a95..26755849ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
+ * If a real value "n" in the range (0, 1.0] is set,
+ * use n * number of features.
+ * If an integer value "n" in the range (1, num features) is set,
+ * use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
- require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+ || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
- s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
deleted file mode 100644
index dc7e969f7b..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: DeveloperApi ::
- * This is used by the node id cache to find the child id that a data point would belong to.
- * @param split Split information.
- * @param nodeIndex The current node index of a data point that this will update.
- */
-@DeveloperApi
-private[tree] case class NodeIndexUpdater(
- split: Split,
- nodeIndex: Int) {
- /**
- * Determine a child node index based on the feature value and the split.
- * @param binnedFeatures Binned feature values.
- * @param bins Bin information to convert the bin indices to approximate feature values.
- * @return Child node index to update to.
- */
- def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
- if (split.featureType == Continuous) {
- val featureIndex = split.feature
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- if (featureValueUpperBound <= split.threshold) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- } else {
- if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * A given TreePoint would belong to a particular node per tree.
- * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
- * in each tree. Initially, values should all be 1 for root node.
- * The nodeIdsForInstances RDD needs to be updated at each iteration.
- * @param nodeIdsForInstances The initial values in the cache
- * (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- */
-@DeveloperApi
-private[spark] class NodeIdCache(
- var nodeIdsForInstances: RDD[Array[Int]],
- val checkpointInterval: Int) {
-
- // Keep a reference to a previous node Ids for instances.
- // Because we will keep on re-persisting updated node Ids,
- // we want to unpersist the previous RDD.
- private var prevNodeIdsForInstances: RDD[Array[Int]] = null
-
- // To keep track of the past checkpointed RDDs.
- private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
- private var rddUpdateCount = 0
-
- /**
- * Update the node index values in the cache.
- * This updates the RDD and its lineage.
- * TODO: Passing bin information to executors seems unnecessary and costly.
- * @param data The RDD of training rows.
- * @param nodeIdUpdaters A map of node index updaters.
- * The key is the indices of nodes that we want to update.
- * @param bins Bin information needed to find child node indices.
- */
- def updateNodeIndices(
- data: RDD[BaggedPoint[TreePoint]],
- nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
- bins: Array[Array[Bin]]): Unit = {
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
-
- prevNodeIdsForInstances = nodeIdsForInstances
- nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
- case (point, node) => {
- var treeId = 0
- while (treeId < nodeIdUpdaters.length) {
- val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
- if (nodeIdUpdater != null) {
- val newNodeIndex = nodeIdUpdater.updateNodeIndex(
- binnedFeatures = point.datum.binnedFeatures,
- bins = bins)
- node(treeId) = newNodeIndex
- }
-
- treeId += 1
- }
-
- node
- }
- }
-
- // Keep on persisting new ones.
- nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
- rddUpdateCount += 1
-
- // Handle checkpointing if the directory is not None.
- if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
- (rddUpdateCount % checkpointInterval) == 0) {
- // Let's see if we can delete previous checkpoints.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // We can delete the oldest checkpoint iff
- // the next checkpoint actually exists in the file system.
- if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
- val old = checkpointQueue.dequeue()
-
- // Since the old checkpoint is not deleted by Spark,
- // we'll manually delete it here.
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(old.getCheckpointFile.get), true)
- } else {
- canDelete = false
- }
- }
-
- nodeIdsForInstances.checkpoint()
- checkpointQueue.enqueue(nodeIdsForInstances)
- }
- }
-
- /**
- * Call this after training is finished to delete any remaining checkpoints.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.nonEmpty) {
- val old = checkpointQueue.dequeue()
- for (checkpointFile <- old.getCheckpointFile) {
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(checkpointFile), true)
- }
- }
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
- }
-}
-
-private[spark] object NodeIdCache {
- /**
- * Initialize the node Id cache with initial node Id values.
- * @param data The RDD of training rows.
- * @param numTrees The number of trees that we want to create cache for.
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- * @param initVal The initial values in the cache.
- * @return A node Id cache containing an RDD of initial root node Indices.
- */
- def init(
- data: RDD[BaggedPoint[TreePoint]],
- numTrees: Int,
- checkpointInterval: Int,
- initVal: Int = 1): NodeIdCache = {
- new NodeIdCache(
- data.map(_ => Array.fill[Int](numTrees)(initVal)),
- checkpointInterval)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
deleted file mode 100644
index 21919d69a3..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.Bin
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Internal representation of LabeledPoint for DecisionTree.
- * This bins feature values based on a subsampled of data as follows:
- * (a) Continuous features are binned into ranges.
- * (b) Unordered categorical features are binned based on subsets of feature values.
- * "Unordered categorical features" are categorical features with low arity used in
- * multiclass classification.
- * (c) Ordered categorical features are binned based on feature values.
- * "Ordered categorical features" are categorical features with high arity,
- * or any categorical feature used in regression or binary classification.
- *
- * @param label Label from LabeledPoint
- * @param binnedFeatures Binned feature values.
- * Same length as LabeledPoint.features, but values are bin indices.
- */
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
- extends Serializable {
-}
-
-private[spark] object TreePoint {
-
- /**
- * Convert an input dataset into its TreePoint representation,
- * binning feature values in preparation for DecisionTree training.
- * @param input Input dataset.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param metadata Learning and dataset metadata
- * @return TreePoint dataset representation
- */
- def convertToTreeRDD(
- input: RDD[LabeledPoint],
- bins: Array[Array[Bin]],
- metadata: DecisionTreeMetadata): RDD[TreePoint] = {
- // Construct arrays for featureArity for efficiency in the inner loop.
- val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- var featureIndex = 0
- while (featureIndex < metadata.numFeatures) {
- featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- featureIndex += 1
- }
- input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, featureArity)
- }
- }
-
- /**
- * Convert one LabeledPoint into its TreePoint representation.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
- * for categorical features.
- */
- private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
- bins: Array[Array[Bin]],
- featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
- val arr = new Array[Int](numFeatures)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
- bins)
- featureIndex += 1
- }
- new TreePoint(labeledPoint.label, arr)
- }
-
- /**
- * Find bin for one (labeledPoint, feature).
- *
- * @param featureArity 0 for continuous features; number of categories for categorical features.
- * @param bins Bins for features, of size (numFeatures, numBins).
- */
- private def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- featureArity: Int,
- bins: Array[Array[Bin]]): Int = {
-
- /**
- * Binary search helper method for continuous feature.
- */
- def binarySearchForBins(): Int = {
- val binForFeatures = bins(featureIndex)
- val feature = labeledPoint.features(featureIndex)
- var left = 0
- var right = binForFeatures.length - 1
- while (left <= right) {
- val mid = left + (right - left) / 2
- val bin = binForFeatures(mid)
- val lowThreshold = bin.lowSplit.threshold
- val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)) {
- return mid
- } else if (lowThreshold >= feature) {
- right = mid - 1
- } else {
- left = mid + 1
- }
- }
- -1
- }
-
- if (featureArity == 0) {
- // Perform binary search for finding bin for continuous features.
- val binIndex = binarySearchForBins()
- if (binIndex == -1) {
- throw new RuntimeException("No bin was found for continuous feature." +
- " This error can occur when given invalid data values (such as NaN)." +
- s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
- }
- binIndex
- } else {
- // Categorical feature bins are indexed by feature values.
- val featureValue = labeledPoint.features(featureIndex)
- if (featureValue < 0 || featureValue >= featureArity) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureArity - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- featureValue.toInt
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 13aff11007..ff7700d2d1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -85,7 +85,7 @@ object Entropy extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class EntropyAggregator(numClasses: Int)
+private[spark] class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 39c7f9c3be..58dc79b739 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -81,7 +81,7 @@ object Gini extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class GiniAggregator(numClasses: Int)
+private[spark] class GiniAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 92d74a1b83..2423516123 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -71,7 +71,7 @@ object Variance extends Impurity {
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
-private[tree] class VarianceAggregator()
+private[spark] class VarianceAggregator()
extends ImpurityAggregator(statsSize = 3) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
deleted file mode 100644
index 0cad473782..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-
-/**
- * Used for "binning" the feature values for faster best split calculation.
- *
- * For a continuous feature, the bin is determined by a low and a high split,
- * where an example with featureValue falls into the bin s.t.
- * lowSplit.threshold < featureValue <= highSplit.threshold.
- *
- * For ordered categorical features, there is a 1-1-1 correspondence between
- * bins, splits, and feature values. The bin is determined by category/feature value.
- * However, the bins are not necessarily ordered by feature value;
- * they are ordered using impurity.
- *
- * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
- * where bins and splits correspond to subsets of feature values (in highSplit.categories).
- * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
- * partitionings of categories into 2 disjoint, non-empty sets.
- *
- * @param lowSplit signifying the lower threshold for the continuous feature to be
- * accepted in the bin
- * @param highSplit signifying the upper threshold for the continuous feature to be
- * accepted in the bin
- * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for ordered features
- */
-private[tree]
-case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ea68ff64a8..a87f8a6cde 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -156,7 +156,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
feature: Int,
threshold: Double,
featureType: Int,
- categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+ categories: Seq[Double]) {
def toSplit: Split = {
new Split(feature, threshold, FeatureType(featureType), categories.toList)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c3b1d5cdd7..774170ff40 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -67,42 +67,14 @@ object MLUtils {
path: String,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = {
- val parsed = sc.textFile(path, minPartitions)
- .map(_.trim)
- .filter(line => !(line.isEmpty || line.startsWith("#")))
- .map { line =>
- val items = line.split(' ')
- val label = items.head.toDouble
- val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
- val indexAndValue = item.split(':')
- val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
- val value = indexAndValue(1).toDouble
- (index, value)
- }.unzip
-
- // check if indices are one-based and in ascending order
- var previous = -1
- var i = 0
- val indicesLength = indices.length
- while (i < indicesLength) {
- val current = indices(i)
- require(current > previous, s"indices should be one-based and in ascending order;"
- + " found current=$current, previous=$previous; line=\"$line\"")
- previous = current
- i += 1
- }
-
- (label, indices.toArray, values.toArray)
- }
+ val parsed = parseLibSVMFile(sc, path, minPartitions)
// Determine number of features.
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.persist(StorageLevel.MEMORY_ONLY)
- parsed.map { case (label, indices, values) =>
- indices.lastOption.getOrElse(0)
- }.reduce(math.max) + 1
+ computeNumFeatures(parsed)
}
parsed.map { case (label, indices, values) =>
@@ -110,6 +82,47 @@ object MLUtils {
}
}
+ private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = {
+ rdd.map { case (label, indices, values) =>
+ indices.lastOption.getOrElse(0)
+ }.reduce(math.max) + 1
+ }
+
+ private[spark] def parseLibSVMFile(
+ sc: SparkContext,
+ path: String,
+ minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
+ sc.textFile(path, minPartitions)
+ .map(_.trim)
+ .filter(line => !(line.isEmpty || line.startsWith("#")))
+ .map(parseLibSVMRecord)
+ }
+
+ private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
+ val items = line.split(' ')
+ val label = items.head.toDouble
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
+ val indexAndValue = item.split(':')
+ val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
+ val value = indexAndValue(1).toDouble
+ (index, value)
+ }.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, s"indices should be one-based and in ascending order;"
+ + " found current=$current, previous=$previous; line=\"$line\"")
+ previous = current
+ i += 1
+ }
+
+ (label, indices.toArray, values.toArray)
+ }
+
/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions.