|
|
|
@ -8,9 +8,12 @@ import com.alibaba.fastjson.JSONArray;
|
|
|
|
|
import com.alibaba.fastjson.JSONObject; |
|
|
|
|
import com.alibaba.fastjson.TypeReference; |
|
|
|
|
import com.alibaba.fastjson.serializer.SerializerFeature; |
|
|
|
|
import com.baomidou.mybatisplus.core.metadata.IPage; |
|
|
|
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; |
|
|
|
|
import com.google.common.collect.Lists; |
|
|
|
|
import com.hnac.gglm.bigmodel.BigModelConstants; |
|
|
|
|
import com.hnac.gglm.bigmodel.configuration.BigModelInvokeApi; |
|
|
|
|
import com.hnac.gglm.bigmodel.configuration.WeaviateProperties; |
|
|
|
|
import com.hnac.gglm.bigmodel.utils.RequestClientUtil; |
|
|
|
|
import com.hnac.hzims.fdp.constants.ScheduledConstant; |
|
|
|
|
import com.hnac.hzinfo.exception.HzServiceException; |
|
|
|
@ -32,6 +35,7 @@ import io.weaviate.client.v1.graphql.query.Get;
|
|
|
|
|
import io.weaviate.client.v1.graphql.query.fields.Field; |
|
|
|
|
import lombok.RequiredArgsConstructor; |
|
|
|
|
import lombok.extern.slf4j.Slf4j; |
|
|
|
|
import org.springblade.core.mp.support.Query; |
|
|
|
|
import org.springblade.core.tool.api.ResultCode; |
|
|
|
|
import org.springblade.core.tool.utils.BeanUtil; |
|
|
|
|
import org.springblade.core.tool.utils.Func; |
|
|
|
@ -57,39 +61,36 @@ public class WeaviateService {
|
|
|
|
|
|
|
|
|
|
private final WeaviateClient weaviateClient; |
|
|
|
|
private final BigModelInvokeApi invokeApi; |
|
|
|
|
private final WeaviateProperties weaviateProperties; |
|
|
|
|
|
|
|
|
|
@Value("${gglm.vectorUrl}") |
|
|
|
|
private String vectorUrl; |
|
|
|
|
@Value("${gglm.url}") |
|
|
|
|
private String gglmUrl; |
|
|
|
|
|
|
|
|
|
private static final String QUERY_TEMPLATE = "{Aggregate {%s {meta {count}}}}"; |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 对象保存向量数据库 |
|
|
|
|
* @param entity 保存对象 |
|
|
|
|
* @param className 保存表名 |
|
|
|
|
* @param attrs 待计算的列信息 |
|
|
|
|
* @param attrsMap |
|
|
|
|
* @return 保存操作结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean save(Object entity, String className, List<String> attrs) { |
|
|
|
|
ObjectCreator creator = weaviateClient.data().creator().withClassName(className); |
|
|
|
|
if(Func.isNotEmpty(attrs)) { |
|
|
|
|
JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(entity)); |
|
|
|
|
List<String> vectors = attrs.stream().map(attr -> jsonObject.getString(attr)).collect(Collectors.toList()); |
|
|
|
|
Float[] compute = this.compute(vectors); |
|
|
|
|
creator.withVector(compute); |
|
|
|
|
} |
|
|
|
|
Result<WeaviateObject> result = creator.withProperties(BeanUtil.toMap(entity)).run(); |
|
|
|
|
return !result.hasErrors(); |
|
|
|
|
public String save(Object entity, String className, Map<String,String> attrsMap) { |
|
|
|
|
return this.saveBatch(Lists.newArrayList(entity),className, attrsMap); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public String saveBatch(List entities,String className, Map<String,String> attrsMap) { |
|
|
|
|
String tableName = getRealTableName(className); |
|
|
|
|
String realClassName = this.getRealClassName(className); |
|
|
|
|
// 查询表是否存在 若不存则新建表
|
|
|
|
|
Result<Boolean> existResult = weaviateClient.schema().exists().withClassName(BigModelConstants.PREFIX + "_" + className).run(); |
|
|
|
|
Result<Boolean> existResult = weaviateClient.schema().exists().withClassName(tableName).run(); |
|
|
|
|
if(existResult.hasErrors() || !existResult.getResult()) { |
|
|
|
|
Map<java.lang.String,Object> createTableParams = new HashMap<>(2); |
|
|
|
|
Map<String,String> deleteTableParams = new HashMap<>(1); |
|
|
|
|
deleteTableParams.put("table_name",className); |
|
|
|
|
createTableParams.put("table_name",className); |
|
|
|
|
deleteTableParams.put("table_name", realClassName); |
|
|
|
|
createTableParams.put("table_name", realClassName); |
|
|
|
|
List<String> vectorStr = Lists.newArrayList(); |
|
|
|
|
attrsMap.keySet().forEach(key -> vectorStr.add(key)); |
|
|
|
|
createTableParams.put("vector_names",vectorStr.toArray(new String[vectorStr.size()])); |
|
|
|
@ -97,18 +98,26 @@ public class WeaviateService {
|
|
|
|
|
RequestClientUtil.postCall(gglmUrl + invokeApi.getCreateTable(),createTableParams); |
|
|
|
|
} |
|
|
|
|
Map<String,Object> params = new HashMap<>(2); |
|
|
|
|
params.put("table_name", className); |
|
|
|
|
params.put("table_name", realClassName); |
|
|
|
|
// 将entities按size截断为1000个一组
|
|
|
|
|
List<List> entitiesList = splitList(entities, 1000); |
|
|
|
|
int total = 0; |
|
|
|
|
for (List entityList : entitiesList) { |
|
|
|
|
Integer insert = this.insert(entityList, attrsMap, params); |
|
|
|
|
total += insert; |
|
|
|
|
total = insert; |
|
|
|
|
} |
|
|
|
|
// 查询weaviate 中该表的数据量
|
|
|
|
|
return String.format("传入数据总量为:%s 保存成功数量为:%s", entities.size(), total); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private String getRealTableName(String className) { |
|
|
|
|
return className.contains(weaviateProperties.getDatabase()) ? className : weaviateProperties.getDatabase() + "_" + className; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private String getRealClassName(String className) { |
|
|
|
|
return className.contains(weaviateProperties.getDatabase()) ? className.replace(weaviateProperties.getDatabase() + "_","") : className; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* 根据条件删除数据 |
|
|
|
|
* @param className 表名 |
|
|
|
@ -116,15 +125,18 @@ public class WeaviateService {
|
|
|
|
|
* @return 删除结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean deleteCondition(String className,Map<String,String> condition) { |
|
|
|
|
// 获取到实际表名
|
|
|
|
|
String tableName = this.getRealTableName(className); |
|
|
|
|
// 查询到相关数据
|
|
|
|
|
Object query = this.query(null, className, condition); |
|
|
|
|
Object query = this.query(null, tableName, condition); |
|
|
|
|
if(Func.isEmpty(query)) { |
|
|
|
|
throw new HzServiceException("暂无数据,删除失败!"); |
|
|
|
|
} |
|
|
|
|
JSONObject queryJson = JSONObject.parseObject(JSON.toJSONString(query)); |
|
|
|
|
JSONArray data = Optional.ofNullable(queryJson).map(json -> json.getJSONObject("Get")) |
|
|
|
|
.map(json -> json.getJSONArray(className)).orElse(null); |
|
|
|
|
.map(json -> json.getJSONArray(tableName)).orElse(null); |
|
|
|
|
if(Func.isNotEmpty(data)) { |
|
|
|
|
// 获取到主键id,根据id删除数据
|
|
|
|
|
List<String> ids = data.stream().map(item -> { |
|
|
|
|
JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(item)); |
|
|
|
|
return Optional.ofNullable(jsonObject) |
|
|
|
@ -133,7 +145,7 @@ public class WeaviateService {
|
|
|
|
|
.orElse(""); |
|
|
|
|
}).filter(Func::isNotEmpty).collect(Collectors.toList()); |
|
|
|
|
if(Func.isNotEmpty(ids)) { |
|
|
|
|
this.delete(ids.stream().collect(Collectors.joining(",")), className); |
|
|
|
|
this.delete(ids.stream().collect(Collectors.joining(",")), tableName); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return true; |
|
|
|
@ -164,7 +176,7 @@ public class WeaviateService {
|
|
|
|
|
private Integer insert(List entities, Map<String,String> attrsMap, Map<String,Object> params) { |
|
|
|
|
List<Map<String, Object>> data = new ArrayList<>(); |
|
|
|
|
entities.forEach(entity -> { |
|
|
|
|
// 将entity转换为Map<String,String>
|
|
|
|
|
// 将entity转换为 Map<String,String> 根据数据以及向量字段对向量进行拼接后调用大模型接口进行保存操作
|
|
|
|
|
JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(entity, SerializerFeature.WriteMapNullValue)); |
|
|
|
|
Map<String,Object> map = new HashMap<>(); |
|
|
|
|
jsonObject.forEach((k,v) -> map.put(k,jsonObject.get(k))); |
|
|
|
@ -211,15 +223,17 @@ public class WeaviateService {
|
|
|
|
|
* @return 删除结果 |
|
|
|
|
*/ |
|
|
|
|
public Boolean delete(String ids,String className) { |
|
|
|
|
if(Func.isEmpty(ids) && Func.isNotEmpty(className)) { |
|
|
|
|
String tableName = this.getRealTableName(className); |
|
|
|
|
String realClassName = this.getRealClassName(className); |
|
|
|
|
if(Func.isEmpty(ids)) { |
|
|
|
|
// 删除className
|
|
|
|
|
Map<String,String> deleteTableParams = new HashMap<>(1); |
|
|
|
|
deleteTableParams.put("table_name",className.replace(BigModelConstants.PREFIX + "_","")); |
|
|
|
|
deleteTableParams.put("table_name", realClassName); |
|
|
|
|
RequestClientUtil.postCall(gglmUrl + invokeApi.getDeleteTable(),deleteTableParams); |
|
|
|
|
} else { |
|
|
|
|
// 删除记录
|
|
|
|
|
ObjectDeleter deleter = weaviateClient.data().deleter(); |
|
|
|
|
deleter.withClassName(className); |
|
|
|
|
deleter.withClassName(tableName); |
|
|
|
|
Func.toStrList(",",ids).forEach(id -> { |
|
|
|
|
Result<Boolean> result = deleter.withID(id).run(); |
|
|
|
|
if(result.hasErrors()) { |
|
|
|
@ -235,6 +249,7 @@ public class WeaviateService {
|
|
|
|
|
* @param id 向量数据库ID |
|
|
|
|
* @return 更新结果 |
|
|
|
|
*/ |
|
|
|
|
@Deprecated |
|
|
|
|
public Boolean updateById(String id, Object entity, String className, Map<String,String> attrMap) { |
|
|
|
|
ObjectUpdater updater = weaviateClient.data().updater().withClassName(className).withID(id).withProperties(BeanUtil.toMap(entity)); |
|
|
|
|
// 计算向量
|
|
|
|
@ -253,13 +268,49 @@ public class WeaviateService {
|
|
|
|
|
return !result.hasErrors(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public IPage<WeaviateObject> page(String className, Query query) { |
|
|
|
|
IPage<WeaviateObject> result = new Page<>(); |
|
|
|
|
Integer offset = (query.getCurrent() - 1) * query.getSize(); |
|
|
|
|
ObjectsGetter objectsGetter = weaviateClient.data().objectsGetter(); |
|
|
|
|
String realTableName = this.getRealTableName(className); |
|
|
|
|
// 查询数据库表数据总量
|
|
|
|
|
Integer classTotal = this.getClassTotal(className); |
|
|
|
|
Result<List<WeaviateObject>> run = objectsGetter.withClassName(realTableName).withLimit(query.getSize()).withOffset(offset).run(); |
|
|
|
|
if(run.getResult() == null && run.hasErrors()) { |
|
|
|
|
throw new HzServiceException("数据库暂无" + realTableName + "数据表信息,同步数据后查询"); |
|
|
|
|
} |
|
|
|
|
result.setRecords(run.getResult()); |
|
|
|
|
result.setTotal(classTotal); |
|
|
|
|
result.setSize(query.getSize()); |
|
|
|
|
result.setCurrent(query.getCurrent()); |
|
|
|
|
return result; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public Integer getClassTotal(String className) { |
|
|
|
|
String realTableName = this.getRealTableName(className); |
|
|
|
|
String url = weaviateProperties.getSchema() + "://" + weaviateProperties.getHost() + ":" + weaviateProperties.getPort() + "/v1/graphql"; |
|
|
|
|
String query = "{Aggregate {" + realTableName + " {meta {count}}}}"; |
|
|
|
|
Map<String,Object> params = new HashMap<>(); |
|
|
|
|
params.put("query",query); |
|
|
|
|
String authorization = "Bearer " + weaviateProperties.getApiKey(); |
|
|
|
|
try { |
|
|
|
|
String body = HttpRequest.post(url).header("Authorization", authorization).body(JSON.toJSONString(params, SerializerFeature.WriteMapNullValue)).execute().body(); |
|
|
|
|
JSONArray dataArray = JSONObject.parseObject(body).getJSONObject("data").getJSONObject("Aggregate").getJSONArray(realTableName); |
|
|
|
|
return Optional.ofNullable(dataArray).map(array -> array.getJSONObject(0).getJSONObject("meta").getInteger("count")).orElse(0); |
|
|
|
|
} catch (Exception e) { |
|
|
|
|
e.printStackTrace(); |
|
|
|
|
throw new HzServiceException("获取" + className + "数据总量失败,查询失败!"); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public List<WeaviateObject> list(String id,String className,Integer current,Integer pageSize) { |
|
|
|
|
String tableName = this.getRealTableName(className); |
|
|
|
|
ObjectsGetter objectsGetter = weaviateClient.data().objectsGetter(); |
|
|
|
|
if(Func.isNotEmpty(id)) { |
|
|
|
|
objectsGetter.withID(id); |
|
|
|
|
} |
|
|
|
|
if(Func.isNotEmpty(className)) { |
|
|
|
|
objectsGetter.withClassName(className); |
|
|
|
|
objectsGetter.withClassName(tableName); |
|
|
|
|
} |
|
|
|
|
Result<List<WeaviateObject>> result = objectsGetter.withLimit(pageSize).withOffset((current-1) * pageSize).run(); |
|
|
|
|
if(result.hasErrors()) { |
|
|
|
@ -361,9 +412,10 @@ public class WeaviateService {
|
|
|
|
|
* @return 查询结果 |
|
|
|
|
*/ |
|
|
|
|
public Object query(String resultFields,String className,Map<String,String> query) { |
|
|
|
|
String realTableName = this.getRealTableName(className); |
|
|
|
|
List<String> fieldList = Func.toStrList(",", resultFields); |
|
|
|
|
Get get = weaviateClient.graphQL().get(); |
|
|
|
|
get.withClassName(className); |
|
|
|
|
get.withClassName(realTableName); |
|
|
|
|
List<Field> fields = fieldList.stream().map(fieldStr -> Field.builder().name(fieldStr).build()).collect(Collectors.toList()); |
|
|
|
|
Field additionalId = Field.builder().name("_additional { id }").build(); |
|
|
|
|
fields.add(additionalId); |
|
|
|
|