springboot通过限制一定时间内限制接口请求次数,防止恶意评价,批量发布恶意广告推文

wylc123 1年前 ⋅ 671 阅读

1、背景

因为博客开放注册,有人恶意注册账号发布设灰广告推文。

2、 原理

限制一定时间内限制接口请求次数,主要是评论,发文接口, 超过请求次数,会发送邮件到站长邮箱,通知删除恶意文章、评论。

3、用到的技术

自定义注解、AOP、ExpiringMap(带有有效期的映射)

4、具体代码实现

4.1 引入依赖
<!-- AOP依赖 -->
<dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
            <version>2.1.5.RELEASE</version>
</dependency>
<!-- Map依赖 -->
<dependency>
            <groupId>net.jodah</groupId>
            <artifactId>expiringmap</artifactId>
            <version>0.5.8</version>
</dependency>
4.2 自定义注解@LimitRequest
/**
 * @author SongBin on 2021-11-12.
 */
package com.mtons.mblog.web.interceptor;

import java.lang.annotation.*;

@Documented
@Target(ElementType.METHOD) // 说明该注解只能放在方法上面
@Retention(RetentionPolicy.RUNTIME)
//十分钟限制发表一篇文章
public @interface LimitRequest {
    long time() default 1000 * 60 * 10; // 限制时间1小时 单位:毫秒
    int count() default 1; // 允许请求的次数
}
4.3 自定义AOP
/**
 * @author SongBin on 2021-11-12.
 */
package com.mtons.mblog.web.interceptor;

import com.alibaba.fastjson.JSON;
import com.mtons.mblog.base.lang.Result;
import net.jodah.expiringmap.ExpirationPolicy;
import net.jodah.expiringmap.ExpiringMap;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

@Aspect
@Component
public class LimitRequestAspect {
    private static final String errorView = "/error";
    private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();

    // 定义切点
    // 让所有有@LimitRequest注解的方法都执行切面方法
    @Pointcut("@annotation(limitRequest)")
    public void excudeService(LimitRequest limitRequest) {
    }

    @Around("excudeService(limitRequest)")
    public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {

        // 获得request对象
        RequestAttributes ra = RequestContextHolder.getRequestAttributes();
        ServletRequestAttributes sra = (ServletRequestAttributes) ra;
        HttpServletRequest request = sra.getRequest();
        HttpServletResponse response = sra.getResponse();

        // 获取Map对象, 如果没有则返回默认值
        // 第一个参数是key, 第二个参数是默认值
        ExpiringMap<String, Integer> uc = book.getOrDefault(request.getRequestURI(), ExpiringMap.builder().variableExpiration().build());
        Integer uCount = uc.getOrDefault(request.getRemoteAddr(), 0);



        if (uCount >= limitRequest.count()) { // 超过次数,不执行目标方法
            ModelAndView view = null;
            String ret = "接口请求超过次数";
//            try {
//                response.setContentType("application/json;charset=UTF-8");
//                response.getWriter().print(JSON.toJSONString(Result.failure(ret)));
//            } catch (IOException e) {
//                // do something
//            }
//            view = new ModelAndView();
            Map<String, Object> map = new HashMap<String, Object>();
            map.put("error", ret);
//            map.put("base", request.getContextPath());
            view = new ModelAndView(errorView, map);
            return view;
        } else if (uCount == 0){ // 第一次请求时,设置有效时间
//            /** Expires entries based on when they were last accessed */
//            ACCESSED,
//            /** Expires entries based on when they were created */
//            CREATED;
            uc.put(request.getRemoteAddr(), uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
        } else { // 未超过次数, 记录加一
            uc.put(request.getRemoteAddr(), uCount + 1);
        }
        book.put(request.getRequestURI(), uc);

        // result的值就是被拦截方法的返回值
        Object result = pjp.proceed();

        return result;
    }


}

第一个静态Map是多线程安全的Map(ConcurrentHashMap),它的key是接口对于的url,它的value是一个多线程安全且键值对是有有效期的Map(ExpiringMap)。

ExpiringMap的key是请求的ip地址,value是已经请求的次数。

ExpiringMap更多的使用方法可以参考:https://github.com/jhalterman/expiringmap

4.4 上面的LimitRequestAspect代码也可以用一个拦截器代替,可以将次数存到redis里
import cnki.bdms.aop.LimitRequest;
import cnki.bdms.common.enums.ReturnStatus;
import cnki.bdms.common.util.JsonHelper;
import cnki.bdms.util.IPUtils;
import cnki.bdms.util.RedisUtil;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
 * User: SongBin
 * Date: 11:31 2022/4/25
 * Description: 限制用户接口访问频次
 * Version 1.0
 */
@Component
public class ApiIPVisitInterception extends HandlerInterceptorAdapter {
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Value("${sms.expireAfterSend}")
    private Integer expireAfterSend;

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
                           ModelAndView modelAndView) throws Exception {
        super.postHandle(request, response, handler, modelAndView);
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        return checkUrl(request, response, handler);
    }

