Source code for locust.contrib.milvus

import gevent.monkey

gevent.monkey.patch_all()
import grpc.experimental.gevent as grpc_gevent

grpc_gevent.init_gevent()

from locust import User, events

import time
from abc import ABC, abstractmethod
from typing import Any

from pymilvus import CollectionSchema, MilvusClient
from pymilvus.milvus_client import IndexParams


class BaseClient(ABC):
    @abstractmethod
    def close(self) -> None:
        pass

    @abstractmethod
    def create_collection(self, schema, index_params, **kwargs) -> None:
        pass

    @abstractmethod
    def insert(self, data) -> dict[str, Any]:
        pass

    @abstractmethod
    def upsert(self, data) -> dict[str, Any]:
        pass

    @abstractmethod
    def search(
        self,
        data,
        anns_field,
        limit,
        filter="",
        search_params=None,
        output_fields=None,
        calculate_recall=False,
        ground_truth=None,
    ) -> dict[str, Any]:
        pass

    @abstractmethod
    def hybrid_search(self, reqs, ranker, limit, output_fields=None) -> dict[str, Any]:
        pass

    @abstractmethod
    def query(self, filter, output_fields=None) -> dict[str, Any]:
        pass

    @abstractmethod
    def delete(self, filter) -> dict[str, Any]:
        pass


class MilvusV2Client(BaseClient):
    """Milvus v2 Python SDK Client Wrapper"""

    def __init__(self, uri, collection_name, token="root:Milvus", db_name="default", timeout=60):
        self.uri = uri
        self.collection_name = collection_name
        self.token = token
        self.db_name = db_name
        self.timeout = timeout

        # Initialize MilvusClient v2
        self.client = MilvusClient(
            uri=self.uri,
            token=self.token,
            db_name=self.db_name,
            timeout=self.timeout,
        )

    def close(self):
        self.client.close()

    def create_collection(self, schema, index_params, **kwargs):
        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            index_params=index_params,
            **kwargs,
        )

    def insert(self, data):
        start = time.time()
        try:
            result = self.client.insert(collection_name=self.collection_name, data=data)
            total_time = (time.time() - start) * 1000
            return {"success": True, "response_time": total_time, "result": result}
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def upsert(self, data):
        start = time.time()
        try:
            result = self.client.upsert(collection_name=self.collection_name, data=data)
            total_time = (time.time() - start) * 1000
            return {"success": True, "response_time": total_time, "result": result}
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def search(
        self,
        data,
        anns_field,
        limit,
        filter="",
        search_params=None,
        output_fields=None,
        calculate_recall=False,
        ground_truth=None,
    ):
        if output_fields is None:
            output_fields = ["id"]

        start = time.time()
        try:
            result = self.client.search(
                collection_name=self.collection_name,
                data=data,
                anns_field=anns_field,
                filter=filter,
                limit=limit,
                search_params=search_params,
                output_fields=output_fields,
            )
            total_time = (time.time() - start) * 1000
            empty = len(result) == 0 or all(len(r) == 0 for r in result)

            # Prepare base result
            search_result = {
                "success": not empty,
                "response_time": total_time,
                "empty": empty,
                "result": result,
            }

            # Calculate recall if requested
            if calculate_recall and ground_truth is not None and not empty:
                recall_value = self.get_recall(result, ground_truth, limit)
                search_result["recall"] = recall_value

            return search_result
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def hybrid_search(self, reqs, ranker, limit, output_fields=None):
        if output_fields is None:
            output_fields = ["id"]

        start = time.time()
        try:
            result = self.client.hybrid_search(
                collection_name=self.collection_name,
                reqs=reqs,
                ranker=ranker,
                limit=limit,
                output_fields=output_fields,
                timeout=self.timeout,
            )
            total_time = (time.time() - start) * 1000
            empty = len(result) == 0 or all(len(r) == 0 for r in result)

            # Prepare base result
            search_result = {
                "success": not empty,
                "response_time": total_time,
                "empty": empty,
                "result": result,
            }

            return search_result
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    @staticmethod
    def get_recall(search_results, ground_truth, limit=None):
        """Calculate recall for V2 client search results."""
        try:
            # Extract IDs from V2 search results
            retrieved_ids = []
            if isinstance(search_results, list) and len(search_results) > 0:
                # search_results[0] contains the search results for the first query
                for hit in search_results[0] if isinstance(search_results[0], list) else search_results:
                    if isinstance(hit, dict) and "id" in hit:
                        retrieved_ids.append(hit["id"])
                    elif hasattr(hit, "get"):
                        retrieved_ids.append(hit.get("id"))

            # Apply limit if specified
            if limit is None:
                limit = len(retrieved_ids)

            if len(ground_truth) < limit:
                raise ValueError(f"Ground truth length is less than limit: {len(ground_truth)} < {limit}")

            # Calculate recall
            ground_truth_set = set(ground_truth[:limit])
            retrieved_set = set(retrieved_ids)
            intersect = len(ground_truth_set.intersection(retrieved_set))
            return intersect / len(ground_truth_set)

        except Exception:
            return 0.0

    def query(self, filter, output_fields=None):
        if output_fields is None:
            output_fields = ["id"]

        start = time.time()
        try:
            result = self.client.query(
                collection_name=self.collection_name,
                filter=filter,
                output_fields=output_fields,
            )
            total_time = (time.time() - start) * 1000
            empty = len(result) == 0
            return {
                "success": not empty,
                "response_time": total_time,
                "empty": empty,
                "result": result,
            }
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }

    def delete(self, filter):
        start = time.time()
        try:
            result = self.client.delete(collection_name=self.collection_name, filter=filter)
            total_time = (time.time() - start) * 1000
            return {"success": True, "response_time": total_time, "result": result}
        except Exception as e:
            return {
                "success": False,
                "response_time": (time.time() - start) * 1000,
                "exception": e,
            }


