MyBatis 插件修改 Sql 语句
大约 4 分钟
MyBatis 插件动态修改查询 sql
Mybatis 插件
插件(plugins)
MyBatis 允许你在映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:
- Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
- ParameterHandler (getParameterObject, setParameters)
- ResultSetHandler (handleResultSets, handleOutputParameters)
- StatementHandler (prepare, parameterize, batch, update, query)
这些类中方法的细节可以通过查看每个方法的签名来发现,或者直接查看 MyBatis 发行包中的源代码。 如果你想做的不仅仅是监控方法的调用,那么你最好相当了解要重写的方法的行为。 因为在试图修改或重写已有方法的行为时,很可能会破坏 MyBatis 的核心模块。 这些都是更底层的类和方法,所以使用插件的时候要特别当心。
通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定想要拦截的方法签名即可。 原文地址
MyBatisUtil
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* MyBatis 工具类
*/
public class MyBatisUtil {
/**
* 用于重写 update 方法的 sqlSource
*/
public static class RewriteSqlSource implements SqlSource {
private final BoundSql boundSql;
public RewriteSqlSource(BoundSql boundSql){
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
/**
* 通过 rewriteFunction 创建修改过 sql 的 BoundSql
*/
public static BoundSql rewriteSql(MappedStatement ms, Object parameterObject, Function<String, String> rewriteFunction){
BoundSql boundSql = ms.getBoundSql(parameterObject);
return new BoundSql(ms.getConfiguration(), rewriteFunction.apply(boundSql.getSql()),
boundSql.getParameterMappings(), boundSql.getParameterObject());
}
/**
* 通过 rewriteFunction 重写已经存在的 boundSql
*/
public static BoundSql rewriteSql(MappedStatement ms, BoundSql boundSql, Function<String, String> rewriteFunction){
return new BoundSql(ms.getConfiguration(), rewriteFunction.apply(boundSql.getSql()),
boundSql.getParameterMappings(), boundSql.getParameterObject());
}
/**
* 创建 executor.update 方法的 stateMent 并通过 rewriteFunction 修改 sql
*/
public static MappedStatement rewriteStatement(MappedStatement ms, BoundSql boundSql){
return new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), new RewriteSqlSource(boundSql), ms.getSqlCommandType())
.resource(ms.getResource())
.fetchSize(ms.getFetchSize())
.statementType(ms.getStatementType())
.keyGenerator(ms.getKeyGenerator())
.timeout(ms.getTimeout())
.parameterMap(ms.getParameterMap())
.resultMaps(ms.getResultMaps())
.cache(ms.getCache())
.useCache(ms.isUseCache()).build();
}
@Data
@AllArgsConstructor
public static class Predicate{
private String filed;
private String operator;
private String value;
public String buildSql(){
return String.format("%s %s '%s'", filed, operator, value);
}
}
/**
* 在 sql 语句 WHERE 关键字后添加条件
* @param sql 原 sql 语句
* @param predicate 条件
* @return 修改后的语句 0
*/
public static String appendWhere(String sql, Predicate predicate){
String havingClause = "";
String limitClause = "";
int havingIndex = sql.indexOf("HAVING");
int limitIndex = sql.indexOf("LIMIT");
if (limitIndex != -1) {
limitClause = sql.substring(limitIndex);
sql = sql.substring(0, limitIndex);
}
if (havingIndex != -1){
havingClause = sql.substring(havingIndex);
sql = sql.substring(0, havingIndex);
}
Pattern reg = Pattern.compile("^[\\s\\S]*(where|WHERE|Where) (?<clause>[\\s\\S]*)$");
Matcher matcher = reg.matcher(sql);
if (matcher.matches()){
String clause = matcher.group("clause");
return sql.substring(0, sql.length()-clause.length()) + predicate.buildSql() + " and " + clause
+ " " + havingClause
+ " " + limitClause;
}
return sql + " WHERE " + predicate.buildSql()
+ " " + havingClause
+ " " + limitClause;
}
/**
* 复制源 additionalParam
*/
public static void copyAdditionalParam(BoundSql srcBoundSql, BoundSql tarBoundSql){
for (ParameterMapping parameterMapping : srcBoundSql.getParameterMappings()) {
String propName = parameterMapping.getProperty();
if (srcBoundSql.hasAdditionalParameter(propName)){
tarBoundSql.setAdditionalParameter(propName, srcBoundSql.getAdditionalParameter(propName));
}
}
}
}
插件源码
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.lang.reflect.Method;
import java.util.function.Function;
@Slf4j
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class MyInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object param = args[1];
Method method = ReflectUtils.getUniqueMethodByName(ms.getId());
if (method!=null){
//TODO 这里可以拿到 method 方法进行一些判断
}
BoundSql boundSql = ms.getBoundSql(param);
BoundSql newSql;
if (invocation.getMethod().getName().equals("query")){
Executor executor = (Executor) invocation.getTarget();
RowBounds rowBounds = (RowBounds) args[2];
if (args.length==6){
boundSql = (BoundSql) args[5];
}
newSql = MyBatisUtil.rewriteSql(ms, boundSql, sql -> {
//TODO 这里重写 query 的 sql
return sql;
});
if (args.length==6){
//重写 cacheKey
args[4] = executor.createCacheKey(ms, param, rowBounds, newSql);
}
}else {
newSql = MyBatisUtil.rewriteSql(ms, boundSql, sql -> {
//TODO 这里重写 update 的 sql
return sql;
});
}
MyBatisUtil.copyAdditionalParam(boundSql, newSql);
args[0] = MyBatisUtil.rewriteStatement(ms, newSql);
return invocation.proceed();
}
}
注册插件
import lombok.AllArgsConstructor;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import java.util.List;
@Configuration
@AllArgsConstructor
public class MybatisConfig implements ApplicationListener<ContextRefreshedEvent> {
private List<SqlSessionFactory> sqlSessionFactoryList;
@Override
public void onApplicationEvent(ContextRefreshedEvent event) {
for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
MyInterceptor myInterceptor = new MyInterceptor();
if (!containsInterceptor(configuration, myInterceptor)){
configuration.addInterceptor(myInterceptor);
}
}
}
private boolean containsInterceptor(org.apache.ibatis.session.Configuration configuration, Interceptor interceptor) {
try {
return configuration.getInterceptors().stream().anyMatch((config) -> {
return interceptor.getClass().isAssignableFrom(config.getClass());
});
} catch (Exception var4) {
return false;
}
}
}
使用例- dao 方法有 isOwner 注解的方法添加查询条件:create_by = 当前登录用户
Dao方法
@IsOwner
List<Role> queryAllByLimit(@Param("obj") Role role);
Mapper 文件
<!--查询指定数据-->
<select id="queryAllByLimit" resultMap="RoleMap">
select
id,name,description,create_time,create_by, update_by, update_time
from sys_role
<where>
<if test="obj.id != null">
and id = #{obj.id}
</if>
<if test="obj.name != null and obj.name != ''">
and name like concat('%', #{obj.name}, '%')
</if>
<if test="obj.description != null and obj.description != ''">
and description = #{obj.description}
</if>
<if test="obj.createTime != null">
and create_time = #{obj.createTime}
</if>
<if test="obj.createBy != null and obj.createBy != ''">
and create_by = #{obj.createBy}
</if>
and id > 1
</where>
</select>
插件
import cn.linkot.utils.RContext;
import cn.linkot.utils.data.ReflectUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.lang.reflect.Method;
@Slf4j
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class MyInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
Executor executor = (Executor) invocation.getTarget();
MappedStatement ms = (MappedStatement) args[0];
Method method = ReflectUtils.getUniqueMethodByName(ms.getId());
if (method!=null){
//如果 dao 方法没有 isOwner 注解就不进行修改
if (method.getAnnotation(IsOwner.class)==null){
return invocation.proceed();
}
}
Object parameter = args[1];
if (invocation.getMethod().getName().equals("query")){
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler resultHandler = (ResultHandler) args[3];
CacheKey cacheKey;
BoundSql boundSql = MyBatisUtil.rewriteSql(ms, parameter, sql -> {
//直接拼接 create_by = 当前用户名
return sql + " and create_by='"+ RContext.getUsername()+"'";
});
if (args.length == 4) {
//4 个参数时
cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
} else {
//6 个参数时
cacheKey = (CacheKey) args[4];
}
return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
}else {
//Update
return executor.update(MyBatisUtil.rewriteUpdateStatement(ms, sql->{
//TODO 这里对 update sql 进行操作
return sql;
}), parameter);
}
}
}
查询日志
DEBUG c.d.a.d.R.queryAllByLimit_COUNT - ==> Preparing: select id,name,description,create_time,create_by, update_by, update_time from sys_role WHERE id > 1 and create_by='admin'
DEBUG c.d.a.d.R.queryAllByLimit_COUNT - ==> Parameters:
DEBUG c.d.a.d.R.queryAllByLimit_COUNT - <== Total: 0