    private boolean checkUrl(HttpServletRequest request, HttpServletResponse response, Object handler) throws IOException {
        boolean flag = true;
        ObjectNode result = JsonHelper.createNode();
        HandlerMethod hm = (HandlerMethod) handler;
        //获取方法中的注解,看是否有该注解
        LimitRequest accessLimit = hm.getMethodAnnotation(LimitRequest.class);
        if (accessLimit != null) {
            int time = expireAfterSend != null && expireAfterSend > 0 ? 1000 * 60 * expireAfterSend : accessLimit.time();
            int maxCount = accessLimit.count();
            //从redis中获取用户访问的次数
            String ip = IPUtils.getIP();
            Integer count = (Integer) RedisUtil.get(ip);
            if (count == null) {
                //第一次访问
                RedisUtil.set(ip, 1);
            } else if (count < maxCount) {
                //加1
                count = count + 1;
                RedisUtil.set(ip, count, time);
            } else {
                //超出访问次数
                logger.info("访问过快ip  ===> " + ip + " 且在   " + time + " 毫秒内超过最大限制  ===> " + maxCount + " 次数达到    ====> " + count);
                result.put("code", ReturnStatus.ApiLock.getValue());
                result.put("msg", "接口请求太频繁,系统认定恶意攻击!");
                response.setHeader("Access-Control-Allow-Origin", "*");
                response.setHeader("Access-Control-Allow-Headers", "token, Accept, Origin, X-Requested-With, Content-Type, Last-Modified");
                response.addHeader("Content-Type", "application/json;charset=UTF-8");
                response.setStatus(200);
                response.getWriter().write(result.toString());
                flag = false;
            }
        }
        return flag;
    }
}

然后将拦截器注册进去

import cnki.bdms.web.ui.security.ApiIPVisitInterception;
import cnki.bdms.web.ui.security.CorsInterceptor;
import cnki.bdms.web.ui.security.URLInterception;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.alibaba.fastjson.support.config.FastJsonConfig;
import com.alibaba.fastjson.support.spring.FastJsonHttpMessageConverter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;

@Configuration
public class WebMvcConfig implements WebMvcConfigurer {

    @Autowired
    private URLInterception urlInterception;
    @Autowired
    private ApiIPVisitInterception apiIPVisitInterception;
    @Autowired
    CorsInterceptor corsInterceptor;

