PyTorch в production — тяжёлый и медленный. Мы конвертируем модели в ONNX и раздаём через FastAPI:
# serving/app.py — production model serving
import io
import numpy as np
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime as ort
from PIL import Image
from prometheus_client import Counter, Histogram, generate_latest
from pydantic import BaseModel
import mlflow
import time
app = FastAPI(title="RetailVision People Counter API", version="2.3.0")
# Метрики
INFERENCE_LATENCY = Histogram(
'model_inference_seconds', 'Inference latency',
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
)
REQUESTS_TOTAL = Counter('model_requests_total', 'Total inference requests',
['model_name', 'status'])
PREDICTIONS_TOTAL = Counter('model_predictions_total', 'Detections count')
class ModelManager:
"""Управление загрузкой и версионированием моделей."""
def __init__(self):
self.sessions: dict[str, ort.InferenceSession] = {}
self.active_model: str = None
def load_model(self, model_name: str, version: str):
"""Загружаем модель из MLflow Model Registry."""
model_uri = f"models:/{model_name}/{version}"
local_path = mlflow.artifacts.download_artifacts(model_uri)
onnx_path = f"{local_path}/model.onnx"
# Настраиваем ONNX Runtime с GPU
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 4
session = ort.InferenceSession(onnx_path, sess_options, providers=providers)
model_key = f"{model_name}:{version}"
self.sessions[model_key] = session
self.active_model = model_key
return model_key
def predict(self, image_array: np.ndarray, model_key: str = None) -> dict:
key = model_key or self.active_model
session = self.sessions[key]
with INFERENCE_LATENCY.time():
outputs = session.run(None, {"input": image_array})
boxes, scores, labels = outputs
# Фильтруем по confidence threshold
mask = scores[0] > 0.5
return {
"boxes": boxes[0][mask].tolist(),
"scores": scores[0][mask].tolist(),
"count": int(mask.sum()),
}
manager = ModelManager()
@app.on_event("startup")
async def startup():
manager.load_model("people-counter", "Production")
class PredictionResponse(BaseModel):
count: int
detections: list[dict]
model_version: str
inference_ms: float
@app.post("/api/v1/detect", response_model=PredictionResponse)
async def detect_people(file: UploadFile):
if not file.content_type.startswith("image/"):
raise HTTPException(400, "File must be an image")
start = time.monotonic()
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
image = image.resize((640, 640))
img_array = np.array(image).astype(np.float32) / 255.0
img_array = np.transpose(img_array, (2, 0, 1))
img_array = np.expand_dims(img_array, axis=0)
result = manager.predict(img_array)
elapsed_ms = (time.monotonic() - start) * 1000
REQUESTS_TOTAL.labels(model_name="people-counter", status="success").inc()
PREDICTIONS_TOTAL.inc(result["count"])
return PredictionResponse(
count=result["count"],
detections=[
{"box": box, "confidence": score}
for box, score in zip(result["boxes"], result["scores"])
],
model_version=manager.active_model,
inference_ms=round(elapsed_ms, 2),
)
@app.get("/metrics")
async def metrics():
from starlette.responses import Response
return Response(generate_latest(), media_type="text/plain")
Оставить комментарий