From 56fe0f6f300ad63fe70314dc843a44f644c4e1dd Mon Sep 17 00:00:00 2001 From: Vargha Csongor Date: Fri, 18 Apr 2025 10:15:28 +0200 Subject: [PATCH] cypher verify test --- .../internal/memgraph/cypher_verify_string.go | 11 ++--- .../memgraph/cypher_verify_string_test.go | 45 +++++++++++++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 apps/db-adapter/internal/memgraph/cypher_verify_string_test.go diff --git a/apps/db-adapter/internal/memgraph/cypher_verify_string.go b/apps/db-adapter/internal/memgraph/cypher_verify_string.go index 0a75037..3b64557 100644 --- a/apps/db-adapter/internal/memgraph/cypher_verify_string.go +++ b/apps/db-adapter/internal/memgraph/cypher_verify_string.go @@ -2,6 +2,7 @@ package memgraph import ( "fmt" + "regexp" "strings" ) @@ -104,19 +105,13 @@ var cypherDelimiters = map[string]string{ // VerifyString verifies if a string is valid and does not contain cypher injection func VerifyString(s string) error { - s = strings.ToUpper(s) for _, keyword := range cypherKeywords { - if strings.Contains(s, keyword) { + keywordPattern := fmt.Sprintf(`\b%s\b`, strings.ToUpper(keyword)) + if match, _ := regexp.MatchString(keywordPattern, strings.ToUpper(s)); match { return fmt.Errorf("invalid string: %s contains cypher keyword: %s", s, keyword) } } - for _, operator := range cypherOperators { - if strings.Contains(s, operator) { - return fmt.Errorf("invalid string: %s contains cypher operator: %s", s, operator) - } - } - for key := range cypherDelimiters { if strings.Contains(s, key) { return fmt.Errorf("invalid string: %s contains cypher delimiter: %s", s, key) diff --git a/apps/db-adapter/internal/memgraph/cypher_verify_string_test.go b/apps/db-adapter/internal/memgraph/cypher_verify_string_test.go new file mode 100644 index 0000000..a0d6a81 --- /dev/null +++ b/apps/db-adapter/internal/memgraph/cypher_verify_string_test.go @@ -0,0 +1,45 @@ +package memgraph + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyString(t *testing.T) { + tests := []struct { + input string + expectError bool + }{ + {"MATCH (n) RETURN n", true}, // Contains Cypher keyword + {"Hello 'World'", true}, // Contains Cypher delimiter + {"Hello World", false}, // Valid string + } + + for _, test := range tests { + err := VerifyString(test.input) + if test.expectError { + require.Error(t, err, "expected error for input: %s", test.input) + } else { + require.NoError(t, err, "did not expect error for input: %s", test.input) + } + } +} + +func TestEscapeString(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello 'World'", "Hello \\'World\\'"}, + {"Hello \"World\"", "Hello \\\"World\\\""}, + {"Hello `World`", "Hello ``World``"}, + {"Hello World", "Hello World"}, // No delimiters to escape + } + + for _, test := range tests { + result := EscapeString(test.input) + assert.Equal(t, test.expected, result, "unexpected result for input: %s", test.input) + } +}