You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.0 KiB
110 lines
3.0 KiB
package engine
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// TestPostgreSQLOnlySupport ensures that non-PostgreSQL syntax is properly rejected
|
|
func TestPostgreSQLOnlySupport(t *testing.T) {
|
|
engine := NewTestSQLEngine()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
sql string
|
|
shouldError bool
|
|
errorMsg string
|
|
desc string
|
|
}{
|
|
// Test that MySQL backticks are not supported for identifiers
|
|
{
|
|
name: "MySQL_Backticks_Table",
|
|
sql: "SELECT * FROM `user_events` LIMIT 1",
|
|
shouldError: true,
|
|
desc: "MySQL backticks for table names should be rejected",
|
|
},
|
|
{
|
|
name: "MySQL_Backticks_Column",
|
|
sql: "SELECT `column_name` FROM user_events LIMIT 1",
|
|
shouldError: true,
|
|
desc: "MySQL backticks for column names should be rejected",
|
|
},
|
|
|
|
// Test that PostgreSQL double quotes work (should NOT error)
|
|
{
|
|
name: "PostgreSQL_Double_Quotes_OK",
|
|
sql: `SELECT "user_id" FROM user_events LIMIT 1`,
|
|
shouldError: false,
|
|
desc: "PostgreSQL double quotes for identifiers should work",
|
|
},
|
|
|
|
// Note: MySQL functions like YEAR(), MONTH() may parse but won't have proper implementations
|
|
// They're removed from the engine so they won't work correctly, but we don't explicitly reject them
|
|
|
|
// Test that PostgreSQL EXTRACT works (should NOT error)
|
|
{
|
|
name: "PostgreSQL_EXTRACT_OK",
|
|
sql: "SELECT EXTRACT(YEAR FROM CURRENT_DATE) FROM user_events LIMIT 1",
|
|
shouldError: false,
|
|
desc: "PostgreSQL EXTRACT function should work",
|
|
},
|
|
|
|
// Test that single quotes work for string literals but not identifiers
|
|
{
|
|
name: "Single_Quotes_String_Literal_OK",
|
|
sql: "SELECT 'hello world' FROM user_events LIMIT 1",
|
|
shouldError: false,
|
|
desc: "Single quotes for string literals should work",
|
|
},
|
|
}
|
|
|
|
passCount := 0
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result, err := engine.ExecuteSQL(context.Background(), tc.sql)
|
|
|
|
if tc.shouldError {
|
|
// We expect this query to fail
|
|
if err == nil && result.Error == nil {
|
|
t.Errorf("❌ Expected error for %s, but query succeeded", tc.desc)
|
|
return
|
|
}
|
|
|
|
// Check for specific error message if provided
|
|
if tc.errorMsg != "" {
|
|
errorText := ""
|
|
if err != nil {
|
|
errorText = err.Error()
|
|
} else if result.Error != nil {
|
|
errorText = result.Error.Error()
|
|
}
|
|
|
|
if !strings.Contains(errorText, tc.errorMsg) {
|
|
t.Errorf("❌ Expected error containing '%s', got: %s", tc.errorMsg, errorText)
|
|
return
|
|
}
|
|
}
|
|
|
|
t.Logf("CORRECTLY REJECTED: %s", tc.desc)
|
|
passCount++
|
|
} else {
|
|
// We expect this query to succeed
|
|
if err != nil {
|
|
t.Errorf("Unexpected error for %s: %v", tc.desc, err)
|
|
return
|
|
}
|
|
|
|
if result.Error != nil {
|
|
t.Errorf("Unexpected result error for %s: %v", tc.desc, result.Error)
|
|
return
|
|
}
|
|
|
|
t.Logf("CORRECTLY ACCEPTED: %s", tc.desc)
|
|
passCount++
|
|
}
|
|
})
|
|
}
|
|
|
|
t.Logf("PostgreSQL-only compliance: %d/%d tests passed", passCount, len(testCases))
|
|
}
|