aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala27
2 files changed, 50 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 060fd5b859..8a57ebc387 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -23,7 +23,7 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.SparkContext
+import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -115,7 +116,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
override protected def formatVersion: String = "1.0"
}
-object DecisionTreeModel extends Loader[DecisionTreeModel] {
+object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
private[tree] object SaveLoadV1_0 {
@@ -187,6 +188,28 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
+ // SPARK-6120: We do a hacky check here so users understand why save() is failing
+ // when they run the ML guide example.
+ // TODO: Fix this issue for real.
+ val memThreshold = 768
+ if (sc.isLocal) {
+ val driverMemory = sc.getConf.getOption("spark.driver.memory")
+ .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ if (driverMemory <= memThreshold) {
+ logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
+ s" driver memory (${driverMemory}m)." +
+ s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
+ }
+ } else {
+ if (sc.executorMemory <= memThreshold) {
+ logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
+ s" executor memory (${sc.executorMemory}m)." +
+ s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
+ }
+ }
+
// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 4897906aea..30a8f7ca30 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -24,7 +24,7 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.SparkContext
+import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -250,7 +251,7 @@ private[tree] sealed class TreeEnsembleModel(
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
-private[tree] object TreeEnsembleModel {
+private[tree] object TreeEnsembleModel extends Logging {
object SaveLoadV1_0 {
@@ -277,6 +278,28 @@ private[tree] object TreeEnsembleModel {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
+ // SPARK-6120: We do a hacky check here so users understand why save() is failing
+ // when they run the ML guide example.
+ // TODO: Fix this issue for real.
+ val memThreshold = 768
+ if (sc.isLocal) {
+ val driverMemory = sc.getConf.getOption("spark.driver.memory")
+ .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ if (driverMemory <= memThreshold) {
+ logWarning(s"$className.save() was called, but it may fail because of too little" +
+ s" driver memory (${driverMemory}m)." +
+ s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
+ }
+ } else {
+ if (sc.executorMemory <= memThreshold) {
+ logWarning(s"$className.save() was called, but it may fail because of too little" +
+ s" executor memory (${sc.executorMemory}m)." +
+ s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
+ }
+ }
+
// Create JSON metadata.
implicit val format = DefaultFormats
val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString,