Skip to content

ChromaSQL Multi-Collection Examples

This document provides practical examples of using ChromaSQL with multi-collection setups. These patterns are useful when you have data partitioned across multiple ChromaDB collections.

Table of Contents

  1. Basic Multi-Collection Query
  2. Model-Based Routing (Most Common)
  3. Custom Routing Strategy
  4. Fallback Behavior
  5. Integration with AsyncMultiCollectionQueryClient
  6. Error Handling and Resilience

Basic Multi-Collection Query

The simplest way to query multiple collections is to provide a static list:

import asyncio
from pathlib import Path
from chromasql.multi_collection import execute_multi_collection, CollectionRouter
from chromasql.adapters import SimpleAsyncClientAdapter
import chromadb

# Simple router that queries specific collections
class StaticRouter(CollectionRouter):
    def __init__(self, collections):
        self.collections = collections

    def route(self, query):
        return self.collections

async def main():
    # Connect to ChromaDB
    client = await chromadb.AsyncHttpClient(host="localhost", port=8000)

    # Create adapter
    adapter = SimpleAsyncClientAdapter(
        client=client,
        collection_names=["collection_1", "collection_2", "collection_3"],
    )

    # Execute query across specific collections
    router = StaticRouter(["collection_1", "collection_2"])

    result = await execute_multi_collection(
        query_str="""
            SELECT id, distance, document
            FROM demo
            USING EMBEDDING (VECTOR [0.1, 0.2, 0.3])
            TOPK 10;
        """,
        router=router,
        collection_provider=adapter,
    )

    # Results are merged and sorted by distance
    for row in result.rows:
        print(f"{row['id']}: {row['distance']}")

asyncio.run(main())

Model-Based Routing (Most Common)

Route queries based on metadata.model filter using your existing query config:

import asyncio
from pathlib import Path
from chromasql.adapters import MetadataFieldRouter
from chromasql.multi_collection import execute_multi_collection
from idxr.query_lib.async_multi_collection_adapter import AsyncMultiCollectionAdapter
from idxr.vectorize_lib.query_client import AsyncMultiCollectionQueryClient
from idxr.vectorize_lib.query_config import load_query_config

async def query_by_model(query_str: str, embed_fn):
    """Execute a ChromaSQL query with model-based routing."""

    # Load query config (generated by vectorize_lib)
    config = load_query_config(Path("output/query_config.json"))

    # Initialize multi-collection client
    client = AsyncMultiCollectionQueryClient(
        config_path=Path("output/query_config.json"),
        client_type="cloud",
        cloud_api_key="your-api-key",
        cloud_tenant="your-tenant",
        cloud_database="your-database",
    )
    await client.connect()

    try:
        # Create adapters
        adapter = AsyncMultiCollectionAdapter(client)
        router = MetadataFieldRouter(
            query_config=config,
            field_path=("model",),  # Route based on metadata.model
            fallback_to_all=True,   # Query all 37 collections if not specified
        )

        # Execute query
        result = await execute_multi_collection(
            query_str=query_str,
            router=router,
            collection_provider=adapter,
            embed_fn=embed_fn,
        )

        return result

    finally:
        await client.close()

# Example usage
async def main():
    def my_embed(text, model):
        # Your embedding logic here
        return [0.1, 0.2, 0.3] * 384  # Example 384-dim vector

    # Query specific models - only queries collections containing Table/Field
    result = await query_by_model(
        query_str="""
            SELECT id, distance, metadata.model, document
            FROM sap_data
            WHERE metadata.model IN ('Table', 'Field')
              AND metadata.environment = 'production'
            USING EMBEDDING (TEXT 'SAP financial tables')
            TOPK 20;
        """,
        embed_fn=my_embed,
    )

    print(f"Found {len(result.rows)} results")
    for row in result.rows:
        print(f"  {row['id']} ({row['metadata.model']}): {row['distance']:.3f}")

    # Query without model filter - queries all 37 collections
    result_all = await query_by_model(
        query_str="""
            SELECT id, distance, document
            FROM sap_data
            WHERE metadata.environment = 'production'
            USING EMBEDDING (TEXT 'configuration settings')
            TOPK 10;
        """,
        embed_fn=my_embed,
    )

    print(f"\nQueried all collections, found {len(result_all.rows)} results")

