You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
407 lines
10 KiB
407 lines
10 KiB
package specifications |
|
|
|
import ( |
|
"fmt" |
|
"reflect" |
|
"strings" |
|
"time" |
|
) |
|
|
|
type Specification[T any] interface { |
|
IsSatisfiedBy(candidate T) bool |
|
And(other Specification[T]) Specification[T] |
|
Or(other Specification[T]) Specification[T] |
|
Not() Specification[T] |
|
GetConditions() []Condition |
|
GetSpec() *Spec |
|
} |
|
|
|
type SimpleSpecification[T any] struct { |
|
spec *Spec |
|
} |
|
|
|
func NewSimpleSpecification[T any](spec *Spec) *SimpleSpecification[T] { |
|
return &SimpleSpecification[T]{spec: spec} |
|
} |
|
|
|
func (s *SimpleSpecification[T]) IsSatisfiedBy(candidate T) bool { |
|
return s.evaluateSpec(s.spec, candidate) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) And(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: s, right: other, op: "AND"} |
|
} |
|
|
|
func (s *SimpleSpecification[T]) Or(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: s, right: other, op: "OR"} |
|
} |
|
|
|
func (s *SimpleSpecification[T]) Not() Specification[T] { |
|
return NewSimpleSpecification[T](Not(s.spec)) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) GetSpec() *Spec { |
|
return s.spec |
|
} |
|
|
|
func (s *SimpleSpecification[T]) GetConditions() []Condition { |
|
return GetConditions(s.spec) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) evaluateSpec(spec *Spec, candidate T) bool { |
|
if spec == nil { |
|
return false |
|
} |
|
|
|
if spec.LogicalGroup != nil { |
|
return s.evaluateLogicalGroup(spec.LogicalGroup, candidate) |
|
} |
|
|
|
if spec.Condition != nil { |
|
return s.evaluateCondition(*spec.Condition, candidate) |
|
} |
|
|
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) evaluateLogicalGroup(group *LogicalGroup, candidate T) bool { |
|
switch group.Operator { |
|
case GROUP_AND: |
|
return s.evaluateAndGroup(group, candidate) |
|
case GROUP_OR: |
|
return s.evaluateOrGroup(group, candidate) |
|
case GROUP_NOT: |
|
if group.Spec != nil { |
|
return !s.evaluateSpec(group.Spec, candidate) |
|
} |
|
} |
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) evaluateAndGroup(group *LogicalGroup, candidate T) bool { |
|
for _, cond := range group.Conditions { |
|
if !s.evaluateCondition(cond, candidate) { |
|
return false |
|
} |
|
} |
|
|
|
if group.Spec != nil { |
|
return s.evaluateSpec(group.Spec, candidate) |
|
} |
|
|
|
return len(group.Conditions) > 0 || group.Spec != nil |
|
} |
|
|
|
func (s *SimpleSpecification[T]) evaluateOrGroup(group *LogicalGroup, candidate T) bool { |
|
for _, cond := range group.Conditions { |
|
if s.evaluateCondition(cond, candidate) { |
|
return true |
|
} |
|
} |
|
|
|
if group.Spec != nil { |
|
return s.evaluateSpec(group.Spec, candidate) |
|
} |
|
|
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) evaluateCondition(condition Condition, candidate T) bool { |
|
fieldValue, err := s.getFieldValue(candidate, condition.Field) |
|
if err != nil { |
|
return false |
|
} |
|
|
|
return s.compareValues(fieldValue, condition.Operator, condition.Value) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) getFieldValue(candidate T, fieldName string) (interface{}, error) { |
|
v := reflect.ValueOf(candidate) |
|
|
|
if v.Kind() == reflect.Ptr { |
|
if v.IsNil() { |
|
return nil, fmt.Errorf("candidate is nil") |
|
} |
|
v = v.Elem() |
|
} |
|
|
|
if v.Kind() != reflect.Struct { |
|
return nil, fmt.Errorf("candidate is not a struct") |
|
} |
|
|
|
getterName := "Get" + strings.Title(fieldName) |
|
|
|
originalV := reflect.ValueOf(candidate) |
|
if method := originalV.MethodByName(getterName); method.IsValid() { |
|
return s.callMethod(method) |
|
} |
|
|
|
if method := originalV.MethodByName(fieldName); method.IsValid() { |
|
return s.callMethod(method) |
|
} |
|
|
|
if v.Kind() == reflect.Struct { |
|
if method := v.MethodByName(getterName); method.IsValid() { |
|
return s.callMethod(method) |
|
} |
|
|
|
if method := v.MethodByName(fieldName); method.IsValid() { |
|
return s.callMethod(method) |
|
} |
|
|
|
if field := v.FieldByName(fieldName); field.IsValid() && field.CanInterface() { |
|
return field.Interface(), nil |
|
} |
|
} |
|
|
|
return nil, fmt.Errorf("field %s not found", fieldName) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) callMethod(method reflect.Value) (interface{}, error) { |
|
if !method.IsValid() || method.Type().NumIn() != 0 || method.Type().NumOut() == 0 { |
|
return nil, fmt.Errorf("invalid method") |
|
} |
|
|
|
results := method.Call(nil) |
|
if len(results) == 0 { |
|
return nil, fmt.Errorf("method returned no values") |
|
} |
|
|
|
return results[0].Interface(), nil |
|
} |
|
|
|
func (s *SimpleSpecification[T]) compareValues(fieldValue interface{}, operator string, compareValue interface{}) bool { |
|
if fieldValue == nil { |
|
return (operator == OP_EQ && compareValue == nil) || (operator == OP_NEQ && compareValue != nil) |
|
} |
|
|
|
if s.isTimeComparable(fieldValue, compareValue) { |
|
return s.compareTimes(fieldValue, operator, compareValue) |
|
} |
|
|
|
if operator == OP_IN || operator == OP_NIN { |
|
return s.compareIn(fieldValue, operator, compareValue) |
|
} |
|
|
|
return s.compareGeneral(fieldValue, operator, compareValue) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) isTimeComparable(fieldValue, compareValue interface{}) bool { |
|
_, fieldIsTime := fieldValue.(time.Time) |
|
_, compareIsTime := compareValue.(time.Time) |
|
|
|
if fieldIsTime || compareIsTime { |
|
return true |
|
} |
|
|
|
return s.hasTimeMethod(fieldValue) || s.hasTimeMethod(compareValue) |
|
} |
|
|
|
func (s *SimpleSpecification[T]) hasTimeMethod(value interface{}) bool { |
|
v := reflect.ValueOf(value) |
|
if !v.IsValid() || v.Kind() == reflect.Ptr { |
|
return false |
|
} |
|
|
|
method := v.MethodByName("Time") |
|
return method.IsValid() && method.Type().NumIn() == 0 && method.Type().NumOut() == 1 |
|
} |
|
|
|
func (s *SimpleSpecification[T]) compareTimes(fieldValue interface{}, operator string, compareValue interface{}) bool { |
|
fieldTime := s.extractTime(fieldValue) |
|
compareTime := s.extractTime(compareValue) |
|
|
|
if fieldTime == nil || compareTime == nil { |
|
return false |
|
} |
|
|
|
switch operator { |
|
case OP_EQ: |
|
return fieldTime.Equal(*compareTime) |
|
case OP_NEQ: |
|
return !fieldTime.Equal(*compareTime) |
|
case OP_GT: |
|
return fieldTime.After(*compareTime) |
|
case OP_GTE: |
|
return fieldTime.After(*compareTime) || fieldTime.Equal(*compareTime) |
|
case OP_LT: |
|
return fieldTime.Before(*compareTime) |
|
case OP_LTE: |
|
return fieldTime.Before(*compareTime) || fieldTime.Equal(*compareTime) |
|
} |
|
|
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) extractTime(value interface{}) *time.Time { |
|
switch v := value.(type) { |
|
case time.Time: |
|
return &v |
|
case string: |
|
if t, err := time.Parse(time.RFC3339, v); err == nil { |
|
return &t |
|
} |
|
default: |
|
if method := reflect.ValueOf(value).MethodByName("Time"); method.IsValid() { |
|
if results := method.Call(nil); len(results) > 0 { |
|
if t, ok := results[0].Interface().(time.Time); ok { |
|
return &t |
|
} |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (s *SimpleSpecification[T]) compareIn(fieldValue interface{}, operator string, compareValue interface{}) bool { |
|
compareSlice, ok := compareValue.([]interface{}) |
|
if !ok { |
|
return false |
|
} |
|
|
|
for _, v := range compareSlice { |
|
if reflect.DeepEqual(fieldValue, v) { |
|
return operator == OP_IN |
|
} |
|
} |
|
|
|
return operator == OP_NIN |
|
} |
|
|
|
func (s *SimpleSpecification[T]) compareGeneral(fieldValue interface{}, operator string, compareValue interface{}) bool { |
|
fieldVal := reflect.ValueOf(fieldValue) |
|
compareVal := reflect.ValueOf(compareValue) |
|
|
|
if !s.makeComparable(&fieldVal, &compareVal) { |
|
return false |
|
} |
|
|
|
switch operator { |
|
case OP_EQ: |
|
return reflect.DeepEqual(fieldValue, compareValue) |
|
case OP_NEQ: |
|
return !reflect.DeepEqual(fieldValue, compareValue) |
|
case OP_GT: |
|
return s.isGreater(fieldVal, compareVal) |
|
case OP_GTE: |
|
return s.isGreater(fieldVal, compareVal) || reflect.DeepEqual(fieldValue, compareValue) |
|
case OP_LT: |
|
return s.isLess(fieldVal, compareVal) |
|
case OP_LTE: |
|
return s.isLess(fieldVal, compareVal) || reflect.DeepEqual(fieldValue, compareValue) |
|
} |
|
|
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) makeComparable(fieldVal, compareVal *reflect.Value) bool { |
|
if fieldVal.Kind() == compareVal.Kind() { |
|
return true |
|
} |
|
|
|
if compareVal.CanConvert(fieldVal.Type()) { |
|
*compareVal = compareVal.Convert(fieldVal.Type()) |
|
return true |
|
} |
|
|
|
if fieldVal.CanConvert(compareVal.Type()) { |
|
*fieldVal = fieldVal.Convert(compareVal.Type()) |
|
return true |
|
} |
|
|
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) isGreater(fieldVal, compareVal reflect.Value) bool { |
|
switch fieldVal.Kind() { |
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|
return fieldVal.Int() > compareVal.Int() |
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|
return fieldVal.Uint() > compareVal.Uint() |
|
case reflect.Float32, reflect.Float64: |
|
return fieldVal.Float() > compareVal.Float() |
|
case reflect.String: |
|
return fieldVal.String() > compareVal.String() |
|
} |
|
return false |
|
} |
|
|
|
func (s *SimpleSpecification[T]) isLess(fieldVal, compareVal reflect.Value) bool { |
|
switch fieldVal.Kind() { |
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|
return fieldVal.Int() < compareVal.Int() |
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|
return fieldVal.Uint() < compareVal.Uint() |
|
case reflect.Float32, reflect.Float64: |
|
return fieldVal.Float() < compareVal.Float() |
|
case reflect.String: |
|
return fieldVal.String() < compareVal.String() |
|
} |
|
return false |
|
} |
|
|
|
type CompositeSpecification[T any] struct { |
|
left Specification[T] |
|
right Specification[T] |
|
op string |
|
} |
|
|
|
func (c *CompositeSpecification[T]) IsSatisfiedBy(candidate T) bool { |
|
switch c.op { |
|
case "AND": |
|
return c.left.IsSatisfiedBy(candidate) && c.right.IsSatisfiedBy(candidate) |
|
case "OR": |
|
return c.left.IsSatisfiedBy(candidate) || c.right.IsSatisfiedBy(candidate) |
|
} |
|
return false |
|
} |
|
|
|
func (c *CompositeSpecification[T]) And(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: c, right: other, op: "AND"} |
|
} |
|
|
|
func (c *CompositeSpecification[T]) Or(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: c, right: other, op: "OR"} |
|
} |
|
|
|
func (c *CompositeSpecification[T]) Not() Specification[T] { |
|
return &NotSpecification[T]{spec: c} |
|
} |
|
|
|
func (c *CompositeSpecification[T]) GetConditions() []Condition { |
|
leftConditions := c.left.GetConditions() |
|
rightConditions := c.right.GetConditions() |
|
return append(leftConditions, rightConditions...) |
|
} |
|
|
|
func (c *CompositeSpecification[T]) GetSpec() *Spec { |
|
return nil |
|
} |
|
|
|
type NotSpecification[T any] struct { |
|
spec Specification[T] |
|
} |
|
|
|
func (n *NotSpecification[T]) IsSatisfiedBy(candidate T) bool { |
|
return !n.spec.IsSatisfiedBy(candidate) |
|
} |
|
|
|
func (n *NotSpecification[T]) And(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: n, right: other, op: "AND"} |
|
} |
|
|
|
func (n *NotSpecification[T]) Or(other Specification[T]) Specification[T] { |
|
return &CompositeSpecification[T]{left: n, right: other, op: "OR"} |
|
} |
|
|
|
func (n *NotSpecification[T]) Not() Specification[T] { |
|
return n.spec |
|
} |
|
|
|
func (n *NotSpecification[T]) GetConditions() []Condition { |
|
return n.spec.GetConditions() |
|
} |
|
|
|
func (n *NotSpecification[T]) GetSpec() *Spec { |
|
return nil |
|
} |