Files
Backend/db/repositories/ai_prediction_repository.py
T

119 lines
5.5 KiB
Python

from typing import List, Optional
from datetime import datetime, date
import logging
from db.connection import get_connection
from model.ai_prediction import AIPrediction
logger = logging.getLogger(__name__)
class AIPredictionsRepository:
def get_all(self) -> List[AIPrediction]:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM ai_predictions ORDER BY prediction_date DESC, product_id")
predictions = [
AIPrediction(
id=row[0],
product_id=row[1],
prediction_date=row[2],
days_until_stockout=row[3],
recommended_order=row[4],
confidence_score=row[5],
created_at=row[6]
) for row in cur.fetchall()
]
logger.debug(f"Получено {len(predictions)} прогнозов")
return predictions
except Exception as e:
logger.error(f"Ошибка получения прогнозов: {e}")
return []
def get_by_id(self, prediction_id: int) -> Optional[AIPrediction]:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM ai_predictions WHERE id = %s", (prediction_id,))
row = cur.fetchone()
if row:
logger.debug(f"Прогноз {prediction_id} найден")
return AIPrediction(*row)
logger.warning(f"Прогноз {prediction_id} не найден")
return None
except Exception as e:
logger.error(f"Ошибка получения прогноза {prediction_id}: {e}")
return None
def get_by_product(self, product_id: str, limit: int = 10) -> List[AIPrediction]:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT * FROM ai_predictions
WHERE product_id = %s
ORDER BY prediction_date DESC
LIMIT %s
""", (product_id, limit))
predictions = [AIPrediction(*row) for row in cur.fetchall()]
logger.debug(f"Получено {len(predictions)} прогнозов для товара {product_id}")
return predictions
except Exception as e:
logger.error(f"Ошибка получения прогноза по товару {product_id}: {e}")
return []
def get_latest_predictions(self) -> List[AIPrediction]:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT DISTINCT ON (product_id) *
FROM ai_predictions
ORDER BY product_id, prediction_date DESC
""")
predictions = [AIPrediction(*row) for row in cur.fetchall()]
logger.debug(f"Получено {len(predictions)} последних прогнозов по товарам")
return predictions
except Exception as e:
logger.error(f"Ошибка получения последних прогнозов по товарам: {e}")
return []
def create_prediction(self,
product_id: str,
prediction_date: date,
days_until_stockout: int,
recommended_order: int,
confidence_score: float) -> Optional[int]:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO ai_predictions
(product_id, prediction_date, days_until_stockout, recommended_order, confidence_score)
VALUES (%s, %s, %s, %s, %s)
RETURNING id
""", (product_id, prediction_date, days_until_stockout, recommended_order, confidence_score))
prediction_id = cur.fetchone()[0]
conn.commit()
logger.debug(f"Создан новый прогноз ID: {prediction_id} для товара {product_id}")
return prediction_id
except Exception as e:
logger.error(f"Ошибка создания прогноза: {e}")
return None
def delete_old_predictions(self, older_than_days: int = 90) -> int:
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
DELETE FROM ai_predictions
WHERE prediction_date < CURRENT_DATE - INTERVAL '%s days'
""", (older_than_days,))
deleted_count = cur.rowcount
conn.commit()
logger.debug(f"Удалено {deleted_count} старых прогнозов старше {older_than_days} дней")
return deleted_count
except Exception as e:
logger.error(f"Ошибка удаления старых прогнозов: {e}")
return 0