from typing import TYPE_CHECKING

from django.contrib.auth import get_user_model
from django.db import models
from django.db.models import Value

from rest_framework import serializers

from baserow.api.generative_ai.errors import (
    ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
    ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
)
from baserow.contrib.database.fields.field_filters import (
    contains_filter,
    contains_word_filter,
)
from baserow.contrib.database.fields.field_types import CollationSortMixin, TextField
from baserow.contrib.database.fields.models import Field
from baserow.contrib.database.fields.registries import FieldType
from baserow.contrib.database.formula import BaserowFormulaTextType, BaserowFormulaType
from baserow.core.db import collate_expression
from baserow.core.formula.serializers import FormulaSerializerField
from baserow.core.generative_ai.exceptions import (
    GenerativeAITypeDoesNotExist,
    ModelDoesNotBelongToType,
)
from baserow.core.generative_ai.registries import generative_ai_model_type_registry

from .models import AIField

User = get_user_model()

if TYPE_CHECKING:
    from baserow.contrib.database.table.models import GeneratedTableModel


class AIFieldType(CollationSortMixin, FieldType):
    """
    The AI field can automatically query a generative AI model based on the provided
    prompt. It's possible to reference other fields to generate a unique output.
    """

    type = "ai"
    model_class = AIField
    can_be_in_form_view = False
    keep_data_on_duplication = True
    allowed_fields = ["ai_generative_ai_type", "ai_generative_ai_model", "ai_prompt"]
    serializer_field_names = [
        "ai_generative_ai_type",
        "ai_generative_ai_model",
        "ai_prompt",
    ]
    serializer_field_overrides = {
        "ai_prompt": FormulaSerializerField(
            help_text="The prompt that must run for each row. Must be an formula.",
            required=False,
            allow_blank=True,
            default="",
        ),
    }
    api_exceptions_map = {
        GenerativeAITypeDoesNotExist: ERROR_GENERATIVE_AI_DOES_NOT_EXIST,
        ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE,
    }
    can_get_unique_values = False

    def get_serializer_field(self, instance, **kwargs):
        required = kwargs.get("required", False)
        return serializers.CharField(
            **{
                "required": required,
                "allow_null": not required,
                "allow_blank": not required,
                **kwargs,
            }
        )

    def get_model_field(self, instance, **kwargs):
        return models.TextField(null=True, **kwargs)

    def get_serializer_help_text(self, instance):
        return (
            "Holds a text value that is generated by a generative UI model using a "
            "dynamic prompt."
        )

    def random_value(self, instance, fake, cache):
        return fake.name()

    def to_baserow_formula_type(self, field) -> BaserowFormulaType:
        return BaserowFormulaTextType(nullable=True)

    def from_baserow_formula_type(
        self, formula_type: BaserowFormulaTextType
    ) -> TextField:
        return TextField()

    def get_value_for_filter(self, row: "GeneratedTableModel", field: Field) -> any:
        value = getattr(row, field.db_column)
        return collate_expression(Value(value))

    def contains_query(self, *args):
        return contains_filter(*args)

    def contains_word_query(self, *args):
        return contains_word_filter(*args)

    def _validate_field_kwargs(self, ai_type, model_type, workspace=None):
        ai_type = generative_ai_model_type_registry.get(ai_type)
        models = ai_type.get_enabled_models(workspace=workspace)
        if model_type not in models:
            raise ModelDoesNotBelongToType(model_name=model_type)

    def before_create(
        self, table, primary, allowed_field_values, order, user, field_kwargs
    ):
        ai_type = field_kwargs.get("ai_generative_ai_type", None)
        model_type = field_kwargs.get("ai_generative_ai_model", None)
        workspace = table.database.workspace
        self._validate_field_kwargs(ai_type, model_type, workspace=workspace)

    def before_update(self, from_field, to_field_values, user, field_kwargs):
        update_field = None
        if isinstance(from_field, AIField):
            update_field = from_field

        ai_type = field_kwargs.get("ai_generative_ai_type", None) or getattr(
            update_field, "ai_generative_ai_type", None
        )
        model_type = field_kwargs.get("ai_generative_ai_model", None) or getattr(
            update_field, "ai_generative_ai_model", None
        )
        workspace = from_field.table.database.workspace
        self._validate_field_kwargs(ai_type, model_type, workspace=workspace)