    /**
     * 添加拦截器
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        boolean flag = true;
        if (flag) {
            registry.addInterceptor(urlInterception)
                    .addPathPatterns("/**");
        }
        //配置检索下的跨域请求
        registry.addInterceptor(corsInterceptor).addPathPatterns("/**");
        registry.addInterceptor(apiIPVisitInterception);
    }
}

RedisUtil代码:

import cnki.bdms.base.StringUtil;
import cnki.bdms.common.log.LogOpr;
import cnki.bdms.common.util.JsonHelper;
import com.fasterxml.jackson.core.type.TypeReference;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
@Lazy(false)
public class RedisUtil {
    private static LogOpr log = new LogOpr(RedisUtil.class);

    private static RedisTemplate<String, Object> redisTemplate;
    private static Integer maxCacheSize;
    private static Integer initialCapacity;
    private static Integer concurrencyLevel;
    private static Integer expireAfterWrite;

    //内存池
    private static LoadingCache<String, Object> objectCache;

    @Value("${guava.cache.maxCacheSize}")
    public void setMaxCacheSize(Integer maxCacheSize) {
        RedisUtil.maxCacheSize = maxCacheSize;
    }

    @Value("${guava.cache.initialCapacity}")
    public void setInitialCapacity(Integer initialCapacity) {
        RedisUtil.initialCapacity = initialCapacity;
    }

    @Value("${guava.cache.concurrencyLevel}")
    public void setConcurrencyLevel(Integer concurrencyLevel) {
        RedisUtil.concurrencyLevel = concurrencyLevel;
    }

    @Value("${guava.cache.expireAfterWrite}")
    public void setExpireAfterWrite(Integer expireAfterWrite) {
        RedisUtil.expireAfterWrite = expireAfterWrite;
    }

    @Autowired
    public void setRedisTemplate(RedisTemplate<String, Object> redisTemplate) {
        RedisUtil.redisTemplate = redisTemplate;
    }

    @PostConstruct
    public void init() {
        objectCache = Caffeine.newBuilder()
                .initialCapacity(maxCacheSize)
                .maximumSize(maxCacheSize)
                .expireAfterWrite(expireAfterWrite, TimeUnit.SECONDS)
                .build(key -> RedisUtil.getWithoutCache(key)
                );
    }

    /**
     * 指定缓存失效时间
     *
     * @param key  键
     * @param time 时间(秒)
     * @return
     */
    public static boolean expire(String key, long time) {
        try {
            if (time > 0) {
                redisTemplate.expire(key, time, TimeUnit.SECONDS);
            }
            return true;
        } catch (Exception e) {
            log.error(e);
            return false;
        }
    }

    /**
     * 根据key 获取过期时间
     *
     * @param key 键 不能为null
     * @return 时间(秒) 返回0代表为永久有效
     */
    public static long getExpire(String key) {
        return redisTemplate.getExpire(key, TimeUnit.SECONDS);
    }

    /**
     * 判断key是否存在
     *
     * @param key 键
     * @return true 存在 false不存在
     */
    public static boolean hasKey(String key) {
        try {
            return redisTemplate.hasKey(key);
        } catch (Exception e) {
            log.error(e);
            return false;
        }
    }

    /**
     * 删除缓存
     *
     * @param key 可以传一个值 或多个
     */
    public static void remove(String... key) {
        if (key != null && key.length > 0) {
            try {
                for (String item : key) {
                    //Set<String> keys = redisTemplate.keys(item);
                    Set<String> keys = getRedisKeys(item);
                    if (keys.size() == 0) {
                        redisTemplate.delete(item);
                        objectCache.invalidate(item);
                    } else {
                        redisTemplate.delete(keys);
                        objectCache.invalidateAll(keys);
                    }

                }
            } catch (Exception e) {
                log.error(e);
            }
        }
    }

