import { Connection, Node, NodeEditor, Socket } from 'rete';

import { getComponentByName } from '../nodes';

function install(editor: NodeEditor) {
  // Rete removes connections while first building the pipeline.
  // In that case, we don't need to reset the original output type.
  const initializedConnections = new Set<string>();

  // Reset output type to default when node inputs are fully disconnected
  editor.on('connectionremoved', ({ input: { node } }) => {
    if (!node || !isDynamicSocketNode(node)) {
      return;
    }

    if (!initializedConnections.has(String(node.id))) {
      initializedConnections.add(String(node.id));
      return;
    }

    const Component = getComponentByName(node.name);

    if (!Component || !Component.outputParams) {
      return;
    }

    const [target] = Component.outputParams;

    const inputs = Array.from(node.inputs.values());
    const hasConnection = inputs.some((input) => input.hasConnection());
    const outputs = Array.from(node.outputs.values());
    const isDefaultSocket = outputs.every(
      ({ socket }) => socket.name === target.name
    );

    // Skip nodes that still have connections or already use the default socket
    if (hasConnection || isDefaultSocket) {
      return;
    }

    updateNodeOutputSockets(node, target);
  });

  editor.on('connectioncreated', ({ input: { node }, output }) => {
    if (!node || !isDynamicSocketNode(node)) {
      return;
    }

    updateNodeOutputSockets(node, output.socket);
  });

  // Remove output sockets of the node receiving a connection and replace them
  // with sockets based on the connection output type.
  // Only restore original connections if they're still valid
  async function updateNodeOutputSockets(
    /** Node receiving the connection */
    node: Node,
    /** Target output type */
    target: Socket
  ) {
    // Store connections to be able to restore them later if valid
    const previousConnections = new Map<string, Connection[]>();

    // Remove outputs that don't match the target type
    for (const output of node.outputs.values()) {
      const connections = [...output.connections];

      if (output.hasConnection()) {
        output.connections.forEach(editor.removeConnection.bind(editor));
      }

      previousConnections.set(output.key, connections);

      output.socket = target;
      node.outputs.set(output.key, output);
    }

    // Restore valid saved connections
    for (const output of node.outputs.values()) {
      const outputConnections = previousConnections.get(output.key);

      outputConnections?.forEach((connection) => {
        try {
          editor.connect(output, connection.input);
        } catch (error) {
          // Do nothing -- let Rete remove connections if it fails
        }
      });
    }

    // NOTE: await is required to be able to access the new connections
    // see https://github.com/retejs/rete/issues/417#issuecomment-592930912
    await node.update();

    try {
      if (node.getConnections().length > 1) {
        editor.view.updateConnections({ node });
      }
    } catch (error) {
      editor.trigger('warn', error as Error);
      console.error(error);
    }
  }
}

function isDynamicSocketNode(node: Node | null) {
  if (!node) {
    return false;
  }

  const Component = getComponentByName(node.name);
  return Boolean(Component?.inheritOutputFromInput);
}

const plugin = { name: 'InheritedOutputType', install };

export default plugin;
