Source code for locust.contrib.qdrant

from locust import User, events

import time
from typing import Any

from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams


class QdrantLocustClient:
    """Qdrant Client Wrapper"""

    def __init__(self, url, collection_name, api_key=None, timeout=60, **kwargs):
        self.url = url
        self.collection_name = collection_name
        self.api_key = api_key
        self.timeout = timeout

        self.client = QdrantClient(
            url=self.url,
            api_key=self.api_key,
            timeout=self.timeout,
            **kwargs,
        )

    def close(self):
        self.client.close()

    def create_collection(self, vectors_config, **kwargs):
        if not self.client.collection_exists(collection_name=self.collection_name):
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=vectors_config,
                **kwargs,
            )

    def upsert(self, points):
        start = time.time()
        try:
            result = self.client.upsert(
                collection_name=self.collection_name,
                points=points,
            )
            total_time = (time.time() - start) * 1000
            return {"success": True, "response_time": total_time, "result": result}
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def search(
        self,
        query,
        limit=10,
        query_filter=None,
        search_params=None,
        with_payload=True,
    ):
        start = time.time()
        try:
            result = self.client.query_points(
                collection_name=self.collection_name,
                query=query,
                limit=limit,
                query_filter=query_filter,
                search_params=search_params,
                with_payload=with_payload,
            )
            total_time = (time.time() - start) * 1000
            empty = len(result.points) == 0
            return {
                "success": not empty,
                "response_time": total_time,
                "empty": empty,
                "result": result,
            }
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def scroll(
        self,
        scroll_filter=None,
        limit=10,
        with_payload=True,
    ):
        start = time.time()
        try:
            result, next_offset = self.client.scroll(
                collection_name=self.collection_name,
                scroll_filter=scroll_filter,
                limit=limit,
                with_payload=with_payload,
            )
            total_time = (time.time() - start) * 1000
            empty = len(result) == 0
            return {
                "success": not empty,
                "response_time": total_time,
                "empty": empty,
                "result": result,
                "next_offset": next_offset,
            }
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def delete(self, points_selector):
        start = time.time()
        try:
            result = self.client.delete(
                collection_name=self.collection_name,
                points_selector=points_selector,
            )
            total_time = (time.time() - start) * 1000
            return {"success": True, "response_time": total_time, "result": result}
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }


# ----------------------------------
# Locust User wrapper
# ----------------------------------


[docs] class QdrantUser(User): """Locust User implementation for Qdrant operations. This class wraps the QdrantLocustClient implementation and translates client method results into Locust request events so that performance statistics are collected properly. Parameters ---------- host : str Qdrant server URL, e.g. ``"http://localhost:6333"``. collection_name : str The name of the collection to operate on. **client_kwargs Additional keyword arguments forwarded to the client. **collection_kwargs Additional keyword arguments forwarded to ``create_collection``. """ abstract = True url: str = "http://localhost:6333" api_key: str | None = None collection_name: str | None = None timeout: int = 60 vectors_config: VectorParams | None = None client_kwargs: dict | None = None collection_kwargs: dict | None = None def __init__(self, environment): super().__init__(environment) if self.collection_name is None: raise ValueError("'collection_name' must be provided for QdrantUser") self.client_type = "qdrant" self.client = QdrantLocustClient( url=self.url, api_key=self.api_key, collection_name=self.collection_name, timeout=self.timeout, **(self.client_kwargs or {}), ) if self.vectors_config is not None: self.client.create_collection(vectors_config=self.vectors_config, **(self.collection_kwargs or {})) @staticmethod def _fire_event(request_type: str, name: str, result: dict[str, Any]): """Emit a Locust request event from a Qdrant client result dict.""" response_time = int(result.get("response_time", 0)) events.request.fire( request_type=f"{request_type}", name=name, response_time=response_time, response_length=0, exception=result.get("exception"), ) def upsert(self, points): result = self.client.upsert(points) self._fire_event(self.client_type, "upsert", result) return result def search( self, query, limit=10, query_filter=None, search_params=None, with_payload=True, ): result = self.client.search( query=query, limit=limit, query_filter=query_filter, search_params=search_params, with_payload=with_payload, ) self._fire_event(self.client_type, "search", result) return result def scroll( self, scroll_filter=None, limit=10, with_payload=True, ): result = self.client.scroll( scroll_filter=scroll_filter, limit=limit, with_payload=with_payload, ) self._fire_event(self.client_type, "scroll", result) return result def delete(self, points_selector): result = self.client.delete(points_selector) self._fire_event(self.client_type, "delete", result) return result def on_stop(self): self.client.close()