asyncio.run(main())

Custom Routing Strategy

Implement custom routing logic for complex scenarios:

from typing import Optional, Sequence
from chromasql.multi_collection import CollectionRouter
from chromasql.analysis import extract_metadata_values
from chromasql.ast import Query

class TenantAndRegionRouter(CollectionRouter):
    """Route based on both tenant and region metadata."""

    def __init__(self, collection_mapping: dict):
        """
        collection_mapping example:
        {
            ("tenant_123", "us-east"): ["shard_001", "shard_002"],
            ("tenant_123", "eu-west"): ["shard_003"],
            ("tenant_456", "us-east"): ["shard_004"],
        }
        """
        self.collection_mapping = collection_mapping

    def route(self, query: Query) -> Optional[Sequence[str]]:
        # Extract both discriminators
        tenants = extract_metadata_values(query, field_path=("tenant_id",))
        regions = extract_metadata_values(query, field_path=("region",))

        if not tenants or not regions:
            # If either is missing, query all collections
            return None

        # Find collections for all (tenant, region) pairs
        collections = set()
        for tenant in tenants:
            for region in regions:
                key = (tenant, region)
                if key in self.collection_mapping:
                    collections.update(self.collection_mapping[key])

        return sorted(collections) if collections else None


# Usage
async def main():
    mapping = {
        ("tenant_123", "us-east"): ["shard_001", "shard_002"],
        ("tenant_123", "eu-west"): ["shard_003"],
        ("tenant_456", "us-east"): ["shard_004"],
    }

    router = TenantAndRegionRouter(mapping)

    # This will only query shard_001 and shard_002
    result = await execute_multi_collection(
        query_str="""
            SELECT id, distance
            FROM data
            WHERE metadata.tenant_id = 'tenant_123'
              AND metadata.region = 'us-east'
            USING EMBEDDING (TEXT 'search query')
            TOPK 10;
        """,
        router=router,
        collection_provider=adapter,
        embed_fn=embed_fn,
    )

Fallback Behavior

Control what happens when discriminator fields are not filtered:

from chromasql.adapters import MetadataFieldRouter

# Option 1: Fallback to all collections (recommended for most cases)
router_with_fallback = MetadataFieldRouter(
    query_config=config,
    field_path=("model",),
    fallback_to_all=True,  # Default behavior
)

# This query doesn't filter on metadata.model, so it queries all collections
result = await execute_multi_collection(
    query_str="""
        SELECT id FROM demo
        WHERE metadata.status = 'active'
        USING EMBEDDING (TEXT 'query')
        TOPK 5;
    """,
    router=router_with_fallback,
    collection_provider=adapter,
    embed_fn=embed_fn,
)
print(f"Queried all collections, found {len(result.rows)} results")


# Option 2: Require discriminator field (strict mode)
router_strict = MetadataFieldRouter(
    query_config=config,
    field_path=("model",),
    fallback_to_all=False,  # Raise error if not filtered
)

try:
    # This will raise ValueError because metadata.model is not filtered
    result = await execute_multi_collection(
        query_str="SELECT id FROM demo WHERE metadata.status = 'active';",
        router=router_strict,
        collection_provider=adapter,
    )
except ValueError as e:
    print(f"Error: {e}")
    # Output: "Query must filter on metadata.model (fallback_to_all is disabled)"

Integration with AsyncMultiCollectionQueryClient

Seamlessly integrate ChromaSQL with your existing vectorize_lib infrastructure:

import asyncio
from pathlib import Path
from chromasql import parse, extract_metadata_values
from chromasql.adapters import MetadataFieldRouter
from chromasql.multi_collection import execute_multi_collection
from idxr.query_lib.async_multi_collection_adapter import AsyncMultiCollectionAdapter
from idxr.vectorize_lib.query_client import AsyncMultiCollectionQueryClient
from idxr.vectorize_lib.query_config import load_query_config

