aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
blob: 8daca6c390635211039d5375638dd1c35f813e27 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.serializer

import java.io._
import java.lang.reflect.{Field, Method}
import java.security.AccessController

import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging

private[spark] object SerializationDebugger extends Logging {

  /**
   * Improve the given NotSerializableException with the serialization path leading from the given
   * object to the problematic object. This is turned off automatically if
   * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM.
   */
  def improveException(obj: Any, e: NotSerializableException): NotSerializableException = {
    if (enableDebugging && reflect != null) {
      try {
        new NotSerializableException(
          e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n"))
      } catch {
        case NonFatal(t) =>
          // Fall back to old exception
          logWarning("Exception in serialization debugger", t)
          e
      }
    } else {
      e
    }
  }

  /**
   * Find the path leading to a not serializable object. This method is modeled after OpenJDK's
   * serialization mechanism, and handles the following cases:
   *
   *  - primitives
   *  - arrays of primitives
   *  - arrays of non-primitive objects
   *  - Serializable objects
   *  - Externalizable objects
   *  - writeReplace
   *
   * It does not yet handle writeObject override, but that shouldn't be too hard to do either.
   */
  private[serializer] def find(obj: Any): List[String] = {
    new SerializationDebugger().visit(obj, List.empty)
  }

  private[serializer] var enableDebugging: Boolean = {
    !AccessController.doPrivileged(new sun.security.action.GetBooleanAction(
      "sun.io.serialization.extendedDebugInfo")).booleanValue()
  }

  private class SerializationDebugger {

    /** A set to track the list of objects we have visited, to avoid cycles in the graph. */
    private val visited = new mutable.HashSet[Any]

    /**
     * Visit the object and its fields and stop when we find an object that is not serializable.
     * Return the path as a list. If everything can be serialized, return an empty list.
     */
    def visit(o: Any, stack: List[String]): List[String] = {
      if (o == null) {
        List.empty
      } else if (visited.contains(o)) {
        List.empty
      } else {
        visited += o
        o match {
          // Primitive value, string, and primitive arrays are always serializable
          case _ if o.getClass.isPrimitive => List.empty
          case _: String => List.empty
          case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty

          // Traverse non primitive array.
          case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive =>
            val elem = s"array (class ${a.getClass.getName}, size ${a.length})"
            visitArray(o.asInstanceOf[Array[_]], elem :: stack)

          case e: java.io.Externalizable =>
            val elem = s"externalizable object (class ${e.getClass.getName}, $e)"
            visitExternalizable(e, elem :: stack)

          case s: Object with java.io.Serializable =>
            val elem = s"object (class ${s.getClass.getName}, $s)"
            visitSerializable(s, elem :: stack)

          case _ =>
            // Found an object that is not serializable!
            s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack
        }
      }
    }

    private def visitArray(o: Array[_], stack: List[String]): List[String] = {
      var i = 0
      while (i < o.length) {
        val childStack = visit(o(i), s"element of array (index: $i)" :: stack)
        if (childStack.nonEmpty) {
          return childStack
        }
        i += 1
      }
      return List.empty
    }

    /**
     * Visit an externalizable object.
     * Since writeExternal() can choose to add arbitrary objects at the time of serialization,
     * the only way to capture all the objects it will serialize is by using a
     * dummy ObjectOutput that collects all the relevant objects for further testing.
     */
    private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
    {
      val fieldList = new ListObjectOutput
      o.writeExternal(fieldList)
      val childObjects = fieldList.outputArray
      var i = 0
      while (i < childObjects.length) {
        val childStack = visit(childObjects(i), "writeExternal data" :: stack)
        if (childStack.nonEmpty) {
          return childStack
        }
        i += 1
      }
      return List.empty
    }

    private def visitSerializable(o: Object, stack: List[String]): List[String] = {
      // An object contains multiple slots in serialization.
      // Get the slots and visit fields in all of them.
      val (finalObj, desc) = findObjectAndDescriptor(o)

      // If the object has been replaced using writeReplace(),
      // then call visit() on it again to test its type again.
      if (!finalObj.eq(o)) {
        return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
      }

      // Every class is associated with one or more "slots", each slot refers to the parent
      // classes of this class. These slots are used by the ObjectOutputStream
      // serialization code to recursively serialize the fields of an object and
      // its parent classes. For example, if there are the following classes.
      //
      //     class ParentClass(parentField: Int)
      //     class ChildClass(childField: Int) extends ParentClass(1)
      //
      // Then serializing the an object Obj of type ChildClass requires first serializing the fields
      // of ParentClass (that is, parentField), and then serializing the fields of ChildClass
      // (that is, childField). Correspondingly, there will be two slots related to this object:
      //
      // 1. ParentClass slot, which will be used to serialize parentField of Obj
      // 2. ChildClass slot, which will be used to serialize childField fields of Obj
      //
      // The following code uses the description of each slot to find the fields in the
      // corresponding object to visit.
      //
      val slotDescs = desc.getSlotDescs
      var i = 0
      while (i < slotDescs.length) {
        val slotDesc = slotDescs(i)
        if (slotDesc.hasWriteObjectMethod) {
          // If the class type corresponding to current slot has writeObject() defined,
          // then its not obvious which fields of the class will be serialized as the writeObject()
          // can choose arbitrary fields for serialization. This case is handled separately.
          val elem = s"writeObject data (class: ${slotDesc.getName})"
          val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack)
          if (childStack.nonEmpty) {
            return childStack
          }
        } else {
          // Visit all the fields objects of the class corresponding to the current slot.
          val fields: Array[ObjectStreamField] = slotDesc.getFields
          val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
          val numPrims = fields.length - objFieldValues.length
          slotDesc.getObjFieldValues(finalObj, objFieldValues)

          var j = 0
          while (j < objFieldValues.length) {
            val fieldDesc = fields(numPrims + j)
            val elem = s"field (class: ${slotDesc.getName}" +
              s", name: ${fieldDesc.getName}" +
              s", type: ${fieldDesc.getType})"
            val childStack = visit(objFieldValues(j), elem :: stack)
            if (childStack.nonEmpty) {
              return childStack
            }
            j += 1
          }
        }
        i += 1
      }
      return List.empty
    }

    /**
     * Visit a serializable object which has the writeObject() defined.
     * Since writeObject() can choose to add arbitrary objects at the time of serialization,
     * the only way to capture all the objects it will serialize is by using a
     * dummy ObjectOutputStream that collects all the relevant fields for further testing.
     * This is similar to how externalizable objects are visited.
     */
    private def visitSerializableWithWriteObjectMethod(
        o: Object, stack: List[String]): List[String] = {
      val innerObjectsCatcher = new ListObjectOutputStream
      var notSerializableFound = false
      try {
        innerObjectsCatcher.writeObject(o)
      } catch {
        case io: IOException =>
          notSerializableFound = true
      }

      // If something was not serializable, then visit the captured objects.
      // Otherwise, all the captured objects are safely serializable, so no need to visit them.
      // As an optimization, just added them to the visited list.
      if (notSerializableFound) {
        val innerObjects = innerObjectsCatcher.outputArray
        var k = 0
        while (k < innerObjects.length) {
          val childStack = visit(innerObjects(k), stack)
          if (childStack.nonEmpty) {
            return childStack
          }
          k += 1
        }
      } else {
        visited ++= innerObjectsCatcher.outputArray
      }
      return List.empty
    }
  }

  /**
   * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
   * writeReplace in Serializable. It starts with the object itself, and keeps calling the
   * writeReplace method until there is no more.
   */
  @tailrec
  private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
    val cl = o.getClass
    val desc = ObjectStreamClass.lookupAny(cl)
    if (!desc.hasWriteReplaceMethod) {
      (o, desc)
    } else {
      // write place
      findObjectAndDescriptor(desc.invokeWriteReplace(o))
    }
  }

  /**
   * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal
   * call, and returns them through `outputArray`.
   */
  private class ListObjectOutput extends ObjectOutput {
    private val output = new mutable.ArrayBuffer[Any]
    def outputArray: Array[Any] = output.toArray
    override def writeObject(o: Any): Unit = output += o
    override def flush(): Unit = {}
    override def write(i: Int): Unit = {}
    override def write(bytes: Array[Byte]): Unit = {}
    override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {}
    override def close(): Unit = {}
    override def writeFloat(v: Float): Unit = {}
    override def writeChars(s: String): Unit = {}
    override def writeDouble(v: Double): Unit = {}
    override def writeUTF(s: String): Unit = {}
    override def writeShort(i: Int): Unit = {}
    override def writeInt(i: Int): Unit = {}
    override def writeBoolean(b: Boolean): Unit = {}
    override def writeBytes(s: String): Unit = {}
    override def writeChar(i: Int): Unit = {}
    override def writeLong(l: Long): Unit = {}
    override def writeByte(i: Int): Unit = {}
  }

  /** An output stream that emulates /dev/null */
  private class NullOutputStream extends OutputStream {
    override def write(b: Int) { }
  }

  /**
   * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns
   * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()`
   * method which gets called on every object, only if replacing is enabled. So this subclass
   * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that
   * are being serializabled. The serialized bytes are ignored by sending them to a
   * [[NullOutputStream]], which acts like a /dev/null.
   */
  private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) {
    private val output = new mutable.ArrayBuffer[Any]
    this.enableReplaceObject(true)

    def outputArray: Array[Any] = output.toArray

    override def replaceObject(obj: Object): Object = {
      output += obj
      obj
    }
  }

  /** An implicit class that allows us to call private methods of ObjectStreamClass. */
  implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
    def getSlotDescs: Array[ObjectStreamClass] = {
      reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map {
        classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass]
      }
    }

    def hasWriteObjectMethod: Boolean = {
      reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean]
    }

    def hasWriteReplaceMethod: Boolean = {
      reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean]
    }

    def invokeWriteReplace(obj: Object): Object = {
      reflect.InvokeWriteReplace.invoke(desc, obj)
    }

    def getNumObjFields: Int = {
      reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int]
    }

    def getObjFieldValues(obj: Object, out: Array[Object]): Unit = {
      reflect.GetObjFieldValues.invoke(desc, obj, out)
    }
  }

  /**
   * Object to hold all the reflection objects. If we run on a JVM that we cannot understand,
   * this field will be null and this the debug helper should be disabled.
   */
  private val reflect: ObjectStreamClassReflection = try {
    new ObjectStreamClassReflection
  } catch {
    case e: Exception =>
      logWarning("Cannot find private methods using reflection", e)
      null
  }

  private class ObjectStreamClassReflection {
    /** ObjectStreamClass.getClassDataLayout */
    val GetClassDataLayout: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout")
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass.hasWriteObjectMethod */
    val HasWriteObjectMethod: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod")
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass.hasWriteReplaceMethod */
    val HasWriteReplaceMethod: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod")
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass.invokeWriteReplace */
    val InvokeWriteReplace: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object])
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass.getNumObjFields */
    val GetNumObjFields: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields")
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass.getObjFieldValues */
    val GetObjFieldValues: Method = {
      val f = classOf[ObjectStreamClass].getDeclaredMethod(
        "getObjFieldValues", classOf[Object], classOf[Array[Object]])
      f.setAccessible(true)
      f
    }

    /** ObjectStreamClass$ClassDataSlot.desc field */
    val DescField: Field = {
      // scalastyle:off classforname
      val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
      // scalastyle:on classforname
      f.setAccessible(true)
      f
    }
  }
}