fix templated parameter type, convert struct to map

This commit is contained in:
2025-04-19 16:47:20 +02:00
parent cc9a863311
commit a1b907024e
5 changed files with 223 additions and 10 deletions

View File

@@ -9,9 +9,10 @@ import (
)
func CreatePerson(ctx context.Context, person *api.PersonProperties) neo4j.ManagedTransactionWork {
convertedPerson := StructToMap(person)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, CreatePersonCypherQuery, map[string]any{
"Person": *person,
"Person": convertedPerson,
})
if err != nil {
return nil, err
@@ -45,10 +46,11 @@ func GetPersonById(ctx context.Context, id int) neo4j.ManagedTransactionWork {
}
func UpdatePerson(ctx context.Context, id int, person *api.PersonProperties) neo4j.ManagedTransactionWork {
convertedPerson := StructToMap(person)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, UpdatePersonCypherQuery, map[string]any{
"id": id,
"props": *person,
"props": convertedPerson,
})
if err != nil {
return nil, err

View File

@@ -26,10 +26,11 @@ func GetPersonByGoogleId(ctx context.Context, googleId string) neo4j.ManagedTran
}
func UpdatePersonByInviteCode(ctx context.Context, inviteCode string, person *api.PersonProperties) neo4j.ManagedTransactionWork {
convertedPerson := StructToMap(person)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, UpdatePersonByInviteCodeCypherQuery, map[string]any{
"invite_code": inviteCode,
"props": *person,
"props": convertedPerson,
})
if err != nil {
return nil, err

View File

@@ -43,11 +43,12 @@ func DeleteRelationship(ctx context.Context, id1, id2 int) neo4j.ManagedTransact
func UpdateRelationship(
ctx context.Context, id1, id2 int, relationship api.FamilyRelationship,
) neo4j.ManagedTransactionWork {
convertedRelationship := StructToMap(relationship)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, UpdateRelationshipCypherQuery, map[string]any{
"id1": id1,
"id2": id2,
"relationship": relationship,
"relationship": convertedRelationship,
})
if err != nil {
return nil, err
@@ -65,12 +66,13 @@ func UpdateRelationship(
func CreateChildParentRelationship(
ctx context.Context, childId, parentId int, relationship api.FamilyRelationship,
) neo4j.ManagedTransactionWork {
convertedRelationship := StructToMap(relationship)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, CreateChildParentRelationshipCypherQuery, map[string]any{
"childId": childId,
"parentId": parentId,
"childRelationship": relationship,
"parentRelationship": relationship,
"childRelationship": convertedRelationship,
"parentRelationship": convertedRelationship,
})
if err != nil {
return nil, err
@@ -83,12 +85,13 @@ func CreateChildParentRelationship(
func CreateSiblingRelationship(
ctx context.Context, siblingId1, siblingId2 int, relationship api.FamilyRelationship,
) neo4j.ManagedTransactionWork {
convertedRelationship := StructToMap(relationship)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, CreateSiblingRelationshipCypherQuery, map[string]any{
"id1": siblingId1,
"id2": siblingId2,
"Relationship1": relationship,
"Relationship2": relationship,
"Relationship1": convertedRelationship,
"Relationship2": convertedRelationship,
})
if err != nil {
return nil, err
@@ -101,12 +104,13 @@ func CreateSiblingRelationship(
func CreateSpouseRelationship(
ctx context.Context, spouseId1, spouseId2 int, relationship api.FamilyRelationship,
) neo4j.ManagedTransactionWork {
convertedRelationship := StructToMap(relationship)
return func(tx neo4j.ManagedTransaction) (any, error) {
result, err := tx.Run(ctx, CreateSpouseRelationshipCypherQuery, map[string]any{
"id1": spouseId1,
"id2": spouseId2,
"Relationship1": relationship,
"Relationship2": relationship,
"Relationship1": convertedRelationship,
"Relationship2": convertedRelationship,
})
if err != nil {
return nil, err

View File

@@ -0,0 +1,106 @@
package memgraph
import (
"reflect"
"time"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
)
// StructToMap recursively converts a struct to a map using JSON tags.
// Nil pointers and unexported fields are excluded.
func StructToMap(input interface{}) map[string]interface{} {
result := make(map[string]interface{})
value := reflect.ValueOf(input)
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return result
}
value = value.Elem()
}
typ := value.Type()
for i := 0; i < value.NumField(); i++ {
field := typ.Field(i)
fieldValue := value.Field(i)
// Skip unexported fields
if field.PkgPath != "" {
continue
}
// Get the JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
continue
}
// Only use the name before the first comma
jsonKey := field.Name
if jsonTag != "" {
jsonKey = jsonTag
if commaIdx := indexComma(jsonKey); commaIdx >= 0 {
jsonKey = jsonKey[:commaIdx]
}
}
// Skip empty json keys (e.g., tag is `json:"-"`)
if jsonKey == "" {
continue
}
// Handle nil pointers
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
continue
}
// Dereference pointers
val := fieldValue
if fieldValue.Kind() == reflect.Ptr {
val = fieldValue.Elem()
}
if isPreservedType(val.Interface()) {
result[jsonKey] = val.Interface()
continue
}
// Recurse into nested structs
switch val.Kind() {
case reflect.Struct:
result[jsonKey] = StructToMap(val.Interface())
default:
result[jsonKey] = val.Interface()
}
}
return result
}
func indexComma(tag string) int {
for i, r := range tag {
if r == ',' {
return i
}
}
return -1
}
// Checks if a value is one of the preserved types that shouldn't be expanded recursively
func isPreservedType(v interface{}) bool {
switch v.(type) {
case dbtype.Point2D, *dbtype.Point2D,
dbtype.Point3D, *dbtype.Point3D,
time.Time,
dbtype.LocalDateTime,
dbtype.Date,
dbtype.Time,
dbtype.LocalTime,
dbtype.Duration:
return true
default:
return false
}
}

View File

@@ -0,0 +1,100 @@
package memgraph
import (
"reflect"
"testing"
"time"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
)
type NestedStruct struct {
NestedField string `json:"nested_field"`
}
type TestStruct struct {
ExportedField string `json:"exported_field"`
unexportedField string // Should be ignored
IgnoredField string `json:"-"`
PointerField *string `json:"pointer_field"`
Nested NestedStruct `json:"nested"`
NilPointer *string `json:"nil_pointer"`
}
func TestStructToMap(t *testing.T) {
// Test data
pointerValue := "pointer value"
testStruct := TestStruct{
ExportedField: "exported value",
unexportedField: "unexported value",
IgnoredField: "ignored value",
PointerField: &pointerValue,
Nested: NestedStruct{
NestedField: "nested value",
},
NilPointer: nil,
}
expected := map[string]interface{}{
"exported_field": "exported value",
"pointer_field": "pointer value",
"nested": map[string]interface{}{
"nested_field": "nested value",
},
}
result := StructToMap(testStruct)
if !reflect.DeepEqual(result, expected) {
t.Errorf("StructToMap() = %v, want %v", result, expected)
}
}
func TestStructToMap_NilPointer(t *testing.T) {
var nilPointer *TestStruct
result := StructToMap(nilPointer)
if len(result) != 0 {
t.Errorf("StructToMap(nil) = %v, want empty map", result)
}
}
func TestStructToMap_EmptyStruct(t *testing.T) {
type EmptyStruct struct{}
result := StructToMap(EmptyStruct{})
if len(result) != 0 {
t.Errorf("StructToMap(EmptyStruct{}) = %v, want empty map", result)
}
}
func TestIsPreservedType(t *testing.T) {
// Test cases for preserved types
tests := []struct {
name string
input interface{}
expected bool
}{
{"Point2D", dbtype.Point2D{}, true},
{"Pointer to Point2D", &dbtype.Point2D{}, true},
{"Point3D", dbtype.Point3D{}, true},
{"Pointer to Point3D", &dbtype.Point3D{}, true},
{"Time", time.Time{}, true},
{"LocalDateTime", dbtype.LocalDateTime{}, true},
{"Date", dbtype.Date{}, true},
{"Time", dbtype.Time{}, true},
{"LocalTime", dbtype.LocalTime{}, true},
{"Duration", dbtype.Duration{}, true},
{"String", "not preserved", false},
{"Integer", 123, false},
{"Struct", struct{}{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isPreservedType(tt.input)
if result != tt.expected {
t.Errorf("isPreservedType(%v) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}