import { WebSocketContext } from 'ui/common/WebSocketProvider';
import {
  ChatCompleteErrorPayload,
  ChatCompleteErrorTypes,
  ChatCompleteRequestPayload,
  ChatCompleteResponsePayload,
  ChatMessage,
  ChatMessageRole,
  FinishReason,
  GptFunctionName,
} from './third_party_types/chat-types';
import { WebSocketMessage } from './third_party_types/websocket/websocket-message';
import { WebSocketMessageType } from './third_party_types/websocket/websocket-message-type';

type OpenAiModel = {
  name: string;
  id: string;
  maxTokens: number;
};

export interface CallbackResult {
  result?: string;
  error?: string;
}

export interface GptFunction {
  callback: (
    args: object,
    onComplete: (result: CallbackResult) => void,
  ) => void;
}

export const OpenAiModels: OpenAiModel[] = [
  // The first entry is the default
  { id: 'gpt-4-0613', name: 'GPT-4 (06/13)', maxTokens: 8192 },
  {
    id: 'gpt-3.5-turbo-16k-0613',
    name: 'GPT-3.5 Turbo 16K (06/13)',
    maxTokens: 16384,
  },
  { id: 'gpt-3.5-turbo-0613', name: 'GPT-3.5 Turbo (06/13)', maxTokens: 4096 },
  { id: 'none', name: '(None)', maxTokens: 0 },

  // Do not support functions
  // { id: 'gpt-4', name: 'GPT-4', maxTokens: 8192 },
  // { id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo', maxTokens: 4096 },

  // 404 not found:
  // { id: 'gpt-4-32k', name: 'GPT-4 32k', maxTokens: 32768 },
  // { id: 'gpt-4-32k-0314', name: 'GPT-4 32k (03/14)', maxTokens: 32768 },

  // Not chat models:
  // { id: 'code-davinci-002', name: 'Davinci 002 (Code)', maxTokens: 8001 },
  // { id: 'text-davinci-002', name: 'Davinci 002', maxTokens: 4097 },
  // { id: 'text-davinci-003', name: 'Davinci 003', maxTokens: 4097 },
];

class ResponseParser {
  fullContent = '';

  functionCallName = '';

  functionCallArguments = '';

  isDone = false;

  websocket: WebSocketContext;

  updateOutput: (r: ChatMessageRole, o: string) => void;

  onFunctionCallParseStart: (fname: string) => void;

  onFunctionCallParseUpdate: (fargs: string) => void;

  onFunctionCallParseDone: (f: string, args: string) => void;

  onFinish: (reason: string) => void;

  setStreamUuid?: (uuid: string) => void;

  abortSignal?: AbortSignal;

  constructor(
    websocket: WebSocketContext,
    updateOutput: (r: ChatMessageRole, o: string) => void,
    onFunctionCallParseStart: (fname: string) => void,
    onFunctionCallParseUpdate: (fargs: string) => void,
    onFunctionCallParseDone: (f: string, args: string) => void,
    onFinish: (reason: string) => void,
    abortSignal?: AbortSignal,
    setStreamUuid?: (uuid: string) => void,
  ) {
    this.websocket = websocket;
    this.updateOutput = updateOutput;
    this.onFunctionCallParseStart = onFunctionCallParseStart;
    this.onFunctionCallParseUpdate = onFunctionCallParseUpdate;
    this.onFunctionCallParseDone = onFunctionCallParseDone;
    this.onFinish = onFinish;
    this.setStreamUuid = setStreamUuid;
    this.abortSignal = abortSignal;

    websocket.subscribe(
      WebSocketMessageType.CHAT_COMPLETE_RESPONSE,
      'response_parser',
      this.onChatResponseReceived,
    );

    websocket.subscribe(
      WebSocketMessageType.CHAT_COMPLETE_ERROR,
      'response_parser',
      this.onError,
    );
  }

