diff --git a/.dockerignore b/.dockerignore index 885316e..a495033 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,4 +3,7 @@ !README.md !pyproject.toml !handler.py -!service.py \ No newline at end of file +!service.py +!setup.py +!tests/fixtures/document.pdf +docketanalyzer_ocr/data/venv \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index b858df1..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,125 +0,0 @@ -name: CI - -on: - push: - branches: [ main, master ] - pull_request: - branches: [ main, master ] - -jobs: - lint-and-autofix: - if: success() - runs-on: ubuntu-latest - permissions: - contents: write - pull-requests: write - env: - RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }} - RUNPOD_OCR_ENDPOINT_ID: ${{ secrets.RUNPOD_OCR_ENDPOINT_ID }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} - S3_ENDPOINT_URL: ${{ secrets.S3_ENDPOINT_URL }} - steps: - - uses: actions/checkout@v3 - with: - ref: ${{ github.event_name == 'pull_request' && github.head_ref || '' }} - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - cache: 'pip' - - - name: Install linting dependencies - run: | - python -m pip install --upgrade pip - pip install ruff - - - name: Check linting - id: lint_check - continue-on-error: true - run: | - ruff format --check . - ruff check . - - - name: Auto-fix with Ruff - if: steps.lint_check.outcome == 'failure' && !contains(github.event.head_commit.message, 'Auto-fix linting issues') - run: | - # Run format and fix what can be fixed automatically - ruff format . - ruff check --fix . - - - name: Verify linting - if: steps.lint_check.outcome == 'failure' && !contains(github.event.head_commit.message, 'Auto-fix linting issues') - run: | - ruff format --check . - ruff check . - - - name: Commit changes - id: auto_commit - if: steps.lint_check.outcome == 'failure' && !contains(github.event.head_commit.message, 'Auto-fix linting issues') - uses: stefanzweifel/git-auto-commit-action@v4 - with: - commit_message: "Auto-fix linting issues" - commit_user_name: "GitHub Actions" - commit_user_email: "actions@github.com" - commit_author: "GitHub Actions " - skip_dirty_check: false - - - name: Add PR comment - if: steps.auto_commit.outputs.changes_detected == 'true' && github.event_name == 'pull_request' - uses: actions/github-script@v6 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: '✅ I automatically fixed some linting issues in this pull request. All tests have already passed.' - }) - - test: - runs-on: ubuntu-latest - env: - RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }} - RUNPOD_OCR_ENDPOINT_ID: ${{ secrets.RUNPOD_OCR_ENDPOINT_ID }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} - S3_ENDPOINT_URL: ${{ secrets.S3_ENDPOINT_URL }} - steps: - - uses: actions/checkout@v3 - with: - # For PRs, checkout the head ref; for pushes, use the default behavior - ref: ${{ github.event_name == 'pull_request' && github.head_ref || '' }} - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - cache: 'pip' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install . - pip install pytest-cov - - - name: Test with pytest - run: | - pytest --cov=docketanalyzer_ocr tests/ --cov-report=xml --cov-branch --junitxml=junit.xml -o junit_family=legacy - - - name: Upload test results to Codecov - if: ${{ !cancelled() }} - uses: codecov/test-results-action@v1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - - - name: Upload coverage report - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - slug: docketanalyzer/ocr \ No newline at end of file diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml new file mode 100644 index 0000000..659f448 --- /dev/null +++ b/.github/workflows/code-format.yml @@ -0,0 +1,35 @@ +name: Code Format + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + format-check: + name: Format Check + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set Up + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Check formatting + run: | + ruff format --check . + + - name: Check linting + run: | + ruff check . diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..6259858 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + run-tests: + runs-on: ubuntu-latest + env: + RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }} + RUNPOD_OCR_ENDPOINT_ID: ${{ secrets.RUNPOD_OCR_ENDPOINT_ID }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + AWS_S3_BUCKET_NAME: ${{ secrets.AWS_S3_BUCKET_NAME }} + AWS_S3_ENDPOINT_URL: ${{ secrets.AWS_S3_ENDPOINT_URL }} + + steps: + - uses: actions/checkout@v3 + + - name: Set Up + uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install + run: | + python -m pip install --upgrade pip + pip install '.[dev]' + + - name: Test with pytest + run: | + pytest diff --git a/.gitignore b/.gitignore index 843f092..0d58233 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ -archive -temp.py -junit.xml +temp +docketanalyzer_ocr/data/venv ### Flask ### instance/* diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 0000000..c20c60a --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,42 @@ +# Development + +## Install + +``` +pip install '.[dev]' +``` + +## Test + +``` +pytest -vv +``` + +## Format + +``` +ruff format . && ruff check --fix . +``` + +## Build and Push to PyPi + +``` +python -m docketanalyzer_core dev build +python -m docketanalyzer_core dev build --push +``` + +## Docker Container + +Build and run: + +``` +DOCKER_BUILDKIT=1 docker build -t docketanalyzer-ocr . +docker run --gpus all -p 8000:8000 docketanalyzer-ocr +``` + +Push: + +``` +docker tag docketanalyzer-ocr nadahlberg/docketanalyzer-ocr:latest +docker push nadahlberg/docketanalyzer-ocr:latest +``` diff --git a/Dockerfile b/Dockerfile index 864fe94..e94f9a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM runpod/pytorch:2.0.1-py3.10-cuda11.8.0-devel-ubuntu22.04 WORKDIR /app ENV PYTHONUNBUFFERED=1 +ENV FORCE_GPU=1 RUN apt-get update && \ apt-get install -y libreoffice && \ @@ -14,8 +15,8 @@ RUN python -m pip install --upgrade pip COPY . . -RUN pip install --no-cache-dir '.[gpu]' +RUN pip install --no-cache-dir . -RUN python docketanalyzer_ocr/setup/run.py +RUN python setup.py CMD [ "uvicorn", "service:app", "--host", "0.0.0.0", "--port", "8000" ] diff --git a/README.md b/README.md index 6218737..f3398cc 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,8 @@ ## Installation -Requires Python 3.10 - ```bash -pip install git+https://github.com/docketanalyzer/ocr -``` - -To install with GPU support (much faster): - -``` -pip install 'git+https://github.com/docketanalyzer/ocr[gpu]' +pip install 'docketanalyzer[ocr]' ``` ## Local Usage @@ -19,7 +11,7 @@ pip install 'git+https://github.com/docketanalyzer/ocr[gpu]' Process a document: ```python -from docketanalyzer_ocr import pdf_document +from docketanalyzer.ocr import pdf_document path = 'path/to/doc.pdf doc = pdf_document(path) # the input can also be raw bytes @@ -64,13 +56,10 @@ Save and load data: ```python # Saving a document -with open('saved.json', 'w') as f: - f.write(json.dumps(doc.data)) +doc.save('doc.json') # Loading a document -with open('saved.json', 'r') as f: - data = json.loads(f.read()) -doc = pdf_document(path, load=data) +doc = pdf_document(path, load='doc.json') ``` # Remote Usage @@ -78,6 +67,7 @@ doc = pdf_document(path, load=data) You can also serve this tool with Docker. ``` +*** add prebuilt container here docker build -t docketanalyzer-ocr . docker run --gpus all -p 8000:8000 docketanalyzer-ocr ``` @@ -95,33 +85,61 @@ for page in doc.stream(): When using the remote service, if you want to avoid sending the file in a POST request, configure your S3 credentials. Your document will be temporarily pushed to your bucket to be retrieved by the service. -Set the following in your environment (both for the client and service): +To configure your S3 credentials run: ``` -AWS_ACCESS_KEY_ID=... -AWS_SECRET_ACCESS_KEY=... -S3_BUCKET_NAME=... -S3_ENDPOINT_URL=... +da configure s3 ``` -Usage is identical. We default to using S3 if credentials are available. You can disable this by passing `s3=False` to `process` or `stream`. +Or set the following in your env: + +``` +AWS_ACCESS_KEY_ID +AWS_SECRET_ACCESS_KEY +AWS_S3_BUCKET_NAME +AWS_S3_ENDPOINT_URL +``` + +Usage is identical. We default to using S3 if credentials are available. You can control this explicitly by passing `use_s3=False` to `pdf_document`. # Serverless Support -For serverless usage you can deploy this to RunPod. Just include a custom run command: +For serverless usage you can deploy this to RunPod. To get set up: + +1. Create a serverless worker on RunPod using the docker container. + +``` +*** add prebuilt container here +``` + +2. Add the following custom run command. ``` python -u handler.py ``` -On the client side, add the following variables to your env: +3. Add your S3 credentials to the RunPod worker. + +``` +AWS_ACCESS_KEY_ID +AWS_SECRET_ACCESS_KEY +AWS_S3_BUCKET_NAME +AWS_S3_ENDPOINT_URL +``` + +4. On your local machine, configure your RunPod key and the worker id. + +You can run: ``` -RUNPOD_API_KEY=... -RUNPOD_OCR_ENDPOINT_ID=... +da configure runpod ``` -Usage is otherwise identical, just use the remote flag. +Or set the following in your env: +``` +RUNPOD_API_KEY +RUNPOD_OCR_ENDPOINT_ID +``` -[![codecov](https://codecov.io/gh/docketanalyzer/ocr/graph/badge.svg?token=XRATNOME24)](https://codecov.io/gh/docketanalyzer/ocr) +Usage is otherwise identical, just use `remote=True` with `pdf_document` diff --git a/docketanalyzer_ocr/__init__.py b/docketanalyzer_ocr/__init__.py index 284716b..277dd89 100644 --- a/docketanalyzer_ocr/__init__.py +++ b/docketanalyzer_ocr/__init__.py @@ -1,10 +1,12 @@ from .document import PDFDocument, pdf_document -from .utils import download_from_s3, load_pdf, upload_to_s3 +from .layout import predict_layout +from .utils import load_pdf, page_needs_ocr, page_to_image __all__ = [ "PDFDocument", - "pdf_document", - "upload_to_s3", - "download_from_s3", "load_pdf", + "page_needs_ocr", + "page_to_image", + "pdf_document", + "predict_layout", ] diff --git a/docketanalyzer_ocr/models/doclayout_yolo_docstructbench_imgsz1280_2501.pt b/docketanalyzer_ocr/data/doclayout_yolo_docstructbench_imgsz1280_2501.pt similarity index 100% rename from docketanalyzer_ocr/models/doclayout_yolo_docstructbench_imgsz1280_2501.pt rename to docketanalyzer_ocr/data/doclayout_yolo_docstructbench_imgsz1280_2501.pt diff --git a/docketanalyzer_ocr/document.py b/docketanalyzer_ocr/document.py index ed3f653..425dd36 100644 --- a/docketanalyzer_ocr/document.py +++ b/docketanalyzer_ocr/document.py @@ -1,152 +1,19 @@ import json +import tempfile import uuid +from collections.abc import Generator, Iterator from pathlib import Path -from typing import Generator, Iterator, Optional, Union import fitz -import numpy as np from PIL import Image from tqdm import tqdm -from .remote import RemoteClient -from .utils import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, delete_from_s3, upload_to_s3 - - -def page_to_image(page: fitz.Page, dpi: int = 200) -> np.ndarray: - """Converts a PDF page to a numpy image array. - - This function renders a PDF page at the specified DPI and converts it to a numpy array. - If the resulting image would be too large, it falls back to a lower resolution. - - Args: - page: The pymupdf Page object to convert. - dpi: The dots per inch resolution to render at. Defaults to 200. - - Returns: - np.ndarray: The page as a numpy array in RGB format. - """ - mat = fitz.Matrix(dpi / 72, dpi / 72) - pm = page.get_pixmap(matrix=mat, alpha=False) - - if pm.width > 4500 or pm.height > 4500: - pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - - img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) - img = np.array(img) - - return img - - -def extract_native_text(page: fitz.Page, dpi: int) -> list[dict]: - """Extracts text content and bounding boxes from a PDF page using native PDF text. - - This function extracts text directly from the PDF's internal structure rather than using OCR. - - Args: - page: The pymupdf Page object to extract text from. - dpi: The resolution to use when scaling bounding boxes. - - Returns: - list[dict]: A list of dictionaries, each containing: - - 'bbox': The bounding box coordinates [x1, y1, x2, y2] - - 'content': The text content of the line - """ - blocks = page.get_text("dict")["blocks"] - data = [] - for block in blocks: - if "lines" in block: - for line in block["lines"]: - content = "".join([span["text"] for span in line["spans"]]) - if content.strip(): - line["bbox"] = tuple([(dpi / 72) * x for x in line["bbox"]]) - data.append( - { - "bbox": line["bbox"], - "content": content, - } - ) - return data - - -def has_images(page: fitz.Page) -> bool: - """Checks if a page has images that are large enough to potentially contain text. - - Args: - page: The pymupdf Page object to check. - - Returns: - bool: True if the page contains images of a significant size, False otherwise. - """ - image_list = page.get_images(full=True) - - for _, img_info in enumerate(image_list): - xref = img_info[0] - base_image = page.parent.extract_image(xref) - - if base_image: - width = base_image["width"] - height = base_image["height"] - if width > 10 and height > 10: - return True - - return False - - -def has_text_annotations(page: fitz.Page) -> bool: - """Checks if a page has annotations that could contain text. - - Args: - page: The pymupdf Page object to check. - - Returns: - bool: True if the page has text-containing annotations, False otherwise. - """ - annots = page.annots() - - if annots: - for annot in annots: - annot_type = annot.type[1] - if annot_type in [fitz.PDF_ANNOT_FREE_TEXT, fitz.PDF_ANNOT_WIDGET]: - return True - - return False +from docketanalyzer_core import load_s3 - -def page_needs_ocr(page: fitz.Page) -> bool: - """Determines if a page needs OCR processing. - - This function checks various conditions to decide if OCR is needed: - - If the page has no text - - If the page has CID-encoded text (often indicates non-extractable text) - - If the page has text annotations - - If the page has images that might contain text - - If the page has many drawing paths (might be scanned text) - - Args: - page: The pymupdf Page object to check. - - Returns: - bool: True if the page needs OCR processing, False otherwise. - """ - page_text = page.get_text() - - if page_text.strip() == "": - return True - - if "(cid:" in page_text: - return True - - if has_text_annotations(page): - return True - - if has_images(page): - return True - - paths = page.get_drawings() - if len(paths) > 10: - return True - - return False +from .layout import boxes_overlap, predict_layout +from .ocr import extract_ocr_text +from .remote import RemoteClient +from .utils import extract_native_text, page_needs_ocr, page_to_image class DocumentComponent: @@ -180,7 +47,8 @@ def children(self) -> list["DocumentComponent"]: """Gets the child components of this component. Returns: - list[DocumentComponent]: A list of child components, or an empty list if no children exist. + list[DocumentComponent]: A list of child components, or an empty list + if no children exist. """ if self.child_attr is not None: return getattr(self, self.child_attr, []) @@ -234,8 +102,8 @@ def id(self) -> str: def clip( self, - bbox: Optional[tuple[float, float, float, float]] = None, - save: Optional[str] = None, + bbox: tuple[float, float, float, float] | None = None, + save: str | None = None, ): """Clips an image of this component from the parent page. @@ -266,8 +134,7 @@ def __iter__(self) -> Iterator["DocumentComponent"]: Yields: DocumentComponent: Each child component. """ - for child in self.children: - yield child + yield from self.children def __len__(self) -> int: """Gets the number of child components. @@ -349,7 +216,7 @@ def __init__( i: int, bbox: tuple[float, float, float, float], block_type: str = "text", - lines: list[dict] = [], + lines: list[dict] | None = None, ): """Initializes a new Block component. @@ -364,7 +231,12 @@ def __init__( self.i = i self.bbox = bbox self.block_type = block_type - self.lines = [Line(self, i, line["bbox"], line["content"]) for i, line in enumerate(lines)] + self.lines = [] + if lines is not None: + self.lines = [ + Line(self, i, line["bbox"], line["content"]) + for i, line in enumerate(lines) + ] @property def data(self) -> dict: @@ -399,7 +271,7 @@ class Page(DocumentComponent): child_attr = "blocks" text_join = "\n\n" - def __init__(self, doc: "PDFDocument", i: int, blocks: list[dict] = []): + def __init__(self, doc: "PDFDocument", i: int, blocks: list[dict] | None = None): """Initializes a new Page component. Args: @@ -413,7 +285,8 @@ def __init__(self, doc: "PDFDocument", i: int, blocks: list[dict] = []): self.img = None self.extracted_text = None self.needs_ocr = None - self.set_blocks(blocks) + if blocks is not None: + self.set_blocks(blocks) @property def fitz(self) -> fitz.Page: @@ -443,8 +316,8 @@ def set_blocks(self, blocks: list[dict]) -> None: def clip( self, - bbox: Optional[tuple[float, float, float, float]] = None, - save: Optional[str] = None, + bbox: tuple[float, float, float, float] | None = None, + save: str | None = None, ) -> Image.Image: """Clips an image from this page. @@ -482,8 +355,9 @@ def data(self) -> dict: class PDFDocument: """Represents a PDF document. - This class handles loading, processing, and extracting text from PDF documents. - It manages the document hierarchy (pages, blocks, lines) and handles OCR when needed. + This class handles loading, processing, and extracting text from PDF + documents. It manages the document hierarchy (pages, blocks, lines) + and handles OCR when needed. Attributes: doc: The underlying PyMuPDF document. @@ -495,12 +369,13 @@ class PDFDocument: def __init__( self, - file_or_path: Union[bytes, str, Path], - filename: Optional[str] = None, + file_or_path: bytes | str | Path, + filename: str | None = None, dpi: int = 200, + use_s3: bool = True, remote: bool = False, - api_key: Optional[str] = None, - endpoint_url: Optional[str] = None, + api_key: str | None = None, + endpoint_url: str | None = None, ): """Initializes a new PDFDocument. @@ -508,7 +383,10 @@ def __init__( file_or_path: The PDF file content as bytes, or a path to the PDF file. filename: Optional name for the PDF file. dpi: The resolution to use when rendering pages for OCR. Defaults to 200. - remote: Whether to use remote processing via RemoteClient. Defaults to False. + use_s3: Whether to upload the PDF to S3 for remote processing. + Defaults to True. + remote: Whether to use remote processing via RemoteClient. + Defaults to False. api_key: Optional API key for remote processing. endpoint_url: Optional full endpoint URL for remote processing. """ @@ -516,107 +394,85 @@ def __init__( self.doc = fitz.open("pdf", file_or_path) self.pdf_bytes = file_or_path self.pdf_path = None + self.filename = filename or "document.pdf" else: self.doc = fitz.open(file_or_path) - self.pdf_bytes = None - self.pdf_path = file_or_path - self.filename = filename or getattr(file_or_path, "name", "document.pdf") + self.pdf_bytes = self.doc.tobytes() + self.pdf_path = Path(file_or_path) + self.filename = filename or self.pdf_path.name self.dpi = dpi self.remote = remote self.pages = [Page(self, i) for i in range(len(self.doc))] - self._remote_client = None - self._s3_key = None - self._api_key = api_key - self._endpoint_url = endpoint_url - - @property - def remote_client(self) -> RemoteClient: - """Gets or creates the remote client. - - Returns: - RemoteClient: The remote client instance. - """ - if self._remote_client is None: - self._remote_client = RemoteClient(api_key=self._api_key, endpoint_url=self._endpoint_url) - return self._remote_client - - def _upload_to_s3(self) -> str: - """Uploads the PDF to S3 with a random filename under the 'ocr' folder. - - Returns: - str: The S3 key where the file was uploaded. - - Raises: - ValueError: If the upload to S3 fails. - """ - random_id = str(uuid.uuid4()) - s3_key = f"tmp/{random_id}_{self.filename}" - - if self.pdf_bytes is not None: - with Path(f"/tmp/{random_id}.pdf").open("wb") as f: - f.write(self.pdf_bytes) - temp_path = f.name - else: - temp_path = self.pdf_path - - # Upload to S3 - success = upload_to_s3(temp_path, s3_key, overwrite=True) - if not success: - raise ValueError(f"Failed to upload PDF to S3 at key: {s3_key}") - - # Clean up temporary file if we created one - if self.pdf_bytes is not None: - Path(temp_path).unlink(missing_ok=True) - - self._s3_key = s3_key - return s3_key - - def stream(self, batch_size: int = 1, s3: bool = True) -> Generator[Page, None, None]: + self.remote_client = RemoteClient(api_key=api_key, endpoint_url=endpoint_url) + self.use_s3 = use_s3 + self.s3 = load_s3() + self.s3_available = self.s3.status() + self.s3_key = ( + None if not self.s3_available else f"tmp/{uuid.uuid4()}_{self.filename}" + ) + + def stream(self, batch_size: int = 1) -> Generator[Page, None, None]: """Processes the document page by page and yields each processed page. - This is a generator that processes pages in batches and yields each page - after it has been processed. If remote=True, uses the RemoteClient for processing. + If remote=True, uses the RemoteClient for processing. Args: batch_size: Number of pages to process in each batch. Defaults to 1. - s3: Whether to upload the PDF to S3 for remote processing. Defaults to True. Yields: Page: Each processed page. """ if self.remote: + s3_key, file = None, None try: - s3_key, file = None, None - if s3 and AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY: - s3_key = self._upload_to_s3() + if self.use_s3 and self.s3_available: + if self.pdf_path is not None: + self.s3.upload(self.pdf_path, self.s3_key) + else: + with tempfile.NamedTemporaryFile() as f: + f.write(self.pdf_bytes) + self.s3.upload(f.name, self.s3_key) + s3_key = self.s3_key else: - file = self.pdf_bytes or Path(self.pdf_path).read_bytes() + file = self.pdf_bytes for result in self.remote_client( - file=file, s3_key=s3_key, filename=self.filename, batch_size=batch_size + file=file, + s3_key=s3_key, + filename=self.filename, + batch_size=batch_size, ): if "stream" in result: for stream_item in result["stream"]: - if "output" in stream_item and "page" in stream_item["output"]: + if ( + "output" in stream_item + and "page" in stream_item["output"] + ): page_data = stream_item["output"]["page"] page_idx = page_data.get("i", 0) if page_idx < len(self.pages): - self.pages[page_idx].set_blocks(page_data.get("blocks", [])) + self.pages[page_idx].set_blocks( + page_data.get("blocks", []) + ) yield self.pages[page_idx] status = result.get("status") if status in ["COMPLETED", "FAILED", "CANCELLED"]: break - if "stream" in result: + # delete maybe? + if (1 == -1) and "stream" in result: for stream_item in result["stream"]: - if stream_item.get("output", {}).get("status") in ["COMPLETED", "FAILED", "CANCELLED"]: + if stream_item.get("output", {}).get("status") in [ + "COMPLETED", + "FAILED", + "CANCELLED", + ]: break finally: - if self._s3_key: - delete_from_s3(self._s3_key) - self._s3_key = None + if s3_key is not None: + self.s3.delete(self.s3_key) else: for i in tqdm(range(0, len(self.doc), batch_size), desc="Processing PDF"): batch_pages = [] @@ -625,24 +481,22 @@ def stream(self, batch_size: int = 1, s3: bool = True) -> Generator[Page, None, for j in range(i, min(i + batch_size, len(self.doc))): page = self[j] # Get the image representation of the page - if not page.img: + if page.img is None: page.img = page_to_image(page.fitz, self.dpi) # Check if we need to OCR the page page.needs_ocr = page_needs_ocr(page.fitz) if page.needs_ocr: - from .ocr import extract_ocr_text - page.extracted_text = extract_ocr_text(page.img) else: - page.extracted_text = extract_native_text(page.fitz, dpi=self.dpi) + page.extracted_text = extract_native_text( + page.fitz, dpi=self.dpi + ) batch_pages.append(page) batch_imgs.append(page.img) - from .layout import boxes_overlap, predict_layout - layout_data = predict_layout(batch_imgs, batch_size=batch_size) for j, page_layout in enumerate(layout_data): @@ -655,25 +509,27 @@ def stream(self, batch_size: int = 1, s3: bool = True) -> Generator[Page, None, if boxes_overlap(block["bbox"], line["bbox"]): block_lines.append(li) block["lines"] = [lines[li] for li in block_lines] - lines = [line for li, line in enumerate(lines) if li not in block_lines] + lines = [ + line + for li, line in enumerate(lines) + if li not in block_lines + ] blocks.append(block) page.set_blocks(blocks) yield page - def process(self, batch_size: int = 1, s3: bool = True) -> "PDFDocument": + def process(self, batch_size: int = 1) -> "PDFDocument": """Processes the entire document at once. - This method processes all pages in the document and returns the document itself. - If remote=True, uses the RemoteClient for processing. + This just runs stream in a loop and returns the document when done. Args: batch_size: Number of pages to process in each batch. Defaults to 1. - s3: Whether to upload the PDF to S3 for remote processing. Defaults to True. Returns: PDFDocument: The processed document (self). """ - for _ in self.stream(batch_size=batch_size, s3=s3): + for _ in self.stream(batch_size=batch_size): pass return self @@ -689,7 +545,7 @@ def data(self) -> dict: "pages": [page.data for page in self.pages], } - def save(self, path: Union[str, Path]) -> None: + def save(self, path: str | Path) -> None: """Saves the document data to a JSON file. Args: @@ -697,7 +553,7 @@ def save(self, path: Union[str, Path]) -> None: """ Path(path).write_text(json.dumps(self.data)) - def load(self, path_or_data: Union[str, Path, dict]) -> "PDFDocument": + def load(self, path_or_data: str | Path | dict) -> "PDFDocument": """Loads document data from a JSON file or dictionary. Args: @@ -706,7 +562,7 @@ def load(self, path_or_data: Union[str, Path, dict]) -> "PDFDocument": Returns: PDFDocument: The loaded document (self). """ - if isinstance(path_or_data, (str, Path)): + if isinstance(path_or_data, str | Path): path = Path(path_or_data) data = json.loads(path.read_text()) else: @@ -740,8 +596,7 @@ def __iter__(self) -> Iterator[Page]: Yields: Page: Each page in the document. """ - for page in self.pages: - yield page + yield from self.pages def __len__(self) -> int: """Gets the number of pages in the document. @@ -753,13 +608,14 @@ def __len__(self) -> int: def pdf_document( - file_or_path: Union[bytes, str, Path], - filename: Optional[str] = None, + file_or_path: bytes | str | Path, + filename: str | None = None, dpi: int = 200, - load: Optional[Union[str, Path, dict]] = None, + use_s3: bool = True, remote: bool = False, - api_key: Optional[str] = None, - endpoint_url: Optional[str] = None, + api_key: str | None = None, + endpoint_url: str | None = None, + load: str | Path | dict | None = None, ) -> PDFDocument: """Processes a PDF file for text extraction. @@ -770,16 +626,25 @@ def pdf_document( file_or_path: The PDF file content as bytes, or a path to the PDF file. filename: Optional name for the PDF file. dpi: The resolution to use when rendering pages for OCR. Defaults to 200. - load: Optional path to a JSON file or dictionary with existing document data to load. + use_s3: Whether to upload the PDF to S3 for remote processing. Defaults + to True. remote: Whether to use remote processing via RemoteClient. Defaults to False. api_key: Optional API key for remote processing. endpoint_url: Optional full endpoint URL for remote processing. + load: Optional path to a JSON file or dictionary with existing document + data to load. Returns: PDFDocument: The created (and possibly processed) document. """ doc = PDFDocument( - file_or_path, filename=filename, dpi=dpi, remote=remote, api_key=api_key, endpoint_url=endpoint_url + file_or_path, + filename=filename, + dpi=dpi, + use_s3=use_s3, + remote=remote, + api_key=api_key, + endpoint_url=endpoint_url, ) if load is not None: diff --git a/docketanalyzer_ocr/layout.py b/docketanalyzer_ocr/layout.py index 61c9185..5892545 100644 --- a/docketanalyzer_ocr/layout.py +++ b/docketanalyzer_ocr/layout.py @@ -1,11 +1,7 @@ -import torch -from doclayout_yolo import YOLOv10 - from .utils import BASE_DIR LAYOUT_MODEL = None - LAYOUT_CHOICES = { 0: "title", 1: "text", @@ -37,62 +33,52 @@ def merge_overlapping_blocks(blocks: list[dict]) -> list[dict]: if not blocks: return [] - # Create a priority map for faster lookup - type_priority = {block_type: i for i, block_type in enumerate(LAYOUT_CHOICES.values())} + # Merged blocks with different types will get the type with the highest priority + type_priority = { + block_type: i for i, block_type in enumerate(LAYOUT_CHOICES.values()) + } - # Add default priority for any types not in the list (lowest priority) - max_priority = len(type_priority) - - # Start with all blocks as unprocessed unprocessed = [block.copy() for block in blocks] result = [] while unprocessed: - # Take a block as the current merged block current = unprocessed.pop(0) current_bbox = current["bbox"] - # Flag to check if any merge happened in this iteration merged = True while merged: merged = False - # Check each remaining unprocessed block i = 0 while i < len(unprocessed): other = unprocessed[i] other_bbox = other["bbox"] - # Check for overlap if boxes_overlap(current_bbox, other_bbox): - # Determine which type to keep based on priority - current_priority = type_priority.get(current["type"], max_priority) - other_priority = type_priority.get(other["type"], max_priority) + current_priority = type_priority[current["type"]] + other_priority = type_priority[other["type"]] - # Keep the type with higher priority (lower number) if other_priority < current_priority: current["type"] = other["type"] - # Merge the bounding boxes current_bbox = merge_boxes(current_bbox, other_bbox) current["bbox"] = current_bbox - # Remove the merged block from unprocessed unprocessed.pop(i) merged = True else: i += 1 - # Add the merged block to the result result.append(current) - # Sort by ymin and then xmin result.sort(key=lambda x: (x["bbox"][1], x["bbox"][0])) return result -def boxes_overlap(box1: tuple[float, float, float, float], box2: tuple[float, float, float, float]) -> bool: +def boxes_overlap( + box1: tuple[float, float, float, float], box2: tuple[float, float, float, float] +) -> bool: """Checks if two bounding boxes overlap. Args: @@ -102,17 +88,15 @@ def boxes_overlap(box1: tuple[float, float, float, float], box2: tuple[float, fl Returns: bool: True if the boxes overlap, False otherwise. """ - # Extract coordinates x1_min, y1_min, x1_max, y1_max = box1 x2_min, y2_min, x2_max, y2_max = box2 - # Check for overlap - if x1_max < x2_min or x2_max < x1_min: # No horizontal overlap - return False - if y1_max < y2_min or y2_max < y1_min: # No vertical overlap - return False - - return True + return not ( + x1_max < x2_min + or x2_max < x1_min # No horizontal overlap + or y1_max < y2_min + or y2_max < y1_min # No vertical overlap + ) def merge_boxes( @@ -125,13 +109,12 @@ def merge_boxes( box2: Tuple of (xmin, ymin, xmax, ymax) for the second box. Returns: - tuple[float, float, float, float]: A new bounding box that contains both input boxes. + tuple[float, float, float, float]: A new bounding box that contains both + input boxes. """ - # Extract coordinates x1_min, y1_min, x1_max, y1_max = box1 x2_min, y2_min, x2_max, y2_max = box2 - # Create merged box with min/max coordinates merged_box = ( min(x1_min, x2_min), min(y1_min, y2_min), @@ -142,7 +125,7 @@ def merge_boxes( return merged_box -def load_model() -> tuple[YOLOv10, str]: +def load_model() -> tuple["YOLOv10", str]: # noqa: F821 """Loads and initializes the document layout detection model. Returns: @@ -150,12 +133,18 @@ def load_model() -> tuple[YOLOv10, str]: - The initialized YOLOv10 model - The device string ('cpu' or 'cuda') """ + import torch + from doclayout_yolo import YOLOv10 + global LAYOUT_MODEL device = "cpu" if not torch.cuda.is_available() else "cuda" if LAYOUT_MODEL is None: - LAYOUT_MODEL = YOLOv10(BASE_DIR / "models" / "doclayout_yolo_docstructbench_imgsz1280_2501.pt", verbose=False) + LAYOUT_MODEL = YOLOv10( + BASE_DIR / "data" / "doclayout_yolo_docstructbench_imgsz1280_2501.pt", + verbose=False, + ) LAYOUT_MODEL.to(device) return LAYOUT_MODEL, device @@ -187,6 +176,7 @@ def predict_layout(images: list, batch_size: int = 8) -> list[list[dict]]: for xyxy, cla in zip( pred.boxes.xyxy, pred.boxes.cls, + strict=False, ): bbox = [int(p.item()) for p in xyxy] blocks.append( diff --git a/docketanalyzer_ocr/ocr.py b/docketanalyzer_ocr/ocr.py index 9746e19..d936db6 100644 --- a/docketanalyzer_ocr/ocr.py +++ b/docketanalyzer_ocr/ocr.py @@ -1,43 +1,262 @@ -from typing import Any, Union +import atexit +import base64 +import json +import os +import platform +import shutil +import signal +import subprocess +import sys +import threading +from datetime import datetime +from pathlib import Path +from typing import Any -import torch -from paddleocr import PaddleOCR +import numpy as np +FORCE_GPU = int(os.getenv("FORCE_GPU", 0)) +SCRIPT_PATH = Path(__file__).resolve() +VENV_SCRIPT_PATH = ( + Path.home() / ".cache" / "docketanalyzer" / "ocr" / "venv" / SCRIPT_PATH.name +) OCR_MODEL = None -def load_model() -> tuple[PaddleOCR, str]: - """Loads and initializes the PaddleOCR model. +class OCRService: + """Service for running PaddleOCR in a separate process. - This function initializes the OCR model if it hasn't been loaded yet. - It determines whether to use CPU or CUDA based on availability. - - Returns: - tuple[PaddleOCR, str]: A tuple containing: - - The initialized PaddleOCR model - - The device string ('cpu' or 'cuda') + We do gymnastics to avoid the Python 3.10 requirement for the PaddleOCR package. """ - global OCR_MODEL - device = "cpu" if not torch.cuda.is_available() else "cuda" + def __init__(self, device: str | None = None): + """Initialize the OCR service. + + Args: + device: The device to use for OCR processing. Defaults to 'cuda' if + available, otherwise 'cpu'. + """ + import torch - if OCR_MODEL is None: - OCR_MODEL = PaddleOCR( - lang="en", - use_gpu=device == "cuda", - gpu_mem=5000, - precision="bf16", - show_log=False, + self.device = device or ( + "cpu" if not (torch.cuda.is_available() or FORCE_GPU) else "cuda" ) - return OCR_MODEL, device + def process_image(self, image: np.array) -> list[dict]: + """Extracts text from an image using OCR. + + This function processes an image with the PaddleOCR model to extract text + and bounding boxes for each detected text line. + + Args: + image: The input image. Can be a file path, bytes, or a numpy array. + + Returns: + list[dict]: A list of dictionaries, each containing: + - 'bbox': The bounding box coordinates [x1, y1, x2, y2] + - 'content': The extracted text content + """ + global OCR_MODEL + if OCR_MODEL is None: + print("Loading OCR model...") + from paddleocr import PaddleOCR + + OCR_MODEL = PaddleOCR( + lang="en", + use_gpu=self.device == "cuda", + gpu_mem=5000, + precision="bf16", + show_log=True, + ) + print("OCR model loaded.") + + result = OCR_MODEL.ocr(image, cls=False) + data = [] + for idx in range(len(result)): + res = result[idx] + if res: + for line in res: + data.append( + { + "bbox": line[0][0] + line[0][2], + "content": line[1][0], + } + ) + return data + + def run(self): + """Main service loop listening for image input.""" + while True: + try: + input_line = sys.stdin.readline().strip() + if not input_line: + continue + + request = json.loads(input_line) + if "image" not in request: + continue + + print("Received image", flush=True) + image = np.frombuffer( + base64.b64decode(request["image"]), + dtype=request["dtype"], + ).reshape(request["shape"]) + + data = self.process_image(image) + print("OCR RESULT:" + json.dumps(data), flush=True) + + except Exception as e: + print(f"ERROR: {e!s}", flush=True) + + +class OCRServiceClient: + """Client for interacting with the OCR service. + + This class manages the subprocess for the OCR service and provides methods + to send images for processing and receive results. + """ + + def __init__(self, verbose: bool = False): + """Spawns the OCR service as a subprocess. + + Args: + verbose: If True, prints debug information to stdout. + Defaults to False. + """ + self._process = None + self.verbose = verbose + self._lock = threading.Lock() + + def install(self): + """Installs the PaddleOCR package in a virtual environment. + This function checks for the existence of a virtual environment and + the PaddleOCR package. If not found, it creates a new virtual environment + and installs the necessary packages, including PaddleOCR and PyTorch. + """ + import torch -def extract_ocr_text(image: Union[str, bytes, Any]) -> list[dict]: - """Extracts text from an image using OCR. + if ( + VENV_SCRIPT_PATH.exists() + and VENV_SCRIPT_PATH.read_text() != SCRIPT_PATH.read_text() + ): + shutil.rmtree(VENV_SCRIPT_PATH.parent) + VENV_SCRIPT_PATH.parent.mkdir(parents=True, exist_ok=True) + if not VENV_SCRIPT_PATH.exists(): + print("Creating a virtual environment and installing PaddleOCR...") + subprocess.check_call( + ["uv", "venv", VENV_SCRIPT_PATH.parent, "--python", "3.10"] + ) + VENV_SCRIPT_PATH.write_text(SCRIPT_PATH.read_text()) + venv_python = VENV_SCRIPT_PATH.parent / "bin" / "python" + cmd = [ + "uv", + "pip", + "install", + "--python", + str(venv_python), + "torch", + "paddleocr", + "setuptools", + ] + if torch.cuda.is_available() or FORCE_GPU: + cmd.append("paddlepaddle-gpu==2.6.2") + elif platform.system() == "Darwin": + cmd.append("paddlepaddle==0.0.0") + cmd.append("-f") + cmd.append("https://www.paddlepaddle.org.cn/whl/mac/cpu/develop.html") + else: + cmd.append("paddlepaddle") - This function processes an image with the PaddleOCR model to extract text - and bounding boxes for each detected text line. + subprocess.check_call(cmd) + print("Installation complete. Downloading models...") + + @property + def process(self) -> subprocess.Popen: + """Returns the subprocess for the OCR service. + + If the subprocess is not running, it starts a new one and registers + cleanup handlers to terminate the process on exit or signal. + + Returns: + subprocess.Popen: The running OCR service process. + + Raises: + RuntimeError: If the OCR service is not running or has failed. + """ + if self._process is None: + print("Starting OCR process...") + self.install() + venv_python = VENV_SCRIPT_PATH.parent / "bin" / "python" + + # Ensure process dies with main script + if os.name == "nt": + creationflags = subprocess.CREATE_NEW_PROCESS_GROUP + preexec_fn = None + else: + creationflags = 0 + preexec_fn = os.setsid + + self._process = subprocess.Popen( + [str(venv_python), str(VENV_SCRIPT_PATH)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + bufsize=1, + preexec_fn=preexec_fn, + creationflags=creationflags, + ) + + atexit.register(self.stop) + signal.signal(signal.SIGTERM, self.stop) + signal.signal(signal.SIGINT, self.stop) + return self._process + + def process_image(self, image: np.array) -> list[dict]: + """Sends an image to the OCR service and retrieves the results.""" + with self._lock: + if self.process.poll() is not None: + print("OCR Service terminated unexpectedly. Restarting...") + self._process = None + if self.process.poll() is not None: + raise RuntimeError("Failed to restart OCR Service") + print("OCR Service restarted successfully") + + request = { + "image": base64.b64encode(image.tobytes()).decode("utf-8"), + "dtype": str(image.dtype), + "shape": image.shape, + } + + self.process.stdin.write(json.dumps(request) + "\n") + self.process.stdin.flush() + + start = datetime.now() + while (datetime.now() - start).total_seconds() < 200: + response = self.process.stdout.readline().strip() + if response and self.verbose: + print(response) + if response.startswith("OCR RESULT:"): + return json.loads(response[11:]) + if response.startswith("ERROR:"): + raise RuntimeError(response[6:]) + raise TimeoutError("OCR service timed out") + + def stop(self, *args: Any): + """Terminates the OCR service.""" + if self._process is not None: + self._process.terminate() + self._process.wait() + self._process = None + + +OCR_CLIENT = OCRServiceClient() + + +def extract_ocr_text(image: str | bytes | Any) -> list[dict]: + """Extracts text from an image using the OCR service. + + This function sends an image to the OCR service for processing and returns + the extracted text and bounding boxes. Args: image: The input image. Can be a file path, bytes, or a numpy array. @@ -47,18 +266,8 @@ def extract_ocr_text(image: Union[str, bytes, Any]) -> list[dict]: - 'bbox': The bounding box coordinates [x1, y1, x2, y2] - 'content': The extracted text content """ - model, _ = load_model() - - result = model.ocr(image, cls=False) - data = [] - for idx in range(len(result)): - res = result[idx] - if res: - for line in res: - data.append( - { - "bbox": line[0][0] + line[0][2], - "content": line[1][0], - } - ) - return data + return OCR_CLIENT.process_image(image) + + +if __name__ == "__main__": + OCRService().run() diff --git a/docketanalyzer_ocr/remote.py b/docketanalyzer_ocr/remote.py index 545c15d..5437429 100644 --- a/docketanalyzer_ocr/remote.py +++ b/docketanalyzer_ocr/remote.py @@ -1,11 +1,12 @@ import base64 import json import time -from typing import Any, Dict, Generator, List, Optional, Union +from collections.abc import Generator +from typing import Any import requests -from .utils import RUNPOD_API_KEY, RUNPOD_OCR_ENDPOINT_ID +from docketanalyzer_core import env class RemoteClient: @@ -15,20 +16,21 @@ class RemoteClient: authentication, request formatting, and streaming response handling. """ - def __init__(self, api_key: Optional[str] = None, endpoint_url: Optional[str] = None): + def __init__(self, api_key: str | None = None, endpoint_url: str | None = None): """Initialize the remote client. Args: - api_key: API key for authentication. If None, uses RUNPOD_API_KEY from environment. + api_key: API key for authentication. If None, uses RUNPOD_API_KEY + from environment. endpoint_url: Full endpoint URL. If None, constructs URL from RUNPOD_OCR_ENDPOINT_ID or defaults to localhost. """ - self.api_key = api_key or RUNPOD_API_KEY + self.api_key = api_key or env.RUNPOD_API_KEY if endpoint_url: self.base_url = endpoint_url - elif RUNPOD_OCR_ENDPOINT_ID: - self.base_url = f"https://api.runpod.ai/v2/{RUNPOD_OCR_ENDPOINT_ID}" + elif env.RUNPOD_OCR_ENDPOINT_ID: + self.base_url = f"https://api.runpod.ai/v2/{env.RUNPOD_OCR_ENDPOINT_ID}" else: self.base_url = "http://localhost:8000" @@ -38,26 +40,27 @@ def __init__(self, api_key: Optional[str] = None, endpoint_url: Optional[str] = def __call__( self, - s3_key: Optional[str] = None, - file: Optional[bytes] = None, - filename: Optional[str] = None, + s3_key: str | None = None, + file: bytes | None = None, + filename: str | None = None, batch_size: int = 1, stream: bool = True, - timeout: int = 600, + timeout: int = 300, poll_interval: float = 1.0, - **extra_params, - ) -> Union[List[Dict[str, Any]], Generator[Dict[str, Any], None, None]]: + **kwargs: Any, + ) -> list[dict[str, Any]] | Generator[dict[str, Any], None, None]: """Make a request to the remote endpoint. Args: s3_key: S3 key to the PDF file. Either s3_key or file must be provided. - file: Binary PDF data or base64-encoded string. Either s3_key or file must be provided. + file: Binary PDF data or base64-encoded string. Either s3_key or + file must be provided. filename: Optional filename for the PDF. batch_size: Batch size for processing. Defaults to 1. stream: Whether to stream the response. Defaults to True. timeout: Request timeout in seconds. Defaults to 600 (10 minutes). poll_interval: Interval in seconds between status checks. Defaults to 1.0. - **extra_params: Additional parameters to include in the input payload. + **kwargs: Additional parameters to include in the input payload. Returns: If stream=True, returns a generator yielding response chunks. @@ -81,7 +84,7 @@ def __call__( if filename: input_data["filename"] = filename - input_data.update(extra_params) + input_data.update(kwargs) payload = {"input": input_data} @@ -97,7 +100,7 @@ def __call__( break return results - def _submit_job(self, payload: Dict[str, Any], timeout: int) -> str: + def _submit_job(self, payload: dict[str, Any], timeout: int) -> str: """Submit a job to the remote endpoint. Args: @@ -113,18 +116,19 @@ def _submit_job(self, payload: Dict[str, Any], timeout: int) -> str: """ url = f"{self.base_url}/run" - response = requests.post(url, headers=self.headers, json=payload, timeout=timeout) + response = requests.post( + url, headers=self.headers, json=payload, timeout=timeout + ) response.raise_for_status() - try: - result = response.json() - if "id" not in result: - raise ValueError(f"Invalid response format, missing 'id': {result}") - return result["id"] - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + result = response.json() + if "id" not in result: + raise ValueError(f"Invalid response format, missing 'id': {result}") + return result["id"] - def _stream_results(self, job_id: str, timeout: int, poll_interval: float) -> Generator[Dict[str, Any], None, None]: + def _stream_results( + self, job_id: str, timeout: int, poll_interval: float + ) -> Generator[dict[str, Any], None, None]: """Stream results from a job. Args: @@ -142,30 +146,26 @@ def _stream_results(self, job_id: str, timeout: int, poll_interval: float) -> Ge """ url = f"{self.base_url}/stream/{job_id}" start_time = time.time() - completed = False - while not completed and time.time() - start_time < timeout: + while time.time() - start_time < timeout: try: - with requests.post(url, headers=self.headers, stream=True, timeout=timeout) as response: + with requests.post( + url, headers=self.headers, stream=True, timeout=timeout + ) as response: if response.status_code == 200: for line in response.iter_lines(): if not line: continue - try: - data = json.loads(line.decode("utf-8")) - - if data.get("status") == "COMPLETED": - completed = True - - yield data - - if data.get("status") in ["COMPLETED", "FAILED", "CANCELLED"]: - return - - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + data = json.loads(line.decode("utf-8")) + yield data + if data.get("status") in [ + "COMPLETED", + "FAILED", + "CANCELLED", + ]: + return elif response.status_code == 404: time.sleep(poll_interval) else: @@ -175,10 +175,9 @@ def _stream_results(self, job_id: str, timeout: int, poll_interval: float) -> Ge time.sleep(poll_interval) continue - if not completed and time.time() - start_time >= timeout: - raise TimeoutError(f"Streaming results timed out after {timeout} seconds") + raise TimeoutError(f"Streaming results timed out after {timeout} seconds") - def get_status(self, job_id: str, timeout: int = 30) -> Dict[str, Any]: + def get_status(self, job_id: str, timeout: int = 30) -> dict[str, Any]: """Get the status of a job. Args: @@ -197,12 +196,9 @@ def get_status(self, job_id: str, timeout: int = 30) -> Dict[str, Any]: response = requests.post(url, headers=self.headers, timeout=timeout) response.raise_for_status() - try: - return response.json() - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + return response.json() - def cancel_job(self, job_id: str, timeout: int = 30) -> Dict[str, Any]: + def cancel_job(self, job_id: str, timeout: int = 30) -> dict[str, Any]: """Cancel a job. Args: @@ -221,12 +217,9 @@ def cancel_job(self, job_id: str, timeout: int = 30) -> Dict[str, Any]: response = requests.post(url, headers=self.headers, timeout=timeout) response.raise_for_status() - try: - return response.json() - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + return response.json() - def purge_queue(self, timeout: int = 30) -> Dict[str, Any]: + def purge_queue(self, timeout: int = 30) -> dict[str, Any]: """Purge all queued jobs. Args: @@ -244,12 +237,9 @@ def purge_queue(self, timeout: int = 30) -> Dict[str, Any]: response = requests.post(url, headers=self.headers, timeout=timeout) response.raise_for_status() - try: - return response.json() - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + return response.json() - def get_health(self, timeout: int = 30) -> Dict[str, Any]: + def get_health(self, timeout: int = 30) -> dict[str, Any]: """Get endpoint health information. Args: @@ -267,7 +257,4 @@ def get_health(self, timeout: int = 30) -> Dict[str, Any]: response = requests.get(url, headers=self.headers, timeout=timeout) response.raise_for_status() - try: - return response.json() - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse response: {e}") + return response.json() diff --git a/docketanalyzer_ocr/setup/run.py b/docketanalyzer_ocr/setup/run.py deleted file mode 100644 index 3e432d9..0000000 --- a/docketanalyzer_ocr/setup/run.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Docket Analyzer OCR Module - Main Entry Point. - -This script serves as the main entry point for the Docket Analyzer OCR module -when run directly. It demonstrates basic usage by processing a test PDF file -and printing the extracted text from each page. -""" - -from pathlib import Path - -from docketanalyzer_ocr import pdf_document - -if __name__ == "__main__": - # Process a test PDF file and print the extracted text - path = Path(__file__).parent / "test.pdf" - doc = pdf_document(path) - for page in doc.stream(): - print(page.text) diff --git a/docketanalyzer_ocr/utils.py b/docketanalyzer_ocr/utils.py index fdf5bd5..6481e1b 100644 --- a/docketanalyzer_ocr/utils.py +++ b/docketanalyzer_ocr/utils.py @@ -1,142 +1,180 @@ -import os import tempfile from pathlib import Path -from typing import Optional, Union -import boto3 -from botocore.client import Config -from dotenv import load_dotenv - -load_dotenv(override=True) +import fitz +import numpy as np +from PIL import Image +from docketanalyzer_core import load_s3 BASE_DIR = Path(__file__).resolve().parent -RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY") -RUNPOD_OCR_ENDPOINT_ID = os.getenv("RUNPOD_OCR_ENDPOINT_ID") +def load_pdf( + file: bytes | None = None, + s3_key: str | None = None, + filename: str | None = None, +) -> tuple[bytes, str]: + """Loads a PDF file either from binary content or S3. + This function handles loading a PDF file from either binary or from an S3 bucket. + It returns the binary content of the PDF file and the filename. -AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") -AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") -S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") -S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") + Args: + file: PDF file content as bytes. Defaults to None. + s3_key: S3 key if the PDF should be fetched from S3. Defaults to None. + filename: Optional filename to use. If not provided, will be derived + from s3_key or set to a default. + Returns: + tuple[bytes, str]: A tuple containing: + - The binary content of the PDF file + - The filename of the PDF -s3_client = boto3.client( - "s3", - endpoint_url=S3_ENDPOINT_URL, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - config=Config(signature_version="s3v4"), -) + Raises: + ValueError: If neither file nor s3_key is provided. + """ + if file is None and s3_key is None: + raise ValueError("Either file or s3_key must be provided") + if filename is None: + filename = Path(s3_key).name if s3_key else "document.pdf" + + # If we already have the file content, just return it + if file is not None: + return file, filename + + # Otherwise, we need to download from S3 + with tempfile.NamedTemporaryFile() as temp_file: + temp_path = Path(temp_file.name) + load_s3().download(s3_key, str(temp_path)) + return temp_path.read_bytes(), filename + + +def page_to_image(page: fitz.Page, dpi: int = 200) -> np.ndarray: + """Converts a PDF page to a numpy image array. -def upload_to_s3(file_path: Union[str, Path], s3_key: str, overwrite: bool = False) -> bool: - """Uploads a file to an S3 bucket. + This function renders a PDF page at the specified DPI and converts it to a numpy + array. If the resulting image would be too large, it falls back to a + lower resolution. Args: - file_path: Local path to the file to upload. - s3_key: S3 key (path) where the file will be stored. - overwrite: If True, overwrites existing file. Defaults to False. + page: The pymupdf Page object to convert. + dpi: The dots per inch resolution to render at. Defaults to 200. Returns: - bool: True if upload was successful, False otherwise. + np.ndarray: The page as a numpy array in RGB format. """ - try: - file_path = Path(file_path) - if not overwrite: - try: - s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=s3_key) - return False - except Exception: - pass + mat = fitz.Matrix(dpi / 72, dpi / 72) + pm = page.get_pixmap(matrix=mat, alpha=False) - s3_client.upload_file(str(file_path), S3_BUCKET_NAME, s3_key) - return True - except Exception as e: - print(f"Error uploading file to S3: {str(e)}") - return False + if pm.width > 4500 or pm.height > 4500: + pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) + img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) + img = np.array(img) -def download_from_s3(s3_key: str, local_path: Optional[Union[str, Path]] = None) -> Optional[Path]: - """Downloads a file from an S3 bucket. + return img - Args: - s3_key: S3 key (path) of the file to download. - local_path: Local path where to save the file. If None, saves to the same name as s3_key. - Returns: - Optional[Path]: Path to the downloaded file if successful, None otherwise. - """ - try: - if local_path is None: - local_path = Path(Path(s3_key).name) - else: - local_path = Path(local_path) +def extract_native_text(page: fitz.Page, dpi: int) -> list[dict]: + """Extracts text content and bounding boxes from a PDF page using native PDF text. - s3_client.download_file(S3_BUCKET_NAME, s3_key, str(local_path)) - return local_path - except Exception as e: - print(f"Error downloading file from S3: {str(e)}") - return None + This function extracts text directly from the PDF's internal structure + rather than using OCR. + Args: + page: The pymupdf Page object to extract text from. + dpi: The resolution to use when scaling bounding boxes. -def delete_from_s3(s3_key: str) -> bool: - """Deletes a file from an S3 bucket. + Returns: + list[dict]: A list of dictionaries, each containing: + - 'bbox': The bounding box coordinates [x1, y1, x2, y2] + - 'content': The text content of the line + """ + blocks = page.get_text("dict")["blocks"] + data = [] + for block in blocks: + if "lines" in block: + for line in block["lines"]: + content = "".join([span["text"] for span in line["spans"]]) + if content.strip(): + line["bbox"] = tuple([(dpi / 72) * x for x in line["bbox"]]) + data.append( + { + "bbox": line["bbox"], + "content": content, + } + ) + return data + + +def has_images(page: fitz.Page) -> bool: + """Checks if a page has images that are large enough to potentially contain text. Args: - s3_key: S3 key (path) of the file to delete. + page: The pymupdf Page object to check. Returns: - bool: True if deletion was successful, False otherwise. + bool: True if the page contains images of a significant size, False otherwise. """ - try: - s3_client.delete_object(Bucket=S3_BUCKET_NAME, Key=s3_key) - return True - except Exception as e: - print(f"Warning: Failed to delete S3 object {s3_key}: {e}") - return False + image_list = page.get_images(full=True) + for _, img_info in enumerate(image_list): + xref = img_info[0] + base_image = page.parent.extract_image(xref) -def load_pdf( - file: Optional[bytes] = None, - s3_key: Optional[str] = None, - filename: Optional[str] = None, -) -> tuple[bytes, str]: - """Loads a PDF file either from binary content or S3. + if base_image: + width = base_image["width"] + height = base_image["height"] + if width > 10 and height > 10: + return True + + return False - This function handles loading a PDF file from either binary content or from an S3 bucket. - It returns the binary content of the PDF file and the filename. + +def has_text_annotations(page: fitz.Page) -> bool: + """Checks if a page has annotations that could contain text. Args: - file: PDF file content as bytes. Defaults to None. - s3_key: S3 key if the PDF should be fetched from S3. Defaults to None. - filename: Optional filename to use. If not provided, will be derived from s3_key or set to a default. + page: The pymupdf Page object to check. Returns: - tuple[bytes, str]: A tuple containing: - - The binary content of the PDF file - - The filename of the PDF - - Raises: - ValueError: If neither file nor s3_key is provided. + bool: True if the page has text-containing annotations, False otherwise. """ - if file is None and s3_key is None: - raise ValueError("Either file or s3_key must be provided") + annots = page.annots() - if filename is None: - if s3_key: - filename = Path(s3_key).name - else: - filename = "document.pdf" + if annots: + for annot in annots: + annot_type = annot.type[1] + if annot_type in [fitz.PDF_ANNOT_FREE_TEXT, fitz.PDF_ANNOT_WIDGET]: + return True - # If we already have the file content, just return it - if file is not None: - return file, filename + return False - # Otherwise, we need to download from S3 - with tempfile.NamedTemporaryFile() as temp_file: - temp_path = Path(temp_file.name) - download_from_s3(s3_key, temp_path) - return temp_path.read_bytes(), filename + +def page_needs_ocr(page: fitz.Page) -> bool: + """Determines if a page needs OCR processing. + + This function checks various conditions to decide if OCR is needed: + - If the page has no text + - If the page has CID-encoded text (often indicates non-extractable text) + - If the page has text annotations + - If the page has images that might contain text + - If the page has many drawing paths (might be scanned text) + + Args: + page: The pymupdf Page object to check. + + Returns: + bool: True if the page needs OCR processing, False otherwise. + """ + page_text = page.get_text() + + return ( + page_text.strip() == "" + or "(cid:" in page_text + or has_text_annotations(page) + or has_images(page) + or len(page.get_drawings()) > 10 + ) diff --git a/handler.py b/handler.py index a8f8470..1d1fd9e 100644 --- a/handler.py +++ b/handler.py @@ -1,9 +1,12 @@ +from collections.abc import Generator from datetime import datetime -from typing import Generator import runpod from docketanalyzer_ocr import load_pdf, pdf_document +from docketanalyzer_ocr.ocr import OCR_CLIENT + +process = OCR_CLIENT.process def handler(event: dict) -> Generator[dict, None, None]: @@ -34,26 +37,25 @@ def handler(event: dict) -> Generator[dict, None, None]: batch_size = inputs.get("batch_size", 1) try: - # Load the PDF file (now returns binary data) if inputs.get("s3_key"): - pdf_data, filename = load_pdf(s3_key=inputs.pop("s3_key"), filename=filename) + pdf_data, filename = load_pdf( + s3_key=inputs.pop("s3_key"), filename=filename + ) elif inputs.get("file"): pdf_data, filename = load_pdf(file=inputs.pop("file"), filename=filename) else: raise ValueError("Neither 's3_key' nor 'file' provided in input") - # Process the PDF using the binary data doc = pdf_document(pdf_data, filename=filename) - completed = 0 - for page in doc.stream(batch_size=batch_size): - completed += 1 + for i, page in enumerate(doc.stream(batch_size=batch_size)): duration = (datetime.now() - start).total_seconds() yield { "page": page.data, "seconds_elapsed": duration, - "progress": completed / len(doc), + "progress": i / len(doc), "status": "success", } + doc.close() except Exception as e: yield { "error": str(e), diff --git a/pyproject.toml b/pyproject.toml index 9801b21..31b2d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,46 +8,51 @@ version = "0.1.0" authors = [ { name = "Nathan Dahlberg" }, ] -description = "Docket Analyzer OCR Container" +description = "Docket Analyzer OCR Utility" readme = "README.md" -requires-python = ">=3.10,<3.11" +requires-python = ">=3.10" dependencies = [ - "boto3", - "click", "dill", "doclayout_yolo", + "docketanalyzer-core>=0.1.0", "fastapi", "huggingface-hub", - "paddlepaddle", - "paddleocr", + "numpy<2", "pymupdf", - "pytest", - "pytest-cov", - "python-dotenv", - "ruff", "runpod", + "uv", "uvicorn", ] [project.optional-dependencies] -gpu = [ - "paddlepaddle-gpu==2.6.2", +dev = [ + "docketanalyzer-core[dev]>=0.1.3", ] [tool.ruff] -lint.select = ["E", "F", "I"] -line-length = 120 +lint.select = ["E", "F", "I", "B", "UP", "N", "SIM", "PD", "NPY", "PTH", "RUF", "D"] +lint.ignore = ["D100", "D104"] [tool.ruff.lint.isort] -known-first-party = ["docketanalyzer_ocr"] +known-first-party = ["docketanalyzer_core", "docketanalyzer_ocr"] section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["I001", "I002"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.pytest.ini_options] -# Filter out warnings that aren't relevant to our code +log_cli = true +log_cli_level = "INFO" +addopts = "-ra -q --cov=docketanalyzer_ocr" +testpaths = ["tests"] +pythonpath = "." filterwarnings = [ "ignore::DeprecationWarning:importlib._bootstrap", "ignore::DeprecationWarning:thop.profile", "ignore::DeprecationWarning:setuptools.command.easy_install", "ignore::DeprecationWarning:pkg_resources", "ignore::DeprecationWarning:sys", -] \ No newline at end of file +] diff --git a/service.py b/service.py index 36bb1d5..88efa40 100644 --- a/service.py +++ b/service.py @@ -2,9 +2,9 @@ import base64 import json import uuid -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException @@ -13,40 +13,34 @@ from pydantic import BaseModel from docketanalyzer_ocr import load_pdf, pdf_document -from docketanalyzer_ocr.utils import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY +from docketanalyzer_ocr.ocr import OCR_CLIENT + +process = OCR_CLIENT.process -# Dictionary to store job information jobs = {} -# Cleanup task reference cleanup_task = None +ocr_semaphore = asyncio.Semaphore(1) + async def cleanup_old_jobs(): """Periodically clean up old jobs to prevent memory leaks.""" while True: try: - # Wait for 1 hour between cleanups await asyncio.sleep(3600) - - # Get current time now = datetime.now() - - # Find jobs older than 24 hours - old_jobs = [] - for job_id, job in jobs.items(): - created_at = datetime.fromisoformat(job["created_at"]) - if (now - created_at).total_seconds() > 86400: # 24 hours - old_jobs.append(job_id) - - # Remove old jobs + old_jobs = [ + job_id + for job_id, job in jobs.items() + if (now - datetime.fromisoformat(job["created_at"])).total_seconds() + > 86400 + ] for job_id in old_jobs: del jobs[job_id] - except asyncio.CancelledError: break except Exception: - # Log error and continue pass @@ -58,12 +52,13 @@ async def lifespan(app: FastAPI): yield + print("Stopping OCR service...", flush=True) + OCR_CLIENT.stop() + if cleanup_task: cleanup_task.cancel() - try: + with suppress(asyncio.CancelledError): await cleanup_task - except asyncio.CancelledError: - pass app = FastAPI( @@ -84,9 +79,9 @@ async def lifespan(app: FastAPI): class JobInput(BaseModel): """Input model for job submission.""" - s3_key: Optional[str] = None - file: Optional[str] = None # Base64 encoded file content - filename: Optional[str] = None + s3_key: str | None = None + file: str | None = None + filename: str | None = None batch_size: int = 1 @@ -106,7 +101,7 @@ class JobStatus(BaseModel): """Status model for job status.""" status: str - stream: Optional[List[Dict[str, Any]]] = None + stream: list[dict[str, Any]] | None = None async def process_document(job_id: str, input_data: JobInput): @@ -121,48 +116,41 @@ async def process_document(job_id: str, input_data: JobInput): jobs[job_id]["stream"] = [] try: - # Load the PDF data if input_data.s3_key: - if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: - raise ValueError("You must set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables") - pdf_data, filename = load_pdf(s3_key=input_data.s3_key, filename=input_data.filename) + pdf_data, filename = load_pdf( + s3_key=input_data.s3_key, filename=input_data.filename + ) elif input_data.file: - # Decode base64 file content pdf_bytes = base64.b64decode(input_data.file) pdf_data, filename = load_pdf(file=pdf_bytes, filename=input_data.filename) else: raise ValueError("Neither 's3_key' nor 'file' provided in input") - # Process the PDF - doc = pdf_document(pdf_data, filename=filename) - completed = 0 + async with ocr_semaphore: + doc = pdf_document(pdf_data, filename=filename) + pages = list(doc.stream(batch_size=input_data.batch_size)) - # Stream the results - for page in doc.stream(batch_size=input_data.batch_size): - completed += 1 + for i, page in enumerate(pages): duration = (datetime.now() - start).total_seconds() - # Create a stream item with the page data in the format expected by RemoteClient stream_item = { "output": { "page": page.data, "seconds_elapsed": duration, - "progress": completed / len(doc), + "progress": i / len(doc), "status": "success", } } - # Add to job stream jobs[job_id]["stream"].append(stream_item) - # Small delay to prevent CPU hogging await asyncio.sleep(0.1) - # Mark job as completed jobs[job_id]["status"] = "COMPLETED" + doc.close() except Exception as e: - # Handle errors + print(f"Error processing job {job_id}: {e}", flush=True) error_result = { "output": { "error": str(e), @@ -191,7 +179,6 @@ async def run_job(request: JobRequest, background_tasks: BackgroundTasks): "created_at": datetime.now().isoformat(), } - # Start processing in the background background_tasks.add_task(process_document, job_id, request.input) return {"id": job_id} @@ -210,11 +197,9 @@ async def stream_job(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") - # If job is still pending, return 404 to mimic RunPod behavior if jobs[job_id]["status"] == "PENDING": raise HTTPException(status_code=404, detail="Job not ready yet") - # Get the current stream position stream_position = 0 async def generate(): @@ -250,7 +235,9 @@ async def job_status(job_id: str): return { "status": jobs[job_id]["status"], - "stream": jobs[job_id]["stream"] if jobs[job_id]["status"] != "PENDING" else None, + "stream": jobs[job_id]["stream"] + if jobs[job_id]["status"] != "PENDING" + else None, } @@ -281,8 +268,9 @@ async def health_check(): Returns: dict: Health information. """ - # Count active jobs - active_jobs = sum(1 for job in jobs.values() if job["status"] in ["PENDING", "IN_PROGRESS"]) + active_jobs = sum( + 1 for job in jobs.values() if job["status"] in ["PENDING", "IN_PROGRESS"] + ) return { "workers": { diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..82c1144 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from docketanalyzer_ocr import pdf_document + +if __name__ == "__main__": + # Process a test PDF file to setup additional dependencies + path = Path(__file__).parent / "tests" / "fixtures" / "document.pdf" + doc = pdf_document(path, use_s3=False) + + for page in doc.stream(): + print(page.text) diff --git a/test_input.json b/test_input.json deleted file mode 100644 index a5d10e9..0000000 --- a/test_input.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "input": { - "s3_key": "ocr/test.pdf", - "filename": "test.pdf", - "batch_size": 1 - } -} \ No newline at end of file diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 4d75948..0000000 --- a/tests/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# Docket Analyzer OCR Tests - -This directory contains unit tests for the Docket Analyzer OCR package. - -## Test Structure - -The tests are organized by module: - -- `test_document.py`: Tests for document processing functionality -- `test_ocr.py`: Tests for OCR functionality -- `test_layout.py`: Tests for layout analysis functionality -- `test_utils.py`: Tests for utility functions -- `test_remote.py`: Tests for remote processing functionality -- `test_service.py`: Tests for the OCR service API endpoints - -## Running Tests and Code Coverage - -```bash -pytest --cov=docketanalyzer_ocr tests/ --cov-report=xml --cov-branch --junitxml=junit.xml -o junit_family=legacy -``` - -## Code Quality - -```bash -ruff format . && ruff check --fix . -``` - -## Test Fixtures - -The tests use fixtures defined in `conftest.py`: - -- `sample_pdf_bytes`: A simple PDF file in memory -- `sample_pdf_path`: A temporary PDF file on disk -- `sample_image`: A sample image for OCR testing -- `test_pdf_path`: Path to the real test PDF document diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index 9a695f2..29d2394 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,62 +1,34 @@ -import os -import sys -import tempfile from pathlib import Path -import fitz -import numpy as np import pytest -from PIL import Image - -sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +import simplejson as json @pytest.fixture -def sample_pdf_bytes(): - """Create a simple PDF file in memory for testing.""" - doc = fitz.open() - page = doc.new_page(width=595, height=842) # A4 size - - # Add some text to the page - page.insert_text((100, 100), "Sample PDF Document", fontsize=16) - page.insert_text((100, 150), "This is a test document for OCR testing.", fontsize=12) - page.insert_text((100, 200), "It contains some text that can be extracted.", fontsize=12) - - # Save to bytes - pdf_bytes = doc.tobytes() - doc.close() - - return pdf_bytes +def fixture_dir(): + """Path to the fixtures directory.""" + return Path(__file__).parent / "fixtures" @pytest.fixture -def sample_pdf_path(sample_pdf_bytes): - """Create a temporary PDF file on disk for testing.""" - with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: - tmp.write(sample_pdf_bytes) - tmp_path = tmp.name - - yield Path(tmp_path) - - if os.path.exists(tmp_path): - os.unlink(tmp_path) - - -@pytest.fixture -def sample_image(): - """Create a sample image with text for OCR testing.""" - # Create a white image - width, height = 800, 600 - image = Image.new("RGB", (width, height), color="white") - - # Convert to numpy array for easier handling - img_array = np.array(image) - - return img_array +def sample_pdf_path(fixture_dir): + """Path to the sample PDF file for testing.""" + return fixture_dir / "document.pdf" @pytest.fixture -def test_pdf_path(): - """Path to the real test PDF document.""" - pdf_path = Path(__file__).parent.parent / "docketanalyzer_ocr" / "setup" / "test.pdf" - return pdf_path +def sample_pdf_json(fixture_dir): + """Path to the sample PDF file for testing.""" + return json.loads((fixture_dir / "document.json").read_text()) + + +def compare_docs(doc1, doc2): + """Compare two processed PDF documents for equality.""" + for page1, page2 in zip(doc1, doc2, strict=False): + for block1, block2 in zip(page1, page2, strict=False): + for line1, line2 in zip(block1, block2, strict=False): + if line1.text != line2.text: + return False + if block1.block_type != block2.block_type: + return False + return True diff --git a/tests/fixtures/document.json b/tests/fixtures/document.json new file mode 100644 index 0000000..d9c007c --- /dev/null +++ b/tests/fixtures/document.json @@ -0,0 +1 @@ +{"filename": "document.pdf", "pages": [{"i": 0, "blocks": [{"i": 0, "bbox": [540, 331, 1161, 423], "type": "title", "lines": [{"i": 0, "bbox": [660.8055538601345, 336.8056403266059, 1056.6751268174912, 375.8056216769748], "content": "NOTICE: SLIP OPINION "}, {"i": 1, "bbox": [543.7222374810112, 384.13891262478296, 1165.3417799207898, 423.1388939751519], "content": "(not the court\u2019s final written decision) "}]}, {"i": 1, "bbox": [199, 518, 1455, 609], "type": "text", "lines": [{"i": 0, "bbox": [200.06667243109808, 523.472171359592, 1458.2455105251736, 562.4721527099609], "content": "The opinion that begins on the next page is a slip opinion. Slip opinions are the "}, {"i": 1, "bbox": [200.06667243109808, 571.1389329698351, 1064.0084160698784, 610.138914320204], "content": "written opinions that are originally filed by the court. "}]}, {"i": 2, "bbox": [196, 635, 1497, 1058], "type": "text", "lines": [{"i": 0, "bbox": [200.06667243109808, 640.8054775661892, 1448.7341986762153, 679.8054589165581], "content": "A slip opinion is not necessarily the court\u2019s final written decision. Slip opinions "}, {"i": 1, "bbox": [200.06667243109808, 688.138919406467, 1464.283243815104, 727.1389431423611], "content": "can be changed by subsequent court orders. For example, a court may issue an "}, {"i": 2, "bbox": [200.06667243109808, 735.5556064181858, 1486.1682467990452, 774.5556301540798], "content": "order making substantive changes to a slip opinion or publishing for precedential "}, {"i": 3, "bbox": [200.06667243109808, 783.2221984863281, 1482.8955756293403, 822.2222222222222], "content": "purposes a previously \u201cunpublished\u201d opinion. Additionally, nonsubstantive edits "}, {"i": 4, "bbox": [200.06667243109808, 830.5555555555555, 1412.6944647894966, 869.5555792914496], "content": "(for style, grammar, citation, format, punctuation, etc.) are made before the "}, {"i": 5, "bbox": [200.06667243109808, 878.2222323947483, 1500.8726331922742, 917.2222561306423], "content": "opinions that have precedential value are published in the official reports of court "}, {"i": 6, "bbox": [200.06667243109808, 925.5555894639757, 1498.8143920898438, 964.5556131998698], "content": "decisions: the Washington Reports 2d and the Washington Appellate Reports. An "}, {"i": 7, "bbox": [200.06667243109808, 973.2222663031683, 1451.1355929904514, 1012.2222900390625], "content": "opinion in the official reports replaces the slip opinion as the official opinion of "}, {"i": 8, "bbox": [200.06667243109808, 1020.5555386013455, 364.202880859375, 1059.5555623372395], "content": "the court. "}]}, {"i": 3, "bbox": [196, 1084, 1496, 1368], "type": "text", "lines": [{"i": 0, "bbox": [200.06667243109808, 1090.2222527398003, 1464.4849989149304, 1129.2222764756943], "content": "The slip opinion that begins on the next page is for a published opinion, and it "}, {"i": 1, "bbox": [200.06667243109808, 1137.5555250379773, 1494.6592542860242, 1176.5555487738716], "content": "has since been revised for publication in the printed official reports. The official "}, {"i": 2, "bbox": [200.06667243109808, 1185.2778116861978, 1490.0197347005208, 1224.277835422092], "content": "text of the court\u2019s opinion is found in the advance sheets and the bound volumes "}, {"i": 3, "bbox": [200.06667243109808, 1232.611083984375, 1361.5407307942708, 1271.611107720269], "content": "of the official reports. Also, an electronic version (intended to mirror the "}, {"i": 4, "bbox": [200.06667243109808, 1280.2777608235676, 1499.0592108832466, 1319.2777845594617], "content": "language found in the official reports) of the revised opinion can be found, free of "}, {"i": 5, "bbox": [200.06667243109808, 1327.6111178927952, 1358.425055609809, 1366.6111416286892], "content": "charge, at this website: https://www.lexisnexis.com/clients/wareports. "}]}, {"i": 4, "bbox": [198, 1393, 1467, 1532], "type": "text", "lines": [{"i": 0, "bbox": [200.06667243109808, 1397.2777472601997, 1469.399685329861, 1436.2777709960938], "content": "For more information about precedential (published) opinions, nonprecedential "}, {"i": 1, "bbox": [200.06667243109808, 1444.9442545572915, 1240.1014539930554, 1483.944363064236], "content": "(unpublished) opinions, slip opinions, and the official reports, see "}, {"i": 2, "bbox": [200.06667243109808, 1492.2776963975693, 1438.7585110134548, 1531.277804904514], "content": "https://www.courts.wa.gov/opinions and the information that is linked there. "}]}]}, {"i": 1, "blocks": [{"i": 0, "bbox": [44, 53, 1570, 253], "type": "abandon", "lines": [{"i": 0, "bbox": [51.0, 60.0, 1554.0, 92.0], "content": "For the current opinion, go to https://www.lexisnexis.com/cliAILeDwareports/."}, {"i": 1, "bbox": [1220.0, 99.0, 1374.0, 131.0], "content": "4/10/2023"}, {"i": 2, "bbox": [1166.0, 135.0, 1427.0, 174.0], "content": "Court of Appeals"}, {"i": 3, "bbox": [1224.0, 172.0, 1374.0, 211.0], "content": "Division I"}, {"i": 4, "bbox": [1139.0, 208.0, 1453.0, 250.0], "content": "State of Washington"}]}, {"i": 1, "bbox": [342, 351, 1353, 393], "type": "title", "lines": [{"i": 0, "bbox": [351.0, 355.0, 1349.0, 387.0], "content": "IN THE COURT OF APPEALS OF THE STATE OF WASHINGTON"}]}, {"i": 2, "bbox": [259, 465, 805, 547], "type": "text", "lines": [{"i": 0, "bbox": [266.0, 472.0, 797.0, 504.0], "content": "CULINARY VENTURES, LTD, d/b/a"}, {"i": 1, "bbox": [263.0, 511.0, 443.0, 543.0], "content": "BITEMOJO,"}]}, {"i": 3, "bbox": [969, 503, 1322, 702], "type": "text", "lines": [{"i": 0, "bbox": [975.0, 509.0, 1183.0, 541.0], "content": "No. 83486-0-I"}, {"i": 1, "bbox": [977.0, 587.0, 1206.0, 619.0], "content": "DIVISION ONE"}, {"i": 2, "bbox": [977.0, 660.0, 1317.0, 692.0], "content": "PUBLISHED OPINION"}]}, {"i": 4, "bbox": [561, 585, 715, 625], "type": "text", "lines": [{"i": 0, "bbox": [561.0, 584.0, 716.0, 623.0], "content": "Appellant,"}]}, {"i": 5, "bbox": [257, 732, 756, 856], "type": "text", "lines": [{"i": 0, "bbox": [266.0, 740.0, 732.0, 772.0], "content": "MICROSOFT CORPORATION,"}, {"i": 1, "bbox": [561.0, 814.0, 753.0, 852.0], "content": "Respondent."}]}, {"i": 6, "bbox": [243, 927, 1454, 1819], "type": "text", "lines": [{"i": 0, "bbox": [351.0, 933.0, 1280.0, 965.0], "content": "Chung, J. Culinary Ventures d/b/a Bitemojo, the creator of a"}, {"i": 1, "bbox": [252.0, 1011.0, 1381.0, 1043.0], "content": "smartphone application for self-guided food tours, entered into a subscription"}, {"i": 2, "bbox": [249.0, 1084.0, 1404.0, 1123.0], "content": "agreement with Microsoft Ireland for its Azure online cloud-based data storage"}, {"i": 3, "bbox": [252.0, 1164.0, 1395.0, 1196.0], "content": "services. The agreement included a forum selection clause specifying that if it"}, {"i": 4, "bbox": [245.0, 1235.0, 1446.0, 1276.0], "content": "brought an action to enforce the agreement, Bitemojo would bring such an action"}, {"i": 5, "bbox": [249.0, 1318.0, 1427.0, 1350.0], "content": "in Ireland. At Bitemojo's request, Azure twice suspended Bitemojo's account, as."}, {"i": 6, "bbox": [249.0, 1393.0, 1333.0, 1425.0], "content": "well as the required payments. Thereafter, Azure deleted Bitemojo's data.."}, {"i": 7, "bbox": [252.0, 1471.0, 1284.0, 1503.0], "content": "Subsequently, Bitemojo sued Microsoft Corporation in King County for."}, {"i": 8, "bbox": [247.0, 1549.0, 1312.0, 1577.0], "content": "promissory estoppel, breach of contract, conversion, and violation of the."}, {"i": 9, "bbox": [252.0, 1625.0, 1441.0, 1657.0], "content": "Washington Consumer Protection Act (CPA), chapter 19.86 RCW. The trial court."}, {"i": 10, "bbox": [247.0, 1698.0, 1358.0, 1735.0], "content": "granted Microsoft Corporation's CR 12(b)(3) motion to dismiss for improper."}, {"i": 11, "bbox": [252.0, 1778.0, 1081.0, 1808.0], "content": "venue based on the agreement's forum selection clause.."}]}]}]} \ No newline at end of file diff --git a/docketanalyzer_ocr/setup/test.pdf b/tests/fixtures/document.pdf similarity index 100% rename from docketanalyzer_ocr/setup/test.pdf rename to tests/fixtures/document.pdf diff --git a/tests/test_document.py b/tests/test_document.py deleted file mode 100644 index 7a0871e..0000000 --- a/tests/test_document.py +++ /dev/null @@ -1,101 +0,0 @@ -import json -import tempfile -from pathlib import Path - -import fitz -import numpy as np - -from docketanalyzer_ocr.document import ( - PDFDocument, - extract_native_text, - page_to_image, - pdf_document, -) - - -class TestDocumentCore: - """Core tests for document processing functionality.""" - - def test_document_processing(self, sample_pdf_bytes): - """Test basic document processing without mocks.""" - # Create document and process - doc = PDFDocument(sample_pdf_bytes, filename="test.pdf") - doc.process() - - # Verify document structure - assert len(doc.pages) > 0 - assert doc.filename == "test.pdf" - - # Check page content - for page in doc.pages: - assert page.text - assert len(page.blocks) > 0 - - # Check block structure - for block in page.blocks: - assert hasattr(block, "bbox") - assert hasattr(block, "block_type") - assert hasattr(block, "lines") - - # Check line structure - for line in block.lines: - assert hasattr(line, "bbox") - assert hasattr(line, "content") - - def test_page_to_image_conversion(self, sample_pdf_bytes): - """Test converting PDF pages to images.""" - doc = fitz.open("pdf", sample_pdf_bytes) - page = doc[0] - - # Test with default DPI - img = page_to_image(page) - assert isinstance(img, np.ndarray) - assert img.shape[2] == 3 # RGB image - - # Test with custom DPI - img_high_dpi = page_to_image(page, dpi=300) - assert img_high_dpi.shape[0] > img.shape[0] # Higher resolution - - doc.close() - - def test_native_text_extraction(self, sample_pdf_bytes): - """Test extracting native text from PDF pages.""" - doc = fitz.open("pdf", sample_pdf_bytes) - page = doc[0] - - text_data = extract_native_text(page, 100) - - assert isinstance(text_data, list) - assert len(text_data) > 0 - - # Check structure of text items - for item in text_data: - assert "bbox" in item - assert "content" in item - assert isinstance(item["bbox"], tuple) - assert len(item["bbox"]) == 4 - assert isinstance(item["content"], str) - - doc.close() - - def test_document_save_load(self, sample_pdf_bytes): - """Test saving and loading document data.""" - # Process document - doc = pdf_document(sample_pdf_bytes, filename="test.pdf").process() - - # Save to temporary file - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - tmp_path = Path(tmp.name) - doc.save(tmp_path) - - # Load saved data - with open(tmp_path, "r") as f: - data = json.load(f) - - # Verify data structure - assert data["filename"] == "test.pdf" - assert "pages" in data - assert len(data["pages"]) > 0 - - # Clean up - tmp_path.unlink() diff --git a/tests/test_layout.py b/tests/test_layout.py deleted file mode 100644 index cceb632..0000000 --- a/tests/test_layout.py +++ /dev/null @@ -1,115 +0,0 @@ -import fitz - -from docketanalyzer_ocr.document import page_to_image -from docketanalyzer_ocr.layout import ( - boxes_overlap, - load_model, - merge_boxes, - merge_overlapping_blocks, - predict_layout, -) - - -class TestLayout: - """Tests for layout analysis functionality.""" - - def test_boxes_overlap(self): - """Test detection of overlapping and non-overlapping boxes.""" - # Overlapping boxes - box1 = (10, 10, 50, 50) - box2 = (30, 30, 70, 70) - assert boxes_overlap(box1, box2) is True - - # Box contained within another - box3 = (20, 20, 40, 40) - assert boxes_overlap(box1, box3) is True - - # Edge touching - box4 = (50, 10, 90, 50) - assert boxes_overlap(box1, box4) is True - - # Non-overlapping boxes - box5 = (60, 60, 100, 100) - assert boxes_overlap(box1, box5) is False - - def test_merge_boxes(self): - """Test merging of bounding boxes.""" - # Overlapping boxes - box1 = (10, 10, 50, 50) - box2 = (30, 30, 70, 70) - merged = merge_boxes(box1, box2) - assert merged == (10, 10, 70, 70) - - # Non-overlapping boxes - box3 = (100, 100, 150, 150) - merged2 = merge_boxes(box1, box3) - assert merged2 == (10, 10, 150, 150) - - def test_merge_overlapping_blocks(self): - """Test merging overlapping layout blocks.""" - # Test with empty input - assert merge_overlapping_blocks([]) == [] - - # Test with blocks to merge - blocks = [ - {"type": "text", "bbox": (10, 10, 50, 50)}, - {"type": "text", "bbox": (30, 30, 70, 70)}, - {"type": "figure", "bbox": (100, 100, 150, 150)}, - {"type": "title", "bbox": (20, 20, 40, 40)}, # Title has higher priority - ] - - result = merge_overlapping_blocks(blocks) - - assert len(result) == 2 # Should merge the overlapping blocks - - # First block should be the merged one with type 'title' (highest priority) - assert result[0]["type"] == "title" - assert result[0]["bbox"] == (10, 10, 70, 70) - - # Second block should be the figure - assert result[1]["type"] == "figure" - assert result[1]["bbox"] == (100, 100, 150, 150) - - def test_model_loading(self): - """Test loading the layout model.""" - # Reset global model to ensure it's loaded - import docketanalyzer_ocr.layout - - docketanalyzer_ocr.layout.LAYOUT_MODEL = None - - # Load the model - model, device = load_model() - - # Verify model was loaded - assert model is not None - assert device in ["cpu", "cuda"] - - # Verify the model is cached - model2, _ = load_model() - assert model2 is model # Should be the same instance - - def test_layout_prediction(self, test_pdf_path): - """Test layout prediction on a real PDF.""" - # Open the PDF and get the first page - doc = fitz.open(test_pdf_path) - page = doc[0] - - # Convert the page to an image - image = page_to_image(page, dpi=150) - - # Run layout analysis - layout_results = predict_layout([image]) - - # Verify layout results - assert isinstance(layout_results, list) - assert len(layout_results) > 0 - - # Check structure of layout results - for block in layout_results[0]: - assert "bbox" in block - assert "type" in block - assert isinstance(block["bbox"], tuple) or isinstance(block["bbox"], list) - assert len(block["bbox"]) == 4 - assert block["type"] in ["text", "title", "list", "table", "figure"] - - doc.close() diff --git a/tests/test_local.py b/tests/test_local.py new file mode 100644 index 0000000..654fd27 --- /dev/null +++ b/tests/test_local.py @@ -0,0 +1,47 @@ +from docketanalyzer_ocr import pdf_document + +from .conftest import compare_docs + + +def test_load(sample_pdf_path, sample_pdf_json): + """Test loading a PDF document from serialized data.""" + # Load from path + sample_pdf_json_path = sample_pdf_path.with_suffix(".json") + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json_path) + + # Load from dict + doc2 = pdf_document(sample_pdf_path, load=sample_pdf_json) + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" + + edited_pdf_json = sample_pdf_json.copy() + edited_pdf_json["pages"][0]["blocks"][0]["lines"][0]["content"] = "Edited text" + doc3 = pdf_document(sample_pdf_path, load=edited_pdf_json) + assert not compare_docs(doc1, doc3), "Edited documents shoudl not be equal" + + +def test_local_process(sample_pdf_path, sample_pdf_json): + """Test process method, loading pdf from path.""" + # Sample doc for comparison + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + + # Doc to process from path + doc2 = pdf_document(sample_pdf_path).process() + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" + + +def test_local_stream(sample_pdf_path, sample_pdf_json): + """Test stream method, loading pdf from bytes.""" + # Sample doc for comparison + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + + # Doc to process from path + doc2 = pdf_document(sample_pdf_path.read_bytes()) + for _ in doc2.stream(): + pass + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" diff --git a/tests/test_ocr.py b/tests/test_ocr.py deleted file mode 100644 index 53abf3d..0000000 --- a/tests/test_ocr.py +++ /dev/null @@ -1,59 +0,0 @@ -import time - -import fitz - -from docketanalyzer_ocr.document import page_to_image -from docketanalyzer_ocr.ocr import extract_ocr_text, load_model - - -class TestOCR: - """Tests for OCR functionality.""" - - def test_model_loading(self): - """Test loading the OCR model.""" - # Reset global model to ensure it's loaded - import docketanalyzer_ocr.ocr - - docketanalyzer_ocr.ocr.OCR_MODEL = None - - # Load the model - model, device = load_model() - - # Verify model was loaded - assert model is not None - assert device in ["cpu", "cuda"] - - # Verify the model is cached - model2, _ = load_model() - assert model2 is model # Should be the same instance - - def test_real_ocr_extraction(self, test_pdf_path): - """Test actual OCR extraction on a real PDF document.""" - # Open the PDF and get the first page - doc = fitz.open(test_pdf_path) - page = doc[0] - - # Convert the page to an image - image = page_to_image(page, dpi=150) # Lower DPI for faster processing - - # Run OCR on the image - ocr_start = time.time() - ocr_results = extract_ocr_text(image) - ocr_time = time.time() - ocr_start - - # Verify OCR results - assert isinstance(ocr_results, list) - assert len(ocr_results) > 0 - - # Check structure of OCR results - for result in ocr_results: - assert "bbox" in result - assert "content" in result - assert isinstance(result["bbox"], list) - assert len(result["bbox"]) == 4 - assert isinstance(result["content"], str) - - # Verify OCR actually ran (should take some time) - assert ocr_time > 0.1 - - doc.close() diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..79cc9dc --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,25 @@ +import logging +import subprocess +import sys + +from docketanalyzer_core import notabs + + +def test_import_time(): + """Test the import time of the package.""" + timing_code = notabs(""" + import time + start = time.time() + import docketanalyzer_ocr + end = time.time() + print(end - start) + """) + + result = subprocess.run( + [sys.executable, "-c", timing_code], capture_output=True, text=True, check=True + ) + + import_time = float(result.stdout.strip()) + + logging.info(f"docketanalyzer_ocr import time: {import_time:.4f} seconds") + assert import_time < 1, f"Import time is too long: {import_time} seconds" diff --git a/tests/test_remote.py b/tests/test_remote.py deleted file mode 100644 index 3f37421..0000000 --- a/tests/test_remote.py +++ /dev/null @@ -1,63 +0,0 @@ -from datetime import datetime - -from docketanalyzer_ocr.document import PDFDocument -from docketanalyzer_ocr.remote import RemoteClient -from docketanalyzer_ocr.utils import ( - RUNPOD_API_KEY, - RUNPOD_OCR_ENDPOINT_ID, -) - - -class TestRemote: - """Tests for remote processing functionality.""" - - def test_client_initialization(self): - """Test remote client initialization.""" - client = RemoteClient() - - assert client.api_key == RUNPOD_API_KEY - assert client.base_url == f"https://api.runpod.ai/v2/{RUNPOD_OCR_ENDPOINT_ID}" - assert client.headers == { - "Authorization": f"Bearer {RUNPOD_API_KEY}", - "Content-Type": "application/json", - } - - def test_client_initialization_with_custom_url(self): - """Test remote client initialization with custom endpoint URL.""" - custom_url = "http://example.com/api" - client = RemoteClient(endpoint_url=custom_url) - - assert client.base_url == custom_url - assert client.headers["Content-Type"] == "application/json" - - def test_client_health_check(self): - """Test client health check.""" - client = RemoteClient() - health = client.get_health() - - assert isinstance(health, dict) - assert "workers" in health - assert "jobs" in health - - def test_remote_processing(self, test_pdf_path): - """Test remote document processing.""" - # Load the test PDF - test_pdf_bytes = test_pdf_path.read_bytes() - - # Process with remote=True - start = datetime.now() - doc = PDFDocument(test_pdf_bytes, remote=True) - - # Test streaming API - processed_pages = [] - for page in doc.stream(): - processed_pages.append(page) - assert len(page.text) > 0 - assert len(page.blocks) > 0 - - # Verify all pages were processed - assert len(processed_pages) == len(doc.pages) - - # The test should take a reasonable amount of time for real streaming - elapsed_time = (datetime.now() - start).total_seconds() - assert elapsed_time > 1.0, "Test completed too quickly for real streaming" diff --git a/tests/test_runpod.py b/tests/test_runpod.py new file mode 100644 index 0000000..d11fe20 --- /dev/null +++ b/tests/test_runpod.py @@ -0,0 +1,52 @@ +from docketanalyzer_core import env +from docketanalyzer_ocr import pdf_document + +from .conftest import compare_docs + + +def test_runpod_post(monkeypatch, sample_pdf_path, sample_pdf_json): + """Test process method, loading pdf from path.""" + key_check = bool(env.RUNPOD_API_KEY) + assert key_check, "RUNPOD_API_KEY is not set" + key_check = bool(env.RUNPOD_OCR_ENDPOINT_ID) + assert key_check, "RUNPOD_OCR_ENDPOINT_ID is not set" + + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + doc2 = pdf_document(sample_pdf_path, remote=True, use_s3=False) + + endpoint_check = env.RUNPOD_OCR_ENDPOINT_ID in doc2.remote_client.base_url + assert endpoint_check, "Endpoint ID not found in remote client URL" + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "disabled") + key_check = env.AWS_ACCESS_KEY_ID == "disabled" + assert key_check, "AWS_ACCESS_KEY_ID not disabled" + + for _ in doc2.stream(batch_size=2): + pass + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" + + +def test_runpod_s3(sample_pdf_path, sample_pdf_json): + """Test stream method, loading pdf from bytes.""" + key_check = bool(env.AWS_S3_BUCKET_NAME) + assert key_check, "AWS_S3_BUCKET_NAME is not set" + key_check = bool(env.AWS_ACCESS_KEY_ID) + assert key_check, "AWS_ACCESS_KEY_ID is not set" + key_check = bool(env.AWS_SECRET_ACCESS_KEY) + assert key_check, "AWS_SECRET_ACCESS_KEY is not set" + key_check = bool(env.AWS_S3_ENDPOINT_URL) + assert key_check, "AWS_S3_ENDPOINT_URL is not set" + + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + doc2 = pdf_document(sample_pdf_path, remote=True) + + assert doc2.s3_available, "S3 availability check failed" + assert doc2.use_s3, "S3 setting didnt default to true" + + for _ in doc2.stream(): + pass + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" diff --git a/tests/test_service.py b/tests/test_service.py index fc560f5..b0f23b6 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -1,47 +1,16 @@ -import base64 -import json -import logging +import importlib +import multiprocessing +import sys import time import pytest -import requests import uvicorn -from fastapi.testclient import TestClient -from service import app, jobs +from docketanalyzer_core import env +from docketanalyzer_ocr import pdf_document +from service import app -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - datefmt="%H:%M:%S", - force=True, # Force reconfiguration of the root logger -) -logger = logging.getLogger(__name__) - -# Ensure logs are displayed during test execution -logging.getLogger().setLevel(logging.INFO) - - -def log_timing(start_time, message): - """Log elapsed time with a message.""" - elapsed = time.time() - start_time - logger.info(f"{message} - Took {elapsed:.2f} seconds") - return time.time() - - -@pytest.fixture -def client(): - """Create a test client for the FastAPI app.""" - return TestClient(app) - - -@pytest.fixture -def clear_jobs(): - """Clear the jobs dictionary before and after each test.""" - jobs.clear() - yield - jobs.clear() +from .conftest import compare_docs def run_server(): @@ -50,131 +19,92 @@ def run_server(): This function is used by the service_process fixture to avoid pickling the FastAPI app directly. """ - import importlib - import sys - - # Force reload the service module to ensure we get a fresh instance if "service" in sys.modules: importlib.reload(sys.modules["service"]) - # Import the app from the service module - from service import app - - # Run the server - logger.info("Starting FastAPI server") uvicorn.run(app, host="127.0.0.1", port=8000, log_level="error") @pytest.fixture(scope="function") def service_process(): """Start the FastAPI service in a separate process for testing.""" - # Create a process to run the service using the run_server function - logger.info("Creating service process") - - import multiprocessing as mp - - mp.set_start_method("spawn", force=True) - process = mp.Process(target=run_server) + multiprocessing.set_start_method("spawn", force=True) + process = multiprocessing.Process(target=run_server) - # Start the service - start_time = time.time() process.start() - logger.info(f"Service process started with PID: {process.pid}") - - # Wait for the service to start time.sleep(2) - log_timing(start_time, "Service startup wait completed") yield - # Terminate the service process - start_time = time.time() - logger.info(f"Terminating service process with PID: {process.pid}") + time.sleep(0.5) process.terminate() - process.join() - log_timing(start_time, "Service process terminated") - - -class TestOCRService: - """Tests for the OCR service.""" - - def test_service_api_endpoints(self, test_pdf_path, service_process, clear_jobs): - """Test the service API endpoints directly. - - This test ensures that: - 1. The /run endpoint accepts a job and returns a job ID - 2. The /status endpoint returns the correct job status - 3. The /stream endpoint streams job results - """ - logger.info("=" * 80) - logger.info("STARTING TEST: test_service_api_endpoints") - overall_start = time.time() - - # Read the test PDF file - start_time = time.time() - pdf_bytes = test_pdf_path.read_bytes() - logger.info(f"Test PDF path: {test_pdf_path}, size: {len(pdf_bytes)} bytes") - log_timing(start_time, "Read PDF file") - - # Base64 encode the PDF bytes for JSON serialization - start_time = time.time() - pdf_base64 = base64.b64encode(pdf_bytes).decode("utf-8") - log_timing(start_time, "Base64 encode PDF") - - # Submit a job to the service - start_time = time.time() - logger.info("Submitting job to service") - response = requests.post( - "http://127.0.0.1:8000/run", json={"input": {"file": pdf_base64, "filename": "test.pdf", "batch_size": 1}} - ) - # Check that the job was submitted successfully - assert response.status_code == 200, f"Expected status code 200, got {response.status_code}" - job_id = response.json()["id"] - assert job_id, "Job ID should be returned" - logger.info(f"Job submitted successfully with ID: {job_id}") - log_timing(start_time, "Submit job") - - # Wait for the job to start processing (max 30 seconds) - start_time = time.time() - max_attempts = 30 - status = "PENDING" - for attempt in range(max_attempts): - logger.info(f"Checking job status (attempt {attempt + 1}/{max_attempts})") - status_response = requests.post(f"http://127.0.0.1:8000/status/{job_id}") - assert status_response.status_code == 200, f"Status endpoint failed with code {status_response.status_code}" - - status = status_response.json()["status"] - logger.info(f"Job status: {status}") - if status != "PENDING": - break - time.sleep(1) - log_timing(start_time, f"Wait for job to start processing (status: {status})") - - # Check that the job status is either IN_PROGRESS or COMPLETED - assert status in ["IN_PROGRESS", "COMPLETED"], f"Job status should be IN_PROGRESS or COMPLETED, got {status}" - - # Test the stream endpoint - start_time = time.time() - logger.info("Testing stream endpoint") - stream_response = requests.post(f"http://127.0.0.1:8000/stream/{job_id}", stream=True) - assert stream_response.status_code == 200, f"Stream endpoint failed with code {stream_response.status_code}" - - # Check that we get at least one result from the stream - results = [] - page_count = 0 - logger.info("Reading complete stream response") - for line in stream_response.iter_lines(): - if line: - result = json.loads(line.decode("utf-8")) - logger.info(f"Received stream result: {result.keys()}") - results.append(result) - if "stream" in result: - for item in result["stream"]: - if "output" in item and "page" in item["output"]: - page_count += 1 - logger.info(f"Processed page {page_count}") - - assert len(results) > 0, "Should get at least one result from the stream" - log_timing(start_time, f"Stream test completed with {len(results)} results") - - log_timing(overall_start, "COMPLETED TEST: test_service_api_endpoints") + process.join(timeout=3) + + if process.is_alive(): + print("Process didn't terminate gracefully, forcing kill") + process.kill() + process.join(timeout=2) + + assert not process.is_alive(), "Failed to terminate the service process" + + +def test_service_post(monkeypatch, sample_pdf_path, sample_pdf_json, service_process): + """Test service with post, explicitly overriding endpoint.""" + time.sleep(5) # Let service start + # Do this to confirm override + key_check = bool(env.RUNPOD_API_KEY) + assert key_check, "RUNPOD_API_KEY is not set" + key_check = bool(env.RUNPOD_OCR_ENDPOINT_ID) + assert key_check, "RUNPOD_OCR_ENDPOINT_ID is not set" + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "disabled") + key_check = env.AWS_ACCESS_KEY_ID == "disabled" + assert key_check, "AWS_ACCESS_KEY_ID not disabled" + + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + doc2 = pdf_document( + sample_pdf_path, + remote=True, + use_s3=False, + endpoint_url="http://localhost:8000/", + ) + + endpoint_check = "localhost" in doc2.remote_client.base_url + assert endpoint_check, "Not using local endpoint" + + status = doc2.remote_client.get_health() + assert status["workers"]["total"] == 1, "No active workers found" + + for _ in doc2.stream(): + pass + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" + + +def test_service_s3(monkeypatch, sample_pdf_path, sample_pdf_json, service_process): + """Test service with s3, with implicit local endpoint.""" + time.sleep(5) # Let service start + monkeypatch.setenv("RUNPOD_OCR_ENDPOINT_ID", "") + key_check = bool(env.RUNPOD_OCR_ENDPOINT_ID) + assert not key_check, "RUNPOD_OCR_ENDPOINT_ID is not disabled" + + doc1 = pdf_document(sample_pdf_path, load=sample_pdf_json) + doc2 = pdf_document( + sample_pdf_path, + remote=True, + use_s3=True, + endpoint_url="http://localhost:8000/", + ) + + endpoint_check = "localhost" in doc2.remote_client.base_url + assert endpoint_check, "Not using local endpoint" + assert doc2.s3_available, "S3 availability check failed" + + status = doc2.remote_client.get_health() + assert status["workers"]["total"] == 1, "No active workers found" + + doc2.process() + + assert len(doc1) == len(doc2), "Document lengths do not match" + assert compare_docs(doc1, doc2), "Processed documents are not equal" diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 68235d9..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -import tempfile -from pathlib import Path - -import pytest - -from docketanalyzer_ocr.utils import load_pdf - - -class TestUtils: - """Tests for utility functions.""" - - def test_load_pdf_from_bytes(self, sample_pdf_bytes): - """Test loading a PDF from bytes.""" - pdf_bytes, filename = load_pdf(file=sample_pdf_bytes, filename="test.pdf") - - assert pdf_bytes == sample_pdf_bytes - assert filename == "test.pdf" - - def test_load_pdf_missing_params(self): - """Test loading a PDF with missing parameters.""" - with pytest.raises(ValueError): - load_pdf() # No file or s3_key provided - - def test_s3_operations(self): - """Test S3 operations with real AWS resources.""" - - # Create a test file - with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp: - tmp.write(b"Test S3 file content") - tmp_path = Path(tmp.name) - - try: - # Import here to avoid errors if AWS credentials are not set - from docketanalyzer_ocr.utils import download_from_s3, upload_to_s3 - - # Upload to S3 - s3_key = f"test/test_file_{os.urandom(4).hex()}.txt" - upload_success = upload_to_s3(tmp_path, s3_key, overwrite=True) - - assert upload_success is True - - # Download from S3 - with tempfile.TemporaryDirectory() as tmp_dir: - download_path = Path(tmp_dir) / "downloaded.txt" - result_path = download_from_s3(s3_key, download_path) - - assert result_path == download_path - assert download_path.exists() - assert download_path.read_bytes() == b"Test S3 file content" - - finally: - # Clean up - if tmp_path.exists(): - tmp_path.unlink()