From d60f6f62d00ffccc40ed72e15349358fe3543311 Mon Sep 17 00:00:00 2001 From: sueann Date: Fri, 6 Jan 2017 18:53:16 -0800 Subject: [SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit ## What changes were proposed in this pull request? Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions. ## How was this patch tested? Ran unit tests and checked the log file (see output in comments). Author: sueann Closes #16480 from sueann/SPARK-18194. --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 7 +++++++ .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 6 +++--- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 6 ++++++ .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++++ .../scala/org/apache/spark/ml/tuning/ValidatorParams.scala | 11 ++++++++++- .../main/scala/org/apache/spark/ml/util/Instrumentation.scala | 8 ++++++-- 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index e58b30d665..cbd508ae79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -308,6 +308,10 @@ final class OneVsRest @Since("1.4.0") ( override def fit(dataset: Dataset[_]): OneVsRestModel = { transformSchema(dataset.schema) + val instr = Instrumentation.create(this, dataset) + instr.logParams(labelCol, featuresCol, predictionCol) + instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) + // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) val computeNumClasses: () => Int = () => { @@ -316,6 +320,7 @@ final class OneVsRest @Since("1.4.0") ( maxLabelIndex.toInt + 1 } val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) + instr.logNumClasses(numClasses) val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) @@ -339,6 +344,7 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.predictionCol -> getPredictionCol) classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] + instr.logNumFeatures(models.head.numFeatures) if (handlePersistence) { multiclassLabeled.unpersist() @@ -352,6 +358,7 @@ final class OneVsRest @Since("1.4.0") ( case attr: Attribute => attr } val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) + instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index b466e2ed35..cdea90ec1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -457,8 +457,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val instrLog = Instrumentation.create(this, ratings) - instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, + val instr = Instrumentation.create(this, ratings) + instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, seed) val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), @@ -471,7 +471,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) - instrLog.logSuccess(model) + instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 85191d46fd..2012d6ca8b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -101,6 +101,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) + + val instr = Instrumentation.create(this, dataset) + instr.logParams(numFolds, seed) + logTuningParams(instr) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() @@ -127,6 +132,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 5d1a39f7c1..db7c9d13d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -97,6 +97,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val numModels = epm.length val metrics = new Array[Double](epm.length) + val instr = Instrumentation.create(this, dataset) + instr.logParams(trainRatio, seed) + logTuningParams(instr) + val Array(trainingDataset, validationDataset) = dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) trainingDataset.cache() @@ -123,6 +127,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 26fd73814d..d55eb14d03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.HasSeed -import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.sql.types.StructType @@ -76,6 +76,15 @@ private[ml] trait ValidatorParams extends HasSeed with Params { } est.copy(firstEstimatorParamMap).transformSchema(schema) } + + /** + * Instrumentation logging for tuning params including the inner estimator and evaluator info. + */ + protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) + instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) + instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) + } } private[ml] object ValidatorParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 71a626647a..a2794368db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -87,8 +87,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( /** * Logs the value with customized name field. */ - def logNamedValue(name: String, num: Long): Unit = { - log(compact(render(name -> num))) + def logNamedValue(name: String, value: String): Unit = { + log(compact(render(name -> value))) + } + + def logNamedValue(name: String, value: Long): Unit = { + log(compact(render(name -> value))) } /** -- cgit v1.2.3