diff options
author | Bryan Cutler <bjcutler@us.ibm.com> | 2015-07-17 14:10:16 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-17 14:10:16 -0700 |
commit | 8b8be1f5d698e796b96a92f1ed2c13162a90944e (patch) | |
tree | d51d35a04d2ee2849962e189f0b1bd9a209acf75 | |
parent | 830666f6fe1e77faa39eed2c1c3cd8e83bc93ef9 (diff) | |
download | spark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.tar.gz spark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.tar.bz2 spark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.zip |
[SPARK-7127] [MLLIB] Adding broadcast of model before prediction for ensembles
Broadcast of ensemble models in transformImpl before call to predict
Author: Bryan Cutler <bjcutler@us.ibm.com>
Closes #6300 from BryanCutler/bcast-ensemble-models-7127 and squashes the following commits:
86e73de [Bryan Cutler] [SPARK-7127] Replaced deprecated callUDF with udf
40a139d [Bryan Cutler] Merge branch 'master' into bcast-ensemble-models-7127
9afad56 [Bryan Cutler] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction
1f34be4 [Bryan Cutler] [SPARK-7127] Removed accidental newline
171a6ce [Bryan Cutler] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model
6fd153c [Bryan Cutler] [SPARK-7127] Applied broadcasting to remaining ensemble models
aaad77b [Bryan Cutler] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform
83904bb [Bryan Cutler] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier
5 files changed, 48 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 333b42711e..19fe039b8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - val predictUDF = udf { (features: Any) => - predict(features.asInstanceOf[FeaturesType]) - } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } + protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 554e3b8e05..eb0b1a0a40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -177,8 +179,15 @@ final class GBTClassificationModel( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model: SPARK-7127 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 490f04c7c7..fc0693f67c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -143,8 +145,15 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 47c110d027..e38dc73ee0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -167,8 +169,15 @@ final class GBTRegressionModel( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 5fd5c7c7bd..506a878c25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -129,8 +131,15 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + override protected def predict(features: Vector): Double = { - // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. |