# backend/app.py
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

import sys
from pathlib import Path 
from dotenv import load_dotenv


class PrefixMiddleware:
    """Middleware to add a URL prefix to all routes."""
    def __init__(self, app, prefix='/app'):
        self.app = app
        self.prefix = prefix

    def __call__(self, environ, start_response):
        # Strip the prefix from the path before it reaches your routes
        if environ['PATH_INFO'].startswith(self.prefix):
            environ['PATH_INFO'] = environ['PATH_INFO'][len(self.prefix):]
            environ['SCRIPT_NAME'] = self.prefix
        return self.app(environ, start_response)


def find_env_file():
    """
    Find the .env file in order of preference:
    1. ENV_FILE environment variable
    2. .env in current working directory
    3. .env in the same directory as app.py
    4. .env.production in the same directory as app.py
    5. /var/www/platform/backend/.env (production default)
    """
    # Priority 1: ENV_FILE environment variable
    env_file = os.environ.get('ENV_FILE')
    if env_file and Path(env_file).exists():
        return Path(env_file)
    
    # Priority 2: Current working directory
    cwd_env = Path.cwd() / '.env'
    if cwd_env.exists():
        return cwd_env
    
    # Priority 3: Same directory as app.py
    app_dir = Path(__file__).parent.absolute()
    app_env = app_dir / '.env'
    if app_env.exists():
        return app_env
    
    # Priority 4: .env.production in app directory
    app_env_prod = app_dir / '.env.production'
    if app_env_prod.exists():
        return app_env_prod
    
    # Priority 5: Production default path (for your platform folder)
    prod_env = Path('/var/www/platform/backend/.env')
    if prod_env.exists():
        return prod_env
    
    # No .env found - return None (will use system env vars)
    return None

# Load .env file
env_path = find_env_file()

if env_path:
    print(f"[FIRE] Loading .env from: {env_path}")
    load_dotenv(dotenv_path=env_path, override=True)
else:
    print(f"[WARNING] No .env file found. Using system environment variables.")
    print(f"   Looking in: {Path(__file__).parent.absolute()}")
    load_dotenv()


from middleware.request_logger import request_logger_middleware
from extensions import limiter


import pymysql
pymysql.install_as_MySQLdb()
import logging
from flask import Flask, jsonify
from flask_migrate import Migrate
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from datetime import datetime, timedelta 
from flask_caching import Cache
from extensions import cache, socketio, mail 
from events import register_socket_events
from flask_socketio import SocketIO, emit
from sockets.calendar_sockets import register_calendar_handlers
from sockets.alert_sockets import register_alert_handlers
from utils.config_helper import get_frontend_url, get_cors_origins
from cli_commands import (
    create_super_admin,
    create_test_users,
    list_users,
    assign_role_to_user,
    check_user_permissions,
    cleanup_expired_tokens,
    init_rbac_system,
    check_relationships
)

# Load environment variables
#load_dotenv()

# below robust cache when production ready
#cache = Cache(config={
#    'CACHE_TYPE': 'RedisCache',
#    'CACHE_REDIS_HOST': 'localhost',
#    'CACHE_REDIS_PORT': 6379,
#    'CACHE_DEFAULT_TIMEOUT': 300
#})

# Import db from models package (this will be the single source of truth)
from models import db


jwt = JWTManager()


def verify_system_health():
    """Verify database connections and system health at startup"""
    from utils.db_session_manager import DBSessionManager
    import logging
    logger = logging.getLogger(__name__)
    
    logger.info("[SEARCH] Running system health check...")
    
    try:
        # Test database connection
        companies = DBSessionManager.get_companies(include_inactive=True)
        
        if companies:
            logger.info(f"[OK] Database connection verified: {len(companies)} companies found")
            for company in companies[:3]:  # Show first 3 companies
                logger.info(f"   - Company {company['id']}: {company['name']} (active: {company['is_active']})")
            if len(companies) > 3:
                logger.info(f"   ... and {len(companies) - 3} more")
        else:
            logger.warning("[WARNING] No companies found in database. Seeder may not work until companies are created.")
            logger.info("   Run 'python run.py' to initialize default companies")
        
        return True
    except Exception as e:
        logger.error(f"[ERROR] Database health check failed: {e}")
        import traceback
        traceback.print_exc()
        return False


