From 31bc1c0bb97dba89143ef0b035529610a65eb8b1 Mon Sep 17 00:00:00 2001 From: Vargha Csongor Date: Thu, 10 Apr 2025 15:58:35 +0200 Subject: [PATCH] implement full mock for neo4j client operations --- .../internal/memgraph/mock/result.go | 116 ++++++++++++++++++ .../internal/memgraph/mock/transaction.go | 26 ++++ 2 files changed, 142 insertions(+) create mode 100644 apps/db-adapter/internal/memgraph/mock/result.go create mode 100644 apps/db-adapter/internal/memgraph/mock/transaction.go diff --git a/apps/db-adapter/internal/memgraph/mock/result.go b/apps/db-adapter/internal/memgraph/mock/result.go new file mode 100644 index 0000000..2cff55e --- /dev/null +++ b/apps/db-adapter/internal/memgraph/mock/result.go @@ -0,0 +1,116 @@ +package mock + +import ( + "context" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/stretchr/testify/mock" +) + +type Result struct { + neo4j.ResultWithContext + mock.Mock +} + +// Ensure Result implements neo4j.ResultWithContext interface +var _ neo4j.ResultWithContext = (*Result)(nil) + +func (m *Result) Keys() ([]string, error) { + args := m.Called() + if keys, ok := args.Get(0).([]string); ok { + return keys, args.Error(1) + } + return nil, args.Error(1) +} + +func (m *Result) NextRecord(ctx context.Context, record **neo4j.Record) bool { + args := m.Called(ctx, record) + return args.Bool(0) +} + +func (m *Result) Next(ctx context.Context) bool { + args := m.Called(ctx) + return args.Bool(0) +} + +func (m *Result) PeekRecord(ctx context.Context, record **neo4j.Record) bool { + args := m.Called(ctx, record) + return args.Bool(0) +} + +func (m *Result) Peek(ctx context.Context) bool { + args := m.Called(ctx) + return args.Bool(0) +} + +func (m *Result) Err() error { + args := m.Called() + return args.Error(0) +} + +func (m *Result) Record() *neo4j.Record { + args := m.Called() + if record, ok := args.Get(0).(*neo4j.Record); ok { + return record + } + + return nil +} + +func (m *Result) Collect(ctx context.Context) ([]*neo4j.Record, error) { + args := m.Called(ctx) + if records, ok := args.Get(0).([]*neo4j.Record); ok { + return records, args.Error(1) + } + + return nil, args.Error(1) +} + +func (m *Result) Records(ctx context.Context) func(yield func(*neo4j.Record, error) bool) { + args := m.Called(ctx) + if recordsFunc, ok := args.Get(0).(func(yield func(*neo4j.Record, error) bool)); ok { + return recordsFunc + } + + return nil +} + +func (m *Result) Single(ctx context.Context) (*neo4j.Record, error) { + args := m.Called(ctx) + if record, ok := args.Get(0).(*neo4j.Record); ok { + return record, args.Error(1) + } + + return nil, args.Error(1) +} + +func (m *Result) Consume(ctx context.Context) (neo4j.ResultSummary, error) { + args := m.Called(ctx) + if summary, ok := args.Get(0).(neo4j.ResultSummary); ok { + return summary, args.Error(1) + } + + return nil, args.Error(1) +} + +func (m *Result) IsOpen() bool { + args := m.Called() + return args.Bool(0) +} + +func (m *Result) buffer(ctx context.Context) { + m.Called(ctx) +} + +func (m *Result) legacy() neo4j.Result { + args := m.Called() + if result, ok := args.Get(0).(neo4j.Result); ok { + return result + } + + return nil +} + +func (m *Result) errorHandler(err error) { + m.Called(err) +} diff --git a/apps/db-adapter/internal/memgraph/mock/transaction.go b/apps/db-adapter/internal/memgraph/mock/transaction.go new file mode 100644 index 0000000..275daa0 --- /dev/null +++ b/apps/db-adapter/internal/memgraph/mock/transaction.go @@ -0,0 +1,26 @@ +package mock + +import ( + "context" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/stretchr/testify/mock" +) + +type Transaction struct { + neo4j.ManagedTransaction + mock.Mock +} + +func (m *Transaction) Run(ctx context.Context, cypher string, params map[string]any) (neo4j.ResultWithContext, error) { + args := m.Called(ctx, cypher, params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(neo4j.ResultWithContext), args.Error(1) +} + +func (m *Transaction) legacy() neo4j.Transaction { + return m.Called().Get(0).(neo4j.Transaction) +}