|
|
|
@ -6,28 +6,26 @@ import cn.hutool.json.JSONUtil;
|
|
|
|
|
import com.alibaba.fastjson.JSON; |
|
|
|
|
import com.alibaba.fastjson.JSONObject; |
|
|
|
|
import com.google.common.collect.Lists; |
|
|
|
|
import com.google.protobuf.ServiceException; |
|
|
|
|
import com.hnac.hzims.bigmodel.configuration.BigModelInvokeApi; |
|
|
|
|
import com.hnac.hzinfo.exception.HzServiceException; |
|
|
|
|
import io.weaviate.client.base.Result; |
|
|
|
|
import io.weaviate.client.v1.data.api.ObjectCreator; |
|
|
|
|
import io.weaviate.client.v1.data.api.ObjectDeleter; |
|
|
|
|
import io.weaviate.client.v1.data.api.ObjectUpdater; |
|
|
|
|
import io.weaviate.client.v1.data.api.ObjectsGetter; |
|
|
|
|
import io.weaviate.client.v1.data.model.WeaviateObject; |
|
|
|
|
import lombok.RequiredArgsConstructor; |
|
|
|
|
import lombok.extern.slf4j.Slf4j; |
|
|
|
|
import org.springblade.core.tool.utils.BeanUtil; |
|
|
|
|
import org.springblade.core.tool.utils.Func; |
|
|
|
|
import org.springblade.core.tool.utils.StringUtil; |
|
|
|
|
import org.springframework.beans.factory.annotation.Value; |
|
|
|
|
import org.springframework.stereotype.Service; |
|
|
|
|
import org.springframework.util.Assert; |
|
|
|
|
|
|
|
|
|
import java.lang.reflect.Field; |
|
|
|
|
import java.nio.ByteBuffer; |
|
|
|
|
import java.nio.ByteOrder; |
|
|
|
|
import java.util.*; |
|
|
|
|
import java.util.stream.Collector; |
|
|
|
|
import java.util.concurrent.atomic.AtomicInteger; |
|
|
|
|
import java.util.stream.Collectors; |
|
|
|
|
import java.util.stream.IntStream; |
|
|
|
|
|
|
|
|
@ -43,9 +41,11 @@ public class WeaviateService {
|
|
|
|
|
private final ObjectCreator objectCreator; |
|
|
|
|
private final ObjectUpdater objectUpdater; |
|
|
|
|
private final ObjectDeleter objectDeleter; |
|
|
|
|
private final ObjectsGetter objectsGetter; |
|
|
|
|
private final BigModelInvokeApi invokeApi; |
|
|
|
|
|
|
|
|
|
@Value("${gglm.vectorUrl}") |
|
|
|
|
private final String vectorUrl; |
|
|
|
|
private String vectorUrl; |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 对象保存向量数据库 |
|
|
|
@ -89,7 +89,10 @@ public class WeaviateService {
|
|
|
|
|
if(Func.isNotEmpty(vectorStrs)) { |
|
|
|
|
// 若解析出来的向量存在值
|
|
|
|
|
Float[] vectors = this.compute(vectorStrs); |
|
|
|
|
|
|
|
|
|
List<Map<String, Float[]>> vector = this.splitVector(entities.size(), attrsMap, vectors); |
|
|
|
|
for(int i = 0; i < entities.size(); i++) { |
|
|
|
|
creator.withProperties(BeanUtil.toMap(entities.get(i))).withVectors(vector.get(i)).run(); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
entities.forEach(entity -> creator.withProperties(BeanUtil.toMap(entity)).run()); |
|
|
|
|
return true; |
|
|
|
@ -98,6 +101,63 @@ public class WeaviateService {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 删除向量数据库(表名) |
|
|
|
|
* @param className 表名 |
|
|
|
|
* @return 删除结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean deleteByClassName(String className) { |
|
|
|
|
Result<Boolean> result = objectDeleter.withClassName(className).run(); |
|
|
|
|
return !result.hasErrors(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 删除向量数据库(ID) |
|
|
|
|
* @param id 向量数据库ID |
|
|
|
|
* @return 删除结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean deleteById(String id) { |
|
|
|
|
Result<Boolean> result = objectDeleter.withID(id).run(); |
|
|
|
|
return !result.hasErrors(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 更新数据库(通过ID) |
|
|
|
|
* @param id 向量数据库ID |
|
|
|
|
* @return 更新结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean updateById(String id, Object entity, String className, Map<String,String> attrMap) { |
|
|
|
|
ObjectUpdater updater = objectUpdater.withClassName(className).withID(id).withProperties(BeanUtil.toMap(entity)); |
|
|
|
|
// 计算向量
|
|
|
|
|
Map<String, Float[]> vector = new HashMap<>(); |
|
|
|
|
if(Func.isNotEmpty(attrMap)) { |
|
|
|
|
attrMap.forEach((k,v) -> { |
|
|
|
|
String fieldValue = this.getFieldValue(v, entity); |
|
|
|
|
Float[] compute = this.compute(Lists.newArrayList(fieldValue)); |
|
|
|
|
vector.put(k,compute); |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|
if(Func.isNotEmpty(vector)) { |
|
|
|
|
updater.withVectors(vector); |
|
|
|
|
} |
|
|
|
|
Result<Boolean> result = updater.run(); |
|
|
|
|
return !result.hasErrors(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public List<Map> list(String id,String className) { |
|
|
|
|
if(Func.isNotEmpty(id)) { |
|
|
|
|
objectsGetter.withID(id); |
|
|
|
|
} |
|
|
|
|
if(Func.isNotEmpty(className)) { |
|
|
|
|
objectsGetter.withClassName(className); |
|
|
|
|
} |
|
|
|
|
Result<List<WeaviateObject>> result = objectsGetter.run(); |
|
|
|
|
if(result.hasErrors()) { |
|
|
|
|
throw new HzServiceException("查询失败!"); |
|
|
|
|
} |
|
|
|
|
return result.getResult().stream().map(WeaviateObject::getProperties).collect(Collectors.toList()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 拆解计算出来的向量Float[] |
|
|
|
|
* @param entitySize 对象列表size |
|
|
|
|
* @param attrsMap 待计算的列信息 key-向量名 value-实体类对象属性,多个按逗号分隔 |
|
|
|
@ -106,13 +166,22 @@ public class WeaviateService {
|
|
|
|
|
*/ |
|
|
|
|
private List<Map<String,Float[]>> splitVector(Integer entitySize,Map<String,String> attrsMap,Float[] vectorTotal) { |
|
|
|
|
List<Map<String,Float[]>> result = Lists.newArrayList(); |
|
|
|
|
|
|
|
|
|
List<Float> vectorTotalList = Lists.newArrayList(vectorTotal); |
|
|
|
|
// 获取待切割的下标
|
|
|
|
|
List<Integer> indexes = this.getSplitIndex(vectorTotal.length, entitySize); |
|
|
|
|
int step = vectorTotal.length / entitySize; |
|
|
|
|
indexes.forEach(index -> { |
|
|
|
|
|
|
|
|
|
List<Float> vectors = vectorTotalList.subList(index, index + step); |
|
|
|
|
Map<String,Float[]> vectorMap = new HashMap<>(); |
|
|
|
|
List<Integer> splitIndex = this.getSplitIndex(vectors.size(), attrsMap.size()); |
|
|
|
|
AtomicInteger i = new AtomicInteger(); |
|
|
|
|
attrsMap.forEach((k,v) -> { |
|
|
|
|
List<Float> vector = vectors.subList(splitIndex.get(i.get()), splitIndex.get(i.get() + (vectors.size() / attrsMap.size()))); |
|
|
|
|
vectorMap.put(k, vector.toArray(new Float[vector.size()])); |
|
|
|
|
i.getAndIncrement(); |
|
|
|
|
}); |
|
|
|
|
return null; |
|
|
|
|
}); |
|
|
|
|
return result; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|