你好,游客 登录 注册 搜索
背景:
阅读新闻

Spark源码分析之七:Task运行(一)

[日期:2016-02-26] 来源:辰辰爸的技术博客  作者: [字体: ]

        在Task调度相关的两篇文章《Spark源码分析之五:Task调度(一)》《Spark源码分析之六:Task调度(二)》中,我们大致了解了Task调度相关的主要逻辑,并且在Task调度逻辑的最后,CoarseGrainedSchedulerBackend的内部类DriverEndpoint中的makeOffers()方法的最后,我们通过调用TaskSchedulerImpl的resourceOffers()方法,得到了TaskDescription序列的序列Seq[Seq[TaskDescription]],相关代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. // 调用scheduler的resourceOffers()方法,分配资源,并在得到资源后,调用launchTasks()方法,启动tasks 
  2.       // 这个scheduler就是TaskSchedulerImpl 
  3.       launchTasks(scheduler.resourceOffers(workOffers)) 
[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** 
  2.    * Called by cluster manager to offer resources on slaves. We respond by asking our active task 
  3.    * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so 
  4.    * that tasks are balanced across the cluster. 
  5.    * 
  6.    * 被集群manager调用以提供slaves上的资源。我们通过按照优先顺序询问活动task集中的task来回应。 
  7.    * 我们通过循环的方式将task调度到每个节点上以便tasks在集群中可以保持大致的均衡。 
  8.    */ 
  9.   def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { 

        这个TaskDescription很简单,是传递到executor上即将被执行的Task的描述,通常由TaskSetManager的resourceOffer()方法生成。代码如下:

 

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** 
  2.  * Description of a task that gets passed onto executors to be executed, usually created by 
  3.  * [[TaskSetManager.resourceOffer]]. 
  4.  */ 
  5. private[spark] class TaskDescription( 
  6.     val taskId: Long, 
  7.     val attemptNumber: Int, 
  8.     val executorId: String, 
  9.     val name: String, 
  10.     val index: Int,    // Index within this task's TaskSet 
  11.     _serializedTask: ByteBuffer) 
  12.   extends Serializable { 
  13.  
  14.   // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer 
  15.   // 由于ByteBuffers不可以被序列化,所以将task包装在SerializableBuffer中,_serializedTask为ByteBuffer类型的Task 
  16.   private val buffer = new SerializableBuffer(_serializedTask) 
  17.    
  18.   // 序列化后的Task, 取buffer的value 
  19.   def serializedTask: ByteBuffer = buffer.value 
  20.  
  21.  
  22.   override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) 

        此时,得到Seq[Seq[TaskDescription]],即Task被调度到相应executor上后(仅是逻辑调度,实际上并未分配到executor上执行),接下来要做的,便是真正的将Task分配到指定的executor上去执行,也就是本篇我们将要讲的Task的运行。而这部分的开端,源于上述提到的CoarseGrainedSchedulerBackend的内部类DriverEndpoint中的launchTasks()方法,代码如下:

 

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. // Launch tasks returned by a set of resource offers 
  2.     private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { 
  3.      
  4.       // 循环每个task 
  5.       for (task <- tasks.flatten) { 
  6.          
  7.         // 序列化Task 
  8.         val serializedTask = ser.serialize(task) 
  9.          
  10.         // 序列化后的task的大小超出规定的上限 
  11.         // 即如果序列化后task的大小大于等于框架配置的Akka消息最大大小减去除序列化task或task结果外,一个Akka消息需要保留的额外大小的值 
  12.         if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { 
  13.            
  14.           // 根据task的taskId,在TaskSchedulerImpl的taskIdToTaskSetManager中获取对应的TaskSetManager 
  15.           scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => 
  16.             try { 
  17.               var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + 
  18.                 "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + 
  19.                 "spark.akka.frameSize or using broadcast variables for large values." 
  20.               msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, 
  21.                 AkkaUtils.reservedSizeBytes) 
  22.                
  23.               // 调用TaskSetManager的abort()方法,标记对应TaskSetManager为失败 
  24.               taskSetMgr.abort(msg) 
  25.             } catch { 
  26.               case e: Exception => logError("Exception in error callback", e) 
  27.             } 
  28.           } 
  29.         } 
  30.         else {// 序列化后task的大小在规定的大小内 
  31.            
  32.           // 从executorDataMap中,根据task.executorId获取executor描述信息executorData 
  33.           val executorData = executorDataMap(task.executorId) 
  34.            
  35.           // executorData中,freeCores做相应减少 
  36.           executorData.freeCores -= scheduler.CPUS_PER_TASK 
  37.            
  38.           // 利用executorData中的executorEndpoint,发送LaunchTask事件,LaunchTask事件中包含序列化后的task 
  39.           executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) 
  40.         } 
  41.       } 
  42.     } 

        launchTasks的执行逻辑很简单,针对传入的TaskDescription序列,循环每个Task,做以下处理:

 

        1、首先对Task进行序列化,得到serializedTask;

        2、针对序列化后的Task:serializedTask,判断其大小:

              2.1、序列化后的task的大小达到或超出规定的上限,即框架配置的Akka消息最大大小,减去除序列化task或task结果外,一个Akka消息需要保留的额外大小的值,则根据task的taskId,在TaskSchedulerImpl的taskIdToTaskSetManager中获取对应的TaskSetManager,并调用其abort()方法,标记对应TaskSetManager为失败;

              2.2、序列化后的task的大小未达到上限,在规定的大小范围内,则:

                       2.2.1、从executorDataMap中,根据task.executorId获取executor描述信息executorData;

                       2.2.2、在executorData中,freeCores做相应减少;

                       2.2.3、利用executorData中的executorEndpoint,即Driver端executor通讯端点的引用,发送LaunchTask事件,LaunchTask事件中包含序列化后的task,将Task传递到executor中去执行。

        接下来,我们重点分析下上述流程。

        先说下异常流程,即序列化后Task的大小超过上限时,对TaskSet标记为失败的处理。入口方法为TaskSetManager的abort()方法,代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { 
  2.      
  3.     // TODO: Kill running tasks if we were not terminated due to a Mesos error 
  4.     // 调用DAGScheduler的taskSetFailed()方法,标记TaskSet运行失败 
  5.     sched.dagScheduler.taskSetFailed(taskSet, message, exception) 
  6.      
  7.     // 标志位isZombie设置为true 
  8.     isZombie = true 
  9.      
  10.     // 满足一定条件的情况下,将TaskSet标记为Finished 
  11.     maybeFinishTaskSet() 
  12.   } 

        abort()方法处理逻辑共分三步:

 

        第一,调用DAGScheduler的taskSetFailed()方法,标记TaskSet运行失败;

        第二,标志位isZombie设置为true;

        第三,满足一定条件的情况下,将TaskSet标记为Finished。

        首先看下DAGScheduler的taskSetFailed()方法,代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** 
  2.    * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or 
  3.    * cancellation of the job itself. 
  4.    */ 
  5.   def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { 
  6.     eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) 
  7.   } 

        和第二篇文章《Spark源码分析之二:Job的调度模型与运行反馈》中Job的调度模型一致,都是依靠事件队列eventProcessLoop来完成事件的调度执行的,这里,我们在事件队列eventProcessLoop中放入了一个TaskSetFailed事件。在DAGScheduler的事件处理调度函数doOnReceive()方法中,明确规定了事件的处理方法,代码如下:

 

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. // 如果是TaskSetFailed事件,调用dagScheduler.handleTaskSetFailed()方法处理 
  2.     case TaskSetFailed(taskSet, reason, exception) => 
  3.       dagScheduler.handleTaskSetFailed(taskSet, reason, exception) 

        下面,我们看下handleTaskSetFailed()这个方法。

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. private[scheduler] def handleTaskSetFailed( 
  2.       taskSet: TaskSet, 
  3.       reason: String, 
  4.       exception: Option[Throwable]): Unit = { 
  5.      
  6.     // 根据taskSet的stageId获取到对应的Stage,循环调用abortStage,终止该Stage 
  7.     stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } 
  8.      
  9.     // 提交等待的Stages 
  10.     submitWaitingStages() 
  11.   } 

        很简单,首先通过taskSet的stageId获取到对应的Stage,针对Stage,循环调用abortStage()方法,终止该Stage,然后调用submitWaitingStages()方法提交等待的Stages。我们先看下abortStage()方法,代码如下:

 

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** 
  2.    * Aborts all jobs depending on a particular Stage. This is called in response to a task set 
  3.    * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. 
  4.    * 终止给定Stage上的所有Job。 
  5.    */ 
  6.   private[scheduler] def abortStage( 
  7.       failedStage: Stage, 
  8.       reason: String, 
  9.       exception: Option[Throwable]): Unit = { 
  10.      
  11.     // 如果stageIdToStage中不存在对应的stage,说明stage已经被移除,直接返回 
  12.     if (!stageIdToStage.contains(failedStage.id)) { 
  13.       // Skip all the actions if the stage has been removed. 
  14.       return 
  15.     } 
  16.      
  17.     // 遍历activeJobs中的ActiveJob,逐个调用stageDependsOn()方法,找出存在failedStage的祖先stage的activeJob,即dependentJobs 
  18.     val dependentJobs: Seq[ActiveJob] = 
  19.       activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq 
  20.      
  21.     // 标记failedStage的完成时间completionTime 
  22.     failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) 
  23.      
  24.     // 遍历dependentJobs,调用failJobAndIndependentStages() 
  25.     for (job <- dependentJobs) { 
  26.       failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) 
  27.     } 
  28.     if (dependentJobs.isEmpty) { 
  29.       logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done"
  30.     } 
  31.   } 

        这个方法的处理逻辑主要分为四步:

 

        1、如果stageIdToStage中不存在对应的stage,说明stage已经被移除,直接返回,这是对异常情况下的一种特殊处理;

        2、遍历activeJobs中的ActiveJob,逐个调用stageDependsOn()方法,找出存在failedStage的祖先stage的activeJob,即dependentJobs;

        3、标记failedStage的完成时间completionTime;

        4、遍历dependentJobs,调用failJobAndIndependentStages()。

        其它都好说,我们主要看下stageDependsOn()和failJobAndIndependentStages()这两个方法。首先看下stageDependsOn()方法,代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** Return true if one of stage's ancestors is target. */ 
  2.   // 如果参数stage的祖先是target,返回true 
  3.   private def stageDependsOn(stage: Stage, target: Stage): Boolean = { 
  4.      
  5.     // 如果stage即为target,返回true 
  6.     if (stage == target) { 
  7.       return true 
  8.     } 
  9.      
  10.     // 存储处理过的RDD 
  11.     val visitedRdds = new HashSet[RDD[_]] 
  12.      
  13.     // We are manually maintaining a stack here to prevent StackOverflowError 
  14.     // caused by recursively visiting 
  15.     // 存储待处理的RDD 
  16.     val waitingForVisit = new Stack[RDD[_]] 
  17.      
  18.     // 定义一个visit()方法 
  19.     def visit(rdd: RDD[_]) { 
  20.       // 如果该RDD未被处理过的话,继续处理 
  21.       if (!visitedRdds(rdd)) { 
  22.         // 将RDD添加到visitedRdds中 
  23.         visitedRdds += rdd 
  24.          
  25.         // 遍历RDD的依赖 
  26.         for (dep <- rdd.dependencies) { 
  27.           dep match { 
  28.             // 如果是ShuffleDependency 
  29.             case shufDep: ShuffleDependency[_, _, _] => 
  30.              
  31.               // 获得mapStage,并且如果stage的isAvailable为false的话,将其压入waitingForVisit 
  32.               val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) 
  33.               if (!mapStage.isAvailable) { 
  34.                 waitingForVisit.push(mapStage.rdd) 
  35.               }  // Otherwise there's no need to follow the dependency back 
  36.             // 如果是NarrowDependency,直接将其压入waitingForVisit 
  37.             case narrowDep: NarrowDependency[_] => 
  38.               waitingForVisit.push(narrowDep.rdd) 
  39.           } 
  40.         } 
  41.       } 
  42.     } 
  43.      
  44.     // 从stage的rdd开始处理,将其入栈waitingForVisit 
  45.     waitingForVisit.push(stage.rdd) 
  46.      
  47.     // 当waitingForVisit中存在数据,就调用visit()方法进行处理 
  48.     while (waitingForVisit.nonEmpty) { 
  49.       visit(waitingForVisit.pop()) 
  50.     } 
  51.      
  52.     // 根据visitedRdds中是否存在target的rdd判断参数stage的祖先是否为target 
  53.     visitedRdds.contains(target.rdd) 
  54.   } 

        这个方法主要是判断参数stage是否为参数target的祖先stage,其代码风格与stage划分和提交中的部分代码一样,这在前面的两篇文章中也提到过,在此不再赘述。而它主要是通过stage的rdd,并遍历其上层依赖的rdd链,将每个stage的rdd加入到visitedRdds中,最后根据visitedRdds中是否存在target的rdd判断参数stage的祖先是否为target。值得一提的是,如果RDD的依赖是NarrowDependency,直接将其压入waitingForVisit,如果为ShuffleDependency,则需要判断stage的isAvailable,如果为false,则将对应RDD压入waitingForVisit。关于isAvailable,我在《Spark源码分析之四:Stage提交》一文中具体阐述过,这里不再赘述。

        接下来,我们再看下failJobAndIndependentStages()方法,这个方法的主要作用就是使得一个Job和仅被该Job使用的所有stages失败,并清空有关状态。代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ 
  2.   // 使得一个Job和仅被该Job使用的所有stages失败,并清空有关状态 
  3.   private def failJobAndIndependentStages( 
  4.       job: ActiveJob, 
  5.       failureReason: String, 
  6.       exception: Option[Throwable] = None): Unit = { 
  7.      
  8.     // 构造一个异常,内容为failureReason 
  9.     val error = new SparkException(failureReason, exception.getOrElse(null)) 
  10.      
  11.     // 标志位,是否能取消Stages 
  12.     var ableToCancelStages = true 
  13.  
  14.     // 标志位,是否应该中断线程 
  15.     val shouldInterruptThread = 
  16.       if (job.properties == nullfalse 
  17.       else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean 
  18.  
  19.     // Cancel all independent, running stages. 
  20.     // 取消所有独立的,正在运行的stages 
  21.      
  22.     // 根据Job的jobId,获取其stages 
  23.     val stages = jobIdToStageIds(job.jobId) 
  24.      
  25.     // 如果stages为空,记录错误日志 
  26.     if (stages.isEmpty) { 
  27.       logError("No stages registered for job " + job.jobId) 
  28.     } 
  29.      
  30.     // 遍历stages,循环处理 
  31.     stages.foreach { stageId => 
  32.        
  33.       // 根据stageId,获取jobsForStage,即每个Job所包含的Stage信息 
  34.       val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds) 
  35.        
  36.       // 首先处理异常情况,即jobsForStage为空,或者jobsForStage中不包含当前Job 
  37.       if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { 
  38.         logError( 
  39.           "Job %d not registered for stage %d even though that stage was registered for the job" 
  40.             .format(job.jobId, stageId)) 
  41.       } else if (jobsForStage.get.size == 1) { 
  42.         // 如果stageId对应的stage不存在 
  43.         if (!stageIdToStage.contains(stageId)) { 
  44.           logError(s"Missing Stage for stage with id $stageId"
  45.         } else { 
  46.           // This is the only job that uses this stage, so fail the stage if it is running. 
  47.           //  
  48.           val stage = stageIdToStage(stageId) 
  49.           if (runningStages.contains(stage)) { 
  50.             try { // cancelTasks will fail if a SchedulerBackend does not implement killTask 
  51.                
  52.               // 调用taskScheduler的cancelTasks()方法,取消stage内的tasks 
  53.               taskScheduler.cancelTasks(stageId, shouldInterruptThread) 
  54.                
  55.               // 标记Stage为完成 
  56.               markStageAsFinished(stage, Some(failureReason)) 
  57.             } catch { 
  58.               case e: UnsupportedOperationException => 
  59.                 logInfo(s"Could not cancel tasks for stage $stageId", e) 
  60.               ableToCancelStages = false 
  61.             } 
  62.           } 
  63.         } 
  64.       } 
  65.     } 
  66.  
  67.     if (ableToCancelStages) {// 如果能取消Stages 
  68.      
  69.       // 调用job监听器的jobFailed()方法 
  70.       job.listener.jobFailed(error) 
  71.        
  72.       // 为Job和独立Stages清空状态,独立Stages的意思为该stage仅为该Job使用 
  73.       cleanupStateForJobAndIndependentStages(job) 
  74.        
  75.       // 发送一个SparkListenerJobEnd事件 
  76.       listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) 
  77.     } 
  78.   } 

        处理过程还是很简单的,读者可以通过上述源码和注释自行补脑,这里就先略过了。

 

        下面,再说下正常情况下,即序列化后Task大小未超过上限时,LaunchTask事件的发送及executor端的响应。代码再跳转到CoarseGrainedSchedulerBackend的内部类DriverEndpoint中的launchTasks()方法。正常情况下处理流程主要分为三大部分:

        1、从executorDataMap中,根据task.executorId获取executor描述信息executorData;

        2、在executorData中,freeCores做相应减少;

        3、利用executorData中的executorEndpoint,即Driver端executor通讯端点的引用,发送LaunchTask事件,LaunchTask事件中包含序列化后的task,将Task传递到executor中去执行。

        我们重点看下第3步,利用Driver端持有的executor描述信息executorData中的executorEndpoint,即Driver端executor通讯端点的引用,发送LaunchTask事件给executor,将Task传递到executor中去执行。那么executor中是如何接收LaunchTask事件的呢?答案就在CoarseGrainedExecutorBackend中。

        我们先说下这个CoarseGrainedExecutorBackend,类的定义如下所示:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. private[spark] class CoarseGrainedExecutorBackend( 
  2.     override val rpcEnv: RpcEnv, 
  3.     driverUrl: String, 
  4.     executorId: String, 
  5.     hostPort: String, 
  6.     cores: Int, 
  7.     userClassPath: Seq[URL], 
  8.     env: SparkEnv) 
  9.   extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { 

        由上面的代码我们可以知道,它实现了ThreadSafeRpcEndpoint和ExecutorBackend两个trait,而ExecutorBackend的定义如下:

 

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /** 
  2.  * A pluggable interface used by the Executor to send updates to the cluster scheduler. 
  3.  * 一个被Executor用来发送更新到集群调度器的可插拔接口。 
  4.  */ 
  5. private[spark] trait ExecutorBackend { 
  6.    
  7.   // 唯一的一个statusUpdate()方法 
  8.   // 需要Long类型的taskId、TaskState类型的state、ByteBuffer类型的data三个参数 
  9.   def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) 

 

        那么它自然就有两种主要的任务,第一,作为endpoint提供driver与executor间的通讯功能;第二,提供了executor任务执行时状态汇报的功能。

        CoarseGrainedExecutorBackend到底是什么呢?这里我们先不深究,留到以后分析,你只要知道它是Executor的一个后台辅助进程,和Executor是一对一的关系,向Executor提供了与Driver通讯、任务执行时状态汇报两个基本功能即可。

        接下来,我们看下CoarseGrainedExecutorBackend是如何处理LaunchTask事件的。做为RpcEndpoint,在其处理各类事件或消息的receive()方法中,定义如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. case LaunchTask(data) => 
  2.       if (executor == null) { 
  3.         logError("Received LaunchTask command but executor was null"
  4.         System.exit(1
  5.       } else { 
  6.        
  7.         // 反序列话task,得到taskDesc 
  8.         val taskDesc = ser.deserialize[TaskDescription](data.value) 
  9.         logInfo("Got assigned task " + taskDesc.taskId) 
  10.          
  11.         // 调用executor的launchTask()方法加载task 
  12.         executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, 
  13.           taskDesc.name, taskDesc.serializedTask) 
  14.       } 

        首先,会判断对应的executor是否为空,为空的话,记录错误日志并退出,不为空的话,则按照如下流程处理:

 

        1、反序列话task,得到taskDesc;

        2、调用executor的launchTask()方法加载task。

        那么,重点就落在了Executor的launchTask()方法中,代码如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. def launchTask( 
  2.       context: ExecutorBackend, 
  3.       taskId: Long, 
  4.       attemptNumber: Int, 
  5.       taskName: String, 
  6.       serializedTask: ByteBuffer): Unit = { 
  7.        
  8.     // 新建一个TaskRunner 
  9.     val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, 
  10.       serializedTask) 
  11.        
  12.     // 将taskId与TaskRunner的对应关系存入runningTasks 
  13.     runningTasks.put(taskId, tr) 
  14.      
  15.     // 线程池执行TaskRunner 
  16.     threadPool.execute(tr) 
  17.   } 

        非常简单,创建一个TaskRunner对象,然后将taskId与TaskRunner的对应关系存入runningTasks,将TaskRunner扔到线程池中去执行即可。

 

        我们先看下这个TaskRunner类。我们先看下Class及其成员变量的定义,如下:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. class TaskRunner( 
  2.       execBackend: ExecutorBackend, 
  3.       val taskId: Long, 
  4.       val attemptNumber: Int, 
  5.       taskName: String, 
  6.       serializedTask: ByteBuffer) 
  7.     extends Runnable { 
  8.      
  9.     // TaskRunner继承了Runnable 
  10.  
  11.     /** Whether this task has been killed. */ 
  12.     // 标志位,task是否被杀掉 
  13.     @volatile private var killed = false 
  14.  
  15.     /** How much the JVM process has spent in GC when the task starts to run. */ 
  16.     @volatile var startGCTime: Long = _ 
  17.  
  18.     /** 
  19.      * The task to run. This will be set in run() by deserializing the task binary coming 
  20.      * from the driver. Once it is set, it will never be changed. 
  21.      *  
  22.      * 需要运行的task。它将在反序列化来自driver的task二进制数据时在run()方法被设置,一旦被设置,它将不会再发生改变。 
  23.      */ 
  24.     @volatile var task: Task[Any] = _ 

        由类的定义我们可以看出,TaskRunner继承了Runnable,所以它本质上是一个线程,故其可以被放到线程池中去运行。它所包含的成员变量,主要有以下几个:

 

        1、execBackend:Executor后台辅助进程,提供了与Driver通讯、状态汇报等两大基本功能,实际上传入的是CoarseGrainedExecutorBackend实例;

        2、taskId:Task的唯一标识;

        3、attemptNumber:Task运行的序列号,Spark与MapReduce一样,可以为拖后腿任务启动备份任务,即推测执行原理,如此,就需要通过taskId加attemptNumber来唯一标识一个Task运行实例;

        4、serializedTask:ByteBuffer类型,序列化后的Task,包含的是Task的内容,通过发序列化它来得到Task,并运行其中的run()方法来执行Task;

        5、killed:Task是否被杀死的标志位;

        6、task:Task[Any]类型,需要运行的Task,它将在反序列化来自driver的task二进制数据时在run()方法被设置,一旦被设置,它将不会再发生改变;

       7、startGCTime:JVM在task开始运行后,进行垃圾回收的时间。

        另外,既然是一个线程,TaskRunner必须得提供run()方法,该run()方法就是TaskRunner线程在线程池中被调度时,需要执行的方法,我们来看下它的定义:

 

[java] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. override def run(): Unit = { 
  2.      
  3.       // Step1:Task及其运行时需要的辅助对象构造 
  4.        
  5.       // 获取任务内存管理器 
  6.       val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) 
  7.        
  8.       // 反序列化开始时间 
  9.       val deserializeStartTime = System.currentTimeMillis() 
  10.        
  11.       // 当前线程设置上下文类加载器 
  12.       Thread.currentThread.setContextClassLoader(replClassLoader) 
  13.        
  14.       // 从SparkEnv中获取序列化器 
  15.       val ser = env.closureSerializer.newInstance() 
  16.       logInfo(s"Running $taskName (TID $taskId)"
  17.        
  18.       // execBackend更新状态TaskState.RUNNING 
  19.       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) 
  20.       var taskStart: Long = 0 
  21.        
  22.       // 计算开始垃圾回收的时间 
  23.       startGCTime = computeTotalGcTime() 
  24.  
  25.       try { 
  26.         // 调用Task的deserializeWithDependencies()方法,反序列化Task,得到Task运行需要的文件taskFiles、jar包taskFiles和Task二进制数据taskBytes 
  27.         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) 
  28.         updateDependencies(taskFiles, taskJars) 
  29.          
  30.         // 反序列化Task二进制数据taskBytes,得到task实例 
  31.         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) 
  32.          
  33.         // 设置Task的任务内存管理器 
  34.         task.setTaskMemoryManager(taskMemoryManager) 
  35.  
  36.         // If this task has been killed before we deserialized it, let's quit now. Otherwise, 
  37.         // continue executing the task. 
  38.         // 如果此时Task被kill,抛出异常,快速退出 
  39.         if (killed) { 
  40.           // Throw an exception rather than returning, because returning within a try{} block 
  41.           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl 
  42.           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure 
  43.           // for the task. 
  44.           throw new TaskKilledException 
  45.         } 
  46.  
  47.         logDebug("Task " + taskId + "'s epoch is " + task.epoch) 
  48.         // mapOutputTracker更新Epoch 
  49.         env.mapOutputTracker.updateEpoch(task.epoch) 
  50.  
  51.         // Run the actual task and measure its runtime. 
  52.         // 运行真正的task,并度量它的运行时间 
  53.          
  54.         // Step2:Task运行 
  55.          
  56.         // task开始时间 
  57.         taskStart = System.currentTimeMillis() 
  58.          
  59.         // 标志位threwException设置为true,标识Task真正执行过程中是否抛出异常 
  60.         var threwException = true 
  61.          
  62.         // 调用Task的run()方法,真正执行Task,并获得运行结果value 
  63.         val (value, accumUpdates) = try { 
  64.          
  65.           // 调用Task的run()方法,真正执行Task 
  66.           val res = task.run( 
  67.             taskAttemptId = taskId, 
  68.             attemptNumber = attemptNumber, 
  69.             metricsSystem = env.metricsSystem) 
  70.            
  71.           // 标志位threwException设置为false 
  72.           threwException = false 
  73.            
  74.           // 返回res,Task的run()方法中,res的定义为(T, AccumulatorUpdates) 
  75.           // 这里,前者为任务运行结果,后者为累加器更新 
  76.           res 
  77.         } finally { 
  78.            
  79.           // 通过任务内存管理器清理所有的分配的内存 
  80.           val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() 
  81.           if (freedMemory > 0) { 
  82.             val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" 
  83.             if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak"false) && !threwException) { 
  84.               throw new SparkException(errMsg) 
  85.             } else { 
  86.               logError(errMsg) 
  87.             } 
  88.           } 
  89.         } 
  90.          
  91.         // task完成时间 
  92.         val taskFinish = System.currentTimeMillis() 
  93.  
  94.         // If the task has been killed, let's fail it. 
  95.         // 如果task被杀死,抛出TaskKilledException异常 
  96.         if (task.killed) { 
  97.           throw new TaskKilledException 
  98.         } 
  99.  
  100.         // Step3:Task运行结果处理 
  101.          
  102.         // 通过Spark获取Task运行结果序列化器 
  103.         val resultSer = env.serializer.newInstance() 
  104.          
  105.         // 结果序列化前的时间点 
  106.         val beforeSerialization = System.currentTimeMillis() 
  107.          
  108.         // 利用Task运行结果序列化器序列化Task运行结果,得到valueBytes 
  109.         val valueBytes = resultSer.serialize(value) 
  110.          
  111.         // 结果序列化后的时间点 
  112.         val afterSerialization = System.currentTimeMillis() 
  113.  
  114.         // 度量指标体系相关,暂不介绍 
  115.         for (m <- task.metrics) { 
  116.           // Deserialization happens in two parts: first, we deserialize a Task object, which 
  117.           // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. 
  118.           m.setExecutorDeserializeTime( 
  119.             (taskStart - deserializeStartTime) + task.executorDeserializeTime) 
  120.           // We need to subtract Task.run()'s deserialization time to avoid double-counting 
  121.           m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) 
  122.           m.setJvmGCTime(computeTotalGcTime() - startGCTime) 
  123.           m.setResultSerializationTime(afterSerialization - beforeSerialization) 
  124.           m.updateAccumulators() 
  125.         } 
  126.  
  127.         // 构造DirectTaskResult,同时包含Task运行结果valueBytes和累加器更新值accumulator updates 
  128.         val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) 
  129.          
  130.         // 序列化DirectTaskResult,得到serializedDirectResult 
  131.         val serializedDirectResult = ser.serialize(directResult) 
  132.          
  133.         // 获取Task运行结果大小 
  134.         val resultSize = serializedDirectResult.limit 
  135.  
  136.         // directSend = sending directly back to the driver 
  137.         // directSend的意思就是直接发送结果至Driver端 
  138.         val serializedResult: ByteBuffer = { 
  139.          
  140.           // 如果Task运行结果大小大于所有Task运行结果的最大大小,序列化IndirectTaskResult 
  141.           // IndirectTaskResult为存储在Worker上BlockManager中DirectTaskResult的一个引用 
  142.           if (maxResultSize > 0 && resultSize > maxResultSize) { 
  143.             logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + 
  144.               s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + 
  145.               s"dropping it."
  146.             ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) 
  147.           } 
  148.           // 如果 Task运行结果大小超过Akka除去需要保留的字节外最大大小,则将结果写入BlockManager 
  149.           // 即运行结果无法通过消息传递 
  150.           else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { 
  151.              
  152.             val blockId = TaskResultBlockId(taskId) 
  153.             env.blockManager.putBytes( 
  154.               blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) 
  155.             logInfo( 
  156.               s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)"
  157.             ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) 
  158.           }   
  159.           // Task运行结果比较小的话,直接返回,通过消息传递 
  160.           else { 
  161.             logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver"
  162.             serializedDirectResult 
  163.           } 
  164.         } 
  165.  
  166.         // execBackend更新状态TaskState.FINISHED 
  167.         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) 
  168.  
  169.       } catch {// 处理各种异常信息 
  170.          
  171.         case ffe: FetchFailedException => 
  172.           val reason = ffe.toTaskEndReason 
  173.           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) 
  174.  
  175.         case _: TaskKilledException | _: InterruptedException if task.killed => 
  176.           logInfo(s"Executor killed $taskName (TID $taskId)"
  177.           execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) 
  178.  
  179.         case cDE: CommitDeniedException => 
  180.           val reason = cDE.toTaskEndReason 
  181.           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) 
  182.  
  183.         case t: Throwable => 
  184.           // Attempt to exit cleanly by informing the driver of our failure. 
  185.           // If anything goes wrong (or this was a fatal exception), we will delegate to 
  186.           // the default uncaught exception handler, which will terminate the Executor. 
  187.           logError(s"Exception in $taskName (TID $taskId)", t) 
  188.  
  189.           val metrics: Option[TaskMetrics] = Option(task).flatMap { task => 
  190.             task.metrics.map { m => 
  191.               m.setExecutorRunTime(System.currentTimeMillis() - taskStart) 
  192.               m.setJvmGCTime(computeTotalGcTime() - startGCTime) 
  193.               m.updateAccumulators() 
  194.               m 
  195.             } 
  196.           } 
  197.           val serializedTaskEndReason = { 
  198.             try { 
  199.               ser.serialize(new ExceptionFailure(t, metrics)) 
  200.             } catch { 
  201.               case _: NotSerializableException => 
  202.                 // t is not serializable so just send the stacktrace 
  203.                 ser.serialize(new ExceptionFailure(t, metrics, false)) 
  204.             } 
  205.           } 
  206.            
  207.           // execBackend更新状态TaskState.FAILED 
  208.           execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) 
  209.  
  210.           // Don't forcibly exit unless the exception was inherently fatal, to avoid 
  211.           // stopping other tasks unnecessarily. 
  212.           if (Utils.isFatalError(t)) { 
  213.             SparkUncaughtExceptionHandler.uncaughtException(t) 
  214.           } 
  215.  
  216.       } finally { 
  217.        
  218.         // 最后,无论运行成功还是失败,将task从runningTasks中移除 
  219.         runningTasks.remove(taskId) 
  220.       } 
  221.     } 

        如此长的一个方法,好长好大,哈哈!不过,纵观全篇,无非三个Step就可搞定:

 

        1、Step1:Task及其运行时需要的辅助对象构造;

        2、Step2:Task运行;

        3、Step3:Task运行结果处理。

        对, 就这么简单!鉴于时间与篇幅问题,我们这里先讲下主要流程,细节方面的东西留待下节继续。

        下面,我们一个个Step来看,首先看下Step1:Task及其运行时需要的辅助对象构造,主要包括以下步骤:

        1.1、构造TaskMemoryManager任务内存管理器,即taskMemoryManager;

        1.2、记录反序列化开始时间;

        1.3、当前线程设置上下文类加载器;

        1.4、从SparkEnv中获取序列化器ser;

        1.5、execBackend更新状态TaskState.RUNNING;

        1.6、计算开始垃圾回收的时间;

        1.7、调用Task的deserializeWithDependencies()方法,反序列化Task,得到Task运行需要的文件taskFiles、jar包taskFiles和Task二进制数据taskBytes;

        1.8、反序列化Task二进制数据taskBytes,得到task实例;

        1.9、设置Task的任务内存管理器;

        1.10、如果此时Task被kill,抛出异常,快速退出;

       

        接下来,是Step2:Task运行,主要流程如下:

        2.1、获取task开始时间;

        2.2、标志位threwException设置为true,标识Task真正执行过程中是否抛出异常;

        2.3、调用Task的run()方法,真正执行Task,并获得运行结果value,和累加器更新accumUpdates;

        2.4、标志位threwException设置为false;

        2.5、通过任务内存管理器taskMemoryManager清理所有的分配的内存;

        2.6、获取task完成时间;

        2.7、如果task被杀死,抛出TaskKilledException异常。

       

        最后一步,Step3:Task运行结果处理,大体流程如下:

        3.1、通过SparkEnv获取Task运行结果序列化器;

        3.2、获取结果序列化前的时间点;

        3.3、利用Task运行结果序列化器序列化Task运行结果value,得到valueBytes;

        3.4、获取结果序列化后的时间点;

        3.5、度量指标体系相关,暂不介绍;

        3.6、构造DirectTaskResult,同时包含Task运行结果valueBytes和累加器更新值accumulator updates;

        3.7、序列化DirectTaskResult,得到serializedDirectResult;

        3.8、获取Task运行结果大小;

        3.9、处理Task运行结果:

                 3.9.1、如果Task运行结果大小大于所有Task运行结果的最大大小,序列化IndirectTaskResult,IndirectTaskResult为存储在Worker上BlockManager中DirectTaskResult的一个引用;

                 3.9.2、如果 Task运行结果大小超过Akka除去需要保留的字节外最大大小,则将结果写入BlockManager,Task运行结果比较小的话,直接返回,通过消息传递;

                 3.9.3、Task运行结果比较小的话,直接返回,通过消息传递

        3.10、execBackend更新状态TaskState.FINISHED;

        最后,无论运行成功还是失败,将task从runningTasks中移除。

        至此,Task的运行主体流程已经介绍完毕,剩余的部分细节,包括Task内run()方法的具体执行,还有任务内存管理器、序列化器、累加更新,还有部分异常情况处理,状态汇报等等其他更为详细的内容留到下篇再讲吧!

        明天还要工作,洗洗睡了!





收藏 推荐 打印 | 录入:elainebo | 阅读:
本文评论   查看全部评论 (0)
表情: 表情 姓名: 字数
点评:
       
评论声明
  • 尊重网上道德,遵守中华人民共和国的各项有关法律法规
  • 承担一切因您的行为而直接或间接导致的民事或刑事法律责任
  • 本站管理人员有权保留或删除其管辖留言中的任意内容
  • 本站有权在网站内转载或引用您的评论
  • 参与本评论即表明您已经阅读并接受上述条款