| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- import { Injectable } from '@nestjs/common';
- import { ChatOpenAI } from '@langchain/openai';
- import { FFBVectorService } from './ffb-vector.service';
- import { z } from "zod";
- import { StateGraph, START, END, Annotation } from "@langchain/langgraph";
- import { BaseMessage, HumanMessage, AIMessage } from "@langchain/core/messages";
- import { forwardRef, Inject } from '@nestjs/common';
- import { FFBGateway } from '../ffb.gateway';
- // State Definition using Annotation
- const AgentState = Annotation.Root({
- messages: Annotation<BaseMessage[]>({
- reducer: (x, y) => x.concat(y),
- default: () => [],
- }),
- activeIntent: Annotation<string>({
- reducer: (x, y) => y ?? x ?? "General",
- default: () => "General",
- }),
- entityStore: Annotation<Record<string, any>>({
- reducer: (x, y) => ({ ...x, ...y }),
- default: () => ({}),
- }),
- actionPayload: Annotation<any>({
- reducer: (x, y) => y ?? x,
- default: () => null,
- }),
- finalResponse: Annotation<string>({
- reducer: (x, y) => y ?? x,
- }),
- socketId: Annotation<string>({
- reducer: (x, y) => y ?? x,
- default: () => "default",
- })
- });
- @Injectable()
- export class FFBLangChainService {
- private model: ChatOpenAI;
- private graph: any;
- private sessions: Map<string, BaseMessage[]> = new Map();
- constructor(
- private readonly vectorService: FFBVectorService,
- @Inject(forwardRef(() => FFBGateway))
- private readonly gateway: FFBGateway
- ) {
- this.model = new ChatOpenAI({
- modelName: 'gpt-4o',
- apiKey: process.env.OPENAI_API_KEY,
- temperature: 0
- });
- this.initGraph();
- }
- private initGraph() {
- const graph = new StateGraph(AgentState)
- .addNode("router_node", this.routerNode.bind(this))
- .addNode("clarifier_node", this.clarifierNode.bind(this))
- .addNode("general_node", this.generalNode.bind(this))
- .addNode("vector_search_node", this.vectorSearchNode.bind(this))
- .addNode("aggregation_node", this.aggregationNode.bind(this))
- .addNode("synthesis_node", this.synthesisNode.bind(this));
- // Add Edges
- graph.addEdge(START, "router_node");
- graph.addConditionalEdges(
- "router_node",
- (state) => state.activeIntent,
- {
- Clarify: "clarifier_node",
- General: "general_node",
- Semantic: "vector_search_node",
- Aggregate: "aggregation_node"
- }
- );
- graph.addEdge("clarifier_node", END);
- graph.addEdge("general_node", END);
- graph.addEdge("vector_search_node", "synthesis_node");
- graph.addEdge("aggregation_node", "synthesis_node");
- graph.addEdge("synthesis_node", END);
- this.graph = graph.compile();
- }
- // --- NODE IMPLEMENTATIONS ---
- private async routerNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const lastMessage = state.messages[state.messages.length - 1].content as string;
- // Change this in your routerNode:
- const routerSchema = z.object({
- intent: z.enum(['General', 'Clarify', 'Semantic', 'Aggregate']),
- entities: z.object({
- // Use .nullable() instead of .optional() for OpenAI Strict mode
- // Or ensure they are always provided by the LLM
- site: z.string().nullable().describe("The site name mentioned, or null"),
- date: z.string().nullable().describe("The date mentioned, or null"),
- }), // Remove .optional() here; the object itself must be returned
- reasoning: z.string()
- });
- this.gateway.emitThought(state.socketId, {
- node: 'router_node',
- status: 'processing',
- message: 'Analyzing user intent...',
- input: lastMessage
- });
- const routerPrompt = `
- You are an Application Router for a production database.
- Analyze the user input and route to: [General, Clarify, Semantic, Aggregate].
- INTENT DEFINITIONS:
- - Aggregate: Use if the user asks for numbers, totals, averages, or counts (e.g., "How much...", "Total weight").
- - Semantic: Use if the user asks for specific records, qualitative descriptions, issues, "what happened", or "find info about" (e.g., "Show me records for Site A", "What were the notes on block X?").
- - Clarify: Use ONLY if the user names an entity (like a Site) but provides NO verb or question.
- - General: Use for greetings or off-topic chat.
- STRICT RULES:
- 1. If "Site" is mentioned alone (e.g., "Site A"), route to 'Clarify'.
- 2. If the user asks for data or "what happened" regarding a site, route to 'Semantic'.
- 3. Do NOT route to 'Clarify' if there is a clear question.
- User Input: "${lastMessage}"
- `;
- const structuredLlm = this.model.withStructuredOutput(routerSchema);
- const result = await structuredLlm.invoke(routerPrompt);
- // Merge extracted entities with existing store
- this.gateway.emitThought(state.socketId, {
- node: 'router_node',
- status: 'completed',
- result: result
- });
- return {
- activeIntent: result.intent as any,
- entityStore: result.entities || {},
- socketId: state.socketId
- };
- }
- private async clarifierNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const prompt = `User mentioned ${JSON.stringify(state.entityStore)}. Ask them to clarify what they want to know (e.g., total production, specific issues, etc.).`;
- this.gateway.emitThought(state.socketId, {
- node: 'clarifier_node',
- status: 'processing',
- message: 'Asking for clarification',
- context: state.entityStore
- });
- const response = await this.model.invoke(prompt);
- return {
- messages: [response]
- };
- }
- private async generalNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const lastMessage = state.messages[state.messages.length - 1];
- const response = await this.model.invoke([
- new HumanMessage("You are a helpful assistant. Reply to: " + lastMessage.content)
- ]);
- return {
- messages: [response]
- };
- }
- private async vectorSearchNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const lastMessage = state.messages[state.messages.length - 1].content as string;
- const filter: Record<string, any> = {};
- if (state.entityStore && state.entityStore.site) {
- filter.site = state.entityStore.site;
- }
- const results = await this.vectorService.vectorSearch(lastMessage, 5, filter);
- this.gateway.emitThought(state.socketId, {
- node: 'vector_search_node',
- status: 'completed',
- query: lastMessage,
- filter: filter,
- resultsCount: results.length
- });
- return {
- actionPayload: { type: 'search', query: lastMessage, results }
- };
- }
- private async aggregationNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const lastMessage = state.messages[state.messages.length - 1].content as string;
- const pipelineSchema = z.object({
- matchStage: z.object({
- site: z.string().nullable(),
- startDate: z.string().nullable(),
- endDate: z.string().nullable(),
- }),
- aggregationType: z.enum(["sum", "avg", "count"]),
- fieldToAggregate: z.enum(["quantity", "weight"])
- });
- const structuredLlm = this.model.withStructuredOutput(pipelineSchema);
- const params = await structuredLlm.invoke(`Extract aggregation parameters for: "${lastMessage}". Context: ${JSON.stringify(state.entityStore)}`);
- const pipeline: any[] = [];
- const match: any = {};
- // Check for null instead of undefined
- if (params.matchStage.site !== null) {
- match.site = params.matchStage.site;
- }
- if (params.matchStage.startDate !== null || params.matchStage.endDate !== null) {
- match.productionDate = {};
- if (params.matchStage.startDate !== null) {
- match.productionDate.$gte = new Date(params.matchStage.startDate);
- }
- if (params.matchStage.endDate !== null) {
- match.productionDate.$lte = new Date(params.matchStage.endDate);
- }
- }
- if (Object.keys(match).length > 0) {
- pipeline.push({ $match: match });
- }
- const group: any = { _id: null };
- const operator = `$${params.aggregationType}`;
- group.totalValue = { [operator]: `$${params.fieldToAggregate}` };
- pipeline.push({ $group: group });
- const results = await this.vectorService.aggregate(pipeline);
- this.gateway.emitThought(state.socketId, {
- node: 'aggregation_node',
- status: 'completed',
- pipeline: pipeline,
- results: results
- });
- return {
- actionPayload: { type: 'aggregate', pipeline, results }
- };
- }
- private async synthesisNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
- const lastMessage = state.messages[state.messages.length - 1].content as string;
- const payload = state.actionPayload;
- const prompt = `
- User Question: "${lastMessage}"
- Data Context: ${JSON.stringify(payload)}
-
- Synthesize a natural language answer based STRICTLY on the Data Context.
- Cite the source (e.g., "Based on aggregation results...").
- `;
- this.gateway.emitThought(state.socketId, {
- node: 'synthesis_node',
- status: 'processing',
- message: 'Synthesizing final response',
- dataContextLength: JSON.stringify(payload).length
- });
- const response = await this.model.invoke(prompt);
- return {
- messages: [response]
- };
- }
- // --- MAIN ENTRY POINT ---
- createSession(socketId: string) {
- this.sessions.set(socketId, []);
- console.log(`Session created for ${socketId}`);
- }
- deleteSession(socketId: string) {
- this.sessions.delete(socketId);
- console.log(`Session deleted for ${socketId}`);
- }
- async chat(socketId: string, message: string): Promise<string> {
- try {
- // Get history or init empty
- const history = this.sessions.get(socketId) || [];
- const inputs = {
- messages: [...history, new HumanMessage(message)],
- entityStore: {},
- socketId: socketId
- };
- const result = await this.graph.invoke(inputs);
- const allMessages = result.messages as BaseMessage[];
- // Update history (keep all messages for context window? Or truncate?)
- // For now, keep all. Memory optimization might be needed later.
- this.sessions.set(socketId, allMessages);
- const agentMessages = allMessages.filter((m: BaseMessage) => m._getType() === 'ai');
- const lastResponse = agentMessages[agentMessages.length - 1];
- return lastResponse.content as string;
- } catch (error) {
- console.error('Error calling LangGraph:', error);
- throw error;
- }
- }
- }
|