    /**
     * 获取指定前缀的一系列key
     * 使用scan命令代替keys, Redis是单线程处理,keys命令在KEY数量较多时,
     * 操作效率极低【时间复杂度为O(N)】,该命令一旦执行会严重阻塞线上其它命令的正常请求
     *
     * @param realKey
     * @return
     */
    public static Set<String> getRedisKeys(String realKey) {
        try {
            return redisTemplate.execute((RedisCallback<Set<String>>) connection -> {
                Set<String> binaryKeys = new HashSet<>();
                Cursor<byte[]> cursor = connection.scan(new ScanOptions.ScanOptionsBuilder().match(realKey + "*").count(Integer.MAX_VALUE).build());
                while (cursor.hasNext()) {
                    binaryKeys.add(new String(cursor.next()));
                }
                try {
                    cursor.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
                return binaryKeys;
            });
        } catch (Throwable e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 删除指定前缀的一系列key
     *
     * @param keyPrefix
     */
    public static void removeAll(String keyPrefix) {
        try {
            Set<String> keys = getRedisKeys(keyPrefix);
            redisTemplate.delete(keys);
        } catch (Throwable e) {
            e.printStackTrace();
        }
    }

    public static void removeByPre(String keyPrefix) {
        try {
            Set<String> keys = redisTemplate.keys(keyPrefix + "*");
            redisTemplate.delete(keys);
        } catch (Throwable e) {
            e.printStackTrace();
        }
    }
    // ============================String=============================

    /**
     * 普通缓存获取
     *
     * @param key 键
     * @return 值
     */
    private static Object getWithoutCache(String key) {
        log.debug("从redis里取值:" + key);
        if (key == null) {
            return null;
        }
        Object v = redisTemplate.opsForValue().get(key);
        return v;
    }

    public static Object getFromRedisCache(String key, Integer type) {
        if (key == null) {
            return null;
        }
        Object v = redisTemplate.opsForValue().get(key);
        if (type == 1) {
            if (v instanceof String) {
                return GZIPUtils.uncompress((String) v);
            }
        }
        return v;
    }

    public static Object get(String key) {
        log.debug("从redis里取值:" + key);
        if (key == null) {
            return null;
        }
        Object v = objectCache.get(key);
        return v;
    }

    /**
     * 普通缓存获取,泛型
     *
     * @param key   键
     * @param clazz 类型
     * @return 值
     */
    public static <T> T get(String key, Class<T> clazz) {
        if (key == null) {
            return null;
        }
        Object v = objectCache.get(key);
        if (null == v) {
            return null;
        } else {
            try {
                return (T) v;
            } catch (Exception ex) {
                log.error(ex);
            }
            return null;
        }
    }

    public static <T> T getUseUnzip(String key, TypeReference<T> type) {
        if (key == null) {
            return null;
        }
        Object v = objectCache.get(key);
        if (null == v) {
            return null;
        } else {
            try {
                if (v instanceof String) {
                    String searchNavJson = GZIPUtils.uncompress((String) v);
                    if (StringUtil.isNotBlank(searchNavJson)) {
                        T result = JsonHelper.parseObject(searchNavJson, type);
                        objectCache.put(key, result);
                        return result;
                    } else {
                        return null;
                    }
                } else {
                    return (T) v;
                }
            } catch (Exception ex) {
                log.error(ex);
            }
            return null;
        }
    }

    public static <T> T get(String key, TypeReference<T> type) {
        if (key == null) {
            return null;
        }
        Object v = objectCache.get(key);
        if (null == v) {
            return null;
        } else {
            try {
                return (T) v;
            } catch (Exception ex) {
                log.error(ex);
            }
            return null;
        }
    }

    /**
     * 普通缓存放入
     *
     * @param key   键
     * @param value 值
     * @return true成功 false失败
     */
    public static boolean set(String key, Object value) {
        if (key == null || value == null) {
            return false;
        }
        try {
            //加写锁
            objectCache.put(key, value);
            redisTemplate.opsForValue().set(key, value, expireAfterWrite, TimeUnit.SECONDS);
            return true;
        } catch (Exception e) {
            log.error(e);
            return false;
        }
    }

    public static boolean setUseZip(String key, Object value) {
        if (key == null || value == null) {
            return false;
        }
        try {
            String jsonString = JsonHelper.toJSONString(value);
            String zipStr = GZIPUtils.compress(jsonString);
            objectCache.put(key, value);
            redisTemplate.opsForValue().set(key, zipStr, expireAfterWrite, TimeUnit.SECONDS);
            return true;
        } catch (Exception e) {
            log.error(e);
            return false;
        }
    }

    /**
     * 普通缓存放入
     *
     * @param key   键
     * @param value 值
     * @return true成功 false失败
     */
    public static boolean set(String key, Object value, Integer expireAfterSend) {
        if (key == null || value == null) {
            return false;
        }
        try {
            //加写锁
            objectCache.put(key, value);
            redisTemplate.opsForValue().set(key, value, expireAfterSend, TimeUnit.SECONDS);
            return true;
        } catch (Exception e) {
            log.error(e);
            return false;
        }
    }
}
4.5 最后在方法上面加上@LimitRequest就行了

 


相关文章推荐

全部评论: 0

    我有话说: