"""Base classes for document sources.
This module provides base classes for different types of document sources.
Sources represent the location/type of documents, while loaders handle the actual loading.
"""
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from haive.core.common.mixins.secure_config import SecureConfigMixin
[docs]
class BaseSource(BaseModel, ABC):
"""Abstract base class for all document sources."""
# Source identification
source_type: str | None = Field(None, description="Type identifier")
source_path: str | None = Field(None, description="Path or URL to source")
# Metadata
description: str | None = Field(None, description="Source description")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
[docs]
@abstractmethod
def validate_source(self) -> bool:
"""Validate that the source is accessible/valid."""
[docs]
@abstractmethod
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs to pass to the loader."""
[docs]
class LocalSource(BaseSource):
"""Base class for local file sources."""
file_path: str = Field(..., description="Path to local file")
encoding: str = Field("utf-8", description="File encoding")
[docs]
def validate_source(self) -> bool:
"""Check if file exists."""
from pathlib import Path
return Path(self.file_path).exists()
[docs]
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs for local file loaders."""
return {
"file_path": self.file_path,
"encoding": self.encoding,
}
[docs]
class DirectorySource(LocalSource):
"""Source for directory of files."""
file_path: str | None = Field(None, description="Not used for directories")
directory_path: str = Field(..., description="Path to directory")
glob_pattern: str = Field("**/*", description="File glob pattern")
recursive: bool = Field(True, description="Recursive search")
exclude_patterns: list[str] = Field(
default_factory=list, description="Patterns to exclude"
)
[docs]
def validate_source(self) -> bool:
"""Check if directory exists."""
from pathlib import Path
return Path(self.directory_path).exists()
[docs]
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs for directory loaders."""
return {
"path": self.directory_path,
"glob": self.glob_pattern,
"recursive": self.recursive,
"exclude": self.exclude_patterns,
}
[docs]
class RemoteSource(BaseSource, SecureConfigMixin):
"""Base class for remote sources with credential support."""
url: str = Field(..., description="Remote URL")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers")
# For SecureConfigMixin
provider: str = Field("generic", description="Provider name for credentials")
api_key: SecretStr | None = Field(None, description="API key if required")
[docs]
def validate_source(self) -> bool:
"""Validate URL format."""
from urllib.parse import urlparse
try:
result = urlparse(self.url)
return all([result.scheme, result.netloc])
except Exception:
return False
[docs]
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs for remote loaders."""
kwargs = {
"url": self.url,
"headers": self.headers,
}
# Add API key if available
api_key = self.get_api_key()
if api_key:
kwargs["api_key"] = api_key
return kwargs
[docs]
class DatabaseSource(BaseSource, SecureConfigMixin):
"""Base class for database sources."""
connection_string: str = Field(..., description="Database connection string")
query: str | None = Field(None, description="Query to execute")
table_name: str | None = Field(None, description="Table to load from")
# For SecureConfigMixin
provider: str = Field("database", description="Database provider")
[docs]
def validate_source(self) -> bool:
"""Basic validation of connection string."""
return bool(self.connection_string)
[docs]
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs for database loaders."""
kwargs = {
"connection_string": self.connection_string,
}
if self.query:
kwargs["query"] = self.query
if self.table_name:
kwargs["table_name"] = self.table_name
return kwargs
[docs]
class CloudSource(RemoteSource):
"""Base class for cloud storage sources."""
bucket_name: str = Field(..., description="Bucket/container name")
object_key: str | None = Field(None, description="Specific object key")
prefix: str | None = Field(None, description="Object prefix for listing")
# Override provider default
provider: str = Field("aws", description="Cloud provider")
[docs]
def get_loader_kwargs(self) -> dict[str, Any]:
"""Get kwargs for cloud storage loaders."""
kwargs = super().get_loader_kwargs()
kwargs.update(
{
"bucket": self.bucket_name,
}
)
if self.object_key:
kwargs["key"] = self.object_key
if self.prefix:
kwargs["prefix"] = self.prefix
return kwargs
# Concrete source implementations can be registered
__all__ = [
"BaseSource",
"CloudSource",
"DatabaseSource",
"DirectorySource",
"LocalSource",
"RemoteSource",
]