  onChatResponseReceived = (data: WebSocketMessage) => {
    if (this.abortSignal && this.abortSignal.aborted) {
      this.abort();
    }

    const { content, finishReason, functionCall, streamUuid } =
      data.payload as ChatCompleteResponsePayload;

    if (this.setStreamUuid) {
      this.setStreamUuid(streamUuid);
    }

    if (finishReason) {
      // We need to unsubscribe because onFinishReason will call
      // callCompletion again which registers another subscriber.
      this.onFinish(finishReason);
      this.cleanup();
      if (finishReason === FinishReason.FunctionCall) {
        this.handleFunctionCall(this.onFunctionCallParseDone);
      }
    } else if (functionCall) {
      const { name, arguments: args } = functionCall;
      if (name) {
        this.functionCallName = name;
        this.onFunctionCallParseStart(name);
      }
      if (args) {
        this.functionCallArguments = `${this.functionCallArguments}${args}`;
        this.onFunctionCallParseUpdate(args);
      }
    } else if (content) {
      this.fullContent = `${this.fullContent}${content}`;
      this.updateOutput(
        this.functionCallName
          ? ChatMessageRole.Function
          : ChatMessageRole.Assistant,
        this.fullContent,
      );
    }
  };

  onError = (data: WebSocketMessage) => {
    try {
      const payload = JSON.parse(
        data.payload as string,
      ) as ChatCompleteErrorPayload;
      if (payload.type === ChatCompleteErrorTypes.OpenAiKeyNotFound) {
        this.onFinish('Please configure your OpenAI API key.');
      } else if (
        payload.type === ChatCompleteErrorTypes.ChatCallCountExceeded
      ) {
        this.onFinish('Call count exceeded. Please register an API key.');
      } else if (payload.error.includes('status code: 404')) {
        // the user doesn't have access to the gpt model
        this.onFinish(payload.error);
      } else if (payload.error.includes('status code: 401')) {
        this.onFinish('Invalid OpenAI API key.');
      } else if (
        payload.error.includes('status code: 400') &&
        payload.error.includes(
          'Please reduce the length of the messages or functions.',
        )
      ) {
        this.onFinish('Maximum token count reached');
      } else {
        this.onFinish('Internal Error');
      }
    } catch (e) {
      this.onFinish('Internal Error');
    }
    this.cleanup();
  };

  cleanup = () => {
    this.websocket.unsubscribe(
      WebSocketMessageType.CHAT_COMPLETE_RESPONSE,
      'response_parser',
    );
    this.websocket.unsubscribe(
      WebSocketMessageType.CHAT_COMPLETE_ERROR,
      'response_parser',
    );
    this.isDone = true;
  };

  abort = () => {
    this.onFinish('abort');
    this.cleanup();
  };

  handleFunctionCall = async (
    onFunctionCallParseDone: (f: string, args: string) => void,
  ) => {
    // HACK: sometimes chatgpt returns functionName = python and functionArgs
    // will contain the python code.
    if (this.functionCallName === 'python') {
      console.warn('Force execute_python');
      this.functionCallName = 'execute_python';
      this.functionCallArguments = this.functionCallArguments.replace(
        /\r?\n/g,
        '\\n',
      );
      this.functionCallArguments = this.functionCallArguments.replace(
        /"/g,
        '\\"',
      );
      this.functionCallArguments = `{ "code": "${this.functionCallArguments}" }`;
    }
    await onFunctionCallParseDone(
      this.functionCallName,
      this.functionCallArguments,
    );
  };
}

