from db_config import connection_pool
import json
import mysql.connector


def get_last_messages(chat_id, limit=100):
    conn = connection_pool.get_connection()
    cursor = conn.cursor(dictionary=True)

    try:
        # Fetch the latest messages first, then reverse to chronological order
        cursor.execute(
            "SELECT message_text, author FROM chat_message WHERE chat_id = %s ORDER BY created_at DESC LIMIT %s",
            (chat_id, limit)
        )
        messages_desc = cursor.fetchall()
        messages = list(reversed(messages_desc))

        formatted_messages = []

        for message in messages:
            message_content = message['message_text']
            # Try to parse the message as JSON. If it's a multimodal message, it will be a JSON string.
            try:
                # If a message from the DB is a JSON string, parse it.
                # This will be the case for new multimodal messages.
                parsed_content = json.loads(message_content)
                message_content = parsed_content
            except (json.JSONDecodeError, TypeError):
                # If it's not a valid JSON string, it's a plain text message (legacy).
                # Keep message_content as a string.
                pass

            formatted_messages.append({
                "role": "user" if message['author'] == 1 else "assistant",
                "content": message_content
            })

        return formatted_messages
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return []
    finally:
        cursor.close()
        conn.close()
