Browse Source

Refactoring the code for file uploads

kl 2 years ago
parent
commit
c2c870668b

+ 26 - 0
server/src/main/java/cn/keking/utils/WebUtils.java

@@ -3,6 +3,8 @@ package cn.keking.utils;
 import io.mola.galimatias.GalimatiasParseException;
 import org.apache.commons.lang3.StringUtils;
 import org.springframework.util.Base64Utils;
+import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.util.HtmlUtils;
 
 import javax.servlet.ServletRequest;
 import java.io.UnsupportedEncodingException;
@@ -103,6 +105,30 @@ public class WebUtils {
         return noQueryUrl.substring(noQueryUrl.lastIndexOf("/") + 1);
     }
 
+    /**
+     * 从url中剥离出文件名
+     * @param file 文件
+     * @return 文件名
+     */
+    public static String getFileNameFromMultipartFile(MultipartFile file) {
+        String fileName = file.getOriginalFilename();
+        //判断是否为IE浏览器的文件名,IE浏览器下文件名会带有盘符信
+        // escaping dangerous characters to prevent XSS
+        assert fileName != null;
+        fileName = HtmlUtils.htmlEscape(fileName, KkFileUtils.DEFAULT_FILE_ENCODING);
+
+        // Check for Unix-style path
+        int unixSep = fileName.lastIndexOf('/');
+        // Check for Windows-style path
+        int winSep = fileName.lastIndexOf('\\');
+        // Cut off at latest possible point
+        int pos = (Math.max(winSep, unixSep));
+        if (pos != -1) {
+            fileName = fileName.substring(pos + 1);
+        }
+        return fileName;
+    }
+
 
     /**
      * 从url中获取文件后缀

+ 84 - 52
server/src/main/java/cn/keking/web/controller/FileController.java

@@ -6,22 +6,26 @@ import cn.keking.utils.KkFileUtils;
 import cn.keking.utils.WebUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.springframework.util.ObjectUtils;
 import org.springframework.util.StreamUtils;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestParam;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.multipart.MultipartFile;
-import org.springframework.web.util.HtmlUtils;
 
 import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
-import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Paths;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
 
 /**
  * @author yudian-it
@@ -36,50 +40,22 @@ public class FileController {
     private final String demoDir = "demo";
     private final String demoPath = demoDir + File.separator;
     public static final String BASE64_DECODE_ERROR_MSG = "Base64解码失败,请检查你的 %s 是否采用 Base64 + urlEncode 双重编码了!";
+    private static final String[] not_allowed = { "dll", "exe", "msi" }; // 不允许上传的文件扩展名
+
     @PostMapping("/fileUpload")
     public ReturnResponse<Object> fileUpload(@RequestParam("file") MultipartFile file) {
-        if (ConfigConstants.getFileUploadDisable()) {
-            return ReturnResponse.failure("文件传接口已禁用");
-        }
-        // 获取文件名
-        String fileName = file.getOriginalFilename();
-        //判断是否为IE浏览器的文件名,IE浏览器下文件名会带有盘符信息
-
-        // escaping dangerous characters to prevent XSS
-        assert fileName != null;
-        fileName = HtmlUtils.htmlEscape(fileName, StandardCharsets.UTF_8.name());
-
-        // Check for Unix-style path
-        int unixSep = fileName.lastIndexOf('/');
-        // Check for Windows-style path
-        int winSep = fileName.lastIndexOf('\\');
-        // Cut off at latest possible point
-        int pos = (Math.max(winSep, unixSep));
-        if (pos != -1) {
-            fileName = fileName.substring(pos + 1);
-        }
-        String fileType= "";
-        int i = fileName.lastIndexOf('.');
-        if (i > 0) {
-            fileType= fileName.substring(i+1);
-            fileType= fileType.toLowerCase();
-        }
-        if (fileType.length() == 0 || fileType.equals("dll") || fileType.equals("exe") || fileType.equals("msi") ){
-            return ReturnResponse.failure(fileName+"不允许上传的文件");
-        }
-        // 判断是否存在同名文件
-        if (existsFile(fileName)) {
-            return ReturnResponse.failure("存在同名文件,请先删除原有文件再次上传");
+        ReturnResponse<Object> checkResult = this.fileUploadCheck(file);
+        if (checkResult.isFailure()) {
+            return checkResult;
         }
         File outFile = new File(fileDir + demoPath);
         if (!outFile.exists() && !outFile.mkdirs()) {
             logger.error("创建文件夹【{}】失败,请检查目录权限!", fileDir + demoPath);
         }
-        logger.info("上传文件:{}", fileDir + demoPath + fileName);
+        String fileName = checkResult.getContent().toString();
+        logger.info("上传文件:{}{}{}", fileDir, demoPath, fileName);
         try (InputStream in = file.getInputStream(); OutputStream out = Files.newOutputStream(Paths.get(fileDir + demoPath + fileName))) {
             StreamUtils.copy(in, out);
-            in.close();
-            out.close();
             return ReturnResponse.success(null);
         } catch (IOException e) {
             logger.error("文件上传失败", e);
@@ -89,21 +65,11 @@ public class FileController {
 
     @GetMapping("/deleteFile")
     public ReturnResponse<Object> deleteFile(String fileName) {
-        if (fileName == null || fileName.length() == 0) {
-            return ReturnResponse.failure("文件名为空,删除失败!");
-        }
-        try {
-            fileName = WebUtils.decodeUrl(fileName);
-        } catch (Exception ex) {
-            String errorMsg = String.format(BASE64_DECODE_ERROR_MSG, "url");
-            return ReturnResponse.failure(errorMsg+"删除失败!");
-        }
-        if (fileName.contains("/")) {
-            fileName = fileName.substring(fileName.lastIndexOf("/") + 1);
-        }
-        if (KkFileUtils.isIllegalFileName(fileName)) {
-            return ReturnResponse.failure("非法文件名,删除失败!");
+        ReturnResponse<Object> checkResult = this.deleteFileCheck(fileName);
+        if (checkResult.isFailure()) {
+            return checkResult;
         }
+        fileName = checkResult.getContent().toString();
         File file = new File(fileDir + demoPath + fileName);
         logger.info("删除文件:{}", file.getAbsolutePath());
         if (file.exists() && !file.delete()) {
@@ -130,6 +96,72 @@ public class FileController {
         return list;
     }
 
+    /**
+     * 上传文件前校验
+     *
+     * @param file 文件
+     * @return 校验结果
+     */
+    private ReturnResponse<Object> fileUploadCheck(MultipartFile file) {
+        if (ConfigConstants.getFileUploadDisable()) {
+            return ReturnResponse.failure("文件传接口已禁用");
+        }
+        String fileName = WebUtils.getFileNameFromMultipartFile(file);
+
+        if (!isAllowedUpload(fileName)) {
+            return ReturnResponse.failure("不允许上传的文件类型: " + fileName);
+        }
+        if (KkFileUtils.isIllegalFileName(fileName)) {
+            return ReturnResponse.failure("不允许上传的文件名: " + fileName);
+        }
+        // 判断是否存在同名文件
+        if (existsFile(fileName)) {
+            return ReturnResponse.failure("存在同名文件,请先删除原有文件再次上传");
+        }
+        return ReturnResponse.success(fileName);
+    }
+
+    /**
+     * 判断文件是否允许上传
+     *
+     * @param file 文件扩展名
+     * @return 是否允许上传
+     */
+    private boolean isAllowedUpload(String file) {
+        String fileType = KkFileUtils.suffixFromFileName(file);
+        for (String type : not_allowed) {
+            if (type.equals(fileType))
+                return false;
+        }
+        return !ObjectUtils.isEmpty(fileType);
+    }
+
+    /**
+     * 删除文件前校验
+     *
+     * @param fileName 文件名
+     * @return 校验结果
+     */
+    private ReturnResponse<Object> deleteFileCheck(String fileName) {
+        if (ObjectUtils.isEmpty(fileName)) {
+            return ReturnResponse.failure("文件名为空,删除失败!");
+        }
+        try {
+            fileName = WebUtils.decodeUrl(fileName);
+        } catch (Exception ex) {
+            String errorMsg = String.format(BASE64_DECODE_ERROR_MSG, fileName);
+            return ReturnResponse.failure(errorMsg + "删除失败!");
+        }
+        assert fileName != null;
+        if (fileName.contains("/")) {
+            fileName = fileName.substring(fileName.lastIndexOf("/") + 1);
+        }
+        if (KkFileUtils.isIllegalFileName(fileName)) {
+            return ReturnResponse.failure("非法文件名,删除失败!");
+        }
+        return ReturnResponse.success(fileName);
+    }
+
     private boolean existsFile(String fileName) {
         File file = new File(fileDir + demoPath + fileName);
         return file.exists();