Browse Source

fix: 大模型向量数据库同步

zhongwei
haungxing 2 months ago
parent
commit
f1f97fe9ee
  1. 10
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/configuration/BigModelInvokeApi.java
  2. 103
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/database/service/WeaviateService.java
  3. 11
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/interactive/controller/FontEndInteractiveController.java
  4. 2
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/DataSourceService.java
  5. 2
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/TablePropertyService.java
  6. 3
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/impl/VectorParamServiceImpl.java
  7. 7
      hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/utils/RequestClientUtil.java
  8. 3
      hzims-service/gglm-big-model/src/main/resources/template/template.yml

10
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/configuration/BigModelInvokeApi.java

@ -72,4 +72,14 @@ public class BigModelInvokeApi {
*/ */
private String identifyForm; private String identifyForm;
/**
* 增加向量
*/
private String insertVectors;
/**
* 新建向量表
*/
private String createTable;
} }

103
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/database/service/WeaviateService.java

@ -7,6 +7,7 @@ import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.hnac.gglm.bigmodel.configuration.BigModelInvokeApi; import com.hnac.gglm.bigmodel.configuration.BigModelInvokeApi;
import com.hnac.gglm.bigmodel.utils.RequestClientUtil;
import com.hnac.hzinfo.exception.HzServiceException; import com.hnac.hzinfo.exception.HzServiceException;
import io.weaviate.client.WeaviateClient; import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result; import io.weaviate.client.base.Result;
@ -44,6 +45,8 @@ public class WeaviateService {
@Value("${gglm.vectorUrl}") @Value("${gglm.vectorUrl}")
private String vectorUrl; private String vectorUrl;
@Value("${gglm.url}")
private String gglmUrl;
/** /**
* 对象保存向量数据库 * 对象保存向量数据库
@ -64,44 +67,74 @@ public class WeaviateService {
return !result.hasErrors(); return !result.hasErrors();
} }
/** // /**
* 对象批量保存向量数据库 // * 对象批量保存向量数据库
* @param entities 保存对象列表 // * @param entities 保存对象列表
* @param className 保存表名 // * @param className 保存表名
* @param attrsMap 待计算的列信息 key-向量名 value-实体类对象属性,多个按逗号分隔 // * @param attrsMap 待计算的列信息 key-向量名 value-实体类对象属性,多个按逗号分隔
* @return 保存操作结果 // * @return 保存操作结果
*/ // */
// public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) {
// entities = entities.subList(0, 1);
// ObjectCreator creator = weaviateClient.data().creator().withClassName(className);
// List<String> vectorStrs = Lists.newArrayList();
// List<String> attrs = Lists.newArrayList();
// if(Func.isNotEmpty(attrsMap)) {
// // 格式化数据
// attrsMap.forEach((k,v) -> attrs.add(v));
// // 解析待计算的向量字段
// entities.forEach(entity -> {
// List<String> vectorStr = attrs.stream().map(fields -> this.getFieldValue(fields, entity)).filter(Func::isNotEmpty).collect(Collectors.toList());
// vectorStrs.addAll(vectorStr);
// });
// }
// if(Func.isNotEmpty(vectorStrs)) {
// // 若解析出来的向量存在值
// Float[] vectors = this.compute(vectorStrs);
// List<Map<String, Float[]>> vector = this.splitVector(entities.size(), attrsMap, vectors);
// for(int i = 0; i < entities.size(); i++) {
// // log.info("vector:{}",JSON.toJSONString(vector.get(i)));
// Map<String, Object> properties = this.objectToMap(entities.get(i));
// log.info("properties:{}",JSON.toJSONString(properties));
// Result<WeaviateObject> run = creator.withProperties(properties).withVectors(vector.get(i)).run();
// if(run.hasErrors()) {
// log.error("保存失败!,保存结果为:{}",JSON.toJSONString(run));
// }
// }
// } else {
// entities.forEach(entity -> creator.withProperties(this.objectToMap(entity)).run());
// return true;
// }
// return false;
// }
public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) { public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) {
ObjectCreator creator = weaviateClient.data().creator().withClassName(className); Map<String,String> createTableParams = new HashMap<>(1);
List<String> vectorStrs = Lists.newArrayList(); createTableParams.put("table_name",className);
List<String> attrs = Lists.newArrayList(); RequestClientUtil.postCall(gglmUrl + invokeApi.getCreateTable(),createTableParams);
if(Func.isNotEmpty(attrsMap)) { Map<String,Object> params = new HashMap<>(2);
// 格式化数据 params.put("table_name", className);
attrsMap.forEach((k,v) -> attrs.add(v)); List<Map<String, Object>> data = new ArrayList<>();
// 解析待计算的向量字段 entities.forEach(entity -> data.add(this.getVectorData(entity,attrsMap)));
entities.forEach(entity -> { log.info("data:{}",JSON.toJSONString(data));
List<String> vectorStr = attrs.stream().map(fields -> this.getFieldValue(fields, entity)).filter(Func::isNotEmpty).collect(Collectors.toList()); params.put("data",data);
vectorStrs.addAll(vectorStr); String url = gglmUrl + invokeApi.getInsertVectors();
}); RequestClientUtil.postCall(url,params);
}
if(Func.isNotEmpty(vectorStrs)) {
// 若解析出来的向量存在值
Float[] vectors = this.compute(vectorStrs);
List<Map<String, Float[]>> vector = this.splitVector(entities.size(), attrsMap, vectors);
for(int i = 0; i < entities.size(); i++) {
// log.info("vector:{}",JSON.toJSONString(vector.get(i)));
Map<String, Object> properties = this.objectToMap(entities.get(i));
log.info("properties:{}",JSON.toJSONString(properties));
Result<WeaviateObject> run = creator.withProperties(properties).withVectors(vector.get(i)).run();
if(run.hasErrors()) {
log.error("保存失败!,保存结果为:{}",JSON.toJSONString(run));
}
}
} else {
entities.forEach(entity -> creator.withProperties(this.objectToMap(entity)).run());
return true; return true;
} }
return false;
private Map<String,Object> getVectorData(Object entity,Map<String,String> attrsMap) {
Map<String,Object> result = new HashMap<>(2);
result.put("object", entity);
List<Map<String,String>> vectors = new ArrayList<>();
attrsMap.forEach((k,fields) -> {
Map<String,String> vector = new HashMap<>();
vector.put("key",k);
vector.put("content", this.getFieldValue(fields, entity));
vectors.add(vector);
});
result.put("vector", vectors);
return result;
} }
private Map<String,Object> objectToMap(Object object) { private Map<String,Object> objectToMap(Object object) {

11
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/interactive/controller/FontEndInteractiveController.java

@ -14,10 +14,7 @@ import org.springblade.core.secure.utils.AuthUtil;
import org.springblade.core.tool.api.IResultCode; import org.springblade.core.tool.api.IResultCode;
import org.springblade.core.tool.api.R; import org.springblade.core.tool.api.R;
import org.springblade.core.tool.utils.Func; import org.springblade.core.tool.utils.Func;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import java.util.List; import java.util.List;
@ -75,11 +72,11 @@ public class FontEndInteractiveController {
@ApiOperation("识别智能表单") @ApiOperation("识别智能表单")
@ApiOperationSupport(order = 6) @ApiOperationSupport(order = 6)
@GetMapping("/identifyForm") @PostMapping("/identifyForm")
public R identifyForm(@RequestParam("formStructure") @ApiParam("表单结构") String formStructure public R identifyForm(@RequestParam("formStructure") @ApiParam("表单结构") String formStructure
, @RequestParam(value = "file",required = false) @ApiParam("表单文件") MultipartFile file , @RequestParam(value = "file",required = false) @ApiParam("表单文件") MultipartFile file
, @RequestParam(value = "content",required = false) @ApiParam("用于提取的文本") String content , @RequestParam(value = "content",required = false) @ApiParam("用于提取的文本") String content
, @RequestParam(value = "chatId") String chatId) { , @RequestParam(value = "chatId",required = false) String chatId) {
if(Func.isEmpty(chatId)) { if(Func.isEmpty(chatId)) {
chatId = UUID.randomUUID().toString(); chatId = UUID.randomUUID().toString();
} }
@ -95,6 +92,6 @@ public class FontEndInteractiveController {
if(Func.isNotEmpty(answers) && answers.size() == 1 && answers.get(0).getStatus().intValue() == 0) { if(Func.isNotEmpty(answers) && answers.size() == 1 && answers.get(0).getStatus().intValue() == 0) {
return R.data(answers.get(0)); return R.data(answers.get(0));
} }
return R.data(null); return R.success("操作成功!");
} }
} }

2
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/DataSourceService.java

@ -1,5 +1,6 @@
package com.hnac.gglm.bigmodel.maintenance.service; package com.hnac.gglm.bigmodel.maintenance.service;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.hnac.gglm.bigmodel.maintenance.mapper.DatasourceMapper; import com.hnac.gglm.bigmodel.maintenance.mapper.DatasourceMapper;
import com.hnac.gglm.bigmodel.maintenance.entity.DatasourceEntity; import com.hnac.gglm.bigmodel.maintenance.entity.DatasourceEntity;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -14,6 +15,7 @@ import org.springframework.stereotype.Service;
@Service("dataSourceMaintenanceService") @Service("dataSourceMaintenanceService")
@Slf4j @Slf4j
@AllArgsConstructor @AllArgsConstructor
@DS("hznlm")
public class DataSourceService extends BaseServiceImpl<DatasourceMapper, DatasourceEntity> { public class DataSourceService extends BaseServiceImpl<DatasourceMapper, DatasourceEntity> {
} }

2
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/TablePropertyService.java

@ -1,5 +1,6 @@
package com.hnac.gglm.bigmodel.maintenance.service; package com.hnac.gglm.bigmodel.maintenance.service;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.hnac.gglm.bigmodel.maintenance.mapper.TablePropertyMapper; import com.hnac.gglm.bigmodel.maintenance.mapper.TablePropertyMapper;
import com.hnac.gglm.bigmodel.maintenance.entity.TablePropertyEntity; import com.hnac.gglm.bigmodel.maintenance.entity.TablePropertyEntity;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -14,6 +15,7 @@ import org.springframework.stereotype.Service;
@Service @Service
@AllArgsConstructor @AllArgsConstructor
@Slf4j @Slf4j
@DS("hznlm")
public class TablePropertyService extends BaseServiceImpl<TablePropertyMapper, TablePropertyEntity> { public class TablePropertyService extends BaseServiceImpl<TablePropertyMapper, TablePropertyEntity> {
} }

3
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/maintenance/service/impl/VectorParamServiceImpl.java

@ -135,7 +135,8 @@ public class VectorParamServiceImpl extends ServiceImpl<VectorParamMapper, Vecto
String key = iterator.next(); String key = iterator.next();
attrMap.put(key, rootNode.findValue(key).textValue()); attrMap.put(key, rootNode.findValue(key).textValue());
} }
weaviateService.saveBatch(response.getOriginalData(), entity.getTableName(), attrMap); String tableName = entity.getTableName().replace(entity.getProjectPrefix() + "_","");
weaviateService.saveBatch(response.getOriginalData(), tableName, attrMap);
this.update(Wrappers.<VectorParamEntity>lambdaUpdate().eq(VectorParamEntity::getId, id).set(VectorParamEntity::getUpdateTime, new Date())); this.update(Wrappers.<VectorParamEntity>lambdaUpdate().eq(VectorParamEntity::getId, id).set(VectorParamEntity::getUpdateTime, new Date()));
} }
} }