export const callCompletion = async (
  messages: ChatMessage[],
  setOutput: (o: ChatMessage[]) => void,
  websocket: WebSocketContext,
  abortSignal?: AbortSignal,
  doneCallback?: (failed: boolean, finishReason: string) => void,
  temperature = 0.1,
  modelId = OpenAiModels[0].id,
  functions?: { [k: string]: GptFunction },
  setStreamUuid?: (uuid: string) => void,
) => {
  let failed = true;
  let finishReason = '';
  let skipDoneCallback = false; // in case of function call

  try {
    const aiModel = OpenAiModels.find((m) => m.id === modelId);
    if (!aiModel) throw new Error(`Unknown model ID: ${modelId}`);

    if (modelId === 'none') {
      finishReason = 'done';
      return;
    }

    let conversation: ChatMessage[] = [...messages];

    const updateOutput = (role: ChatMessageRole, content: string) => {
      if (role === conversation[conversation.length - 1]?.role) {
        conversation.pop();
      }
      conversation = [...conversation, { role, content }];
      setOutput(conversation);
    };

    const onFunctionCallParseStart = (fname: string) => {
      const functionCall: ChatMessage = {
        role: ChatMessageRole.Function,
        content: '',
        functionName: fname as GptFunctionName,
        functionArgs: '',
        functionResult: '...',
      };
      conversation = [...conversation, functionCall];
      setOutput(conversation);
    };

    const functionDoneCallback = async (callbackResult: CallbackResult) => {
      const lastTurn = conversation[conversation.length - 1];
      lastTurn.functionResult = callbackResult.result || callbackResult.error;
      lastTurn.content = callbackResult.result || callbackResult.error || '';
      lastTurn.functionHasError = callbackResult.error !== undefined;
      setOutput([...conversation]);

      await callCompletion(
        conversation,
        setOutput,
        websocket,
        abortSignal,
        doneCallback,
        temperature,
        modelId,
        functions,
        setStreamUuid,
      );
    };

    const onFunctionCallParseUpdate = (fargs: string) => {
      const lastTurn = conversation[conversation.length - 1];
      lastTurn.functionArgs = `${lastTurn.functionArgs}${fargs}`;
      setOutput([...conversation]);
    };

    const onFunctionCallParseDone = async (f: string, args: string) => {
      if (!functions) return;
      const func = functions[f];
      if (!func) {
        console.error(`Unknown function: ${f}`);
        return;
      }
      skipDoneCallback = true;

      try {
        let argsJson = {};
        argsJson = JSON.parse(args);
        func.callback(argsJson, functionDoneCallback);
      } catch (e) {
        functionDoneCallback({
          error: `Error while parsing ${f} args:\n${e}\nMake sure the args are valid JSON.`,
        });
      }
    };

    const responseParser = new ResponseParser(
      websocket,
      updateOutput,
      onFunctionCallParseStart,
      onFunctionCallParseUpdate,
      onFunctionCallParseDone,
      (r: string) => (finishReason = r),
      abortSignal,
      setStreamUuid,
    );

    const payload: ChatCompleteRequestPayload = {
      messages,
      temperature,
      model: aiModel.id,
    };
    websocket.publish({
      id: '', // FIXME: set the conversation ID + turn index.
      type: WebSocketMessageType.CHAT_COMPLETE_REQUEST,
      payload,
    });

    while (!responseParser.isDone) {
      // eslint-disable-next-line no-await-in-loop
      await new Promise((resolve) => {
        setTimeout(resolve, 1000);
      });
    }

    failed =
      finishReason !== FinishReason.Stop &&
      finishReason !== FinishReason.FunctionCall;
  } catch (e) {
    failed = true;
    console.error('Error while talking to ChatGPT:', e);
  } finally {
    try {
      if (doneCallback && !skipDoneCallback) {
        doneCallback(failed, finishReason);
      }
    } catch (e) {
      console.error(e);
    }
  }
};

export const extractCodeFromChatGpt = (messages: ChatMessage[]) => {
  const lastMessage = messages[messages.length - 1];

  let code = '';
  if (lastMessage.content.includes('```C')) {
    code = lastMessage.content.split('```C')[1];
    if (code.includes('```')) {
      code = code.split('```')[0];
    }
  }

  return code;
};