class ChromaSQLQueryService:
    """Service that wraps ChromaSQL with your existing infrastructure."""

    def __init__(self, config_path: Path, **client_kwargs):
        self.config_path = config_path
        self.client_kwargs = client_kwargs
        self.client = None
        self.adapter = None
        self.router = None
        self.config = None

    async def connect(self):
        """Initialize connections."""
        # Load config
        self.config = load_query_config(self.config_path)

        # Initialize client
        self.client = AsyncMultiCollectionQueryClient(
            config_path=self.config_path,
            **self.client_kwargs,
        )
        await self.client.connect()

        # Create adapters
        self.adapter = AsyncMultiCollectionAdapter(self.client)
        self.router = MetadataFieldRouter(
            query_config=self.config,
            field_path=("model",),
            fallback_to_all=True,
        )

    async def close(self):
        """Close connections."""
        if self.client:
            await self.client.close()

    async def query(self, sql: str, embed_fn):
        """Execute a ChromaSQL query."""
        return await execute_multi_collection(
            query_str=sql,
            router=self.router,
            collection_provider=self.adapter,
            embed_fn=embed_fn,
        )

    def preview_routing(self, sql: str):
        """Preview which collections would be queried (without executing)."""
        query = parse(sql)
        collections = self.router.route(query)

        if collections is None:
            all_collections = sorted(self.config["collection_to_models"].keys())
            return {
                "mode": "all",
                "collections": all_collections,
                "count": len(all_collections),
            }
        else:
            return {
                "mode": "targeted",
                "collections": list(collections),
                "count": len(collections),
            }

    async def __aenter__(self):
        await self.connect()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()


# Usage
async def main():
    async with ChromaSQLQueryService(
        config_path=Path("output/query_config.json"),
        client_type="cloud",
        cloud_api_key="your-api-key",
    ) as service:

        # Preview routing
        sql = """
            SELECT id, distance, document
            FROM sap_data
            WHERE metadata.model IN ('Table', 'Field')
            USING EMBEDDING (TEXT 'financial tables')
            TOPK 10;
        """

        routing_info = service.preview_routing(sql)
        print(f"Will query {routing_info['count']} collection(s):")
        print(f"  Mode: {routing_info['mode']}")
        print(f"  Collections: {routing_info['collections'][:5]}...")

        # Execute query
        def embed(text, model):
            # Your embedding logic
            return [0.1] * 384

        result = await service.query(sql, embed_fn=embed)
        print(f"\nFound {len(result.rows)} results")

asyncio.run(main())

Error Handling and Resilience

Handle partial collection failures gracefully:

from chromasql.errors import ChromaSQLExecutionError

async def resilient_query(query_str: str, router, adapter, embed_fn):
    """Execute query with retry logic and fallback."""

    try:
        result = await execute_multi_collection(
            query_str=query_str,
            router=router,
            collection_provider=adapter,
            embed_fn=embed_fn,
        )

        # Check if we got partial results
        if result.raw.get("merged_from_collections"):
            total_collections = result.raw.get("total_collections_attempted", 0)
            successful = result.raw["merged_from_collections"]

            if successful < total_collections:
                print(f"Warning: Only {successful}/{total_collections} collections responded")

        return result

    except ChromaSQLExecutionError as e:
        if "All collection queries failed" in str(e):
            print("All collections failed - check your connection")
            # Fallback: try with a smaller subset
            # Or: alert monitoring system
            raise
        else:
            print(f"Query error: {e}")
            raise


# The multi-collection executor automatically handles partial failures:
# - If some collections fail, results from successful collections are returned
# - If ALL collections fail, ChromaSQLExecutionError is raised
# - Individual collection errors are logged but don't fail the entire query

