haungxing
5 months ago
10 changed files with 313 additions and 56 deletions
@ -0,0 +1,23 @@
|
||||
package com.hnac.hzims.bigmodel.database.entity; |
||||
|
||||
import com.alibaba.fastjson.annotation.JSONField; |
||||
import lombok.Data; |
||||
import lombok.EqualsAndHashCode; |
||||
|
||||
import java.io.Serializable; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/22 19:35 |
||||
*/ |
||||
@Data |
||||
@EqualsAndHashCode |
||||
public class WeaviateEntity implements Serializable { |
||||
|
||||
@JSONField(name = "item_id") |
||||
private String itemId; |
||||
|
||||
@JSONField(name = "item_name") |
||||
private String itemName; |
||||
|
||||
} |
@ -1,20 +0,0 @@
|
||||
package com.hnac.hzims.bigmodel.configuration; |
||||
|
||||
import org.springframework.context.annotation.Bean; |
||||
import org.springframework.context.annotation.Configuration; |
||||
import org.springframework.web.reactive.function.client.WebClient; |
||||
import reactor.core.publisher.Mono; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/21 15:22 |
||||
*/ |
||||
@Configuration |
||||
public class WeaviateConfig { |
||||
|
||||
@Bean |
||||
public WebClient weaviateClient() { |
||||
return WebClient.create("http://192.168.60.16:9992"); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,56 @@
|
||||
package com.hnac.hzims.bigmodel.configuration; |
||||
|
||||
import io.weaviate.client.Config; |
||||
import io.weaviate.client.WeaviateAuthClient; |
||||
import io.weaviate.client.WeaviateClient; |
||||
import io.weaviate.client.v1.auth.exception.AuthException; |
||||
import io.weaviate.client.v1.data.api.ObjectCreator; |
||||
import io.weaviate.client.v1.data.api.ObjectDeleter; |
||||
import io.weaviate.client.v1.data.api.ObjectUpdater; |
||||
import io.weaviate.client.v1.data.api.ObjectsGetter; |
||||
import org.springframework.context.annotation.Bean; |
||||
import org.springframework.context.annotation.Configuration; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/22 18:38 |
||||
*/ |
||||
@Configuration |
||||
public class WeaviateConfigure { |
||||
|
||||
private final WeaviateProperties weaviateProperties; |
||||
|
||||
public WeaviateConfigure(WeaviateProperties weaviateProperties) { |
||||
this.weaviateProperties = weaviateProperties; |
||||
} |
||||
|
||||
@Bean |
||||
public WeaviateClient weaviateClient() throws AuthException { |
||||
Config config = new Config(this.weaviateProperties.getSchema(), this.weaviateProperties.getHost() + ":" + this.weaviateProperties.getPort()); |
||||
return WeaviateAuthClient.apiKey(config,this.weaviateProperties.getApiKey()); |
||||
} |
||||
|
||||
@Bean |
||||
public ObjectsGetter objectsGetter() throws AuthException { |
||||
WeaviateClient weaviateClient = weaviateClient(); |
||||
return weaviateClient.data().objectsGetter(); |
||||
} |
||||
|
||||
@Bean |
||||
public ObjectCreator objectCreator() throws AuthException { |
||||
WeaviateClient weaviateClient = weaviateClient(); |
||||
return weaviateClient.data().creator(); |
||||
} |
||||
|
||||
@Bean |
||||
public ObjectDeleter deleter() throws AuthException { |
||||
WeaviateClient weaviateClient = weaviateClient(); |
||||
return weaviateClient.data().deleter(); |
||||
} |
||||
|
||||
@Bean |
||||
public ObjectUpdater updater() throws AuthException { |
||||
WeaviateClient weaviateClient = weaviateClient(); |
||||
return weaviateClient.data().updater(); |
||||
} |
||||
} |
@ -0,0 +1,32 @@
|
||||
package com.hnac.hzims.bigmodel.configuration; |
||||
|
||||
import lombok.Data; |
||||
import org.springframework.boot.context.properties.ConfigurationProperties; |
||||
import org.springframework.stereotype.Component; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/21 15:22 |
||||
*/ |
||||
@Data |
||||
@Component |
||||
@ConfigurationProperties(prefix = "weaviate.datasource") |
||||
public class WeaviateProperties { |
||||
|
||||
private String schema; |
||||
|
||||
private String host; |
||||
|
||||
private String port; |
||||
|
||||
/** |
||||
* 登录认证KEY |
||||
*/ |
||||
private String apiKey; |
||||
|
||||
/** |
||||
* 数据库表名前缀 |
||||
*/ |
||||
private String classNamePrefix; |
||||
|
||||
} |
@ -0,0 +1,183 @@
|
||||
package com.hnac.hzims.bigmodel.database.service; |
||||
|
||||
import cn.hutool.http.HttpRequest; |
||||
import cn.hutool.http.HttpResponse; |
||||
import cn.hutool.json.JSONUtil; |
||||
import com.alibaba.fastjson.JSON; |
||||
import com.alibaba.fastjson.JSONObject; |
||||
import com.google.common.collect.Lists; |
||||
import com.google.protobuf.ServiceException; |
||||
import com.hnac.hzims.bigmodel.configuration.BigModelInvokeApi; |
||||
import com.hnac.hzinfo.exception.HzServiceException; |
||||
import io.weaviate.client.base.Result; |
||||
import io.weaviate.client.v1.data.api.ObjectCreator; |
||||
import io.weaviate.client.v1.data.api.ObjectDeleter; |
||||
import io.weaviate.client.v1.data.api.ObjectUpdater; |
||||
import io.weaviate.client.v1.data.model.WeaviateObject; |
||||
import lombok.RequiredArgsConstructor; |
||||
import lombok.extern.slf4j.Slf4j; |
||||
import org.springblade.core.tool.utils.BeanUtil; |
||||
import org.springblade.core.tool.utils.Func; |
||||
import org.springblade.core.tool.utils.StringUtil; |
||||
import org.springframework.beans.factory.annotation.Value; |
||||
import org.springframework.stereotype.Service; |
||||
import org.springframework.util.Assert; |
||||
|
||||
import java.lang.reflect.Field; |
||||
import java.nio.ByteBuffer; |
||||
import java.nio.ByteOrder; |
||||
import java.util.*; |
||||
import java.util.stream.Collector; |
||||
import java.util.stream.Collectors; |
||||
import java.util.stream.IntStream; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/22 19:17 |
||||
*/ |
||||
@RequiredArgsConstructor |
||||
@Service |
||||
@Slf4j |
||||
public class WeaviateService { |
||||
|
||||
private final ObjectCreator objectCreator; |
||||
private final ObjectUpdater objectUpdater; |
||||
private final ObjectDeleter objectDeleter; |
||||
private final BigModelInvokeApi invokeApi; |
||||
@Value("${gglm.vectorUrl}") |
||||
private final String vectorUrl; |
||||
|
||||
/** |
||||
* 对象保存向量数据库 |
||||
* @param entity 保存对象 |
||||
* @param className 保存表名 |
||||
* @param attrs 待计算的列信息 |
||||
* @return 保存操作结果 |
||||
*/ |
||||
public Boolean save(Object entity, String className, List<String> attrs) { |
||||
ObjectCreator creator = objectCreator.withClassName(className); |
||||
if(Func.isNotEmpty(attrs)) { |
||||
JSONObject jsonObject = JSONObject.parseObject(JSON.toJSONString(entity)); |
||||
List<String> vectors = attrs.stream().map(attr -> jsonObject.getString(attr)).collect(Collectors.toList()); |
||||
Float[] compute = this.compute(vectors); |
||||
creator.withVector(compute); |
||||
} |
||||
Result<WeaviateObject> result = creator.withProperties(BeanUtil.toMap(entity)).run(); |
||||
return !result.hasErrors(); |
||||
} |
||||
|
||||
/** |
||||
* 对象批量保存向量数据库 |
||||
* @param entities 保存对象列表 |
||||
* @param className 保存表名 |
||||
* @param attrsMap 待计算的列信息 key-向量名 value-实体类对象属性,多个按逗号分隔 |
||||
* @return 保存操作结果 |
||||
*/ |
||||
public Boolean saveBatch(List entities,String className, Map<String,String> attrsMap) { |
||||
ObjectCreator creator = objectCreator.withClassName(className); |
||||
List<String> vectorStrs = Lists.newArrayList(); |
||||
List<String> attrs = Lists.newArrayList(); |
||||
if(Func.isNotEmpty(attrsMap)) { |
||||
// 格式化数据
|
||||
attrsMap.forEach((k,v) -> attrs.add(v)); |
||||
// 解析待计算的向量字段
|
||||
entities.forEach(entity -> { |
||||
List<String> vectorStr = attrs.stream().map(fields -> this.getFieldValue(fields, entity)).filter(Func::isNotEmpty).collect(Collectors.toList()); |
||||
vectorStrs.addAll(vectorStr); |
||||
}); |
||||
} |
||||
if(Func.isNotEmpty(vectorStrs)) { |
||||
// 若解析出来的向量存在值
|
||||
Float[] vectors = this.compute(vectorStrs); |
||||
|
||||
} else { |
||||
entities.forEach(entity -> creator.withProperties(BeanUtil.toMap(entity)).run()); |
||||
return true; |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
/** |
||||
* 拆解计算出来的向量Float[] |
||||
* @param entitySize 对象列表size |
||||
* @param attrsMap 待计算的列信息 key-向量名 value-实体类对象属性,多个按逗号分隔 |
||||
* @param vectorTotal 计算出的向量总量 |
||||
* @return 拆解结果 |
||||
*/ |
||||
private List<Map<String,Float[]>> splitVector(Integer entitySize,Map<String,String> attrsMap,Float[] vectorTotal) { |
||||
List<Map<String,Float[]>> result = Lists.newArrayList(); |
||||
|
||||
// 获取待切割的下标
|
||||
List<Integer> indexes = this.getSplitIndex(vectorTotal.length, entitySize); |
||||
indexes.forEach(index -> { |
||||
|
||||
}); |
||||
return null; |
||||
} |
||||
|
||||
/** |
||||
* 获取将list等量分隔成若干份的列表下标 |
||||
* @param size 总数 |
||||
* @param splitNum 分隔数量 |
||||
* @return 下标集合 |
||||
*/ |
||||
private List<Integer> getSplitIndex(int size,int splitNum) { |
||||
if(size % splitNum != 0) { |
||||
throw new HzServiceException("向量计算失败,无法根据同步对象进行等量分隔!"); |
||||
} |
||||
return IntStream.iterate(0, index -> index + 1) |
||||
.limit(splitNum) |
||||
.mapToObj(index -> index * (size / splitNum)) |
||||
.collect(Collectors.toList()); |
||||
} |
||||
|
||||
|
||||
private String getFieldValue(String fields,Object object) { |
||||
Class clazz = object.getClass(); |
||||
return Func.toStrList(",", fields).stream().map(field -> { |
||||
try { |
||||
Field declaredField = clazz.getDeclaredField(field); |
||||
declaredField.setAccessible(true); |
||||
return declaredField.get(object).toString(); |
||||
} catch (NoSuchFieldException | IllegalAccessException e) { |
||||
return null; |
||||
} |
||||
}).collect(Collectors.joining(" ")); |
||||
} |
||||
|
||||
/** |
||||
* 计算向量值 |
||||
* @param vectors 待计算的向量 |
||||
* @return 向量值Float[] |
||||
*/ |
||||
private Float[] compute(List<String> vectors) { |
||||
// 向量计算
|
||||
String url = vectorUrl + invokeApi.getCompute(); |
||||
String jsonData = JSONUtil.toJsonStr(vectors); |
||||
HttpResponse response = HttpRequest.post(url) |
||||
.header("Content-Type", "application/json; charset=utf-8") |
||||
.body(jsonData) |
||||
.execute(); |
||||
byte[] bytes = response.bodyBytes(); |
||||
if (bytes.length % 4 != 0) { |
||||
throw new HzServiceException("向量计算失败!响应数据长度不是4的倍数"); |
||||
} |
||||
List<byte[]> chunks = new ArrayList<>(); |
||||
int range = bytes.length / 4; |
||||
IntStream.range(0, range) |
||||
.forEach(index -> { |
||||
byte[] chunk = new byte[4]; |
||||
int page = index * 4; |
||||
chunk[0] = bytes[page]; |
||||
chunk[1] = bytes[page + 1]; |
||||
chunk[2] = bytes[page + 2]; |
||||
chunk[3] = bytes[page + 3]; |
||||
chunks.add(chunk); |
||||
}); |
||||
List<Float> floats = chunks.stream().map(b -> { |
||||
ByteBuffer buffer = ByteBuffer.wrap(b).order(ByteOrder.LITTLE_ENDIAN); |
||||
return buffer.getFloat(); |
||||
}).collect(Collectors.toList()); |
||||
return floats.toArray(new Float[floats.size()]); |
||||
} |
||||
} |
@ -0,0 +1,11 @@
|
||||
package com.hnac.hzims.bigmodel.database.util; |
||||
|
||||
/** |
||||
* @Author: huangxing |
||||
* @Date: 2024/08/23 10:09 |
||||
*/ |
||||
public class WeaviateUtil { |
||||
|
||||
|
||||
|
||||
} |
Loading…
Reference in new issue