mybatis拦截器更新语句:使用拦截器优雅的实现数据权限拦截
mybatis拦截器更新语句:使用拦截器优雅的实现数据权限拦截
项目需求比如一个系统中有机构这张表,所有表都按照(orgid)过滤权限,
常规做法在sql语句的 最后加上 where orgId=?来控制权限,但是这样做非常的麻烦,一旦要修改需要改所有包含orgid字段的地方,下面给大家演示如何使用拦截器做这件事,让开发人员尽情的写代码,不用管权限的事。
1.建立一个数据权限枚举,方便统一管理,直接上干货。
package com.mkx.cn.web.entity.enums;
import lombok.Data;
import java.util.List;
/**
* @Author wangh
* @Description 控制数据权限枚举
**/
public enum DataAuthEnum {
orgId("orgId" "EXISTS(SELECT * FROM `sys_org` o WHERE %s.org_id=o.id AND JSON_CONTAINS(o.`path` '%s'))" "selectOrgId" "%s.org_id=%s" "JSON_CONTAINS(%s.`path` '%s')")
shopId("shopId" "EXISTS(SELECT * FROM `shops` s WHERE %s.shop_id=s.id and JSON_CONTAINS(s.`real_path` '%s'))" "selectShopId" "%s.shop_id=%s" "JSON_CONTAINS(%s.`real_path` '%s')")
sid("sid" "EXISTS(SELECT * FROM `shops` s WHERE %s.sid=s.id and JSON_CONTAINS(s.`real_path` '%s'))" "selectShopId" "%s.sid=%s" "JSON_CONTAINS(%s.`real_path` '%s')");
private String filed;
private String filedsql;
private String selectFiled;
private String selectFiledSql;
private String joinTableSql;
DataAuthEnum(String filed String filedSql String selectFiled String selectFiledSql String joinTableSql) {
this.filed = filed;
this.filedSql = filedSql;
this.selectFiled = selectFiled;
this.selectFiledSql = selectFiledSql;
this.joinTableSql = joinTableSql;
}
public static DataAuthEnum getAuthFiled(String filed) {
DataAuthEnum result = values()[0];
for (DataAuthEnum s : values()) {
if (s.getFiled().equals(filed)) {
result = s;
break;
}
}
return result;
}
public static DataAuthEnum[] getAuthAll() {
return values();
}
public String getFiled() {
return filed;
}
public String getFiledSql() {
return filedSql;
}
public String getSelectFiled() {
return selectFiled;
}
public String getJoinTableSql() {
return joinTableSql;
}
public String getSelectFiledSql() {
return selectFiledSql;
}
}
2.实现mybatis-plus过滤器,在执行sql之前,可以将参数中含有orgId 和shopId等字段的表,默认加上orgId=?或者是=你定义好的sql,直接上图看下效果。
执行sql之前的sql
经过滤器 的sql直接加上了权限
package com.mkx.cn.web.interceptor;
import cn.hutool.core.Annotation.AnnotationUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.ReflectUtil;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.mkx.cn.common.utils.JsoConst;
import com.mkx.cn.web.annotation.DataAuthInject;
import com.mkx.cn.web.annotation.JoinTableAlias;
import com.mkx.cn.web.config.UserContextHolder;
import com.mkx.cn.web.entity.enums.DataAuthEnum;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.io.StringReader;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
/**
* TODO 数据权限拦截器,自动增加sql段
*
* @author wangh
* @datetime 2022/10/18 16:32
*/
@Slf4j
public class MyBatisPlusAuthInterceptor extends JsqlParserSupport implements InnerInterceptor {
private String targetTableAlias = null;
@Override
public void beforeQuery(Executor executor MappedStatement ms Object parameter RowBounds rowBounds ResultHandler resultHandler BoundSql boundSql) {
try {
if ((isFilterMethod(ms) || getAuthAnnotation(ms))) {
DataAuthEnum[] authALL = DataAuthEnum.getAuthAll();
for (DataAuthEnum dataAuthEnum : authALL) {
//判断-是否需要添加企业过滤权限
boolean isAddShopAuth = JsoConst.mkxShop.equals(UserContextHolder.getInstance().getUserType()) || !ListUtil.of("shopId" "sid").contains(dataAuthEnum.getFiled());
if (isAddShopAuth && paramContainKey(parameter dataAuthEnum.getFiled())) {
makeSql(boundSql dataAuthEnum parameter);
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* @Author wangh
* @Description 判断包含需要过滤的字段
* @Date 2022/10/18 16:32
**/
private boolean paramContainKey(Object parameter String key) {
boolean haveKey = false;
try {
if (parameter instanceof MapperMethod.ParamMap) {
Map paramMap = (Map) parameter;
for (Object k : paramMap.keySet()) {
Object o = k.equals("ew") ? ((AbstractWrapper) paramMap.get(k)).getEntity() : paramMap.get(k);
haveKey = fieldIsExist(o key);
break;
}
} else {
haveKey = fieldIsExist(parameter key);
}
} catch (Exception e) {
e.printStackTrace();
}
return haveKey;
}
private boolean fieldIsExist(Object o String key) {
boolean haveField = ReflectUtil.hasField(o.getClass() key);
if (haveField) {
targetTableAlias = null;
Field field = ReflectUtil.getField(o.getClass() key);
TableField annotation = AnnotationUtil.getAnnotation(field TableField.class);
JoinTableAlias aliasNameAnnotation = AnnotationUtil.getAnnotation(field JoinTableAlias.class);
if (!Objects.isNull(aliasNameAnnotation)) {
targetTableAlias = aliasNameAnnotation.value();
}
return Objects.isNull(annotation) ? true : annotation.exist();
}
return false;
}
private Object paramGetValue(Object parameter String key) {
Object v = null;
try {
if (parameter instanceof MapperMethod.ParamMap) {
Map paramMap = (Map) parameter;
for (Object k : paramMap.keySet()) {
Object o = paramMap.get(k);
if (!ReflectUtil.hasField(o.getClass() key)) {
continue;
}
v = ReflectUtil.getFieldValue(o.getClass() key);
break;
}
} else {
ReflectUtil.getFieldValue(parameter.getClass() key);
}
} catch (SecurityException e) {
}
return v;
}
/**
* @Author wangh
* @Description 需要过滤的mybatis-plus 内部提供的方法
* @Date 2022/10/18 16:30
**/
private boolean isFilterMethod(MappedStatement mappedStatement) {
String[] methods = new String[]{"selectList" "selectPage" "selectListPage_COUNT"};
boolean isTrue = false;
try {
String id = mappedStatement.getId();
String methodName = id.substring(id.lastIndexOf(".") 1);
isTrue = Arrays.asList(methods).contains(methodName);
} catch (SecurityException e) {
}
return isTrue;
}
private String makeSql(BoundSql boundSql DataAuthEnum dataAuthEnum Object parameter) throws JSQLParserException {
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
String sql = mpBs.sql();
Select select = (Select) new CCJSqlParserManager().parse(new StringReader(sql));
PlainSelect plain = (PlainSelect) select.getSelectBody();
Table fromItem = (Table) plain.getFromItem();
//有别名用别名,无别名用表名,防止字段冲突报错
String mainTableName = !Objects.isNull(targetTableAlias) ? targetTableAlias :
Objects.isNull(fromItem.getAlias()) ? fromItem.getName() : fromItem.getAlias().getName();
boolean isOrg = dataAuthEnum.getFiled().equals("orgId") ? true : false;
//构建子查询
boolean isSelectValue = paramContainKey(parameter dataAuthEnum.getSelectFiled());
String formatSql = !Objects.isNull(targetTableAlias) ? dataAuthEnum.getJoinTableSql() : isSelectValue ? dataAuthEnum.getSelectFiledSql() : dataAuthEnum.getFiledSql();
String dataAuthSql = String.format(formatSql mainTableName
isSelectValue ? paramGetValue(parameter dataAuthEnum.getSelectFiled()) : isOrg ? UserContextHolder.getInstance().getOrgId() : UserContextHolder.getInstance().getShopId());
if (plain.getWhere() == null) {
plain.setWhere(CCJSqlParserUtil.parseCondExpression(dataAuthSql true));
} else {
plain.setWhere(new AndExpression(plain.getWhere() CCJSqlParserUtil.parseCondExpression(dataAuthSql)));
}
//重新构建sql
mpBs.sql(select.toString());
return select.toString();
}
/**
* 通过反射获取mapper方法是否加了自定义注解
*/
private boolean getAuthAnnotation(MappedStatement mappedStatement) {
DataAuthInject dataAuthInject = null;
try {
String id = mappedStatement.getId();
String className = id.substring(0 id.lastIndexOf("."));
String methodName = id.substring(id.lastIndexOf(".") 1);
final Class<?> cls = Class.forName(className);
final Method[] methods = cls.getMethods();
for (Method method : methods) {
if (method.getName().equals(methodName) && method.isAnnotationPresent(DataAuthInject.class)) {
dataAuthInject = method.getAnnotation(DataAuthInject.class);
break;
}
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
return dataAuthInject != null;
}
@Override
public void setProperties(Properties properties) {
}
}