aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsueann <sueann@databricks.com>2017-01-06 18:53:16 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-01-06 18:53:16 -0800
commitd60f6f62d00ffccc40ed72e15349358fe3543311 (patch)
tree15adf4be1de5772cd1339cbf7b7db651af7a2953
parentb59cddaba01cbdf50dbe8fe7ef7b9913bad9552d (diff)
downloadspark-d60f6f62d00ffccc40ed72e15349358fe3543311.tar.gz
spark-d60f6f62d00ffccc40ed72e15349358fe3543311.tar.bz2
spark-d60f6f62d00ffccc40ed72e15349358fe3543311.zip
[SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit
## What changes were proposed in this pull request? Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions. ## How was this patch tested? Ran unit tests and checked the log file (see output in comments). Author: sueann <sueann@databricks.com> Closes #16480 from sueann/SPARK-18194.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala8
6 files changed, 37 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index e58b30d665..cbd508ae79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -308,6 +308,10 @@ final class OneVsRest @Since("1.4.0") (
override def fit(dataset: Dataset[_]): OneVsRestModel = {
transformSchema(dataset.schema)
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(labelCol, featuresCol, predictionCol)
+ instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
+
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
@@ -316,6 +320,7 @@ final class OneVsRest @Since("1.4.0") (
maxLabelIndex.toInt + 1
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
+ instr.logNumClasses(numClasses)
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
@@ -339,6 +344,7 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
+ instr.logNumFeatures(models.head.numFeatures)
if (handlePersistence) {
multiclassLabeled.unpersist()
@@ -352,6 +358,7 @@ final class OneVsRest @Since("1.4.0") (
case attr: Attribute => attr
}
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
+ instr.logSuccess(model)
copyValues(model)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index b466e2ed35..cdea90ec1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -457,8 +457,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
- val instrLog = Instrumentation.create(this, ratings)
- instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
+ val instr = Instrumentation.create(this, ratings)
+ instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
userCol, itemCol, ratingCol, predictionCol, maxIter,
regParam, nonnegative, checkpointInterval, seed)
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
@@ -471,7 +471,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
- instrLog.logSuccess(model)
+ instr.logSuccess(model)
copyValues(model)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 85191d46fd..2012d6ca8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -101,6 +101,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
+
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(numFolds, seed)
+ logTuningParams(instr)
+
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
@@ -127,6 +132,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 5d1a39f7c1..db7c9d13d3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -97,6 +97,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length
val metrics = new Array[Double](epm.length)
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(trainRatio, seed)
+ logTuningParams(instr)
+
val Array(trainingDataset, validationDataset) =
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
trainingDataset.cache()
@@ -123,6 +127,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 26fd73814d..d55eb14d03 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasSeed
-import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable}
+import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
@@ -76,6 +76,15 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
}
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
+
+ /**
+ * Instrumentation logging for tuning params including the inner estimator and evaluator info.
+ */
+ protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
+ instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
+ instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
+ instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
+ }
}
private[ml] object ValidatorParams {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 71a626647a..a2794368db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -87,8 +87,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
/**
* Logs the value with customized name field.
*/
- def logNamedValue(name: String, num: Long): Unit = {
- log(compact(render(name -> num)))
+ def logNamedValue(name: String, value: String): Unit = {
+ log(compact(render(name -> value)))
+ }
+
+ def logNamedValue(name: String, value: Long): Unit = {
+ log(compact(render(name -> value)))
}
/**