From 948000e500b98792a37f177877d06024a4f64c6f Mon Sep 17 00:00:00 2001 From: haungxing <1203316822@qq.com> Date: Fri, 5 Jul 2024 14:39:00 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E5=A4=A7=E6=A8=A1=E5=9E=8Bsql?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../datasource/service/DataSourceService.java | 63 +++++++++++++--------- .../controller/HznlmInteractiveController.java | 11 ++++ .../bigmodel/interactive/dto/AuthDataDTO.java | 3 +- 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/datasource/service/DataSourceService.java b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/datasource/service/DataSourceService.java index 8509f2b..68dd57f 100644 --- a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/datasource/service/DataSourceService.java +++ b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/datasource/service/DataSourceService.java @@ -45,33 +45,44 @@ public class DataSourceService { }); String sql = sqlVO.getSql(); String userAuthDataSQL = userAuthDataService.getUserAuthDataSQL(Long.parseLong(sqlVO.getUserId())); - List> tempViewList = Lists.newArrayList(); - try { - for (TableAuthVO tableAuthVO : sqlVO.getTableAuthVOList()) { - // 创建视图语句 - String viewName = "V_TEMP_" + UUID.randomUUID().toString().replace("-", ""); - String createView = "CREATE VIEW " + viewName + " AS SELECT * FROM " + tableAuthVO.getTableName() + " where " + userAuthDataSQL; - this.updateOnSpecificDataSource(createView,tableAuthVO.getDatasourceName()); - Map viewMap = new HashMap(2); - viewMap.put("datasource",tableAuthVO.getDatasourceName()); - viewMap.put("viewName",viewName); - tempViewList.add(viewMap); - sql = sql.replace(tableAuthVO.getTableName(),viewName); - } - log.info("执行sql:{}",sql); - return this.queryListOnSpecificDataSource(sql, sqlVO.getTableAuthVOList().get(0).getDatasourceName()); - } - catch(Exception e) { - log.error("An Error occurred!",e); - throw new ServiceException("sql执行失败!"); - } - finally { - if(CollectionUtil.isNotEmpty(tempViewList)) { - tempViewList.forEach(viewMap -> { - this.updateOnSpecificDataSource("DROP VIEW IF EXISTS `" + viewMap.get("viewName")+"`;",viewMap.get("datasource")); - }); - } + for (TableAuthVO tableAuthVO : sqlVO.getTableAuthVOList()) { + String tableSubStr = "(SELECT * FROM " + tableAuthVO.getTableName() + " where" + userAuthDataSQL +") temp"; + sql = sql.replace(tableAuthVO.getTableName(),tableSubStr); } + return this.queryListOnSpecificDataSource(sql, sqlVO.getTableAuthVOList().get(0).getDatasourceName()); + // 过滤更新、删除语句 +// Assert.isTrue(!DataSourceService.isUpdateOrDelete(sqlVO.getSql()),() -> { +// throw new ServiceException("执行sql语句包含更新/删除操作,执行失败!"); +// }); +// String sql = sqlVO.getSql(); +// String userAuthDataSQL = userAuthDataService.getUserAuthDataSQL(Long.parseLong(sqlVO.getUserId())); +// List> tempViewList = Lists.newArrayList(); +// try { +// for (TableAuthVO tableAuthVO : sqlVO.getTableAuthVOList()) { +// // 创建视图语句 +// String viewName = "V_TEMP_" + UUID.randomUUID().toString().replace("-", ""); +// String createView = "CREATE VIEW " + viewName + " AS SELECT * FROM " + tableAuthVO.getTableName() + " where " + userAuthDataSQL; +// this.updateOnSpecificDataSource(createView,tableAuthVO.getDatasourceName()); +// Map viewMap = new HashMap(2); +// viewMap.put("datasource",tableAuthVO.getDatasourceName()); +// viewMap.put("viewName",viewName); +// tempViewList.add(viewMap); +// sql = sql.replace(tableAuthVO.getTableName(),viewName); +// } +// log.info("执行sql:{}",sql); +// return this.queryListOnSpecificDataSource(sql, sqlVO.getTableAuthVOList().get(0).getDatasourceName()); +// } +// catch(Exception e) { +// log.error("An Error occurred!",e); +// throw new ServiceException("sql执行失败!"); +// } +// finally { +// if(CollectionUtil.isNotEmpty(tempViewList)) { +// tempViewList.forEach(viewMap -> { +// this.updateOnSpecificDataSource("DROP VIEW IF EXISTS `" + viewMap.get("viewName")+"`;",viewMap.get("datasource")); +// }); +// } +// } } /** diff --git a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/controller/HznlmInteractiveController.java b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/controller/HznlmInteractiveController.java index 28b9ad3..1ebbacf 100644 --- a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/controller/HznlmInteractiveController.java +++ b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/controller/HznlmInteractiveController.java @@ -2,6 +2,8 @@ package com.hnac.hzims.bigmodel.interactive.controller; import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport; import com.hnac.hzims.bigmodel.BigModelConstants; +import com.hnac.hzims.bigmodel.datasource.service.DataSourceService; +import com.hnac.hzims.bigmodel.datasource.vo.SqlVO; import com.hnac.hzims.bigmodel.interactive.dto.AuthDataDTO; import com.hnac.hzims.bigmodel.interactive.req.ModelFunctionReq; import com.hnac.hzims.bigmodel.interactive.service.IHznlmInteractiveService; @@ -17,6 +19,7 @@ import org.springframework.web.bind.annotation.*; import javax.validation.Valid; import java.util.List; +import java.util.Map; /** * @Author: huangxing @@ -30,6 +33,7 @@ import java.util.List; public class HznlmInteractiveController { private final IHznlmInteractiveService interactiveService; + private final DataSourceService dataSourceService; @PostMapping(value = "/get_auth_data") @ApiOperation("获取鉴权数据") @@ -44,4 +48,11 @@ public class HznlmInteractiveController { public R resolve(@RequestBody ModelFunctionReq req) { return R.data(interactiveService.resolve(req)); } + + @PostMapping("/execute_query") + @ApiOperation("执行大模型sql") + @ApiOperationSupport(order = 1) + public R>> executeQuery(@RequestBody @Valid SqlVO sqlVO) { + return R.data(dataSourceService.queryListOnSpecificDataSource(sqlVO)); + } } diff --git a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/dto/AuthDataDTO.java b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/dto/AuthDataDTO.java index 3d4fba5..b891ec0 100644 --- a/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/dto/AuthDataDTO.java +++ b/hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/interactive/dto/AuthDataDTO.java @@ -19,12 +19,11 @@ import java.io.Serializable; @EqualsAndHashCode public class AuthDataDTO implements Serializable { - @JsonProperty("chat_id") + @JsonProperty("chatId") @ApiModelProperty("问答ID,用于获取前端发起问答传入缓存中数据") @NotBlank private String sessionId; - @JsonProperty("user_id") @ApiModelProperty("提问用户ID") @NotBlank private String userId;