1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| from langchain_core.embeddings import Embeddings from langchain_community.vectorstores import Chroma from langchain.text_splitter import CharacterTextSplitter from typing import List, Dict, Any import requests import os from langchain.document_loaders import JSONLoader from dotenv import load_dotenv from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain_deepseek import ChatDeepSeek
load_dotenv()
class SiliconFlowEmbeddings(Embeddings): def __init__(self, api_url: str, api_token: str): self.api_url = api_url self.api_token = api_token self.session = requests.Session() self.session.headers.update({ "Authorization": f"Bearer {api_token}", "Content-Type": "application/json" }) api_key = os.getenv("SILICONFLOW_API_KEY")
def _get_embedding(self, text: str) -> List[float]: payload = { "model": "BAAI/bge-large-zh-v1.5", "input": text, "encoding_format": "float", "dimensions": 1024 }
try: response = self.session.post(self.api_url, json=payload, timeout=10) response.raise_for_status() return response.json()["data"][0]["embedding"] except requests.exceptions.RequestException as e: raise RuntimeError(f"API request failed: {str(e)}") from e def embed_documents(self, texts: List[str]) -> List[List[float]]: return [self._get_embedding(text) for text in texts] def embed_query(self, text: str) -> List[float]: return self._get_embedding(text)
embeddings = SiliconFlowEmbeddings( api_url="https://api.siliconflow.cn/v1/embeddings", api_token=os.environ["SILICONFLOW_API_KEY"] )
def metadata_fun(record: dict,metadata: dict) ->dict: metadata["title"] = record.get("title") metadata["year"] = record.get("year") return metadata
loader = JSONLoader( file_path='data_example.json', jq_schema=".notable_films[]", content_key="title", text_content=False, metadata_func=metadata_fun ) data = loader.load() print(data)
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50) split_documents = text_splitter.split_documents(data)
db = Chroma.from_documents( documents=split_documents, embedding=embeddings, persist_directory="./chroma_db" )
query = "低俗小说" docs = db.similarity_search(query) print(docs)
|