feat(Postgres Chat Memory Node): Implement Postgres Chat Memory node (#10071)

This commit is contained in:
oleg
2024-07-17 08:25:37 +02:00
committed by GitHub
parent e5dda5731d
commit 9cbbb6335d
6 changed files with 200 additions and 6 deletions

View File

@@ -0,0 +1,103 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { NodeConnectionType } from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory';
import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres';
import type pg from 'pg';
import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport';
import type { PostgresNodeCredentials } from 'n8n-nodes-base/dist/nodes/Postgres/v2/helpers/interfaces';
import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryPostgresChat implements INodeType {
description: INodeTypeDescription = {
displayName: 'Postgres Chat Memory',
name: 'memoryPostgresChat',
icon: 'file:postgres.svg',
group: ['transform'],
version: [1],
description: 'Stores the chat history in Postgres table.',
defaults: {
name: 'Postgres Chat Memory',
},
credentials: [
{
name: 'postgres',
required: true,
testedBy: 'postgresConnectionTest',
},
],
codex: {
categories: ['AI'],
subcategories: {
AI: ['Memory'],
},
resources: {
primaryDocumentation: [
{
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.memorypostgreschat/',
},
],
},
},
// eslint-disable-next-line n8n-nodes-base/node-class-description-inputs-wrong-regular-node
inputs: [],
// eslint-disable-next-line n8n-nodes-base/node-class-description-outputs-wrong
outputs: [NodeConnectionType.AiMemory],
outputNames: ['Memory'],
properties: [
getConnectionHintNoticeField([NodeConnectionType.AiAgent]),
sessionIdOption,
sessionKeyProperty,
{
displayName: 'Table Name',
name: 'tableName',
type: 'string',
default: 'n8n_chat_histories',
description:
'The table name to store the chat history in. If table does not exist, it will be created.',
},
],
};
methods = {
credentialTest: {
postgresConnectionTest,
},
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = (await this.getCredentials('postgres')) as PostgresNodeCredentials;
const tableName = this.getNodeParameter('tableName', itemIndex, 'n8n_chat_histories') as string;
const sessionId = getSessionId(this, itemIndex);
const pgConf = await configurePostgres.call(this, credentials);
const pool = pgConf.db.$pool as unknown as pg.Pool;
const pgChatHistory = new PostgresChatMessageHistory({
pool,
sessionId,
tableName,
});
const memory = new BufferMemory({
memoryKey: 'chat_history',
chatHistory: pgChatHistory,
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
});
async function closeFunction() {
void pool.end();
}
return {
closeFunction,
response: logWrapper(memory, this),
};
}
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.9 KiB

View File

@@ -77,6 +77,7 @@
"dist/nodes/llms/LMOpenHuggingFaceInference/LmOpenHuggingFaceInference.node.js",
"dist/nodes/memory/MemoryBufferWindow/MemoryBufferWindow.node.js",
"dist/nodes/memory/MemoryMotorhead/MemoryMotorhead.node.js",
"dist/nodes/memory/MemoryPostgresChat/MemoryPostgresChat.node.js",
"dist/nodes/memory/MemoryRedisChat/MemoryRedisChat.node.js",
"dist/nodes/memory/MemoryManager/MemoryManager.node.js",
"dist/nodes/memory/MemoryChatRetriever/MemoryChatRetriever.node.js",
@@ -153,6 +154,7 @@
"@pinecone-database/pinecone": "2.2.1",
"@qdrant/js-client-rest": "1.9.0",
"@supabase/supabase-js": "2.43.4",
"@types/pg": "^8.11.3",
"@xata.io/client": "0.28.4",
"basic-auth": "2.0.1",
"cheerio": "1.0.0-rc.12",

View File

@@ -10,6 +10,18 @@ import type { BaseOutputParser } from '@langchain/core/output_parsers';
import type { BaseMessage } from '@langchain/core/messages';
import { DynamicTool, type Tool } from '@langchain/core/tools';
import type { BaseLLM } from '@langchain/core/language_models/llms';
import type { BaseChatMemory } from 'langchain/memory';
import type { BaseChatMessageHistory } from '@langchain/core/chat_history';
function hasMethods<T>(obj: unknown, ...methodNames: Array<string | symbol>): obj is T {
return methodNames.every(
(methodName) =>
typeof obj === 'object' &&
obj !== null &&
methodName in obj &&
typeof (obj as Record<string | symbol, unknown>)[methodName] === 'function',
);
}
export function getMetadataFiltersValues(
ctx: IExecuteFunctions,
@@ -38,8 +50,16 @@ export function getMetadataFiltersValues(
return undefined;
}
export function isBaseChatMemory(obj: unknown) {
return hasMethods<BaseChatMemory>(obj, 'loadMemoryVariables', 'saveContext');
}
export function isBaseChatMessageHistory(obj: unknown) {
return hasMethods<BaseChatMessageHistory>(obj, 'getMessages', 'addMessage');
}
export function isChatInstance(model: unknown): model is BaseChatModel {
const namespace = (model as BaseLLM | BaseChatModel)?.lc_namespace ?? [];
const namespace = (model as BaseLLM)?.lc_namespace ?? [];
return namespace.includes('chat_models');
}

View File

@@ -4,21 +4,21 @@ import type { ConnectionTypes, IExecuteFunctions, INodeExecutionData } from 'n8n
import type { Tool } from '@langchain/core/tools';
import type { BaseMessage } from '@langchain/core/messages';
import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory';
import { BaseChatMessageHistory } from '@langchain/core/chat_history';
import type { BaseChatMessageHistory } from '@langchain/core/chat_history';
import type { BaseCallbackConfig, Callbacks } from '@langchain/core/callbacks/manager';
import { Embeddings } from '@langchain/core/embeddings';
import { VectorStore } from '@langchain/core/vectorstores';
import type { Document } from '@langchain/core/documents';
import { TextSplitter } from '@langchain/textsplitters';
import { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import { BaseRetriever } from '@langchain/core/retrievers';
import { BaseOutputParser, OutputParserException } from '@langchain/core/output_parsers';
import { isObject } from 'lodash';
import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base';
import { N8nJsonLoader } from './N8nJsonLoader';
import { N8nBinaryLoader } from './N8nBinaryLoader';
import { logAiEvent, isToolsInstance } from './helpers';
import { logAiEvent, isToolsInstance, isBaseChatMemory, isBaseChatMessageHistory } from './helpers';
const errorsMap: { [key: string]: { message: string; description: string } } = {
'You exceeded your current quota, please check your plan and billing details.': {
@@ -125,7 +125,7 @@ export function logWrapper(
get: (target, prop) => {
let connectionType: ConnectionTypes | undefined;
// ========== BaseChatMemory ==========
if (originalInstance instanceof BaseChatMemory) {
if (isBaseChatMemory(originalInstance)) {
if (prop === 'loadMemoryVariables' && 'loadMemoryVariables' in target) {
return async (values: InputValues): Promise<MemoryVariables> => {
connectionType = NodeConnectionType.AiMemory;
@@ -177,7 +177,7 @@ export function logWrapper(
}
// ========== BaseChatMessageHistory ==========
if (originalInstance instanceof BaseChatMessageHistory) {
if (isBaseChatMessageHistory(originalInstance)) {
if (prop === 'getMessages' && 'getMessages' in target) {
return async (): Promise<BaseMessage[]> => {
connectionType = NodeConnectionType.AiMemory;