#!/usr/bin/env python3
"""
MCP Server for database questions
Provides JSON-RPC interface via STDIO to query database for specific questions from tu_problems_collection and tu_problem tables
"""
import json
import os
import sys
import logging
import io
import mysql.connector
from typing import Dict, Any, Optional, List

# Force UTF-8 encoding for stdout on Windows
if sys.platform == "win32":
    sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
    sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')

# Configure logging to stderr (не путать с stdout для JSON-RPC)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    stream=sys.stderr
)
logger = logging.getLogger(__name__)

# Define script directory for database imports
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

# Database configuration
# We'll import from parent directory
sys.path.insert(0, os.path.dirname(_SCRIPT_DIR))
try:
    from db_config import db_config
    DB_AVAILABLE = True
    logger.info("Database configuration loaded successfully")
except ImportError:
    DB_AVAILABLE = False
    logger.error("Failed to import database configuration")
    db_config = None

# Subject ID to folder mapping
SUBJECT_MAPPING = {
    1: 'physics',        # Физика
    2: 'algebra',        # Алгебра  
    3: 'geometry',       # Геометрия
    4: 'chemistry',      # Химия
    5: 'biology',        # Биология
    6: 'history',        # История
    7: 'russian',        # Русский язык
    8: 'social_studies', # Обществознание
    9: 'literature',     # Литература
    10: 'geography',     # География
    11: 'english',       # Английский язык
    12: 'informatics'    # Информатика
}

