Browse Source

fix: 向量数据库操作api

zhongwei
haungxing 3 months ago
parent
commit
94fd6ace14
  1. 38
      hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/controller/WeaviateController.java
  2. 33
      hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/dto/WeaviateSaveDTO.java
  3. 48
      hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/service/WeaviateService.java
  4. 9
      hzims-service/hzims-big-model/src/main/resources/template/template.yml

38
hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/controller/WeaviateController.java

@ -0,0 +1,38 @@
package com.hnac.hzims.bigmodel.database.controller;
import com.hnac.hzims.bigmodel.database.dto.WeaviateSaveDTO;
import com.hnac.hzims.bigmodel.database.service.WeaviateService;
import io.weaviate.client.v1.data.model.WeaviateObject;
import lombok.AllArgsConstructor;
import org.springblade.core.tool.api.R;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* @Author: huangxing
* @Date: 2024/09/04 14:16
*/
@RestController
@AllArgsConstructor
@RequestMapping("/weaviate")
public class WeaviateController {
private final WeaviateService weaviateService;
@PostMapping("/saveBatch")
public R<Boolean> saveBatch(@RequestBody WeaviateSaveDTO req) {
weaviateService.saveBatch(req.getEntities(), req.getClassName(), req.getAttrsMap());
return R.success("操作成功!");
}
@GetMapping("/list")
public R<List<WeaviateObject>> list(@RequestParam(value = "id",required = false) String id, @RequestParam("className") String className) {
return R.data(weaviateService.list(id,className));
}
@DeleteMapping("/removeById")
public R<Boolean> removeById(@RequestParam(value = "id",required = false) String id, @RequestParam("className") String className) {
return R.status(weaviateService.delete(id,className));
}
}

33
hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/dto/WeaviateSaveDTO.java

@ -0,0 +1,33 @@
package com.hnac.hzims.bigmodel.database.dto;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* @Author: huangxing
* @Date: 2024/09/04 14:36
*/
@Data
@EqualsAndHashCode
public class WeaviateSaveDTO implements Serializable {
/**
* 向量数据库表名
*/
private String className;
/**
* 向量数据库属性名
*/
private Map<String,String> attrsMap;
/**
* 向量数据库存入对象列表
*/
private List entities;
}

48
hzims-service/hzims-big-model/src/main/java/com/hnac/hzims/bigmodel/database/service/WeaviateService.java

