// Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import './styles.scss';
import React, { useEffect, useState } from 'react';
import { connect } from 'react-redux';
import Modal from 'antd/lib/modal';
import notification from 'antd/lib/notification';

import { ThunkDispatch } from 'utils/redux';
import { modelsActions, startInferenceAsync } from 'actions/models-actions';
import CVATLoadingSpinner from 'components/common/loading-spinner';
import { CombinedState } from 'reducers';
import MLModel from 'cvat-core/src/ml-model';
import { getCore, Label, ShapeType, Task } from 'cvat-core-wrapper';
import DetectorRunner from './detector-runner';
import { RectDrawingMethod } from 'cvat-canvas/src/typescript/canvasModel';
import { ObjectType } from 'cvat-core/src/enums';
import { rememberObject } from 'actions/annotation-actions';

const core = getCore();

interface StateToProps {
    visible: boolean;
    task: any;
    detectors: MLModel[];
    reid: MLModel[];
    classifiers: MLModel[];
    canvasInstance: any;
    labels: Label[];
}

interface DispatchToProps {
    runInference(task: any, model: MLModel, body: object): void;
    closeDialog(): void;
    onDrawStart(
        shapeType: ShapeType,
        labelID: number,
        objectType: ObjectType,
        points?: number,
        rectDrawingMethod?: RectDrawingMethod,
    ): void;
}

function mapStateToProps(state: CombinedState): StateToProps {
    const { models } = state;
    const { detectors, reid, classifiers } = models;
    const canvasInstance = state.annotation.canvas.instance;
    const { labels } = state.annotation.job;
    return {
        visible: models.modelRunnerIsVisible,
        task: models.modelRunnerTask,
        reid,
        detectors,
        classifiers,
        canvasInstance,
        labels,
    };
}

function mapDispatchToProps(dispatch: ThunkDispatch): DispatchToProps {
    return {
        runInference(taskID: number, model: MLModel, body: object) {
            dispatch(startInferenceAsync(taskID, model, body));
        },
        closeDialog() {
            dispatch(modelsActions.closeRunModelDialog());
        },
        onDrawStart(
            shapeType: ShapeType,
            labelID: number,
            objectType: ObjectType,
            points?: number,
            rectDrawingMethod?: RectDrawingMethod,
        ): void {
            dispatch(
                rememberObject({
                    activeObjectType: objectType,
                    activeShapeType: shapeType,
                    activeLabelID: labelID,
                    activeNumOfPoints: points,
                    activeRectDrawingMethod: rectDrawingMethod,
                    activeCuboidDrawingMethod: undefined,
                }),
            );
        },
    };
}

function ModelRunnerDialog(props: StateToProps & DispatchToProps): JSX.Element {
    const { reid, detectors, classifiers, task, visible, canvasInstance, runInference, closeDialog } = props;

    const models = [...reid, ...detectors, ...classifiers];
    const [taskInstance, setTaskInstance] = useState<Task | null>(null);

    useEffect(() => {
        if (task) {
            core.tasks
                .get({ id: task.id })
                .then(([_task]: Task[]) => {
                    if (_task) {
                        setTaskInstance(_task);
                    }
                })
                .catch((error: any) => {
                    notification.error({ message: 'Could not get task details', description: error.toString() });
                });
        }
    }, [visible, task]);

    return (
        <Modal
            destroyOnClose
            visible={visible}
            footer={[]}
            onCancel={(): void => closeDialog()}
            maskClosable
            title='Automatic annotation'
        >
            {taskInstance ? (
                <DetectorRunner
                    withCleanup
                    models={models}
                    labels={taskInstance.labels}
                    dimension={taskInstance.dimension}
                    runInference={(...args) => {
                        closeDialog();
                        runInference(taskInstance.id, ...args);
                    }}
                />
            ) : (
                <CVATLoadingSpinner />
            )}
        </Modal>
    );
}

export default connect(mapStateToProps, mapDispatchToProps)(ModelRunnerDialog);
