Browse Source

[DS-6849][MasterServer] fetch more commands and handle in parallel (#6850)

* [DS-6849][MasterServer] fetch more commands and handle in parallel

* add return

* handle command with check

* add test

* delete master prefix

Co-authored-by: caishunfeng <534328519@qq.com>
wind 3 years ago
parent
commit
595e4843d0

+ 18 - 0
dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/master/config/MasterConfig.java

@@ -28,6 +28,8 @@ import org.springframework.stereotype.Component;
 @ConfigurationProperties("master")
 public class MasterConfig {
     private int listenPort;
+    private int fetchCommandNum;
+    private int preExecThreads;
     private int execThreads;
     private int execTaskNum;
     private int dispatchTaskNumber;
@@ -48,6 +50,22 @@ public class MasterConfig {
         this.listenPort = listenPort;
     }
 
+    public int getFetchCommandNum() {
+        return fetchCommandNum;
+    }
+
+    public void setFetchCommandNum(int fetchCommandNum) {
+        this.fetchCommandNum = fetchCommandNum;
+    }
+
+    public int getPreExecThreads() {
+        return preExecThreads;
+    }
+
+    public void setPreExecThreads(int preExecThreads) {
+        this.preExecThreads = preExecThreads;
+    }
+
     public int getExecThreads() {
         return execThreads;
     }

+ 105 - 50
dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/master/runner/MasterSchedulerService.java

@@ -36,9 +36,14 @@ import org.apache.dolphinscheduler.server.master.registry.ServerNodeManager;
 import org.apache.dolphinscheduler.service.alert.ProcessAlertManager;
 import org.apache.dolphinscheduler.service.process.ProcessService;
 
+import org.apache.commons.collections4.CollectionUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 
@@ -90,6 +95,11 @@ public class MasterSchedulerService extends Thread {
     @Autowired
     NettyExecutorManager nettyExecutorManager;
 
+    /**
+     * master prepare exec service
+     */
+    private ThreadPoolExecutor masterPrepareExecService;
+
     /**
      * master exec service
      */
@@ -120,6 +130,7 @@ public class MasterSchedulerService extends Thread {
      * constructor of MasterSchedulerService
      */
     public void init() {
+        this.masterPrepareExecService = (ThreadPoolExecutor) ThreadUtils.newDaemonFixedThreadExecutor("Master-Pre-Exec-Thread", masterConfig.getPreExecThreads());
         this.masterExecService = (ThreadPoolExecutor) ThreadUtils.newDaemonFixedThreadExecutor("Master-Exec-Thread", masterConfig.getExecThreads());
         NettyClientConfig clientConfig = new NettyClientConfig();
         this.nettyRemotingClient = new NettyRemotingClient(clientConfig);
@@ -175,74 +186,110 @@ public class MasterSchedulerService extends Thread {
     /**
      * 1. get command by slot
      * 2. donot handle command if slot is empty
-     *
-     * @throws Exception
      */
     private void scheduleProcess() throws Exception {
+        List<Command> commands = findCommands();
+        if (CollectionUtils.isEmpty(commands)) {
+            //indicate that no command ,sleep for 1s
+            Thread.sleep(Constants.SLEEP_TIME_MILLIS);
+            return;
+        }
 
-        // make sure to scan and delete command  table in one transaction
-        Command command = findOneCommand();
-        if (command != null) {
-            logger.info("find one command: id: {}, type: {}", command.getId(), command.getCommandType());
-            try {
-                ProcessInstance processInstance = processService.handleCommand(logger,
-                        getLocalAddress(),
-                        command,
-                        processDefinitionCacheMaps);
-                if (!masterConfig.isCacheProcessDefinition()
-                        && processDefinitionCacheMaps.size() > 0) {
-                    processDefinitionCacheMaps.clear();
+        if (!masterConfig.isCacheProcessDefinition() && processDefinitionCacheMaps.size() > 0) {
+            processDefinitionCacheMaps.clear();
+        }
+
+        List<ProcessInstance> processInstances = command2ProcessInstance(commands);
+        if (CollectionUtils.isEmpty(processInstances)) {
+            return;
+        }
+
+        for (ProcessInstance processInstance : processInstances) {
+            if (processInstance == null) {
+                continue;
+            }
+
+            WorkflowExecuteThread workflowExecuteThread = new WorkflowExecuteThread(
+                    processInstance
+                    , processService
+                    , nettyExecutorManager
+                    , processAlertManager
+                    , masterConfig
+                    , taskTimeoutCheckList);
+
+            this.processInstanceExecCacheManager.cache(processInstance.getId(), workflowExecuteThread);
+            if (processInstance.getTimeout() > 0) {
+                this.processTimeoutCheckList.put(processInstance.getId(), processInstance);
+            }
+            masterExecService.execute(workflowExecuteThread);
+        }
+    }
+
+    private List<ProcessInstance> command2ProcessInstance(List<Command> commands) {
+        if (CollectionUtils.isEmpty(commands)) {
+            return null;
+        }
+
+        ProcessInstance[] processInstances = new ProcessInstance[commands.size()];
+        CountDownLatch latch = new CountDownLatch(commands.size());
+        for (int i = 0; i < commands.size(); i++) {
+            int index = i;
+            this.masterPrepareExecService.execute(() -> {
+                Command command = commands.get(index);
+                // slot check again
+                if (!slotCheck(command)) {
+                    latch.countDown();
+                    return;
                 }
-                if (processInstance != null) {
-                    WorkflowExecuteThread workflowExecuteThread = new WorkflowExecuteThread(
-                            processInstance
-                            , processService
-                            , nettyExecutorManager
-                            , processAlertManager
-                            , masterConfig
-                            , taskTimeoutCheckList);
-
-                    this.processInstanceExecCacheManager.cache(processInstance.getId(), workflowExecuteThread);
-                    if (processInstance.getTimeout() > 0) {
-                        this.processTimeoutCheckList.put(processInstance.getId(), processInstance);
+
+                try {
+                    ProcessInstance processInstance = processService.handleCommand(logger,
+                            getLocalAddress(),
+                            command,
+                            processDefinitionCacheMaps);
+                    if (processInstance != null) {
+                        processInstances[index] = processInstance;
+                        logger.info("handle command command {} end, create process instance {}",
+                                command.getId(), processInstance.getId());
                     }
-                    logger.info("handle command end, command {} process {} start...",
-                            command.getId(), processInstance.getId());
-                    masterExecService.execute(workflowExecuteThread);
+                } catch (Exception e) {
+                    logger.error("scan command error ", e);
+                    processService.moveToErrorCommand(command, e.toString());
+                } finally {
+                    latch.countDown();
                 }
-            } catch (Exception e) {
-                logger.error("scan command error ", e);
-                processService.moveToErrorCommand(command, e.toString());
-            }
-        } else {
-            //indicate that no command ,sleep for 1s
-            Thread.sleep(Constants.SLEEP_TIME_MILLIS);
+            });
+        }
+
+        try {
+            // make sure to finish handling command each time before next scan
+            latch.await();
+        } catch (InterruptedException e) {
+            logger.error("countDownLatch await error ", e);
         }
+
+        return Arrays.asList(processInstances);
     }
 
-    private Command findOneCommand() {
+    private List<Command> findCommands() {
         int pageNumber = 0;
-        Command result = null;
+        int pageSize = masterConfig.getFetchCommandNum();
+        List<Command> result = new ArrayList<>();
         while (Stopper.isRunning()) {
             if (ServerNodeManager.MASTER_SIZE == 0) {
-                return null;
+                return result;
             }
-            List<Command> commandList = processService.findCommandPage(ServerNodeManager.MASTER_SIZE, pageNumber);
+            List<Command> commandList = processService.findCommandPage(pageSize, pageNumber);
             if (commandList.size() == 0) {
-                return null;
+                return result;
             }
             for (Command command : commandList) {
-                int slot = ServerNodeManager.getSlot();
-                if (ServerNodeManager.MASTER_SIZE != 0
-                        && command.getId() % ServerNodeManager.MASTER_SIZE == slot) {
-                    result = command;
-                    break;
+                if (slotCheck(command)) {
+                    result.add(command);
                 }
             }
-            if (result != null) {
-                logger.info("find command {}, slot:{} :",
-                        result.getId(),
-                        ServerNodeManager.getSlot());
+            if (CollectionUtils.isNotEmpty(result)) {
+                logger.info("find {} commands, slot:{}", result.size(), ServerNodeManager.getSlot());
                 break;
             }
             pageNumber += 1;
@@ -250,6 +297,14 @@ public class MasterSchedulerService extends Thread {
         return result;
     }
 
+    private boolean slotCheck(Command command) {
+        int slot = ServerNodeManager.getSlot();
+        if (ServerNodeManager.MASTER_SIZE != 0 && command.getId() % ServerNodeManager.MASTER_SIZE == slot) {
+            return true;
+        }
+        return false;
+    }
+
     private String getLocalAddress() {
         return NetUtils.getAddr(masterConfig.getListenPort());
     }

+ 4 - 0
dolphinscheduler-server/src/main/resources/application-master.yaml

@@ -20,6 +20,10 @@ spring:
 
 master:
   listen-port: 5678
+  # master fetch command num
+  fetch-command-num: 10
+  # master prepare execute thread number to limit handle commands in parallel
+  pre-exec-threads: 10
   # master execute thread number to limit process instances in parallel
   exec-threads: 100
   # master execute task number in parallel per process instance

+ 106 - 108
dolphinscheduler-service/src/main/java/org/apache/dolphinscheduler/service/process/ProcessService.java

@@ -37,7 +37,6 @@ import org.apache.dolphinscheduler.common.enums.ExecutionStatus;
 import org.apache.dolphinscheduler.common.enums.FailureStrategy;
 import org.apache.dolphinscheduler.common.enums.Flag;
 import org.apache.dolphinscheduler.common.enums.ReleaseState;
-import org.apache.dolphinscheduler.spi.enums.ResourceType;
 import org.apache.dolphinscheduler.common.enums.TaskDependType;
 import org.apache.dolphinscheduler.common.enums.TimeoutFlag;
 import org.apache.dolphinscheduler.common.enums.WarningType;
@@ -102,8 +101,10 @@ import org.apache.dolphinscheduler.dao.utils.DagHelper;
 import org.apache.dolphinscheduler.remote.command.StateEventChangeCommand;
 import org.apache.dolphinscheduler.remote.processor.StateEventCallbackService;
 import org.apache.dolphinscheduler.remote.utils.Host;
+import org.apache.dolphinscheduler.service.exceptions.ServiceException;
 import org.apache.dolphinscheduler.service.log.LogClientService;
 import org.apache.dolphinscheduler.service.quartz.cron.CronUtils;
+import org.apache.dolphinscheduler.spi.enums.ResourceType;
 
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.lang.StringUtils;
@@ -140,10 +141,10 @@ public class ProcessService {
     private final Logger logger = LoggerFactory.getLogger(getClass());
 
     private final int[] stateArray = new int[]{ExecutionStatus.SUBMITTED_SUCCESS.ordinal(),
-        ExecutionStatus.RUNNING_EXECUTION.ordinal(),
-        ExecutionStatus.DELAY_EXECUTION.ordinal(),
-        ExecutionStatus.READY_PAUSE.ordinal(),
-        ExecutionStatus.READY_STOP.ordinal()};
+            ExecutionStatus.RUNNING_EXECUTION.ordinal(),
+            ExecutionStatus.DELAY_EXECUTION.ordinal(),
+            ExecutionStatus.READY_PAUSE.ordinal(),
+            ExecutionStatus.READY_STOP.ordinal()};
 
     @Autowired
     private UserMapper userMapper;
@@ -215,9 +216,9 @@ public class ProcessService {
      * @param logger logger
      * @param host host
      * @param command found command
-     * @param processDefinitionCacheMaps
      * @return process instance
      */
+    @Transactional
     public ProcessInstance handleCommand(Logger logger, String host, Command command, HashMap<String, ProcessDefinition> processDefinitionCacheMaps) {
         ProcessInstance processInstance = constructProcessInstance(command, host, processDefinitionCacheMaps);
         // cannot construct process instance, return null
@@ -231,21 +232,21 @@ public class ProcessService {
         //if the processDefination is serial
         ProcessDefinition processDefinition = this.findProcessDefinition(processInstance.getProcessDefinitionCode(), processInstance.getProcessDefinitionVersion());
         if (processDefinition.getExecutionType().typeIsSerial()) {
-            saveSerialProcess(processInstance,processDefinition);
+            saveSerialProcess(processInstance, processDefinition);
             if (processInstance.getState() != ExecutionStatus.SUBMITTED_SUCCESS) {
-                this.setSubProcessParam(processInstance);
-                this.commandMapper.deleteById(command.getId());
+                setSubProcessParam(processInstance);
+                deleteCommandWithCheck(command.getId());
                 return null;
             }
         } else {
             saveProcessInstance(processInstance);
         }
-        this.setSubProcessParam(processInstance);
-        this.commandMapper.deleteById(command.getId());
+        setSubProcessParam(processInstance);
+        deleteCommandWithCheck(command.getId());
         return processInstance;
     }
 
-    private void saveSerialProcess(ProcessInstance processInstance,ProcessDefinition processDefinition) {
+    private void saveSerialProcess(ProcessInstance processInstance, ProcessDefinition processDefinition) {
         processInstance.setState(ExecutionStatus.SERIAL_WAIT);
         saveProcessInstance(processInstance);
         //serial wait
@@ -253,7 +254,7 @@ public class ProcessService {
         if (processDefinition.getExecutionType().typeIsSerialWait()) {
             while (true) {
                 List<ProcessInstance> runningProcessInstances = this.processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(processInstance.getProcessDefinitionCode(),
-                        Constants.RUNNING_PROCESS_STATE,processInstance.getId());
+                        Constants.RUNNING_PROCESS_STATE, processInstance.getId());
                 if (CollectionUtils.isEmpty(runningProcessInstances)) {
                     processInstance.setState(ExecutionStatus.SUBMITTED_SUCCESS);
                     saveProcessInstance(processInstance);
@@ -266,14 +267,14 @@ public class ProcessService {
             }
         } else if (processDefinition.getExecutionType().typeIsSerialDiscard()) {
             List<ProcessInstance> runningProcessInstances = this.processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(processInstance.getProcessDefinitionCode(),
-                    Constants.RUNNING_PROCESS_STATE,processInstance.getId());
+                    Constants.RUNNING_PROCESS_STATE, processInstance.getId());
             if (CollectionUtils.isEmpty(runningProcessInstances)) {
                 processInstance.setState(ExecutionStatus.STOP);
                 saveProcessInstance(processInstance);
             }
         } else if (processDefinition.getExecutionType().typeIsSerialPriority()) {
             List<ProcessInstance> runningProcessInstances = this.processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(processInstance.getProcessDefinitionCode(),
-                    Constants.RUNNING_PROCESS_STATE,processInstance.getId());
+                    Constants.RUNNING_PROCESS_STATE, processInstance.getId());
             if (CollectionUtils.isNotEmpty(runningProcessInstances)) {
                 for (ProcessInstance info : runningProcessInstances) {
                     info.setCommandType(CommandType.STOP);
@@ -345,10 +346,6 @@ public class ProcessService {
 
     /**
      * get command page
-     *
-     * @param pageSize
-     * @param pageNumber
-     * @return
      */
     public List<Command> findCommandPage(int pageSize, int pageNumber) {
         return commandMapper.queryCommandPage(pageSize, pageNumber * pageSize);
@@ -569,21 +566,21 @@ public class ProcessService {
         // process instance quit by "waiting thread" state
         if (originCommand == null) {
             Command command = new Command(
-                CommandType.RECOVER_WAITING_THREAD,
-                processInstance.getTaskDependType(),
-                processInstance.getFailureStrategy(),
-                processInstance.getExecutorId(),
-                processInstance.getProcessDefinition().getCode(),
-                JSONUtils.toJsonString(cmdParam),
-                processInstance.getWarningType(),
-                processInstance.getWarningGroupId(),
-                processInstance.getScheduleTime(),
-                processInstance.getWorkerGroup(),
-                processInstance.getEnvironmentCode(),
-                processInstance.getProcessInstancePriority(),
-                processInstance.getDryRun(),
-                processInstance.getId(),
-                processInstance.getProcessDefinitionVersion()
+                    CommandType.RECOVER_WAITING_THREAD,
+                    processInstance.getTaskDependType(),
+                    processInstance.getFailureStrategy(),
+                    processInstance.getExecutorId(),
+                    processInstance.getProcessDefinition().getCode(),
+                    JSONUtils.toJsonString(cmdParam),
+                    processInstance.getWarningType(),
+                    processInstance.getWarningGroupId(),
+                    processInstance.getScheduleTime(),
+                    processInstance.getWorkerGroup(),
+                    processInstance.getEnvironmentCode(),
+                    processInstance.getProcessInstancePriority(),
+                    processInstance.getDryRun(),
+                    processInstance.getId(),
+                    processInstance.getProcessDefinitionVersion()
             );
             saveCommand(command);
             return;
@@ -675,10 +672,10 @@ public class ProcessService {
 
         // curing global params
         processInstance.setGlobalParams(ParameterUtils.curingGlobalParams(
-            processDefinition.getGlobalParamMap(),
-            processDefinition.getGlobalParamList(),
-            getCommandTypeIfComplement(processInstance, command),
-            processInstance.getScheduleTime()));
+                processDefinition.getGlobalParamMap(),
+                processDefinition.getGlobalParamList(),
+                getCommandTypeIfComplement(processInstance, command),
+                processInstance.getScheduleTime()));
 
         // set process instance priority
         processInstance.setProcessInstancePriority(command.getProcessInstancePriority());
@@ -705,7 +702,7 @@ public class ProcessService {
         startParamMap.putAll(fatherParamMap);
         // set start param into global params
         if (startParamMap.size() > 0
-            && processDefinition.getGlobalParamMap() != null) {
+                && processDefinition.getGlobalParamMap() != null) {
             for (Map.Entry<String, String> param : processDefinition.getGlobalParamMap().entrySet()) {
                 String val = startParamMap.get(param.getKey());
                 if (val != null) {
@@ -767,8 +764,8 @@ public class ProcessService {
     private Boolean checkCmdParam(Command command, Map<String, String> cmdParam) {
         if (command.getTaskDependType() == TaskDependType.TASK_ONLY || command.getTaskDependType() == TaskDependType.TASK_PRE) {
             if (cmdParam == null
-                || !cmdParam.containsKey(Constants.CMD_PARAM_START_NODES)
-                || cmdParam.get(Constants.CMD_PARAM_START_NODES).isEmpty()) {
+                    || !cmdParam.containsKey(Constants.CMD_PARAM_START_NODES)
+                    || cmdParam.get(Constants.CMD_PARAM_START_NODES).isEmpty()) {
                 logger.error("command node depend type is {}, but start nodes is null ", command.getTaskDependType());
                 return false;
             }
@@ -779,9 +776,8 @@ public class ProcessService {
     /**
      * construct process instance according to one command.
      *
-     * @param command                    command
-     * @param host                       host
-     * @param processDefinitionCacheMaps
+     * @param command command
+     * @param host host
      * @return process instance
      */
     private ProcessInstance constructProcessInstance(Command command, String host, HashMap<String, ProcessDefinition> processDefinitionCacheMaps) {
@@ -954,7 +950,7 @@ public class ProcessService {
                 }
 
                 return processDefineLogMapper.queryByDefinitionCodeAndVersion(
-                    processInstance.getProcessDefinitionCode(), processInstance.getProcessDefinitionVersion());
+                        processInstance.getProcessDefinitionCode(), processInstance.getProcessDefinitionVersion());
             }
         }
 
@@ -1000,9 +996,9 @@ public class ProcessService {
             processInstance.setScheduleTime(complementDate.get(0));
         }
         processInstance.setGlobalParams(ParameterUtils.curingGlobalParams(
-            processDefinition.getGlobalParamMap(),
-            processDefinition.getGlobalParamList(),
-            CommandType.COMPLEMENT_DATA, processInstance.getScheduleTime()));
+                processDefinition.getGlobalParamMap(),
+                processDefinition.getGlobalParamList(),
+                CommandType.COMPLEMENT_DATA, processInstance.getScheduleTime()));
     }
 
     /**
@@ -1020,7 +1016,7 @@ public class ProcessService {
         Map<String, String> paramMap = JSONUtils.toMap(cmdParam);
         // write sub process id into cmd param.
         if (paramMap.containsKey(CMD_PARAM_SUB_PROCESS)
-            && CMD_PARAM_EMPTY_SUB_PROCESS.equals(paramMap.get(CMD_PARAM_SUB_PROCESS))) {
+                && CMD_PARAM_EMPTY_SUB_PROCESS.equals(paramMap.get(CMD_PARAM_SUB_PROCESS))) {
             paramMap.remove(CMD_PARAM_SUB_PROCESS);
             paramMap.put(CMD_PARAM_SUB_PROCESS, String.valueOf(subProcessInstance.getId()));
             subProcessInstance.setCommandParam(JSONUtils.toJsonString(paramMap));
@@ -1033,7 +1029,7 @@ public class ProcessService {
             ProcessInstance parentInstance = findProcessInstanceDetailById(Integer.parseInt(parentInstanceId));
             if (parentInstance != null) {
                 subProcessInstance.setGlobalParams(
-                    joinGlobalParams(parentInstance.getGlobalParams(), subProcessInstance.getGlobalParams()));
+                        joinGlobalParams(parentInstance.getGlobalParams(), subProcessInstance.getGlobalParams()));
                 this.saveProcessInstance(subProcessInstance);
             } else {
                 logger.error("sub process command params error, cannot find parent instance: {} ", cmdParam);
@@ -1080,7 +1076,7 @@ public class ProcessService {
     private void initTaskInstance(TaskInstance taskInstance) {
 
         if (!taskInstance.isSubProcess()
-            && (taskInstance.getState().typeIsCancel() || taskInstance.getState().typeIsFailure())) {
+                && (taskInstance.getState().typeIsCancel() || taskInstance.getState().typeIsFailure())) {
             taskInstance.setFlag(Flag.NO);
             updateTaskInstance(taskInstance);
             return;
@@ -1091,11 +1087,6 @@ public class ProcessService {
 
     /**
      * retry submit task to db
-     *
-     * @param taskInstance
-     * @param commitRetryTimes
-     * @param commitInterval
-     * @return
      */
     public TaskInstance submitTask(TaskInstance taskInstance, int commitRetryTimes, int commitInterval) {
 
@@ -1135,12 +1126,12 @@ public class ProcessService {
     public TaskInstance submitTask(TaskInstance taskInstance) {
         ProcessInstance processInstance = this.findProcessInstanceDetailById(taskInstance.getProcessInstanceId());
         logger.info("start submit task : {}, instance id:{}, state: {}",
-            taskInstance.getName(), taskInstance.getProcessInstanceId(), processInstance.getState());
+                taskInstance.getName(), taskInstance.getProcessInstanceId(), processInstance.getState());
         //submit to db
         TaskInstance task = submitTaskInstanceToDB(taskInstance, processInstance);
         if (task == null) {
             logger.error("end submit task to db error, task name:{}, process id:{} state: {} ",
-                taskInstance.getName(), taskInstance.getProcessInstance(), processInstance.getState());
+                    taskInstance.getName(), taskInstance.getProcessInstance(), processInstance.getState());
             return task;
         }
         if (!task.getState().typeIsFinished()) {
@@ -1206,7 +1197,7 @@ public class ProcessService {
             }
         }
         logger.info("sub process instance is not found,parent task:{},parent instance:{}",
-            parentTask.getId(), parentProcessInstance.getId());
+                parentTask.getId(), parentProcessInstance.getId());
         return null;
     }
 
@@ -1298,21 +1289,21 @@ public class ProcessService {
         String processParam = getSubWorkFlowParam(instanceMap, parentProcessInstance, fatherParams);
         int subProcessInstanceId = childInstance == null ? 0 : childInstance.getId();
         return new Command(
-            commandType,
-            TaskDependType.TASK_POST,
-            parentProcessInstance.getFailureStrategy(),
-            parentProcessInstance.getExecutorId(),
-            subProcessDefinition.getCode(),
-            processParam,
-            parentProcessInstance.getWarningType(),
-            parentProcessInstance.getWarningGroupId(),
-            parentProcessInstance.getScheduleTime(),
-            task.getWorkerGroup(),
-            task.getEnvironmentCode(),
-            parentProcessInstance.getProcessInstancePriority(),
-            parentProcessInstance.getDryRun(),
-            subProcessInstanceId,
-            subProcessDefinition.getVersion()
+                commandType,
+                TaskDependType.TASK_POST,
+                parentProcessInstance.getFailureStrategy(),
+                parentProcessInstance.getExecutorId(),
+                subProcessDefinition.getCode(),
+                processParam,
+                parentProcessInstance.getWarningType(),
+                parentProcessInstance.getWarningGroupId(),
+                parentProcessInstance.getScheduleTime(),
+                task.getWorkerGroup(),
+                task.getEnvironmentCode(),
+                parentProcessInstance.getProcessInstancePriority(),
+                parentProcessInstance.getDryRun(),
+                subProcessInstanceId,
+                subProcessDefinition.getVersion()
         );
     }
 
@@ -1349,7 +1340,7 @@ public class ProcessService {
      */
     private void updateSubProcessDefinitionByParent(ProcessInstance parentProcessInstance, long childDefinitionCode) {
         ProcessDefinition fatherDefinition = this.findProcessDefinition(parentProcessInstance.getProcessDefinitionCode(),
-            parentProcessInstance.getProcessDefinitionVersion());
+                parentProcessInstance.getProcessDefinitionVersion());
         ProcessDefinition childDefinition = this.findProcessDefinitionByCode(childDefinitionCode);
         if (childDefinition != null && fatherDefinition != null) {
             childDefinition.setWarningGroupId(fatherDefinition.getWarningGroupId());
@@ -1372,7 +1363,7 @@ public class ProcessService {
                 taskInstance.setRetryTimes(taskInstance.getRetryTimes() + 1);
             } else {
                 if (processInstanceState != ExecutionStatus.READY_STOP
-                    && processInstanceState != ExecutionStatus.READY_PAUSE) {
+                        && processInstanceState != ExecutionStatus.READY_PAUSE) {
                     // failure task set invalid
                     taskInstance.setFlag(Flag.NO);
                     updateTaskInstance(taskInstance);
@@ -1425,9 +1416,9 @@ public class ProcessService {
         // the task already exists in task queue
         // return state
         if (
-            state == ExecutionStatus.RUNNING_EXECUTION
-                || state == ExecutionStatus.DELAY_EXECUTION
-                || state == ExecutionStatus.KILL
+                state == ExecutionStatus.RUNNING_EXECUTION
+                        || state == ExecutionStatus.DELAY_EXECUTION
+                        || state == ExecutionStatus.KILL
         ) {
             return state;
         }
@@ -1436,7 +1427,7 @@ public class ProcessService {
         if (processInstanceState == ExecutionStatus.READY_PAUSE) {
             state = ExecutionStatus.PAUSE;
         } else if (processInstanceState == ExecutionStatus.READY_STOP
-            || !checkProcessStrategy(taskInstance)) {
+                || !checkProcessStrategy(taskInstance)) {
             state = ExecutionStatus.KILL;
         } else {
             state = ExecutionStatus.SUBMITTED_SUCCESS;
@@ -1460,7 +1451,7 @@ public class ProcessService {
 
         for (TaskInstance task : taskInstances) {
             if (task.getState() == ExecutionStatus.FAILURE
-                && task.getRetryTimes() >= task.getMaxRetryTimes()) {
+                    && task.getRetryTimes() >= task.getMaxRetryTimes()) {
                 return false;
             }
         }
@@ -1589,7 +1580,8 @@ public class ProcessService {
     private void updateTaskDefinitionResources(TaskDefinition taskDefinition) {
         Map<String, Object> taskParameters = JSONUtils.parseObject(
                 taskDefinition.getTaskParams(),
-                new TypeReference<Map<String, Object>>() { });
+                new TypeReference<Map<String, Object>>() {
+                });
         if (taskParameters != null) {
             // if contains mainJar field, query resource from database
             // Flink, Spark, MR
@@ -1815,8 +1807,6 @@ public class ProcessService {
 
     /**
      * for show in page of taskInstance
-     *
-     * @param taskInstance
      */
     public void changeOutParam(TaskInstance taskInstance) {
         if (StringUtils.isEmpty(taskInstance.getVarPool())) {
@@ -1827,7 +1817,8 @@ public class ProcessService {
             return;
         }
         //if the result more than one line,just get the first .
-        Map<String, Object> taskParams = JSONUtils.parseObject(taskInstance.getTaskParams(), new TypeReference<Map<String, Object>>() {});
+        Map<String, Object> taskParams = JSONUtils.parseObject(taskInstance.getTaskParams(), new TypeReference<Map<String, Object>>() {
+        });
         Object localParams = taskParams.get(LOCAL_PARAMS);
         if (localParams == null) {
             return;
@@ -1928,7 +1919,7 @@ public class ProcessService {
      */
     public List<TaskInstance> queryNeedFailoverTaskInstances(String host) {
         return taskInstanceMapper.queryByHostAndStatus(host,
-            stateArray);
+                stateArray);
     }
 
     /**
@@ -2024,8 +2015,8 @@ public class ProcessService {
      */
     public ProcessInstance findLastSchedulerProcessInterval(Long definitionCode, DateInterval dateInterval) {
         return processInstanceMapper.queryLastSchedulerProcess(definitionCode,
-            dateInterval.getStartTime(),
-            dateInterval.getEndTime());
+                dateInterval.getStartTime(),
+                dateInterval.getEndTime());
     }
 
     /**
@@ -2037,8 +2028,8 @@ public class ProcessService {
      */
     public ProcessInstance findLastManualProcessInterval(Long definitionCode, DateInterval dateInterval) {
         return processInstanceMapper.queryLastManualProcess(definitionCode,
-            dateInterval.getStartTime(),
-            dateInterval.getEndTime());
+                dateInterval.getStartTime(),
+                dateInterval.getEndTime());
     }
 
     /**
@@ -2051,9 +2042,9 @@ public class ProcessService {
      */
     public ProcessInstance findLastRunningProcess(Long definitionCode, Date startTime, Date endTime) {
         return processInstanceMapper.queryLastRunningProcess(definitionCode,
-            startTime,
-            endTime,
-            stateArray);
+                startTime,
+                endTime,
+                stateArray);
     }
 
     /**
@@ -2259,10 +2250,10 @@ public class ProcessService {
         AbstractParameters params = TaskParametersUtils.getParameters(taskDefinition.getTaskType(), taskDefinition.getTaskParams());
         if (params != null && CollectionUtils.isNotEmpty(params.getResourceFilesList())) {
             resourceIds = params.getResourceFilesList().
-                stream()
-                .filter(t -> t.getId() != 0)
-                .map(ResourceInfo::getId)
-                .collect(Collectors.toSet());
+                    stream()
+                    .filter(t -> t.getId() != 0)
+                    .map(ResourceInfo::getId)
+                    .collect(Collectors.toSet());
         }
         if (CollectionUtils.isEmpty(resourceIds)) {
             return StringUtils.EMPTY;
@@ -2282,7 +2273,7 @@ public class ProcessService {
             taskDefinitionLog.setResourceIds(getResourceIds(taskDefinitionLog));
             if (taskDefinitionLog.getCode() > 0 && taskDefinitionLog.getVersion() > 0) {
                 TaskDefinitionLog definitionCodeAndVersion = taskDefinitionLogMapper
-                    .queryByDefinitionCodeAndVersion(taskDefinitionLog.getCode(), taskDefinitionLog.getVersion());
+                        .queryByDefinitionCodeAndVersion(taskDefinitionLog.getCode(), taskDefinitionLog.getVersion());
                 if (definitionCodeAndVersion != null) {
                     if (!taskDefinitionLog.equals(definitionCodeAndVersion)) {
                         taskDefinitionLog.setUserId(definitionCodeAndVersion.getUserId());
@@ -2356,7 +2347,7 @@ public class ProcessService {
         Map<Long, TaskDefinitionLog> taskDefinitionLogMap = null;
         if (CollectionUtils.isNotEmpty(taskDefinitionLogs)) {
             taskDefinitionLogMap = taskDefinitionLogs.stream()
-                .collect(Collectors.toMap(TaskDefinition::getCode, taskDefinitionLog -> taskDefinitionLog));
+                    .collect(Collectors.toMap(TaskDefinition::getCode, taskDefinitionLog -> taskDefinitionLog));
         }
         Date now = new Date();
         for (ProcessTaskRelationLog processTaskRelationLog : taskRelationList) {
@@ -2394,9 +2385,9 @@ public class ProcessService {
         List<ProcessTaskRelation> processTaskRelationList = processTaskRelationMapper.queryByTaskCode(taskCode);
         if (!processTaskRelationList.isEmpty()) {
             Set<Long> processDefinitionCodes = processTaskRelationList
-                .stream()
-                .map(ProcessTaskRelation::getProcessDefinitionCode)
-                .collect(Collectors.toSet());
+                    .stream()
+                    .map(ProcessTaskRelation::getProcessDefinitionCode)
+                    .collect(Collectors.toSet());
             List<ProcessDefinition> processDefinitionList = processDefineMapper.queryByCodes(processDefinitionCodes);
             // check process definition is already online
             for (ProcessDefinition processDefinition : processDefinitionList) {
@@ -2429,8 +2420,8 @@ public class ProcessService {
         List<ProcessTaskRelation> processTaskRelations = processTaskRelationMapper.queryByProcessCode(processDefinition.getProjectCode(), processDefinition.getCode());
         List<TaskDefinitionLog> taskDefinitionLogList = genTaskDefineList(processTaskRelations);
         List<TaskDefinition> taskDefinitions = taskDefinitionLogList.stream()
-            .map(taskDefinitionLog -> JSONUtils.parseObject(JSONUtils.toJsonString(taskDefinitionLog), TaskDefinition.class))
-            .collect(Collectors.toList());
+                .map(taskDefinitionLog -> JSONUtils.parseObject(JSONUtils.toJsonString(taskDefinitionLog), TaskDefinition.class))
+                .collect(Collectors.toList());
         return new DagData(processDefinition, processTaskRelations, taskDefinitions);
     }
 
@@ -2493,7 +2484,7 @@ public class ProcessService {
             taskDefinitionLogs = genTaskDefineList(taskRelationList);
         }
         Map<Long, TaskDefinitionLog> taskDefinitionLogMap = taskDefinitionLogs.stream()
-            .collect(Collectors.toMap(TaskDefinitionLog::getCode, taskDefinitionLog -> taskDefinitionLog));
+                .collect(Collectors.toMap(TaskDefinitionLog::getCode, taskDefinitionLog -> taskDefinitionLog));
         List<TaskNode> taskNodeList = new ArrayList<>();
         for (Entry<Long, List<Long>> code : taskCodeMap.entrySet()) {
             TaskDefinitionLog taskDefinitionLog = taskDefinitionLogMap.get(code.getKey());
@@ -2518,8 +2509,8 @@ public class ProcessService {
                 taskNode.setWorkerGroup(taskDefinitionLog.getWorkerGroup());
                 taskNode.setEnvironmentCode(taskDefinitionLog.getEnvironmentCode());
                 taskNode.setTimeout(JSONUtils.toJsonString(new TaskTimeoutParameter(taskDefinitionLog.getTimeoutFlag() == TimeoutFlag.OPEN,
-                    taskDefinitionLog.getTimeoutNotifyStrategy(),
-                    taskDefinitionLog.getTimeout())));
+                        taskDefinitionLog.getTimeoutNotifyStrategy(),
+                        taskDefinitionLog.getTimeout())));
                 taskNode.setDelayTime(taskDefinitionLog.getDelayTime());
                 taskNode.setPreTasks(JSONUtils.toJsonString(code.getValue().stream().map(taskDefinitionLogMap::get).map(TaskDefinition::getCode).collect(Collectors.toList())));
                 taskNodeList.add(taskNode);
@@ -2545,6 +2536,13 @@ public class ProcessService {
     }
 
     public ProcessInstance loadNextProcess4Serial(long code, int state) {
-        return this.processInstanceMapper.loadNextProcess4Serial(code,state);
+        return this.processInstanceMapper.loadNextProcess4Serial(code, state);
+    }
+
+    private void deleteCommandWithCheck(int commandId) {
+        int delete = this.commandMapper.deleteById(commandId);
+        if (delete != 1) {
+            throw new ServiceException("delete command fail, id:" + commandId);
+        }
     }
 }

+ 64 - 1
dolphinscheduler-service/src/test/java/org/apache/dolphinscheduler/service/process/ProcessServiceTest.java

@@ -60,6 +60,7 @@ import org.apache.dolphinscheduler.dao.mapper.TaskDefinitionLogMapper;
 import org.apache.dolphinscheduler.dao.mapper.TaskDefinitionMapper;
 import org.apache.dolphinscheduler.dao.mapper.TaskInstanceMapper;
 import org.apache.dolphinscheduler.dao.mapper.UserMapper;
+import org.apache.dolphinscheduler.service.exceptions.ServiceException;
 import org.apache.dolphinscheduler.service.quartz.cron.CronUtilsTest;
 
 import java.util.ArrayList;
@@ -72,7 +73,9 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.mockito.InjectMocks;
 import org.mockito.Mock;
@@ -92,6 +95,9 @@ public class ProcessServiceTest {
 
     private static final Logger logger = LoggerFactory.getLogger(CronUtilsTest.class);
 
+    @Rule
+    public final ExpectedException exception = ExpectedException.none();
+
     @InjectMocks
     private ProcessService processService;
     @Mock
@@ -255,10 +261,12 @@ public class ProcessServiceTest {
         int processInstanceId = 222;
         //there is not enough thread for this command
         Command command1 = new Command();
+        command1.setId(1);
         command1.setProcessDefinitionCode(definitionCode);
         command1.setProcessDefinitionVersion(definitionVersion);
         command1.setCommandParam("{\"ProcessInstanceId\":222}");
         command1.setCommandType(CommandType.START_PROCESS);
+        Mockito.when(commandMapper.deleteById(1)).thenReturn(1);
 
         ProcessDefinition processDefinition = new ProcessDefinition();
         processDefinition.setId(123);
@@ -284,31 +292,37 @@ public class ProcessServiceTest {
         Assert.assertNotNull(processService.handleCommand(logger, host, command1, processDefinitionCacheMaps));
 
         Command command2 = new Command();
+        command2.setId(2);
         command2.setCommandParam("{\"ProcessInstanceId\":222,\"StartNodeIdList\":\"n1,n2\"}");
         command2.setProcessDefinitionCode(definitionCode);
         command2.setProcessDefinitionVersion(definitionVersion);
         command2.setCommandType(CommandType.RECOVER_SUSPENDED_PROCESS);
         command2.setProcessInstanceId(processInstanceId);
-
+        Mockito.when(commandMapper.deleteById(2)).thenReturn(1);
         Assert.assertNotNull(processService.handleCommand(logger, host, command2, processDefinitionCacheMaps));
 
         Command command3 = new Command();
+        command3.setId(3);
         command3.setProcessDefinitionCode(definitionCode);
         command3.setProcessDefinitionVersion(definitionVersion);
         command3.setProcessInstanceId(processInstanceId);
         command3.setCommandParam("{\"WaitingThreadInstanceId\":222}");
         command3.setCommandType(CommandType.START_FAILURE_TASK_PROCESS);
+        Mockito.when(commandMapper.deleteById(3)).thenReturn(1);
         Assert.assertNotNull(processService.handleCommand(logger, host, command3, processDefinitionCacheMaps));
 
         Command command4 = new Command();
+        command4.setId(4);
         command4.setProcessDefinitionCode(definitionCode);
         command4.setProcessDefinitionVersion(definitionVersion);
         command4.setCommandParam("{\"WaitingThreadInstanceId\":222,\"StartNodeIdList\":\"n1,n2\"}");
         command4.setCommandType(CommandType.REPEAT_RUNNING);
         command4.setProcessInstanceId(processInstanceId);
+        Mockito.when(commandMapper.deleteById(4)).thenReturn(1);
         Assert.assertNotNull(processService.handleCommand(logger, host, command4, processDefinitionCacheMaps));
 
         Command command5 = new Command();
+        command5.setId(5);
         command5.setProcessDefinitionCode(definitionCode);
         command5.setProcessDefinitionVersion(definitionVersion);
         HashMap<String, String> startParams = new HashMap<>();
@@ -318,6 +332,7 @@ public class ProcessServiceTest {
         command5.setCommandParam(JSONUtils.toJsonString(commandParams));
         command5.setCommandType(CommandType.START_PROCESS);
         command5.setDryRun(Constants.DRY_RUN_FLAG_NO);
+        Mockito.when(commandMapper.deleteById(5)).thenReturn(1);
         ProcessInstance processInstance1 = processService.handleCommand(logger, host, command5, processDefinitionCacheMaps);
         Assert.assertTrue(processInstance1.getGlobalParams().contains("\"testStartParam1\""));
 
@@ -342,14 +357,18 @@ public class ProcessServiceTest {
         processInstance2.setProcessDefinitionVersion(1);
         Mockito.when(processInstanceMapper.queryDetailById(223)).thenReturn(processInstance2);
         Mockito.when(processDefineMapper.queryByCode(11L)).thenReturn(processDefinition1);
+        Mockito.when(commandMapper.deleteById(1)).thenReturn(1);
         Assert.assertNotNull(processService.handleCommand(logger, host, command1, processDefinitionCacheMaps));
+
         Command command6 = new Command();
+        command6.setId(6);
         command6.setProcessDefinitionCode(11L);
         command6.setCommandParam("{\"ProcessInstanceId\":223}");
         command6.setCommandType(CommandType.RECOVER_SERIAL_WAIT);
         command6.setProcessDefinitionVersion(1);
         Mockito.when(processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(11L,Constants.RUNNING_PROCESS_STATE,223)).thenReturn(lists);
         Mockito.when(processInstanceMapper.updateNextProcessIdById(223, 222)).thenReturn(true);
+        Mockito.when(commandMapper.deleteById(6)).thenReturn(1);
         ProcessInstance processInstance6 = processService.handleCommand(logger, host, command6, processDefinitionCacheMaps);
         Assert.assertTrue(processInstance6 != null);
 
@@ -362,10 +381,12 @@ public class ProcessServiceTest {
         Mockito.when(processInstanceMapper.queryDetailById(224)).thenReturn(processInstance7);
 
         Command command7 = new Command();
+        command7.setId(7);
         command7.setProcessDefinitionCode(11L);
         command7.setCommandParam("{\"ProcessInstanceId\":224}");
         command7.setCommandType(CommandType.RECOVER_SERIAL_WAIT);
         command7.setProcessDefinitionVersion(1);
+        Mockito.when(commandMapper.deleteById(7)).thenReturn(1);
         Mockito.when(processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(11L,Constants.RUNNING_PROCESS_STATE,224)).thenReturn(null);
         ProcessInstance processInstance8 = processService.handleCommand(logger, host, command7, processDefinitionCacheMaps);
         Assert.assertTrue(processInstance8 == null);
@@ -382,6 +403,7 @@ public class ProcessServiceTest {
         processInstance9.setProcessDefinitionCode(11L);
         processInstance9.setProcessDefinitionVersion(1);
         Command command9 = new Command();
+        command9.setId(9);
         command9.setProcessDefinitionCode(12L);
         command9.setCommandParam("{\"ProcessInstanceId\":225}");
         command9.setCommandType(CommandType.RECOVER_SERIAL_WAIT);
@@ -389,10 +411,51 @@ public class ProcessServiceTest {
         Mockito.when(processInstanceMapper.queryDetailById(225)).thenReturn(processInstance9);
         Mockito.when(processInstanceMapper.queryByProcessDefineCodeAndStatusAndNextId(12L,Constants.RUNNING_PROCESS_STATE,0)).thenReturn(lists);
         Mockito.when(processInstanceMapper.updateById(processInstance)).thenReturn(1);
+        Mockito.when(commandMapper.deleteById(9)).thenReturn(1);
         ProcessInstance processInstance10 = processService.handleCommand(logger, host, command9, processDefinitionCacheMaps);
         Assert.assertTrue(processInstance10 == null);
     }
 
+    @Test(expected = ServiceException.class)
+    public void testDeleteNotExistCommand() {
+        String host = "127.0.0.1";
+        int definitionVersion = 1;
+        long definitionCode = 123;
+        int processInstanceId = 222;
+
+        Command command1 = new Command();
+        command1.setId(1);
+        command1.setProcessDefinitionCode(definitionCode);
+        command1.setProcessDefinitionVersion(definitionVersion);
+        command1.setCommandParam("{\"ProcessInstanceId\":222}");
+        command1.setCommandType(CommandType.START_PROCESS);
+
+        ProcessDefinition processDefinition = new ProcessDefinition();
+        processDefinition.setId(123);
+        processDefinition.setName("test");
+        processDefinition.setVersion(definitionVersion);
+        processDefinition.setCode(definitionCode);
+        processDefinition.setGlobalParams("[{\"prop\":\"startParam1\",\"direct\":\"IN\",\"type\":\"VARCHAR\",\"value\":\"\"}]");
+        processDefinition.setExecutionType(ProcessExecutionTypeEnum.PARALLEL);
+
+        ProcessInstance processInstance = new ProcessInstance();
+        processInstance.setId(222);
+        processInstance.setProcessDefinitionCode(11L);
+        processInstance.setHost("127.0.0.1:5678");
+        processInstance.setProcessDefinitionVersion(1);
+        processInstance.setId(processInstanceId);
+        processInstance.setProcessDefinitionCode(definitionCode);
+        processInstance.setProcessDefinitionVersion(definitionVersion);
+
+        Mockito.when(processDefineMapper.queryByCode(command1.getProcessDefinitionCode())).thenReturn(processDefinition);
+        Mockito.when(processDefineLogMapper.queryByDefinitionCodeAndVersion(processInstance.getProcessDefinitionCode(),
+                processInstance.getProcessDefinitionVersion())).thenReturn(new ProcessDefinitionLog(processDefinition));
+        Mockito.when(processInstanceMapper.queryDetailById(222)).thenReturn(processInstance);
+
+        // will throw exception when command id is 0 and delete fail
+        processService.handleCommand(logger, host, command1, processDefinitionCacheMaps);
+    }
+
     @Test
     public void testGetUserById() {
         User user = new User();