163 lines
4.0 KiB
Go
163 lines
4.0 KiB
Go
package database
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func getSQLTableName(table any) string {
|
|
reflectType := reflect.TypeOf(table)
|
|
if reflectType.Kind() == reflect.Pointer {
|
|
reflectType = reflectType.Elem()
|
|
}
|
|
return toSnakeCase(reflectType.Name())
|
|
}
|
|
|
|
func toSnakeCase(str string) string {
|
|
var result []rune
|
|
for i, r := range str {
|
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
|
result = append(result, '_')
|
|
}
|
|
result = append(result, r)
|
|
}
|
|
return strings.ToLower(string(result))
|
|
}
|
|
|
|
func getSQLColumns(table any) []sqlColumn {
|
|
columns := []sqlColumn{}
|
|
|
|
reflectValue := reflect.ValueOf(table)
|
|
if reflectValue.Kind() == reflect.Pointer {
|
|
reflectValue = reflectValue.Elem()
|
|
}
|
|
reflectType := reflectValue.Type()
|
|
|
|
for i := 0; i < reflectType.NumField(); i++ {
|
|
field := reflectType.Field(i)
|
|
dbTag := field.Tag.Get("db")
|
|
if len(dbTag) < 2 {
|
|
continue
|
|
}
|
|
|
|
if dbType := field.Tag.Get("dbType"); len(dbType) > 0 {
|
|
dbTag = dbType + " " + dbTag
|
|
} else if dbType = goToMySQLTypes[field.Type.String()]; len(dbType) > 0 {
|
|
dbTag = dbType + " " + dbTag
|
|
} else {
|
|
log.Fatalf("[getSQLColumns] no type found for column '%s' -> '%s'", field.Name, field.Type)
|
|
}
|
|
|
|
foreignKey := field.Tag.Get("foreignKey")
|
|
columns = append(columns, sqlColumn{Name: field.Name, Statement: dbTag, ForeignKey: foreignKey})
|
|
}
|
|
|
|
return columns
|
|
}
|
|
|
|
// @TODO: maybe save result from the func globally to avoid reflect overuse
|
|
func getSQLTableFromInterface(table any) *sqlTable {
|
|
tableName := getSQLTableName(table)
|
|
columns := getSQLColumns(table)
|
|
|
|
var primaryKey string
|
|
if primary, ok := table.(SQLPrimary); ok {
|
|
primaryKeys := primary.GetPrimaryKeys()
|
|
|
|
if len(primaryKeys) > 0 {
|
|
if slices.ContainsFunc(primaryKeys, func(key string) bool {
|
|
return !slices.ContainsFunc(columns, func(column sqlColumn) bool { return column.Name == key })
|
|
}) {
|
|
log.Fatalf("[AddToCreateTable]: invalid primary key for table '%s'", tableName)
|
|
}
|
|
|
|
primaryKey = fmt.Sprintf("PRIMARY KEY (%s), ", strings.Join(primaryKeys, ", "))
|
|
}
|
|
|
|
}
|
|
|
|
return &sqlTable{Name: tableName, Columns: columns, PrimaryKey: primaryKey}
|
|
}
|
|
|
|
// to db
|
|
func serialize(val any) ([]string, []any) {
|
|
reflectValue := reflect.ValueOf(val)
|
|
if reflectValue.Kind() == reflect.Pointer {
|
|
reflectValue = reflectValue.Elem()
|
|
}
|
|
reflectType := reflectValue.Type()
|
|
|
|
var columns []string
|
|
var values []any
|
|
|
|
for i := 0; i < reflectType.NumField(); i++ {
|
|
field := reflectType.Field(i)
|
|
dbTag := field.Tag.Get("db")
|
|
if len(dbTag) < 2 {
|
|
continue
|
|
}
|
|
|
|
reflectField := reflectValue.Field(i)
|
|
if !reflectField.IsValid() || reflectField.IsZero() {
|
|
continue
|
|
}
|
|
|
|
columns = append(columns, field.Name)
|
|
|
|
switch field.Type {
|
|
case reflect.TypeOf(time.Time{}):
|
|
const mysqlTimeFormat = "2006-01-02 15:04:05"
|
|
values = append(values, reflectField.Interface().(time.Time).Format(mysqlTimeFormat))
|
|
case reflect.TypeOf(true):
|
|
if reflectField.Bool() {
|
|
values = append(values, 1)
|
|
} else {
|
|
values = append(values, 0)
|
|
}
|
|
default:
|
|
values = append(values, reflectField.Interface())
|
|
}
|
|
}
|
|
|
|
return columns, values
|
|
}
|
|
|
|
// from db
|
|
func deserialize(val any, columns []string, values []any) {
|
|
reflectElem := reflect.ValueOf(val)
|
|
if reflectElem.Kind() == reflect.Pointer {
|
|
reflectElem = reflectElem.Elem()
|
|
}
|
|
for i, colName := range columns {
|
|
field := reflectElem.FieldByName(colName)
|
|
if !field.IsValid() || !field.CanSet() {
|
|
continue
|
|
}
|
|
|
|
reflectVal := reflect.ValueOf(values[i])
|
|
if !reflectVal.IsValid() || reflectVal.IsZero() {
|
|
field.SetZero()
|
|
continue
|
|
}
|
|
|
|
switch field.Type() {
|
|
case reflect.TypeOf(time.Time{}):
|
|
const mysqlTimeFormat = "2006-01-02 15:04:05"
|
|
parsed, err := time.Parse(mysqlTimeFormat, string(reflectVal.Bytes()))
|
|
if err != nil {
|
|
log.Println("[ERROR] could not parse time from db")
|
|
continue
|
|
}
|
|
field.Set(reflect.ValueOf(parsed))
|
|
case reflect.TypeOf(true):
|
|
field.SetBool(reflectVal.Interface() == 1)
|
|
default:
|
|
field.Set(reflectVal.Convert(field.Type()))
|
|
}
|
|
}
|
|
}
|