diff --git a/apps/db-adapter/internal/memgraph/person.go b/apps/db-adapter/internal/memgraph/person.go index 03ddacf..4002612 100644 --- a/apps/db-adapter/internal/memgraph/person.go +++ b/apps/db-adapter/internal/memgraph/person.go @@ -9,9 +9,10 @@ import ( ) func CreatePerson(ctx context.Context, person *api.PersonProperties) neo4j.ManagedTransactionWork { + convertedPerson := StructToMap(person) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, CreatePersonCypherQuery, map[string]any{ - "Person": *person, + "Person": convertedPerson, }) if err != nil { return nil, err @@ -45,10 +46,11 @@ func GetPersonById(ctx context.Context, id int) neo4j.ManagedTransactionWork { } func UpdatePerson(ctx context.Context, id int, person *api.PersonProperties) neo4j.ManagedTransactionWork { + convertedPerson := StructToMap(person) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, UpdatePersonCypherQuery, map[string]any{ "id": id, - "props": *person, + "props": convertedPerson, }) if err != nil { return nil, err diff --git a/apps/db-adapter/internal/memgraph/person_google.go b/apps/db-adapter/internal/memgraph/person_google.go index 80b52a8..4594826 100644 --- a/apps/db-adapter/internal/memgraph/person_google.go +++ b/apps/db-adapter/internal/memgraph/person_google.go @@ -26,10 +26,11 @@ func GetPersonByGoogleId(ctx context.Context, googleId string) neo4j.ManagedTran } func UpdatePersonByInviteCode(ctx context.Context, inviteCode string, person *api.PersonProperties) neo4j.ManagedTransactionWork { + convertedPerson := StructToMap(person) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, UpdatePersonByInviteCodeCypherQuery, map[string]any{ "invite_code": inviteCode, - "props": *person, + "props": convertedPerson, }) if err != nil { return nil, err diff --git a/apps/db-adapter/internal/memgraph/relationship.go b/apps/db-adapter/internal/memgraph/relationship.go index 9716d24..dba0b23 100644 --- a/apps/db-adapter/internal/memgraph/relationship.go +++ b/apps/db-adapter/internal/memgraph/relationship.go @@ -43,11 +43,12 @@ func DeleteRelationship(ctx context.Context, id1, id2 int) neo4j.ManagedTransact func UpdateRelationship( ctx context.Context, id1, id2 int, relationship api.FamilyRelationship, ) neo4j.ManagedTransactionWork { + convertedRelationship := StructToMap(relationship) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, UpdateRelationshipCypherQuery, map[string]any{ "id1": id1, "id2": id2, - "relationship": relationship, + "relationship": convertedRelationship, }) if err != nil { return nil, err @@ -65,12 +66,13 @@ func UpdateRelationship( func CreateChildParentRelationship( ctx context.Context, childId, parentId int, relationship api.FamilyRelationship, ) neo4j.ManagedTransactionWork { + convertedRelationship := StructToMap(relationship) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, CreateChildParentRelationshipCypherQuery, map[string]any{ "childId": childId, "parentId": parentId, - "childRelationship": relationship, - "parentRelationship": relationship, + "childRelationship": convertedRelationship, + "parentRelationship": convertedRelationship, }) if err != nil { return nil, err @@ -83,12 +85,13 @@ func CreateChildParentRelationship( func CreateSiblingRelationship( ctx context.Context, siblingId1, siblingId2 int, relationship api.FamilyRelationship, ) neo4j.ManagedTransactionWork { + convertedRelationship := StructToMap(relationship) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, CreateSiblingRelationshipCypherQuery, map[string]any{ "id1": siblingId1, "id2": siblingId2, - "Relationship1": relationship, - "Relationship2": relationship, + "Relationship1": convertedRelationship, + "Relationship2": convertedRelationship, }) if err != nil { return nil, err @@ -101,12 +104,13 @@ func CreateSiblingRelationship( func CreateSpouseRelationship( ctx context.Context, spouseId1, spouseId2 int, relationship api.FamilyRelationship, ) neo4j.ManagedTransactionWork { + convertedRelationship := StructToMap(relationship) return func(tx neo4j.ManagedTransaction) (any, error) { result, err := tx.Run(ctx, CreateSpouseRelationshipCypherQuery, map[string]any{ "id1": spouseId1, "id2": spouseId2, - "Relationship1": relationship, - "Relationship2": relationship, + "Relationship1": convertedRelationship, + "Relationship2": convertedRelationship, }) if err != nil { return nil, err diff --git a/apps/db-adapter/internal/memgraph/struct_to_map.go b/apps/db-adapter/internal/memgraph/struct_to_map.go new file mode 100644 index 0000000..f818827 --- /dev/null +++ b/apps/db-adapter/internal/memgraph/struct_to_map.go @@ -0,0 +1,106 @@ +package memgraph + +import ( + "reflect" + "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" +) + +// StructToMap recursively converts a struct to a map using JSON tags. +// Nil pointers and unexported fields are excluded. +func StructToMap(input interface{}) map[string]interface{} { + result := make(map[string]interface{}) + value := reflect.ValueOf(input) + + if value.Kind() == reflect.Ptr { + if value.IsNil() { + return result + } + value = value.Elem() + } + + typ := value.Type() + + for i := 0; i < value.NumField(); i++ { + field := typ.Field(i) + fieldValue := value.Field(i) + + // Skip unexported fields + if field.PkgPath != "" { + continue + } + + // Get the JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + // Only use the name before the first comma + jsonKey := field.Name + if jsonTag != "" { + jsonKey = jsonTag + if commaIdx := indexComma(jsonKey); commaIdx >= 0 { + jsonKey = jsonKey[:commaIdx] + } + } + + // Skip empty json keys (e.g., tag is `json:"-"`) + if jsonKey == "" { + continue + } + + // Handle nil pointers + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + continue + } + + // Dereference pointers + val := fieldValue + if fieldValue.Kind() == reflect.Ptr { + val = fieldValue.Elem() + } + + if isPreservedType(val.Interface()) { + result[jsonKey] = val.Interface() + + continue + } + + // Recurse into nested structs + switch val.Kind() { + case reflect.Struct: + result[jsonKey] = StructToMap(val.Interface()) + default: + result[jsonKey] = val.Interface() + } + } + + return result +} + +func indexComma(tag string) int { + for i, r := range tag { + if r == ',' { + return i + } + } + return -1 +} + +// Checks if a value is one of the preserved types that shouldn't be expanded recursively +func isPreservedType(v interface{}) bool { + switch v.(type) { + case dbtype.Point2D, *dbtype.Point2D, + dbtype.Point3D, *dbtype.Point3D, + time.Time, + dbtype.LocalDateTime, + dbtype.Date, + dbtype.Time, + dbtype.LocalTime, + dbtype.Duration: + return true + default: + return false + } +} diff --git a/apps/db-adapter/internal/memgraph/struct_to_map_test.go b/apps/db-adapter/internal/memgraph/struct_to_map_test.go new file mode 100644 index 0000000..97b116c --- /dev/null +++ b/apps/db-adapter/internal/memgraph/struct_to_map_test.go @@ -0,0 +1,100 @@ +package memgraph + +import ( + "reflect" + "testing" + "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" +) + +type NestedStruct struct { + NestedField string `json:"nested_field"` +} + +type TestStruct struct { + ExportedField string `json:"exported_field"` + unexportedField string // Should be ignored + IgnoredField string `json:"-"` + PointerField *string `json:"pointer_field"` + Nested NestedStruct `json:"nested"` + NilPointer *string `json:"nil_pointer"` +} + +func TestStructToMap(t *testing.T) { + // Test data + pointerValue := "pointer value" + testStruct := TestStruct{ + ExportedField: "exported value", + unexportedField: "unexported value", + IgnoredField: "ignored value", + PointerField: &pointerValue, + Nested: NestedStruct{ + NestedField: "nested value", + }, + NilPointer: nil, + } + + expected := map[string]interface{}{ + "exported_field": "exported value", + "pointer_field": "pointer value", + "nested": map[string]interface{}{ + "nested_field": "nested value", + }, + } + + result := StructToMap(testStruct) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("StructToMap() = %v, want %v", result, expected) + } +} + +func TestStructToMap_NilPointer(t *testing.T) { + var nilPointer *TestStruct + result := StructToMap(nilPointer) + + if len(result) != 0 { + t.Errorf("StructToMap(nil) = %v, want empty map", result) + } +} + +func TestStructToMap_EmptyStruct(t *testing.T) { + type EmptyStruct struct{} + result := StructToMap(EmptyStruct{}) + + if len(result) != 0 { + t.Errorf("StructToMap(EmptyStruct{}) = %v, want empty map", result) + } +} +func TestIsPreservedType(t *testing.T) { + // Test cases for preserved types + tests := []struct { + name string + input interface{} + expected bool + }{ + {"Point2D", dbtype.Point2D{}, true}, + {"Pointer to Point2D", &dbtype.Point2D{}, true}, + {"Point3D", dbtype.Point3D{}, true}, + {"Pointer to Point3D", &dbtype.Point3D{}, true}, + {"Time", time.Time{}, true}, + {"LocalDateTime", dbtype.LocalDateTime{}, true}, + {"Date", dbtype.Date{}, true}, + {"Time", dbtype.Time{}, true}, + {"LocalTime", dbtype.LocalTime{}, true}, + {"Duration", dbtype.Duration{}, true}, + {"String", "not preserved", false}, + {"Integer", 123, false}, + {"Struct", struct{}{}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPreservedType(tt.input) + if result != tt.expected { + t.Errorf("isPreservedType(%v) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +}