import { useMutation, useQueryClient } from "@tanstack/react-query";
import { ThreadService } from "@/services/thread";
import { getThreadQueryKey } from "../useThread";
import { THREADS_QUERY_KEY } from "../usePaginatedThreads";
import { Thread } from "@/services/thread/types";
import { getCasesQueryKey } from "@/hooks/lexZap/useLexZapCases";
import { Case } from "@/hooks/lexZap/types";

export const useAttachThreadToCase = () => {
  const queryClient = useQueryClient();

  return useMutation({
    mutationKey: ["attachThreadToCase"],
    mutationFn: async ({ threadId, caseId }: { threadId: string; caseId: string }) => {
      await ThreadService.attachCase({ threadId, caseId });
      return { threadId, caseId };
    },
    onMutate: async ({ threadId, caseId }: { threadId: string; caseId: string }) => {
      await queryClient.invalidateQueries({ queryKey: THREADS_QUERY_KEY });
      await queryClient.invalidateQueries({ queryKey: getThreadQueryKey({ threadId }) });

      const previousThread = queryClient.getQueryData<Thread>(getThreadQueryKey({ threadId }));

      queryClient.setQueryData<Thread>(getThreadQueryKey({ threadId }), (old) =>
        old
          ? {
              ...old,
              caseId,
            }
          : undefined
      );
      return { previousThread };
    },
    onSuccess: async ({ threadId, caseId }: { threadId: string; caseId: string }) => {
      await queryClient.invalidateQueries({ queryKey: getCasesQueryKey() });

      queryClient.setQueryData(getCasesQueryKey(), (old: Case[]) => [
        ...old.map((c) => (c.id === caseId ? { ...c, threadIds: [...(c.threadIds || []), threadId] } : c)),
      ]);
    },
    onError: (_error, _newFavoriteSkill, context) => {
      queryClient.setQueryData(getThreadQueryKey({ threadId: context?.previousThread?.id }), context?.previousThread);
    },
  });
};
