"""
Base Agent class for multi-agent education system
Provides common interface and methods for all specialized agents
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Iterator, Tuple, Union
from dataclasses import dataclass
import logging
import json
import requests  # used to fetch user images by URL
import base64  # used to encode images to base64
import mimetypes  # used to determine correct MIME types
import time # Import time for logging
from google import genai
from google.genai import types
from openai import OpenAI

# Re-use utility from deepseek_handler to fetch model info from DB
from deepseek_handler import get_model_info
from message_utils import get_last_messages
from llm_providers.gemini_provider import convert_to_gemini_parts
from prompts.tg_suffix import TG_PLAIN_TEXT_SUFFIX

logger = logging.getLogger(__name__)


@dataclass
class AgentState:
    """Shared state between agents"""
    user_id: Optional[int] = None
    subject_id: int = 1  # Default to physics
    grade: int = 8
    textbook_id: Optional[int] = None
    context: Dict[str, Any] = None
    chat_id: Optional[int] = None
    plain_text_mode: bool = False
    
    def __post_init__(self):
        if self.context is None:
            self.context = {}


@dataclass  
class AgentResponse:
    """Response from an agent"""
    agent_name: str
    content: Optional[str] = None  # For non-streaming text
    stream: Optional[Iterator[str]] = None # For streaming content
    next_agent: Optional[str] = None  # Which agent should handle next
    tool_calls: List[Dict[str, Any]] = None
    update_state: Dict[str, Any] = None  # Updates to apply to AgentState
    is_final: bool = False  # If True, this ends the conversation

    
    def __post_init__(self):
        if self.tool_calls is None:
            self.tool_calls = []
        if self.update_state is None:
            self.update_state = {}


class RouteSignal(Exception):
    """Special exception to signal a routing action from within the stream handler."""
    def __init__(self, response: AgentResponse):
        self.response = response

class DirectResponseSignal(Exception):
    """Signal to immediately return final text without further LLM generation."""
    def __init__(self, text: str):
        self.text = text


class BaseAgent(ABC):
    """
    Base class for all agents in the multi-agent system
    
    Each agent has:
    - A specific role and responsibility 
    - Access to tools (database, other agents)
    - Ability to process messages and generate responses
    - Integration with Gemini and OpenAI LLMs
    """
    
    def __init__(self, name: str, system_prompt: str, mode: int = 0, history_limit: int = 100):
        """mode = 0 (chat-model), 1 (reasoner-model)"""
        self.name = name
        self.system_prompt = system_prompt
        self.mode = mode  # 0 chat, 1 reasoner
        self.history_limit = history_limit # Set history_limit for each agent

        # caching: {(agent_name, subject_id, mode): (provider, model_obj, model_name)}
        self._model_cache: Dict[Tuple[str, int, int], Tuple[str, object, str]] = {}
        logger.info(f"Initialized {self.name} agent")
    
    @abstractmethod
    def _get_tools(self) -> List:
        """
        Get tools available to this agent
        Must be implemented by each agent subclass
        
        Returns:
            List of tool definitions (format depends on provider)
        """
        pass

    def _get_gemini_tools(self) -> List:
        """Convert tools to Gemini format - default implementation"""
        return []

    def _get_system_prompt(self, subject_id: int, state: Optional["AgentState"] = None) -> str:
        """Child classes can override to supply subject-specific prompts"""
        prompt = self.system_prompt
        if state and getattr(state, "plain_text_mode", False):
            if TG_PLAIN_TEXT_SUFFIX not in (prompt or ""):
                prompt = (prompt or "") + TG_PLAIN_TEXT_SUFFIX
        return prompt
    
    @abstractmethod  
    def _process_function_call(self, function_call, state: AgentState) -> Any:
        """
        Process a function call from the LLM
        Must be implemented by each agent subclass
        
        Args:
            function_call: Function call object (format depends on provider)
            state: Current agent state
            
        Returns:
            Function call result
        """
        pass

    def can_handle(self, message: str, state: AgentState) -> bool:
        """
        Check if this agent can handle the given message
        Default implementation returns True (can handle anything)
        Override in subclasses for specific routing logic
        
        Args:
            message: User message to check
            state: Current agent state
            
        Returns:
            True if agent can handle this message
        """
        return True
    
    def process_message(self, history: List[Dict[str, Any]], state: AgentState) -> AgentResponse:
        """
        Process a message and generate response
        """
        try:
            current_timestamp = time.time()
            logger.info(f"{self.name}: Processing message at {current_timestamp:.2f}")
            logger.info(f"{self.name}: Subject ID: {state.subject_id}, Mode: {self.mode}")

            # Get thread-safe model and client
            provider, client, model_name = self._get_model_info_safe(state)

            # Send message based on provider
            if provider == "gemini_client":
                return self._send_message_gemini(history, state, client, model_name)
            elif provider in ["openai_client", "openrouter_client", "deepseek_client"]:
                return self._send_message_openai(history, state, client, model_name)
            else:
                raise NotImplementedError(f"Provider {provider} not implemented")
            
        except Exception as e:
            logger.error(f"{self.name}: Error processing message: {e}", exc_info=True)
            return AgentResponse(
                agent_name=self.name,
                content=f"Извините, произошла ошибка при обработке вашего запроса: {repr(e)}",
                is_final=True
            )

    def _get_model_info_safe(self, state: AgentState) -> Tuple[str, Any, str]:
        """Thread-safe way to get model info without modifying self state."""
        subject_id = state.subject_id
        key = (self.name, subject_id, self.mode)
        
        if key in self._model_cache:
            return self._model_cache[key]

        model_info = get_model_info(subject_id, self.mode)
        if not model_info:
            raise RuntimeError(f"Model not found for subject {subject_id} mode {self.mode}")

        api_url, api_key, model_name, provider = model_info
        
        if provider == "gemini_client":
            # Removing http_options for now as it causes 400 INVALID_ARGUMENT in some environments
            client = genai.Client(api_key=api_key)
        elif provider in ["openai_client", "openrouter_client", "deepseek_client"]:
            client = OpenAI(api_key=api_key, base_url=api_url)
        else:
            raise NotImplementedError(f"Provider {provider} not supported")

        # Cache it for next time
        self._model_cache[key] = (provider, client, model_name)
        return provider, client, model_name

    def _send_message_gemini(self, history: List[Dict[str, Any]], state: AgentState, client: Any, model_name: str) -> AgentResponse:
        """Send message using Gemini API with streaming and full tool call loop support."""
        try:
            # Geometry safeguard: disable ProblemRetriever for subject_id == 3
            if self.name == "ProblemRetriever" and int(getattr(state, 'subject_id', 0)) == 3:
                logger.info("ProblemRetriever is disabled for subject_id=3 (geometry). Returning informative message.")
                return AgentResponse(
                    agent_name=self.name,
                    content="Для геометрии выдача задач из базы отключена. Пришлите свою задачу или запросите объяснение темы.",
                    is_final=True
                )

            gemini_history = self._prepare_gemini_history(history, state)
            
            # Get system prompt and tools for this call
            system_prompt = self._get_system_prompt(state.subject_id, state)
            tools = self._get_gemini_tools()

            # Combine system prompt and tools into the config object
            config_params = {}
            if tools:
                config_params['tools'] = tools
            
            if system_prompt:
                config_params['system_instruction'] = system_prompt

            # Configure generation based on agent mode
            if self.mode == 1:  # Reasoner mode
                config_params['temperature'] = 0
            else:  # Chat mode
                config_params['thinking_config'] = types.ThinkingConfig(thinking_budget=0)

            # For ProblemRetriever: force calling DB tools first and suppress any pre-tool text
            suppress_text_until_tool = False
            if self.name == "ProblemRetriever":
                suppress_text_until_tool = True
                try:
                    config_params['tool_config'] = types.ToolConfig(
                        function_calling_config=types.FunctionCallingConfig(
                            mode="ANY",
                            allowed_function_names=["llm_list_available_topics", "llm_get_problems_by_topic"]
                        )
                    )
                    config_params['temperature'] = 0
                except Exception as e:
                    logger.warning(f"{self.name}: Failed to apply tool_config for ProblemRetriever: {e}")

            final_config = types.GenerateContentConfig(**config_params)

            def _internal_stream_handler():
                """
                Handles the streaming and tool-calling logic iteratively.
                If a tool call is found, it processes it and continues the conversation loop.
                """
                nonlocal gemini_history

                # Suppress any model text until the first successful tool call is processed
                allow_text_yield = not suppress_text_until_tool

                while True:
                    # Get a streaming response from the model
                    response_stream = client.models.generate_content_stream(
                        model=model_name,
                        contents=gemini_history,
                        config=final_config
                    )

                    if not response_stream:
                        logger.warning(f"{self.name}: Model returned an empty stream. Halting.")
                        break

                    # Iterate through the stream, yield text, and find any function call
                    found_function_call_part = None
                    for chunk in response_stream:
                        try:
                            if not getattr(chunk, 'candidates', None):
                                continue

                            # Check all parts in the chunk for a function call
                            candidate0 = chunk.candidates[0]
                            content_obj = getattr(candidate0, 'content', None)
                            parts_iter = getattr(content_obj, 'parts', None) or []
                            for part in parts_iter:
                                if getattr(part, 'function_call', None):
                                    # We found a function call. We will process it after the stream is done.
                                    # Note: This assumes the entire function call is in a single part.
                                    found_function_call_part = part

                            # Yield text only if allowed (prevents pre-tool chatter for specific agents)
                            if hasattr(chunk, 'text') and chunk.text and allow_text_yield:
                                yield chunk.text
                        except Exception as stream_chunk_exc:
                            logger.debug(f"{self.name}: Skipped malformed stream chunk: {stream_chunk_exc}")
                            continue

                    # After consuming the stream, check if a tool needs to be called
                    if found_function_call_part:
                        function_call = found_function_call_part.function_call
                        logger.info(f"{self.name}: 🔧 Gemini stream detected function call: {function_call.name}")

                        # Add the model's turn (which contains the function call) to the history
                        gemini_history.append({'role': 'model', 'parts': [found_function_call_part]})
                        
                        # Execute the function
                        function_result = self._process_function_call(function_call, state)

                        # Special handling for the Router agent
                        if self.name == "Router":
                            if isinstance(function_result, dict) and "next_agent" in function_result:
                                raise RouteSignal(AgentResponse(
                                    agent_name=self.name,
                                    next_agent=function_result["next_agent"],
                                    update_state=function_result.get("update_state", {}),
                                    is_final=False
                                ))
                            else:
                                logger.warning(f"{self.name}: Router received non-routing result. Halting.")
                                return

                        # If tool requested direct response, return immediately
                        if isinstance(function_result, dict) and function_result.get("__direct_response__") is not None:
                            raise DirectResponseSignal(str(function_result.get("__direct_response__")))

                        # For other agents, add the tool's result to history and continue the loop
                        logger.info(f"{self.name}: ↪️ Sending function result back to Gemini for {function_call.name}")
                        
                        response_data = function_result
                        if isinstance(function_result, str):
                            try:
                                response_data = json.loads(function_result)
                            except json.JSONDecodeError:
                                response_data = function_result

                        # Create the tool response part
                        function_response_part = types.Part.from_function_response(
                            name=function_call.name,
                            response={'result': response_data}
                        )
                        gemini_history.append({'role': 'tool', 'parts': [function_response_part]})
                        # For ProblemRetriever we never allow free text streaming; we only return direct problem text
                        if self.name == "ProblemRetriever":
                            allow_text_yield = False
                        else:
                            allow_text_yield = True
                        
                        # Continue to the next iteration of the while loop to get the model's final response
                        continue
                    else:
                        # The model's response did not contain a function call, so we're done.
                        break

            if self.name == "Router":
                try:
                    # For Router, we must consume the stream to get the routing decision
                    stream = _internal_stream_handler()
                    for chunk in stream:
                        # The router is not supposed to yield text. Log if it does.
                        if chunk:
                            logger.warning(f"Router agent generated unexpected text output: {chunk}")
                    
                    # If the loop completes without a RouteSignal, it means no function was called.
                    # In this case, we return a response that will trigger the default fallback in RouterAgent.
                    return AgentResponse(agent_name=self.name, is_final=False)
                
                except RouteSignal as e:
                    # This is the expected path for the router.
                    logger.info(f"{self.name}: Caught route signal to {e.response.next_agent}")
                    return e.response
            else:
                # For all other agents, we return a stream. However, tools may request
                # an immediate direct response by raising DirectResponseSignal from the
                # internal handler. Since that exception is thrown during iteration of
                # the generator (not at creation time), we must wrap the generator and
                # convert the signal to normal text output to prevent it from bubbling
                # up to the SSE layer as an error.
                def _wrapped_stream():
                    try:
                        for chunk in _internal_stream_handler():
                            yield chunk
                    except DirectResponseSignal as e:
                        # Yield the final text as a normal chunk so upper layers
                        # can handle it uniformly without treating it as an error.
                        yield e.text

                return AgentResponse(
                    agent_name=self.name,
                    stream=_wrapped_stream(),
                    is_final=True
                )

        except Exception as e:
            logger.error(f"{self.name}: Error in Gemini message sending (outer): {e}", exc_info=True)
            return AgentResponse(
                agent_name=self.name,
                content=f"Извините, произошла ошибка в Gemini: {e}",
                is_final=True
            )

    def _prepare_gemini_history(self, history: List[Dict[str, Any]], state: AgentState):
        """
        Prepare chat session with history for Gemini, using the robust
        conversion logic from the gemini_provider.
        """
        try:
            gemini_history = []
            for msg in history:
                role = 'model' if msg.get('role') == 'assistant' else 'user'
                parts = convert_to_gemini_parts(msg['content'])
                gemini_history.append({'role': role, 'parts': parts})
            
            logger.debug(f"{self.name}: Prepared history with {len(gemini_history)} messages")
            return gemini_history
            
        except Exception as e:
            logger.error(f"{self.name}: Error preparing chat history: {e}")
            return [] # Return empty history on error

    def _send_message_openai(self, history: List[Dict[str, Any]], state: AgentState, client: Any, model_name: str) -> AgentResponse:
        """Send message using OpenAI API with streaming and full tool call loop support."""
        try:
            messages = []
            system_prompt = self._get_system_prompt(state.subject_id, state)
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            
            for msg in history:
                role = 'assistant' if msg.get('role') == 'assistant' else 'user'
                content = msg.get('content')
                if isinstance(content, list):
                    openai_content = []
                    for item in content:
                        if item.get('type') == 'text':
                            openai_content.append({"type": "text", "text": item['text']})
                        elif item.get('type') == 'image_url':
                            openai_content.append({
                                "type": "image_url",
                                "image_url": {"url": item['image_url']['url']}
                            })
                    messages.append({"role": role, "content": openai_content})
                else:
                    messages.append({"role": role, "content": str(content)})

            tools = self._get_tools()
            
            def _internal_openai_handler():
                nonlocal messages
                while True:
                    kwargs = {
                        "model": model_name,
                        "messages": messages,
                        "stream": True
                    }
                    if not ("o1" in model_name or "o3" in model_name):
                        kwargs["temperature"] = 0 if self.mode == 1 else 0.7
                    if tools:
                        kwargs["tools"] = tools

                    response = client.chat.completions.create(**kwargs)
                    
                    full_content = ""
                    tool_calls = []
                    
                    for chunk in response:
                        if not chunk.choices:
                            continue
                        delta = chunk.choices[0].delta
                        if delta.content:
                            full_content += delta.content
                            yield delta.content
                        
                        if delta.tool_calls:
                            for tc in delta.tool_calls:
                                while len(tool_calls) <= tc.index:
                                    tool_calls.append({
                                        "id": None,
                                        "type": "function",
                                        "function": {"name": "", "arguments": ""}
                                    })
                                if tc.id:
                                    tool_calls[tc.index]["id"] = tc.id
                                if tc.function:
                                    if tc.function.name:
                                        tool_calls[tc.index]["function"]["name"] = tc.function.name
                                    if tc.function.arguments:
                                        tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments

                    if tool_calls:
                        messages.append({
                            "role": "assistant",
                            "content": full_content or None,
                            "tool_calls": tool_calls
                        })
                        
                        for tool_call in tool_calls:
                            func_name = tool_call["function"]["name"]
                            func_args_str = tool_call["function"]["arguments"]
                            try:
                                func_args = json.loads(func_args_str)
                            except json.JSONDecodeError:
                                func_args = {}
                                
                            logger.info(f"{self.name}: 🔧 OpenAI detected function call: {func_name}")
                            
                            class MockFunctionCall:
                                def __init__(self, name, args):
                                    self.name = name
                                    self.args = args
                            
                            result = self._process_function_call(MockFunctionCall(func_name, func_args), state)
                            
                            if self.name == "Router":
                                if isinstance(result, dict) and "next_agent" in result:
                                    raise RouteSignal(AgentResponse(
                                        agent_name=self.name,
                                        next_agent=result["next_agent"],
                                        update_state=result.get("update_state", {}),
                                        is_final=False
                                    ))

                            if isinstance(result, dict) and result.get("__direct_response__") is not None:
                                raise DirectResponseSignal(str(result.get("__direct_response__")))

                            messages.append({
                                "role": "tool",
                                "tool_call_id": tool_call["id"],
                                "content": json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
                            })
                        continue
                    else:
                        break

            if self.name == "Router":
                try:
                    stream = _internal_openai_handler()
                    for chunk in stream:
                        if chunk:
                            logger.warning(f"Router agent generated unexpected text output: {chunk}")
                    return AgentResponse(agent_name=self.name, is_final=False)
                except RouteSignal as e:
                    return e.response
            else:
                def _wrapped_stream():
                    try:
                        for chunk in _internal_openai_handler():
                            yield chunk
                    except DirectResponseSignal as e:
                        yield e.text

                return AgentResponse(
                    agent_name=self.name,
                    stream=_wrapped_stream(),
                    is_final=True
                )

        except Exception as e:
            logger.error(f"{self.name}: Error in OpenAI message sending: {e}", exc_info=True)
            return AgentResponse(
                agent_name=self.name,
                content=f"Извините, произошла ошибка в OpenAI: {e}",
                is_final=True
            )

    def get_agent_info(self) -> Dict[str, Any]:
        """Get information about this agent"""
        return {
            'name': self.name,
            'model': getattr(self, 'model_name', 'unknown'),
            'provider': self.provider,
            'tools': len(self._get_tools()) if self._get_tools() else 0,
            'description': self.system_prompt[:100] + "..." if len(self.system_prompt) > 100 else self.system_prompt
        } 

    def _ensure_model(self, state: AgentState):
        """Fetch or reuse model for given subject and self.mode."""
        subject_id = state.subject_id
        key = (self.name, subject_id, self.mode)
        if key in self._model_cache:
            self.provider, self.client, self.model_name = self._model_cache[key]
            logger.debug(f"{self.name}: Using cached model for subject {subject_id}, mode {self.mode}")
            return

        model_info = get_model_info(subject_id, self.mode)
        if not model_info:
            raise RuntimeError(f"Model not found for subject {subject_id} mode {self.mode}")

        api_url, api_key, model_name, provider = model_info
        
        self.provider = provider

        logger.info(f"{self.name}: Initializing {provider} with model {model_name}")
        
        if provider == "gemini_client":
            # Removing http_options for now as it causes 400 INVALID_ARGUMENT in some environments
            client_obj = genai.Client(api_key=api_key)
            logger.info(f"{self.name}: Gemini client for {model_name} configured")
            self._model_cache[key] = (provider, client_obj, model_name)
            self.client = client_obj
            self.model_name = model_name
            
        elif provider in ["openai_client", "openrouter_client", "deepseek_client"]:
            client_obj = OpenAI(api_key=api_key, base_url=api_url)
            logger.info(f"{self.name}: OpenAI-compatible client ({provider}) for {model_name} configured")
            self._model_cache[key] = (provider, client_obj, model_name)
            self.client = client_obj
            self.model_name = model_name

        else:
            raise NotImplementedError(f"Provider {provider} not yet supported in BaseAgent") 