Source code for haive.core.engine.document.loaders.sources.implementation

"""Enhanced Source Implementation for Document Engine.

This module provides enhanced source type implementations adapted from the original
project_notes with proper integration into the Haive document engine framework.
"""

import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any
from urllib.parse import urlparse

from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


[docs] class SourceType(str, Enum): """Enhanced source type classification.""" # File-based sources LOCAL_FILE = "local_file" LOCAL_DIRECTORY = "local_directory" # Web sources WEB_URL = "web_url" WEB_API = "web_api" # Database sources DATABASE = "database" # Cloud sources CLOUD_STORAGE = "cloud_storage" # Text sources TEXT_INPUT = "text_input"
[docs] class CredentialType(str, Enum): """Types of credentials supported.""" API_KEY = "api_key" OAUTH2 = "oauth2" USERNAME_PASSWORD = "username_password" SERVICE_ACCOUNT = "service_account" ACCESS_TOKEN = "access_token" CONNECTION_STRING = "connection_string"
[docs] class Credential(BaseModel): """Credential information for authenticated sources.""" credential_type: CredentialType value: str = Field(..., description="The credential value") metadata: dict[str, Any] = Field(default_factory=dict) class Config: extra = "allow"
[docs] class CredentialManager: """Manages credentials for various source types.""" def __init__(self) -> None: """Init . Returns: [TODO: Add return description] """ self._credentials: dict[str, Credential] = {} self._env_prefix = "HAIVE_CRED_"
[docs] def add_credential(self, source_id: str, credential: Credential) -> None: """Add a credential for a source.""" self._credentials[source_id] = credential
[docs] def get_credential(self, source_id: str) -> Credential | None: """Get credential for a source.""" # Try direct lookup first if source_id in self._credentials: return self._credentials[source_id] # Try environment variable env_key = f"{self._env_prefix}{source_id.upper()}" env_value = os.getenv(env_key) if env_value: return Credential(credential_type=CredentialType.API_KEY, value=env_value) return None
[docs] def has_credential(self, source_id: str) -> bool: """Check if credential exists for source.""" return self.get_credential(source_id) is not None
[docs] class EnhancedSource(BaseModel, ABC): """Enhanced base class for document sources.""" source_type: SourceType source_path: str = Field(..., description="Path or identifier for the source") metadata: dict[str, Any] = Field(default_factory=dict) credential_manager: CredentialManager | None = Field(default=None, exclude=True) class Config: arbitrary_types_allowed = True
[docs] @abstractmethod def can_handle(self, path: str) -> bool: """Check if this source can handle the given path."""
[docs] @abstractmethod def get_confidence_score(self, path: str) -> float: """Get confidence score (0.0-1.0) for handling this path."""
[docs] def requires_authentication(self) -> bool: """Check if this source requires authentication.""" return False
[docs] def get_credential_requirements(self) -> list[CredentialType]: """Get required credential types.""" return []
[docs] class LocalFileSource(EnhancedSource): """Source for local files.""" source_type: SourceType = Field(default=SourceType.LOCAL_FILE) file_extensions: list[str] = Field(default_factory=list)
[docs] def can_handle(self, path: str) -> bool: """Check if this is a local file.""" try: p = Path(path) return p.exists() and p.is_file() except (OSError, ValueError): return False
[docs] def get_confidence_score(self, path: str) -> float: """Get confidence score for local files.""" if not self.can_handle(path): return 0.0 p = Path(path) if self.file_extensions: if p.suffix.lower() in [ext.lower() for ext in self.file_extensions]: return 0.9 return 0.3 return 0.7
[docs] class LocalDirectorySource(EnhancedSource): """Source for local directories.""" source_type: SourceType = Field(default=SourceType.LOCAL_DIRECTORY) recursive: bool = Field(default=True) include_patterns: list[str] = Field(default_factory=list) exclude_patterns: list[str] = Field(default_factory=list)
[docs] def can_handle(self, path: str) -> bool: """Check if this is a local directory.""" try: p = Path(path) return p.exists() and p.is_dir() except (OSError, ValueError): return False
[docs] def get_confidence_score(self, path: str) -> float: """Get confidence score for local directories.""" if not self.can_handle(path): return 0.0 return 0.8
[docs] class WebUrlSource(EnhancedSource): """Source for web URLs.""" source_type: SourceType = Field(default=SourceType.WEB_URL) allowed_schemes: list[str] = Field(default=["http", "https"]) allowed_domains: list[str] = Field(default_factory=list)
[docs] def can_handle(self, path: str) -> bool: """Check if this is a valid web URL.""" try: parsed = urlparse(path) return parsed.scheme in self.allowed_schemes and bool(parsed.netloc) except Exception: return False
[docs] def get_confidence_score(self, path: str) -> float: """Get confidence score for web URLs.""" if not self.can_handle(path): return 0.0 parsed = urlparse(path) # Higher confidence for known domains if self.allowed_domains: for domain in self.allowed_domains: if domain in parsed.netloc: return 0.9 return 0.4 return 0.7
[docs] class DatabaseSource(EnhancedSource): """Source for database connections.""" source_type: SourceType = Field(default=SourceType.DATABASE) supported_schemes: list[str] = Field( default=["postgresql", "mysql", "sqlite", "mongodb"] )
[docs] def can_handle(self, path: str) -> bool: """Check if this is a database URI.""" try: parsed = urlparse(path) return parsed.scheme in self.supported_schemes except Exception: return False
[docs] def get_confidence_score(self, path: str) -> float: """Get confidence score for database URIs.""" if not self.can_handle(path): return 0.0 return 0.8
[docs] def requires_authentication(self) -> bool: """Database sources typically require authentication.""" return True
[docs] def get_credential_requirements(self) -> list[CredentialType]: """Database sources need connection credentials.""" return [CredentialType.USERNAME_PASSWORD, CredentialType.CONNECTION_STRING]
[docs] class CloudStorageSource(EnhancedSource): """Source for cloud storage.""" source_type: SourceType = Field(default=SourceType.CLOUD_STORAGE) supported_providers: list[str] = Field(default=["s3", "gcs", "azure", "dropbox"])
[docs] def can_handle(self, path: str) -> bool: """Check if this is a cloud storage path.""" try: parsed = urlparse(path) return any( provider in parsed.scheme for provider in self.supported_providers ) except Exception: return False
[docs] def get_confidence_score(self, path: str) -> float: """Get confidence score for cloud storage.""" if not self.can_handle(path): return 0.0 return 0.8
[docs] def requires_authentication(self) -> bool: """Cloud storage typically requires authentication.""" return True
[docs] def get_credential_requirements(self) -> list[CredentialType]: """Cloud storage needs API credentials.""" return [CredentialType.API_KEY, CredentialType.SERVICE_ACCOUNT]
[docs] class TextInputSource(EnhancedSource): """Source for direct text input.""" source_type: SourceType = Field(default=SourceType.TEXT_INPUT)
[docs] def can_handle(self, path: str) -> bool: """Text input can handle anything as fallback.""" return True
[docs] def get_confidence_score(self, path: str) -> float: """Low confidence - fallback option.""" # Only use as fallback if it looks like direct text if not any(char in path for char in ["/", "\\", ":", "."]): return 0.2 return 0.1
[docs] class SourceRegistry: """Registry for managing source types.""" def __init__(self) -> None: """Init . Returns: [TODO: Add return description] """ self._sources: list[EnhancedSource] = [] self._register_default_sources() def _register_default_sources(self): """Register default source types.""" self.register(LocalFileSource(source_path="")) self.register(LocalDirectorySource(source_path="")) self.register(WebUrlSource(source_path="")) self.register(DatabaseSource(source_path="")) self.register(CloudStorageSource(source_path="")) self.register(TextInputSource(source_path=""))
[docs] def register(self, source: EnhancedSource): """Register a new source type.""" self._sources.append(source)
[docs] def find_best_source(self, path: str) -> EnhancedSource | None: """Find the best source for a given path.""" candidates = [] for source in self._sources: if source.can_handle(path): confidence = source.get_confidence_score(path) candidates.append((source, confidence)) if not candidates: return None # Sort by confidence (highest first) candidates.sort(key=lambda x: x[1], reverse=True) best_source, _ = candidates[0] # Create a new instance with the specific path source_class = type(best_source) return source_class( source_path=path, metadata=best_source.metadata.copy(), credential_manager=best_source.credential_manager, )
[docs] def find_all_sources(self, path: str) -> list[tuple[EnhancedSource, float]]: """Find all sources that can handle a path with confidence scores.""" candidates = [] for source in self._sources: if source.can_handle(path): confidence = source.get_confidence_score(path) candidates.append((source, confidence)) # Sort by confidence (highest first) candidates.sort(key=lambda x: x[1], reverse=True) return candidates
# Global registry instance source_registry = SourceRegistry() # Export key components __all__ = [ "CloudStorageSource", "Credential", "CredentialManager", "CredentialType", "DatabaseSource", "EnhancedSource", "LocalDirectorySource", "LocalFileSource", "SourceRegistry", "SourceType", "TextInputSource", "WebUrlSource", "source_registry", ]