1、背景
因为博客开放注册,有人恶意注册账号发布设灰广告推文。
2、 原理
限制一定时间内限制接口请求次数,主要是评论,发文接口, 超过请求次数,会发送邮件到站长邮箱,通知删除恶意文章、评论。
3、用到的技术
自定义注解、AOP、ExpiringMap(带有有效期的映射)
4、具体代码实现
4.1 引入依赖
<!-- AOP依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
<version>2.1.5.RELEASE</version>
</dependency>
<!-- Map依赖 -->
<dependency>
<groupId>net.jodah</groupId>
<artifactId>expiringmap</artifactId>
<version>0.5.8</version>
</dependency>
4.2 自定义注解@LimitRequest
/**
* @author SongBin on 2021-11-12.
*/
package com.mtons.mblog.web.interceptor;
import java.lang.annotation.*;
@Documented
@Target(ElementType.METHOD) // 说明该注解只能放在方法上面
@Retention(RetentionPolicy.RUNTIME)
//十分钟限制发表一篇文章
public @interface LimitRequest {
long time() default 1000 * 60 * 10; // 限制时间1小时 单位:毫秒
int count() default 1; // 允许请求的次数
}
4.3 自定义AOP
/**
* @author SongBin on 2021-11-12.
*/
package com.mtons.mblog.web.interceptor;
import com.alibaba.fastjson.JSON;
import com.mtons.mblog.base.lang.Result;
import net.jodah.expiringmap.ExpirationPolicy;
import net.jodah.expiringmap.ExpiringMap;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
@Aspect
@Component
public class LimitRequestAspect {
private static final String errorView = "/error";
private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();
// 定义切点
// 让所有有@LimitRequest注解的方法都执行切面方法
@Pointcut("@annotation(limitRequest)")
public void excudeService(LimitRequest limitRequest) {
}
@Around("excudeService(limitRequest)")
public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
// 获得request对象
RequestAttributes ra = RequestContextHolder.getRequestAttributes();
ServletRequestAttributes sra = (ServletRequestAttributes) ra;
HttpServletRequest request = sra.getRequest();
HttpServletResponse response = sra.getResponse();
// 获取Map对象, 如果没有则返回默认值
// 第一个参数是key, 第二个参数是默认值
ExpiringMap<String, Integer> uc = book.getOrDefault(request.getRequestURI(), ExpiringMap.builder().variableExpiration().build());
Integer uCount = uc.getOrDefault(request.getRemoteAddr(), 0);
if (uCount >= limitRequest.count()) { // 超过次数,不执行目标方法
ModelAndView view = null;
String ret = "接口请求超过次数";
// try {
// response.setContentType("application/json;charset=UTF-8");
// response.getWriter().print(JSON.toJSONString(Result.failure(ret)));
// } catch (IOException e) {
// // do something
// }
// view = new ModelAndView();
Map<String, Object> map = new HashMap<String, Object>();
map.put("error", ret);
// map.put("base", request.getContextPath());
view = new ModelAndView(errorView, map);
return view;
} else if (uCount == 0){ // 第一次请求时,设置有效时间
// /** Expires entries based on when they were last accessed */
// ACCESSED,
// /** Expires entries based on when they were created */
// CREATED;
uc.put(request.getRemoteAddr(), uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
} else { // 未超过次数, 记录加一
uc.put(request.getRemoteAddr(), uCount + 1);
}
book.put(request.getRequestURI(), uc);
// result的值就是被拦截方法的返回值
Object result = pjp.proceed();
return result;
}
}
第一个静态Map是多线程安全的Map(ConcurrentHashMap),它的key是接口对于的url,它的value是一个多线程安全且键值对是有有效期的Map(ExpiringMap)。
ExpiringMap的key是请求的ip地址,value是已经请求的次数。
ExpiringMap更多的使用方法可以参考:https://github.com/jhalterman/expiringmap
4.4 上面的LimitRequestAspect代码也可以用一个拦截器代替,可以将次数存到redis里
import cnki.bdms.aop.LimitRequest;
import cnki.bdms.common.enums.ReturnStatus;
import cnki.bdms.common.util.JsonHelper;
import cnki.bdms.util.IPUtils;
import cnki.bdms.util.RedisUtil;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* User: SongBin
* Date: 11:31 2022/4/25
* Description: 限制用户接口访问频次
* Version 1.0
*/
@Component
public class ApiIPVisitInterception extends HandlerInterceptorAdapter {
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
@Value("${sms.expireAfterSend}")
private Integer expireAfterSend;
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
ModelAndView modelAndView) throws Exception {
super.postHandle(request, response, handler, modelAndView);
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
return checkUrl(request, response, handler);
}
private boolean checkUrl(HttpServletRequest request, HttpServletResponse response, Object handler) throws IOException {
boolean flag = true;
ObjectNode result = JsonHelper.createNode();
HandlerMethod hm = (HandlerMethod) handler;
//获取方法中的注解,看是否有该注解
LimitRequest accessLimit = hm.getMethodAnnotation(LimitRequest.class);
if (accessLimit != null) {
int time = expireAfterSend != null && expireAfterSend > 0 ? 1000 * 60 * expireAfterSend : accessLimit.time();
int maxCount = accessLimit.count();
//从redis中获取用户访问的次数
String ip = IPUtils.getIP();
Integer count = (Integer) RedisUtil.get(ip);
if (count == null) {
//第一次访问
RedisUtil.set(ip, 1);
} else if (count < maxCount) {
//加1
count = count + 1;
RedisUtil.set(ip, count, time);
} else {
//超出访问次数
logger.info("访问过快ip ===> " + ip + " 且在 " + time + " 毫秒内超过最大限制 ===> " + maxCount + " 次数达到 ====> " + count);
result.put("code", ReturnStatus.ApiLock.getValue());
result.put("msg", "接口请求太频繁,系统认定恶意攻击!");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Headers", "token, Accept, Origin, X-Requested-With, Content-Type, Last-Modified");
response.addHeader("Content-Type", "application/json;charset=UTF-8");
response.setStatus(200);
response.getWriter().write(result.toString());
flag = false;
}
}
return flag;
}
}
然后将拦截器注册进去
import cnki.bdms.web.ui.security.ApiIPVisitInterception;
import cnki.bdms.web.ui.security.CorsInterceptor;
import cnki.bdms.web.ui.security.URLInterception;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.alibaba.fastjson.support.config.FastJsonConfig;
import com.alibaba.fastjson.support.spring.FastJsonHttpMessageConverter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Autowired
private URLInterception urlInterception;
@Autowired
private ApiIPVisitInterception apiIPVisitInterception;
@Autowired
CorsInterceptor corsInterceptor;
/**
* 添加拦截器
*/
@Override
public void addInterceptors(InterceptorRegistry registry) {
boolean flag = true;
if (flag) {
registry.addInterceptor(urlInterception)
.addPathPatterns("/**");
}
//配置检索下的跨域请求
registry.addInterceptor(corsInterceptor).addPathPatterns("/**");
registry.addInterceptor(apiIPVisitInterception);
}
}
RedisUtil代码:
import cnki.bdms.base.StringUtil;
import cnki.bdms.common.log.LogOpr;
import cnki.bdms.common.util.JsonHelper;
import com.fasterxml.jackson.core.type.TypeReference;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@Component
@Lazy(false)
public class RedisUtil {
private static LogOpr log = new LogOpr(RedisUtil.class);
private static RedisTemplate<String, Object> redisTemplate;
private static Integer maxCacheSize;
private static Integer initialCapacity;
private static Integer concurrencyLevel;
private static Integer expireAfterWrite;
//内存池
private static LoadingCache<String, Object> objectCache;
@Value("${guava.cache.maxCacheSize}")
public void setMaxCacheSize(Integer maxCacheSize) {
RedisUtil.maxCacheSize = maxCacheSize;
}
@Value("${guava.cache.initialCapacity}")
public void setInitialCapacity(Integer initialCapacity) {
RedisUtil.initialCapacity = initialCapacity;
}
@Value("${guava.cache.concurrencyLevel}")
public void setConcurrencyLevel(Integer concurrencyLevel) {
RedisUtil.concurrencyLevel = concurrencyLevel;
}
@Value("${guava.cache.expireAfterWrite}")
public void setExpireAfterWrite(Integer expireAfterWrite) {
RedisUtil.expireAfterWrite = expireAfterWrite;
}
@Autowired
public void setRedisTemplate(RedisTemplate<String, Object> redisTemplate) {
RedisUtil.redisTemplate = redisTemplate;
}
@PostConstruct
public void init() {
objectCache = Caffeine.newBuilder()
.initialCapacity(maxCacheSize)
.maximumSize(maxCacheSize)
.expireAfterWrite(expireAfterWrite, TimeUnit.SECONDS)
.build(key -> RedisUtil.getWithoutCache(key)
);
}
/**
* 指定缓存失效时间
*
* @param key 键
* @param time 时间(秒)
* @return
*/
public static boolean expire(String key, long time) {
try {
if (time > 0) {
redisTemplate.expire(key, time, TimeUnit.SECONDS);
}
return true;
} catch (Exception e) {
log.error(e);
return false;
}
}
/**
* 根据key 获取过期时间
*
* @param key 键 不能为null
* @return 时间(秒) 返回0代表为永久有效
*/
public static long getExpire(String key) {
return redisTemplate.getExpire(key, TimeUnit.SECONDS);
}
/**
* 判断key是否存在
*
* @param key 键
* @return true 存在 false不存在
*/
public static boolean hasKey(String key) {
try {
return redisTemplate.hasKey(key);
} catch (Exception e) {
log.error(e);
return false;
}
}
/**
* 删除缓存
*
* @param key 可以传一个值 或多个
*/
public static void remove(String... key) {
if (key != null && key.length > 0) {
try {
for (String item : key) {
//Set<String> keys = redisTemplate.keys(item);
Set<String> keys = getRedisKeys(item);
if (keys.size() == 0) {
redisTemplate.delete(item);
objectCache.invalidate(item);
} else {
redisTemplate.delete(keys);
objectCache.invalidateAll(keys);
}
}
} catch (Exception e) {
log.error(e);
}
}
}
/**
* 获取指定前缀的一系列key
* 使用scan命令代替keys, Redis是单线程处理,keys命令在KEY数量较多时,
* 操作效率极低【时间复杂度为O(N)】,该命令一旦执行会严重阻塞线上其它命令的正常请求
*
* @param realKey
* @return
*/
public static Set<String> getRedisKeys(String realKey) {
try {
return redisTemplate.execute((RedisCallback<Set<String>>) connection -> {
Set<String> binaryKeys = new HashSet<>();
Cursor<byte[]> cursor = connection.scan(new ScanOptions.ScanOptionsBuilder().match(realKey + "*").count(Integer.MAX_VALUE).build());
while (cursor.hasNext()) {
binaryKeys.add(new String(cursor.next()));
}
try {
cursor.close();
} catch (IOException e) {
e.printStackTrace();
}
return binaryKeys;
});
} catch (Throwable e) {
e.printStackTrace();
}
return null;
}
/**
* 删除指定前缀的一系列key
*
* @param keyPrefix
*/
public static void removeAll(String keyPrefix) {
try {
Set<String> keys = getRedisKeys(keyPrefix);
redisTemplate.delete(keys);
} catch (Throwable e) {
e.printStackTrace();
}
}
public static void removeByPre(String keyPrefix) {
try {
Set<String> keys = redisTemplate.keys(keyPrefix + "*");
redisTemplate.delete(keys);
} catch (Throwable e) {
e.printStackTrace();
}
}
// ============================String=============================
/**
* 普通缓存获取
*
* @param key 键
* @return 值
*/
private static Object getWithoutCache(String key) {
log.debug("从redis里取值:" + key);
if (key == null) {
return null;
}
Object v = redisTemplate.opsForValue().get(key);
return v;
}
public static Object getFromRedisCache(String key, Integer type) {
if (key == null) {
return null;
}
Object v = redisTemplate.opsForValue().get(key);
if (type == 1) {
if (v instanceof String) {
return GZIPUtils.uncompress((String) v);
}
}
return v;
}
public static Object get(String key) {
log.debug("从redis里取值:" + key);
if (key == null) {
return null;
}
Object v = objectCache.get(key);
return v;
}
/**
* 普通缓存获取,泛型
*
* @param key 键
* @param clazz 类型
* @return 值
*/
public static <T> T get(String key, Class<T> clazz) {
if (key == null) {
return null;
}
Object v = objectCache.get(key);
if (null == v) {
return null;
} else {
try {
return (T) v;
} catch (Exception ex) {
log.error(ex);
}
return null;
}
}
public static <T> T getUseUnzip(String key, TypeReference<T> type) {
if (key == null) {
return null;
}
Object v = objectCache.get(key);
if (null == v) {
return null;
} else {
try {
if (v instanceof String) {
String searchNavJson = GZIPUtils.uncompress((String) v);
if (StringUtil.isNotBlank(searchNavJson)) {
T result = JsonHelper.parseObject(searchNavJson, type);
objectCache.put(key, result);
return result;
} else {
return null;
}
} else {
return (T) v;
}
} catch (Exception ex) {
log.error(ex);
}
return null;
}
}
public static <T> T get(String key, TypeReference<T> type) {
if (key == null) {
return null;
}
Object v = objectCache.get(key);
if (null == v) {
return null;
} else {
try {
return (T) v;
} catch (Exception ex) {
log.error(ex);
}
return null;
}
}
/**
* 普通缓存放入
*
* @param key 键
* @param value 值
* @return true成功 false失败
*/
public static boolean set(String key, Object value) {
if (key == null || value == null) {
return false;
}
try {
//加写锁
objectCache.put(key, value);
redisTemplate.opsForValue().set(key, value, expireAfterWrite, TimeUnit.SECONDS);
return true;
} catch (Exception e) {
log.error(e);
return false;
}
}
public static boolean setUseZip(String key, Object value) {
if (key == null || value == null) {
return false;
}
try {
String jsonString = JsonHelper.toJSONString(value);
String zipStr = GZIPUtils.compress(jsonString);
objectCache.put(key, value);
redisTemplate.opsForValue().set(key, zipStr, expireAfterWrite, TimeUnit.SECONDS);
return true;
} catch (Exception e) {
log.error(e);
return false;
}
}
/**
* 普通缓存放入
*
* @param key 键
* @param value 值
* @return true成功 false失败
*/
public static boolean set(String key, Object value, Integer expireAfterSend) {
if (key == null || value == null) {
return false;
}
try {
//加写锁
objectCache.put(key, value);
redisTemplate.opsForValue().set(key, value, expireAfterSend, TimeUnit.SECONDS);
return true;
} catch (Exception e) {
log.error(e);
return false;
}
}
}
4.5 最后在方法上面加上@LimitRequest就行了
注意:本文归作者所有,未经作者允许,不得转载