diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-10-25 13:11:21 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-10-25 13:11:21 -0700 |
commit | 2c7394ad096201cd721be7f532da9d97028cc577 (patch) | |
tree | 53383e3dbfa7e92762c929590f90865a38d24d94 /mllib | |
parent | a21791e3164f4e6546fbe0a90017a4394a05deb1 (diff) | |
download | spark-2c7394ad096201cd721be7f532da9d97028cc577.tar.gz spark-2c7394ad096201cd721be7f532da9d97028cc577.tar.bz2 spark-2c7394ad096201cd721be7f532da9d97028cc577.zip |
[SPARK-18019][ML] Add instrumentation to GBTs
## What changes were proposed in this pull request?
Add instrumentation for logging in ML GBT, part of umbrella ticket [SPARK-14567](https://issues.apache.org/jira/browse/SPARK-14567)
## How was this patch tested?
Tested locally:
````
16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: training: numPartitions=1 storageLevel=StorageLevel(1 replicas)
16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"maxIter":1}
16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"numFeatures":2}
16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"numClasses":0}
...
16/10/20 15:54:21 INFO Instrumentation: GBTRegressor-gbtr_065fad465377-1922077832-22: training finished
````
Author: sethah <seth.hendrickson16@gmail.com>
Closes #15574 from sethah/gbt_instr.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala | 10 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala | 9 |
2 files changed, 17 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ba70293273..8bffe0cda0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -137,9 +137,17 @@ class GBTClassifier @Since("1.4.0") ( } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + instr.logNumClasses(2) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bb01f9d5a3..fa69d60836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -123,9 +123,16 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.0") |