import json
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async


class NotificationConsumer(AsyncWebsocketConsumer):
    """WebSocket consumer for real-time notifications."""

    async def connect(self):
        self.user = self.scope.get('user')
        if not self.user or not self.user.is_authenticated:
            await self.close()
            return

        self.room_group_name = f"notifications_{self.user.id}"

        await self.channel_layer.group_add(
            self.room_group_name,
            self.channel_name
        )
        await self.accept()

        unread = await self.get_unread_notifications()
        await self.send(text_data=json.dumps({
            'type': 'connection_established',
            'unread_count': unread['count'],
            'unread': unread['notifications']
        }))

    async def disconnect(self, close_code):
        if hasattr(self, 'room_group_name'):
            await self.channel_layer.group_discard(
                self.room_group_name,
                self.channel_name
            )

    async def receive(self, text_data):
        data = json.loads(text_data)
        action = data.get('action')

        if action == 'mark_read':
            notification_id = data.get('notification_id')
            if notification_id:
                success = await self.mark_notification_read(notification_id)
                if success:
                    await self.send(text_data=json.dumps({
                        'type': 'notification_read',
                        'notification_id': notification_id
                    }))

        elif action == 'mark_all_read':
            count = await self.mark_all_notifications_read()
            await self.send(text_data=json.dumps({
                'type': 'all_notifications_read',
                'count': count
            }))

        elif action == 'get_unread':
            unread = await self.get_unread_notifications()
            await self.send(text_data=json.dumps({
                'type': 'unread_notifications',
                'unread_count': unread['count'],
                'unread': unread['notifications']
            }))

    async def notification(self, event):
        """Handle incoming notification broadcast."""
        await self.send(text_data=json.dumps({
            'type': 'new_notification',
            'notification': event['notification']
        }))

    @database_sync_to_async
    def get_unread_notifications(self):
        from .models import Notification
        from .handler import NotificationHandler
        return {
            'count': NotificationHandler.get_unread_count(self.user),
            'notifications': [
                {
                    'id': str(n.id),
                    'title': n.title,
                    'message': n.message,
                    'notification_type': n.notification_type,
                    'priority': n.priority,
                    'icon': n.icon,
                    'link': n.link,
                    'is_read': n.is_read,
                    'created_at': n.created_at.isoformat(),
                }
                for n in NotificationHandler.get_unread(self.user, limit=20)
            ]
        }

    @database_sync_to_async
    def mark_notification_read(self, notification_id):
        from .handler import NotificationHandler
        return NotificationHandler.mark_as_read(notification_id, self.user)

    @database_sync_to_async
    def mark_all_notifications_read(self):
        from .handler import NotificationHandler
        return NotificationHandler.mark_all_as_read(self.user)
