aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 23:14:35 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-25 23:14:35 -0700
commitdc06b528790c69b2e6de85cba84266fea81dd4f4 (patch)
treed3886b58068ca8f0d46a42e9c575d3c187426c68 /mllib
parentc874625354de7117da9586cfbbe919bb6801a932 (diff)
downloadspark-dc06b528790c69b2e6de85cba84266fea81dd4f4.tar.gz
spark-dc06b528790c69b2e6de85cba84266fea81dd4f4.tar.bz2
spark-dc06b528790c69b2e6de85cba84266fea81dd4f4.zip
Add an option to turn off data validation, test it.
Also moves addIntercept to have default true to make it similar to validateData option
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala9
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVM.scala9
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala16
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Lasso.scala9
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala3
5 files changed, 28 insertions, 18 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 474ca6e97c..482e4a6745 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -54,8 +54,7 @@ class LogisticRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {
@@ -71,7 +70,7 @@ class LogisticRegressionWithSGD private (
/**
* Construct a LogisticRegression object with default parameters
*/
- def this() = this(1.0, 100, 0.0, 1.0, true)
+ def this() = this(1.0, 100, 0.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
@@ -108,7 +107,7 @@ object LogisticRegressionWithSGD {
initialWeights: Array[Double])
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input, initialWeights)
}
@@ -131,7 +130,7 @@ object LogisticRegressionWithSGD {
miniBatchFraction: Double)
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input)
}
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index b680d81e86..69393cd7b0 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -54,8 +54,7 @@ class SVMWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
val gradient = new HingeGradient()
@@ -71,7 +70,7 @@ class SVMWithSGD private (
/**
* Construct a SVM object with default parameters
*/
- def this() = this(1.0, 100, 1.0, 1.0, true)
+ def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new SVMModel(weights, intercept)
@@ -107,7 +106,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
@@ -131,7 +130,7 @@ object SVMWithSGD {
miniBatchFraction: Double)
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 03f991df39..d164d415d6 100644
--- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -87,13 +87,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
val optimizer: Optimizer
+ protected var addIntercept: Boolean = true
+
+ protected var validateData: Boolean = true
+
/**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Array[Double], intercept: Double): M
- protected var addIntercept: Boolean
-
/**
* Set if the algorithm should add an intercept. Default true.
*/
@@ -103,6 +105,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
/**
+ * Set if the algorithm should validate data before training. Default true.
+ */
+ def setValidateData(validateData: Boolean): this.type = {
+ this.validateData = validateData
+ this
+ }
+
+ /**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
*/
@@ -119,7 +129,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
// Check the data properties before running the optimizer
- if (!validators.forall(func => func(input))) {
+ if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
index 6bbc990a5a..89f791e85a 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -48,8 +48,7 @@ class LassoWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
@@ -63,7 +62,7 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters
*/
- def this() = this(1.0, 100, 1.0, 1.0, true)
+ def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
@@ -98,7 +97,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
@@ -121,7 +120,7 @@ object LassoWithSGD {
miniBatchFraction: Double)
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 8fa9e4639b..894ae458ad 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -162,5 +162,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
intercept[spark.SparkException] {
val model = SVMWithSGD.train(testRDDInvalid, 100)
}
+
+ // Turning off data validation should not throw an exception
+ val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}