aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <bjcutler@us.ibm.com>2015-07-17 14:10:16 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-17 14:10:16 -0700
commit8b8be1f5d698e796b96a92f1ed2c13162a90944e (patch)
treed51d35a04d2ee2849962e189f0b1bd9a209acf75
parent830666f6fe1e77faa39eed2c1c3cd8e83bc93ef9 (diff)
downloadspark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.tar.gz
spark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.tar.bz2
spark-8b8be1f5d698e796b96a92f1ed2c13162a90944e.zip
[SPARK-7127] [MLLIB] Adding broadcast of model before prediction for ensembles
Broadcast of ensemble models in transformImpl before call to predict Author: Bryan Cutler <bjcutler@us.ibm.com> Closes #6300 from BryanCutler/bcast-ensemble-models-7127 and squashes the following commits: 86e73de [Bryan Cutler] [SPARK-7127] Replaced deprecated callUDF with udf 40a139d [Bryan Cutler] Merge branch 'master' into bcast-ensemble-models-7127 9afad56 [Bryan Cutler] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction 1f34be4 [Bryan Cutler] [SPARK-7127] Removed accidental newline 171a6ce [Bryan Cutler] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model 6fd153c [Bryan Cutler] [SPARK-7127] Applied broadcasting to remaining ensemble models aaad77b [Bryan Cutler] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform 83904bb [Bryan Cutler] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala11
5 files changed, 48 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 333b42711e..19fe039b8f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
- val predictUDF = udf { (features: Any) =>
- predict(features.asInstanceOf[FeaturesType])
- }
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
@@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}
}
+ protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf { (features: Any) =>
+ predict(features.asInstanceOf[FeaturesType])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
/**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 554e3b8e05..eb0b1a0a40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -177,8 +179,15 @@ final class GBTClassificationModel(
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model: SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 490f04c7c7..fc0693f67c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -143,8 +145,15 @@ final class RandomForestClassificationModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 47c110d027..e38dc73ee0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -167,8 +169,15 @@ final class GBTRegressionModel(
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 5fd5c7c7bd..506a878c25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -129,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.