【LangChain】How to create a custom Memory class 如何自定义一个记忆类

发布时间 2023-07-18 11:56:54作者: ryukirin

How to create a custom Memory class 如何自定义一个记忆类

本文主要自定义了一个在LangChain中使用的Memory类

原文:How to create a custom Memory class

翻译

尽管在LangChain中有了一些预定义好的记忆类型,但是还是很有可能会有人想为自己的应用添加自己的记忆类型。这个笔记会介绍怎么添加。

在这个笔记中,我们会给ConversationChain添加一个自定义的记忆类型。为了添加这个自定义记忆类,我们需要import基础的记忆类然后创建它的子类。

from langchain import OpenAI, ConversationChain
from langchain.schema import BaseMemory
from pydantic import BaseModel
from typing import List, Dict, Any

在这个例子中,我们将写一个自定义记忆类,这个类使用spacy提取实体,并将实体信息保存在一个简单的哈希表中。接着,在进行会话的过程中,我们会关注input的文本,提取所有实体,并将有关他们的所有信息放入上下文中。

  • 请注意,该实现非常简单和脆弱,在生产环境中可能没有用处。其目的是展示您可以添加自定义记忆。

为此,我们需要spaCy

spaCy(简单介绍)

一个可以快速上手的nlp开发库,简单来讲,这里用到的spaCy就是先加载一个语言模型,之后把一个句子放进去跑一遍,同时完成了好几个nlp任务,包括分词、词性标注等,之后结果放在了一个类doc类中。

快速入门教程:使用spaCy做进阶自然语言处理

自定义记忆例子

# !pip install spacy
# !python -m spacy download en_core_web_lg
import spacy

nlp = spacy.load("en_core_web_lg")
class SpacyEntityMemory(BaseMemory, BaseModel):
    """为了保存实体信息的记忆类"""

    # 定义用来保存实体信息的字典
    entities: dict = {}
    # 定义用来筛选添加到prompt中的实体信息的key
    memory_key: str = "entities"

    def clear(self):
        self.entities = {}

    @property
    def memory_variables(self) -> List[str]:
        """定义加进prompt的变量"""
        return [self.memory_key]

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """加载记忆变量,本例子中为entities"""
        # 得到input文本并且用spacy跑一遍
        doc = nlp(inputs[list(inputs.keys())[0]])
        # 提取实体的已知信息,如果存在的话
        entities = [
            self.entities[str(ent)] for ent in doc.ents if str(ent) in self.entities
        ]
        # 返回实体的综合信息,用来放进上下文中
        return {self.memory_key: "\n".join(entities)}

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """将本次对话的上下文保存到缓冲区"""
        # 得到input文本并且用spacy跑一遍
        text = inputs[list(inputs.keys())[0]]
        doc = nlp(text)
        # 对于提到的每个实体,将信息保存到字典中
        for ent in doc.ents:
            ent_str = str(ent)
            if ent_str in self.entities:
                self.entities[ent_str] += f"\n{text}"
            else:
                self.entities[ent_str] = text

现在我们定义一个prompt来接收实体信息和用户input

from langchain.prompts.prompt import PromptTemplate

template = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.

Relevant entity information:
{entities}

Conversation:
Human: {input}
AI:"""
prompt = PromptTemplate(input_variables=["entities", "input"], template=template)

然后现在我们把它们放在一起!

llm = OpenAI(temperature=0)
conversation = ConversationChain(
    llm=llm, prompt=prompt, verbose=True, memory=SpacyEntityMemory()
)

在第一个例子中,对于没有任何预先知识的Harrison,"Relevant entity information"字段是空的。

conversation.predict(input="Harrison likes machine learning")
    
    
    > Entering new ConversationChain chain...
    Prompt after formatting:
    The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.
    
    Relevant entity information:
    
    
    Conversation:
    Human: Harrison likes machine learning
    AI:
    
    > Finished ConversationChain chain.





    " That's great to hear! Machine learning is a fascinating field of study. It involves using algorithms to analyze data and make predictions. Have you ever studied machine learning, Harrison?"

现在是第二个例子,我们可以看到它存入了Harrison的信息。

conversation.predict(
    input="What do you think Harrison's favorite subject in college was?"
)
    
    
    > Entering new ConversationChain chain...
    Prompt after formatting:
    The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. You are provided with information about entities the Human mentions, if relevant.
    
    Relevant entity information:
    Harrison likes machine learning
    
    Conversation:
    Human: What do you think Harrison's favorite subject in college was?
    AI:
    
    > Finished ConversationChain chain.





    ' From what I know about Harrison, I believe his favorite subject in college was machine learning. He has expressed a strong interest in the subject and has mentioned it often.'

请再次注意,该实现非常简单和脆弱,在生产环境中可能并不实用。其目的是展示您可以添加自定义记忆。

个人测试例子

我希望得到一个自定义记忆能够将每次GPT3.5生成的JSON中的某一键值。

在本例子中的JSON格式为:

{
  "current_step": 当前步骤名;
  "result_info":
	{
    	"step": 默认步骤中有几个步骤,step列表中就生成几个json
    	[{
      		"name":步骤名,默认步骤中生成的步骤名;
      		"ok": 是否向用户确认完毕,是/否;
		}, ...];
    	"is_over": 是否都向用户确认完毕,是/否;
    	"next_step": 若"is_over"为"是",则写下一步骤名,若"is_over"为"否",则为"";
	};
  "reply": 会议促进者的发言,必须有内容;
}

我想得到的键值为 "reply" ,故创建新的 RecentKConversationMemory 类。

*注意此处的 RecentKConversationMemory 类泛用性十分差,仅限于在我这个任务中可以使用。

from langchain.schema import BaseMemory
from typing import List, Dict, Any


class RecentKConversationMemory(BaseMemory):
    """保存最近5轮对话中JSON文件特定键值的记忆类"""
    # 定义用来保存最近对话的列表,默认为空
    recent_conversations: list = []
    # 定义用来筛选添加到prompt中的实体信息的key,默认为"recent_conversations"
    memory_key: str = "recent_conversations"
    # 定义前缀,默认为"Human"和"AI"
    human_prefix: str = "Human"
    ai_prefix: str = "AI"
    # 定义保留轮数,默认5
    k: int = 5

    def clear(self):
        self.recent_conversations = []

    @property
    def memory_variables(self) -> List[str]:
        """定义加入prompt的变量"""
        return [self.memory_key]

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """加载记忆变量,这里为特定键值的内容"""
        recent_json_value = ""

        # 遍历最近k轮对话
        for conversation in self.recent_conversations[-self.k:]:
            # 检查是否存在键值为"reply"的内容
            if "reply" in eval(conversation["output"]["response"]).keys():
                human = f"{self.human_prefix}: " + conversation["input"]['input']
                ai = f"{self.ai_prefix}: " + eval(conversation["output"]["response"])["reply"]
                recent_json_value += "\n" + "\n".join([human, ai])

        return {self.memory_key: recent_json_value}

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """将本次对话的上下文保存到缓冲区"""
        # 每轮对话的输入和输出作为一个会话
        conversation = {
            "input": inputs,
            "output": outputs
        }
		# 如果超出5个就把最开始的conversation pop出去
        self.recent_conversations.append(conversation)
        if len(self.recent_conversations) > self.k:
            self.recent_conversations.pop(0)

测试

from langchain.prompts.prompt import PromptTemplate

template = """
你将扮演一个优秀的判断精准的会议促进者,现在需要你参与进会议中来

