ffb-langchain.service.ts 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import { Injectable, Inject, forwardRef } from '@nestjs/common';
  2. import { ChatOpenAI } from '@langchain/openai';
  3. import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
  4. import { BaseChatModel } from "@langchain/core/language_models/chat_models";
  5. import { StateGraph, START, END } from "@langchain/langgraph";
  6. import { BaseMessage, HumanMessage } from "@langchain/core/messages";
  7. import { FFBVectorService } from './ffb-vector.service';
  8. import { FFBGateway } from '../ffb.gateway';
  9. // Config & Utils
  10. import { AgentState } from './config/agent-state';
  11. import { SessionManager } from './utils/session-manager';
  12. // Nodes
  13. import { entryNode } from './nodes/entry.node';
  14. import { routerNode } from './nodes/router.node';
  15. import { guidanceNode } from './nodes/guidance.node';
  16. import { metaNode } from './nodes/meta.node';
  17. import { refusalNode } from './nodes/refusal.node';
  18. import { vectorSearchNode } from './nodes/vector-search.node';
  19. import { aggregationNode } from './nodes/aggregation.node';
  20. import { synthesisNode } from './nodes/synthesis.node';
  21. @Injectable()
  22. export class FFBLangChainService {
  23. private openaiModel: BaseChatModel;
  24. private geminiModel: BaseChatModel;
  25. private graph: any;
  26. private sessionManager: SessionManager;
  27. constructor(
  28. private readonly vectorService: FFBVectorService,
  29. @Inject(forwardRef(() => FFBGateway))
  30. private readonly gateway: FFBGateway
  31. ) {
  32. this.openaiModel = new ChatOpenAI({
  33. modelName: 'gpt-4o-mini',
  34. apiKey: process.env.OPENAI_API_KEY,
  35. temperature: 0
  36. });
  37. this.geminiModel = new ChatGoogleGenerativeAI({
  38. model: 'gemini-2.5-flash',
  39. apiKey: process.env.GOOGLE_API_KEY,
  40. temperature: 0
  41. });
  42. this.sessionManager = new SessionManager();
  43. this.initGraph();
  44. }
  45. private getModel(socketId: string): BaseChatModel {
  46. const provider = this.sessionManager.getModelProvider(socketId);
  47. return provider === 'gemini' ? this.geminiModel : this.openaiModel;
  48. }
  49. switchModel(socketId: string, provider: 'openai' | 'gemini') {
  50. this.sessionManager.setModelProvider(socketId, provider);
  51. }
  52. getCurrentModel(socketId: string) {
  53. const provider = this.sessionManager.getModelProvider(socketId);
  54. return {
  55. provider: provider,
  56. modelName: provider === 'gemini' ? 'gemini-2.5-flash' : 'gpt-4o-mini'
  57. };
  58. }
  59. private initGraph() {
  60. const graph = new StateGraph(AgentState)
  61. .addNode("entry_node", (state) => entryNode(state, this.getModel(state.socketId), this.gateway))
  62. .addNode("guidance_node", (state) => guidanceNode(state))
  63. .addNode("meta_node", (state) => {
  64. const socketId = state.socketId;
  65. const provider = this.sessionManager.getModelProvider(socketId);
  66. const providerName = provider === 'gemini' ? 'Google Gemini' : 'OpenAI';
  67. return metaNode(state, this.getModel(socketId), providerName, this.vectorService);
  68. })
  69. .addNode("refusal_node", (state) => refusalNode(state))
  70. .addNode("router_node", (state) => routerNode(state, this.getModel(state.socketId), this.gateway))
  71. .addNode("vector_search_node", (state) => vectorSearchNode(state, this.vectorService, this.gateway))
  72. .addNode("aggregation_node", (state) => aggregationNode(state, this.getModel(state.socketId), this.vectorService, this.gateway))
  73. .addNode("synthesis_node", (state) => synthesisNode(state, this.getModel(state.socketId), this.gateway));
  74. // Add Edges
  75. graph.addEdge(START, "entry_node");
  76. graph.addConditionalEdges(
  77. "entry_node",
  78. (state) => state.entryCategory,
  79. {
  80. "InScope-Actionable": "router_node",
  81. "InScope-NeedsGuidance": "guidance_node",
  82. "InScope-Meta": "meta_node",
  83. "OutOfScope": "refusal_node"
  84. }
  85. );
  86. graph.addConditionalEdges(
  87. "router_node",
  88. (state) => state.activeIntent,
  89. {
  90. Semantic: "vector_search_node",
  91. Aggregate: "aggregation_node"
  92. }
  93. );
  94. graph.addEdge("guidance_node", END);
  95. graph.addEdge("meta_node", END);
  96. graph.addEdge("refusal_node", END);
  97. graph.addEdge("vector_search_node", "synthesis_node");
  98. graph.addEdge("aggregation_node", "synthesis_node");
  99. graph.addEdge("synthesis_node", END);
  100. this.graph = graph.compile();
  101. }
  102. // --- MAIN ENTRY POINT ---
  103. createSession(socketId: string) {
  104. this.sessionManager.createSession(socketId);
  105. }
  106. deleteSession(socketId: string) {
  107. this.sessionManager.deleteSession(socketId);
  108. }
  109. async chat(socketId: string, message: string): Promise<string> {
  110. try {
  111. // Get session & filter valid history
  112. const session = this.sessionManager.getSession(socketId);
  113. const validHistory = this.sessionManager.getValidHistory(socketId);
  114. const userMsg = new HumanMessage(message);
  115. const inputs = {
  116. messages: [...validHistory, userMsg],
  117. entityStore: session.entityStore,
  118. socketId: socketId
  119. };
  120. const result = await this.graph.invoke(inputs);
  121. const allMessages = result.messages as BaseMessage[];
  122. const updatedEntityStore = result.entityStore as Record<string, any>;
  123. const classification = result.entryCategory as string;
  124. // Get the AI response (last message)
  125. const agentMessages = allMessages.filter((m: BaseMessage) => m._getType() === 'ai');
  126. const lastResponse = agentMessages[agentMessages.length - 1];
  127. // Update Session Storage
  128. this.sessionManager.updateSession(
  129. socketId,
  130. userMsg,
  131. lastResponse,
  132. classification,
  133. updatedEntityStore
  134. );
  135. return lastResponse?.content as string || "I'm sorry, I encountered an error.";
  136. } catch (error) {
  137. console.error('Error calling LangGraph:', error);
  138. throw error;
  139. }
  140. }
  141. }