# ----------------------------------
# Locust User wrapper
# ----------------------------------


[docs] class MilvusUser(User): """Locust User implementation for Milvus operations. This class wraps the MilvusV2Client implementation and translates client method results into Locust request events so that performance statistics are collected properly. Parameters ---------- host : str Milvus server URI, e.g. ``"http://localhost:19530"``. collection_name : str The name of the collection to operate on. **client_kwargs Additional keyword arguments forwarded to the client. """ abstract = True def __init__( self, environment, uri: str = "http://localhost:19530", token: str = "root:Milvus", collection_name: str = "test_collection", db_name: str = "default", timeout: int = 60, schema: CollectionSchema | None = None, index_params: IndexParams | None = None, **kwargs, # enable_dynamic_field, num_shards, consistency_level etc. ref: https://milvus.io/api-reference/pymilvus/v2.6.x/MilvusClient/Collections/create_collection.md ): super().__init__(environment) if uri is None: raise ValueError("'uri' must be provided for MilvusUser") if collection_name is None: raise ValueError("'collection_name' must be provided for MilvusUser") self.client_type = "milvus" self.client = MilvusV2Client( uri=uri, token=token, collection_name=collection_name, db_name=db_name, timeout=timeout, ) if schema is not None: self.client.create_collection(schema=schema, index_params=index_params, **kwargs) @staticmethod def _fire_event(request_type: str, name: str, result: dict[str, Any]): """Emit a Locust request event from a Milvus client result dict.""" response_time = int(result.get("response_time", 0)) events.request.fire( request_type=f"{request_type}", name=name, response_time=response_time, response_length=0, exception=result.get("exception"), ) @staticmethod def _fire_recall_event(request_type: str, name: str, result: dict[str, Any]): """Emit a Locust request event for recall metric using recall value instead of response time.""" recall_value = result.get("recall", 0.0) # Use recall value as response_time for metric display (scaled by 100 for better visualization) percentage response_time_as_recall = int(recall_value * 100) events.request.fire( request_type=f"{request_type}", name=name, response_time=response_time_as_recall, response_length=result.get("retrieved_count", 0), exception=result.get("exception"), ) def insert(self, data): result = self.client.insert(data) self._fire_event(self.client_type, "insert", result) return result def upsert(self, data): result = self.client.upsert(data) self._fire_event(self.client_type, "upsert", result) return result def search( self, data, anns_field, limit, filter="", search_params=None, output_fields=None, calculate_recall=False, ground_truth=None, ): result = self.client.search( data, anns_field, limit, filter=filter, search_params=search_params, output_fields=output_fields, calculate_recall=calculate_recall, ground_truth=ground_truth, ) # Fire search event self._fire_event(self.client_type, "search", result) # Fire recall event if recall was calculated if calculate_recall and "recall" in result: self._fire_recall_event(self.client_type, "recall", result) return result def hybrid_search(self, reqs, ranker, limit, output_fields=None): result = self.client.hybrid_search(reqs, ranker, limit, output_fields) self._fire_event(self.client_type, "hybrid_search", result) return result def query(self, filter, output_fields=None): result = self.client.query( filter=filter, output_fields=output_fields, ) self._fire_event(self.client_type, "query", result) return result def delete(self, filter): result = self.client.delete(filter) self._fire_event(self.client_type, "delete", result) return result def on_stop(self): self.client.close()