aboutsummaryrefslogblamecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
blob: b73be035e29b5863ef6ec5d03717cc9c73ac9f0e (plain) (tree)



















                                                                           
                                  
                                         
                                                                        
                                                                       
                                     
                                       


                                            

                                  
                                                                                            
                                                                           
 

                                                               

   

                       
                       
                                                             
   



                                                                                      
 
                        
                                                                  

                        
                                                             

                        


                                                                           
                                                                    
 
                                                                               
                                                                                      
                                                  





                                                                            
                                              
                                    
                

                                   
                                     
                                                                                     
 


                            
                   


   
 

                       
  




                                            




                                                                           
                        
                                                                    
 






                                                        

   
                                                           





                                                                                               
                                                   







                                                                              
                                    
                                                                             

                                                                  

                        

                                      
                                                  
                                                          
                                                   
         
                                                                                           
              
                                                                                                 
                                                                                           

                        
     


                                    
                                            

                                        

                                                                                       
                                               


                                                          
                                                                                         
              
                                                                        
                                                                                    

                        
     













                                                                                       
                                                











                                                                           

                                                                              

   
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel

/**
 * Params for logistic regression.
 */
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
  with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {

  setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
}

/**
 * :: AlphaComponent ::
 *
 * Logistic regression.
 * Currently, this class only supports binary classification.
 */
@AlphaComponent
class LogisticRegression
  extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
  with LogisticRegressionParams {

  /** @group setParam */
  def setRegParam(value: Double): this.type = set(regParam, value)

  /** @group setParam */
  def setMaxIter(value: Int): this.type = set(maxIter, value)

  /** @group setParam */
  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)

  /** @group setParam */
  def setThreshold(value: Double): this.type = set(threshold, value)

  override protected def train(dataset: DataFrame): LogisticRegressionModel = {
    // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
    val oldDataset = extractLabeledPoints(dataset)
    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
    if (handlePersistence) {
      oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
    }

    // Train model
    val lr = new LogisticRegressionWithLBFGS()
      .setIntercept($(fitIntercept))
    lr.optimizer
      .setRegParam($(regParam))
      .setNumIterations($(maxIter))
    val oldModel = lr.run(oldDataset)
    val lrm = new LogisticRegressionModel(this, oldModel.weights, oldModel.intercept)

    if (handlePersistence) {
      oldDataset.unpersist()
    }
    copyValues(lrm)
  }
}


/**
 * :: AlphaComponent ::
 *
 * Model produced by [[LogisticRegression]].
 */
@AlphaComponent
class LogisticRegressionModel private[ml] (
    override val parent: LogisticRegression,
    val weights: Vector,
    val intercept: Double)
  extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
  with LogisticRegressionParams {

  /** @group setParam */
  def setThreshold(value: Double): this.type = set(threshold, value)

  private val margin: Vector => Double = (features) => {
    BLAS.dot(features, weights) + intercept
  }

  private val score: Vector => Double = (features) => {
    val m = margin(features)
    1.0 / (1.0 + math.exp(-m))
  }

  override def transform(dataset: DataFrame): DataFrame = {
    // This is overridden (a) to be more efficient (avoiding re-computing values when creating
    // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
    // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
    // can call whichever UDFs are needed to create the output columns.

    // Check schema
    transformSchema(dataset.schema, logging = true)

    // Output selected columns only.
    // This is a bit complicated since it tries to avoid repeated computation.
    //   rawPrediction (-margin, margin)
    //   probability (1.0-score, score)
    //   prediction (max margin)
    var tmpData = dataset
    var numColsOutput = 0
    if ($(rawPredictionCol) != "") {
      val features2raw: Vector => Vector = (features) => predictRaw(features)
      tmpData = tmpData.withColumn($(rawPredictionCol),
        callUDF(features2raw, new VectorUDT, col($(featuresCol))))
      numColsOutput += 1
    }
    if ($(probabilityCol) != "") {
      if ($(rawPredictionCol) != "") {
        val raw2prob = udf { (rawPreds: Vector) =>
          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
          Vectors.dense(1.0 - prob1, prob1): Vector
        }
        tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol))))
      } else {
        val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
        tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol))))
      }
      numColsOutput += 1
    }
    if ($(predictionCol) != "") {
      val t = $(threshold)
      if ($(probabilityCol) != "") {
        val predict = udf { probs: Vector =>
          if (probs(1) > t) 1.0 else 0.0
        }
        tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol))))
      } else if ($(rawPredictionCol) != "") {
        val predict = udf { rawPreds: Vector =>
          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
          if (prob1 > t) 1.0 else 0.0
        }
        tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol))))
      } else {
        val predict = udf { features: Vector => this.predict(features) }
        tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol))))
      }
      numColsOutput += 1
    }
    if (numColsOutput == 0) {
      this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
        " since no output columns were set.")
    }
    tmpData
  }

  override val numClasses: Int = 2

  /**
   * Predict label for the given feature vector.
   * The behavior of this can be adjusted using [[threshold]].
   */
  override protected def predict(features: Vector): Double = {
    if (score(features) > getThreshold) 1 else 0
  }

  override protected def predictProbabilities(features: Vector): Vector = {
    val s = score(features)
    Vectors.dense(1.0 - s, s)
  }

  override protected def predictRaw(features: Vector): Vector = {
    val m = margin(features)
    Vectors.dense(0.0, m)
  }

  override def copy(extra: ParamMap): LogisticRegressionModel = {
    copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
  }
}