Browse Source
S3: Add tests for PyArrow with native S3 filesystem (#7508)
S3: Add tests for PyArrow with native S3 filesystem (#7508)
* PyArrow native S3 filesystem * add sse-s3 tests * update * minor * ENABLE_SSE_S3 * Update test_pyarrow_native_s3.py * clean up * refactoring * Update test_pyarrow_native_s3.pypull/7510/head
committed by
GitHub
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1008 additions and 5 deletions
-
22.github/workflows/s3-parquet-tests.yml
-
92test/s3/parquet/Makefile
-
87test/s3/parquet/README.md
-
134test/s3/parquet/example_pyarrow_native.py
-
41test/s3/parquet/parquet_test_utils.py
-
383test/s3/parquet/test_pyarrow_native_s3.py
-
254test/s3/parquet/test_sse_s3_compatibility.py
@ -0,0 +1,134 @@ |
|||
#!/usr/bin/env python3 |
|||
# /// script |
|||
# dependencies = [ |
|||
# "pyarrow>=22", |
|||
# "boto3>=1.28.0", |
|||
# ] |
|||
# /// |
|||
|
|||
""" |
|||
Simple example of using PyArrow's native S3 filesystem with SeaweedFS. |
|||
|
|||
This is a minimal example demonstrating how to write and read Parquet files |
|||
using PyArrow's built-in S3FileSystem without any additional dependencies |
|||
like s3fs. |
|||
|
|||
Usage: |
|||
# Set environment variables |
|||
export S3_ENDPOINT_URL=localhost:8333 |
|||
export S3_ACCESS_KEY=some_access_key1 |
|||
export S3_SECRET_KEY=some_secret_key1 |
|||
export BUCKET_NAME=test-parquet-bucket |
|||
|
|||
# Run the script |
|||
python3 example_pyarrow_native.py |
|||
|
|||
# Or run with uv (if available) |
|||
uv run example_pyarrow_native.py |
|||
""" |
|||
|
|||
import os |
|||
import secrets |
|||
|
|||
import pyarrow as pa |
|||
import pyarrow.dataset as pads |
|||
import pyarrow.fs as pafs |
|||
import pyarrow.parquet as pq |
|||
|
|||
from parquet_test_utils import create_sample_table |
|||
|
|||
# Configuration |
|||
BUCKET_NAME = os.getenv("BUCKET_NAME", "test-parquet-bucket") |
|||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL", "localhost:8333") |
|||
S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY", "some_access_key1") |
|||
S3_SECRET_KEY = os.getenv("S3_SECRET_KEY", "some_secret_key1") |
|||
|
|||
# Determine scheme from endpoint |
|||
if S3_ENDPOINT_URL.startswith("http://"): |
|||
scheme = "http" |
|||
endpoint = S3_ENDPOINT_URL[7:] |
|||
elif S3_ENDPOINT_URL.startswith("https://"): |
|||
scheme = "https" |
|||
endpoint = S3_ENDPOINT_URL[8:] |
|||
else: |
|||
scheme = "http" # Default to http for localhost |
|||
endpoint = S3_ENDPOINT_URL |
|||
|
|||
print(f"Connecting to S3 endpoint: {scheme}://{endpoint}") |
|||
|
|||
# Initialize PyArrow's NATIVE S3 filesystem |
|||
s3 = pafs.S3FileSystem( |
|||
access_key=S3_ACCESS_KEY, |
|||
secret_key=S3_SECRET_KEY, |
|||
endpoint_override=endpoint, |
|||
scheme=scheme, |
|||
allow_bucket_creation=True, |
|||
allow_bucket_deletion=True, |
|||
) |
|||
|
|||
print("✓ Connected to S3 endpoint") |
|||
|
|||
|
|||
# Create bucket if needed (using boto3) |
|||
try: |
|||
import boto3 |
|||
from botocore.exceptions import ClientError |
|||
|
|||
s3_client = boto3.client( |
|||
's3', |
|||
endpoint_url=f"{scheme}://{endpoint}", |
|||
aws_access_key_id=S3_ACCESS_KEY, |
|||
aws_secret_access_key=S3_SECRET_KEY, |
|||
region_name='us-east-1', |
|||
) |
|||
|
|||
try: |
|||
s3_client.head_bucket(Bucket=BUCKET_NAME) |
|||
print(f"✓ Bucket exists: {BUCKET_NAME}") |
|||
except ClientError as e: |
|||
if e.response['Error']['Code'] == '404': |
|||
print(f"Creating bucket: {BUCKET_NAME}") |
|||
s3_client.create_bucket(Bucket=BUCKET_NAME) |
|||
print(f"✓ Bucket created: {BUCKET_NAME}") |
|||
else: |
|||
raise |
|||
except ImportError: |
|||
print("Warning: boto3 not available, assuming bucket exists") |
|||
|
|||
# Generate a unique filename |
|||
filename = f"{BUCKET_NAME}/dataset-{secrets.token_hex(8)}/test.parquet" |
|||
|
|||
print(f"\nWriting Parquet dataset to: {filename}") |
|||
|
|||
# Write dataset |
|||
table = create_sample_table(200_000) |
|||
pads.write_dataset( |
|||
table, |
|||
filename, |
|||
filesystem=s3, |
|||
format="parquet", |
|||
) |
|||
|
|||
print(f"✓ Wrote {table.num_rows:,} rows") |
|||
|
|||
# Read with pq.read_table |
|||
print("\nReading with pq.read_table...") |
|||
table_read = pq.read_table(filename, filesystem=s3) |
|||
print(f"✓ Read {table_read.num_rows:,} rows") |
|||
|
|||
# Read with pq.ParquetDataset |
|||
print("\nReading with pq.ParquetDataset...") |
|||
dataset = pq.ParquetDataset(filename, filesystem=s3) |
|||
table_dataset = dataset.read() |
|||
print(f"✓ Read {table_dataset.num_rows:,} rows") |
|||
|
|||
# Read with pads.dataset |
|||
print("\nReading with pads.dataset...") |
|||
dataset_pads = pads.dataset(filename, filesystem=s3) |
|||
table_pads = dataset_pads.to_table() |
|||
print(f"✓ Read {table_pads.num_rows:,} rows") |
|||
|
|||
print("\n✅ All operations completed successfully!") |
|||
print(f"\nFile written to: {filename}") |
|||
print("You can verify the file using the SeaweedFS S3 API or weed shell") |
|||
|
|||
@ -0,0 +1,41 @@ |
|||
""" |
|||
Shared utility functions for PyArrow Parquet tests. |
|||
|
|||
This module provides common test utilities used across multiple test scripts |
|||
to avoid code duplication and ensure consistency. |
|||
""" |
|||
|
|||
import pyarrow as pa |
|||
|
|||
|
|||
def create_sample_table(num_rows: int = 5) -> pa.Table: |
|||
"""Create a sample PyArrow table for testing. |
|||
|
|||
Args: |
|||
num_rows: Number of rows to generate (default: 5) |
|||
|
|||
Returns: |
|||
PyArrow Table with test data containing: |
|||
- id: int64 sequential IDs (0 to num_rows-1) |
|||
- name: string user names (user_0, user_1, ...) |
|||
- value: float64 values (id * 1.5) |
|||
- flag: bool alternating True/False based on even/odd id |
|||
|
|||
Example: |
|||
>>> table = create_sample_table(3) |
|||
>>> print(table) |
|||
pyarrow.Table |
|||
id: int64 |
|||
name: string |
|||
value: double |
|||
flag: bool |
|||
""" |
|||
return pa.table( |
|||
{ |
|||
"id": pa.array(range(num_rows), type=pa.int64()), |
|||
"name": pa.array([f"user_{i}" for i in range(num_rows)], type=pa.string()), |
|||
"value": pa.array([float(i) * 1.5 for i in range(num_rows)], type=pa.float64()), |
|||
"flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()), |
|||
} |
|||
) |
|||
|
|||
@ -0,0 +1,383 @@ |
|||
#!/usr/bin/env python3 |
|||
""" |
|||
Test script for PyArrow's NATIVE S3 filesystem with SeaweedFS. |
|||
|
|||
This test uses PyArrow's built-in S3FileSystem (pyarrow.fs.S3FileSystem) |
|||
instead of s3fs, providing a pure PyArrow solution for reading and writing |
|||
Parquet files to S3-compatible storage. |
|||
|
|||
Requirements: |
|||
- pyarrow>=10.0.0 |
|||
|
|||
Environment Variables: |
|||
S3_ENDPOINT_URL: S3 endpoint (default: localhost:8333) |
|||
S3_ACCESS_KEY: S3 access key (default: some_access_key1) |
|||
S3_SECRET_KEY: S3 secret key (default: some_secret_key1) |
|||
BUCKET_NAME: S3 bucket name (default: test-parquet-bucket) |
|||
TEST_QUICK: Run only small/quick tests (default: 0, set to 1 for quick mode) |
|||
|
|||
Usage: |
|||
# Run with default environment variables |
|||
python3 test_pyarrow_native_s3.py |
|||
|
|||
# Run with custom environment variables |
|||
S3_ENDPOINT_URL=localhost:8333 \ |
|||
S3_ACCESS_KEY=mykey \ |
|||
S3_SECRET_KEY=mysecret \ |
|||
BUCKET_NAME=mybucket \ |
|||
python3 test_pyarrow_native_s3.py |
|||
""" |
|||
|
|||
import os |
|||
import secrets |
|||
import sys |
|||
import logging |
|||
from typing import Optional |
|||
|
|||
import pyarrow as pa |
|||
import pyarrow.dataset as pads |
|||
import pyarrow.fs as pafs |
|||
import pyarrow.parquet as pq |
|||
|
|||
try: |
|||
import boto3 |
|||
from botocore.exceptions import ClientError |
|||
HAS_BOTO3 = True |
|||
except ImportError: |
|||
HAS_BOTO3 = False |
|||
|
|||
from parquet_test_utils import create_sample_table |
|||
|
|||
logging.basicConfig(level=logging.INFO, format="%(message)s") |
|||
|
|||
# Configuration from environment variables with defaults |
|||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "localhost:8333") |
|||
S3_ACCESS_KEY = os.environ.get("S3_ACCESS_KEY", "some_access_key1") |
|||
S3_SECRET_KEY = os.environ.get("S3_SECRET_KEY", "some_secret_key1") |
|||
BUCKET_NAME = os.getenv("BUCKET_NAME", "test-parquet-bucket") |
|||
TEST_QUICK = os.getenv("TEST_QUICK", "0") == "1" |
|||
|
|||
# Create randomized test directory |
|||
TEST_RUN_ID = secrets.token_hex(8) |
|||
TEST_DIR = f"parquet-native-tests/{TEST_RUN_ID}" |
|||
|
|||
# Test file sizes |
|||
TEST_SIZES = { |
|||
"small": 5, |
|||
"large": 200_000, # This will create multiple row groups |
|||
} |
|||
|
|||
# Filter to only small tests if quick mode is enabled |
|||
if TEST_QUICK: |
|||
TEST_SIZES = {"small": TEST_SIZES["small"]} |
|||
logging.info("Quick test mode enabled - running only small tests") |
|||
|
|||
|
|||
def init_s3_filesystem() -> tuple[Optional[pafs.S3FileSystem], str, str]: |
|||
"""Initialize PyArrow's native S3 filesystem. |
|||
|
|||
Returns: |
|||
tuple: (S3FileSystem instance, scheme, endpoint) |
|||
""" |
|||
try: |
|||
logging.info("Initializing PyArrow S3FileSystem...") |
|||
logging.info(f" Endpoint: {S3_ENDPOINT_URL}") |
|||
logging.info(f" Bucket: {BUCKET_NAME}") |
|||
|
|||
# Determine scheme from endpoint |
|||
if S3_ENDPOINT_URL.startswith("http://"): |
|||
scheme = "http" |
|||
endpoint = S3_ENDPOINT_URL[7:] # Remove http:// |
|||
elif S3_ENDPOINT_URL.startswith("https://"): |
|||
scheme = "https" |
|||
endpoint = S3_ENDPOINT_URL[8:] # Remove https:// |
|||
else: |
|||
# Default to http for localhost |
|||
scheme = "http" |
|||
endpoint = S3_ENDPOINT_URL |
|||
|
|||
# Enable bucket creation and deletion for testing |
|||
s3 = pafs.S3FileSystem( |
|||
access_key=S3_ACCESS_KEY, |
|||
secret_key=S3_SECRET_KEY, |
|||
endpoint_override=endpoint, |
|||
scheme=scheme, |
|||
allow_bucket_creation=True, |
|||
allow_bucket_deletion=True, |
|||
) |
|||
|
|||
logging.info("✓ PyArrow S3FileSystem initialized successfully\n") |
|||
return s3, scheme, endpoint |
|||
except Exception: |
|||
logging.exception("✗ Failed to initialize PyArrow S3FileSystem") |
|||
return None, "", "" |
|||
|
|||
|
|||
def ensure_bucket_exists_boto3(scheme: str, endpoint: str) -> bool: |
|||
"""Ensure the test bucket exists using boto3.""" |
|||
if not HAS_BOTO3: |
|||
logging.error("boto3 is required for bucket creation") |
|||
return False |
|||
|
|||
try: |
|||
# Create boto3 client |
|||
endpoint_url = f"{scheme}://{endpoint}" |
|||
s3_client = boto3.client( |
|||
's3', |
|||
endpoint_url=endpoint_url, |
|||
aws_access_key_id=S3_ACCESS_KEY, |
|||
aws_secret_access_key=S3_SECRET_KEY, |
|||
region_name='us-east-1', |
|||
) |
|||
|
|||
# Check if bucket exists |
|||
try: |
|||
s3_client.head_bucket(Bucket=BUCKET_NAME) |
|||
logging.info(f"✓ Bucket exists: {BUCKET_NAME}") |
|||
return True |
|||
except ClientError as e: |
|||
error_code = e.response['Error']['Code'] |
|||
if error_code == '404': |
|||
# Bucket doesn't exist, create it |
|||
logging.info(f"Creating bucket: {BUCKET_NAME}") |
|||
s3_client.create_bucket(Bucket=BUCKET_NAME) |
|||
logging.info(f"✓ Bucket created: {BUCKET_NAME}") |
|||
return True |
|||
else: |
|||
raise |
|||
except Exception: |
|||
logging.exception("✗ Failed to create/check bucket") |
|||
return False |
|||
|
|||
|
|||
def ensure_bucket_exists(s3: pafs.S3FileSystem) -> bool: |
|||
"""Ensure the test bucket exists using PyArrow's native S3FileSystem.""" |
|||
try: |
|||
# Check if bucket exists by trying to list it |
|||
try: |
|||
file_info = s3.get_file_info(BUCKET_NAME) |
|||
if file_info.type == pafs.FileType.Directory: |
|||
logging.info(f"✓ Bucket exists: {BUCKET_NAME}") |
|||
return True |
|||
except OSError as e: |
|||
# OSError typically means bucket not found or network/permission issues |
|||
error_msg = str(e).lower() |
|||
if "not found" in error_msg or "does not exist" in error_msg or "nosuchbucket" in error_msg: |
|||
logging.debug(f"Bucket '{BUCKET_NAME}' not found, will attempt creation: {e}") |
|||
else: |
|||
# Log other OSErrors (network, auth, etc.) for debugging |
|||
logging.debug(f"Error checking bucket '{BUCKET_NAME}', will attempt creation anyway: {type(e).__name__}: {e}") |
|||
except Exception as e: |
|||
# Catch any other unexpected exceptions and log them |
|||
logging.debug(f"Unexpected error checking bucket '{BUCKET_NAME}', will attempt creation: {type(e).__name__}: {e}") |
|||
|
|||
# Try to create the bucket |
|||
logging.info(f"Creating bucket: {BUCKET_NAME}") |
|||
s3.create_dir(BUCKET_NAME) |
|||
logging.info(f"✓ Bucket created: {BUCKET_NAME}") |
|||
return True |
|||
except Exception: |
|||
logging.exception(f"✗ Failed to create/check bucket '{BUCKET_NAME}' with PyArrow") |
|||
return False |
|||
|
|||
|
|||
def test_write_and_read(s3: pafs.S3FileSystem, test_name: str, num_rows: int) -> tuple[bool, str]: |
|||
"""Test writing and reading a Parquet dataset using PyArrow's native S3 filesystem.""" |
|||
try: |
|||
table = create_sample_table(num_rows) |
|||
|
|||
# Write using pads.write_dataset |
|||
filename = f"{BUCKET_NAME}/{TEST_DIR}/{test_name}/data.parquet" |
|||
logging.info(f" Writing {num_rows:,} rows to {filename}...") |
|||
|
|||
pads.write_dataset( |
|||
table, |
|||
filename, |
|||
filesystem=s3, |
|||
format="parquet", |
|||
) |
|||
logging.info(" ✓ Write completed") |
|||
|
|||
# Test Method 1: Read with pq.read_table |
|||
logging.info(" Reading with pq.read_table...") |
|||
table_read = pq.read_table(filename, filesystem=s3) |
|||
if table_read.num_rows != num_rows: |
|||
return False, f"pq.read_table: Row count mismatch (expected {num_rows}, got {table_read.num_rows})" |
|||
|
|||
# Check schema first |
|||
if not table_read.schema.equals(table.schema): |
|||
return False, f"pq.read_table: Schema mismatch (expected {table.schema}, got {table_read.schema})" |
|||
|
|||
# Sort both tables by 'id' column before comparison to handle potential row order differences |
|||
table_sorted = table.sort_by([('id', 'ascending')]) |
|||
table_read_sorted = table_read.sort_by([('id', 'ascending')]) |
|||
|
|||
if not table_read_sorted.equals(table_sorted): |
|||
# Provide more detailed error information |
|||
error_details = [] |
|||
for col_name in table.column_names: |
|||
col_original = table_sorted.column(col_name) |
|||
col_read = table_read_sorted.column(col_name) |
|||
if not col_original.equals(col_read): |
|||
error_details.append(f"column '{col_name}' differs") |
|||
return False, f"pq.read_table: Table contents mismatch ({', '.join(error_details)})" |
|||
logging.info(f" ✓ pq.read_table: {table_read.num_rows:,} rows") |
|||
|
|||
# Test Method 2: Read with pq.ParquetDataset |
|||
logging.info(" Reading with pq.ParquetDataset...") |
|||
dataset = pq.ParquetDataset(filename, filesystem=s3) |
|||
table_dataset = dataset.read() |
|||
if table_dataset.num_rows != num_rows: |
|||
return False, f"pq.ParquetDataset: Row count mismatch (expected {num_rows}, got {table_dataset.num_rows})" |
|||
|
|||
# Sort before comparison |
|||
table_dataset_sorted = table_dataset.sort_by([('id', 'ascending')]) |
|||
if not table_dataset_sorted.equals(table_sorted): |
|||
error_details = [] |
|||
for col_name in table.column_names: |
|||
col_original = table_sorted.column(col_name) |
|||
col_read = table_dataset_sorted.column(col_name) |
|||
if not col_original.equals(col_read): |
|||
error_details.append(f"column '{col_name}' differs") |
|||
return False, f"pq.ParquetDataset: Table contents mismatch ({', '.join(error_details)})" |
|||
logging.info(f" ✓ pq.ParquetDataset: {table_dataset.num_rows:,} rows") |
|||
|
|||
# Test Method 3: Read with pads.dataset |
|||
logging.info(" Reading with pads.dataset...") |
|||
dataset_pads = pads.dataset(filename, filesystem=s3) |
|||
table_pads = dataset_pads.to_table() |
|||
if table_pads.num_rows != num_rows: |
|||
return False, f"pads.dataset: Row count mismatch (expected {num_rows}, got {table_pads.num_rows})" |
|||
|
|||
# Sort before comparison |
|||
table_pads_sorted = table_pads.sort_by([('id', 'ascending')]) |
|||
if not table_pads_sorted.equals(table_sorted): |
|||
error_details = [] |
|||
for col_name in table.column_names: |
|||
col_original = table_sorted.column(col_name) |
|||
col_read = table_pads_sorted.column(col_name) |
|||
if not col_original.equals(col_read): |
|||
error_details.append(f"column '{col_name}' differs") |
|||
return False, f"pads.dataset: Table contents mismatch ({', '.join(error_details)})" |
|||
logging.info(f" ✓ pads.dataset: {table_pads.num_rows:,} rows") |
|||
|
|||
return True, "All read methods passed" |
|||
|
|||
except Exception as exc: |
|||
logging.exception(" ✗ Test failed") |
|||
return False, f"{type(exc).__name__}: {exc}" |
|||
|
|||
|
|||
def cleanup_test_files(s3: pafs.S3FileSystem) -> None: |
|||
"""Clean up test files from S3. |
|||
|
|||
Note: We cannot use s3.delete_dir() directly because SeaweedFS uses implicit |
|||
directories (path prefixes without physical directory objects). PyArrow's |
|||
delete_dir() attempts to delete the directory marker itself, which fails with |
|||
"INTERNAL_FAILURE" on SeaweedFS. Instead, we list and delete files individually, |
|||
letting implicit directories disappear automatically. |
|||
""" |
|||
try: |
|||
test_path = f"{BUCKET_NAME}/{TEST_DIR}" |
|||
logging.info(f"Cleaning up test directory: {test_path}") |
|||
|
|||
# List and delete files individually to handle implicit directories |
|||
try: |
|||
file_selector = pafs.FileSelector(test_path, recursive=True) |
|||
files = s3.get_file_info(file_selector) |
|||
|
|||
# Delete files first (not directories) |
|||
for file_info in files: |
|||
if file_info.type == pafs.FileType.File: |
|||
s3.delete_file(file_info.path) |
|||
logging.debug(f" Deleted file: {file_info.path}") |
|||
|
|||
logging.info("✓ Test directory cleaned up") |
|||
except OSError as e: |
|||
# Handle the case where the path doesn't exist or is inaccessible |
|||
if "does not exist" in str(e).lower() or "not found" in str(e).lower(): |
|||
logging.info("✓ Test directory already clean or doesn't exist") |
|||
else: |
|||
raise |
|||
except Exception: |
|||
logging.exception("Failed to cleanup test directory") |
|||
|
|||
|
|||
def main(): |
|||
"""Run all tests with PyArrow's native S3 filesystem.""" |
|||
print("=" * 80) |
|||
print("PyArrow Native S3 Filesystem Tests for SeaweedFS") |
|||
print("Testing Parquet Files with Multiple Row Groups") |
|||
if TEST_QUICK: |
|||
print("*** QUICK TEST MODE - Small files only ***") |
|||
print("=" * 80 + "\n") |
|||
|
|||
print("Configuration:") |
|||
print(f" S3 Endpoint: {S3_ENDPOINT_URL}") |
|||
print(f" Access Key: {S3_ACCESS_KEY}") |
|||
print(f" Bucket: {BUCKET_NAME}") |
|||
print(f" Test Directory: {TEST_DIR}") |
|||
print(f" Quick Mode: {'Yes (small files only)' if TEST_QUICK else 'No (all file sizes)'}") |
|||
print(f" PyArrow Version: {pa.__version__}") |
|||
print() |
|||
|
|||
# Initialize S3 filesystem |
|||
s3, scheme, endpoint = init_s3_filesystem() |
|||
if s3 is None: |
|||
print("Cannot proceed without S3 connection") |
|||
return 1 |
|||
|
|||
# Ensure bucket exists - try PyArrow first, fall back to boto3 |
|||
bucket_created = ensure_bucket_exists(s3) |
|||
if not bucket_created: |
|||
logging.info("Trying to create bucket with boto3...") |
|||
bucket_created = ensure_bucket_exists_boto3(scheme, endpoint) |
|||
|
|||
if not bucket_created: |
|||
print("Cannot proceed without bucket") |
|||
return 1 |
|||
|
|||
results = [] |
|||
|
|||
# Test all file sizes |
|||
for size_name, num_rows in TEST_SIZES.items(): |
|||
print(f"\n{'='*80}") |
|||
print(f"Testing with {size_name} files ({num_rows:,} rows)") |
|||
print(f"{'='*80}\n") |
|||
|
|||
test_name = f"{size_name}_test" |
|||
success, message = test_write_and_read(s3, test_name, num_rows) |
|||
results.append((test_name, success, message)) |
|||
|
|||
status = "✓ PASS" if success else "✗ FAIL" |
|||
print(f"\n{status}: {message}\n") |
|||
|
|||
# Summary |
|||
print("\n" + "=" * 80) |
|||
print("SUMMARY") |
|||
print("=" * 80) |
|||
passed = sum(1 for _, success, _ in results if success) |
|||
total = len(results) |
|||
print(f"\nTotal: {passed}/{total} passed\n") |
|||
|
|||
for test_name, success, message in results: |
|||
status = "✓" if success else "✗" |
|||
print(f" {status} {test_name}: {message}") |
|||
|
|||
print("\n" + "=" * 80) |
|||
if passed == total: |
|||
print("✓ ALL TESTS PASSED!") |
|||
else: |
|||
print(f"✗ {total - passed} test(s) failed") |
|||
|
|||
print("=" * 80 + "\n") |
|||
|
|||
# Cleanup |
|||
cleanup_test_files(s3) |
|||
|
|||
return 0 if passed == total else 1 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
sys.exit(main()) |
|||
|
|||
@ -0,0 +1,254 @@ |
|||
#!/usr/bin/env python3 |
|||
""" |
|||
Test script for SSE-S3 compatibility with PyArrow native S3 filesystem. |
|||
|
|||
This test specifically targets the SSE-S3 multipart upload bug where |
|||
SeaweedFS panics with "bad IV length" when reading multipart uploads |
|||
that were encrypted with bucket-default SSE-S3. |
|||
|
|||
Requirements: |
|||
- pyarrow>=10.0.0 |
|||
- boto3>=1.28.0 |
|||
|
|||
Environment Variables: |
|||
S3_ENDPOINT_URL: S3 endpoint (default: localhost:8333) |
|||
S3_ACCESS_KEY: S3 access key (default: some_access_key1) |
|||
S3_SECRET_KEY: S3 secret key (default: some_secret_key1) |
|||
BUCKET_NAME: S3 bucket name (default: test-parquet-bucket) |
|||
|
|||
Usage: |
|||
# Start SeaweedFS with SSE-S3 enabled |
|||
make start-seaweedfs-ci ENABLE_SSE_S3=true |
|||
|
|||
# Run the test |
|||
python3 test_sse_s3_compatibility.py |
|||
""" |
|||
|
|||
import os |
|||
import secrets |
|||
import sys |
|||
import logging |
|||
from typing import Optional |
|||
|
|||
import pyarrow as pa |
|||
import pyarrow.dataset as pads |
|||
import pyarrow.fs as pafs |
|||
import pyarrow.parquet as pq |
|||
|
|||
try: |
|||
import boto3 |
|||
from botocore.exceptions import ClientError |
|||
HAS_BOTO3 = True |
|||
except ImportError: |
|||
HAS_BOTO3 = False |
|||
logging.exception("boto3 is required for this test") |
|||
sys.exit(1) |
|||
|
|||
from parquet_test_utils import create_sample_table |
|||
|
|||
logging.basicConfig(level=logging.INFO, format="%(message)s") |
|||
|
|||
# Configuration |
|||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "localhost:8333") |
|||
S3_ACCESS_KEY = os.environ.get("S3_ACCESS_KEY", "some_access_key1") |
|||
S3_SECRET_KEY = os.environ.get("S3_SECRET_KEY", "some_secret_key1") |
|||
BUCKET_NAME = os.getenv("BUCKET_NAME", "test-parquet-bucket") |
|||
|
|||
TEST_RUN_ID = secrets.token_hex(8) |
|||
TEST_DIR = f"sse-s3-tests/{TEST_RUN_ID}" |
|||
|
|||
# Test sizes designed to trigger multipart uploads |
|||
# PyArrow typically uses 5MB chunks, so these sizes should trigger multipart |
|||
TEST_SIZES = { |
|||
"tiny": 10, # Single part |
|||
"small": 1_000, # Single part |
|||
"medium": 50_000, # Single part (~1.5MB) |
|||
"large": 200_000, # Multiple parts (~6MB) |
|||
"very_large": 500_000, # Multiple parts (~15MB) |
|||
} |
|||
|
|||
|
|||
def init_s3_filesystem() -> tuple[Optional[pafs.S3FileSystem], str, str]: |
|||
"""Initialize PyArrow's native S3 filesystem.""" |
|||
try: |
|||
logging.info("Initializing PyArrow S3FileSystem...") |
|||
|
|||
# Determine scheme from endpoint |
|||
if S3_ENDPOINT_URL.startswith("http://"): |
|||
scheme = "http" |
|||
endpoint = S3_ENDPOINT_URL[7:] |
|||
elif S3_ENDPOINT_URL.startswith("https://"): |
|||
scheme = "https" |
|||
endpoint = S3_ENDPOINT_URL[8:] |
|||
else: |
|||
scheme = "http" |
|||
endpoint = S3_ENDPOINT_URL |
|||
|
|||
s3 = pafs.S3FileSystem( |
|||
access_key=S3_ACCESS_KEY, |
|||
secret_key=S3_SECRET_KEY, |
|||
endpoint_override=endpoint, |
|||
scheme=scheme, |
|||
allow_bucket_creation=True, |
|||
allow_bucket_deletion=True, |
|||
) |
|||
|
|||
logging.info("✓ PyArrow S3FileSystem initialized\n") |
|||
return s3, scheme, endpoint |
|||
except Exception: |
|||
logging.exception("✗ Failed to initialize PyArrow S3FileSystem") |
|||
return None, "", "" |
|||
|
|||
|
|||
def ensure_bucket_exists(scheme: str, endpoint: str) -> bool: |
|||
"""Ensure the test bucket exists using boto3.""" |
|||
try: |
|||
endpoint_url = f"{scheme}://{endpoint}" |
|||
s3_client = boto3.client( |
|||
's3', |
|||
endpoint_url=endpoint_url, |
|||
aws_access_key_id=S3_ACCESS_KEY, |
|||
aws_secret_access_key=S3_SECRET_KEY, |
|||
region_name='us-east-1', |
|||
) |
|||
|
|||
try: |
|||
s3_client.head_bucket(Bucket=BUCKET_NAME) |
|||
logging.info(f"✓ Bucket exists: {BUCKET_NAME}") |
|||
except ClientError as e: |
|||
error_code = e.response['Error']['Code'] |
|||
if error_code == '404': |
|||
logging.info(f"Creating bucket: {BUCKET_NAME}") |
|||
s3_client.create_bucket(Bucket=BUCKET_NAME) |
|||
logging.info(f"✓ Bucket created: {BUCKET_NAME}") |
|||
else: |
|||
logging.exception("✗ Failed to access bucket") |
|||
return False |
|||
|
|||
# Note: SeaweedFS doesn't support GetBucketEncryption API |
|||
# so we can't verify if SSE-S3 is enabled via API |
|||
# We assume it's configured correctly in the s3.json config file |
|||
logging.info("✓ Assuming SSE-S3 is configured in s3.json") |
|||
return True |
|||
|
|||
except Exception: |
|||
logging.exception("✗ Failed to check bucket") |
|||
return False |
|||
|
|||
|
|||
def test_write_read_with_sse( |
|||
s3: pafs.S3FileSystem, |
|||
test_name: str, |
|||
num_rows: int |
|||
) -> tuple[bool, str, int]: |
|||
"""Test writing and reading with SSE-S3 encryption.""" |
|||
try: |
|||
table = create_sample_table(num_rows) |
|||
filename = f"{BUCKET_NAME}/{TEST_DIR}/{test_name}/data.parquet" |
|||
|
|||
logging.info(f" Writing {num_rows:,} rows...") |
|||
pads.write_dataset( |
|||
table, |
|||
filename, |
|||
filesystem=s3, |
|||
format="parquet", |
|||
) |
|||
|
|||
logging.info(" Reading back...") |
|||
table_read = pq.read_table(filename, filesystem=s3) |
|||
|
|||
if table_read.num_rows != num_rows: |
|||
return False, f"Row count mismatch: {table_read.num_rows} != {num_rows}", 0 |
|||
|
|||
return True, "Success", table_read.num_rows |
|||
|
|||
except Exception as e: |
|||
error_msg = f"{type(e).__name__}: {e!s}" |
|||
logging.exception(" ✗ Failed") |
|||
return False, error_msg, 0 |
|||
|
|||
|
|||
def main(): |
|||
"""Run SSE-S3 compatibility tests.""" |
|||
print("=" * 80) |
|||
print("SSE-S3 Compatibility Tests for PyArrow Native S3") |
|||
print("Testing Multipart Upload Encryption") |
|||
print("=" * 80 + "\n") |
|||
|
|||
print("Configuration:") |
|||
print(f" S3 Endpoint: {S3_ENDPOINT_URL}") |
|||
print(f" Bucket: {BUCKET_NAME}") |
|||
print(f" Test Directory: {TEST_DIR}") |
|||
print(f" PyArrow Version: {pa.__version__}") |
|||
print() |
|||
|
|||
# Initialize |
|||
s3, scheme, endpoint = init_s3_filesystem() |
|||
if s3 is None: |
|||
print("Cannot proceed without S3 connection") |
|||
return 1 |
|||
|
|||
# Check bucket and SSE-S3 |
|||
if not ensure_bucket_exists(scheme, endpoint): |
|||
print("\n⚠ WARNING: Failed to access or create the test bucket!") |
|||
print("This test requires a reachable bucket with SSE-S3 enabled.") |
|||
print("Please ensure SeaweedFS is running with: make start-seaweedfs-ci ENABLE_SSE_S3=true") |
|||
return 1 |
|||
|
|||
print() |
|||
results = [] |
|||
|
|||
# Test all sizes |
|||
for size_name, num_rows in TEST_SIZES.items(): |
|||
print(f"\n{'='*80}") |
|||
print(f"Testing {size_name} dataset ({num_rows:,} rows)") |
|||
print(f"{'='*80}") |
|||
|
|||
success, message, rows_read = test_write_read_with_sse( |
|||
s3, size_name, num_rows |
|||
) |
|||
results.append((size_name, num_rows, success, message, rows_read)) |
|||
|
|||
if success: |
|||
print(f" ✓ SUCCESS: Read {rows_read:,} rows") |
|||
else: |
|||
print(f" ✗ FAILED: {message}") |
|||
|
|||
# Summary |
|||
print("\n" + "=" * 80) |
|||
print("SUMMARY") |
|||
print("=" * 80) |
|||
|
|||
passed = sum(1 for _, _, success, _, _ in results if success) |
|||
total = len(results) |
|||
print(f"\nTotal: {passed}/{total} tests passed\n") |
|||
|
|||
print(f"{'Size':<15} {'Rows':>10} {'Status':<10} {'Rows Read':>10} {'Message':<40}") |
|||
print("-" * 90) |
|||
for size_name, num_rows, success, message, rows_read in results: |
|||
status = "✓ PASS" if success else "✗ FAIL" |
|||
rows_str = f"{rows_read:,}" if success else "N/A" |
|||
print(f"{size_name:<15} {num_rows:>10,} {status:<10} {rows_str:>10} {message[:40]}") |
|||
|
|||
print("\n" + "=" * 80) |
|||
if passed == total: |
|||
print("✓ ALL TESTS PASSED WITH SSE-S3!") |
|||
print("\nThis means:") |
|||
print(" - SSE-S3 encryption is working correctly") |
|||
print(" - PyArrow native S3 filesystem is compatible") |
|||
print(" - Multipart uploads are handled properly") |
|||
else: |
|||
print(f"✗ {total - passed} test(s) failed") |
|||
print("\nPossible issues:") |
|||
print(" - SSE-S3 multipart upload bug with empty IV") |
|||
print(" - Encryption/decryption mismatch") |
|||
print(" - File corruption during upload") |
|||
|
|||
print("=" * 80 + "\n") |
|||
|
|||
return 0 if passed == total else 1 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
sys.exit(main()) |
|||
|
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue