mirror of
https://github.com/vcscsvcscs/GenerationsHeritage.git
synced 2025-08-12 22:09:07 +02:00
cypher verify test
This commit is contained in:
@@ -2,6 +2,7 @@ package memgraph
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -104,19 +105,13 @@ var cypherDelimiters = map[string]string{
|
||||
|
||||
// VerifyString verifies if a string is valid and does not contain cypher injection
|
||||
func VerifyString(s string) error {
|
||||
s = strings.ToUpper(s)
|
||||
for _, keyword := range cypherKeywords {
|
||||
if strings.Contains(s, keyword) {
|
||||
keywordPattern := fmt.Sprintf(`\b%s\b`, strings.ToUpper(keyword))
|
||||
if match, _ := regexp.MatchString(keywordPattern, strings.ToUpper(s)); match {
|
||||
return fmt.Errorf("invalid string: %s contains cypher keyword: %s", s, keyword)
|
||||
}
|
||||
}
|
||||
|
||||
for _, operator := range cypherOperators {
|
||||
if strings.Contains(s, operator) {
|
||||
return fmt.Errorf("invalid string: %s contains cypher operator: %s", s, operator)
|
||||
}
|
||||
}
|
||||
|
||||
for key := range cypherDelimiters {
|
||||
if strings.Contains(s, key) {
|
||||
return fmt.Errorf("invalid string: %s contains cypher delimiter: %s", s, key)
|
||||
|
@@ -0,0 +1,45 @@
|
||||
package memgraph
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectError bool
|
||||
}{
|
||||
{"MATCH (n) RETURN n", true}, // Contains Cypher keyword
|
||||
{"Hello 'World'", true}, // Contains Cypher delimiter
|
||||
{"Hello World", false}, // Valid string
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
err := VerifyString(test.input)
|
||||
if test.expectError {
|
||||
require.Error(t, err, "expected error for input: %s", test.input)
|
||||
} else {
|
||||
require.NoError(t, err, "did not expect error for input: %s", test.input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Hello 'World'", "Hello \\'World\\'"},
|
||||
{"Hello \"World\"", "Hello \\\"World\\\""},
|
||||
{"Hello `World`", "Hello ``World``"},
|
||||
{"Hello World", "Hello World"}, // No delimiters to escape
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := EscapeString(test.input)
|
||||
assert.Equal(t, test.expected, result, "unexpected result for input: %s", test.input)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user