async def main():
    result = await resilient_query(
        query_str="SELECT id FROM demo USING EMBEDDING (TEXT 'test') TOPK 10;",
        router=router,
        adapter=adapter,
        embed_fn=embed_fn,
    )

    print(f"Successfully retrieved {len(result.rows)} results")

OR Predicate Routing (Union Behavior)

ChromaSQL uses union routing for OR predicates to ensure you never miss results:

# Query with OR predicate
result = await execute_multi_collection(
    query_str="""
        SELECT id, distance, document
        FROM sap_data
        WHERE metadata.model = 'Table' OR metadata.model = 'Field'
        USING EMBEDDING (TEXT 'financial data')
        TOPK 10;
    """,
    router=router,
    collection_provider=adapter,
    embed_fn=embed_fn,
)

# Router behavior:
# 1. Extracts {'Table', 'Field'} from OR branches
# 2. Maps to collections: Table → [coll_1, coll_2], Field → [coll_3, coll_4]
# 3. Queries UNION: [coll_1, coll_2, coll_3, coll_4]
# 4. Merges results and returns top 10 globally

Important OR Behaviors

✅ Union across multiple OR branches:

-- Queries union of all three models
WHERE metadata.model = 'Table'
   OR metadata.model = 'Field'
   OR metadata.model = 'View'

✅ Works with IN and OR combinations:

-- Extracts {'A', 'B', 'C', 'D'}
WHERE metadata.model IN ('A', 'B')
   OR metadata.model IN ('C', 'D')

✅ Mixed OR (discriminator + other fields):

-- Extracts {'Table'} only (has_sem not a routing field)
-- Still queries all collections that contain 'Table'
WHERE metadata.model = 'Table'
   OR metadata.has_sem = FALSE

⚠️ OR with no discriminator field:

-- No model values → falls back to ALL collections
WHERE metadata.status = 'active'
   OR metadata.has_sem = TRUE

Why Union Routing?

Union routing prevents under-routing - missing results because a collection wasn't queried:

# Without union routing (BAD):
# Query: WHERE model = 'Table' OR has_sem = FALSE
# Router sees: model = 'Table'
# Queries: Only collections with 'Table'
# Problem: Misses records where has_sem = FALSE but model != 'Table'

# With union routing (GOOD):
# Query: WHERE model = 'Table' OR has_sem = FALSE
# Router extracts: {'Table'}
# Queries: All collections containing 'Table'
# The full WHERE clause is still applied to each collection
# No results are missed!

The key insight: The router determines which collections to query, but each collection receives the full WHERE clause. So even if has_sem = FALSE isn't a routing field, records matching that condition will be found in any queried collection.


Performance Tips

  1. Use n_results_per_collection for better recall:

    result = await execute_multi_collection(
        query_str="SELECT id FROM demo USING EMBEDDING (TEXT 'test') TOPK 10;",
        router=router,
        collection_provider=adapter,
        embed_fn=embed_fn,
        n_results_per_collection=50,  # Fetch 50 from each, return top 10 overall
    )
    

  2. Filter aggressively in WHERE clause:

  3. Include discriminator fields to reduce collections queried
  4. Add other filters to reduce data transfer
  5. Use AND with discriminator fields when possible

  6. Monitor routing decisions:

    query = parse(your_sql)
    collections = router.route(query)
    print(f"Querying {len(collections) if collections else 'all'} collection(s)")
    

  7. Use LIMIT judiciously:

  8. LIMIT is applied after merging results from all collections
  9. Each collection still returns n_results (or n_results_per_collection)
  10. Consider using smaller TOPK instead of large LIMIT

  11. Understand OR performance implications:

  12. OR predicates with discriminator fields query union of collections
  13. OR with non-discriminator fields queries all collections
  14. Use AND when possible to narrow collection scope

For more information, see: - CONTRIBUTING.md for architecture details - Build Your First Query for ChromaSQL syntax reference - chromasql/multi_collection.py for API documentation

  • Need Help?
    Open a GitHub issue with the steps to reproduce and we’ll help you debug it.