cypher verify test

This commit is contained in:
2025-04-18 10:15:28 +02:00
parent 443db56df4
commit 56fe0f6f30
2 changed files with 48 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ package memgraph
import (
"fmt"
"regexp"
"strings"
)
@@ -104,19 +105,13 @@ var cypherDelimiters = map[string]string{
// VerifyString verifies if a string is valid and does not contain cypher injection
func VerifyString(s string) error {
s = strings.ToUpper(s)
for _, keyword := range cypherKeywords {
if strings.Contains(s, keyword) {
keywordPattern := fmt.Sprintf(`\b%s\b`, strings.ToUpper(keyword))
if match, _ := regexp.MatchString(keywordPattern, strings.ToUpper(s)); match {
return fmt.Errorf("invalid string: %s contains cypher keyword: %s", s, keyword)
}
}
for _, operator := range cypherOperators {
if strings.Contains(s, operator) {
return fmt.Errorf("invalid string: %s contains cypher operator: %s", s, operator)
}
}
for key := range cypherDelimiters {
if strings.Contains(s, key) {
return fmt.Errorf("invalid string: %s contains cypher delimiter: %s", s, key)

View File

@@ -0,0 +1,45 @@
package memgraph
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVerifyString(t *testing.T) {
tests := []struct {
input string
expectError bool
}{
{"MATCH (n) RETURN n", true}, // Contains Cypher keyword
{"Hello 'World'", true}, // Contains Cypher delimiter
{"Hello World", false}, // Valid string
}
for _, test := range tests {
err := VerifyString(test.input)
if test.expectError {
require.Error(t, err, "expected error for input: %s", test.input)
} else {
require.NoError(t, err, "did not expect error for input: %s", test.input)
}
}
}
func TestEscapeString(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"Hello 'World'", "Hello \\'World\\'"},
{"Hello \"World\"", "Hello \\\"World\\\""},
{"Hello `World`", "Hello ``World``"},
{"Hello World", "Hello World"}, // No delimiters to escape
}
for _, test := range tests {
result := EscapeString(test.input)
assert.Equal(t, test.expected, result, "unexpected result for input: %s", test.input)
}
}