uni/WEB43-diary/internal/database/helper.go

163 lines
4.0 KiB
Go

package database
import (
"fmt"
"log"
"reflect"
"slices"
"strings"
"time"
)
func getSQLTableName(table any) string {
reflectType := reflect.TypeOf(table)
if reflectType.Kind() == reflect.Pointer {
reflectType = reflectType.Elem()
}
return toSnakeCase(reflectType.Name())
}
func toSnakeCase(str string) string {
var result []rune
for i, r := range str {
if i > 0 && r >= 'A' && r <= 'Z' {
result = append(result, '_')
}
result = append(result, r)
}
return strings.ToLower(string(result))
}
func getSQLColumns(table any) []sqlColumn {
columns := []sqlColumn{}
reflectValue := reflect.ValueOf(table)
if reflectValue.Kind() == reflect.Pointer {
reflectValue = reflectValue.Elem()
}
reflectType := reflectValue.Type()
for i := 0; i < reflectType.NumField(); i++ {
field := reflectType.Field(i)
dbTag := field.Tag.Get("db")
if len(dbTag) < 2 {
continue
}
if dbType := field.Tag.Get("dbType"); len(dbType) > 0 {
dbTag = dbType + " " + dbTag
} else if dbType = goToMySQLTypes[field.Type.String()]; len(dbType) > 0 {
dbTag = dbType + " " + dbTag
} else {
log.Fatalf("[getSQLColumns] no type found for column '%s' -> '%s'", field.Name, field.Type)
}
foreignKey := field.Tag.Get("foreignKey")
columns = append(columns, sqlColumn{Name: field.Name, Statement: dbTag, ForeignKey: foreignKey})
}
return columns
}
// @TODO: maybe save result from the func globally to avoid reflect overuse
func getSQLTableFromInterface(table any) *sqlTable {
tableName := getSQLTableName(table)
columns := getSQLColumns(table)
var primaryKey string
if primary, ok := table.(SQLPrimary); ok {
primaryKeys := primary.GetPrimaryKeys()
if len(primaryKeys) > 0 {
if slices.ContainsFunc(primaryKeys, func(key string) bool {
return !slices.ContainsFunc(columns, func(column sqlColumn) bool { return column.Name == key })
}) {
log.Fatalf("[AddToCreateTable]: invalid primary key for table '%s'", tableName)
}
primaryKey = fmt.Sprintf("PRIMARY KEY (%s), ", strings.Join(primaryKeys, ", "))
}
}
return &sqlTable{Name: tableName, Columns: columns, PrimaryKey: primaryKey}
}
// to db
func serialize(val any) ([]string, []any) {
reflectValue := reflect.ValueOf(val)
if reflectValue.Kind() == reflect.Pointer {
reflectValue = reflectValue.Elem()
}
reflectType := reflectValue.Type()
var columns []string
var values []any
for i := 0; i < reflectType.NumField(); i++ {
field := reflectType.Field(i)
dbTag := field.Tag.Get("db")
if len(dbTag) < 2 {
continue
}
reflectField := reflectValue.Field(i)
if !reflectField.IsValid() || reflectField.IsZero() {
continue
}
columns = append(columns, field.Name)
switch field.Type {
case reflect.TypeOf(time.Time{}):
const mysqlTimeFormat = "2006-01-02 15:04:05"
values = append(values, reflectField.Interface().(time.Time).Format(mysqlTimeFormat))
case reflect.TypeOf(true):
if reflectField.Bool() {
values = append(values, 1)
} else {
values = append(values, 0)
}
default:
values = append(values, reflectField.Interface())
}
}
return columns, values
}
// from db
func deserialize(val any, columns []string, values []any) {
reflectElem := reflect.ValueOf(val)
if reflectElem.Kind() == reflect.Pointer {
reflectElem = reflectElem.Elem()
}
for i, colName := range columns {
field := reflectElem.FieldByName(colName)
if !field.IsValid() || !field.CanSet() {
continue
}
reflectVal := reflect.ValueOf(values[i])
if !reflectVal.IsValid() || reflectVal.IsZero() {
field.SetZero()
continue
}
switch field.Type() {
case reflect.TypeOf(time.Time{}):
const mysqlTimeFormat = "2006-01-02 15:04:05"
parsed, err := time.Parse(mysqlTimeFormat, string(reflectVal.Bytes()))
if err != nil {
log.Println("[ERROR] could not parse time from db")
continue
}
field.Set(reflect.ValueOf(parsed))
case reflect.TypeOf(true):
field.SetBool(reflectVal.Interface() == 1)
default:
field.Set(reflectVal.Convert(field.Type()))
}
}
}