aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-08 16:24:31 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-08 16:24:31 -0700
commit2812e722008b772756cbd0ef0600a55b65d6ee0e (patch)
treeede282204a385189a976a6fffc24c755b2630146 /mllib
parent338b7a7455c02371890590fb71eefaee587f9d0e (diff)
downloadspark-2812e722008b772756cbd0ef0600a55b65d6ee0e.tar.gz
spark-2812e722008b772756cbd0ef0600a55b65d6ee0e.tar.bz2
spark-2812e722008b772756cbd0ef0600a55b65d6ee0e.zip
Add setters for optimizer, gradient in SGD.
Also remove java-specific constructor for LabeledPoint.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala19
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala8
2 files changed, 19 insertions, 8 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 54793ca74d..1f04398d0c 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -24,7 +24,7 @@ import org.jblas.DoubleMatrix
import scala.collection.mutable.ArrayBuffer
-class GradientDescent(gradient: Gradient, updater: Updater) extends Optimizer {
+class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer {
var stepSize: Double = 1.0
var numIterations: Int = 100
@@ -63,6 +63,23 @@ class GradientDescent(gradient: Gradient, updater: Updater) extends Optimizer {
this
}
+ /**
+ * Set the gradient function to be used for SGD.
+ */
+ def setGradient(gradient: Gradient): this.type = {
+ this.gradient = gradient
+ this
+ }
+
+
+ /**
+ * Set the updater function to be used for SGD.
+ */
+ def setUpdater(updater: Updater): this.type = {
+ this.updater = updater
+ this
+ }
+
def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double])
: Array[Double] = {
diff --git a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala
index 592f0b5414..3de60482c5 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala
@@ -23,10 +23,4 @@ package spark.mllib.regression
* @param label Label for this data point.
* @param features List of features for this data point.
*/
-case class LabeledPoint(val label: Double, val features: Array[Double]) {
-
- /**
- * Construct a labeled point using java.lang.Double.
- */
- def this(label: java.lang.Double, features: Array[Double]) = this(label.doubleValue(), features)
-}
+case class LabeledPoint(val label: Double, val features: Array[Double])