from typing import Sequence, Optional from sqlalchemy import select, desc, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from datetime import date from app.domain.models import Appointment class AppointmentsRepository: def __init__(self, db: AsyncSession): self.db = db async def get_all(self, start_date: date | None = None, end_date: date | None = None) -> Sequence[Appointment]: stmt = ( select(Appointment) .options(joinedload(Appointment.type)) .options(joinedload(Appointment.patient)) .options(joinedload(Appointment.doctor)) .order_by(desc(Appointment.appointment_datetime)) ) if start_date: stmt = stmt.filter(Appointment.appointment_datetime >= start_date) if end_date: stmt = stmt.filter(Appointment.appointment_datetime <= end_date) result = await self.db.execute(stmt) return result.scalars().all() async def get_by_doctor_id(self, doctor_id: int, start_date: date | None = None, end_date: date | None = None) -> \ Sequence[Appointment]: stmt = ( select(Appointment) .options(joinedload(Appointment.type)) .options(joinedload(Appointment.patient)) .options(joinedload(Appointment.doctor)) .filter_by(doctor_id=doctor_id) .order_by(desc(Appointment.appointment_datetime)) ) if start_date: stmt = stmt.filter(Appointment.appointment_datetime >= start_date) if end_date: stmt = stmt.filter(Appointment.appointment_datetime <= end_date) result = await self.db.execute(stmt) return result.scalars().all() async def get_upcoming_by_doctor_id(self, doctor_id: int) -> Sequence[Appointment]: stmt = ( select(Appointment) .options(joinedload(Appointment.type)) .options(joinedload(Appointment.patient)) .options(joinedload(Appointment.doctor)) .filter_by(doctor_id=doctor_id) .filter(Appointment.appointment_datetime >= func.now()) .order_by(Appointment.appointment_datetime) .limit(5) ) result = await self.db.execute(stmt) return result.scalars().all() async def get_by_patient_id(self, patient_id: int, start_date: date | None = None, end_date: date | None = None) -> \ Sequence[Appointment]: stmt = ( select(Appointment) .options(joinedload(Appointment.type)) .options(joinedload(Appointment.patient)) .options(joinedload(Appointment.doctor)) .filter_by(patient_id=patient_id) .order_by(desc(Appointment.appointment_datetime)) ) if start_date: stmt = stmt.filter(Appointment.appointment_datetime >= start_date) if end_date: stmt = stmt.filter(Appointment.appointment_datetime <= end_date) result = await self.db.execute(stmt) return result.scalars().all() async def get_by_id(self, appointment_id: int) -> Optional[Appointment]: stmt = ( select(Appointment) .options(joinedload(Appointment.type)) .options(joinedload(Appointment.patient)) .options(joinedload(Appointment.doctor)) .filter_by(id=appointment_id) ) result = await self.db.execute(stmt) return result.scalars().first() async def create(self, appointment: Appointment) -> Appointment: self.db.add(appointment) await self.db.commit() await self.db.refresh(appointment) return appointment async def update(self, appointment: Appointment) -> Appointment: await self.db.merge(appointment) await self.db.commit() return appointment async def delete(self, appointment) -> Appointment: await self.db.delete(appointment) await self.db.commit() return appointment