import { RUN_BY_ID_QUERY } from "@/queries/run";
import { GQLRunByIdQuery, GQLRunByIdQueryVariables, GQLRunByIdRunFragment } from "@/queries/run.generated";
import { cache } from "@/utils/graphql";
import { GQLRunStatus } from "@/utils/graphql.generated";
import { ApolloError, NetworkStatus, useQuery } from "@apollo/client";
import { RefObject, useState } from "react";

interface UsePollRunReturn {
    data?: GQLRunByIdRunFragment | null;
    loading: boolean;
    error?: ApolloError;
    runId: string | null;
    updateRunId: (id: string | null) => void;
    scrollToLastMessage: () => void;
}

export const usePollRun = (
    assistantId: string,
    threadId: string,
    shouldFetch: boolean,
    widgetInnerRef: RefObject<HTMLDivElement>,
): UsePollRunReturn => {
    const [runId, setRunId] = useState<string | null>(null);
    const skip = !shouldFetch || !runId || (runId || "").length === 0;
    const { data, loading, error, refetch, networkStatus } = useQuery<GQLRunByIdQuery, GQLRunByIdQueryVariables>(RUN_BY_ID_QUERY, {
        variables: {
            id: runId || "",
            threadId: threadId,
            assistantId: assistantId,
        },
        skip: skip,
        notifyOnNetworkStatusChange: true,
        fetchPolicy: "network-only",
        onCompleted: (data) => {
            if (data?.run?.status === GQLRunStatus.InProgress || data?.run?.status === GQLRunStatus.Queued) {
                setTimeout(() => {
                    refetch();
                }, 500);
            }
            if (data?.run?.status === GQLRunStatus.Completed) {
                setRunId(null);
                cache.evict({ fieldName: "thread" });
                cache.gc();
                if (widgetInnerRef.current) {
                    widgetInnerRef.current.scrollTop = widgetInnerRef.current.scrollHeight;
                }
            }
        },
    });

    const updateRunId = (id: string | null) => {
        setRunId(id);
    };

    const scrollToLastMessage = () => {
        if (widgetInnerRef.current) {
            widgetInnerRef.current.scrollTop = widgetInnerRef.current.scrollHeight;
        }
    };

    return { data: data?.run, loading: loading || networkStatus === NetworkStatus.loading, error, updateRunId, runId, scrollToLastMessage };
};
