ffb-langchain.service.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import { Injectable } from '@nestjs/common';
  2. import { ChatOpenAI } from '@langchain/openai';
  3. import { FFBVectorService } from './ffb-vector.service';
  4. import { z } from "zod";
  5. import { StateGraph, START, END, Annotation } from "@langchain/langgraph";
  6. import { BaseMessage, HumanMessage, AIMessage } from "@langchain/core/messages";
  7. import { forwardRef, Inject } from '@nestjs/common';
  8. import { FFBGateway } from '../ffb.gateway';
  9. // State Definition using Annotation
  10. const AgentState = Annotation.Root({
  11. messages: Annotation<BaseMessage[]>({
  12. reducer: (x, y) => x.concat(y),
  13. default: () => [],
  14. }),
  15. activeIntent: Annotation<string>({
  16. reducer: (x, y) => y ?? x ?? "General",
  17. default: () => "General",
  18. }),
  19. entityStore: Annotation<Record<string, any>>({
  20. reducer: (x, y) => ({ ...x, ...y }),
  21. default: () => ({}),
  22. }),
  23. actionPayload: Annotation<any>({
  24. reducer: (x, y) => y ?? x,
  25. default: () => null,
  26. }),
  27. finalResponse: Annotation<string>({
  28. reducer: (x, y) => y ?? x,
  29. }),
  30. socketId: Annotation<string>({
  31. reducer: (x, y) => y ?? x,
  32. default: () => "default",
  33. })
  34. });
  35. @Injectable()
  36. export class FFBLangChainService {
  37. private model: ChatOpenAI;
  38. private graph: any;
  39. private sessions: Map<string, BaseMessage[]> = new Map();
  40. constructor(
  41. private readonly vectorService: FFBVectorService,
  42. @Inject(forwardRef(() => FFBGateway))
  43. private readonly gateway: FFBGateway
  44. ) {
  45. this.model = new ChatOpenAI({
  46. modelName: 'gpt-4o',
  47. apiKey: process.env.OPENAI_API_KEY,
  48. temperature: 0
  49. });
  50. this.initGraph();
  51. }
  52. private initGraph() {
  53. const graph = new StateGraph(AgentState)
  54. .addNode("router_node", this.routerNode.bind(this))
  55. .addNode("clarifier_node", this.clarifierNode.bind(this))
  56. .addNode("general_node", this.generalNode.bind(this))
  57. .addNode("vector_search_node", this.vectorSearchNode.bind(this))
  58. .addNode("aggregation_node", this.aggregationNode.bind(this))
  59. .addNode("synthesis_node", this.synthesisNode.bind(this));
  60. // Add Edges
  61. graph.addEdge(START, "router_node");
  62. graph.addConditionalEdges(
  63. "router_node",
  64. (state) => state.activeIntent,
  65. {
  66. Clarify: "clarifier_node",
  67. General: "general_node",
  68. Semantic: "vector_search_node",
  69. Aggregate: "aggregation_node"
  70. }
  71. );
  72. graph.addEdge("clarifier_node", END);
  73. graph.addEdge("general_node", END);
  74. graph.addEdge("vector_search_node", "synthesis_node");
  75. graph.addEdge("aggregation_node", "synthesis_node");
  76. graph.addEdge("synthesis_node", END);
  77. this.graph = graph.compile();
  78. }
  79. // --- NODE IMPLEMENTATIONS ---
  80. private async routerNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  81. const lastMessage = state.messages[state.messages.length - 1].content as string;
  82. // Change this in your routerNode:
  83. const routerSchema = z.object({
  84. intent: z.enum(['General', 'Clarify', 'Semantic', 'Aggregate']),
  85. entities: z.object({
  86. // Use .nullable() instead of .optional() for OpenAI Strict mode
  87. // Or ensure they are always provided by the LLM
  88. site: z.string().nullable().describe("The site name mentioned, or null"),
  89. date: z.string().nullable().describe("The date mentioned, or null"),
  90. }), // Remove .optional() here; the object itself must be returned
  91. reasoning: z.string()
  92. });
  93. this.gateway.emitThought(state.socketId, {
  94. node: 'router_node',
  95. status: 'processing',
  96. message: 'Analyzing user intent...',
  97. input: lastMessage
  98. });
  99. const routerPrompt = `
  100. You are an Application Router for a production database.
  101. Analyze the user input and route to: [General, Clarify, Semantic, Aggregate].
  102. INTENT DEFINITIONS:
  103. - Aggregate: Use if the user asks for numbers, totals, averages, or counts (e.g., "How much...", "Total weight").
  104. - Semantic: Use if the user asks for specific records, qualitative descriptions, issues, "what happened", or "find info about" (e.g., "Show me records for Site A", "What were the notes on block X?").
  105. - Clarify: Use ONLY if the user names an entity (like a Site) but provides NO verb or question.
  106. - General: Use for greetings or off-topic chat.
  107. STRICT RULES:
  108. 1. If "Site" is mentioned alone (e.g., "Site A"), route to 'Clarify'.
  109. 2. If the user asks for data or "what happened" regarding a site, route to 'Semantic'.
  110. 3. Do NOT route to 'Clarify' if there is a clear question.
  111. User Input: "${lastMessage}"
  112. `;
  113. const structuredLlm = this.model.withStructuredOutput(routerSchema);
  114. const result = await structuredLlm.invoke(routerPrompt);
  115. // Merge extracted entities with existing store
  116. this.gateway.emitThought(state.socketId, {
  117. node: 'router_node',
  118. status: 'completed',
  119. result: result
  120. });
  121. return {
  122. activeIntent: result.intent as any,
  123. entityStore: result.entities || {},
  124. socketId: state.socketId
  125. };
  126. }
  127. private async clarifierNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  128. const prompt = `User mentioned ${JSON.stringify(state.entityStore)}. Ask them to clarify what they want to know (e.g., total production, specific issues, etc.).`;
  129. this.gateway.emitThought(state.socketId, {
  130. node: 'clarifier_node',
  131. status: 'processing',
  132. message: 'Asking for clarification',
  133. context: state.entityStore
  134. });
  135. const response = await this.model.invoke(prompt);
  136. return {
  137. messages: [response]
  138. };
  139. }
  140. private async generalNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  141. const lastMessage = state.messages[state.messages.length - 1];
  142. const response = await this.model.invoke([
  143. new HumanMessage("You are a helpful assistant. Reply to: " + lastMessage.content)
  144. ]);
  145. return {
  146. messages: [response]
  147. };
  148. }
  149. private async vectorSearchNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  150. const lastMessage = state.messages[state.messages.length - 1].content as string;
  151. const filter: Record<string, any> = {};
  152. if (state.entityStore && state.entityStore.site) {
  153. filter.site = state.entityStore.site;
  154. }
  155. const results = await this.vectorService.vectorSearch(lastMessage, 5, filter);
  156. this.gateway.emitThought(state.socketId, {
  157. node: 'vector_search_node',
  158. status: 'completed',
  159. query: lastMessage,
  160. filter: filter,
  161. resultsCount: results.length
  162. });
  163. return {
  164. actionPayload: { type: 'search', query: lastMessage, results }
  165. };
  166. }
  167. private async aggregationNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  168. const lastMessage = state.messages[state.messages.length - 1].content as string;
  169. const pipelineSchema = z.object({
  170. matchStage: z.object({
  171. site: z.string().nullable(),
  172. startDate: z.string().nullable(),
  173. endDate: z.string().nullable(),
  174. }),
  175. aggregationType: z.enum(["sum", "avg", "count"]),
  176. fieldToAggregate: z.enum(["quantity", "weight"])
  177. });
  178. const structuredLlm = this.model.withStructuredOutput(pipelineSchema);
  179. const params = await structuredLlm.invoke(`Extract aggregation parameters for: "${lastMessage}". Context: ${JSON.stringify(state.entityStore)}`);
  180. const pipeline: any[] = [];
  181. const match: any = {};
  182. // Check for null instead of undefined
  183. if (params.matchStage.site !== null) {
  184. match.site = params.matchStage.site;
  185. }
  186. if (params.matchStage.startDate !== null || params.matchStage.endDate !== null) {
  187. match.productionDate = {};
  188. if (params.matchStage.startDate !== null) {
  189. match.productionDate.$gte = new Date(params.matchStage.startDate);
  190. }
  191. if (params.matchStage.endDate !== null) {
  192. match.productionDate.$lte = new Date(params.matchStage.endDate);
  193. }
  194. }
  195. if (Object.keys(match).length > 0) {
  196. pipeline.push({ $match: match });
  197. }
  198. const group: any = { _id: null };
  199. const operator = `$${params.aggregationType}`;
  200. group.totalValue = { [operator]: `$${params.fieldToAggregate}` };
  201. pipeline.push({ $group: group });
  202. const results = await this.vectorService.aggregate(pipeline);
  203. this.gateway.emitThought(state.socketId, {
  204. node: 'aggregation_node',
  205. status: 'completed',
  206. pipeline: pipeline,
  207. results: results
  208. });
  209. return {
  210. actionPayload: { type: 'aggregate', pipeline, results }
  211. };
  212. }
  213. private async synthesisNode(state: typeof AgentState.State): Promise<Partial<typeof AgentState.State>> {
  214. const lastMessage = state.messages[state.messages.length - 1].content as string;
  215. const payload = state.actionPayload;
  216. const prompt = `
  217. User Question: "${lastMessage}"
  218. Data Context: ${JSON.stringify(payload)}
  219. Synthesize a natural language answer based STRICTLY on the Data Context.
  220. Cite the source (e.g., "Based on aggregation results...").
  221. `;
  222. this.gateway.emitThought(state.socketId, {
  223. node: 'synthesis_node',
  224. status: 'processing',
  225. message: 'Synthesizing final response',
  226. dataContextLength: JSON.stringify(payload).length
  227. });
  228. const response = await this.model.invoke(prompt);
  229. return {
  230. messages: [response]
  231. };
  232. }
  233. // --- MAIN ENTRY POINT ---
  234. createSession(socketId: string) {
  235. this.sessions.set(socketId, []);
  236. console.log(`Session created for ${socketId}`);
  237. }
  238. deleteSession(socketId: string) {
  239. this.sessions.delete(socketId);
  240. console.log(`Session deleted for ${socketId}`);
  241. }
  242. async chat(socketId: string, message: string): Promise<string> {
  243. try {
  244. // Get history or init empty
  245. const history = this.sessions.get(socketId) || [];
  246. const inputs = {
  247. messages: [...history, new HumanMessage(message)],
  248. entityStore: {},
  249. socketId: socketId
  250. };
  251. const result = await this.graph.invoke(inputs);
  252. const allMessages = result.messages as BaseMessage[];
  253. // Update history (keep all messages for context window? Or truncate?)
  254. // For now, keep all. Memory optimization might be needed later.
  255. this.sessions.set(socketId, allMessages);
  256. const agentMessages = allMessages.filter((m: BaseMessage) => m._getType() === 'ai');
  257. const lastResponse = agentMessages[agentMessages.length - 1];
  258. return lastResponse.content as string;
  259. } catch (error) {
  260. console.error('Error calling LangGraph:', error);
  261. throw error;
  262. }
  263. }
  264. }