7
hzims-service/gglm-big-model/src/main/java/com/hnac/gglm/bigmodel/utils/RequestClientUtil.java

@ -8,6 +8,7 @@ import com.alibaba.fastjson.TypeReference;
import com.hnac.hzinfo.exception.HzServiceException; import com.hnac.hzinfo.exception.HzServiceException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springblade.core.tool.api.ResultCode; import org.springblade.core.tool.api.ResultCode;
import org.springblade.core.tool.utils.Func;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -28,7 +29,9 @@ public class RequestClientUtil {
*/ */
public static void postCall(String url, String body) { public static void postCall(String url, String body) {
HttpResponse response = HttpRequest.post(url).body(body).execute(); HttpResponse response = HttpRequest.post(url).body(body).execute();
if(Func.isNotEmpty(response.body()) && !"[]".equals(response.body())) {
log.info("接口调用结果为:{}",response.body()); log.info("接口调用结果为:{}",response.body());
}
Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> { Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> {
throw new HzServiceException(ResultCode.FAILURE, "远程调用接口" + url + "失败!"); throw new HzServiceException(ResultCode.FAILURE, "远程调用接口" + url + "失败!");
}); });
@ -41,7 +44,9 @@ public class RequestClientUtil {
*/ */
public static void postCall(String url, Map body) { public static void postCall(String url, Map body) {
HttpResponse response = HttpRequest.post(url).body(JSON.toJSONString(body)).execute(); HttpResponse response = HttpRequest.post(url).body(JSON.toJSONString(body)).execute();
if(Func.isNotEmpty(response.body()) && !"[]".equals(response.body())) {
log.info("接口调用结果为:{}",response.body()); log.info("接口调用结果为:{}",response.body());
}
Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> { Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> {
throw new HzServiceException(ResultCode.FAILURE, "远程调用接口" + url + "失败!"); throw new HzServiceException(ResultCode.FAILURE, "远程调用接口" + url + "失败!");
}); });
@ -60,7 +65,9 @@ public class RequestClientUtil {
Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> { Assert.isTrue(response.getStatus() == HttpServletResponse.SC_OK, () -> {
throw new HzServiceException(ResultCode.FAILURE, "远程调用大模型接口" + url + "失败!"); throw new HzServiceException(ResultCode.FAILURE, "远程调用大模型接口" + url + "失败!");
}); });
if(Func.isNotEmpty(response.body()) && !"[]".equals(response.body())) {
log.info("接口调用结果为:{}",response.body()); log.info("接口调用结果为:{}",response.body());
}
return JSONObject.parseObject(response.body(), typeRef); return JSONObject.parseObject(response.body(), typeRef);
} }

3
hzims-service/gglm-big-model/src/main/resources/template/template.yml

@ -64,6 +64,9 @@ gglm:
assistantAnalyseAsk: "/qa/assistant_analyse_ask" assistantAnalyseAsk: "/qa/assistant_analyse_ask"
updateKnowledge: "/kn/update_knowledge" updateKnowledge: "/kn/update_knowledge"
compute: "/compute" compute: "/compute"
identifyForm: "/custom/auto_form"
insertVectors: "/vector/insert_vectors"
createTable: "/vector/create_table"
swagger: swagger:
base-packages: com.hnac.hzims.bigmodel base-packages: com.hnac.hzims.bigmodel

Loading…
Cancel
Save