mirror of
https://github.com/vcscsvcscs/GenerationsHeritage.git
synced 2025-08-13 06:19:05 +02:00
fix templated parameter type, convert struct to map
This commit is contained in:
@@ -9,9 +9,10 @@ import (
|
||||
)
|
||||
|
||||
func CreatePerson(ctx context.Context, person *api.PersonProperties) neo4j.ManagedTransactionWork {
|
||||
convertedPerson := StructToMap(person)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, CreatePersonCypherQuery, map[string]any{
|
||||
"Person": *person,
|
||||
"Person": convertedPerson,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -45,10 +46,11 @@ func GetPersonById(ctx context.Context, id int) neo4j.ManagedTransactionWork {
|
||||
}
|
||||
|
||||
func UpdatePerson(ctx context.Context, id int, person *api.PersonProperties) neo4j.ManagedTransactionWork {
|
||||
convertedPerson := StructToMap(person)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, UpdatePersonCypherQuery, map[string]any{
|
||||
"id": id,
|
||||
"props": *person,
|
||||
"props": convertedPerson,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -26,10 +26,11 @@ func GetPersonByGoogleId(ctx context.Context, googleId string) neo4j.ManagedTran
|
||||
}
|
||||
|
||||
func UpdatePersonByInviteCode(ctx context.Context, inviteCode string, person *api.PersonProperties) neo4j.ManagedTransactionWork {
|
||||
convertedPerson := StructToMap(person)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, UpdatePersonByInviteCodeCypherQuery, map[string]any{
|
||||
"invite_code": inviteCode,
|
||||
"props": *person,
|
||||
"props": convertedPerson,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -43,11 +43,12 @@ func DeleteRelationship(ctx context.Context, id1, id2 int) neo4j.ManagedTransact
|
||||
func UpdateRelationship(
|
||||
ctx context.Context, id1, id2 int, relationship api.FamilyRelationship,
|
||||
) neo4j.ManagedTransactionWork {
|
||||
convertedRelationship := StructToMap(relationship)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, UpdateRelationshipCypherQuery, map[string]any{
|
||||
"id1": id1,
|
||||
"id2": id2,
|
||||
"relationship": relationship,
|
||||
"relationship": convertedRelationship,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,12 +66,13 @@ func UpdateRelationship(
|
||||
func CreateChildParentRelationship(
|
||||
ctx context.Context, childId, parentId int, relationship api.FamilyRelationship,
|
||||
) neo4j.ManagedTransactionWork {
|
||||
convertedRelationship := StructToMap(relationship)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, CreateChildParentRelationshipCypherQuery, map[string]any{
|
||||
"childId": childId,
|
||||
"parentId": parentId,
|
||||
"childRelationship": relationship,
|
||||
"parentRelationship": relationship,
|
||||
"childRelationship": convertedRelationship,
|
||||
"parentRelationship": convertedRelationship,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -83,12 +85,13 @@ func CreateChildParentRelationship(
|
||||
func CreateSiblingRelationship(
|
||||
ctx context.Context, siblingId1, siblingId2 int, relationship api.FamilyRelationship,
|
||||
) neo4j.ManagedTransactionWork {
|
||||
convertedRelationship := StructToMap(relationship)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, CreateSiblingRelationshipCypherQuery, map[string]any{
|
||||
"id1": siblingId1,
|
||||
"id2": siblingId2,
|
||||
"Relationship1": relationship,
|
||||
"Relationship2": relationship,
|
||||
"Relationship1": convertedRelationship,
|
||||
"Relationship2": convertedRelationship,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -101,12 +104,13 @@ func CreateSiblingRelationship(
|
||||
func CreateSpouseRelationship(
|
||||
ctx context.Context, spouseId1, spouseId2 int, relationship api.FamilyRelationship,
|
||||
) neo4j.ManagedTransactionWork {
|
||||
convertedRelationship := StructToMap(relationship)
|
||||
return func(tx neo4j.ManagedTransaction) (any, error) {
|
||||
result, err := tx.Run(ctx, CreateSpouseRelationshipCypherQuery, map[string]any{
|
||||
"id1": spouseId1,
|
||||
"id2": spouseId2,
|
||||
"Relationship1": relationship,
|
||||
"Relationship2": relationship,
|
||||
"Relationship1": convertedRelationship,
|
||||
"Relationship2": convertedRelationship,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
106
apps/db-adapter/internal/memgraph/struct_to_map.go
Normal file
106
apps/db-adapter/internal/memgraph/struct_to_map.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package memgraph
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
|
||||
)
|
||||
|
||||
// StructToMap recursively converts a struct to a map using JSON tags.
|
||||
// Nil pointers and unexported fields are excluded.
|
||||
func StructToMap(input interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
value := reflect.ValueOf(input)
|
||||
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() {
|
||||
return result
|
||||
}
|
||||
value = value.Elem()
|
||||
}
|
||||
|
||||
typ := value.Type()
|
||||
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := value.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if field.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
// Only use the name before the first comma
|
||||
jsonKey := field.Name
|
||||
if jsonTag != "" {
|
||||
jsonKey = jsonTag
|
||||
if commaIdx := indexComma(jsonKey); commaIdx >= 0 {
|
||||
jsonKey = jsonKey[:commaIdx]
|
||||
}
|
||||
}
|
||||
|
||||
// Skip empty json keys (e.g., tag is `json:"-"`)
|
||||
if jsonKey == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nil pointers
|
||||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Dereference pointers
|
||||
val := fieldValue
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
val = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if isPreservedType(val.Interface()) {
|
||||
result[jsonKey] = val.Interface()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Recurse into nested structs
|
||||
switch val.Kind() {
|
||||
case reflect.Struct:
|
||||
result[jsonKey] = StructToMap(val.Interface())
|
||||
default:
|
||||
result[jsonKey] = val.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func indexComma(tag string) int {
|
||||
for i, r := range tag {
|
||||
if r == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Checks if a value is one of the preserved types that shouldn't be expanded recursively
|
||||
func isPreservedType(v interface{}) bool {
|
||||
switch v.(type) {
|
||||
case dbtype.Point2D, *dbtype.Point2D,
|
||||
dbtype.Point3D, *dbtype.Point3D,
|
||||
time.Time,
|
||||
dbtype.LocalDateTime,
|
||||
dbtype.Date,
|
||||
dbtype.Time,
|
||||
dbtype.LocalTime,
|
||||
dbtype.Duration:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
100
apps/db-adapter/internal/memgraph/struct_to_map_test.go
Normal file
100
apps/db-adapter/internal/memgraph/struct_to_map_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package memgraph
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
|
||||
)
|
||||
|
||||
type NestedStruct struct {
|
||||
NestedField string `json:"nested_field"`
|
||||
}
|
||||
|
||||
type TestStruct struct {
|
||||
ExportedField string `json:"exported_field"`
|
||||
unexportedField string // Should be ignored
|
||||
IgnoredField string `json:"-"`
|
||||
PointerField *string `json:"pointer_field"`
|
||||
Nested NestedStruct `json:"nested"`
|
||||
NilPointer *string `json:"nil_pointer"`
|
||||
}
|
||||
|
||||
func TestStructToMap(t *testing.T) {
|
||||
// Test data
|
||||
pointerValue := "pointer value"
|
||||
testStruct := TestStruct{
|
||||
ExportedField: "exported value",
|
||||
unexportedField: "unexported value",
|
||||
IgnoredField: "ignored value",
|
||||
PointerField: &pointerValue,
|
||||
Nested: NestedStruct{
|
||||
NestedField: "nested value",
|
||||
},
|
||||
NilPointer: nil,
|
||||
}
|
||||
|
||||
expected := map[string]interface{}{
|
||||
"exported_field": "exported value",
|
||||
"pointer_field": "pointer value",
|
||||
"nested": map[string]interface{}{
|
||||
"nested_field": "nested value",
|
||||
},
|
||||
}
|
||||
|
||||
result := StructToMap(testStruct)
|
||||
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("StructToMap() = %v, want %v", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructToMap_NilPointer(t *testing.T) {
|
||||
var nilPointer *TestStruct
|
||||
result := StructToMap(nilPointer)
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("StructToMap(nil) = %v, want empty map", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructToMap_EmptyStruct(t *testing.T) {
|
||||
type EmptyStruct struct{}
|
||||
result := StructToMap(EmptyStruct{})
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("StructToMap(EmptyStruct{}) = %v, want empty map", result)
|
||||
}
|
||||
}
|
||||
func TestIsPreservedType(t *testing.T) {
|
||||
// Test cases for preserved types
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected bool
|
||||
}{
|
||||
{"Point2D", dbtype.Point2D{}, true},
|
||||
{"Pointer to Point2D", &dbtype.Point2D{}, true},
|
||||
{"Point3D", dbtype.Point3D{}, true},
|
||||
{"Pointer to Point3D", &dbtype.Point3D{}, true},
|
||||
{"Time", time.Time{}, true},
|
||||
{"LocalDateTime", dbtype.LocalDateTime{}, true},
|
||||
{"Date", dbtype.Date{}, true},
|
||||
{"Time", dbtype.Time{}, true},
|
||||
{"LocalTime", dbtype.LocalTime{}, true},
|
||||
{"Duration", dbtype.Duration{}, true},
|
||||
{"String", "not preserved", false},
|
||||
{"Integer", 123, false},
|
||||
{"Struct", struct{}{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isPreservedType(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isPreservedType(%v) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user