class MCPServer:
    """MCP Server for database questions"""
    
    def __init__(self):
        pass
    
    def get_subject_folder(self, subject_id: int) -> Optional[str]:
        """Get subject folder name by ID"""
        return SUBJECT_MAPPING.get(subject_id)
    
    def get_db_connection(self):
        """Get database connection"""
        if not DB_AVAILABLE or not db_config:
            raise Exception("Database configuration not available")
        # Use the connection pool from db_config to get a connection
        from db_config import connection_pool
        return connection_pool.get_connection()

    def get_question_from_db(self, subject_id: int, grade: int, question_order: int) -> Dict[str, Any]:
        """Get specific question from database by subject_id, grade and question order"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            subject_folder = self.get_subject_folder(subject_id)
            if not subject_folder:
                return {'status': 'error', 'message': f'Unknown subject ID: {subject_id}'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor(dictionary=True)
            
            try:
                # First, find the problems collection for this subject and grade
                cursor.execute("""
                    SELECT id, name 
                    FROM tu_problems_collection 
                    WHERE subject_id = %s AND grade = %s AND testing = 1
                    LIMIT 1
                """, (subject_id, grade))
                
                collection = cursor.fetchone()
                if not collection:
                    return {
                        'status': 'error', 
                        'message': f'No problems collection found for subject {subject_folder}, grade {grade}'
                    }
                
                # Get total number of questions in this collection
                cursor.execute("""
                    SELECT COUNT(*) as total_questions
                    FROM tu_problem 
                    WHERE collection_id = %s
                """, (collection['id'],))
                
                total_count = cursor.fetchone()['total_questions']
                
                # Now get the specific question by order
                cursor.execute("""
                    SELECT id, question, answer, q_order
                    FROM tu_problem 
                    WHERE collection_id = %s AND q_order = %s
                    LIMIT 1
                """, (collection['id'], question_order))
                
                problem = cursor.fetchone()
                if not problem:
                    return {
                        'status': 'error',
                        'message': f'No question found with order {question_order} in collection {collection["name"]}'
                    }
                
                return {
                    'status': 'success',
                    'subject_id': subject_id,
                    'subject_folder': subject_folder,
                    'grade': grade,
                    'question_order': question_order,
                    'total_questions': total_count,
                    'collection_name': collection['name'],
                    'question': problem['question'],
                    'answer': problem['answer'],
                    'problem_id': problem['id'],
                    'is_last_question': question_order >= total_count
                }
                
            finally:
                cursor.close()
                conn.close()
                
        except mysql.connector.Error as e:
            logger.error(f"Database error getting question {subject_id}/{grade}/{question_order}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error getting question {subject_id}/{grade}/{question_order}: {e}")
            return {'status': 'error', 'message': str(e)}

    def list_available_questions(self, subject_id: int, grade: int) -> Dict[str, Any]:
        """[DEPRECATED] List all available questions for a subject and grade"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            subject_folder = self.get_subject_folder(subject_id)
            if not subject_folder:
                return {'status': 'error', 'message': f'Unknown subject ID: {subject_id}'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor(dictionary=True)
            
            try:
                # Find the problems collection for this subject and grade
                cursor.execute("""
                    SELECT id, name 
                    FROM tu_problems_collection 
                    WHERE subject_id = %s AND grade = %s AND testing = 1
                    LIMIT 1
                """, (subject_id, grade))
                
                collection = cursor.fetchone()
                if not collection:
                    return {
                        'status': 'error', 
                        'message': f'No problems collection found for subject {subject_folder}, grade {grade}'
                    }
                
                # Get all questions in this collection
                cursor.execute("""
                    SELECT id, q_order, LEFT(question, 100) as question_preview
                    FROM tu_problem 
                    WHERE collection_id = %s
                    ORDER BY q_order
                """, (collection['id'],))
                
                questions = cursor.fetchall()
                
                return {
                    'status': 'success',
                    'subject_id': subject_id,
                    'subject_folder': subject_folder,
                    'grade': grade,
                    'collection_name': collection['name'],
                    'total_questions': len(questions),
                    'questions': questions
                }
                
            finally:
                cursor.close()
                conn.close()
                
        except mysql.connector.Error as e:
            logger.error(f"Database error listing questions {subject_id}/{grade}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error listing questions {subject_id}/{grade}: {e}")
            return {'status': 'error', 'message': str(e)}

    def list_available_topics(self, subject_id: int) -> Dict[str, Any]:
        """List all available topics for a subject (no grade filtering)"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            subject_folder = self.get_subject_folder(subject_id)
            if not subject_folder:
                return {'status': 'error', 'message': f'Unknown subject ID: {subject_id}'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor(dictionary=True)
            
            try:
                # Get all topics for the selected subject
                cursor.execute("""
                    SELECT id, topic 
                    FROM tu_problems_topic 
                    WHERE subject_id = %s
                    ORDER BY topic
                """, (subject_id,))
                topics_rows = cursor.fetchall()
                topics = [{'id': t['id'], 'topic': t['topic'] } for t in topics_rows]
                
                return {
                    'status': 'success',
                    'subject_id': subject_id,
                    'subject_folder': subject_folder,
                    'total_topics': len(topics),
                    'topics': topics
                }
            finally:
                cursor.close()
                conn.close()
        except mysql.connector.Error as e:
            logger.error(f"Database error listing topics {subject_id}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error listing topics {subject_id}: {e}")
            return {'status': 'error', 'message': str(e)}

    def save_paragraph_progress(self, user_id: int, textbook_id: int, seq_number: str, progress: str) -> Dict[str, Any]:
        """Save student progress for a specific paragraph to tu_paragraph_progress table"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor(dictionary=True)
            
            try:
                # Convert seq_number to int since database field is integer
                try:
                    seq_number_int = int(seq_number)
                except ValueError:
                    return {
                        'status': 'error', 
                        'message': f'Invalid seq_number: {seq_number}. Must be a number.'
                    }
                
                # First, find the contents_id by seq_number and textbook_id
                cursor.execute("""
                    SELECT id, textbook_id, seq_number, chapter_title as title 
                    FROM contents 
                    WHERE seq_number = %s AND textbook_id = %s
                    LIMIT 1
                """, (seq_number_int, textbook_id))
                
                content = cursor.fetchone()
                if not content:
                    return {
                        'status': 'error', 
                        'message': f'No content found for section {seq_number} in textbook {textbook_id}'
                    }
                
                contents_id = content['id']
                
                # Check if progress record already exists for this user and content
                cursor.execute("""
                    SELECT id, progress, study_time
                    FROM tu_paragraph_progress 
                    WHERE user_id = %s AND contents_id = %s
                    ORDER BY study_time DESC
                    LIMIT 1
                """, (user_id, contents_id))
                
                existing_progress = cursor.fetchone()
                
                # Insert new progress record
                cursor.execute("""
                    INSERT INTO tu_paragraph_progress (user_id, contents_id, textbook_id, progress, study_time)
                    VALUES (%s, %s, %s, %s, NOW())
                """, (user_id, contents_id, textbook_id, progress))
                
                conn.commit()
                new_progress_id = cursor.lastrowid
                
                return {
                    'status': 'success',
                    'progress_id': new_progress_id,
                    'user_id': user_id,
                    'contents_id': contents_id,
                    'textbook_id': textbook_id,
                    'seq_number': seq_number,
                    'section_title': content['title'],
                    'progress': progress,
                    'previous_progress': existing_progress['progress'] if existing_progress else None,
                    'message': 'Progress saved successfully'
                }
                
            finally:
                cursor.close()
                conn.close()
                
        except mysql.connector.Error as e:
            logger.error(f"Database error saving progress for user {user_id}, section {seq_number}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error saving progress for user {user_id}, section {seq_number}: {e}")
            return {'status': 'error', 'message': str(e)}

    def get_problems_by_topic(self, subject_id: int, topic_id: int, user_id: int) -> Dict[str, Any]:
        """Get problems by topic ID for a specific subject (no grade filtering)"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            subject_folder = self.get_subject_folder(subject_id)
            if not subject_folder:
                return {'status': 'error', 'message': f'Unknown subject ID: {subject_id}'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor(dictionary=True)
            
            try:
                # Now get all problems for this topic, excluding solved ones
                cursor.execute("""
                    SELECT p.id, p.question, p.answer, p.q_order
                    FROM tu_problem p
                    LEFT JOIN tu_solved_problem sp ON p.id = sp.problem_id AND sp.user_id = %s
                    WHERE p.topic_id = %s AND sp.id IS NULL
                    ORDER BY p.q_order ASC
                """, (user_id, topic_id))
                
                problems = cursor.fetchall()
                
                if not problems:
                    return {
                        'status': 'error',
                        'message': f'No unsolved problems found for topic ID: {topic_id}'
                    }
                
                return {
                    'status': 'success',
                    'subject_id': subject_id,
                    'subject_folder': subject_folder,
                    'topic_id': topic_id,
                    'total_problems': len(problems),
                    'problems': problems
                }
                
            finally:
                cursor.close()
                conn.close()
                
        except mysql.connector.Error as e:
            logger.error(f"Database error getting problems by topic ID {subject_id}/{topic_id}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error getting problems by topic ID {subject_id}/{topic_id}: {e}")
            return {'status': 'error', 'message': str(e)}

    def mark_problem_as_solved(self, user_id: int, problem_id: int) -> Dict[str, Any]:
        """Mark a problem as solved for a specific user"""
        try:
            if not DB_AVAILABLE:
                return {'status': 'error', 'message': 'Database not available'}
            
            conn = self.get_db_connection()
            cursor = conn.cursor()
            
            try:
                # Check if the record already exists
                cursor.execute("""
                    SELECT id FROM tu_solved_problem WHERE user_id = %s AND problem_id = %s
                """, (user_id, problem_id))
                if cursor.fetchone():
                    return {
                        'status': 'success',
                        'message': 'Problem already marked as solved for this user'
                    }
                
                # Insert new record
                cursor.execute("""
                    INSERT INTO tu_solved_problem (user_id, problem_id) VALUES (%s, %s)
                """, (user_id, problem_id))
                
                # Explicitly commit the transaction
                conn.commit()
                
                # Log the number of affected rows to be sure
                affected_rows = cursor.rowcount
                logger.info(f"Transaction committed. Rows affected: {affected_rows}")

                if affected_rows == 0:
                    logger.warning(f"Commit was successful, but no rows were inserted for user {user_id}, problem {problem_id}.")

                return {
                    'status': 'success',
                    'message': f'Problem {problem_id} marked as solved for user {user_id}'
                }
                
            finally:
                cursor.close()
                conn.close()
                
        except mysql.connector.Error as e:
            logger.error(f"Database error marking problem as solved for user {user_id}, problem {problem_id}: {e}")
            return {'status': 'error', 'message': f'Database error: {str(e)}'}
        except Exception as e:
            logger.error(f"Error marking problem as solved for user {user_id}, problem {problem_id}: {e}")
            return {'status': 'error', 'message': str(e)}




class MCPStdioServer:
    """STDIO-based JSON-RPC server"""
    
    def __init__(self):
        self.mcp_server = MCPServer()
        logger.info("MCP Problems Server (STDIO) initialized")
    
    def process_jsonrpc_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
        """Process JSON-RPC request"""
        try:
            method = request.get('method')
            params = request.get('params', {})
            request_id = request.get('id', 1)
            
            logger.debug(f"Processing method: {method} with params: {params}")
            
            if method == 'get_question_from_db':
                subject_id = params.get('subject_id')
                grade = params.get('grade')
                question_order = params.get('question_order')
                if subject_id is None or not grade or question_order is None:
                    result = {'status': 'error', 'message': 'subject_id, grade and question_order parameters required'}
                else:
                    result = self.mcp_server.get_question_from_db(subject_id, grade, question_order)
            elif method == 'list_available_questions':
                subject_id = params.get('subject_id')
                if subject_id is None:
                    result = {'status': 'error', 'message': 'subject_id parameter required'}
                else:
                    # Keep for backward compatibility
                    result = self.mcp_server.list_available_questions(subject_id, 0)
            elif method == 'list_available_topics':
                subject_id = params.get('subject_id')
                if subject_id is None:
                    result = {'status': 'error', 'message': 'subject_id parameter required'}
                else:
                    result = self.mcp_server.list_available_topics(subject_id)
            elif method == 'save_paragraph_progress':
                user_id = params.get('user_id')
                textbook_id = params.get('textbook_id')
                seq_number = params.get('seq_number')
                progress = params.get('progress')
                if user_id is None or textbook_id is None or seq_number is None or progress is None:
                    result = {'status': 'error', 'message': 'user_id, textbook_id, seq_number and progress parameters required'}
                else:
                    result = self.mcp_server.save_paragraph_progress(user_id, textbook_id, seq_number, progress)
            elif method == 'get_problems_by_topic':
                subject_id = params.get('subject_id')
                topic_id = params.get('topic_id')
                if subject_id is None or topic_id is None:
                    result = {'status': 'error', 'message': 'subject_id and topic_id parameters required'}
                else:
                    user_id = params.get('user_id')
                    if user_id is None:
                        result = {'status': 'error', 'message': 'user_id parameter is required'}
                    else:
                        result = self.mcp_server.get_problems_by_topic(subject_id, topic_id, user_id)
            elif method == 'mark_problem_as_solved':
                user_id = params.get('user_id')
                problem_id = params.get('problem_id')
                if user_id is None or problem_id is None:
                    result = {'status': 'error', 'message': 'user_id and problem_id parameters required'}
                else:
                    result = self.mcp_server.mark_problem_as_solved(user_id, problem_id)

            else:
                result = {'status': 'error', 'message': f'Unknown method: {method}'}
            
            return {
                'jsonrpc': '2.0',
                'result': result,
                'id': request_id
            }
            
        except Exception as e:
            logger.error(f"Error processing JSON-RPC request: {e}")
            return {
                'jsonrpc': '2.0',
                'error': {
                    'code': -32603,
                    'message': 'Internal error',
                    'data': str(e)
                },
                'id': request.get('id', 1)
            }
    
    def run(self):
        """Run the STDIO server - read from stdin, write to stdout"""
        logger.info("Starting STDIO server...")
        # Signal that server is ready
        print("SERVER_READY", flush=True)
        logger.info("Waiting for JSON-RPC requests on stdin...")
        
        try:
            # Read line by line from stdin
            for line in sys.stdin:
                line = line.strip()
                if not line:
                    continue
                
                try:
                    # Parse JSON-RPC request
                    request = json.loads(line)
                    logger.debug(f"Received request: {request}")
                    
                    # Process request
                    response = self.process_jsonrpc_request(request)
                    
                    # Send response to stdout
                    response_line = json.dumps(response, ensure_ascii=False)
                    print(response_line, flush=True)
                    
                except json.JSONDecodeError as e:
                    logger.error(f"Invalid JSON received: {e}")
                    error_response = {
                        'jsonrpc': '2.0',
                        'error': {
                            'code': -32700,
                            'message': 'Parse error',
                            'data': str(e)
                        },
                        'id': None
                    }
                    print(json.dumps(error_response), flush=True)
                
                except Exception as e:
                    logger.error(f"Error processing request: {e}")
                    error_response = {
                        'jsonrpc': '2.0',
                        'error': {
                            'code': -32603,
                            'message': 'Internal error',
                            'data': str(e)
                        },
                        'id': None
                    }
                    print(json.dumps(error_response), flush=True)
        
        except KeyboardInterrupt:
            logger.info("Server shutting down...")
        except Exception as e:
            logger.error(f"Server error: {e}")
            sys.exit(1)

def main():
    """Main entry point"""
    try:
        server = MCPStdioServer()
        server.run()
    except Exception as e:
        logger.error(f"Failed to start server: {e}")
        sys.exit(1)

if __name__ == '__main__':
    main()
