Source code for haive.core.engine.document.loaders.sources.registry

"""Source registry with decorator-based registration.

This module provides a registry for document sources that maps:
- File extensions to source classes
- URL patterns to source classes
- Schemes to source classes
- Source classes to their associated loaders

The registry enables automatic source detection and loader selection.
"""

import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from haive.core.engine.document.config import LoaderPreference

if TYPE_CHECKING:
    from haive.core.engine.document.loaders.path_analyzer import (
        PathAnalysisResult,
        analyze_path,
    )
from haive.core.engine.document.loaders.sources.source_base import (
    BaseSource,
    DatabaseSource,
    DirectorySource,
    LocalSource,
    RemoteSource,
)

logger = logging.getLogger(__name__)


[docs] @dataclass class LoaderMapping: """Mapping of a loader to a source.""" name: str # Loader name in langchain_community module: str = "langchain_community.document_loaders" # Loader characteristics speed: str = "medium" # fast, medium, slow quality: str = "medium" # low, medium, high # Requirements requires_packages: list[str] = field(default_factory=list) requires_auth: bool = False # When to use this loader best_for: list[str] = field(default_factory=list) conditions: dict[str, Any] = field( default_factory=dict ) # e.g., {"file_size": "<100MB"}
[docs] @dataclass class SourceRegistration: """Complete registration info for a source.""" name: str source_class: type[BaseSource] # Pattern matching file_extensions: set[str] = field(default_factory=set) mime_types: set[str] = field(default_factory=set) url_patterns: set[str] = field(default_factory=set) schemes: set[str] = field(default_factory=set) path_patterns: set[str] = field(default_factory=set) # Associated loaders loaders: dict[str, LoaderMapping] = field(default_factory=dict) default_loader: str | None = None # Matching priority (higher = preferred) priority: int = 0 # Custom matcher function custom_matcher: Callable[["PathAnalysisResult"], bool] | None = None
[docs] class SourceRegistry: """Registry for document sources and their loaders.""" def __init__(self) -> None: """Init . Returns: [TODO: Add return description] """ self._sources: dict[str, SourceRegistration] = {} # Indexes for fast lookup self._extension_index: dict[str, set[str]] = {} # ext -> source names # pattern -> source names self._url_pattern_index: dict[str, set[str]] = {} self._scheme_index: dict[str, set[str]] = {} # scheme -> source names self._mime_index: dict[str, set[str]] = {} # mime -> source names
[docs] def register( self, name: str, source_class: type[BaseSource], file_extensions: list[str] | None = None, mime_types: list[str] | None = None, url_patterns: list[str] | None = None, schemes: list[str] | None = None, path_patterns: list[str] | None = None, loaders: dict[str, str | dict[str, Any]] | None = None, default_loader: str | None = None, priority: int = 0, custom_matcher: Callable[["PathAnalysisResult"], bool] | None = None, ) -> SourceRegistration: """Register a source with the registry.""" # Create registration registration = SourceRegistration( name=name, source_class=source_class, file_extensions=set(file_extensions or []), mime_types=set(mime_types or []), url_patterns=set(url_patterns or []), schemes=set(schemes or []), path_patterns=set(path_patterns or []), default_loader=default_loader, priority=priority, custom_matcher=custom_matcher, ) # Process loader mappings if loaders: for loader_name, loader_info in loaders.items(): if isinstance(loader_info, str): # Simple string mapping registration.loaders[loader_name] = LoaderMapping(name=loader_info) elif isinstance(loader_info, dict): # Detailed mapping registration.loaders[loader_name] = LoaderMapping( name=loader_info.get("class", loader_name), module=loader_info.get( "module", "langchain_community.document_loaders" ), speed=loader_info.get("speed", "medium"), quality=loader_info.get("quality", "medium"), requires_packages=loader_info.get("requires_packages", []), requires_auth=loader_info.get("requires_auth", False), best_for=loader_info.get("best_for", []), conditions=loader_info.get("conditions", {}), ) # Store registration self._sources[name] = registration # Update indexes self._update_indexes(name, registration) logger.info( f"Registered source '{name}' with {len(registration.loaders)} loaders, " f"{len(registration.file_extensions)} extensions" ) return registration
def _update_indexes(self, name: str, registration: SourceRegistration): """Update lookup indexes.""" # File extensions for ext in registration.file_extensions: if ext not in self._extension_index: self._extension_index[ext] = set() self._extension_index[ext].add(name) # URL patterns for pattern in registration.url_patterns: if pattern not in self._url_pattern_index: self._url_pattern_index[pattern] = set() self._url_pattern_index[pattern].add(name) # Schemes for scheme in registration.schemes: if scheme not in self._scheme_index: self._scheme_index[scheme] = set() self._scheme_index[scheme].add(name) # MIME types for mime in registration.mime_types: if mime not in self._mime_index: self._mime_index[mime] = set() self._mime_index[mime].add(name)
[docs] def find_source_for_path( self, path: str, analysis: "PathAnalysisResult | None" = None ) -> SourceRegistration | None: """Find the best source for a given path.""" # Analyze path if not provided if not analysis: # Import here to avoid circular import from haive.core.engine.document.loaders.path_analyzer import analyze_path analysis = analyze_path(path) candidates: list[SourceRegistration] = [] # Check file extension if analysis.file_extension: for source_name in self._extension_index.get(analysis.file_extension, []): candidates.append(self._sources[source_name]) # Check URL patterns if analysis.domain: for pattern, source_names in self._url_pattern_index.items(): if pattern in analysis.domain: for source_name in source_names: candidates.append(self._sources[source_name]) # Check schemes if analysis.url_components and analysis.url_components.get("scheme"): scheme = analysis.url_components["scheme"] for source_name in self._scheme_index.get(scheme, []): candidates.append(self._sources[source_name]) # Check MIME type if analysis.mime_type: for source_name in self._mime_index.get(analysis.mime_type, []): candidates.append(self._sources[source_name]) # Check custom matchers for registration in self._sources.values(): if registration.custom_matcher and registration.custom_matcher(analysis): candidates.append(registration) # Return highest priority match if candidates: return max(candidates, key=lambda r: r.priority) return None
[docs] def create_source( self, path: str, source_type: str | None = None, **kwargs ) -> BaseSource | None: """Create a source instance for a path.""" # Use specific source if provided if source_type and source_type in self._sources: registration = self._sources[source_type] else: # Auto-detect source registration = self.find_source_for_path(path) if not registration: return None # Create source instance try: # Analyze path for additional metadata analyze_path(path) # Build source kwargs based on source type source_kwargs = { "source_type": registration.name, "source_id": f"{registration.name}:{path}", } # Add path-specific fields based on base class if issubclass(registration.source_class, LocalSource): source_kwargs["file_path"] = path elif issubclass(registration.source_class, DirectorySource): source_kwargs["directory_path"] = path elif issubclass(registration.source_class, RemoteSource): source_kwargs["url"] = path elif issubclass(registration.source_class, DatabaseSource): source_kwargs["connection_string"] = path # Merge with provided kwargs source_kwargs.update(kwargs) # Create instance return registration.source_class(**source_kwargs) except Exception as e: logger.exception(f"Failed to create source for {path}: {e}") return None
[docs] def get_loader_for_source( self, source: BaseSource, loader_name: str | None = None, preference: LoaderPreference = LoaderPreference.BALANCED, ) -> LoaderMapping | None: """Get the best loader for a source.""" # Get source registration registration = self._sources.get(source.source_type) if not registration or not registration.loaders: return None # Use specific loader if requested if loader_name and loader_name in registration.loaders: return registration.loaders[loader_name] # Use source's preferred loader if source.preferred_loader and source.preferred_loader in registration.loaders: return registration.loaders[source.preferred_loader] # Select based on preference (prioritize over default) if preference == LoaderPreference.SPEED: # Find fastest loader fast_loaders = [ l for l in registration.loaders.values() if l.speed == "fast" ] if fast_loaders: return fast_loaders[0] elif preference == LoaderPreference.QUALITY: # Find highest quality loader quality_loaders = [ l for l in registration.loaders.values() if l.quality == "high" ] if quality_loaders: return quality_loaders[0] # Use registration's default if ( registration.default_loader and registration.default_loader in registration.loaders ): return registration.loaders[registration.default_loader] # Return first available return next(iter(registration.loaders.values()))
[docs] def list_sources(self) -> list[str]: """List all registered source names.""" return list(self._sources.keys())
[docs] def get_source_info(self, name: str) -> SourceRegistration | None: """Get registration info for a source.""" return self._sources.get(name)
# Global registry instance source_registry = SourceRegistry() # Import to ensure LocalSource is available
[docs] def register_source( name: str | None = None, file_extensions: list[str] | None = None, mime_types: list[str] | None = None, url_patterns: list[str] | None = None, schemes: list[str] | None = None, path_patterns: list[str] | None = None, loaders: dict[str, str | dict[str, Any]] | None = None, default_loader: str | None = None, priority: int = 0, custom_matcher: Callable[["PathAnalysisResult"], bool] | None = None, ) -> Callable[[type[BaseSource]], type[BaseSource]]: """Decorator to register a source class. Examples: @register_source( name="pdf", file_extensions=[".pdf"], mime_types=["application/pdf"], loaders={ "fast": "PyPDFLoader", "quality": { "class": "UnstructuredPDFLoader", "quality": "high", "requires_packages": ["unstructured", "pdf2image"], }, "ocr": { "class": "PDFPlumberLoader", "speed": "slow", "quality": "high", "best_for": ["tables", "complex_layouts"], } }, default_loader="fast", priority=10 ) class PDFSource(LocalSource): '''Source for PDF documents.''' pass """ def decorator(source_class: type[BaseSource]) -> type[BaseSource]: """Decorator. Args: source_class: [TODO: Add description] Returns: [TODO: Add return description] """ # Use class name if no name provided source_name = name or source_class.__name__.lower().replace("source", "") # Register with the global registry registration = source_registry.register( name=source_name, source_class=source_class, file_extensions=file_extensions, mime_types=mime_types, url_patterns=url_patterns, schemes=schemes, path_patterns=path_patterns, loaders=loaders, default_loader=default_loader, priority=priority, custom_matcher=custom_matcher, ) # Attach registration info to class source_class._registry_name = source_name source_class._registration = registration return source_class return decorator
__all__ = [ "LoaderMapping", "SourceRegistration", "SourceRegistry", "register_source", "source_registry", ]