aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimothy Hunter <timhunter@databricks.com>2016-04-13 11:06:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 11:06:42 -0700
commit1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8 (patch)
tree3edcb71e82e7b2c2bdc5428bf9b7817434635a97
parent323e7390a5c123c48cc7d6d9be44bee3a7eecd99 (diff)
downloadspark-1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8.tar.gz
spark-1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8.tar.bz2
spark-1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8.zip
[SPARK-14568][ML] Instrumentation framework for logistic regression
## What changes were proposed in this pull request? This adds extra logging information about a `LogisticRegression` estimator when being fit on a dataset. With this PR, you see the following extra lines when running the example in the documentation: ``` 16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training: numPartitions=1 storageLevel=StorageLevel(disk=true, memory=true, offheap=false, deserialized=true, replication=1) 16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): {"regParam":0.3,"elasticNetParam":0.8,"maxIter":10} ... 16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numClasses=2 16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numFeatures=692 ... 16/04/13 07:19:01 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training finished ``` ## How was this patch tested? This PR was manually tested. Author: Timothy Hunter <timhunter@databricks.com> Closes #12331 from thunterdb/1604-instrumentation.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala117
2 files changed, 127 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 4a3fe5c663..c2b440059b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(regParam, elasticNetParam, standardization, threshold,
+ maxIter, tol, fitIntercept)
+
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
@@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
val (coefficients, intercept, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
@@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(logRegSummary)
+ val m = model.setSummary(logRegSummary)
+ instr.logSuccess(m)
+ m
}
@Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
new file mode 100644
index 0000000000..7e57cefc44
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+
+/**
+ * A small wrapper that defines a training session for an estimator, and some methods to log
+ * useful information during this session.
+ *
+ * A new instance is expected to be created within fit().
+ *
+ * @param estimator the estimator that is being fit
+ * @param dataset the training dataset
+ * @tparam E the type of the estimator
+ */
+private[ml] class Instrumentation[E <: Estimator[_]] private (
+ estimator: E, dataset: RDD[_]) extends Logging {
+
+ private val id = Instrumentation.counter.incrementAndGet()
+ private val prefix = {
+ val className = estimator.getClass.getSimpleName
+ s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
+ }
+
+ init()
+
+ private def init(): Unit = {
+ log(s"training: numPartitions=${dataset.partitions.length}" +
+ s" storageLevel=${dataset.getStorageLevel}")
+ }
+
+ /**
+ * Logs a message with a prefix that uniquely identifies the training session.
+ */
+ def log(msg: String): Unit = {
+ logInfo(prefix + msg)
+ }
+
+ /**
+ * Logs the value of the given parameters for the estimator being used in this session.
+ */
+ def logParams(params: Param[_]*): Unit = {
+ val pairs: Seq[(String, JValue)] = for {
+ p <- params
+ value <- estimator.get(p)
+ } yield {
+ val cast = p.asInstanceOf[Param[Any]]
+ p.name -> parse(cast.jsonEncode(value))
+ }
+ log(compact(render(map2jvalue(pairs.toMap))))
+ }
+
+ def logNumFeatures(num: Long): Unit = {
+ log(compact(render("numFeatures" -> num)))
+ }
+
+ def logNumClasses(num: Long): Unit = {
+ log(compact(render("numClasses" -> num)))
+ }
+
+ /**
+ * Logs the successful completion of the training session and the value of the learned model.
+ */
+ def logSuccess(model: Model[_]): Unit = {
+ log(s"training finished")
+ }
+}
+
+/**
+ * Some common methods for logging information about a training session.
+ */
+private[ml] object Instrumentation {
+ private val counter = new AtomicLong(0)
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
+ create[E](estimator, dataset.rdd)
+ }
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: RDD[_]): Instrumentation[E] = {
+ new Instrumentation[E](estimator, dataset)
+ }
+
+}