@ -8,6 +8,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.hnac.hzims.bigmodel.configuration.BigModelInvokeApi; import com.hnac.hzims.bigmodel.configuration.BigModelInvokeApi;
import com.hnac.hzinfo.exception.HzServiceException; import com.hnac.hzinfo.exception.HzServiceException;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result; import io.weaviate.client.base.Result;
import io.weaviate.client.v1.data.api.ObjectCreator; import io.weaviate.client.v1.data.api.ObjectCreator;
import io.weaviate.client.v1.data.api.ObjectDeleter; import io.weaviate.client.v1.data.api.ObjectDeleter;
@ -38,10 +39,7 @@ import java.util.stream.IntStream;
@Slf4j @Slf4j
public class WeaviateService { public class WeaviateService {
private final ObjectCreator objectCreator; private final WeaviateClient weaviateClient;
private final ObjectUpdater objectUpdater;
private final ObjectDeleter objectDeleter;
private final ObjectsGetter objectsGetter;
private final BigModelInvokeApi invokeApi; private final BigModelInvokeApi invokeApi;
@Value("${gglm.vectorUrl}") @Value("${gglm.vectorUrl}")
@ -55,7 +53,7 @@ public class WeaviateService {
* @return 保存操作结果 * @return 保存操作结果
*/ */
public Boolean save(Object entity, String className, List<String> attrs) { public Boolean save(Object entity, String className, List<String> attrs) {
ObjectCreator creator = objectCreator.withClassName(className); ObjectCreator creator = weaviateClient.data().creator().withClassName(className);
if(Func.isNotEmpty(attrs)) { if(Func.isNotEmpty(attrs)) {
JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(entity)); JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(entity));
List<String> vectors = attrs.stream().map(attr -> jsonObject.getString(attr)).collect(Collectors.toList()); List<String> vectors = attrs.stream().map(attr -> jsonObject.getString(attr)).collect(Collectors.toList());
@ -74,7 +72,7 @@ public class WeaviateService {
* @return 保存操作结果 * @return 保存操作结果
*/ */
public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) { public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) {
ObjectCreator creator = objectCreator.withClassName(className); ObjectCreator creator = weaviateClient.data().creator().withClassName(className);
List<String> vectorStrs = Lists.newArrayList(); List<String> vectorStrs = Lists.newArrayList();
List<String> attrs = Lists.newArrayList(); List<String> attrs = Lists.newArrayList();
if(Func.isNotEmpty(attrsMap)) { if(Func.isNotEmpty(attrsMap)) {
@ -91,7 +89,12 @@ public class WeaviateService {
Float[] vectors = this.compute(vectorStrs); Float[] vectors = this.compute(vectorStrs);
List<Map<String, Float[]>> vector = this.splitVector(entities.size(), attrsMap, vectors); List<Map<String, Float[]>> vector = this.splitVector(entities.size(), attrsMap, vectors);
for(int i = 0; i < entities.size(); i++) { for(int i = 0; i < entities.size(); i++) {
creator.withProperties(BeanUtil.toMap(entities.get(i))).withVectors(vector.get(i)).run(); // log.info("vector:{}",JSON.toJSONString(vector.get(i)));
JSONObject object = JSONObject.parseObject(JSON.toJSONString(entities.get(i)));
Map<String,Object> properties = new HashMap<>();
object.forEach((k,v) -> properties.put(k,v));
log.info("properties:{}",JSON.toJSONString(properties));
creator.withProperties(properties).withVectors(vector.get(i)).run();
} }
} else { } else {
entities.forEach(entity -> creator.withProperties(BeanUtil.toMap(entity)).run()); entities.forEach(entity -> creator.withProperties(BeanUtil.toMap(entity)).run());
@ -105,18 +108,15 @@ public class WeaviateService {
* @param className 表名 * @param className 表名
* @return 删除结果 * @return 删除结果
*/ */
public Boolean deleteByClassName(String className) { public Boolean delete(String id,String className) {
Result<Boolean> result = objectDeleter.withClassName(className).run(); ObjectDeleter deleter = weaviateClient.data().deleter();
return !result.hasErrors(); if(Func.isNotEmpty(id)) {
} deleter.withID(id);
}
/** if(Func.isNotEmpty(className)) {
* 删除向量数据库ID deleter.withClassName(className);
* @param id 向量数据库ID }
* @return 删除结果 Result<Boolean> result = deleter.run();
*/
public Boolean deleteById(String id) {
Result<Boolean> result = objectDeleter.withID(id).run();
return !result.hasErrors(); return !result.hasErrors();
} }
@ -126,7 +126,7 @@ public class WeaviateService {
* @return 更新结果 * @return 更新结果
*/ */
public Boolean updateById(String id, Object entity, String className, Map<String,String> attrMap) { public Boolean updateById(String id, Object entity, String className, Map<String,String> attrMap) {
ObjectUpdater updater = objectUpdater.withClassName(className).withID(id).withProperties(BeanUtil.toMap(entity)); ObjectUpdater updater = weaviateClient.data().updater().withClassName(className).withID(id).withProperties(BeanUtil.toMap(entity));
// 计算向量 // 计算向量
Map<String, Float[]> vector = new HashMap<>(); Map<String, Float[]> vector = new HashMap<>();
if(Func.isNotEmpty(attrMap)) { if(Func.isNotEmpty(attrMap)) {
@ -143,7 +143,8 @@ public class WeaviateService {
return !result.hasErrors(); return !result.hasErrors();
} }
public List<Map> list(String id,String className) { public List<WeaviateObject> list(String id,String className) {
ObjectsGetter objectsGetter = weaviateClient.data().objectsGetter();
if(Func.isNotEmpty(id)) { if(Func.isNotEmpty(id)) {
objectsGetter.withID(id); objectsGetter.withID(id);
} }
@ -154,7 +155,7 @@ public class WeaviateService {
if(result.hasErrors()) { if(result.hasErrors()) {
throw new HzServiceException("查询失败!"); throw new HzServiceException("查询失败!");
} }
return result.getResult().stream().map(WeaviateObject::getProperties).collect(Collectors.toList()); return result.getResult();
} }
/** /**
@ -176,10 +177,11 @@ public class WeaviateService {
List<Integer> splitIndex = this.getSplitIndex(vectors.size(), attrsMap.size()); List<Integer> splitIndex = this.getSplitIndex(vectors.size(), attrsMap.size());
AtomicInteger i = new AtomicInteger(); AtomicInteger i = new AtomicInteger();
attrsMap.forEach((k,v) -> { attrsMap.forEach((k,v) -> {
List<Float> vector = vectors.subList(splitIndex.get(i.get()), splitIndex.get(i.get() + (vectors.size() / attrsMap.size()))); List<Float> vector = vectors.subList(splitIndex.get(i.get()), splitIndex.get(i.get()) + (vectors.size() / attrsMap.size()));
vectorMap.put(k, vector.toArray(new Float[vector.size()])); vectorMap.put(k, vector.toArray(new Float[vector.size()]));
i.getAndIncrement(); i.getAndIncrement();
}); });
result.add(vectorMap);
}); });
return result; return result;
} }

9
hzims-service/hzims-big-model/src/main/resources/template/template.yml

@ -63,7 +63,7 @@ gglm:
smartReportGeneratePower: "/custom/smart_report_generate_power" smartReportGeneratePower: "/custom/smart_report_generate_power"
assistantAnalyseAsk: "/qa/assistant_analyse_ask" assistantAnalyseAsk: "/qa/assistant_analyse_ask"
updateKnowledge: "/kn/update_knowledge" updateKnowledge: "/kn/update_knowledge"
compute: "compute" compute: "/compute"
swagger: swagger:
base-packages: com.hnac.hzims.bigmodel base-packages: com.hnac.hzims.bigmodel
@ -85,3 +85,10 @@ bigmodel:
zhipuai: zhipuai:
url: https://open.bigmodel.cn/api/paas/v4/chat/completions url: https://open.bigmodel.cn/api/paas/v4/chat/completions
apiSecret: dfd23052747674818c7ac6f9922beff1.n2o5JEdfnrLbFU53 apiSecret: dfd23052747674818c7ac6f9922beff1.n2o5JEdfnrLbFU53
weaviate:
datasource:
schema: http
host: 192.168.60.16
port: 9992
apiKey: 123

Loading…
Cancel
Save