此处为自己的prompt
重点是接下来的格式所以就将过于长的prompt省掉了

【生成JSON格式】
[
  "current_step": 当前步骤名;
  "result_info":类型为JSON
	[
        "end_condition": 参加者发言中提到的终了条件;
        "reason": 能否达标的理由;
        "is_ok": 终了条件是否达标,是/否;
        "meeting_type": 如果"is_ok"为"是",就把判断好的会议类型写下来;
        "is_next": 可否进行下一步骤,在本步骤,如果'can_type'为'是'则为'是',是/否;
        "next_step": 若"is_next"为"是",则写下一步骤名,若"is_next"为"否",则为"";
	]

  "reply": 会议促进者的发言,必须有内容;
]
###停止生成###

【对话历史】
{recent_conversations}

【当前对话】
参会者:{input}
促进者JSON:"""
prompt = PromptTemplate(input_variables=["recent_conversations", "input"], template=template)
chat = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
llm_chain = ConversationChain(
    llm=chat,
    prompt=prompt,
    # 此处设置为2,目的是方便快速看到结果以及检验对错
    memory=RecentJSONMemory(ai_prefix="会议促进者", human_prefix="参会者", k=2),
    verbose=True,
)
llm_chain.predict(input="开始吧")

此处只截取了【对话历史】以后的结果

【对话历史】


【当前对话】
参会者:开始吧
促进者JSON:

> Finished chain.
{\n  "current_step": "设置终了条件",\n  "result_info": {\n    "end_condition": "",\n    "reason": "",\n    "is_ok": "",\n    "meeting_type": "",\n    "is_next": "",\n    "next_step": ""\n  },\n  "reply": "好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?"\n}
llm_chain.predict(input="不知道")
【对话历史】

参会者: 开始吧
会议促进者: 好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?

【当前对话】
参会者:不知道
促进者JSON:

> Finished chain.
{\n  "current_step": "设置终了条件",\n  "result_info": {\n    "end_condition": "不知道",\n    "reason": "无法明确表明结束的时候的状态",\n    "is_ok": "否",\n    "meeting_type": "",\n    "is_next": "是",\n    "next_step": "明确参会者"\n  },\n  "reply": "终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。"\n}

此处【对话历史】中正如一开始所想的那样只把reply给截了出来放在了记忆中,那么下边验证是否可以做到只保留两轮对话

llm_chain.predict(input="希望能获得几个关于研究室规则的看法")
【对话历史】

参会者: 开始吧
会议促进者: 好的,请问在讨论研究室的规则制定时,你认为应该设置什么样的终了条件呢?
参会者: 不知道
会议促进者: 终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。

【当前对话】
参会者:希望能获得几个关于研究室规则的看法
促进者JSON:

> Finished chain.
{\n  "current_step": "设置终了条件",\n  "result_info": {\n    "end_condition": "希望能获得几个关于研究室规则的看法",\n    "reason": "明确说明要达成人的某种状态",\n    "is_ok": "是",\n    "meeting_type": "信息收集",\n    "is_next": "是",\n    "next_step": "明确参会者"\n  },\n  "reply": "非常好,你提出了一个明确的终了条件,即希望能获得几个关于研究室规则的看法。这符合信息收集类型的会议。接下来,我们需要明确参会者,请问在场有多少人参加这次会议呢?"\n}

此时【对话历史】中已经有2轮对话的消息了,看接着对话是否会将最开始的“开始吧”给去掉

llm_chain.predict(input="3人")
【对话历史】

参会者: 不知道
会议促进者: 终了条件需要能够明确表明结束的时候的状态,请继续讨论并尝试给出一个明确的终了条件。
参会者: 希望能获得几个关于研究室规则的看法
会议促进者: 非常好,你提出了一个明确的终了条件,即希望能获得几个关于研究室规则的看法。这符合信息收集类型的会议。接下来,我们需要明确参会者,请问在场有多少人参加这次会议呢?

【当前对话】
参会者:3人
促进者JSON:

所以是成功了