Files
Automata/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/OpenAiFunctionsAgent/execute.ts
2024-02-21 14:59:37 +01:00

114 lines
3.4 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import {
type IExecuteFunctions,
type INodeExecutionData,
NodeConnectionType,
NodeOperationError,
} from 'n8n-workflow';
import type { AgentExecutorInput } from 'langchain/agents';
import { AgentExecutor, OpenAIAgent } from 'langchain/agents';
import type { Tool } from 'langchain/tools';
import type { BaseOutputParser } from 'langchain/schema/output_parser';
import { PromptTemplate } from 'langchain/prompts';
import { CombiningOutputParser } from 'langchain/output_parsers';
import { BufferMemory, type BaseChatMemory } from 'langchain/memory';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { getOptionalOutputParsers, getPromptInputByType } from '../../../../../utils/helpers';
export async function openAiFunctionsAgentExecute(
this: IExecuteFunctions,
): Promise<INodeExecutionData[][]> {
this.logger.verbose('Executing OpenAi Functions Agent');
const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as ChatOpenAI;
if (!(model instanceof ChatOpenAI)) {
throw new NodeOperationError(
this.getNode(),
'OpenAI Functions Agent requires OpenAI Chat Model',
);
}
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const tools = (await this.getInputConnectionData(NodeConnectionType.AiTool, 0)) as Tool[];
const outputParsers = await getOptionalOutputParsers(this);
const options = this.getNodeParameter('options', 0, {}) as {
systemMessage?: string;
maxIterations?: number;
returnIntermediateSteps?: boolean;
};
const agentConfig: AgentExecutorInput = {
tags: ['openai-functions'],
agent: OpenAIAgent.fromLLMAndTools(model, tools, {
prefix: options.systemMessage,
}),
tools,
maxIterations: options.maxIterations ?? 10,
returnIntermediateSteps: options?.returnIntermediateSteps === true,
memory:
memory ??
new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
inputKey: 'input',
outputKey: 'output',
}),
};
const agentExecutor = AgentExecutor.fromAgentAndTools(agentConfig);
const returnData: INodeExecutionData[] = [];
let outputParser: BaseOutputParser | undefined;
let prompt: PromptTemplate | undefined;
if (outputParsers.length) {
outputParser =
outputParsers.length === 1 ? outputParsers[0] : new CombiningOutputParser(...outputParsers);
const formatInstructions = outputParser.getFormatInstructions();
prompt = new PromptTemplate({
template: '{input}\n{formatInstructions}',
inputVariables: ['input'],
partialVariables: { formatInstructions },
});
}
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('text', itemIndex) as string;
} else {
input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');
}
if (prompt) {
input = (await prompt.invoke({ input })).value;
}
let response = await agentExecutor.call({ input, outputParsers });
if (outputParser) {
response = { output: await outputParser.parse(response.output as string) };
}
returnData.push({ json: response });
}
return await this.prepareOutputData(returnData);
}