def create_app(config_name=None):
    """Application factory"""
    app = Flask(__name__)

    # Apply the prefix middleware
    app.wsgi_app = PrefixMiddleware(app.wsgi_app, prefix='/app')
    

    # take note to enable for production
    #if os.getenv('CACHE_TYPE') == 'RedisCache':
    #    app.config['CACHE_TYPE'] = 'RedisCache'
    #    app.config['CACHE_REDIS_URL'] = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
    #else:
    #    app.config['CACHE_TYPE'] = 'SimpleCache'

   
    # ========== 1. ABSOLUTELY FIRST: Configure basic app settings ========== 
    # Configuration
    app.config.from_object('config.Config')
    app.config['DEBUG'] = os.getenv('FLASK_DEBUG', '1') == '1'
    app.config['APP_URL'] = os.getenv('APP_URL', 'http://localhost:3000')
    app.config['PROPAGATE_EXCEPTIONS'] = True
    app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'dev-secret-key')
    app.config['JWT_SECRET_KEY'] = os.environ.get('JWT_SECRET_KEY', 'jwt-secret-key')
    app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=int(os.getenv('JWT_ACCESS_TOKEN_EXPIRES', 3600)))
    app.config['JWT_TOKEN_LOCATION'] = ['headers']
    app.config['JWT_HEADER_NAME'] = 'Authorization'
    app.config['JWT_HEADER_TYPE'] = 'Bearer'
    

    # IMPORTANT: Configure JWT to handle integer subjects
    app.config['JWT_IDENTITY_CLAIM'] = 'sub'
    app.config['JWT_JSON_KEY'] = 'access_token'
    
    # Additional JWT security settings
    app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(days=30)
    app.config['JWT_BLACKLIST_ENABLED'] = False
    app.config['JWT_ERROR_MESSAGE_KEY'] = 'message'
   
    # Database configuration - build explicitly
    db_user = os.environ.get('DB_USER')
    db_pass = os.environ.get('DB_PASSWORD')
    db_host = os.environ.get('DB_HOST', 'localhost')
    db_port = os.environ.get('DB_PORT', '3306')
    db_name = os.environ.get('DB_NAME', 'service_platform_db')

    print(f"[FIX] Building connection for: {db_user}@{db_host}:{db_port}/{db_name}")

 
    # Database configuration
    app.config['SQLALCHEMY_DATABASE_URI'] = os.environ.get(
    'DATABASE_URL',
    #f"mysql+pymysql://{os.environ.get('DB_USER', 'alphai_suser')}:{os.environ.get('DB_PASSWORD', '')}@{os.environ.get('DB_HOST', 'localhost')}:{os.environ.get('DB_PORT', '3306')}/{os.environ.get('DB_NAME', 'service_platform_db')}"
    f"mysql+pymysql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/{os.environ.get('DB_NAME')}"
)

    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    app.config['SQLALCHEMY_ENGINE_OPTIONS'] = {
        'pool_recycle': 300,
        'pool_pre_ping': True,
        'pool_size': 10,
        'max_overflow': 20
    }


    # ========== 2. CRITICAL: Make Flask use the SocketIO instance ==========
    # This ensures Socket.IO handles WebSocket upgrade requests, not Flask.
    #app.wsgi_app = socketio.WSGIApp(socketio, app.wsgi_app)


    # ========== 3. INITIALIZE SOCKET.IO EARLY - BEFORE ANYTHING ELSE ==========
    #frontend_url = os.getenv('FRONTEND_URL', 'https://localhost:3000')
    frontend_url = get_frontend_url()
    socketio.init_app(app, 
        cors_allowed_origins=frontend_url,
        logger=True,
        engineio_logger=True,
        ping_timeout=60,
        ping_interval=25,
        transports=['websocket', 'polling']
    )


    # ========== 4. NOW initialize other extensions ==========
    # Initialize extensions with app
    db.init_app(app)
    Migrate(app, db)
    jwt = JWTManager(app)
    #CORS(app, supports_credentials=True, origins=["http://localhost:3000"])
    #CORS(app, supports_credentials=True, origins=["https://localhost:3000"])



    # Initialize cache with app
    cache.init_app(app, config={
        'CACHE_TYPE': 'SimpleCache',  # Use SimpleCache for development
        'CACHE_DEFAULT_TIMEOUT': 300
    })


    mail.init_app(app)

    limiter.init_app(app)


    # ========== 5. MIDDLEWARE (after extensions) ==========
    # Initialize middleware AFTER app is created
    from middleware.performance import PerformanceMiddleware
    from middleware.tenant_context import TenantContextMiddleware

    PerformanceMiddleware(app) 
    TenantContextMiddleware(app) 


    # ========== 6. CORS CONFIGURATION ==========
    # Configure CORS
    #cors_origins = os.getenv('CORS_ORIGINS', 'https://localhost:3000').split(',')
    #cors_origins = get_cors_origins()
    cors_origins = ["https://app.synzhi.com"]  # Set explicitly
    CORS(app,
         supports_credentials=os.getenv('CORS_SUPPORTS_CREDENTIALS', 'true').lower() == 'true',
         origins=cors_origins,
         methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
         allow_headers=["Content-Type", "Authorization", "X-Requested-With", "Accept", "Origin"],
         expose_headers=["Content-Type", "Authorization", "Content-Length"],
         max_age=600)


    # ========== 7. BLUEPRINTS ==========    
    # Register blueprints
    from routes.api_v1 import api_v1
    app.register_blueprint(api_v1)
   
 
    # ========== 8. SOCKET.IO EVENT HANDLERS ==========    
    # Register socket events
    register_socket_events(socketio)

    # Register AI agent socket handlers
    from sockets.ai_agent_sockets import register_ai_agent_handlers
    register_ai_agent_handlers(socketio)
    register_calendar_handlers(socketio)

    from sockets.approval_sockets import register_approval_handlers
    register_approval_handlers(socketio)

    from sockets.admin_sockets import register_admin_handlers
    register_admin_handlers(socketio)

    from sockets.alert_sockets import register_alert_handlers 
    register_alert_handlers(socketio)

    # ========== 9. CLI COMMANDS ==========
    # Register CLI commands
    app.cli.add_command(create_super_admin)
    app.cli.add_command(create_test_users)
    app.cli.add_command(list_users)
    app.cli.add_command(assign_role_to_user)
    app.cli.add_command(check_user_permissions)
    app.cli.add_command(init_rbac_system)
    app.cli.add_command(cleanup_expired_tokens)
    app.cli.add_command(check_relationships)


    # ========== 10. ROUTES ==========
    # In app.py, after blueprint registration
    @app.route('/health', methods=['GET'])
    def health_check():
        """Simple health check endpoint"""
        return jsonify({
            'status': 'healthy',
            'timestamp': datetime.now().isoformat(),
            'version': '1.0.0'
        }), 200
   

    # ========== 11. JWT CALLBACKS ========== 
    # Configure JWT callbacks to handle integer subjects
    @jwt.user_identity_loader
    def user_identity_lookup(user):
        """Convert user identifier to string for JWT"""
        if user is None:
            return None
        if hasattr(user, 'id'):
            return str(user.id)
        return str(user)
    
    @jwt.user_lookup_loader
    def user_lookup_callback(_jwt_header, jwt_data):
        """Load user from database using the identity from JWT"""
        identity = jwt_data["sub"]
        from models import User

        try:
            # Handle different identity formats
            if isinstance(identity, (list, tuple)):
                identity = identity[0] if identity else None
            
            # Convert to int if it's a string number
            if isinstance(identity, str):
                if identity.isdigit():
                    identity = int(identity)
                else:
                    # Try to extract number from string like "1" or "['1']"
                    import re
                    numbers = re.findall(r'\d+', identity)
                    if numbers:
                        identity = int(numbers[0])
            
            return User.query.get(identity)
        except Exception as e:
            app.logger.error(f"User lookup error: {str(e)}")
            return None
    
    # JWT error handlers
    @jwt.unauthorized_loader
    def unauthorized_callback(error):
        """Handle missing or invalid JWT"""
        return {
            'success': False,
            'error': 'Authorization required',
            'message': str(error)
        }, 401

    @jwt.invalid_token_loader
    def invalid_token_callback(error):
        """Handle invalid JWT"""
        return {
            'success': False,
            'error': 'Invalid token',
            'message': str(error)
        }, 422

    @jwt.expired_token_loader
    def expired_token_callback(jwt_header, jwt_data):
        """Handle expired JWT"""
        return {
            'success': False,
            'error': 'Token has expired',
            'message': 'Please log in again'
        }, 401

    @jwt.revoked_token_loader
    def revoked_token_callback(jwt_header, jwt_data):
        """Handle revoked JWT"""
        return {
            'success': False,
            'error': 'Token has been revoked',
            'message': 'Please log in again'
        }, 401

    @jwt.needs_fresh_token_loader
    def needs_fresh_token_callback(jwt_header, jwt_data):
        """Handle request that needs fresh token"""
        return {
            'success': False,
            'error': 'Fresh token required',
            'message': 'Please re-authenticate'
        }, 401



    @jwt.additional_claims_loader
    def add_claims_to_access_token(user):
        """Add additional claims to JWT token"""
        # Import User here to avoid circular imports
        from models import User
    
        if isinstance(user, str):
            user = User.query.get(int(user))
        
        if not user:
            return {}
        
        return {
            'user_id': user.id,
            'email': user.email,
            'username': user.username,
            'role': user.role,
            'company_id': user.company_id,
            'permissions': user.permissions  # This helps frontend
        }

    @jwt.token_in_blocklist_loader
    def check_if_token_revoked(jwt_header, jwt_payload):
        """Check if token has been revoked"""
        from models.token_blocklist import TokenBlocklist 

        jti = jwt_payload['jti']
        token = db.session.query(TokenBlocklist.id).filter_by(jti=jti).scalar()
        return token is not None


    
    # Setup logging
    logging.basicConfig(
        level=logging.DEBUG if app.config['DEBUG'] else logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
   
    # ========== 12. DATABASE CREATION ========== 
    # Create tables within app context
    with app.app_context():
        #from sqlalchemy import inspect
        #from models.role_permissions import RolePermission
        #from models.role import Role

        #print("\n=== RELATIONSHIP DIAGNOSTIC ===")

        # Inspect Role relationships
        #role_inspector = inspect(Role)
        #print("\nRole relationships:")
        #for rel in role_inspector.relationships:
        #    print(f"  {rel.key} -> {rel.mapper.class_.__name__}")
        #    print(f"    back_populates: {rel.back_populates}")
        #    print(f"    viewonly: {rel.viewonly}")
    
        # Inspect RolePermission relationships
        #rp_inspector = inspect(RolePermission)
        #print("\nRolePermission relationships:")
        #for rel in rp_inspector.relationships:
        #    print(f"  {rel.key} -> {rel.mapper.class_.__name__}")
        #    print(f"    back_populates: {rel.back_populates}")
        #    print(f"    viewonly: {rel.viewonly}")


        print("\n[LIST] BLUEPRINT REGISTRATION:")
        for name, blueprint in app.blueprints.items():
            print(f"  {name}: {blueprint.url_prefix}")
        
        # Debug: Print all registered models
        #print("\n=== REGISTERED MODELS ===")
        #model_count = 0
        #for mapper in db.Model.registry.mappers:
        #    print(f"  [OK] {mapper.class_.__name__}")
        #    model_count += 1
        #print(f"=========================")
        #print(f" Total models registered: {model_count}")
        
        # Create tables if they don't exist
        db.create_all()
        print("[OK] Database tables created/verified")
        
        # [FIRE] ADD THIS: Verify system health after DB is ready
        verify_system_health()


    # logger middleware
    logger_middleware = request_logger_middleware()
    app.before_request(logger_middleware['before_request'])
    app.after_request(logger_middleware['after_request'])
    
    return app


# Create app instance
app = create_app()

# Export app for use in run.py
__all__ = ['app', 'db', 'jwt']
application = app  # For cPanel WSGI compatibility
