diff --git a/.gitignore b/.gitignore index 1f3b49ecd2d..7cedea8ab7f 100644 --- a/.gitignore +++ b/.gitignore @@ -261,6 +261,11 @@ nul /sdks/csharp/packages/ /sdks/csharp/packages.meta +# kotlin SDK +/sdks/kotlin/.kotlin/ +/sdks/kotlin/.gradle/ +/sdks/kotlin/build/ + # AI agent config .codex .claude diff --git a/crates/cli/src/subcommands/generate.rs b/crates/cli/src/subcommands/generate.rs index 6e5378fede5..6f148443bd9 100644 --- a/crates/cli/src/subcommands/generate.rs +++ b/crates/cli/src/subcommands/generate.rs @@ -6,8 +6,8 @@ use clap::Arg; use clap::ArgAction::{Set, SetTrue}; use fs_err as fs; use spacetimedb_codegen::{ - generate, private_table_names, CodegenOptions, CodegenVisibility, Csharp, Lang, OutputFile, Rust, TypeScript, - UnrealCpp, AUTO_GENERATED_PREFIX, + generate, private_table_names, CodegenOptions, CodegenVisibility, Csharp, Kotlin, Lang, OutputFile, Rust, + TypeScript, UnrealCpp, AUTO_GENERATED_PREFIX, }; use spacetimedb_lib::de::serde::DeserializeWrapper; use spacetimedb_lib::{sats, RawModuleDef}; @@ -289,7 +289,7 @@ pub struct GenerateRunConfig { pub wasm_file: Option, pub js_file: Option, pub lang: Language, - pub namespace: String, + pub namespace: Option, pub module_name: Option, pub module_prefix: Option, pub build_options: String, @@ -316,9 +316,7 @@ fn prepare_generate_run_configs<'a>( let wasm_file = command_config.get_one::("wasm_file")?; let js_file = command_config.get_one::("js_file")?; let requested_lang = command_config.get_one::("language")?; - let namespace = command_config - .get_one::("namespace")? - .unwrap_or_else(|| "SpacetimeDB.Types".to_string()); + let namespace = command_config.get_one::("namespace")?; let module_name = command_config.get_one::("unreal_module_name")?; let module_prefix = command_config.get_one::("module_prefix")?; let build_options = command_config @@ -403,6 +401,9 @@ fn detect_default_language(client_project_dir: &Path) -> anyhow::Result &'static str { Language::Csharp => "csharp", Language::TypeScript => "typescript", Language::UnrealCpp => "unrealcpp", + Language::Kotlin => "kotlin", } } pub fn default_out_dir_for_language(lang: Language) -> Option { match lang { Language::Rust | Language::TypeScript => Some(PathBuf::from("src/module_bindings")), - Language::Csharp => Some(PathBuf::from("module_bindings")), + Language::Csharp | Language::Kotlin => Some(PathBuf::from("module_bindings")), Language::UnrealCpp => None, } } @@ -467,8 +469,10 @@ pub async fn run_prepared_generate_configs( run.project_path.display() ); - if namespace_from_cli && run.lang != Language::Csharp { - return Err(anyhow::anyhow!("--namespace is only supported with --lang csharp")); + if namespace_from_cli && !matches!(run.lang, Language::Csharp | Language::Kotlin) { + return Err(anyhow::anyhow!( + "--namespace is only supported with --lang csharp or --lang kotlin" + )); } let module: ModuleDef = if let Some(paths) = &json_module { @@ -509,10 +513,11 @@ pub async fn run_prepared_generate_configs( let csharp_lang; let unreal_cpp_lang; + let kotlin_lang; let gen_lang = match run.lang { Language::Csharp => { csharp_lang = Csharp { - namespace: &run.namespace, + namespace: run.namespace.as_deref().unwrap_or_else(|| "SpacetimeDB.Types"), }; &csharp_lang as &dyn Lang } @@ -526,6 +531,17 @@ pub async fn run_prepared_generate_configs( } Language::Rust => &Rust, Language::TypeScript => &TypeScript, + Language::Kotlin => { + let pkg = if namespace_from_cli { + run.namespace.as_deref().unwrap() + } else { + "spacetimedb" + }; + kotlin_lang = Kotlin { + package_name: pkg, + }; + &kotlin_lang as &dyn Lang + } }; for OutputFile { filename, code } in generate(&module, gen_lang, &options) { @@ -688,11 +704,18 @@ pub enum Language { Rust, #[serde(alias = "uecpp", alias = "ue5cpp", alias = "unreal")] UnrealCpp, + Kotlin, } impl clap::ValueEnum for Language { fn value_variants<'a>() -> &'a [Self] { - &[Self::Csharp, Self::TypeScript, Self::Rust, Self::UnrealCpp] + &[ + Self::Csharp, + Self::TypeScript, + Self::Rust, + Self::UnrealCpp, + Self::Kotlin, + ] } fn to_possible_value(&self) -> Option { Some(match self { @@ -700,6 +723,7 @@ impl clap::ValueEnum for Language { Self::TypeScript => clap::builder::PossibleValue::new("typescript").aliases(["ts", "TS"]), Self::Rust => clap::builder::PossibleValue::new("rust").aliases(["rs", "RS"]), Self::UnrealCpp => PossibleValue::new("unrealcpp").aliases(["uecpp", "ue5cpp", "unreal"]), + Self::Kotlin => PossibleValue::new("kotlin").aliases(["kt", "Kotlin"]), }) } } @@ -712,6 +736,7 @@ impl Language { Language::Csharp => "C#", Language::TypeScript => "TypeScript", Language::UnrealCpp => "Unreal C++", + Language::Kotlin => "Kotlin", } } @@ -725,6 +750,9 @@ impl Language { Language::UnrealCpp => { // TODO: implement formatting. } + Language::Kotlin => { + // TODO: implement formatting via ktlint or similar. + } } Ok(()) diff --git a/crates/codegen/src/kotlin.rs b/crates/codegen/src/kotlin.rs new file mode 100644 index 00000000000..d17627cc5cd --- /dev/null +++ b/crates/codegen/src/kotlin.rs @@ -0,0 +1,662 @@ +use crate::util::{ + collect_case, iter_indexes, iter_procedures, iter_reducers, iter_table_names_and_types, + print_auto_generated_file_comment, type_ref_name, +}; +use crate::{CodegenOptions, OutputFile}; + +use super::code_indenter::CodeIndenter; +use super::util::fmt_fn; +use super::Lang; + +use convert_case::{Case, Casing}; +use spacetimedb_lib::sats::layout::PrimitiveType; +use spacetimedb_schema::def::{ModuleDef, ReducerDef, TableDef, TypeDef}; +use spacetimedb_schema::identifier::Identifier; +use spacetimedb_schema::schema::TableSchema; +use spacetimedb_schema::type_for_generate::{AlgebraicTypeDef, AlgebraicTypeUse}; + +use std::fmt; +use std::ops::Deref; + +const INDENT: &str = " "; + +pub struct Kotlin<'opts> { + pub package_name: &'opts str, +} + +fn pkg_path(package_name: &str) -> String { + package_name.replace('.', "/") +} + +fn print_file_header(output: &mut CodeIndenter, package_name: &str, subpackage: &str) { + let full_package = if subpackage.is_empty() { + package_name.to_string() + } else { + format!("{package_name}.{subpackage}") + }; + print_auto_generated_file_comment(output); + writeln!(output, "@file:Suppress(\"RedundantVisibilityModifier\")"); + writeln!(output); + writeln!(output, "package {full_package}"); + writeln!(output); + writeln!(output, "import com.clockworklabs.spacetimedb.*"); + writeln!(output, "import com.clockworklabs.spacetimedb.bsatn.*"); + writeln!(output, "import com.clockworklabs.spacetimedb.query.*"); + writeln!(output, "import kotlin.uuid.Uuid"); + if !subpackage.is_empty() { + writeln!(output, "import {package_name}.*"); + if subpackage != "types" { + writeln!(output, "import {package_name}.types.*"); + } + } + if subpackage == "reducers" || subpackage == "procedures" { + writeln!(output, "import {package_name}.tables.*"); + } + if subpackage.is_empty() { + writeln!(output, "import {package_name}.types.*"); + writeln!(output, "import {package_name}.tables.*"); + // Reducers and procedures subpackages are imported later for their internal functions + } + writeln!(output); +} + +fn ty_fmt<'a>(module: &'a ModuleDef, ty: &'a AlgebraicTypeUse) -> impl fmt::Display + 'a { + fmt_fn(move |f| match ty { + AlgebraicTypeUse::Identity => f.write_str("Identity"), + AlgebraicTypeUse::ConnectionId => f.write_str("ConnectionId"), + AlgebraicTypeUse::ScheduleAt => f.write_str("ScheduleAt"), + AlgebraicTypeUse::Timestamp => f.write_str("Timestamp"), + AlgebraicTypeUse::TimeDuration => f.write_str("TimeDuration"), + AlgebraicTypeUse::Uuid => f.write_str("Uuid"), + AlgebraicTypeUse::Unit => f.write_str("Unit"), + AlgebraicTypeUse::Option(inner_ty) => write!(f, "{}?", ty_fmt(module, inner_ty)), + AlgebraicTypeUse::Result { ok_ty, err_ty } => write!(f, "Result<{}, {}>", ty_fmt(module, ok_ty), ty_fmt(module, err_ty)), + AlgebraicTypeUse::Array(elem_ty) => write!(f, "List<{}>", ty_fmt(module, elem_ty)), + AlgebraicTypeUse::String => f.write_str("String"), + AlgebraicTypeUse::Ref(r) => f.write_str(&type_ref_name(module, *r)), + AlgebraicTypeUse::Primitive(prim) => f.write_str(match prim { + PrimitiveType::Bool => "Boolean", + PrimitiveType::I8 => "Byte", + PrimitiveType::U8 => "UByte", + PrimitiveType::I16 => "Short", + PrimitiveType::U16 => "UShort", + PrimitiveType::I32 => "Int", + PrimitiveType::U32 => "UInt", + PrimitiveType::I64 => "Long", + PrimitiveType::U64 => "ULong", + PrimitiveType::I128 => "spacetimedb_lib.i256", + PrimitiveType::U128 => "spacetimedb_lib.u256", + PrimitiveType::I256 => "spacetimedb_lib.i256", + PrimitiveType::U256 => "spacetimedb_lib.u256", + PrimitiveType::F32 => "Float", + PrimitiveType::F64 => "Double", + }), + AlgebraicTypeUse::Never => unreachable!(), + }) +} + +fn write_bsatn_serialize_field( + module: &ModuleDef, + output: &mut CodeIndenter, + prefix: &str, + field_name: &Identifier, + field_type: &AlgebraicTypeUse, +) { + let field_expr = format!("{}.{}", prefix, field_name.deref().to_case(Case::Camel)); + write_bsatn_serialize_expr(module, output, &field_expr, field_type); +} + +fn write_bsatn_serialize_expr( + module: &ModuleDef, + output: &mut CodeIndenter, + expr: &str, + ty: &AlgebraicTypeUse, +) { + write_bsatn_serialize_expr_with_writer(module, output, "writer", expr, ty); +} + +fn write_bsatn_serialize_expr_with_writer( + module: &ModuleDef, + output: &mut CodeIndenter, + writer_var: &str, + expr: &str, + ty: &AlgebraicTypeUse, +) { + match ty { + AlgebraicTypeUse::Primitive(prim) => { + let method = match prim { + PrimitiveType::Bool => "writeBool", + PrimitiveType::I8 => "writeI8", + PrimitiveType::U8 => "writeU8", + PrimitiveType::I16 => "writeI16", + PrimitiveType::U16 => "writeU16", + PrimitiveType::I32 => "writeI32", + PrimitiveType::U32 => "writeU32", + PrimitiveType::I64 => "writeI64", + PrimitiveType::U64 => "writeU64", + PrimitiveType::F32 => "writeF32", + PrimitiveType::F64 => "writeF64", + PrimitiveType::I128 | PrimitiveType::U128 | PrimitiveType::I256 | PrimitiveType::U256 => "writeByteArray", + }; + writeln!(output, "{writer_var}.{method}({expr})"); + } + AlgebraicTypeUse::String => { + writeln!(output, "{writer_var}.writeString({expr})"); + } + AlgebraicTypeUse::Identity => { + writeln!(output, "Identity.write({writer_var}, {expr})"); + } + AlgebraicTypeUse::ConnectionId => { + writeln!(output, "ConnectionId.write({writer_var}, {expr})"); + } + AlgebraicTypeUse::Timestamp => { + writeln!(output, "Timestamp.write({writer_var}, {expr})"); + } + AlgebraicTypeUse::TimeDuration => { + writeln!(output, "// TODO: serialize TimeDuration {expr}"); + } + AlgebraicTypeUse::ScheduleAt => { + writeln!(output, "// TODO: serialize ScheduleAt {expr}"); + } + AlgebraicTypeUse::Uuid => { + writeln!(output, "Uuid.write({writer_var}, {expr})"); + } + AlgebraicTypeUse::Option(inner) => { + writeln!(output, "{writer_var}.writeOption({expr}) {{ v, inner ->"); + output.indent(1); + write_bsatn_serialize_expr_with_writer(module, output, "v", "inner", inner); + output.dedent(1); + writeln!(output, "}}"); + } + AlgebraicTypeUse::Array(elem) => { + writeln!(output, "{writer_var}.writeArray({expr}) {{ w, elem ->"); + output.indent(1); + write_bsatn_serialize_expr_with_writer(module, output, "w", "elem", elem); + output.dedent(1); + writeln!(output, "}}"); + } + AlgebraicTypeUse::Result { .. } => { + writeln!(output, "// TODO: serialize Result {expr}"); + } + AlgebraicTypeUse::Unit => {} + AlgebraicTypeUse::Ref(r) => { + let type_name = type_ref_name(module, *r); + writeln!(output, "{type_name}.write({writer_var}, {expr})"); + } + AlgebraicTypeUse::Never => unreachable!(), + } +} + +fn write_bsatn_deserialize_field( + module: &ModuleDef, + output: &mut CodeIndenter, + field_name: &Identifier, + field_type: &AlgebraicTypeUse, +) { + let camel_name = field_name.deref().to_case(Case::Camel); + write!(output, "{camel_name} = "); + write_bsatn_deserialize_expr(module, output, "reader", field_type); + writeln!(output, ","); +} + +fn write_bsatn_deserialize_expr( + module: &ModuleDef, + output: &mut CodeIndenter, + reader_var: &str, + ty: &AlgebraicTypeUse, +) { + match ty { + AlgebraicTypeUse::Primitive(prim) => { + let method = match prim { + PrimitiveType::Bool => "readBool", + PrimitiveType::I8 => "readI8", + PrimitiveType::U8 => "readU8", + PrimitiveType::I16 => "readI16", + PrimitiveType::U16 => "readU16", + PrimitiveType::I32 => "readI32", + PrimitiveType::U32 => "readU32", + PrimitiveType::I64 => "readI64", + PrimitiveType::U64 => "readU64", + PrimitiveType::F32 => "readF32", + PrimitiveType::F64 => "readF64", + PrimitiveType::I128 | PrimitiveType::U128 | PrimitiveType::I256 | PrimitiveType::U256 => "readByteArray", + }; + write!(output, "{reader_var}.{method}()"); + } + AlgebraicTypeUse::String => { + write!(output, "{reader_var}.readString()"); + } + AlgebraicTypeUse::Identity => { + write!(output, "Identity.read({reader_var})"); + } + AlgebraicTypeUse::ConnectionId => { + write!(output, "ConnectionId.read({reader_var})"); + } + AlgebraicTypeUse::Timestamp => { + write!(output, "Timestamp.read({reader_var})"); + } + AlgebraicTypeUse::TimeDuration => { + write!(output, "TODO(\"read TimeDuration\")"); + } + AlgebraicTypeUse::ScheduleAt => { + write!(output, "TODO(\"read ScheduleAt\")"); + } + AlgebraicTypeUse::Uuid => { + write!(output, "Uuid.read({reader_var})"); + } + AlgebraicTypeUse::Option(inner) => { + write!(output, "{reader_var}.readOption {{ r -> "); + write_bsatn_deserialize_expr(module, output, "r", inner); + write!(output, " }}"); + } + AlgebraicTypeUse::Array(elem) => { + write!(output, "{reader_var}.readArray {{ r -> "); + write_bsatn_deserialize_expr(module, output, "r", elem); + write!(output, " }}"); + } + AlgebraicTypeUse::Result { .. } => { + write!(output, "TODO(\"read Result\")"); + } + AlgebraicTypeUse::Unit => { + write!(output, "Unit"); + } + AlgebraicTypeUse::Ref(r) => { + let type_name = type_ref_name(module, *r); + write!(output, "{type_name}.read({reader_var})"); + } + AlgebraicTypeUse::Never => unreachable!(), + } +} + +impl Lang for Kotlin<'_> { + fn generate_table_file_from_schema( + &self, + module: &ModuleDef, + table: &TableDef, + _schema: TableSchema, + ) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out, &self.package_name, "tables"); + + let row_type = type_ref_name(module, table.product_type_ref); + let table_class_name = format!("{}Handle", table.accessor_name.deref().to_case(Case::Pascal)); + let accessor_pascal = table.accessor_name.deref().to_case(Case::Pascal); + + let base = if table.is_event { "EventTable" } else { "TableWithPrimaryKey" }; + + writeln!(out, "class {table_class_name}(private val conn: DbConnection) : {base}<{row_type}> {{"); + writeln!(out, " override val tableName: String get() = \"{}\"", table.name); + writeln!(out); + writeln!(out, " override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0"); + writeln!(out); + writeln!(out, " override fun iter(): Sequence<{row_type}> {{"); + writeln!(out, " val cache = conn.clientCache.getTable(tableName) ?: return emptySequence()"); + writeln!(out, " return cache.allRows().map {{ bytes ->"); + writeln!(out, " {row_type}.read(BsatnReader(bytes))"); + writeln!(out, " }}.asSequence()"); + writeln!(out, " }}"); + writeln!(out); + writeln!(out, " override fun onInsert(callback: (EventContext<*>, {row_type}) -> Unit): CallbackId {{"); + writeln!(out, " return conn.table(tableName).onInsert {{ bytes ->"); + writeln!(out, " val row = {row_type}.read(BsatnReader(bytes))"); + writeln!(out, " callback(conn.makeEventContext(), row)"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out); + writeln!(out, " override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id)"); + writeln!(out); + writeln!(out, " override fun onDelete(callback: (EventContext<*>, {row_type}) -> Unit): CallbackId {{"); + writeln!(out, " return conn.table(tableName).onDelete {{ bytes ->"); + writeln!(out, " val row = {row_type}.read(BsatnReader(bytes))"); + writeln!(out, " callback(conn.makeEventContext(), row)"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out); + writeln!(out, " override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id)"); + writeln!(out); + + if !table.is_event { + writeln!(out, " override fun onUpdate(callback: (EventContext<*>, {row_type}, {row_type}) -> Unit): CallbackId {{"); + writeln!(out, " return conn.table(tableName).onUpdate {{ oldBytes, newBytes ->"); + writeln!(out, " val oldRow = {row_type}.read(BsatnReader(oldBytes))"); + writeln!(out, " val newRow = {row_type}.read(BsatnReader(newBytes))"); + writeln!(out, " callback(conn.makeEventContext(), oldRow, newRow)"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out); + writeln!(out, " override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id)"); + writeln!(out); + } + + // Index-based lookups (e.g. findByEmail) can be added here + // by generating methods that query the cache by unique constraint. + + + writeln!(out, "}}"); + writeln!(out); + + // Generate typed column accessors + let cols_name = format!("{}Cols", accessor_pascal); + let product_def = module.typespace_for_generate()[table.product_type_ref].as_product().unwrap(); + writeln!(out, "class {cols_name}(tableName: String) : Cols<{row_type}>(tableName) {{"); + for (field_name, field_type) in &product_def.elements { + let camel = field_name.deref().to_case(Case::Camel); + let ty_str = ty_fmt(module, field_type).to_string(); + writeln!(out, " val {camel}: Col<{ty_str}> = Col(\"{camel}\")"); + } + writeln!(out, "}}"); + + // Generate IxCols for indexed columns + let mut ix_col_positions: Vec = Vec::new(); + for idx in iter_indexes(table) { + for col_pos in idx.algorithm.columns().iter() { + if !ix_col_positions.contains(&col_pos.idx()) { + ix_col_positions.push(col_pos.idx()); + } + } + } + if !ix_col_positions.is_empty() { + writeln!(out); + let ixcols_name = format!("{}IxCols", accessor_pascal); + writeln!(out, "class {ixcols_name}(tableName: String) : Cols<{row_type}>(tableName) {{"); + for &pos in &ix_col_positions { + if let Some((field_name, field_type)) = product_def.elements.get(pos) { + let camel = field_name.deref().to_case(Case::Camel); + let ty_str = ty_fmt(module, field_type).to_string(); + writeln!(out, " val {camel}: Col<{ty_str}> = Col(\"{camel}\")"); + } + } + writeln!(out, "}}"); + } + + OutputFile { + filename: format!("{}/tables/{accessor_pascal}.kt", pkg_path(&self.package_name)), + code: output.into_inner(), + } + } + + fn generate_type_files(&self, module: &ModuleDef, typ: &TypeDef) -> Vec { + let name = collect_case(Case::Pascal, typ.accessor_name.name_segments()); + let pkg_prefix = pkg_path(&self.package_name); + let filename = format!("{pkg_prefix}/types/{name}.kt"); + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out, &self.package_name, "types"); + + match &module.typespace_for_generate()[typ.ty] { + AlgebraicTypeDef::Product(product) => { + writeln!(out, "data class {name}("); + out.indent(1); + for (field_name, field_type) in &product.elements { + let camel_name = field_name.deref().to_case(Case::Camel); + let ty_str = ty_fmt(module, field_type).to_string(); + writeln!(out, "val {camel_name}: {ty_str},"); + } + out.dedent(1); + writeln!(out, ") {{"); + writeln!(out, " companion object {{"); + writeln!(out, " fun read(reader: BsatnReader): {name} ="); + writeln!(out, " {name}("); + out.indent(4); + for (field_name, field_type) in &product.elements { + write_bsatn_deserialize_field(module, out, field_name, field_type); + } + out.dedent(4); + writeln!(out, " )"); + writeln!(out); + writeln!(out, " fun write(writer: BsatnWriter, value: {name}) {{"); + out.indent(3); + let elements_copy = product.elements.clone(); + for (field_name, field_type) in &elements_copy { + write_bsatn_serialize_field(module, out, "value", field_name, field_type); + } + out.dedent(3); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out, "}}"); + } + AlgebraicTypeDef::Sum(sum) => { + writeln!(out, "sealed class {name} {{"); + for (variant_name, variant_type) in &sum.variants { + let pascal_variant = variant_name.deref().to_case(Case::Pascal); + match variant_type { + AlgebraicTypeUse::Unit => { + writeln!(out, " data object {pascal_variant} : {name}()"); + } + _ => { + let ty_str = ty_fmt(module, variant_type).to_string(); + writeln!(out, " data class {pascal_variant}(val value: {ty_str}) : {name}()"); + } + } + } + writeln!(out); + writeln!(out, " companion object {{"); + writeln!(out, " fun read(reader: BsatnReader): {name} {{"); + writeln!(out, " val tag = reader.readTag().toInt()"); + writeln!(out, " return when (tag) {{"); + for (i, (variant_name, variant_type)) in sum.variants.iter().enumerate() { + let pascal_variant = variant_name.deref().to_case(Case::Pascal); + match variant_type { + AlgebraicTypeUse::Unit => { + writeln!(out, " {i} -> {name}.{pascal_variant}"); + } + _ => { + writeln!(out, " {i} -> {name}.{pascal_variant}(TODO(\"read variant payload\"))"); + } + } + } + writeln!(out, " else -> throw IllegalStateException(\"Unknown {name} tag\")"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out); + writeln!(out, " fun write(writer: BsatnWriter, value: {name}) {{"); + writeln!(out, " when (value) {{"); + for (i, (variant_name, variant_type)) in sum.variants.iter().enumerate() { + let pascal_variant = variant_name.deref().to_case(Case::Pascal); + match variant_type { + AlgebraicTypeUse::Unit => { + writeln!(out, " is {name}.{pascal_variant} -> writer.writeTag({i}u)"); + } + _ => { + writeln!(out, " is {name}.{pascal_variant} -> writer.writeTag({i}u) // TODO: write variant payload"); + } + } + } + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out, "}}"); + } + AlgebraicTypeDef::PlainEnum(plain_enum) => { + writeln!(out, "enum class {name} {{"); + for (_i, variant_name) in plain_enum.variants.iter().enumerate() { + let pascal_variant = variant_name.deref().to_case(Case::Pascal); + writeln!(out, " {pascal_variant},"); + } + writeln!(out, ";"); + writeln!(out); + writeln!(out, " companion object {{"); + writeln!(out, " fun read(reader: BsatnReader): {name} ="); + writeln!(out, " entries[reader.readU8().toInt()]"); + writeln!(out); + writeln!(out, " fun write(writer: BsatnWriter, value: {name}) {{"); + writeln!(out, " writer.writeU8(value.ordinal.toUByte())"); + writeln!(out, " }}"); + writeln!(out, " }}"); + writeln!(out, "}}"); + } + } + + vec![OutputFile { filename, code: output.into_inner() }] + } + + fn generate_reducer_file(&self, module: &ModuleDef, reducer: &ReducerDef) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out, &self.package_name, "reducers"); + + let pascal_name = reducer.accessor_name.deref().to_case(Case::Pascal); + let camel_name = reducer.accessor_name.deref().to_case(Case::Camel); + + if !reducer.params_for_generate.elements.is_empty() { + writeln!(out, "data class {pascal_name}Args("); + out.indent(1); + for (param_name, param_type) in &reducer.params_for_generate { + let camel_param = param_name.deref().to_case(Case::Camel); + let ty_str = ty_fmt(module, param_type).to_string(); + writeln!(out, "val {camel_param}: {ty_str},"); + } + out.dedent(1); + writeln!(out, ")"); + } else { + writeln!(out, "data object {pascal_name}Args"); + } + writeln!(out); + writeln!(out, "internal fun {camel_name}Reducer(conn: DbConnection, args: {pascal_name}Args, callback: ((ReducerResult) -> Unit)? = null) {{"); + out.indent(1); + if reducer.params_for_generate.elements.is_empty() { + writeln!(out, "val bytes = ByteArray(0)"); + } else { + writeln!(out, "val writer = BsatnWriter()"); + for (param_name, param_type) in &reducer.params_for_generate { + write_bsatn_serialize_field(module, out, "args", param_name, param_type); + } + writeln!(out, "val bytes = writer.toByteArray()"); + } + writeln!(out, "conn.callReducer(\"{name}\", bytes, callback)", name = reducer.name); + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: format!("{}/reducers/{pascal_name}.kt", pkg_path(&self.package_name)), + code: output.into_inner(), + } + } + + fn generate_procedure_file( + &self, + module: &ModuleDef, + procedure: &spacetimedb_schema::def::ProcedureDef, + ) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out, &self.package_name, "procedures"); + + let pascal_name = procedure.accessor_name.deref().to_case(Case::Pascal); + let camel_name = procedure.accessor_name.deref().to_case(Case::Camel); + + if !procedure.params_for_generate.elements.is_empty() { + writeln!(out, "data class {pascal_name}Args("); + out.indent(1); + for (param_name, param_type) in &procedure.params_for_generate { + let camel_param = param_name.deref().to_case(Case::Camel); + let ty_str = ty_fmt(module, param_type).to_string(); + writeln!(out, "val {camel_param}: {ty_str},"); + } + out.dedent(1); + writeln!(out, ")"); + } else { + writeln!(out, "data object {pascal_name}Args"); + } + writeln!(out); + writeln!(out, "internal fun {camel_name}Procedure(conn: DbConnection, args: {pascal_name}Args, callback: ((ProcedureResult) -> Unit)? = null) {{"); + out.indent(1); + writeln!(out, "val writer = BsatnWriter()"); + for (param_name, param_type) in &procedure.params_for_generate { + write_bsatn_serialize_field(module, out, "args", param_name, param_type); + } + writeln!(out, "conn.callProcedure(\"{name}\", writer.toByteArray(), callback)", name = procedure.name); + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: format!("{}/procedures/{pascal_name}.kt", pkg_path(&self.package_name)), + code: output.into_inner(), + } + } + + fn generate_global_files(&self, module: &ModuleDef, options: &CodegenOptions) -> Vec { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out, &self.package_name, ""); + + let has_reducers = iter_reducers(module, options.visibility).count() > 0; + let has_procedures = iter_procedures(module, options.visibility).count() > 0; + if has_reducers { + writeln!(out, "import {pkg}.reducers.*", pkg = self.package_name); + } + if has_procedures { + writeln!(out, "import {pkg}.procedures.*", pkg = self.package_name); + } + if has_reducers || has_procedures { + writeln!(out); + } + + // RemoteTables + writeln!(out, "class RemoteTables(private val conn: DbConnection) {{"); + for (_, accessor_name, _product_type_ref) in iter_table_names_and_types(module, options.visibility) { + let camel_table = accessor_name.deref().to_case(Case::Camel); + let pascal_table = accessor_name.deref().to_case(Case::Pascal); + let table_class = format!("{pascal_table}Handle"); + writeln!(out, " val {camel_table}: {table_class} = {table_class}(conn)"); + } + writeln!(out, "}}"); + writeln!(out); + + // RemoteReducers + writeln!(out, "class RemoteReducers(val conn: DbConnection) {{"); + writeln!(out, " internal var onUnhandledReducerError: ((ReducerEventContext, Exception) -> Unit)? = null"); + for reducer in iter_reducers(module, options.visibility) { + let camel = reducer.accessor_name.deref().to_case(Case::Camel); + let pascal = reducer.accessor_name.deref().to_case(Case::Pascal); + writeln!(out, " fun {camel}(args: {pascal}Args, callback: ((ReducerResult) -> Unit)? = null) = {camel}Reducer(conn, args, callback)"); + } + writeln!(out, "}}"); + writeln!(out); + // RemoteProcedures + writeln!(out, "class RemoteProcedures(val conn: DbConnection) {{"); + for procedure in iter_procedures(module, options.visibility) { + let camel = procedure.accessor_name.deref().to_case(Case::Camel); + let pascal = procedure.accessor_name.deref().to_case(Case::Pascal); + writeln!(out, " fun {camel}(args: {pascal}Args, callback: ((ProcedureResult) -> Unit)? = null) = {camel}Procedure(conn, args, callback)"); + } + writeln!(out, "}}"); + writeln!(out); + + // From class — typed table accessors for query building + writeln!(out, "class From {{"); + for (name, accessor_name, product_type_ref) in iter_table_names_and_types(module, options.visibility) { + let row_type = type_ref_name(module, product_type_ref); + let camel = accessor_name.deref().to_case(Case::Camel); + let cols_name = format!("{}Cols", accessor_name.deref().to_case(Case::Pascal)); + let sql_name = name.deref(); + writeln!(out, " val {camel}: QueryTable<{row_type}> = QueryTable(\"{sql_name}\") {{ {cols_name}(it) }}"); + } + writeln!(out, "}}"); + writeln!(out); + writeln!(out, "fun SubscriptionBuilder.addQuery(provider: From.() -> QueryProvider): SubscriptionBuilder {{"); + writeln!(out, " val from = From()"); + writeln!(out, " return addQueryFrom(from.run(provider))"); + writeln!(out, "}}"); + writeln!(out); + writeln!(out, "val DbConnection.db: RemoteTables get() = RemoteTables(this)"); + writeln!(out, "val DbConnection.reducers: RemoteReducers get() = RemoteReducers(this)"); + writeln!(out, "val DbConnection.procedures: RemoteProcedures get() = RemoteProcedures(this)"); + writeln!(out); + writeln!(out, "// EventContext factory"); + writeln!(out, "fun DbConnection.makeEventContext(): EventContext {{"); + writeln!(out, " return EventContext(identity, connectionId, savedToken, isActive, connectionState, Event.Transaction, this)"); + writeln!(out, "}}"); + + vec![OutputFile { + filename: format!("{}/RemoteModule.kt", pkg_path(&self.package_name)), + code: output.into_inner(), + }] + } +} diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 28d4fb8a5a4..ed84bca0e7a 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -3,12 +3,14 @@ use spacetimedb_schema::schema::{Schema, TableSchema}; mod code_indenter; pub mod cpp; pub mod csharp; +pub mod kotlin; pub mod rust; pub mod typescript; pub mod unrealcpp; mod util; pub use self::csharp::Csharp; +pub use self::kotlin::Kotlin; pub use self::rust::Rust; pub use self::typescript::TypeScript; pub use self::unrealcpp::UnrealCpp; diff --git a/crates/codegen/tests/codegen.rs b/crates/codegen/tests/codegen.rs index 06dc3ebe8fc..bb338a07583 100644 --- a/crates/codegen/tests/codegen.rs +++ b/crates/codegen/tests/codegen.rs @@ -1,4 +1,4 @@ -use spacetimedb_codegen::{generate, CodegenOptions, Csharp, Rust, TypeScript}; +use spacetimedb_codegen::{generate, CodegenOptions, Csharp, Kotlin, Rust, TypeScript}; use spacetimedb_data_structures::map::HashMap; use spacetimedb_schema::def::ModuleDef; use spacetimedb_testing::modules::{CompilationMode, CompiledModule}; @@ -38,4 +38,5 @@ declare_tests! { test_codegen_csharp => Csharp { namespace: "SpacetimeDB" }, test_codegen_typescript => TypeScript, test_codegen_rust => Rust, + test_codegen_kotlin => Kotlin { package_name: "spacetimedb" }, } diff --git a/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap b/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap new file mode 100644 index 00000000000..6193f32548b --- /dev/null +++ b/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap @@ -0,0 +1,1472 @@ +--- +source: crates/codegen/tests/codegen.rs +assertion_line: 37 +expression: outfiles +--- +"spacetimedb/RemoteModule.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.types.* +import spacetimedb.tables.* + +import spacetimedb.reducers.* +import spacetimedb.procedures.* + +class RemoteTables(private val conn: DbConnection) { + val loggedOutPlayer: LoggedOutPlayerHandle = LoggedOutPlayerHandle(conn) + val myPlayer: MyPlayerHandle = MyPlayerHandle(conn) + val person: PersonHandle = PersonHandle(conn) + val player: PlayerHandle = PlayerHandle(conn) + val testD: TestDHandle = TestDHandle(conn) + val testF: TestFHandle = TestFHandle(conn) +} + +class RemoteReducers(val conn: DbConnection) { + internal var onUnhandledReducerError: ((ReducerEventContext, Exception) -> Unit)? = null + fun add(args: AddArgs, callback: ((ReducerResult) -> Unit)? = null) = addReducer(conn, args, callback) + fun addPlayer(args: AddPlayerArgs, callback: ((ReducerResult) -> Unit)? = null) = addPlayerReducer(conn, args, callback) + fun addPrivate(args: AddPrivateArgs, callback: ((ReducerResult) -> Unit)? = null) = addPrivateReducer(conn, args, callback) + fun assertCallerIdentityIsModuleIdentity(args: AssertCallerIdentityIsModuleIdentityArgs, callback: ((ReducerResult) -> Unit)? = null) = assertCallerIdentityIsModuleIdentityReducer(conn, args, callback) + fun deletePlayer(args: DeletePlayerArgs, callback: ((ReducerResult) -> Unit)? = null) = deletePlayerReducer(conn, args, callback) + fun deletePlayersByName(args: DeletePlayersByNameArgs, callback: ((ReducerResult) -> Unit)? = null) = deletePlayersByNameReducer(conn, args, callback) + fun listOverAge(args: ListOverAgeArgs, callback: ((ReducerResult) -> Unit)? = null) = listOverAgeReducer(conn, args, callback) + fun logModuleIdentity(args: LogModuleIdentityArgs, callback: ((ReducerResult) -> Unit)? = null) = logModuleIdentityReducer(conn, args, callback) + fun queryPrivate(args: QueryPrivateArgs, callback: ((ReducerResult) -> Unit)? = null) = queryPrivateReducer(conn, args, callback) + fun sayHello(args: SayHelloArgs, callback: ((ReducerResult) -> Unit)? = null) = sayHelloReducer(conn, args, callback) + fun test(args: TestArgs, callback: ((ReducerResult) -> Unit)? = null) = testReducer(conn, args, callback) + fun testBtreeIndexArgs(args: TestBtreeIndexArgsArgs, callback: ((ReducerResult) -> Unit)? = null) = testBtreeIndexArgsReducer(conn, args, callback) +} + +class RemoteProcedures(val conn: DbConnection) { + fun getMySchemaViaHttp(args: GetMySchemaViaHttpArgs, callback: ((ProcedureResult) -> Unit)? = null) = getMySchemaViaHttpProcedure(conn, args, callback) + fun returnValue(args: ReturnValueArgs, callback: ((ProcedureResult) -> Unit)? = null) = returnValueProcedure(conn, args, callback) + fun sleepOneSecond(args: SleepOneSecondArgs, callback: ((ProcedureResult) -> Unit)? = null) = sleepOneSecondProcedure(conn, args, callback) + fun withTx(args: WithTxArgs, callback: ((ProcedureResult) -> Unit)? = null) = withTxProcedure(conn, args, callback) +} + +class From { + val loggedOutPlayer: QueryTable = QueryTable("logged_out_player") { LoggedOutPlayerCols(it) } + val myPlayer: QueryTable = QueryTable("my_player") { MyPlayerCols(it) } + val person: QueryTable = QueryTable("person") { PersonCols(it) } + val player: QueryTable = QueryTable("player") { PlayerCols(it) } + val testD: QueryTable = QueryTable("test_d") { TestDCols(it) } + val testF: QueryTable = QueryTable("test_f") { TestFCols(it) } +} + +fun SubscriptionBuilder.addQuery(provider: From.() -> QueryProvider): SubscriptionBuilder { + val from = From() + return addQueryFrom(from.run(provider)) +} + +val DbConnection.db: RemoteTables get() = RemoteTables(this) +val DbConnection.reducers: RemoteReducers get() = RemoteReducers(this) +val DbConnection.procedures: RemoteProcedures get() = RemoteProcedures(this) + +// EventContext factory +fun DbConnection.makeEventContext(): EventContext { + return EventContext(identity, connectionId, savedToken, isActive, connectionState, Event.Transaction, this) +} +''' +"spacetimedb/procedures/GetMySchemaViaHttp.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.procedures + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object GetMySchemaViaHttpArgs + +internal fun getMySchemaViaHttpProcedure(conn: DbConnection, args: GetMySchemaViaHttpArgs, callback: ((ProcedureResult) -> Unit)? = null) { + val writer = BsatnWriter() + conn.callProcedure("get_my_schema_via_http", writer.toByteArray(), callback) +} +''' +"spacetimedb/procedures/ReturnValue.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.procedures + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class ReturnValueArgs( + val foo: ULong, +) + +internal fun returnValueProcedure(conn: DbConnection, args: ReturnValueArgs, callback: ((ProcedureResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeU64(args.foo) + conn.callProcedure("return_value", writer.toByteArray(), callback) +} +''' +"spacetimedb/procedures/SleepOneSecond.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.procedures + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object SleepOneSecondArgs + +internal fun sleepOneSecondProcedure(conn: DbConnection, args: SleepOneSecondArgs, callback: ((ProcedureResult) -> Unit)? = null) { + val writer = BsatnWriter() + conn.callProcedure("sleep_one_second", writer.toByteArray(), callback) +} +''' +"spacetimedb/procedures/WithTx.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.procedures + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object WithTxArgs + +internal fun withTxProcedure(conn: DbConnection, args: WithTxArgs, callback: ((ProcedureResult) -> Unit)? = null) { + val writer = BsatnWriter() + conn.callProcedure("with_tx", writer.toByteArray(), callback) +} +''' +"spacetimedb/reducers/Add.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class AddArgs( + val name: String, + val age: UByte, +) + +internal fun addReducer(conn: DbConnection, args: AddArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeString(args.name) + writer.writeU8(args.age) + val bytes = writer.toByteArray() + conn.callReducer("add", bytes, callback) +} +''' +"spacetimedb/reducers/AddPlayer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class AddPlayerArgs( + val name: String, +) + +internal fun addPlayerReducer(conn: DbConnection, args: AddPlayerArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeString(args.name) + val bytes = writer.toByteArray() + conn.callReducer("add_player", bytes, callback) +} +''' +"spacetimedb/reducers/AddPrivate.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class AddPrivateArgs( + val name: String, +) + +internal fun addPrivateReducer(conn: DbConnection, args: AddPrivateArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeString(args.name) + val bytes = writer.toByteArray() + conn.callReducer("add_private", bytes, callback) +} +''' +"spacetimedb/reducers/AssertCallerIdentityIsModuleIdentity.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object AssertCallerIdentityIsModuleIdentityArgs + +internal fun assertCallerIdentityIsModuleIdentityReducer(conn: DbConnection, args: AssertCallerIdentityIsModuleIdentityArgs, callback: ((ReducerResult) -> Unit)? = null) { + val bytes = ByteArray(0) + conn.callReducer("assert_caller_identity_is_module_identity", bytes, callback) +} +''' +"spacetimedb/reducers/DeletePlayer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class DeletePlayerArgs( + val id: ULong, +) + +internal fun deletePlayerReducer(conn: DbConnection, args: DeletePlayerArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeU64(args.id) + val bytes = writer.toByteArray() + conn.callReducer("delete_player", bytes, callback) +} +''' +"spacetimedb/reducers/DeletePlayersByName.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class DeletePlayersByNameArgs( + val name: String, +) + +internal fun deletePlayersByNameReducer(conn: DbConnection, args: DeletePlayersByNameArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeString(args.name) + val bytes = writer.toByteArray() + conn.callReducer("delete_players_by_name", bytes, callback) +} +''' +"spacetimedb/reducers/ListOverAge.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class ListOverAgeArgs( + val age: UByte, +) + +internal fun listOverAgeReducer(conn: DbConnection, args: ListOverAgeArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writer.writeU8(args.age) + val bytes = writer.toByteArray() + conn.callReducer("list_over_age", bytes, callback) +} +''' +"spacetimedb/reducers/LogModuleIdentity.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object LogModuleIdentityArgs + +internal fun logModuleIdentityReducer(conn: DbConnection, args: LogModuleIdentityArgs, callback: ((ReducerResult) -> Unit)? = null) { + val bytes = ByteArray(0) + conn.callReducer("log_module_identity", bytes, callback) +} +''' +"spacetimedb/reducers/QueryPrivate.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object QueryPrivateArgs + +internal fun queryPrivateReducer(conn: DbConnection, args: QueryPrivateArgs, callback: ((ReducerResult) -> Unit)? = null) { + val bytes = ByteArray(0) + conn.callReducer("query_private", bytes, callback) +} +''' +"spacetimedb/reducers/SayHello.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object SayHelloArgs + +internal fun sayHelloReducer(conn: DbConnection, args: SayHelloArgs, callback: ((ReducerResult) -> Unit)? = null) { + val bytes = ByteArray(0) + conn.callReducer("say_hello", bytes, callback) +} +''' +"spacetimedb/reducers/Test.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data class TestArgs( + val arg: TestA, + val arg2: TestB, + val arg3: NamespaceTestC, + val arg4: NamespaceTestF, +) + +internal fun testReducer(conn: DbConnection, args: TestArgs, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + TestA.write(writer, args.arg) + TestB.write(writer, args.arg2) + NamespaceTestC.write(writer, args.arg3) + NamespaceTestF.write(writer, args.arg4) + val bytes = writer.toByteArray() + conn.callReducer("test", bytes, callback) +} +''' +"spacetimedb/reducers/TestBtreeIndexArgs.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.reducers + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* +import spacetimedb.tables.* + +data object TestBtreeIndexArgsArgs + +internal fun testBtreeIndexArgsReducer(conn: DbConnection, args: TestBtreeIndexArgsArgs, callback: ((ReducerResult) -> Unit)? = null) { + val bytes = ByteArray(0) + conn.callReducer("test_btree_index_args", bytes, callback) +} +''' +"spacetimedb/tables/LoggedOutPlayer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class LoggedOutPlayerHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "logged_out_player" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + Player.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, Player, Player) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = Player.read(BsatnReader(oldBytes)) + val newRow = Player.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class LoggedOutPlayerCols(tableName: String) : Cols(tableName) { + val identity: Col = Col("identity") + val playerId: Col = Col("playerId") + val name: Col = Col("name") +} + +class LoggedOutPlayerIxCols(tableName: String) : Cols(tableName) { + val identity: Col = Col("identity") + val name: Col = Col("name") + val playerId: Col = Col("playerId") +} +''' +"spacetimedb/tables/MyPlayer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class MyPlayerHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "my_player" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + Player.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, Player, Player) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = Player.read(BsatnReader(oldBytes)) + val newRow = Player.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class MyPlayerCols(tableName: String) : Cols(tableName) { + val identity: Col = Col("identity") + val playerId: Col = Col("playerId") + val name: Col = Col("name") +} +''' +"spacetimedb/tables/Person.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class PersonHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "person" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + Person.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, Person) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = Person.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, Person) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = Person.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, Person, Person) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = Person.read(BsatnReader(oldBytes)) + val newRow = Person.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class PersonCols(tableName: String) : Cols(tableName) { + val id: Col = Col("id") + val name: Col = Col("name") + val age: Col = Col("age") +} + +class PersonIxCols(tableName: String) : Cols(tableName) { + val age: Col = Col("age") + val id: Col = Col("id") +} +''' +"spacetimedb/tables/Player.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class PlayerHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "player" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + Player.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, Player) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = Player.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, Player, Player) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = Player.read(BsatnReader(oldBytes)) + val newRow = Player.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class PlayerCols(tableName: String) : Cols(tableName) { + val identity: Col = Col("identity") + val playerId: Col = Col("playerId") + val name: Col = Col("name") +} + +class PlayerIxCols(tableName: String) : Cols(tableName) { + val identity: Col = Col("identity") + val name: Col = Col("name") + val playerId: Col = Col("playerId") +} +''' +"spacetimedb/tables/TestD.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class TestDHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "test_d" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + TestD.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, TestD) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = TestD.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, TestD) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = TestD.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, TestD, TestD) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = TestD.read(BsatnReader(oldBytes)) + val newRow = TestD.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class TestDCols(tableName: String) : Cols(tableName) { + val testC: Col = Col("testC") + val testCNested: Col?> = Col("testCNested") +} +''' +"spacetimedb/tables/TestF.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.tables + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* +import spacetimedb.types.* + +class TestFHandle(private val conn: DbConnection) : TableWithPrimaryKey { + override val tableName: String get() = "test_f" + + override val count: Int get() = conn.clientCache.getTable(tableName)?.count ?: 0 + + override fun iter(): Sequence { + val cache = conn.clientCache.getTable(tableName) ?: return emptySequence() + return cache.allRows().map { bytes -> + TestFoobar.read(BsatnReader(bytes)) + }.asSequence() + } + + override fun onInsert(callback: (EventContext<*>, TestFoobar) -> Unit): CallbackId { + return conn.table(tableName).onInsert { bytes -> + val row = TestFoobar.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnInsert(id: CallbackId) = conn.table(tableName).removeOnInsert(id) + + override fun onDelete(callback: (EventContext<*>, TestFoobar) -> Unit): CallbackId { + return conn.table(tableName).onDelete { bytes -> + val row = TestFoobar.read(BsatnReader(bytes)) + callback(conn.makeEventContext(), row) + } + } + + override fun removeOnDelete(id: CallbackId) = conn.table(tableName).removeOnDelete(id) + + override fun onUpdate(callback: (EventContext<*>, TestFoobar, TestFoobar) -> Unit): CallbackId { + return conn.table(tableName).onUpdate { oldBytes, newBytes -> + val oldRow = TestFoobar.read(BsatnReader(oldBytes)) + val newRow = TestFoobar.read(BsatnReader(newBytes)) + callback(conn.makeEventContext(), oldRow, newRow) + } + } + + override fun removeOnUpdate(id: CallbackId) = conn.table(tableName).removeOnUpdate(id) + +} + +class TestFCols(tableName: String) : Cols(tableName) { + val field: Col = Col("field") +} +''' +"spacetimedb/types/Baz.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class Baz( + val field: String, +) { + companion object { + fun read(reader: BsatnReader): Baz = + Baz( + field = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: Baz) { + writer.writeString(value.field) + } + } +} +''' +"spacetimedb/types/Foobar.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +sealed class Foobar { + data class Baz(val value: Baz) : Foobar() + data object Bar : Foobar() + data class Har(val value: UInt) : Foobar() + + companion object { + fun read(reader: BsatnReader): Foobar { + val tag = reader.readTag().toInt() + return when (tag) { + 0 -> Foobar.Baz(TODO("read variant payload")) + 1 -> Foobar.Bar + 2 -> Foobar.Har(TODO("read variant payload")) + else -> throw IllegalStateException("Unknown Foobar tag") + } + } + + fun write(writer: BsatnWriter, value: Foobar) { + when (value) { + is Foobar.Baz -> writer.writeTag(0u) // TODO: write variant payload + is Foobar.Bar -> writer.writeTag(1u) + is Foobar.Har -> writer.writeTag(2u) // TODO: write variant payload + } + } + } +} +''' +"spacetimedb/types/HasSpecialStuff.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class HasSpecialStuff( + val identity: Identity, + val connectionId: ConnectionId, +) { + companion object { + fun read(reader: BsatnReader): HasSpecialStuff = + HasSpecialStuff( + identity = Identity.read(reader), + connectionId = ConnectionId.read(reader), + ) + + fun write(writer: BsatnWriter, value: HasSpecialStuff) { + Identity.write(writer, value.identity) + ConnectionId.write(writer, value.connectionId) + } + } +} +''' +"spacetimedb/types/NamespaceTestC.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +enum class NamespaceTestC { + Foo, + Bar, +; + + companion object { + fun read(reader: BsatnReader): NamespaceTestC = + entries[reader.readU8().toInt()] + + fun write(writer: BsatnWriter, value: NamespaceTestC) { + writer.writeU8(value.ordinal.toUByte()) + } + } +} +''' +"spacetimedb/types/NamespaceTestF.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +sealed class NamespaceTestF { + data object Foo : NamespaceTestF() + data object Bar : NamespaceTestF() + data class Baz(val value: String) : NamespaceTestF() + + companion object { + fun read(reader: BsatnReader): NamespaceTestF { + val tag = reader.readTag().toInt() + return when (tag) { + 0 -> NamespaceTestF.Foo + 1 -> NamespaceTestF.Bar + 2 -> NamespaceTestF.Baz(TODO("read variant payload")) + else -> throw IllegalStateException("Unknown NamespaceTestF tag") + } + } + + fun write(writer: BsatnWriter, value: NamespaceTestF) { + when (value) { + is NamespaceTestF.Foo -> writer.writeTag(0u) + is NamespaceTestF.Bar -> writer.writeTag(1u) + is NamespaceTestF.Baz -> writer.writeTag(2u) // TODO: write variant payload + } + } + } +} +''' +"spacetimedb/types/NonrepeatingTestArg.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class NonrepeatingTestArg( + val scheduledId: ULong, + val scheduledAt: ScheduleAt, + val prevTime: Timestamp, +) { + companion object { + fun read(reader: BsatnReader): NonrepeatingTestArg = + NonrepeatingTestArg( + scheduledId = reader.readU64(), + scheduledAt = TODO("read ScheduleAt"), + prevTime = Timestamp.read(reader), + ) + + fun write(writer: BsatnWriter, value: NonrepeatingTestArg) { + writer.writeU64(value.scheduledId) + // TODO: serialize ScheduleAt value.scheduledAt + Timestamp.write(writer, value.prevTime) + } + } +} +''' +"spacetimedb/types/Person.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class Person( + val id: UInt, + val name: String, + val age: UByte, +) { + companion object { + fun read(reader: BsatnReader): Person = + Person( + id = reader.readU32(), + name = reader.readString(), + age = reader.readU8(), + ) + + fun write(writer: BsatnWriter, value: Person) { + writer.writeU32(value.id) + writer.writeString(value.name) + writer.writeU8(value.age) + } + } +} +''' +"spacetimedb/types/PkMultiIdentity.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class PkMultiIdentity( + val id: UInt, + val other: UInt, +) { + companion object { + fun read(reader: BsatnReader): PkMultiIdentity = + PkMultiIdentity( + id = reader.readU32(), + other = reader.readU32(), + ) + + fun write(writer: BsatnWriter, value: PkMultiIdentity) { + writer.writeU32(value.id) + writer.writeU32(value.other) + } + } +} +''' +"spacetimedb/types/Player.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class Player( + val identity: Identity, + val playerId: ULong, + val name: String, +) { + companion object { + fun read(reader: BsatnReader): Player = + Player( + identity = Identity.read(reader), + playerId = reader.readU64(), + name = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: Player) { + Identity.write(writer, value.identity) + writer.writeU64(value.playerId) + writer.writeString(value.name) + } + } +} +''' +"spacetimedb/types/Point.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class Point( + val x: Long, + val y: Long, +) { + companion object { + fun read(reader: BsatnReader): Point = + Point( + x = reader.readI64(), + y = reader.readI64(), + ) + + fun write(writer: BsatnWriter, value: Point) { + writer.writeI64(value.x) + writer.writeI64(value.y) + } + } +} +''' +"spacetimedb/types/PrivateTable.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class PrivateTable( + val name: String, +) { + companion object { + fun read(reader: BsatnReader): PrivateTable = + PrivateTable( + name = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: PrivateTable) { + writer.writeString(value.name) + } + } +} +''' +"spacetimedb/types/RemoveTable.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class RemoveTable( + val id: UInt, +) { + companion object { + fun read(reader: BsatnReader): RemoveTable = + RemoveTable( + id = reader.readU32(), + ) + + fun write(writer: BsatnWriter, value: RemoveTable) { + writer.writeU32(value.id) + } + } +} +''' +"spacetimedb/types/RepeatingTestArg.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class RepeatingTestArg( + val scheduledId: ULong, + val scheduledAt: ScheduleAt, + val prevTime: Timestamp, +) { + companion object { + fun read(reader: BsatnReader): RepeatingTestArg = + RepeatingTestArg( + scheduledId = reader.readU64(), + scheduledAt = TODO("read ScheduleAt"), + prevTime = Timestamp.read(reader), + ) + + fun write(writer: BsatnWriter, value: RepeatingTestArg) { + writer.writeU64(value.scheduledId) + // TODO: serialize ScheduleAt value.scheduledAt + Timestamp.write(writer, value.prevTime) + } + } +} +''' +"spacetimedb/types/TestA.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class TestA( + val x: UInt, + val y: UInt, + val z: String, +) { + companion object { + fun read(reader: BsatnReader): TestA = + TestA( + x = reader.readU32(), + y = reader.readU32(), + z = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: TestA) { + writer.writeU32(value.x) + writer.writeU32(value.y) + writer.writeString(value.z) + } + } +} +''' +"spacetimedb/types/TestB.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class TestB( + val foo: String, +) { + companion object { + fun read(reader: BsatnReader): TestB = + TestB( + foo = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: TestB) { + writer.writeString(value.foo) + } + } +} +''' +"spacetimedb/types/TestD.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class TestD( + val testC: NamespaceTestC?, + val testCNested: List?, +) { + companion object { + fun read(reader: BsatnReader): TestD = + TestD( + testC = reader.readOption { r -> NamespaceTestC.read(r) }, + testCNested = reader.readOption { r -> r.readArray { r -> NamespaceTestC.read(r) } }, + ) + + fun write(writer: BsatnWriter, value: TestD) { + writer.writeOption(value.testC) { v, inner -> + NamespaceTestC.write(v, inner) + } + writer.writeOption(value.testCNested) { v, inner -> + v.writeArray(inner) { w, elem -> + NamespaceTestC.write(w, elem) + } + } + } + } +} +''' +"spacetimedb/types/TestE.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class TestE( + val id: ULong, + val name: String, +) { + companion object { + fun read(reader: BsatnReader): TestE = + TestE( + id = reader.readU64(), + name = reader.readString(), + ) + + fun write(writer: BsatnWriter, value: TestE) { + writer.writeU64(value.id) + writer.writeString(value.name) + } + } +} +''' +"spacetimedb/types/TestFoobar.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("RedundantVisibilityModifier") + +package spacetimedb.types + +import com.clockworklabs.spacetimedb.* +import com.clockworklabs.spacetimedb.bsatn.* +import com.clockworklabs.spacetimedb.query.* +import kotlin.uuid.Uuid +import spacetimedb.* + +data class TestFoobar( + val field: Foobar, +) { + companion object { + fun read(reader: BsatnReader): TestFoobar = + TestFoobar( + field = Foobar.read(reader), + ) + + fun write(writer: BsatnWriter, value: TestFoobar) { + Foobar.write(writer, value.field) + } + } +} +''' diff --git a/sdks/kotlin/DEVELOP.md b/sdks/kotlin/DEVELOP.md new file mode 100644 index 00000000000..b2db9fe5630 --- /dev/null +++ b/sdks/kotlin/DEVELOP.md @@ -0,0 +1,227 @@ +# Kotlin SDK — Developer Guide + +Internal documentation for contributors working on the SpacetimeDB Kotlin SDK. + +## Project Structure + +``` +src/ + commonMain/ Shared Kotlin code (all targets) + com/clockworklabs/spacetimedb/ + SpacetimeDBClient.kt DbConnection, DbConnectionBuilder, ProcedureResult + Identity.kt Identity, ConnectionId, Address, Timestamp, TimeDuration + ClientCache.kt Client-side row cache (TableCache, ByteArrayWrapper) + TableHandle.kt Untyped per-table callback registration + Table.kt Typed Table/TableWithPrimaryKey/EventTable interfaces + SubscriptionHandle.kt Subscription lifecycle (PENDING/ACTIVE/ENDED/CANCELLED) + SubscriptionBuilder.kt Fluent subscription API with pending query accumulation + Event.kt Event sealed class, ReducerEvent, Credentials + DbContext.kt Interface for connection context + EventContext.kt Event context with event metadata + ReducerEventContext.kt Reducer-specific event context + ProcedureEventContext.kt Procedure event context + SubscriptionEventContext.kt Subscription event context + ErrorContext.kt Error context for disconnect/error callbacks + ReconnectPolicy.kt Exponential backoff configuration + Compression.kt expect declarations for decompression + ScheduleAt.kt ScheduleAt type (Interval/Time) + Uuid.kt UUID type with BSATN read/write + toByteArray/fromByteArray + bsatn/ + BsatnReader.kt Binary deserialization (little-endian, SATS-compatible) + BsatnWriter.kt Binary serialization (little-endian, SATS-compatible) + BsatnRowList.kt Row list decoding (FixedSize/RowOffsets) + protocol/ + ServerMessage.kt Server → Client message decoding + ClientMessage.kt Client → Server message encoding + ProtocolTypes.kt QuerySetId, QueryRows, TableUpdateRows, etc. + query/ + QueryBuilder.kt QueryTable, Col, BoolExpr, Cols, FromQuery, QueryProvider, QueryFrom + websocket/ + WebSocketTransport.kt WebSocket lifecycle, token exchange, keep-alive, reconnection + jvmMain/ JVM-specific (Gzip via java.util.zip, Brotli via org.brotli) + iosMain/ iOS-specific (Gzip via platform.zlib) + commonTest/ Shared tests + jvmTest/ JVM-only tests (compression round-trips, live integration) +``` + +## Architecture + +### Connection Lifecycle + +``` +DbConnectionBuilder.build() + → DbConnection constructor + → WebSocketTransport.connect() + → exchangeToken() POST /v1/identity/websocket-token (if token provided) + → connectSession() opens WebSocket + → processSendQueue() (coroutine: outbound messages) + → processIncoming() (coroutine: inbound frames) + → runKeepAlive() (coroutine: 30s ping heartbeats) +``` + +On unexpected disconnect with a `ReconnectPolicy`, the transport enters a +`RECONNECTING` state and calls `attemptReconnect()` which retries with +exponential backoff up to `maxRetries` times. + +### Wire Protocol + +Uses the `v2.bsatn.spacetimedb` WebSocket subprotocol. All messages are BSATN +(Binary SpacetimeDB Algebraic Type Notation) — a tag-length-value encoding +defined in `crates/client-api-messages/src/websocket/v2.rs`. + +**Server messages** are preceded by a compression byte: +- `0x00` — uncompressed +- `0x01` — Brotli +- `0x02` — Gzip + +The SDK requests Gzip compression via the `compression=Gzip` query parameter. + +**Key protocol details:** +- All integers are little-endian +- Sum types use u8 tag: `Option` uses tag `0` = Some, tag `1` = None +- UUID is encoded as two little-endian i64 values (MSB first, LSB second) +- Connection URL includes `connection_id`, `compression`, optional `token`, `confirmed`, `light` params + +### Client Cache + +`ClientCache` maintains a map of `TableCache` instances, one per table. Each +`TableCache` stores rows keyed by content (`ByteArrayWrapper`) with reference +counting. This allows overlapping subscriptions to share rows without duplicates. + +Transaction updates produce `TableOperation` events (Insert, Delete, Update, +EventInsert) which drive the `TableHandle` callback system. + +**New: Initial subscription data (SubscribeApplied) now fires `onInsert` callbacks** +for each row, matching Rust/TypeScript SDK behavior. Previously initial rows +were inserted silently. + +### Query Builder + +The SDK includes a type-safe query builder DSL matching the Rust SDK pattern: + +```kotlin +conn.subscriptionBuilder() + .addQuery { users } // SELECT * FROM "users" + .addQuery { users.where { cols -> cols.age.gt(18) } } // SELECT * FROM "users" WHERE ... + .subscribe() +``` + +Each generated table has a `{Table}Cols` class with typed `Col` column accessors +supporting `.eq()`, `.ne()`, `.gt()`, `.lt()`, `.gte()`, `.lte()`, plus `.and()` / `.or()` +combinators on `BoolExpr`. + +### Event System + +The SDK provides typed event contexts matching Rust SDK patterns: + +| Context | Event | When | +|---------|-------|------| +| `EventContext` | `Event` | Row callbacks (insert/delete/update) | +| `ReducerEventContext` | `ReducerEvent` | Reducer completion callbacks | +| `ProcedureEventContext` | — | Procedure completion callbacks | +| `SubscriptionEventContext` | — | Subscription applied/ended | +| `ErrorContext` | `Throwable?` | Connection errors | + +### Threading Model + +- `WebSocketTransport` runs on a `CoroutineScope(SupervisorJob() + Dispatchers.Default)`. +- All `handleMessage` processing is serialized behind a `Mutex` to prevent + concurrent cache mutation. +- `atomicfu` atomics are used for transport-level flags (`idle`, `wantPong`, + `intentionalDisconnect`) that are read/written across coroutines. +- User callbacks are wrapped in try-catch to prevent exceptions from crashing + the message processing loop. + +### Platform-Specific Code + +Uses Kotlin `expect`/`actual` for decompression: + +| Platform | Gzip | Brotli | +|----------|------|--------| +| JVM | `java.util.zip.GZIPInputStream` | `org.brotli.dec.BrotliInputStream` | +| iOS | `platform.zlib` (wbits=31) | Not supported (SDK defaults to Gzip) | + +## Code Generation + +The Kotlin codegen backend lives at `crates/codegen/src/kotlin.rs`. It implements +the `Lang` trait and generates: + +- **Type files:** `data class` for products, `sealed class` for sums, `enum class` for plain enums + — each with BSATN `read`/`write` companion methods +- **Table files:** Typed handle classes implementing `TableWithPrimaryKey` or `EventTable` + with `iter()`, `find()`, `onInsert`/`onDelete`/`onUpdate` callbacks +- **Reducer files:** Args data class + `internal fun {name}Reducer(conn, args)` +- **Procedure files:** Args data class + `internal fun {name}Procedure(conn, args, callback)` +- **Global files:** `RemoteTables`, `RemoteReducers`, `RemoteProcedures`, `From` (typed query builder) + +Generate bindings: + +```bash +spacetime generate --lang kotlin --out-dir src/commonMain/kotlin --namespace my.package +``` + +## Building & Testing + +```bash +# Build +./gradlew compileKotlinJvm + +# Run unit tests +./gradlew jvmTest + +# Run live integration tests (requires running server + published module) +SPACETIMEDB_TEST=1 SPACETIMEDB_URI=ws://127.0.0.1:3000 SPACETIMEDB_MODULE=kotlin-sdk-test \ + ./gradlew jvmTest --rerun-tasks + +# Publish to local Maven +./gradlew publishToMavenLocal +``` + +## Test Suite + +| File | Coverage | +|------|----------| +| `BsatnTest.kt` | Reader/Writer round-trips for all primitive types | +| `ProtocolTest.kt` | ServerMessage and ClientMessage encode/decode | +| `ClientCacheTest.kt` | Cache operations, ref counting, transaction updates | +| `OneOffQueryTest.kt` | OneOffQueryResult decode (Ok and Err variants) | +| `CompressionTest.kt` | Gzip round-trip, empty/large payloads (JVM only) | +| `ReconnectPolicyTest.kt` | Backoff calculation, parameter validation | +| `EdgeCaseTest.kt` | Option encoding, subscription lifecycle, callback safety, URI normalization | +| `LiveIntegrationTest.kt` | Connect, subscribe, reducer call, one-off query (requires server) | +| `LiveEdgeCaseTest.kt` | Multi-subscription, reconnect, invalid queries (requires server) | +| `PerformanceBenchmarkTest.kt` | Keynote-2 workload benchmark (JVM only) | + +## Design Decisions + +1. **Option encoding** — SATS encodes `Option` with tag `0` = Some, tag `1` = None. + This differs from what a naive u8 sum type implementation might assume (0 = None). + +2. **UUID byte order** — MSB first (two little-endian i64 values). This matches the + `AlgebraicType::uuid()` wire format used by the SpacetimeDB server. + +3. **Ktor built-in ping interval** — The SDK configures Ktor's `WebSockets` plugin + with `pingInterval = 30.seconds` to send periodic WebSocket pings, keeping the + connection alive. No custom ping/pong coroutine is needed. The Rust SDK's + custom keep-alive logic is not replicated since Ktor handles this at the + transport layer. + +4. **Token exchange** — The SDK POSTs to `/v1/identity/websocket-token` to exchange + the long-lived auth token for a short-lived WebSocket token, matching the TypeScript + SDK authentication flow. + +5. **ByteArray row storage** — Rows are stored as raw BSATN bytes in the cache. + Typed access is provided via generated code that wraps the raw bytes with + type-specific `read()`/`write()` companions. + +6. **Compression negotiation** — The SDK advertises `compression=Gzip` in the + connection URI. Brotli is supported on JVM but not iOS; Gzip provides + universal coverage. + +7. **Callback safety** — All user callbacks are wrapped in try-catch to prevent + exceptions from crashing the message processing coroutine. This matches the + Rust SDK's approach of deferring callback execution. + +8. **Subscription lifecycle** — Supports `PENDING`, `ACTIVE`, `ENDED`, and `CANCELLED` + states. Unsubscribing before a subscription is applied (`PENDING -> CANCELLED`) + prevents it from being registered when the server responds. diff --git a/sdks/kotlin/README.md b/sdks/kotlin/README.md new file mode 100644 index 00000000000..a78a1f8f955 --- /dev/null +++ b/sdks/kotlin/README.md @@ -0,0 +1,162 @@ +# SpacetimeDB Kotlin SDK + +## Overview + +The Kotlin Multiplatform (KMP) client SDK for [SpacetimeDB](https://spacetimedb.com). Targets **JVM** and **iOS** (arm64, simulator-arm64, x64), enabling native SpacetimeDB clients from Kotlin, Java, and Swift (via KMP interop). + +## Features + +- BSATN binary protocol (`v2.bsatn.spacetimedb`) +- Subscriptions with SQL and typed query builder support +- One-off queries (suspend and callback variants) +- Reducer invocation with result callbacks +- Procedure invocation (non-transactional server-side functions) +- Token-based authentication with short-lived token exchange +- Automatic reconnection with exponential backoff +- Keep-alive pings (30s interval) +- Gzip and Brotli message decompression +- Client-side row cache with ref-counted overlapping subscriptions +- Typed code generation via `spacetime generate --lang kotlin` +- Configurable package/namespace for generated code + +## Quick Start + +### With generated client bindings + +```kotlin +val conn = DbConnection.builder() + .withUri("ws://localhost:3000") + .withModuleName("my_module") + .onConnect { conn, identity, token -> + println("Connected as $identity") + + // Subscribe with typed query builder + conn.subscriptionBuilder() + .addQuery { users } + .addQuery { roles.where { cols -> cols.name.eq("admin") } } + .onApplied { println("Subscribed") } + .subscribe() + + // Typed row callbacks + conn.db.users.onInsert { ctx, row -> + println("${row.name} inserted") + } + conn.db.users.onDelete { ctx, row -> + println("${row.name} deleted") + } + conn.db.users.onUpdate { ctx, oldRow, newRow -> + println("${oldRow.name} -> ${newRow.name}") + } + + // Typed reducer calls + conn.reducers.provisionRole(ProvisionRoleArgs(name = "admin")) + } + .onDisconnect { _, error -> + println("Disconnected: ${error?.message ?: "clean"}") + } + .build() +``` + +### Without code generation (raw SDK) + +```kotlin +val conn = DbConnection.builder() + .withUri("ws://localhost:3000") + .withModuleName("my_module") + .build() + +conn.subscriptionBuilder() + .subscribe("SELECT * FROM users") + +conn.table("users").onInsert { rowBytes -> + // Raw BSATN bytes — use generated code for typed access +} +``` + +## Installation + +Add the dependency to your `build.gradle.kts`: + +```kotlin +repositories { + mavenCentral() + mavenLocal() // if building from source +} + +kotlin { + sourceSets { + commonMain.dependencies { + implementation("com.clockworklabs:spacetimedb-sdk:0.1.0") + } + } +} +``` + +## Code Generation + +Generate typed client bindings from your SpacetimeDB module: + +```bash +spacetime generate --lang kotlin --out-dir src/commonMain/kotlin --namespace my.package +``` + +This produces: +- **Types/** — Data classes for each type with BSATN `read`/`write` companions +- **Tables/** — Typed table handles with `onInsert`/`onDelete`/`onUpdate` callbacks +- **Reducers/** — Reducer args classes + `RemoteReducers` methods +- **Procedures/** — Procedure args classes + `RemoteProcedures` methods +- **RemoteModule.kt** — `RemoteTables`, `RemoteReducers`, `RemoteProcedures`, `From`, extensions + +### Generated API + +```kotlin +conn.db.users.subscribe() // subscribe via table handle +conn.subscriptionBuilder() + .addQuery { users } // typed query builder + .addQuery { users.where { cols -> cols.age.gt(18) } } + .subscribe() + +conn.reducers.addPlayer(AddPlayerArgs(name = "Alice")) // typed reducer call +conn.db.users.onInsert { ctx, row -> ... } // typed row callback +conn.db.users.find { it.name == "Alice" } // typed query +conn.unsubscribeAll() // unsubscribe everything +``` + +## Features + +### Connection + +- `withToken(token)` — authenticate with an auth token +- `withCompression(mode)` — Gzip or Brotli compression +- `withReconnectPolicy(policy)` — automatic reconnection with backoff +- `withConfirmedReads(enabled)` — wait for durable confirmation +- `withLightMode(enabled)` — reduced network data + +### Subscriptions + +- `subscriptionBuilder()` — fluent builder with `onApplied`, `onError`, `onEnded` +- `subscribe(vararg queries)` — subscribe to SQL queries +- `unsubscribe()` / `unsubscribeThen(onEnded)` — end a subscription +- `unsubscribeAll()` — end all subscriptions on the connection + +### Row Callbacks + +- `onInsert { ctx, row -> ... }` — called on row insertion +- `onDelete { ctx, row -> ... }` — called on row deletion +- `onUpdate { ctx, oldRow, newRow -> ... }` — called on row update (tables with primary key) +- Callbacks fire for both initial subscription data and subsequent transaction updates + +### Reducers & Procedures + +- `conn.reducers.myReducer(MyReducerArgs(...))` — typed reducer call +- `conn.callReducer(name, args, callback?)` — raw reducer call +- `conn.procedures.myProcedure(MyProcedureArgs(...)) { result -> ... }` — typed procedure call +- `conn.callProcedure(name, args, callback?)` — raw procedure call + +## Documentation + +For the SpacetimeDB platform documentation, see [spacetimedb.com/docs](https://spacetimedb.com/docs). + +## Internal Developer Documentation + +See [`DEVELOP.md`](./DEVELOP.md). diff --git a/sdks/kotlin/build.gradle.kts b/sdks/kotlin/build.gradle.kts new file mode 100644 index 00000000000..2d6a6d036e4 --- /dev/null +++ b/sdks/kotlin/build.gradle.kts @@ -0,0 +1,74 @@ +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.kotlinAtomicfu) + `maven-publish` +} + +group = "com.clockworklabs" +version = "0.1.0" + +kotlin { + jvm() + iosArm64() + iosSimulatorArm64() + macosArm64() + + applyDefaultHierarchyTemplate() + + compilerOptions { + optIn.addAll( + "kotlin.concurrent.atomics.ExperimentalAtomicApi", + "kotlin.uuid.ExperimentalUuidApi" + ) + } + + sourceSets { + commonMain { + dependencies { + implementation(libs.kotlinx.coroutines.core) + implementation(project.dependencies.platform(libs.ktor.bom)) + implementation(libs.ktor.client.core) + implementation(libs.ktor.client.websockets) + } + } + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + jvmMain { + dependencies { + implementation(libs.ktor.client.okhttp) + implementation(libs.brotli.dec) + } + } + appleMain { + dependencies { + implementation(libs.ktor.client.darwin) + } + } + } +} + +publishing { + publications { + withType { + pom { + name.set("SpacetimeDB Kotlin SDK") + description.set("SpacetimeDB client SDK for Kotlin Multiplatform") + } + } + } + repositories { + mavenLocal() + } +} + +tasks.matching { it.name == "jvmTest" }.configureEach { + if (this is Test) { + testLogging { + showStandardStreams = true + } + maxHeapSize = "1g" + } +} diff --git a/sdks/kotlin/gradle.properties b/sdks/kotlin/gradle.properties new file mode 100644 index 00000000000..d54bbe28298 --- /dev/null +++ b/sdks/kotlin/gradle.properties @@ -0,0 +1,3 @@ +kotlin.code.style=official +kotlin.mpp.stability.nowarn=true +org.gradle.jvmargs=-Xmx2g diff --git a/sdks/kotlin/gradle/libs.versions.toml b/sdks/kotlin/gradle/libs.versions.toml new file mode 100644 index 00000000000..9ad299505d9 --- /dev/null +++ b/sdks/kotlin/gradle/libs.versions.toml @@ -0,0 +1,15 @@ +[versions] +kotlin = "2.3.21" + +[libraries] +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version = "1.10.2" } +ktor-bom = { module = "io.ktor:ktor-bom", version = "3.4.3" } +ktor-client-core = { module = "io.ktor:ktor-client-core" } +ktor-client-websockets = { module = "io.ktor:ktor-client-websockets" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp" } +ktor-client-darwin = { module = "io.ktor:ktor-client-darwin" } +brotli-dec = { module = "org.brotli:dec", version = "0.1.2" } + +[plugins] +kotlinMultiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } +kotlinAtomicfu = { id = "org.jetbrains.kotlin.plugin.atomicfu", version.ref = "kotlin" } diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..1b33c55baab Binary files /dev/null and b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar differ diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..2a84e188b85 --- /dev/null +++ b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0.0-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/sdks/kotlin/gradlew b/sdks/kotlin/gradlew new file mode 100755 index 00000000000..23d15a93670 --- /dev/null +++ b/sdks/kotlin/gradlew @@ -0,0 +1,251 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/sdks/kotlin/gradlew.bat b/sdks/kotlin/gradlew.bat new file mode 100644 index 00000000000..db3a6ac207e --- /dev/null +++ b/sdks/kotlin/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/sdks/kotlin/settings.gradle.kts b/sdks/kotlin/settings.gradle.kts new file mode 100644 index 00000000000..c793d4071a8 --- /dev/null +++ b/sdks/kotlin/settings.gradle.kts @@ -0,0 +1,16 @@ +rootProject.name = "spacetimedb-sdk" + +pluginManagement { + repositories { + mavenCentral() + gradlePluginPortal() + google() + } +} + +dependencyResolutionManagement { + repositories { + mavenCentral() + google() + } +} diff --git a/sdks/kotlin/src/appleMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt b/sdks/kotlin/src/appleMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt new file mode 100644 index 00000000000..02748b1c315 --- /dev/null +++ b/sdks/kotlin/src/appleMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt @@ -0,0 +1,87 @@ +package com.clockworklabs.spacetimedb + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.alloc +import kotlinx.cinterop.free +import kotlinx.cinterop.nativeHeap +import kotlinx.cinterop.ptr +import kotlinx.cinterop.reinterpret +import kotlinx.cinterop.usePinned +import platform.zlib.Z_FINISH +import platform.zlib.Z_OK +import platform.zlib.Z_STREAM_END +import platform.zlib.inflate +import platform.zlib.inflateEnd +import platform.zlib.inflateInit2 +import platform.zlib.z_stream + +actual fun decompressBrotli(data: ByteArray): ByteArray { + // Brotli decompression requires Apple's Compression framework interop or a bundled decoder. + // The SDK defaults to Gzip compression (see buildWsUri), so Brotli is not expected. + // If a server sends Brotli, this will surface the issue clearly. + throw UnsupportedOperationException( + "Brotli decompression is not available on iOS. " + + "Configure the server connection to use Gzip compression instead." + ) +} + +@OptIn(ExperimentalForeignApi::class) +actual fun decompressGzip(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + + val stream = nativeHeap.alloc() + try { + stream.zalloc = null + stream.zfree = null + stream.opaque = null + stream.avail_in = 0u + stream.next_in = null + + // wbits = MAX_WBITS + 16 (31) tells zlib to expect gzip format + val initResult = inflateInit2(stream.ptr, 31) + if (initResult != Z_OK) { + throw IllegalStateException("zlib inflateInit2 failed: $initResult") + } + + val chunks = mutableListOf() + val outBuf = ByteArray(8192) + + data.usePinned { srcPinned -> + stream.next_in = srcPinned.addressOf(0).reinterpret() + stream.avail_in = data.size.toUInt() + + do { + outBuf.usePinned { dstPinned -> + stream.next_out = dstPinned.addressOf(0).reinterpret() + stream.avail_out = outBuf.size.toUInt() + + val ret = inflate(stream.ptr, Z_FINISH) + if (ret != Z_OK && ret != Z_STREAM_END) { + inflateEnd(stream.ptr) + throw IllegalStateException("zlib inflate failed: $ret") + } + + val produced = outBuf.size - stream.avail_out.toInt() + if (produced > 0) { + chunks.add(outBuf.copyOf(produced)) + } + } + } while (stream.avail_out == 0u) + } + + inflateEnd(stream.ptr) + + // Concatenate chunks + val totalSize = chunks.sumOf { it.size } + val result = ByteArray(totalSize) + var offset = 0 + for (chunk in chunks) { + chunk.copyInto(result, offset) + offset += chunk.size + } + return result + } finally { + nativeHeap.free(stream) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Address.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Address.kt new file mode 100644 index 00000000000..a14813da16d --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Address.kt @@ -0,0 +1,28 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +/** A 128-bit address identifying a client in the SpacetimeDB network. */ +class Address(val bytes: ByteArray) { + init { + require(bytes.size == 16) { "Address must be 16 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is Address && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "Address(${toHex()})" + + companion object { + val ZERO = Address(ByteArray(16)) + + fun read(reader: BsatnReader): Address = Address(reader.readBytes(16)) + + fun write(writer: BsatnWriter, value: Address) { writer.writeBytes(value.bytes) } + } +} \ No newline at end of file diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt new file mode 100644 index 00000000000..a421992e792 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt @@ -0,0 +1,150 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.PersistentTableRows +import com.clockworklabs.spacetimedb.protocol.QueryRows +import com.clockworklabs.spacetimedb.protocol.QuerySetUpdate +import com.clockworklabs.spacetimedb.protocol.TableUpdateRows + +class ClientCache { + private val tables = mutableMapOf() + + fun getOrCreateTable(name: String): TableCache = + tables.getOrPut(name) { TableCache(name) } + + fun getTable(name: String): TableCache? = tables[name] + + fun tableNames(): Set = tables.keys.toSet() + + fun applySubscribeRows(rows: QueryRows): List { + val operations = mutableListOf() + for (singleTable in rows.tables) { + val tableName = singleTable.table.value + val cache = getOrCreateTable(tableName) + val decodedRows = singleTable.rows.decodeRows() + for (row in decodedRows) { + cache.insertRow(row) + operations.add(TableOperation.Insert(tableName, row)) + } + } + return operations + } + + fun applyUnsubscribeRows(rows: QueryRows) { + for (singleTable in rows.tables) { + val tableName = singleTable.table.value + val cache = getTable(tableName) ?: continue + val decodedRows = singleTable.rows.decodeRows() + for (row in decodedRows) { + cache.deleteRow(row) + } + } + } + + fun applyTransactionUpdate(querySets: List): List { + val operations = mutableListOf() + for (qsUpdate in querySets) { + for (tableUpdate in qsUpdate.tables) { + val tableName = tableUpdate.tableName.value + val cache = getOrCreateTable(tableName) + for (rowUpdate in tableUpdate.rows) { + when (rowUpdate) { + is TableUpdateRows.PersistentTable -> { + applyPersistentUpdate(cache, tableName, rowUpdate.rows, operations) + } + is TableUpdateRows.EventTable -> { + val decoded = rowUpdate.rows.events.decodeRows() + for (row in decoded) { + operations.add(TableOperation.EventInsert(tableName, row)) + } + } + } + } + } + } + return operations + } + + private fun applyPersistentUpdate( + cache: TableCache, + tableName: String, + rows: PersistentTableRows, + operations: MutableList, + ) { + val deletes = rows.deletes.decodeRows() + val inserts = rows.inserts.decodeRows() + + val deletedSet = deletes.map { ByteArrayWrapper(it) }.toSet() + val insertMap = mutableMapOf() + for (row in inserts) { + insertMap[ByteArrayWrapper(row)] = row + } + + for (row in deletes) { + val wrapper = ByteArrayWrapper(row) + val newRow = insertMap[wrapper] + if (newRow != null) { + cache.deleteRow(row) + cache.insertRow(newRow) + operations.add(TableOperation.Update(tableName, row, newRow)) + } else { + cache.deleteRow(row) + operations.add(TableOperation.Delete(tableName, row)) + } + } + + for (row in inserts) { + val wrapper = ByteArrayWrapper(row) + if (wrapper !in deletedSet) { + cache.insertRow(row) + operations.add(TableOperation.Insert(tableName, row)) + } + } + } +} + +class TableCache(val name: String) { + private val rows = mutableMapOf() + + val count: Int get() = rows.size + + fun insertRow(rowBytes: ByteArray) { + val key = ByteArrayWrapper(rowBytes) + val existing = rows[key] + if (existing != null) { + existing.refCount++ + } else { + rows[key] = RowEntry(rowBytes, 1) + } + } + + fun deleteRow(rowBytes: ByteArray): Boolean { + val key = ByteArrayWrapper(rowBytes) + val existing = rows[key] ?: return false + existing.refCount-- + if (existing.refCount <= 0) { + rows.remove(key) + } + return true + } + + fun allRows(): List = rows.values.map { it.data } + + fun containsRow(rowBytes: ByteArray): Boolean = + rows.containsKey(ByteArrayWrapper(rowBytes)) +} + +class RowEntry(val data: ByteArray, var refCount: Int) + +sealed class TableOperation { + data class Insert(val tableName: String, val row: ByteArray) : TableOperation() + data class Delete(val tableName: String, val row: ByteArray) : TableOperation() + data class Update(val tableName: String, val oldRow: ByteArray, val newRow: ByteArray) : TableOperation() + data class EventInsert(val tableName: String, val row: ByteArray) : TableOperation() +} + +class ByteArrayWrapper(val data: ByteArray) { + override fun equals(other: Any?): Boolean = + other is ByteArrayWrapper && data.contentEquals(other.data) + + override fun hashCode(): Int = data.contentHashCode() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt new file mode 100644 index 00000000000..b25cb22be0e --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt @@ -0,0 +1,5 @@ +package com.clockworklabs.spacetimedb + +expect fun decompressBrotli(data: ByteArray): ByteArray + +expect fun decompressGzip(data: ByteArray): ByteArray diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/DbContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/DbContext.kt new file mode 100644 index 00000000000..3a16270c6d9 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/DbContext.kt @@ -0,0 +1,15 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import kotlinx.coroutines.flow.StateFlow + +interface DbContext { + val identity: Identity? + val connectionId: ConnectionId + val savedToken: String? + val isActive: Boolean + val connectionState: StateFlow + + fun disconnect() + fun subscriptionBuilder(): SubscriptionBuilder +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ErrorContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ErrorContext.kt new file mode 100644 index 00000000000..fe75faf3a18 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ErrorContext.kt @@ -0,0 +1,14 @@ +package com.clockworklabs.spacetimedb + +class ErrorContext( + override val identity: Identity?, + override val connectionId: ConnectionId, + override val savedToken: String?, + override val isActive: Boolean, + override val connectionState: kotlinx.coroutines.flow.StateFlow, + val error: Throwable?, + private val conn: DbConnection, +) : DbContext { + override fun disconnect() = conn.disconnect() + override fun subscriptionBuilder() = conn.subscriptionBuilder() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt new file mode 100644 index 00000000000..fd18db8279b --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt @@ -0,0 +1,30 @@ +package com.clockworklabs.spacetimedb + +sealed class Status { + data object Committed : Status() + data class Failed(val message: String) : Status() + data class OutOfEnergy(val message: String) : Status() +} + +data class ReducerEvent( + val timestamp: Timestamp, + val status: Status, + val callerIdentity: Identity, + val callerConnectionId: ConnectionId, + val reducerName: String, + val energyConsumed: Long, +) + +sealed class Event { + data class Reducer(val event: ReducerEvent) : Event() + data object SubscribeApplied : Event() + data object UnsubscribeApplied : Event() + data object Disconnected : Event() + data class SubscribeError(val message: String) : Event() + data object Transaction : Event() +} + +data class Credentials( + val identity: Identity, + val token: String, +) diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/EventContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/EventContext.kt new file mode 100644 index 00000000000..f6b0808da06 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/EventContext.kt @@ -0,0 +1,14 @@ +package com.clockworklabs.spacetimedb + +class EventContext( + override val identity: Identity?, + override val connectionId: ConnectionId, + override val savedToken: String?, + override val isActive: Boolean, + override val connectionState: kotlinx.coroutines.flow.StateFlow, + val event: Event, + private val conn: DbConnection, +) : DbContext { + override fun disconnect() = conn.disconnect() + override fun subscriptionBuilder() = conn.subscriptionBuilder() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt new file mode 100644 index 00000000000..c6c93d3a168 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt @@ -0,0 +1,86 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +private val HEX_CHARS = "0123456789abcdef".toCharArray() + +internal fun ByteArray.toHexString(): String { + val result = CharArray(size * 2) + for (i in indices) { + val v = this[i].toInt() and 0xFF + result[i * 2] = HEX_CHARS[v ushr 4] + result[i * 2 + 1] = HEX_CHARS[v and 0x0F] + } + return result.concatToString() +} + +internal fun String.hexToByteArray(): ByteArray { + require(length % 2 == 0) { "Hex string must have even length" } + return ByteArray(length / 2) { i -> + val hi = this[i * 2].digitToInt(16) + val lo = this[i * 2 + 1].digitToInt(16) + ((hi shl 4) or lo).toByte() + } +} + +/** A 256-bit identifier that uniquely represents a user across all SpacetimeDB modules. */ +class Identity(val bytes: ByteArray) { + init { + require(bytes.size == 32) { "Identity must be 32 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is Identity && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "Identity(${toHex()})" + + companion object { + val ZERO = Identity(ByteArray(32)) + + fun fromHex(hex: String): Identity { + require(hex.length == 64) { "Identity hex must be 64 characters" } + val bytes = hex.hexToByteArray() + return Identity(bytes) + } + + fun read(reader: BsatnReader): Identity = Identity(reader.readBytes(32)) + + fun write(writer: BsatnWriter, value: Identity) { writer.writeBytes(value.bytes) } + } +} + +/** A 128-bit identifier unique to each client connection session. */ +class ConnectionId(val bytes: ByteArray) { + init { + require(bytes.size == 16) { "ConnectionId must be 16 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is ConnectionId && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "ConnectionId(${toHex()})" + + companion object { + val ZERO = ConnectionId(ByteArray(16)) + + fun random(): ConnectionId { + val bytes = ByteArray(16) + kotlin.random.Random.nextBytes(bytes) + return ConnectionId(bytes) + } + + fun read(reader: BsatnReader): ConnectionId = ConnectionId(reader.readBytes(16)) + + fun write(writer: BsatnWriter, value: ConnectionId) { writer.writeBytes(value.bytes) } + } +} + diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ProcedureEventContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ProcedureEventContext.kt new file mode 100644 index 00000000000..64563cd70a7 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ProcedureEventContext.kt @@ -0,0 +1,13 @@ +package com.clockworklabs.spacetimedb + +class ProcedureEventContext( + override val identity: Identity?, + override val connectionId: ConnectionId, + override val savedToken: String?, + override val isActive: Boolean, + override val connectionState: kotlinx.coroutines.flow.StateFlow, + private val conn: DbConnection, +) : DbContext { + override fun disconnect() = conn.disconnect() + override fun subscriptionBuilder() = conn.subscriptionBuilder() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt new file mode 100644 index 00000000000..6be72758915 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt @@ -0,0 +1,31 @@ +package com.clockworklabs.spacetimedb + +/** + * Configures automatic reconnection with exponential backoff. + * + * @property maxRetries Maximum number of reconnect attempts before giving up. + * @property initialDelayMs Delay before the first retry (milliseconds). + * @property maxDelayMs Upper bound on the delay between retries (milliseconds). + * @property backoffMultiplier Factor by which the delay grows each attempt. + */ +data class ReconnectPolicy( + val maxRetries: Int = 5, + val initialDelayMs: Long = 1_000, + val maxDelayMs: Long = 30_000, + val backoffMultiplier: Double = 2.0, +) { + init { + require(maxRetries >= 0) { "maxRetries must be non-negative" } + require(initialDelayMs > 0) { "initialDelayMs must be positive" } + require(maxDelayMs >= initialDelayMs) { "maxDelayMs must be >= initialDelayMs" } + require(backoffMultiplier >= 1.0) { "backoffMultiplier must be >= 1.0" } + } + + internal fun delayForAttempt(attempt: Int): Long { + var delay = initialDelayMs + repeat(attempt) { + delay = (delay * backoffMultiplier).toLong().coerceAtMost(maxDelayMs) + } + return delay + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerEventContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerEventContext.kt new file mode 100644 index 00000000000..d6c5b6ffd98 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerEventContext.kt @@ -0,0 +1,14 @@ +package com.clockworklabs.spacetimedb + +class ReducerEventContext( + override val identity: Identity?, + override val connectionId: ConnectionId, + override val savedToken: String?, + override val isActive: Boolean, + override val connectionState: kotlinx.coroutines.flow.StateFlow, + val event: ReducerEvent, + private val conn: DbConnection, +) : DbContext { + override fun disconnect() = conn.disconnect() + override fun subscriptionBuilder() = conn.subscriptionBuilder() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt new file mode 100644 index 00000000000..8760c757127 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt @@ -0,0 +1,16 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +class ReducerHandle(private val connection: DbConnection) { + + fun call(reducerName: String, args: ByteArray = ByteArray(0), callback: ((ReducerResult) -> Unit)? = null) { + connection.callReducer(reducerName, args, callback) + } + + fun call(reducerName: String, writeArgs: (BsatnWriter) -> Unit, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writeArgs(writer) + connection.callReducer(reducerName, writer.toByteArray(), callback) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ScheduleAt.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ScheduleAt.kt new file mode 100644 index 00000000000..2b428aac71d --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ScheduleAt.kt @@ -0,0 +1,33 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.TimeDuration + +sealed class ScheduleAt { + data class Interval(val every: TimeDuration) : ScheduleAt() + data class Time(val time: Timestamp) : ScheduleAt() + + companion object { + fun read(reader: BsatnReader): ScheduleAt { + return when (reader.readTag().toInt()) { + 0 -> Interval(TimeDuration.read(reader)) + 1 -> Time(Timestamp.read(reader)) + else -> throw IllegalStateException("Unknown ScheduleAt tag") + } + } + + fun write(writer: BsatnWriter, value: ScheduleAt) { + when (value) { + is Interval -> { + writer.writeTag(0u) + TimeDuration.write(writer, value.every) + } + is Time -> { + writer.writeTag(1u) + Timestamp.write(writer, value.time) + } + } + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt new file mode 100644 index 00000000000..1fdfcbece5f --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt @@ -0,0 +1,419 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.* +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import com.clockworklabs.spacetimedb.websocket.WebSocketTransport +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.incrementAndFetch + +/** Called when a connection is established. Receives the connection, the user's [Identity], and an auth token. */ +typealias ConnectCallback = (DbConnection, Identity, String) -> Unit +/** Called when a connection is lost. The [Throwable] is null for clean disconnects. */ +typealias DisconnectCallback = (DbConnection, Throwable?) -> Unit +/** Called when the initial connection attempt fails. */ +typealias ConnectErrorCallback = (Throwable) -> Unit + +/** + * Primary client for interacting with a SpacetimeDB module. + * + * Create instances via [DbConnection.builder]: + * ```kotlin + * val conn = DbConnection.builder() + * .withUri("ws://localhost:3000") + * .withModuleName("my_module") + * .onConnect { conn, identity, token -> println("Connected as $identity") } + * .build() + * ``` + * + * The connection is opened immediately on [build][DbConnectionBuilder.build]. Use [disconnect] + * to tear it down, or configure automatic reconnection via [DbConnectionBuilder.withReconnectPolicy]. + */ +/** Compression mode negotiated with the server for host→client messages. */ +enum class CompressionMode(internal val queryValue: String) { + NONE("None"), + GZIP("Gzip"), + BROTLI("Brotli"), +} + +/** Result of a procedure invocation, including the server-side [timestamp] and [status]. */ +data class ProcedureResult( + val requestId: UInt, + val timestamp: Timestamp, + val status: ProcedureStatus, +) + +class DbConnection internal constructor( + private val uri: String, + private val moduleName: String, + private val token: String?, + private val connectCallbacks: List, + private val disconnectCallbacks: List, + private val connectErrorCallbacks: List, + private val keepAliveIntervalMs: Long = 30_000L, + private val reconnectPolicy: ReconnectPolicy? = null, + private val compression: CompressionMode = CompressionMode.GZIP, + private val confirmedReads: Boolean? = null, + private val lightMode: Boolean = false, +) : DbContext { + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val requestCounter = AtomicInt(0) + private val mutex = Mutex() + + val clientCache = ClientCache() + private val tableHandles = mutableMapOf() + private val subscriptions = mutableMapOf() + private val subscriptionsByQuerySet = mutableMapOf() + private val reducerCallbacks = mutableMapOf Unit>() + private val procedureCallbacks = mutableMapOf Unit>() + private val pendingOneOffQueries = mutableMapOf>() + + override var identity: Identity? = null + private set + override var connectionId: ConnectionId = ConnectionId.random() + private set + override var savedToken: String? = null + private set + + private val transport = WebSocketTransport( + scope = scope, + onMessage = { handleMessage(it) }, + onConnect = {}, + onDisconnect = { error -> + failPendingOperations() + disconnectCallbacks.forEach { it(this, error) } + }, + onConnectError = { error -> connectErrorCallbacks.forEach { it(error) } }, + keepAliveIntervalMs = keepAliveIntervalMs, + reconnectPolicy = reconnectPolicy, + compression = compression, + connectionId = connectionId, + confirmedReads = confirmedReads, + lightMode = lightMode, + ) + + override val connectionState: StateFlow get() = transport.state + override val isActive: Boolean get() = transport.state.value == ConnectionState.CONNECTED + + init { + transport.connect(uri, moduleName, token) + } + + /** Closes the connection, cancels pending operations, and stops any reconnection attempts. */ + override fun disconnect() { + transport.disconnect() + failPendingOperations() + scope.cancel() + } + + /** Returns the [TableHandle] for [name], creating it if needed. Register callbacks before connecting. */ + fun table(name: String): TableHandle { + // tableHandles is only read/written from user thread (registration) + // and from handleMessage under mutex (firing callbacks). + // Reads from handleMessage never mutate, so this is safe for the + // typical pattern of registering table handles before connecting. + return tableHandles.getOrPut(name) { TableHandle(name) } + } + + /** Creates a [SubscriptionBuilder] for subscribing to SQL queries on this connection. */ + override fun subscriptionBuilder(): SubscriptionBuilder = SubscriptionBuilder(this) + + /** Invokes a server-side reducer by name with BSATN-encoded [args]. Optionally receives the [ReducerResult]. */ + fun callReducer(reducerName: String, args: ByteArray, callback: ((ReducerResult) -> Unit)? = null) { + val reqId = nextRequestId() + if (callback != null) { + // Register synchronously before sending to avoid race with server response + reducerCallbacks[reqId] = callback + } + transport.send( + ClientMessage.CallReducer( + requestId = reqId, + reducer = reducerName, + args = args, + ) + ) + } + + /** Invokes a server-side procedure by name with BSATN-encoded [args]. Optionally receives the [ProcedureResult]. */ + fun callProcedure(procedureName: String, args: ByteArray, callback: ((ProcedureResult) -> Unit)? = null) { + val reqId = nextRequestId() + if (callback != null) { + procedureCallbacks[reqId] = callback + } + transport.send( + ClientMessage.CallProcedure( + requestId = reqId, + procedure = procedureName, + args = args, + ) + ) + } + + /** Executes a one-off SQL query against the module and suspends until the result arrives. */ + suspend fun oneOffQuery(query: String): ServerMessage.OneOffQueryResult { + val reqId = nextRequestId() + val deferred = CompletableDeferred() + mutex.withLock { pendingOneOffQueries[reqId] = deferred } + transport.send(ClientMessage.OneOffQuery(requestId = reqId, queryString = query)) + return deferred.await() + } + + /** Callback variant of [oneOffQuery] — launches a coroutine and invokes [callback] with the result. */ + fun oneOffQuery(query: String, callback: (ServerMessage.OneOffQueryResult) -> Unit) { + val reqId = nextRequestId() + val deferred = CompletableDeferred() + scope.launch { + mutex.withLock { pendingOneOffQueries[reqId] = deferred } + transport.send(ClientMessage.OneOffQuery(requestId = reqId, queryString = query)) + callback(deferred.await()) + } + } + + internal fun subscribe( + queries: List, + handle: SubscriptionHandle, + ): UInt { + val reqId = nextRequestId() + val qsId = QuerySetId(reqId) + handle.querySetId = qsId + handle.requestId = reqId + // Register synchronously before sending to avoid race with server response + subscriptions[reqId] = handle + subscriptionsByQuerySet[qsId] = handle + transport.send( + ClientMessage.Subscribe( + requestId = reqId, + querySetId = qsId, + queryStrings = queries, + ) + ) + return reqId + } + + internal fun unsubscribe(handle: SubscriptionHandle) { + val qsId = handle.querySetId ?: return + val reqId = nextRequestId() + transport.send( + ClientMessage.Unsubscribe( + requestId = reqId, + querySetId = qsId, + flags = 1u, // SendDroppedRows — ensures server sends rows to remove from cache + ) + ) + } + + internal fun unsubscribeThen(handle: SubscriptionHandle, onEnded: () -> Unit) { + val qsId = handle.querySetId ?: return + val reqId = nextRequestId() + handle.onEndedCallback = onEnded + transport.send( + ClientMessage.Unsubscribe( + requestId = reqId, + querySetId = qsId, + flags = 1u, + ) + ) + } + + internal fun pendingCancel(handle: SubscriptionHandle) { + val qsId = handle.querySetId ?: return + subscriptionsByQuerySet.remove(qsId) + handle.requestId.let { subscriptions.remove(it) } + } + + fun unsubscribeAll() { + val handles = subscriptionsByQuerySet.values.toList() + for (handle in handles) { + handle.state = SubscriptionState.ENDED + } + subscriptions.clear() + subscriptionsByQuerySet.clear() + for (handle in handles) { + val qsId = handle.querySetId ?: continue + val reqId = nextRequestId() + transport.send( + ClientMessage.Unsubscribe( + requestId = reqId, + querySetId = qsId, + flags = 1u, + ) + ) + } + } + + private fun nextRequestId(): UInt = requestCounter.incrementAndFetch().toUInt() + + private fun failPendingOperations() { + val error = CancellationException("Connection closed") + pendingOneOffQueries.values.forEach { it.cancel(error) } + pendingOneOffQueries.clear() + reducerCallbacks.clear() + procedureCallbacks.clear() + } + + private suspend fun handleMessage(msg: ServerMessage) { + mutex.withLock { + when (msg) { + is ServerMessage.InitialConnection -> { + identity = msg.identity + connectionId = msg.connectionId + savedToken = msg.token + connectCallbacks.forEach { it(this, msg.identity, msg.token) } + } + + is ServerMessage.SubscribeApplied -> { + val ops = clientCache.applySubscribeRows(msg.rows) + fireTableCallbacks(ops) + val handle = subscriptions[msg.requestId] + if (handle != null) { + if (handle.state == SubscriptionState.CANCELLED) { + // Was unsubscribed before being applied; clean up now. + subscriptions.remove(msg.requestId) + subscriptionsByQuerySet.remove(msg.querySetId) + handle.onEndedCallback?.invoke() + } else { + handle.state = SubscriptionState.ACTIVE + handle.onAppliedCallback?.invoke() + } + } + } + + is ServerMessage.UnsubscribeApplied -> { + msg.rows?.let { clientCache.applyUnsubscribeRows(it) } + val handle = subscriptionsByQuerySet[msg.querySetId] + if (handle != null) { + handle.state = SubscriptionState.ENDED + handle.requestId.let { subscriptions.remove(it) } + subscriptionsByQuerySet.remove(msg.querySetId) + handle.onEndedCallback?.invoke() + } + } + + is ServerMessage.SubscriptionError -> { + val handle = if (msg.requestId != null) { + subscriptions[msg.requestId] + } else { + subscriptionsByQuerySet[msg.querySetId] + } + if (handle != null) { + handle.state = SubscriptionState.ENDED + handle.onErrorCallback?.invoke(msg.error) + handle.requestId.let { subscriptions.remove(it) } + subscriptionsByQuerySet.remove(msg.querySetId) + handle.onEndedCallback?.invoke() + } + } + + is ServerMessage.TransactionUpdate -> { + val ops = clientCache.applyTransactionUpdate(msg.querySets) + fireTableCallbacks(ops) + } + + is ServerMessage.ReducerResult -> { + if (msg.result is ReducerOutcome.Ok) { + val txUpdate = msg.result.transactionUpdate + val ops = clientCache.applyTransactionUpdate(txUpdate.querySets) + fireTableCallbacks(ops) + } + reducerCallbacks.remove(msg.requestId)?.invoke( + ReducerResult(msg.requestId, msg.timestamp, msg.result) + ) + } + + is ServerMessage.ProcedureResult -> { + procedureCallbacks.remove(msg.requestId)?.invoke( + ProcedureResult(msg.requestId, msg.timestamp, msg.status) + ) + } + + is ServerMessage.OneOffQueryResult -> { + pendingOneOffQueries.remove(msg.requestId)?.complete(msg) + } + } + } + } + + private fun fireTableCallbacks(ops: List) { + for (op in ops) { + try { + when (op) { + is TableOperation.Insert -> tableHandles[op.tableName]?.fireInsert(op.row) + is TableOperation.Delete -> tableHandles[op.tableName]?.fireDelete(op.row) + is TableOperation.Update -> tableHandles[op.tableName]?.fireUpdate(op.oldRow, op.newRow) + is TableOperation.EventInsert -> tableHandles[op.tableName]?.fireInsert(op.row) + } + } catch (_: Exception) { + // Don't let user callback exceptions crash the receive loop + } + } + } + + companion object { + fun builder(): DbConnectionBuilder = DbConnectionBuilder() + } +} + +/** Result of a reducer invocation, including the server-side [timestamp] and [outcome]. */ +data class ReducerResult( + val requestId: UInt, + val timestamp: Timestamp, + val outcome: ReducerOutcome, +) + +/** Builder for configuring and creating a [DbConnection]. */ +class DbConnectionBuilder { + private var uri: String? = null + private var moduleName: String? = null + private var token: String? = null + private var keepAliveIntervalMs: Long = 30_000L + private var reconnectPolicy: ReconnectPolicy? = null + private var compression: CompressionMode = CompressionMode.GZIP + private var confirmedReads: Boolean? = null + private var lightMode: Boolean = false + private val connectCallbacks = mutableListOf() + private val disconnectCallbacks = mutableListOf() + private val connectErrorCallbacks = mutableListOf() + + fun withUri(uri: String) = apply { this.uri = uri } + + fun withModuleName(name: String) = apply { this.moduleName = name } + + fun withToken(token: String?) = apply { this.token = token } + + fun onConnect(callback: ConnectCallback) = apply { connectCallbacks.add(callback) } + + fun onDisconnect(callback: DisconnectCallback) = apply { disconnectCallbacks.add(callback) } + + fun onConnectError(callback: ConnectErrorCallback) = apply { connectErrorCallbacks.add(callback) } + + fun withKeepAliveInterval(intervalMs: Long) = apply { this.keepAliveIntervalMs = intervalMs } + + fun withReconnectPolicy(policy: ReconnectPolicy) = apply { this.reconnectPolicy = policy } + + fun withCompression(mode: CompressionMode) = apply { this.compression = mode } + + fun withConfirmedReads(enabled: Boolean) = apply { this.confirmedReads = enabled } + + fun withLightMode(enabled: Boolean) = apply { this.lightMode = enabled } + + fun build(): DbConnection { + val uri = requireNotNull(uri) { "URI is required. Call withUri() before build()." } + val module = requireNotNull(moduleName) { "Module name is required. Call withModuleName() before build()." } + return DbConnection( + uri = uri, + moduleName = module, + token = token, + connectCallbacks = connectCallbacks.toList(), + disconnectCallbacks = disconnectCallbacks.toList(), + connectErrorCallbacks = connectErrorCallbacks.toList(), + keepAliveIntervalMs = keepAliveIntervalMs, + reconnectPolicy = reconnectPolicy, + compression = compression, + confirmedReads = confirmedReads, + lightMode = lightMode, + ) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt new file mode 100644 index 00000000000..b1781f725e2 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt @@ -0,0 +1,48 @@ +package com.clockworklabs.spacetimedb + +/** + * Builder for subscribing to SQL queries on a [DbConnection]. + * + * ```kotlin + * conn.subscriptionBuilder() + * .onApplied { println("Subscription active") } + * .onError { err -> println("Subscription failed: $err") } + * .subscribe("SELECT * FROM users WHERE online = true) + * ``` + */ +class SubscriptionBuilder(private val connection: DbConnection) { + private var onAppliedCallback: (() -> Unit)? = null + private var onErrorCallback: ((String) -> Unit)? = null + private var onEndedCallback: (() -> Unit)? = null + private val pendingQueries = mutableListOf() + + fun onApplied(callback: () -> Unit) = apply { this.onAppliedCallback = callback } + + fun onError(callback: (String) -> Unit) = apply { this.onErrorCallback = callback } + + fun onEnded(callback: () -> Unit) = apply { this.onEndedCallback = callback } + + /** Add a raw SQL query string to the pending list. */ + fun addQuery(query: String): SubscriptionBuilder = apply { pendingQueries.add(query) } + + /** Add a query from a [QueryProvider]. Used by generated typed extensions. */ + fun addQueryFrom(provider: com.clockworklabs.spacetimedb.query.QueryProvider): SubscriptionBuilder = + apply { pendingQueries.add(provider.toSql()) } + + fun subscribe(vararg queries: String): SubscriptionHandle { + val allQueries = (pendingQueries + queries).toList() + pendingQueries.clear() + val handle = SubscriptionHandle( + connection = connection, + onAppliedCallback = onAppliedCallback, + onErrorCallback = onErrorCallback, + onEndedCallback = onEndedCallback, + ) + connection.subscribe(allQueries, handle) + return handle + } + + fun subscribeToAllTables(): SubscriptionHandle { + return subscribe("SELECT * FROM *") + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionEventContext.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionEventContext.kt new file mode 100644 index 00000000000..6fdc088a775 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionEventContext.kt @@ -0,0 +1,13 @@ +package com.clockworklabs.spacetimedb + +class SubscriptionEventContext( + override val identity: Identity?, + override val connectionId: ConnectionId, + override val savedToken: String?, + override val isActive: Boolean, + override val connectionState: kotlinx.coroutines.flow.StateFlow, + private val conn: DbConnection, +) : DbContext { + override fun disconnect() = conn.disconnect() + override fun subscriptionBuilder() = conn.subscriptionBuilder() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt new file mode 100644 index 00000000000..a1635f2604a --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt @@ -0,0 +1,60 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.QuerySetId + +/** Lifecycle states of a subscription. */ +enum class SubscriptionState { + PENDING, + ACTIVE, + ENDED, + CANCELLED, +} + +/** + * Represents an active subscription to one or more SQL queries. + * + * Created by [SubscriptionBuilder.subscribe]. Call [unsubscribe] to end it. + */ +class SubscriptionHandle internal constructor( + private val connection: DbConnection, + internal val onAppliedCallback: (() -> Unit)?, + internal val onErrorCallback: ((String) -> Unit)?, + internal var onEndedCallback: (() -> Unit)? = null, +) { + internal var querySetId: QuerySetId? = null + internal var requestId: UInt = 0u + var state: SubscriptionState = SubscriptionState.PENDING + internal set + + val isActive: Boolean get() = state == SubscriptionState.ACTIVE + val isEnded: Boolean get() = state == SubscriptionState.ENDED + + fun unsubscribe() { + when (state) { + SubscriptionState.PENDING -> { + state = SubscriptionState.CANCELLED + connection.pendingCancel(this) + } + SubscriptionState.ACTIVE -> { + state = SubscriptionState.ENDED + connection.unsubscribe(this) + } + else -> {} + } + } + + fun unsubscribeThen(onEnded: () -> Unit) { + when (state) { + SubscriptionState.PENDING -> { + state = SubscriptionState.CANCELLED + connection.pendingCancel(this) + onEnded() + } + SubscriptionState.ACTIVE -> { + state = SubscriptionState.ENDED + connection.unsubscribeThen(this, onEnded) + } + else -> {} + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Table.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Table.kt new file mode 100644 index 00000000000..0579934fd55 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Table.kt @@ -0,0 +1,22 @@ +package com.clockworklabs.spacetimedb + +interface Table { + val tableName: String + val count: Int + fun iter(): Sequence + fun onInsert(callback: (EventContext<*>, TRow) -> Unit): CallbackId + fun removeOnInsert(id: CallbackId) + fun onDelete(callback: (EventContext<*>, TRow) -> Unit): CallbackId + fun removeOnDelete(id: CallbackId) +} + +interface TableWithPrimaryKey : Table { + fun onUpdate(callback: (EventContext<*>, TRow, TRow) -> Unit): CallbackId + fun removeOnUpdate(id: CallbackId) +} + +interface EventTable { + val tableName: String + fun onInsert(callback: (EventContext<*>, TRow) -> Unit): CallbackId + fun removeOnInsert(id: CallbackId) +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt new file mode 100644 index 00000000000..5a78f6019ea --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt @@ -0,0 +1,65 @@ +package com.clockworklabs.spacetimedb + +typealias InsertCallback = (ByteArray) -> Unit +typealias DeleteCallback = (ByteArray) -> Unit +typealias UpdateCallback = (oldRow: ByteArray, newRow: ByteArray) -> Unit + +/** + * Handle for observing row changes on a single table. + * + * Obtain via [DbConnection.table]. Register callbacks with [onInsert], [onDelete], + * and [onUpdate]; remove them later with the returned [CallbackId]. + */ +class TableHandle(val tableName: String) { + private var nextId = 0 + private val insertCallbacks = mutableMapOf() + private val deleteCallbacks = mutableMapOf() + private val updateCallbacks = mutableMapOf() + + fun onInsert(callback: InsertCallback): CallbackId { + val id = nextId++ + insertCallbacks[id] = callback + return CallbackId(id) + } + + fun onDelete(callback: DeleteCallback): CallbackId { + val id = nextId++ + deleteCallbacks[id] = callback + return CallbackId(id) + } + + fun onUpdate(callback: UpdateCallback): CallbackId { + val id = nextId++ + updateCallbacks[id] = callback + return CallbackId(id) + } + + fun removeOnInsert(id: CallbackId) { + insertCallbacks.remove(id.value) + } + + fun removeOnDelete(id: CallbackId) { + deleteCallbacks.remove(id.value) + } + + fun removeOnUpdate(id: CallbackId) { + updateCallbacks.remove(id.value) + } + + internal fun fireInsert(row: ByteArray) { + // Snapshot to allow callbacks to register/remove other callbacks safely + for (cb in insertCallbacks.values.toList()) cb(row) + } + + internal fun fireDelete(row: ByteArray) { + for (cb in deleteCallbacks.values.toList()) cb(row) + } + + internal fun fireUpdate(oldRow: ByteArray, newRow: ByteArray) { + for (cb in updateCallbacks.values.toList()) cb(oldRow, newRow) + } +} + +/** Opaque identifier returned by callback registration methods. Used to remove the callback later. */ +@kotlin.jvm.JvmInline +value class CallbackId(val value: Int) diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Timestamp.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Timestamp.kt new file mode 100644 index 00000000000..7035b00a98d --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Timestamp.kt @@ -0,0 +1,15 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import kotlin.jvm.JvmInline + +/** Server-side timestamp in microseconds since the Unix epoch. */ +@JvmInline +value class Timestamp(val microseconds: Long) { + companion object { + fun read(reader: BsatnReader): Timestamp = Timestamp(reader.readI64()) + + fun write(writer: BsatnWriter, value: Timestamp) { writer.writeI64(value.microseconds) } + } +} \ No newline at end of file diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Uuid.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Uuid.kt new file mode 100644 index 00000000000..bddc4b4a1e6 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Uuid.kt @@ -0,0 +1,24 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import kotlin.uuid.Uuid + +fun Uuid.Companion.read(reader: BsatnReader): Uuid { + val msb = reader.readI64() + val lsb = reader.readI64() + val bytes = ByteArray(16) + for (i in 0 until 8) bytes[7 - i] = (msb shr (i * 8)).toByte() + for (i in 0 until 8) bytes[15 - i] = (lsb shr (i * 8)).toByte() + return Uuid.fromByteArray(bytes) +} + +fun Uuid.Companion.write(writer: BsatnWriter, value: Uuid) { + val bytes = value.toByteArray() + var msb = 0L + var lsb = 0L + for (i in 0 until 8) msb = msb.shl(8) or (bytes[i].toLong() and 0xFF) + for (i in 8 until 16) lsb = lsb.shl(8) or (bytes[i].toLong() and 0xFF) + writer.writeI64(msb) + writer.writeI64(lsb) +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt new file mode 100644 index 00000000000..85cd9ee81ef --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt @@ -0,0 +1,121 @@ +package com.clockworklabs.spacetimedb.bsatn + +class BsatnReader(private val data: ByteArray, private var offset: Int = 0) { + + val remaining: Int get() = data.size - offset + + val isExhausted: Boolean get() = offset >= data.size + + private fun require(count: Int) { + if (count < 0 || offset + count > data.size) { + val remaining = data.size - offset + throw IllegalStateException( + "BSATN: unexpected end of data at offset $offset, " + + "need $count bytes but only $remaining remain" + ) + } + } + + fun readU8(): UByte { + require(1) + return data[offset++].toUByte() + } + + fun readI8(): Byte { + require(1) + return data[offset++] + } + + fun readBool(): Boolean = readU8().toInt() != 0 + + fun readU16(): UShort { + require(2) + val v = (data[offset].toUByte().toInt() or (data[offset + 1].toUByte().toInt() shl 8)).toUShort() + offset += 2 + return v + } + + fun readI16(): Short { + require(2) + val v = (data[offset].toUByte().toInt() or (data[offset + 1].toUByte().toInt() shl 8)).toShort() + offset += 2 + return v + } + + fun readU32(): UInt { + require(4) + val v = (data[offset].toUByte().toUInt()) or + (data[offset + 1].toUByte().toUInt() shl 8) or + (data[offset + 2].toUByte().toUInt() shl 16) or + (data[offset + 3].toUByte().toUInt() shl 24) + offset += 4 + return v + } + + fun readI32(): Int { + require(4) + val v = (data[offset].toUByte().toInt()) or + (data[offset + 1].toUByte().toInt() shl 8) or + (data[offset + 2].toUByte().toInt() shl 16) or + (data[offset + 3].toUByte().toInt() shl 24) + offset += 4 + return v + } + + fun readU64(): ULong { + require(8) + var v = 0UL + for (i in 0 until 8) { + v = v or (data[offset + i].toUByte().toULong() shl (i * 8)) + } + offset += 8 + return v + } + + fun readI64(): Long { + require(8) + var v = 0L + for (i in 0 until 8) { + v = v or ((data[offset + i].toUByte().toLong()) shl (i * 8)) + } + offset += 8 + return v + } + + fun readF32(): Float = Float.fromBits(readI32()) + + fun readF64(): Double = Double.fromBits(readI64()) + + fun readBytes(count: Int): ByteArray { + require(count) + val result = data.copyOfRange(offset, offset + count) + offset += count + return result + } + + fun readByteArray(): ByteArray { + val len = readU32().toInt() + return readBytes(len) + } + + fun readString(): String { + val bytes = readByteArray() + return bytes.decodeToString() + } + + fun readTag(): UByte = readU8() + + fun readArray(readElement: (BsatnReader) -> T): List { + val count = readU32().toInt() + return List(count) { readElement(this) } + } + + fun readOption(readElement: (BsatnReader) -> T): T? { + return when (readTag().toInt()) { + // SATS serializes: Some = tag 0, None = tag 1 + 0 -> readElement(this) + 1 -> null + else -> throw IllegalStateException("Invalid Option tag") + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt new file mode 100644 index 00000000000..328a13e41d7 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt @@ -0,0 +1,70 @@ +package com.clockworklabs.spacetimedb.bsatn + +sealed class RowSizeHint { + data class FixedSize(val rowSize: UShort) : RowSizeHint() + data class RowOffsets(val offsets: List) : RowSizeHint() + + companion object { + fun read(reader: BsatnReader): RowSizeHint { + return when (reader.readTag().toInt()) { + 0 -> FixedSize(reader.readU16()) + 1 -> RowOffsets(reader.readArray { it.readU64() }) + else -> throw IllegalStateException("Invalid RowSizeHint tag") + } + } + + fun write(writer: BsatnWriter, value: RowSizeHint) { + when (value) { + is FixedSize -> { + writer.writeTag(0u) + writer.writeU16(value.rowSize) + } + is RowOffsets -> { + writer.writeTag(1u) + writer.writeArray(value.offsets) { w, v -> w.writeU64(v) } + } + } + } + } +} + +class BsatnRowList( + val sizeHint: RowSizeHint, + val rowsData: ByteArray, +) { + fun decodeRows(): List { + if (rowsData.isEmpty()) return emptyList() + + return when (val hint = sizeHint) { + is RowSizeHint.FixedSize -> { + val rowSize = hint.rowSize.toInt() + if (rowSize == 0) return emptyList() + val count = rowsData.size / rowSize + List(count) { i -> + rowsData.copyOfRange(i * rowSize, (i + 1) * rowSize) + } + } + is RowSizeHint.RowOffsets -> { + val offsets = hint.offsets + List(offsets.size) { i -> + val start = offsets[i].toInt() + val end = if (i + 1 < offsets.size) offsets[i + 1].toInt() else rowsData.size + rowsData.copyOfRange(start, end) + } + } + } + } + + companion object { + fun read(reader: BsatnReader): BsatnRowList { + val sizeHint = RowSizeHint.read(reader) + val rowsData = reader.readByteArray() + return BsatnRowList(sizeHint, rowsData) + } + + fun write(writer: BsatnWriter, value: BsatnRowList) { + RowSizeHint.write(writer, value.sizeHint) + writer.writeByteArray(value.rowsData) + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt new file mode 100644 index 00000000000..567b0a5ca63 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt @@ -0,0 +1,113 @@ +package com.clockworklabs.spacetimedb.bsatn + +class BsatnWriter(initialCapacity: Int = 256) { + + private var buffer = ByteArray(initialCapacity) + private var position = 0 + + private fun ensureCapacity(needed: Int) { + val required = position + needed + if (required > buffer.size) { + val newSize = maxOf(buffer.size * 2, required) + buffer = buffer.copyOf(newSize) + } + } + + fun writeBool(value: Boolean) { + writeU8(if (value) 1u else 0u) + } + + fun writeU8(value: UByte) { + ensureCapacity(1) + buffer[position++] = value.toByte() + } + + fun writeI8(value: Byte) { + ensureCapacity(1) + buffer[position++] = value + } + + fun writeU16(value: UShort) { + ensureCapacity(2) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + } + + fun writeI16(value: Short) { + ensureCapacity(2) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + } + + fun writeU32(value: UInt) { + ensureCapacity(4) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + buffer[position++] = (v shr 16).toByte() + buffer[position++] = (v shr 24).toByte() + } + + fun writeI32(value: Int) { + ensureCapacity(4) + buffer[position++] = value.toByte() + buffer[position++] = (value shr 8).toByte() + buffer[position++] = (value shr 16).toByte() + buffer[position++] = (value shr 24).toByte() + } + + fun writeU64(value: ULong) { + ensureCapacity(8) + val v = value.toLong() + for (i in 0 until 8) { + buffer[position++] = (v shr (i * 8)).toByte() + } + } + + fun writeI64(value: Long) { + ensureCapacity(8) + for (i in 0 until 8) { + buffer[position++] = (value shr (i * 8)).toByte() + } + } + + fun writeF32(value: Float) { writeI32(value.toRawBits()) } + + fun writeF64(value: Double) { writeI64(value.toRawBits()) } + + fun writeBytes(bytes: ByteArray) { + ensureCapacity(bytes.size) + bytes.copyInto(buffer, position) + position += bytes.size + } + + fun writeByteArray(bytes: ByteArray) { + writeU32(bytes.size.toUInt()) + writeBytes(bytes) + } + + fun writeString(value: String) { + val bytes = value.encodeToByteArray() + writeByteArray(bytes) + } + + fun writeTag(tag: UByte) { writeU8(tag) } + + fun writeArray(items: List, writeElement: (BsatnWriter, T) -> Unit) { + writeU32(items.size.toUInt()) + items.forEach { writeElement(this, it) } + } + + fun writeOption(value: T?, writeElement: (BsatnWriter, T) -> Unit) { + if (value == null) { + writeTag(1u) + } else { + writeTag(0u) + writeElement(this, value) + } + } + + fun toByteArray(): ByteArray = buffer.copyOf(position) +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt new file mode 100644 index 00000000000..7fa027f97e7 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt @@ -0,0 +1,100 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +sealed class ClientMessage { + data class Subscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val queryStrings: List, + ) : ClientMessage() + + data class Unsubscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val flags: UByte = 0u, + ) : ClientMessage() + + data class OneOffQuery( + val requestId: UInt, + val queryString: String, + ) : ClientMessage() + + data class CallReducer( + val requestId: UInt, + val reducer: String, + val args: ByteArray, + val flags: UByte = 0u, + ) : ClientMessage() { + override fun equals(other: Any?): Boolean = + other is CallReducer && requestId == other.requestId && + reducer == other.reducer && args.contentEquals(other.args) && + flags == other.flags + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + reducer.hashCode() + result = 31 * result + args.contentHashCode() + result = 31 * result + flags.hashCode() + return result + } + } + + data class CallProcedure( + val requestId: UInt, + val procedure: String, + val args: ByteArray, + val flags: UByte = 0u, + ) : ClientMessage() { + override fun equals(other: Any?): Boolean = + other is CallProcedure && requestId == other.requestId && + procedure == other.procedure && args.contentEquals(other.args) && + flags == other.flags + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + procedure.hashCode() + result = 31 * result + args.contentHashCode() + result = 31 * result + flags.hashCode() + return result + } + } + + fun encode(): ByteArray { + val writer = BsatnWriter() + when (this) { + is Subscribe -> { + writer.writeTag(0u) + writer.writeU32(requestId) + QuerySetId.write(writer, querySetId) + writer.writeArray(queryStrings) { w, s -> w.writeString(s) } + } + is Unsubscribe -> { + writer.writeTag(1u) + writer.writeU32(requestId) + QuerySetId.write(writer, querySetId) + writer.writeU8(flags) + } + is OneOffQuery -> { + writer.writeTag(2u) + writer.writeU32(requestId) + writer.writeString(queryString) + } + is CallReducer -> { + writer.writeTag(3u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(reducer) + writer.writeByteArray(args) + } + is CallProcedure -> { + writer.writeTag(4u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(procedure) + writer.writeByteArray(args) + } + } + return writer.toByteArray() + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt new file mode 100644 index 00000000000..62be95fb5fc --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt @@ -0,0 +1,160 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnRowList +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +@kotlin.jvm.JvmInline +value class QuerySetId(val id: UInt) { + companion object { + fun read(reader: BsatnReader): QuerySetId = QuerySetId(reader.readU32()) + fun write(writer: BsatnWriter, value: QuerySetId) { writer.writeU32(value.id) } + } +} + +@kotlin.jvm.JvmInline +value class RawIdentifier(val value: String) { + companion object { + fun read(reader: BsatnReader): RawIdentifier = RawIdentifier(reader.readString()) + fun write(writer: BsatnWriter, value: RawIdentifier) { writer.writeString(value.value) } + } +} + +data class SingleTableRows( + val table: RawIdentifier, + val rows: BsatnRowList, +) { + companion object { + fun read(reader: BsatnReader): SingleTableRows = SingleTableRows( + table = RawIdentifier.read(reader), + rows = BsatnRowList.read(reader), + ) + } +} + +data class QueryRows(val tables: List) { + companion object { + fun read(reader: BsatnReader): QueryRows = + QueryRows(reader.readArray { SingleTableRows.read(it) }) + } +} + +sealed class TableUpdateRows { + data class PersistentTable(val rows: PersistentTableRows) : TableUpdateRows() + data class EventTable(val rows: EventTableRows) : TableUpdateRows() + + companion object { + fun read(reader: BsatnReader): TableUpdateRows { + return when (reader.readTag().toInt()) { + 0 -> PersistentTable(PersistentTableRows.read(reader)) + 1 -> EventTable(EventTableRows.read(reader)) + else -> throw IllegalStateException("Invalid TableUpdateRows tag") + } + } + } +} + +data class PersistentTableRows( + val inserts: BsatnRowList, + val deletes: BsatnRowList, +) { + companion object { + fun read(reader: BsatnReader): PersistentTableRows = PersistentTableRows( + inserts = BsatnRowList.read(reader), + deletes = BsatnRowList.read(reader), + ) + } +} + +data class EventTableRows(val events: BsatnRowList) { + companion object { + fun read(reader: BsatnReader): EventTableRows = + EventTableRows(BsatnRowList.read(reader)) + } +} + +data class TableUpdate( + val tableName: RawIdentifier, + val rows: List, +) { + companion object { + fun read(reader: BsatnReader): TableUpdate = TableUpdate( + tableName = RawIdentifier.read(reader), + rows = reader.readArray { TableUpdateRows.read(it) }, + ) + } +} + +data class QuerySetUpdate( + val querySetId: QuerySetId, + val tables: List, +) { + companion object { + fun read(reader: BsatnReader): QuerySetUpdate = QuerySetUpdate( + querySetId = QuerySetId.read(reader), + tables = reader.readArray { TableUpdate.read(it) }, + ) + } +} + +sealed class ReducerOutcome { + data class Ok(val retValue: ByteArray, val transactionUpdate: TransactionUpdateData) : ReducerOutcome() { + override fun equals(other: Any?): Boolean = + other is Ok && retValue.contentEquals(other.retValue) && transactionUpdate == other.transactionUpdate + override fun hashCode(): Int = retValue.contentHashCode() * 31 + transactionUpdate.hashCode() + } + data object OkEmpty : ReducerOutcome() + data class Err(val message: ByteArray) : ReducerOutcome() { + override fun equals(other: Any?): Boolean = other is Err && message.contentEquals(other.message) + override fun hashCode(): Int = message.contentHashCode() + } + data class InternalError(val message: String) : ReducerOutcome() + + companion object { + fun read(reader: BsatnReader): ReducerOutcome { + return when (reader.readTag().toInt()) { + 0 -> Ok( + retValue = reader.readByteArray(), + transactionUpdate = TransactionUpdateData.read(reader), + ) + 1 -> OkEmpty + 2 -> Err(reader.readByteArray()) + 3 -> InternalError(reader.readString()) + else -> throw IllegalStateException("Invalid ReducerOutcome tag") + } + } + } +} + +data class TransactionUpdateData(val querySets: List) { + companion object { + fun read(reader: BsatnReader): TransactionUpdateData = + TransactionUpdateData(reader.readArray { QuerySetUpdate.read(it) }) + } +} + +sealed class ProcedureStatus { + data class Returned(val data: ByteArray) : ProcedureStatus() { + override fun equals(other: Any?): Boolean = other is Returned && data.contentEquals(other.data) + override fun hashCode(): Int = data.contentHashCode() + } + data class InternalError(val message: String) : ProcedureStatus() + + companion object { + fun read(reader: BsatnReader): ProcedureStatus { + return when (reader.readTag().toInt()) { + 0 -> Returned(reader.readByteArray()) + 1 -> InternalError(reader.readString()) + else -> throw IllegalStateException("Invalid ProcedureStatus tag") + } + } + } +} + +@kotlin.jvm.JvmInline +value class TimeDuration(val microseconds: ULong) { + companion object { + fun read(reader: BsatnReader): TimeDuration = TimeDuration(reader.readU64()) + fun write(writer: BsatnWriter, value: TimeDuration) { writer.writeU64(value.microseconds) } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt new file mode 100644 index 00000000000..aeaae132d36 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt @@ -0,0 +1,107 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.ConnectionId +import com.clockworklabs.spacetimedb.Identity +import com.clockworklabs.spacetimedb.Timestamp +import com.clockworklabs.spacetimedb.bsatn.BsatnReader + +sealed class ServerMessage { + data class InitialConnection( + val identity: Identity, + val connectionId: ConnectionId, + val token: String, + ) : ServerMessage() + + data class SubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows, + ) : ServerMessage() + + data class UnsubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows?, + ) : ServerMessage() + + data class SubscriptionError( + val requestId: UInt?, + val querySetId: QuerySetId, + val error: String, + ) : ServerMessage() + + data class TransactionUpdate( + val querySets: List, + ) : ServerMessage() + + data class OneOffQueryResult( + val requestId: UInt, + val rows: QueryRows?, + val error: String?, + ) : ServerMessage() + + data class ReducerResult( + val requestId: UInt, + val timestamp: Timestamp, + val result: ReducerOutcome, + ) : ServerMessage() + + data class ProcedureResult( + val requestId: UInt, + val timestamp: Timestamp, + val status: ProcedureStatus, + val totalHostExecutionDuration: TimeDuration, + ) : ServerMessage() + + companion object { + fun decode(data: ByteArray): ServerMessage { + val reader = BsatnReader(data) + return when (reader.readTag().toInt()) { + 0 -> InitialConnection( + identity = Identity.read(reader), + connectionId = ConnectionId.read(reader), + token = reader.readString(), + ) + 1 -> SubscribeApplied( + requestId = reader.readU32(), + querySetId = QuerySetId.read(reader), + rows = QueryRows.read(reader), + ) + 2 -> UnsubscribeApplied( + requestId = reader.readU32(), + querySetId = QuerySetId.read(reader), + rows = reader.readOption { QueryRows.read(it) }, + ) + 3 -> SubscriptionError( + requestId = reader.readOption { it.readU32() }, + querySetId = QuerySetId.read(reader), + error = reader.readString(), + ) + 4 -> TransactionUpdate( + querySets = reader.readArray { QuerySetUpdate.read(it) }, + ) + 5 -> { + val requestId = reader.readU32() + when (reader.readTag().toInt()) { + 0 -> OneOffQueryResult(requestId, QueryRows.read(reader), null) + 1 -> OneOffQueryResult(requestId, null, reader.readString()) + else -> throw IllegalStateException("Invalid OneOffQueryResult Result tag") + } + } + 6 -> ReducerResult( + requestId = reader.readU32(), + timestamp = Timestamp.read(reader), + result = ReducerOutcome.read(reader), + ) + 7 -> ProcedureResult( + status = ProcedureStatus.read(reader), + timestamp = Timestamp.read(reader), + totalHostExecutionDuration = TimeDuration.read(reader), + requestId = reader.readU32(), + ) + else -> throw IllegalStateException("Unknown ServerMessage tag") + } + } + + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/query/QueryBuilder.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/query/QueryBuilder.kt new file mode 100644 index 00000000000..86f5044f351 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/query/QueryBuilder.kt @@ -0,0 +1,75 @@ +package com.clockworklabs.spacetimedb.query + +/** Marker interface for types that can produce SQL queries. */ +fun interface QueryProvider { + fun toSql(): String +} + +/** A typed column reference within a table [T]. */ +class Col(val columnName: String) { + fun eq(value: V): BoolExpr = BoolExpr.Eq(this, value) + fun ne(value: V): BoolExpr = BoolExpr.Ne(this, value) + fun gt(value: V): BoolExpr = BoolExpr.Gt(this, value) + fun lt(value: V): BoolExpr = BoolExpr.Lt(this, value) + fun gte(value: V): BoolExpr = BoolExpr.Gte(this, value) + fun lte(value: V): BoolExpr = BoolExpr.Lte(this, value) +} + +/** A boolean expression for WHERE clauses. */ +sealed class BoolExpr { + data class Eq(val col: Col<*>, val value: Any?) : BoolExpr() + data class Ne(val col: Col<*>, val value: Any?) : BoolExpr() + data class Gt(val col: Col<*>, val value: Any?) : BoolExpr() + data class Lt(val col: Col<*>, val value: Any?) : BoolExpr() + data class Gte(val col: Col<*>, val value: Any?) : BoolExpr() + data class Lte(val col: Col<*>, val value: Any?) : BoolExpr() + data class And(val left: BoolExpr, val right: BoolExpr) : BoolExpr() + data class Or(val left: BoolExpr, val right: BoolExpr) : BoolExpr() + + infix fun and(other: BoolExpr): BoolExpr = And(this, other) + infix fun or(other: BoolExpr): BoolExpr = Or(this, other) + + internal fun toSql(tableName: String): String = when (this) { + is Eq -> "\"${tableName}\".\"${col.columnName}\" = ${formatValue(value)}" + is Ne -> "\"${tableName}\".\"${col.columnName}\" != ${formatValue(value)}" + is Gt -> "\"${tableName}\".\"${col.columnName}\" > ${formatValue(value)}" + is Lt -> "\"${tableName}\".\"${col.columnName}\" < ${formatValue(value)}" + is Gte -> "\"${tableName}\".\"${col.columnName}\" >= ${formatValue(value)}" + is Lte -> "\"${tableName}\".\"${col.columnName}\" <= ${formatValue(value)}" + is And -> "(${left.toSql(tableName)} AND ${right.toSql(tableName)})" + is Or -> "(${left.toSql(tableName)} OR ${right.toSql(tableName)})" + } + + private fun formatValue(value: Any?): String = when (value) { + null -> "NULL" + is String -> "'${value.replace("'", "''")}'" + is Boolean -> if (value) "true" else "false" + is Number -> value.toString() + else -> "'$value'" + } +} + +/** Base class for generated column accessor structs. */ +abstract class Cols(val tableName: String) + +/** + * A reference to a table, providing typed query building. + * + * Usage: + * ``` + * QueryTable("users") { UsersCols(it) }.where { cols -> cols.age.gt(18) } + * ``` + */ +class QueryTable(val tableName: String, private val colsFactory: (String) -> Cols) : QueryProvider { + override fun toSql(): String = """SELECT * FROM "$tableName"""" + + /** Add a WHERE clause using the table's generated column accessors. */ + fun where(block: Cols.() -> BoolExpr): FromQuery = + FromQuery(tableName, block(colsFactory(tableName))) +} + +/** A table reference with a WHERE clause attached. */ +class FromQuery(val tableName: String, val expr: BoolExpr) : QueryProvider { + override fun toSql(): String = """SELECT * FROM "$tableName" WHERE ${expr.toSql(tableName)}""" +} + diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt new file mode 100644 index 00000000000..0b4bf859715 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt @@ -0,0 +1,284 @@ +package com.clockworklabs.spacetimedb.websocket + +import com.clockworklabs.spacetimedb.CompressionMode +import com.clockworklabs.spacetimedb.ConnectionId +import com.clockworklabs.spacetimedb.ReconnectPolicy +import com.clockworklabs.spacetimedb.decompressBrotli +import com.clockworklabs.spacetimedb.decompressGzip +import com.clockworklabs.spacetimedb.protocol.ClientMessage +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import io.ktor.client.* +import io.ktor.client.call.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.ktor.websocket.* +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.launch +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.time.Duration.Companion.milliseconds + +private val HEX = "0123456789ABCDEF".toCharArray() + +enum class ConnectionState { + DISCONNECTED, + CONNECTING, + CONNECTED, + RECONNECTING, +} + +class WebSocketTransport( + private val scope: CoroutineScope, + private val onMessage: suspend (ServerMessage) -> Unit, + private val onConnect: () -> Unit, + private val onDisconnect: (Throwable?) -> Unit, + private val onConnectError: (Throwable) -> Unit, + private val keepAliveIntervalMs: Long = 30_000L, + private val reconnectPolicy: ReconnectPolicy? = null, + private val compression: CompressionMode = CompressionMode.GZIP, + private val connectionId: ConnectionId = ConnectionId.random(), + private val confirmedReads: Boolean? = null, + private val lightMode: Boolean = false, +) { + private val client = HttpClient { + install(WebSockets) { + pingInterval = keepAliveIntervalMs.milliseconds + } + } + + private val _state = MutableStateFlow(ConnectionState.DISCONNECTED) + val state: StateFlow = _state + + private val outboundQueue = Channel(Channel.UNLIMITED) + private var session: DefaultClientWebSocketSession? = null + private var connectJob: Job? = null + private val intentionalDisconnect = AtomicBoolean(false) + + // Tracks whether any data has arrived since the last keep-alive check. + // Used to send pings only when idle to avoid flooding the server. + private val idle = AtomicBoolean(true) + private val wantPong = AtomicBoolean(false) + + fun connect(uri: String, moduleName: String, token: String?) { + if (_state.value != ConnectionState.DISCONNECTED) return + intentionalDisconnect.store(false) + _state.value = ConnectionState.CONNECTING + + connectJob = scope.launch { + runConnection(uri, moduleName, token) + } + } + + private suspend fun runConnection(uri: String, moduleName: String, token: String?) { + try { + connectSession(uri, moduleName, token) + // Session ended normally + if (!intentionalDisconnect.load() && reconnectPolicy != null) { + attemptReconnect(uri, moduleName, token) + } else { + _state.value = ConnectionState.DISCONNECTED + onDisconnect(null) + } + } catch (_: CancellationException) { + _state.value = ConnectionState.DISCONNECTED + if (!intentionalDisconnect.load()) { + onDisconnect(null) + } + } catch (e: Throwable) { + val previousState = _state.value + if (!intentionalDisconnect.load() && reconnectPolicy != null) { + attemptReconnect(uri, moduleName, token) + } else if (previousState == ConnectionState.CONNECTING) { + _state.value = ConnectionState.DISCONNECTED + onConnectError(e) + } else { + _state.value = ConnectionState.DISCONNECTED + onDisconnect(e) + } + } + } + + private suspend fun exchangeToken(baseUri: String, token: String): String { + val httpBase = when { + baseUri.startsWith("ws://") -> "http://" + baseUri.removePrefix("ws://") + baseUri.startsWith("wss://") -> "https://" + baseUri.removePrefix("wss://") + baseUri.startsWith("http://") || baseUri.startsWith("https://") -> baseUri + else -> "http://$baseUri" + } + val response = client.post("${httpBase.trimEnd('/')}/v1/identity/websocket-token") { + header(HttpHeaders.Authorization, "Bearer $token") + } + val body = response.body().decodeToString() + // Parse {"token":"..."} from JSON + val tokenKey = "\"token\":\"" + val start = body.indexOf(tokenKey) + if (start < 0) throw IllegalStateException("Token exchange failed: $body") + val valueStart = start + tokenKey.length + val valueEnd = body.indexOf('"', valueStart) + return body.substring(valueStart, valueEnd) + } + + private suspend fun connectSession(uri: String, moduleName: String, token: String?) { + val wsToken = if (token != null) exchangeToken(uri, token) else null + val wsUri = buildWsUri(uri, moduleName, wsToken) + client.webSocket( + urlString = wsUri, + request = { + headers.append("Sec-WebSocket-Protocol", "v2.bsatn.spacetimedb") + } + ) { + session = this + idle.store(true) + wantPong.store(false) + _state.value = ConnectionState.CONNECTED + onConnect() + + val sendJob = launch { processSendQueue() } + val receiveJob = launch { processIncoming() } + + receiveJob.join() + sendJob.cancelAndJoin() + } + } + + private suspend fun attemptReconnect(uri: String, moduleName: String, token: String?) { + val policy = reconnectPolicy ?: return + _state.value = ConnectionState.RECONNECTING + + for (attempt in 0 until policy.maxRetries) { + if (intentionalDisconnect.load()) { + _state.value = ConnectionState.DISCONNECTED + return + } + + val delayMs = policy.delayForAttempt(attempt) + delay(delayMs.milliseconds) + + if (intentionalDisconnect.load()) { + _state.value = ConnectionState.DISCONNECTED + return + } + + try { + connectSession(uri, moduleName, token) + // If connectSession returns normally, the session ended cleanly. + // If we still want to reconnect (not intentionally), loop again. + if (intentionalDisconnect.load()) { + _state.value = ConnectionState.DISCONNECTED + return + } + _state.value = ConnectionState.RECONNECTING + } catch (_: CancellationException) { + _state.value = ConnectionState.DISCONNECTED + return + } catch (_: Throwable) { + // Connection attempt failed — continue to next retry + _state.value = ConnectionState.RECONNECTING + } + } + + // Exhausted all retries + _state.value = ConnectionState.DISCONNECTED + onDisconnect(null) + } + + fun disconnect() { + intentionalDisconnect.store(true) + connectJob?.cancel() + session = null + _state.value = ConnectionState.DISCONNECTED + client.close() + } + + fun send(message: ClientMessage) { + val encoded = message.encode() + outboundQueue.trySend(encoded) + } + + private suspend fun DefaultClientWebSocketSession.processSendQueue() { + for (bytes in outboundQueue) { + send(Frame.Binary(true, bytes)) + } + } + + private suspend fun DefaultClientWebSocketSession.processIncoming() { + for (frame in incoming) { + when (frame) { + is Frame.Binary -> { + idle.store(false) + val raw = frame.readBytes() + val payload = decompressIfNeeded(raw) + val msg = ServerMessage.decode(payload) + onMessage(msg) + } + + is Frame.Pong -> { + idle.store(false) + wantPong.store(false) + } + + is Frame.Close -> return + else -> { + idle.store(false) + } + } + } + } + + private fun decompressIfNeeded(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + val tag = data[0].toUByte().toInt() + val payload = data.copyOfRange(1, data.size) + return when (tag) { + 0 -> payload + 1 -> decompressBrotli(payload) + 2 -> decompressGzip(payload) + else -> throw IllegalStateException("Unknown compression tag: $tag") + } + } + + private fun urlEncode(value: String): String = buildString { + for (c in value) { + when { + c.isLetterOrDigit() || c in "-._~" -> append(c) + else -> { + for (b in c.toString().encodeToByteArray()) { + append('%') + append(HEX[(b.toInt() shr 4) and 0xF]) + append(HEX[b.toInt() and 0xF]) + } + } + } + } + } + + private fun buildWsUri(uri: String, moduleName: String, token: String?): String { + val base = uri.trimEnd('/') + val wsBase = when { + base.startsWith("ws://") || base.startsWith("wss://") -> base + base.startsWith("http://") -> "ws://" + base.removePrefix("http://") + base.startsWith("https://") -> "wss://" + base.removePrefix("https://") + else -> "ws://$base" + } + val sb = StringBuilder("$wsBase/v1/database/$moduleName/subscribe") + val params = mutableListOf() + if (token != null) params.add("token=${urlEncode(token)}") + params.add("compression=${compression.queryValue}") + params.add("connection_id=${connectionId.toHex()}") + if (confirmedReads != null) { + params.add("confirmed=$confirmedReads") + } + if (lightMode) { + params.add("light=true") + } + sb.append("?${params.joinToString("&")}") + return sb.toString() + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt new file mode 100644 index 00000000000..c7e79a49f74 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt @@ -0,0 +1,170 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class BsatnTest { + + @Test + fun roundTripBool() { + val writer = BsatnWriter() + writer.writeBool(true) + writer.writeBool(false) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(true, reader.readBool()) + assertEquals(false, reader.readBool()) + } + + @Test + fun roundTripU8() { + val writer = BsatnWriter() + writer.writeU8(0u) + writer.writeU8(255u) + writer.writeU8(42u) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u.toUByte(), reader.readU8()) + assertEquals(255u.toUByte(), reader.readU8()) + assertEquals(42u.toUByte(), reader.readU8()) + } + + @Test + fun roundTripI32() { + val writer = BsatnWriter() + writer.writeI32(0) + writer.writeI32(Int.MAX_VALUE) + writer.writeI32(Int.MIN_VALUE) + writer.writeI32(-1) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0, reader.readI32()) + assertEquals(Int.MAX_VALUE, reader.readI32()) + assertEquals(Int.MIN_VALUE, reader.readI32()) + assertEquals(-1, reader.readI32()) + } + + @Test + fun roundTripU32() { + val writer = BsatnWriter() + writer.writeU32(0u) + writer.writeU32(UInt.MAX_VALUE) + writer.writeU32(12345u) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u, reader.readU32()) + assertEquals(UInt.MAX_VALUE, reader.readU32()) + assertEquals(12345u, reader.readU32()) + } + + @Test + fun roundTripI64() { + val writer = BsatnWriter() + writer.writeI64(0L) + writer.writeI64(Long.MAX_VALUE) + writer.writeI64(Long.MIN_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0L, reader.readI64()) + assertEquals(Long.MAX_VALUE, reader.readI64()) + assertEquals(Long.MIN_VALUE, reader.readI64()) + } + + @Test + fun roundTripU64() { + val writer = BsatnWriter() + writer.writeU64(0u) + writer.writeU64(ULong.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u.toULong(), reader.readU64()) + assertEquals(ULong.MAX_VALUE, reader.readU64()) + } + + @Test + fun roundTripF32() { + val writer = BsatnWriter() + writer.writeF32(3.14f) + writer.writeF32(0.0f) + writer.writeF32(-1.0f) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(3.14f, reader.readF32()) + assertEquals(0.0f, reader.readF32()) + assertEquals(-1.0f, reader.readF32()) + } + + @Test + fun roundTripF64() { + val writer = BsatnWriter() + writer.writeF64(3.141592653589793) + writer.writeF64(Double.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(3.141592653589793, reader.readF64()) + assertEquals(Double.MAX_VALUE, reader.readF64()) + } + + @Test + fun roundTripString() { + val writer = BsatnWriter() + writer.writeString("hello") + writer.writeString("") + writer.writeString("unicode: 日本語 🚀") + val reader = BsatnReader(writer.toByteArray()) + assertEquals("hello", reader.readString()) + assertEquals("", reader.readString()) + assertEquals("unicode: 日本語 🚀", reader.readString()) + } + + @Test + fun roundTripByteArray() { + val writer = BsatnWriter() + val data = byteArrayOf(1, 2, 3, 4, 5) + writer.writeByteArray(data) + writer.writeByteArray(ByteArray(0)) + val reader = BsatnReader(writer.toByteArray()) + assertTrue(data.contentEquals(reader.readByteArray())) + assertTrue(ByteArray(0).contentEquals(reader.readByteArray())) + } + + @Test + fun roundTripArray() { + val writer = BsatnWriter() + writer.writeArray(listOf(10, 20, 30)) { w, v -> w.writeI32(v) } + val reader = BsatnReader(writer.toByteArray()) + val result = reader.readArray { it.readI32() } + assertEquals(listOf(10, 20, 30), result) + } + + @Test + fun roundTripOption() { + val writer = BsatnWriter() + writer.writeOption(42) { w, v -> w.writeI32(v) } + writer.writeOption(null) { w, v -> w.writeI32(v) } + val reader = BsatnReader(writer.toByteArray()) + assertEquals(42, reader.readOption { it.readI32() }) + assertNull(reader.readOption { it.readI32() }) + } + + @Test + fun littleEndianEncoding() { + val writer = BsatnWriter() + writer.writeU32(0x04030201u) + val bytes = writer.toByteArray() + assertEquals(1, bytes[0].toInt()) + assertEquals(2, bytes[1].toInt()) + assertEquals(3, bytes[2].toInt()) + assertEquals(4, bytes[3].toInt()) + } + + @Test + fun stringEncodingFormat() { + val writer = BsatnWriter() + writer.writeString("AB") + val bytes = writer.toByteArray() + assertEquals(6, bytes.size) + assertEquals(2, bytes[0].toInt()) + assertEquals(0, bytes[1].toInt()) + assertEquals(0, bytes[2].toInt()) + assertEquals(0, bytes[3].toInt()) + assertEquals(0x41, bytes[4].toInt()) + assertEquals(0x42, bytes[5].toInt()) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt new file mode 100644 index 00000000000..a4f868932c8 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt @@ -0,0 +1,106 @@ +package com.clockworklabs.spacetimedb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ClientCacheTest { + + @Test + fun insertAndCount() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(4, 5, 6)) + assertEquals(2, cache.count) + } + + @Test + fun duplicateInsertIncrementsRefCount() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(1, 2, 3)) + assertEquals(1, cache.count) + } + + @Test + fun deleteRemovesRow() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + assertTrue(cache.deleteRow(byteArrayOf(1, 2, 3))) + assertEquals(0, cache.count) + } + + @Test + fun deleteNonexistentReturnsFalse() { + val cache = TableCache("users") + assertFalse(cache.deleteRow(byteArrayOf(1, 2, 3))) + } + + @Test + fun refCountedDelete() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.deleteRow(byteArrayOf(1, 2, 3)) + assertEquals(1, cache.count) + cache.deleteRow(byteArrayOf(1, 2, 3)) + assertEquals(0, cache.count) + } + + @Test + fun containsRow() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(10, 20)) + assertTrue(cache.containsRow(byteArrayOf(10, 20))) + assertFalse(cache.containsRow(byteArrayOf(30, 40))) + } + + @Test + fun allRows() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1)) + cache.insertRow(byteArrayOf(2)) + cache.insertRow(byteArrayOf(3)) + assertEquals(3, cache.allRows().size) + } + + @Test + fun clientCacheGetOrCreate() { + val cc = ClientCache() + val t1 = cc.getOrCreateTable("users") + val t2 = cc.getOrCreateTable("users") + assertTrue(t1 === t2) + } + + @Test + fun clientCacheTableNames() { + val cc = ClientCache() + cc.getOrCreateTable("users") + cc.getOrCreateTable("messages") + assertEquals(setOf("users", "messages"), cc.tableNames()) + } + + @Test + fun tableHandleCallbacks() { + val handle = TableHandle("users") + var inserted: ByteArray? = null + var deleted: ByteArray? = null + var updatedOld: ByteArray? = null + var updatedNew: ByteArray? = null + + handle.onInsert { row -> inserted = row } + handle.onDelete { row -> deleted = row } + handle.onUpdate { old, new -> updatedOld = old; updatedNew = new } + + handle.fireInsert(byteArrayOf(1, 2, 3)) + assertTrue(byteArrayOf(1, 2, 3).contentEquals(inserted!!)) + + handle.fireDelete(byteArrayOf(4, 5, 6)) + assertTrue(byteArrayOf(4, 5, 6).contentEquals(deleted!!)) + + handle.fireUpdate(byteArrayOf(1), byteArrayOf(2)) + assertTrue(byteArrayOf(1).contentEquals(updatedOld!!)) + assertTrue(byteArrayOf(2).contentEquals(updatedNew!!)) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt new file mode 100644 index 00000000000..1f9afe2eb3d --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt @@ -0,0 +1,754 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNull +import kotlin.test.assertTrue + +/** + * Edge case tests covering protocol decode, cache semantics, callback behavior, + * URI handling, and subscription lifecycle — all offline, no server needed. + */ +class EdgeCaseTest { + + // ──────────────── ReducerOutcome: All 4 variants ──────────────── + + @Test + fun reducerOutcomeOkDecode() { + val w = BsatnWriter(128) + w.writeTag(0u) // Ok + w.writeByteArray(byteArrayOf(42)) // retValue + // TransactionUpdateData: array of QuerySetUpdate (empty) + w.writeU32(0u) + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.Ok) + assertEquals(1, outcome.retValue.size) + assertEquals(42.toByte(), outcome.retValue[0]) + } + + @Test + fun reducerOutcomeOkEmptyDecode() { + val w = BsatnWriter(4) + w.writeTag(1u) // OkEmpty + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.OkEmpty) + } + + @Test + fun reducerOutcomeErrDecode() { + val w = BsatnWriter(64) + w.writeTag(2u) // Err + w.writeByteArray("reducer panicked".encodeToByteArray()) + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.Err) + assertEquals("reducer panicked", outcome.message.decodeToString()) + } + + @Test + fun reducerOutcomeInternalErrorDecode() { + val w = BsatnWriter(64) + w.writeTag(3u) // InternalError + w.writeString("internal server error") + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.InternalError) + assertEquals("internal server error", outcome.message) + } + + @Test + fun reducerOutcomeInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(99u) + assertFailsWith { + ReducerOutcome.read(BsatnReader(w.toByteArray())) + } + } + + // ──────────────── ReducerResult ServerMessage ──────────────── + + @Test + fun serverMessageReducerResultFullDecode() { + val w = BsatnWriter(128) + w.writeTag(6u) // ReducerResult tag + w.writeU32(7u) // requestId + w.writeI64(1_700_000_000_000_000L) // timestamp + w.writeTag(1u) // ReducerOutcome::OkEmpty + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.ReducerResult) + assertEquals(7u, msg.requestId) + assertEquals(1_700_000_000_000_000L, msg.timestamp.microseconds) + assertTrue(msg.result is ReducerOutcome.OkEmpty) + } + + @Test + fun serverMessageReducerResultWithErr() { + val w = BsatnWriter(128) + w.writeTag(6u) + w.writeU32(99u) + w.writeI64(0L) + w.writeTag(3u) // InternalError + w.writeString("boom") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.ReducerResult) + val result = msg.result + assertTrue(result is ReducerOutcome.InternalError) + assertEquals("boom", result.message) + } + + // ──── ReducerResult: Err/InternalError must NOT update cache ──── + + @Test + fun reducerErrDoesNotUpdateCache() { + val cache = ClientCache() + val table = cache.getOrCreateTable("test") + table.insertRow(byteArrayOf(1, 2, 3)) + assertEquals(1, table.count) + + // Simulate: ReducerOutcome.Err should NOT apply any cache update + // (The DbConnection code checks `msg.result is ReducerOutcome.Ok` before applying) + // This test validates the logic by directly testing the guard condition + val errOutcome: ReducerOutcome = ReducerOutcome.Err("fail".encodeToByteArray()) + assertFalse(errOutcome is ReducerOutcome.Ok) + + val emptyOutcome: ReducerOutcome = ReducerOutcome.OkEmpty + assertFalse(emptyOutcome is ReducerOutcome.Ok) + + // Cache unchanged + assertEquals(1, table.count) + } + + // ──────────────── ProcedureStatus decode ──────────────── + + @Test + fun procedureStatusReturnedDecode() { + val w = BsatnWriter(32) + w.writeTag(0u) // Returned + w.writeByteArray(byteArrayOf(0xAB.toByte(), 0xCD.toByte())) + val status = ProcedureStatus.read(BsatnReader(w.toByteArray())) + assertTrue(status is ProcedureStatus.Returned) + assertEquals(2, status.data.size) + } + + @Test + fun procedureStatusInternalErrorDecode() { + val w = BsatnWriter(32) + w.writeTag(1u) // InternalError + w.writeString("proc failed") + val status = ProcedureStatus.read(BsatnReader(w.toByteArray())) + assertTrue(status is ProcedureStatus.InternalError) + assertEquals("proc failed", status.message) + } + + @Test + fun procedureStatusInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(5u) + assertFailsWith { + ProcedureStatus.read(BsatnReader(w.toByteArray())) + } + } + + // ──────────────── ServerMessage: Invalid tag ──────────────── + + @Test + fun serverMessageInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(200u) // invalid + assertFailsWith { + ServerMessage.decode(w.toByteArray()) + } + } + + // ──────────── SubscriptionError with null requestId ────────── + + @Test + fun subscriptionErrorWithNullRequestId() { + val w = BsatnWriter(64) + w.writeTag(3u) // SubscriptionError + w.writeTag(1u) // Option::None (SATS: None = tag 1) + w.writeU32(42u) // querySetId + w.writeString("bad query syntax") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.SubscriptionError) + assertNull(msg.requestId) + assertEquals(QuerySetId(42u), msg.querySetId) + assertEquals("bad query syntax", msg.error) + } + + @Test + fun subscriptionErrorWithRequestId() { + val w = BsatnWriter(64) + w.writeTag(3u) // SubscriptionError + w.writeTag(0u) // Option::Some (SATS: Some = tag 0) + w.writeU32(7u) // requestId + w.writeU32(42u) // querySetId + w.writeString("table not found") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.SubscriptionError) + assertEquals(7u, msg.requestId) + } + + // ──────────── UnsubscribeApplied with null rows ────────── + + @Test + fun unsubscribeAppliedWithNullRows() { + val w = BsatnWriter(32) + w.writeTag(2u) // UnsubscribeApplied + w.writeU32(5u) // requestId + w.writeU32(3u) // querySetId + w.writeTag(1u) // Option::None (SATS: None = tag 1) + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.UnsubscribeApplied) + assertNull(msg.rows) + } + + // ──────── Cache: Update detection edge cases ──────── + + @Test + fun cacheUpdateDetectionDeleteAndInsertSameBytes() { + // When delete + insert have same content → Update + val cache = ClientCache() + cache.getOrCreateTable("t") + val row = byteArrayOf(1, 2, 3) + cache.getOrCreateTable("t").insertRow(row) + + val ops = applyPersistentOps(cache, "t", + inserts = listOf(row), + deletes = listOf(row), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Update) + } + + @Test + fun cacheDeleteWithoutMatchingInsert() { + val cache = ClientCache() + val row = byteArrayOf(1, 2, 3) + cache.getOrCreateTable("t").insertRow(row) + + val ops = applyPersistentOps(cache, "t", + inserts = emptyList(), + deletes = listOf(row), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Delete) + assertEquals(0, cache.getOrCreateTable("t").count) + } + + @Test + fun cacheInsertWithoutMatchingDelete() { + val cache = ClientCache() + cache.getOrCreateTable("t") + + val ops = applyPersistentOps(cache, "t", + inserts = listOf(byteArrayOf(1, 2, 3)), + deletes = emptyList(), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Insert) + assertEquals(1, cache.getOrCreateTable("t").count) + } + + @Test + fun cacheEmptyTransaction() { + val cache = ClientCache() + cache.getOrCreateTable("t") + val ops = applyPersistentOps(cache, "t", + inserts = emptyList(), + deletes = emptyList(), + ) + assertTrue(ops.isEmpty()) + } + + @Test + fun cacheRefCountOverlappingSubscriptions() { + // Two subscriptions insert same row → refCount=2 + val table = TableCache("test") + val row = byteArrayOf(10, 20, 30) + table.insertRow(row) // sub 1 + table.insertRow(row) // sub 2 + assertEquals(1, table.count, "Same content, single entry") + assertTrue(table.containsRow(row)) + + // Unsub 1: refCount=1, row stays + table.deleteRow(row) + assertEquals(1, table.count) + assertTrue(table.containsRow(row)) + + // Unsub 2: refCount=0, row removed + table.deleteRow(row) + assertEquals(0, table.count) + assertFalse(table.containsRow(row)) + } + + @Test + fun cacheDeleteNonExistentRow() { + val table = TableCache("test") + val result = table.deleteRow(byteArrayOf(99)) + assertFalse(result, "Deleting non-existent row should return false") + assertEquals(0, table.count) + } + + // ──────── Callback re-entrance safety ──────── + + @Test + fun callbackCanRegisterAnotherCallbackDuringFire() { + val handle = TableHandle("test") + var secondCallbackFired = false + + handle.onInsert { _ -> + // Register a new callback from within a callback + handle.onInsert { _ -> secondCallbackFired = true } + } + + // First fire: triggers the registration callback + handle.fireInsert(byteArrayOf(1)) + assertFalse(secondCallbackFired, "Newly registered callback should not fire in same event") + + // Second fire: both callbacks fire + handle.fireInsert(byteArrayOf(2)) + assertTrue(secondCallbackFired, "Second callback should fire on next event") + } + + @Test + fun callbackCanRemoveItselfDuringFire() { + val handle = TableHandle("test") + var fireCount = 0 + var selfId: CallbackId? = null + + selfId = handle.onInsert { _ -> + fireCount++ + handle.removeOnInsert(selfId!!) + } + + handle.fireInsert(byteArrayOf(1)) + assertEquals(1, fireCount) + + handle.fireInsert(byteArrayOf(2)) + assertEquals(1, fireCount, "Removed callback should not fire again") + } + + // ──────── Subscription lifecycle states ──────── + + @Test + fun subscriptionStateLifecycle() { + // Can't create a real DbConnection without a server, but we can test + // the SubscriptionHandle state machine directly + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + assertEquals(SubscriptionState.PENDING, handle.state) + assertFalse(handle.isActive) + assertFalse(handle.isEnded) + + handle.state = SubscriptionState.ACTIVE + assertTrue(handle.isActive) + assertFalse(handle.isEnded) + + handle.state = SubscriptionState.ENDED + assertFalse(handle.isActive) + assertTrue(handle.isEnded) + } + + @Test + fun doubleUnsubscribeIsSafe() { + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + handle.state = SubscriptionState.ACTIVE + handle.unsubscribe() // First: transitions to ENDED + assertTrue(handle.isEnded) + handle.unsubscribe() // Second: no-op, no crash + assertTrue(handle.isEnded) + } + + @Test + fun unsubscribeOnPendingIsNoOp() { + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + assertEquals(SubscriptionState.PENDING, handle.state) + handle.unsubscribe() // Transitions to CANCELLED + assertEquals(SubscriptionState.CANCELLED, handle.state) + } + + // ──────── URI scheme normalization ──────── + + @Test + fun uriSchemeNormalization() { + // Test the URI building logic by encoding/decoding the buildWsUri output + // We'll test the WebSocketTransport.buildWsUri indirectly via pattern matching + val testCases = mapOf( + "http://localhost:3000" to "ws://", + "https://example.com" to "wss://", + "ws://localhost:3000" to "ws://", + "wss://example.com" to "wss://", + "localhost:3000" to "ws://", + ) + // These are validated by the WebSocketTransport.buildWsUri method + // which is private — we verify the logic patterns match + for ((input, expectedPrefix) in testCases) { + val base = input.trimEnd('/') + val wsBase = when { + base.startsWith("ws://") || base.startsWith("wss://") -> base + base.startsWith("http://") -> "ws://" + base.removePrefix("http://") + base.startsWith("https://") -> "wss://" + base.removePrefix("https://") + else -> "ws://$base" + } + assertTrue(wsBase.startsWith(expectedPrefix), "Input '$input' should start with '$expectedPrefix', got '$wsBase'") + } + } + + // ──────── BSATN: Boundary values ──────── + + @Test + fun bsatnBoundaryValues() { + val w = BsatnWriter(128) + // Unsigned extremes + w.writeU8(UByte.MIN_VALUE) + w.writeU8(UByte.MAX_VALUE) + w.writeU16(UShort.MIN_VALUE) + w.writeU16(UShort.MAX_VALUE) + w.writeU32(UInt.MIN_VALUE) + w.writeU32(UInt.MAX_VALUE) + w.writeU64(ULong.MIN_VALUE) + w.writeU64(ULong.MAX_VALUE) + // Signed extremes + w.writeI8(Byte.MIN_VALUE) + w.writeI8(Byte.MAX_VALUE) + w.writeI16(Short.MIN_VALUE) + w.writeI16(Short.MAX_VALUE) + w.writeI32(Int.MIN_VALUE) + w.writeI32(Int.MAX_VALUE) + w.writeI64(Long.MIN_VALUE) + w.writeI64(Long.MAX_VALUE) + // Float specials + w.writeF32(Float.NaN) + w.writeF32(Float.POSITIVE_INFINITY) + w.writeF32(Float.NEGATIVE_INFINITY) + w.writeF32(0.0f) + w.writeF32(-0.0f) + w.writeF64(Double.NaN) + w.writeF64(Double.POSITIVE_INFINITY) + w.writeF64(Double.NEGATIVE_INFINITY) + + val r = BsatnReader(w.toByteArray()) + assertEquals(UByte.MIN_VALUE, r.readU8()) + assertEquals(UByte.MAX_VALUE, r.readU8()) + assertEquals(UShort.MIN_VALUE, r.readU16()) + assertEquals(UShort.MAX_VALUE, r.readU16()) + assertEquals(UInt.MIN_VALUE, r.readU32()) + assertEquals(UInt.MAX_VALUE, r.readU32()) + assertEquals(ULong.MIN_VALUE, r.readU64()) + assertEquals(ULong.MAX_VALUE, r.readU64()) + assertEquals(Byte.MIN_VALUE, r.readI8()) + assertEquals(Byte.MAX_VALUE, r.readI8()) + assertEquals(Short.MIN_VALUE, r.readI16()) + assertEquals(Short.MAX_VALUE, r.readI16()) + assertEquals(Int.MIN_VALUE, r.readI32()) + assertEquals(Int.MAX_VALUE, r.readI32()) + assertEquals(Long.MIN_VALUE, r.readI64()) + assertEquals(Long.MAX_VALUE, r.readI64()) + assertTrue(r.readF32().isNaN()) + assertEquals(Float.POSITIVE_INFINITY, r.readF32()) + assertEquals(Float.NEGATIVE_INFINITY, r.readF32()) + assertEquals(0.0f, r.readF32()) + // -0.0f == 0.0f in Kotlin, compare bits + assertEquals((-0.0f).toRawBits(), r.readF32().toRawBits()) + assertTrue(r.readF64().isNaN()) + assertEquals(Double.POSITIVE_INFINITY, r.readF64()) + assertEquals(Double.NEGATIVE_INFINITY, r.readF64()) + assertTrue(r.isExhausted) + } + + @Test + fun bsatnEmptyString() { + val w = BsatnWriter(8) + w.writeString("") + val r = BsatnReader(w.toByteArray()) + assertEquals("", r.readString()) + } + + @Test + fun bsatnEmptyByteArray() { + val w = BsatnWriter(8) + w.writeByteArray(byteArrayOf()) + val r = BsatnReader(w.toByteArray()) + val bytes = r.readByteArray() + assertEquals(0, bytes.size) + } + + @Test + fun bsatnEmptyArray() { + val w = BsatnWriter(8) + w.writeArray(emptyList()) { wr, s -> wr.writeString(s) } + val r = BsatnReader(w.toByteArray()) + val list = r.readArray { it.readString() } + assertTrue(list.isEmpty()) + } + + @Test + fun bsatnOptionNoneAndSome() { + val w = BsatnWriter(16) + w.writeOption(null) { wr, v: String -> wr.writeString(v) } + w.writeOption("hello") { wr, v -> wr.writeString(v) } + + val r = BsatnReader(w.toByteArray()) + assertNull(r.readOption { it.readString() }) + assertEquals("hello", r.readOption { it.readString() }) + } + + @Test + fun bsatnReaderUnderflowThrows() { + val r = BsatnReader(byteArrayOf(1, 2)) + r.readU8() // ok + r.readU8() // ok + assertFailsWith { + r.readU8() // no bytes left + } + } + + @Test + fun bsatnReaderReadMoreThanAvailableThrows() { + val r = BsatnReader(byteArrayOf(1, 2, 3)) + assertFailsWith { + r.readU32() // needs 4 bytes, only 3 available + } + } + + // ──────── Compression tag handling ──────── + + @Test + fun compressionTagUnknownThrows() { + // Tag 0 = uncompressed, 1 = brotli, 2 = gzip. Tag 3+ should throw. + val data = byteArrayOf(3, 0, 0, 0) + assertFailsWith { + decompressWithTag(data) + } + } + + @Test + fun compressionTagUncompressed() { + val payload = byteArrayOf(0, 1, 2, 3, 4) // tag 0 + data + val result = decompressWithTag(payload) + assertEquals(4, result.size) + assertEquals(1.toByte(), result[0]) + } + + // ──────── Identity edge cases ──────── + + @Test + fun identityWrongSizeThrows() { + assertFailsWith { + Identity(ByteArray(16)) // needs 32 + } + } + + @Test + fun connectionIdWrongSizeThrows() { + assertFailsWith { + ConnectionId(ByteArray(8)) // needs 16 + } + } + + @Test + fun addressWrongSizeThrows() { + assertFailsWith { + Address(ByteArray(32)) // needs 16 + } + } + + @Test + fun identityZero() { + assertTrue(Identity.ZERO.bytes.all { it == 0.toByte() }) + assertEquals(32, Identity.ZERO.bytes.size) + } + + @Test + fun identityHexRoundTrip() { + val hex = "0123456789abcdef" .repeat(4) + val id = Identity.fromHex(hex) + assertEquals(hex, id.toHex()) + } + + @Test + fun identityFromHexWrongLengthThrows() { + assertFailsWith { + Identity.fromHex("0123") // needs 64 hex chars + } + } + + @Test + fun identityEquality() { + val a = Identity(ByteArray(32) { it.toByte() }) + val b = Identity(ByteArray(32) { it.toByte() }) + val c = Identity(ByteArray(32) { 0 }) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + // ──────── ClientMessage encode edge cases ──────── + + @Test + fun callReducerEmptyArgs() { + val msg = ClientMessage.CallReducer( + requestId = 1u, + reducer = "no_args_reducer", + args = byteArrayOf(), + ) + val encoded = msg.encode() + val r = BsatnReader(encoded) + assertEquals(3, r.readTag().toInt()) // CallReducer tag + assertEquals(1u, r.readU32()) + assertEquals(0.toUByte(), r.readU8()) // flags + assertEquals("no_args_reducer", r.readString()) + val args = r.readByteArray() + assertEquals(0, args.size) + } + + @Test + fun callReducerEquality() { + val a = ClientMessage.CallReducer(1u, "test", byteArrayOf(1, 2, 3)) + val b = ClientMessage.CallReducer(1u, "test", byteArrayOf(1, 2, 3)) + val c = ClientMessage.CallReducer(1u, "test", byteArrayOf(4, 5, 6)) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + @Test + fun unsubscribeWithSendDroppedRowsFlag() { + val msg = ClientMessage.Unsubscribe( + requestId = 5u, + querySetId = QuerySetId(10u), + flags = 1u, // SendDroppedRows + ) + val encoded = msg.encode() + val r = BsatnReader(encoded) + assertEquals(1, r.readTag().toInt()) // Unsubscribe tag + assertEquals(5u, r.readU32()) + assertEquals(10u, r.readU32()) // querySetId + assertEquals(1.toUByte(), r.readU8()) // flags = SendDroppedRows + } + + // ──────── DbConnectionBuilder validation ──────── + + @Test + fun builderWithoutUriThrows() { + assertFailsWith { + DbConnection.builder() + .withModuleName("test") + .build() + } + } + + @Test + fun builderWithoutModuleNameThrows() { + assertFailsWith { + DbConnection.builder() + .withUri("ws://localhost:3000") + .build() + } + } + + // ──────── ByteArrayWrapper edge cases ──────── + + @Test + fun byteArrayWrapperEquality() { + val a = ByteArrayWrapper(byteArrayOf(1, 2, 3)) + val b = ByteArrayWrapper(byteArrayOf(1, 2, 3)) + val c = ByteArrayWrapper(byteArrayOf(3, 2, 1)) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + @Test + fun byteArrayWrapperEmptyArrays() { + val a = ByteArrayWrapper(byteArrayOf()) + val b = ByteArrayWrapper(byteArrayOf()) + assertEquals(a, b) + } + + @Test + fun byteArrayWrapperNotEqualToOtherTypes() { + val a = ByteArrayWrapper(byteArrayOf(1)) + assertFalse(a.equals("string")) + assertFalse(a.equals(null)) + } + + // ──────── Helpers ──────── + + private fun applyPersistentOps( + cache: ClientCache, + tableName: String, + inserts: List, + deletes: List, + ): List { + val w = BsatnWriter(1024) + w.writeU32(1u) // querySetId + w.writeU32(1u) // 1 table + w.writeString(tableName) + w.writeU32(1u) // 1 row update + w.writeTag(0u) // PersistentTable + // inserts BsatnRowList + writeRowList(w, inserts) + // deletes BsatnRowList + writeRowList(w, deletes) + + val qsUpdate = QuerySetUpdate.read(BsatnReader(w.toByteArray())) + return cache.applyTransactionUpdate(listOf(qsUpdate)) + } + + private fun writeRowList(w: BsatnWriter, rows: List) { + w.writeTag(0u) // FixedSize hint + if (rows.isEmpty()) { + w.writeU16(0u) + w.writeU32(0u) + } else { + val rowSize = rows.first().size + w.writeU16(rowSize.toUShort()) + w.writeU32((rowSize * rows.size).toUInt()) + for (row in rows) w.writeBytes(row) + } + } + + private fun decompressWithTag(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + val tag = data[0].toUByte().toInt() + val payload = data.copyOfRange(1, data.size) + return when (tag) { + 0 -> payload + 1 -> decompressBrotli(payload) + 2 -> decompressGzip(payload) + else -> throw IllegalStateException("Unknown compression tag: $tag") + } + } + + // Stub connection that doesn't actually connect (for subscription state tests) + private fun stubConnection(): DbConnection { + // We only need a DbConnection object for the SubscriptionHandle reference. + // The builder validation requires URI and module name. + // This will attempt to connect but we don't care — we only test handle state. + return DbConnection( + uri = "ws://invalid.test:0", + moduleName = "test", + token = null, + connectCallbacks = emptyList(), + disconnectCallbacks = emptyList(), + connectErrorCallbacks = emptyList(), + keepAliveIntervalMs = 0, + ) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt new file mode 100644 index 00000000000..47b0995510e --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt @@ -0,0 +1,82 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class OneOffQueryTest { + + @Test + fun decodeOneOffQueryOk() { + val writer = BsatnWriter() + // ServerMessage tag 5 = OneOffQueryResult + writer.writeTag(5u) + // requestId + writer.writeU32(42u) + // Result tag 0 = Ok(QueryRows) + writer.writeTag(0u) + // QueryRows: array of SingleTableRows (empty) + writer.writeU32(0u) + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(42u, msg.requestId) + assertNotNull(msg.rows) + assertEquals(0, msg.rows.tables.size) + assertNull(msg.error) + } + + @Test + fun decodeOneOffQueryErr() { + val writer = BsatnWriter() + // ServerMessage tag 5 = OneOffQueryResult + writer.writeTag(5u) + // requestId + writer.writeU32(99u) + // Result tag 1 = Err(string) + writer.writeTag(1u) + writer.writeString("table not found") + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(99u, msg.requestId) + assertNull(msg.rows) + assertEquals("table not found", msg.error) + } + + @Test + fun decodeOneOffQueryOkWithRows() { + val writer = BsatnWriter() + writer.writeTag(5u) + writer.writeU32(7u) + // Result tag 0 = Ok + writer.writeTag(0u) + // QueryRows: 1 table + writer.writeU32(1u) + // SingleTableRows: table name (RawIdentifier = string) + writer.writeString("users") + // BsatnRowList: RowSizeHint (tag 0 = FixedSize) + writer.writeTag(0u) + writer.writeU16(4u) + // rowsData: 2 rows of 4 bytes each = 8 bytes + val rowsData = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8) + writer.writeByteArray(rowsData) + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(7u, msg.requestId) + assertNotNull(msg.rows) + assertEquals(1, msg.rows.tables.size) + assertEquals("users", msg.rows.tables[0].table.value) + + val decodedRows = msg.rows.tables[0].rows.decodeRows() + assertEquals(2, decodedRows.size) + assertTrue(byteArrayOf(1, 2, 3, 4).contentEquals(decodedRows[0])) + assertTrue(byteArrayOf(5, 6, 7, 8).contentEquals(decodedRows[1])) + assertNull(msg.error) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt new file mode 100644 index 00000000000..7891cc9cf2e --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt @@ -0,0 +1,129 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.ClientMessage +import com.clockworklabs.spacetimedb.protocol.QuerySetId +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ProtocolTest { + + @Test + fun encodeSubscribeMessage() { + val msg = ClientMessage.Subscribe( + requestId = 1u, + querySetId = QuerySetId(100u), + queryStrings = listOf("SELECT * FROM users"), + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(0, reader.readTag().toInt()) + assertEquals(1u, reader.readU32()) + assertEquals(100u, reader.readU32()) + val queries = reader.readArray { it.readString() } + assertEquals(listOf("SELECT * FROM users"), queries) + } + + @Test + fun encodeCallReducerMessage() { + val args = byteArrayOf(10, 20, 30) + val msg = ClientMessage.CallReducer( + requestId = 5u, + reducer = "add_user", + args = args, + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(3, reader.readTag().toInt()) + assertEquals(5u, reader.readU32()) + assertEquals(0u.toUByte(), reader.readU8()) + assertEquals("add_user", reader.readString()) + assertTrue(args.contentEquals(reader.readByteArray())) + } + + @Test + fun encodeUnsubscribeMessage() { + val msg = ClientMessage.Unsubscribe( + requestId = 2u, + querySetId = QuerySetId(50u), + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(1, reader.readTag().toInt()) + assertEquals(2u, reader.readU32()) + assertEquals(50u, reader.readU32()) + } + + @Test + fun encodeOneOffQueryMessage() { + val msg = ClientMessage.OneOffQuery( + requestId = 3u, + queryString = "SELECT count(*) FROM users", + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(2, reader.readTag().toInt()) + assertEquals(3u, reader.readU32()) + assertEquals("SELECT count(*) FROM users", reader.readString()) + } + + @Test + fun decodeInitialConnection() { + val writer = BsatnWriter() + writer.writeTag(0u) + writer.writeBytes(ByteArray(32) { it.toByte() }) + writer.writeBytes(ByteArray(16) { (it + 100).toByte() }) + writer.writeString("test-token-abc") + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.InitialConnection) + assertEquals("test-token-abc", msg.token) + assertEquals(ByteArray(32) { it.toByte() }.toList(), msg.identity.bytes.toList()) + assertEquals(ByteArray(16) { (it + 100).toByte() }.toList(), msg.connectionId.bytes.toList()) + } + + @Test + fun identityFromHex() { + val hex = "00" + "01" + "02" + "03" + "04" + "05" + "06" + "07" + + "08" + "09" + "0a" + "0b" + "0c" + "0d" + "0e" + "0f" + + "10" + "11" + "12" + "13" + "14" + "15" + "16" + "17" + + "18" + "19" + "1a" + "1b" + "1c" + "1d" + "1e" + "1f" + val identity = Identity.fromHex(hex) + assertEquals(0, identity.bytes[0].toInt()) + assertEquals(31, identity.bytes[31].toInt()) + assertEquals(hex, identity.toHex()) + } + + @Test + fun identityBsatnRoundTrip() { + val original = Identity(ByteArray(32) { (it * 7).toByte() }) + val writer = BsatnWriter() + Identity.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Identity.read(reader) + assertEquals(original, decoded) + } + + @Test + fun connectionIdBsatnRoundTrip() { + val original = ConnectionId(ByteArray(16) { (it * 3).toByte() }) + val writer = BsatnWriter() + ConnectionId.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ConnectionId.read(reader) + assertEquals(original, decoded) + } + + @Test + fun timestampBsatnRoundTrip() { + val original = Timestamp(1234567890123L) + val writer = BsatnWriter() + Timestamp.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Timestamp.read(reader) + assertEquals(original, decoded) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt new file mode 100644 index 00000000000..a16d93e1523 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt @@ -0,0 +1,84 @@ +package com.clockworklabs.spacetimedb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class ReconnectPolicyTest { + + @Test + fun defaultPolicy() { + val policy = ReconnectPolicy() + assertEquals(5, policy.maxRetries) + assertEquals(1_000L, policy.initialDelayMs) + assertEquals(30_000L, policy.maxDelayMs) + assertEquals(2.0, policy.backoffMultiplier) + } + + @Test + fun delayForAttemptExponentialBackoff() { + val policy = ReconnectPolicy( + initialDelayMs = 1_000, + maxDelayMs = 60_000, + backoffMultiplier = 2.0, + ) + assertEquals(1_000L, policy.delayForAttempt(0)) + assertEquals(2_000L, policy.delayForAttempt(1)) + assertEquals(4_000L, policy.delayForAttempt(2)) + assertEquals(8_000L, policy.delayForAttempt(3)) + assertEquals(16_000L, policy.delayForAttempt(4)) + } + + @Test + fun delayClampedToMax() { + val policy = ReconnectPolicy( + initialDelayMs = 1_000, + maxDelayMs = 5_000, + backoffMultiplier = 3.0, + ) + assertEquals(1_000L, policy.delayForAttempt(0)) + assertEquals(3_000L, policy.delayForAttempt(1)) + assertEquals(5_000L, policy.delayForAttempt(2)) // clamped: 9_000 -> 5_000 + assertEquals(5_000L, policy.delayForAttempt(3)) // stays clamped + } + + @Test + fun noBackoff() { + val policy = ReconnectPolicy( + initialDelayMs = 500, + maxDelayMs = 500, + backoffMultiplier = 1.0, + ) + assertEquals(500L, policy.delayForAttempt(0)) + assertEquals(500L, policy.delayForAttempt(1)) + assertEquals(500L, policy.delayForAttempt(5)) + } + + @Test + fun invalidMaxRetriesThrows() { + assertFailsWith { + ReconnectPolicy(maxRetries = -1) + } + } + + @Test + fun invalidInitialDelayThrows() { + assertFailsWith { + ReconnectPolicy(initialDelayMs = 0) + } + } + + @Test + fun maxDelayLessThanInitialThrows() { + assertFailsWith { + ReconnectPolicy(initialDelayMs = 5_000, maxDelayMs = 1_000) + } + } + + @Test + fun backoffMultiplierLessThanOneThrows() { + assertFailsWith { + ReconnectPolicy(backoffMultiplier = 0.5) + } + } +} diff --git a/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt b/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt new file mode 100644 index 00000000000..f6d9df24e53 --- /dev/null +++ b/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt @@ -0,0 +1,28 @@ +package com.clockworklabs.spacetimedb + +import org.brotli.dec.BrotliInputStream +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPInputStream + +actual fun decompressBrotli(data: ByteArray): ByteArray { + ByteArrayInputStream(data).use { input -> + BrotliInputStream(input).use { brotli -> + ByteArrayOutputStream(data.size * 2).use { output -> + brotli.copyTo(output) + return output.toByteArray() + } + } + } +} + +actual fun decompressGzip(data: ByteArray): ByteArray { + ByteArrayInputStream(data).use { input -> + GZIPInputStream(input).use { gzip -> + ByteArrayOutputStream(data.size * 2).use { output -> + gzip.copyTo(output) + return output.toByteArray() + } + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt new file mode 100644 index 00000000000..93628910275 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt @@ -0,0 +1,63 @@ +package com.clockworklabs.spacetimedb + +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPOutputStream +import kotlin.test.Test +import kotlin.test.assertTrue + +class CompressionTest { + + private fun gzipCompress(data: ByteArray): ByteArray { + val bos = ByteArrayOutputStream() + GZIPOutputStream(bos).use { it.write(data) } + return bos.toByteArray() + } + + @Test + fun gzipRoundTrip() { + val original = "Hello, SpacetimeDB! This is a test of gzip compression.".encodeToByteArray() + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip round-trip failed") + } + + @Test + fun gzipEmptyPayload() { + val original = ByteArray(0) + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip empty round-trip failed") + } + + @Test + fun gzipLargePayload() { + val original = ByteArray(10_000) { (it % 256).toByte() } + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip large payload round-trip failed") + } + + @Test + fun brotliRoundTrip() { + // Brotli-compressed "Hello" (pre-computed with brotli CLI) + // We test decompression only since the SDK only needs to decompress server messages + val original = "Hello".encodeToByteArray() + val compressed = brotliCompress(original) + val decompressed = decompressBrotli(compressed) + assertTrue(original.contentEquals(decompressed), "Brotli round-trip failed") + } + + private fun brotliCompress(data: ByteArray): ByteArray { + // Use org.brotli encoder if available, otherwise use a known compressed payload. + // The org.brotli:dec artifact only includes the decoder. + // Use JNI-free approach: manually construct a minimal brotli stream for "Hello" + // For robustness, we'll use the encoder from the test classpath if available. + // Minimal approach: test with a known brotli-compressed byte sequence. + // + // Pre-compressed "Hello" using brotli (metablock, uncompressed): + // This is a valid brotli stream that decompresses to "Hello" + return byteArrayOf( + 0x0b, 0x02, 0x80.toByte(), 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x03 + ) + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt new file mode 100644 index 00000000000..5ae80b51a99 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt @@ -0,0 +1,202 @@ +package com.clockworklabs.spacetimedb + +import kotlinx.coroutines.* +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong +import kotlin.test.Test +import kotlin.time.Duration.Companion.milliseconds + +/** + * Keynote-2 style TPS benchmark — fund transfers with pipelined reducer calls. + * + * Mirrors the Rust benchmark client at templates/keynote-2/spacetimedb-rust-client: + * - 10 WebSocket connections (no subscriptions) + * - Zipf-distributed account selection (alpha=0.5, 100k accounts) + * - Batched pipeline: fire 16384 reducer calls, await all responses, repeat + * - 5s warmup + 5s measurement + * + * Prerequisites: + * 1. `spacetime start` running on localhost:3000 + * 2. keynote-2 module published: `spacetime publish --server local sim` + * 3. Database seeded via Rust client: `spacetimedb-rust-transfer-sim seed` + * + * Set SPACETIMEDB_TEST=1 to enable. + */ +class Keynote2BenchmarkTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "sim" + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + companion object { + const val ACCOUNTS = 100_000 + const val ALPHA = 0.5 + const val CONNECTIONS = 10 + const val MAX_INFLIGHT = 16_384 + const val WARMUP_MS = 5_000L + const val BENCH_MS = 5_000L + const val AMOUNT = 1 + const val TOTAL_PAIRS = 10_000_000 + } + + /** + * Zipf distribution sampler via inverse CDF with binary search. + * Produces integers in [0, n) with P(k) proportional to 1/(k+1)^alpha. + */ + private class ZipfSampler(n: Int, alpha: Double, seed: Long) { + private val cdf: DoubleArray + private val rng = java.util.Random(seed) + + init { + val weights = DoubleArray(n) { 1.0 / Math.pow((it + 1).toDouble(), alpha) } + val total = weights.sum() + cdf = DoubleArray(n) + var cumulative = 0.0 + for (i in weights.indices) { + cumulative += weights[i] / total + cdf[i] = cumulative + } + } + + fun sample(): Int { + val u = rng.nextDouble() + var lo = 0; var hi = cdf.size - 1 + while (lo < hi) { + val mid = (lo + hi) ushr 1 + if (cdf[mid] < u) lo = mid + 1 else hi = mid + } + return lo + } + } + + /** Pre-compute [TOTAL_PAIRS] transfer pairs using Zipf distribution. */ + private fun generateTransferPairs(from: IntArray, to: IntArray) { + val zipf = ZipfSampler(ACCOUNTS, ALPHA, 0x12345678L) + var idx = 0 + while (idx < TOTAL_PAIRS) { + val a = zipf.sample() + val b = zipf.sample() + if (a != b && a < ACCOUNTS && b < ACCOUNTS) { + from[idx] = a + to[idx] = b + idx++ + } + } + } + + /** BSATN-encode transfer args: (from: u32, to: u32, amount: u32) in little-endian. */ + private fun encodeTransfer(from: Int, to: Int, amount: Int): ByteArray { + val buf = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(from).putInt(to).putInt(amount) + return buf.array() + } + + @Test + fun keynote2Benchmark() { + if (!shouldRun()) { println("SKIP"); return } + + println("=== Kotlin SDK Keynote-2 Transfer Benchmark ===") + println("alpha=$ALPHA, amount=$AMOUNT, accounts=$ACCOUNTS") + println("max inflight reducers = $MAX_INFLIGHT") + println("connections = $CONNECTIONS") + println() + + // Pre-compute transfer pairs (matches Rust client's make_transfers) + print("Pre-computing transfer pairs... ") + val fromArr = IntArray(TOTAL_PAIRS) + val toArr = IntArray(TOTAL_PAIRS) + generateTransferPairs(fromArr, toArr) + println("done") + + val transfersPerWorker = TOTAL_PAIRS / CONNECTIONS + + runBlocking { + // Open connections (no subscriptions — pure reducer pipelining) + println("Initializing $CONNECTIONS connections...") + val connections = (0 until CONNECTIONS).map { + val ready = CompletableDeferred() + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withCompression(CompressionMode.NONE) + .onConnect { c, _, _ -> ready.complete(c) } + .onConnectError { e -> ready.completeExceptionally(e) } + .build() + withTimeout(10_000.milliseconds) { ready.await() } + conn + } + println("All $CONNECTIONS connections established") + + val completed = AtomicLong(0) + val workersReady = AtomicInteger(0) + val benchStartNanos = AtomicLong(0) + + println("Warming up for ${WARMUP_MS / 1000}s...") + val warmupStartNanos = System.nanoTime() + + val jobs = connections.mapIndexed { workerIdx, conn -> + launch(Dispatchers.Default) { + var tIdx = workerIdx * transfersPerWorker + + // Pipeline batch: fire MAX_INFLIGHT calls, suspend until all respond + suspend fun runBatch(): Long { + val batchDone = CompletableDeferred() + val remaining = AtomicInteger(MAX_INFLIGHT) + + repeat(MAX_INFLIGHT) { + val idx = tIdx % TOTAL_PAIRS + tIdx++ + val args = encodeTransfer(fromArr[idx], toArr[idx], AMOUNT) + conn.callReducer("transfer", args) { + if (remaining.decrementAndGet() == 0) { + batchDone.complete(Unit) + } + } + } + + batchDone.await() + return MAX_INFLIGHT.toLong() + } + + // ── Warmup phase ── + while (System.nanoTime() - warmupStartNanos < WARMUP_MS * 1_000_000) { + runBatch() + } + + // Sync: wait for all workers to finish warmup + workersReady.incrementAndGet() + while (workersReady.get() < CONNECTIONS) delay(1.milliseconds) + + // First worker to pass sets the shared start time + benchStartNanos.compareAndSet(0, System.nanoTime()) + + // ── Measurement phase ── + val myStart = System.nanoTime() + while (System.nanoTime() - myStart < BENCH_MS * 1_000_000) { + val count = runBatch() + completed.addAndGet(count) + } + } + } + + println("Finished warmup. Benchmarking for ${BENCH_MS / 1000}s...") + jobs.joinAll() + + val benchEndNanos = System.nanoTime() + val totalCompleted = completed.get() + val elapsed = (benchEndNanos - benchStartNanos.get()) / 1_000_000_000.0 + val tps = totalCompleted / elapsed + + println() + println("=== Results ===") + println("ran for ${"%.3f".format(elapsed)} seconds") + println("completed $totalCompleted transfers") + println("throughput was ${"%.1f".format(tps)} TPS") + + connections.forEach { it.disconnect() } + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt new file mode 100644 index 00000000000..8c8f461d068 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt @@ -0,0 +1,467 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import kotlinx.coroutines.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.measureTime + +/** + * Live edge case tests against a local SpacetimeDB server. + * + * Set SPACETIMEDB_TEST=1 to enable. + */ +class LiveEdgeCaseTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "kotlin-sdk-test" + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + // ──────── Invalid connection scenarios ──────── + + @Test + fun connectToNonExistentModule() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connectError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName("non_existent_module_xyz_12345") + .onConnect { _, _, _ -> connectError.completeExceptionally(AssertionError("Should not connect")) } + .onConnectError { e -> connectError.complete(e) } + .onDisconnect { _, e -> if (!connectError.isCompleted) connectError.complete(e ?: RuntimeException("disconnected")) } + .build() + + val error = withTimeout(10_000.milliseconds) { connectError.await() } + assertNotNull(error) + println("PASS: Non-existent module rejected: ${error.message?.take(80)}") + conn.disconnect() + } + } + + @Test + fun connectToUnreachableHost() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connectError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri("ws://192.0.2.1:9999") // TEST-NET, guaranteed unreachable + .withModuleName("test") + .onConnectError { e -> connectError.complete(e) } + .onDisconnect { _, e -> if (!connectError.isCompleted) connectError.complete(e ?: RuntimeException("disconnected")) } + .build() + + val error = withTimeout(15000.milliseconds) { connectError.await() } + assertNotNull(error) + println("PASS: Unreachable host properly errored: ${error::class.simpleName}") + conn.disconnect() + } + } + + // ──────── Subscription edge cases ──────── + + @Test + fun subscribeWithInvalidSqlSyntax() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subError.completeExceptionally(AssertionError("Should not apply")) } + .onError { err -> subError.complete(err) } + .subscribe("SELECTT * FROMM invalid_table_xyz") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + val error = withTimeout(5000.milliseconds) { subError.await() } + assertTrue(error.isNotEmpty(), "Should get a non-empty error message") + println("PASS: Invalid SQL rejected: ${error.take(80)}") + conn.disconnect() + } + } + + @Test + fun subscribeToNonExistentTable() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subError.completeExceptionally(AssertionError("Should not apply")) } + .onError { err -> subError.complete(err) } + .subscribe("SELECT * FROM nonexistent_table_xyz") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + val error = withTimeout(5000.milliseconds) { subError.await() } + assertTrue(error.isNotEmpty()) + println("PASS: Non-existent table rejected: ${error.take(80)}") + conn.disconnect() + } + } + + @Test + fun multipleIndependentSubscriptions() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val sub1Applied = CompletableDeferred() + val sub2Applied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { sub1Applied.complete(Unit) } + .subscribe("SELECT * FROM player") + + c.subscriptionBuilder() + .onApplied { sub2Applied.complete(Unit) } + .subscribe("SELECT * FROM person") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { sub1Applied.await() } + withTimeout(5000.milliseconds) { sub2Applied.await() } + println("PASS: Two independent subscriptions applied concurrently") + conn.disconnect() + } + } + + @Test + fun subscribeToAllTables() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribeToAllTables() + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + println("PASS: subscribeToAllTables (SELECT * FROM *) applied") + conn.disconnect() + } + } + + // ──────── Reducer edge cases ──────── + + @Test + fun callNonExistentReducer() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val result = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + + conn.callReducer("nonexistent_reducer_xyz", byteArrayOf()) { r -> + result.complete(r) + } + + val res = withTimeout(5000.milliseconds) { result.await() } + // Should get an error outcome, not a crash + assertNotNull(res) + println("PASS: Non-existent reducer returned: ${res.outcome::class.simpleName}") + conn.disconnect() + } + } + + @Test + fun callReducerWithEmptyArgs() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val result = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + + // add_player expects a String arg — empty args should cause an error + conn.callReducer("add_player", byteArrayOf()) { r -> + result.complete(r) + } + + val res = withTimeout(5000.milliseconds) { result.await() } + assertNotNull(res) + // Should be an error since args don't match expected schema + println("PASS: Empty args to add_player returned: ${res.outcome::class.simpleName}") + conn.disconnect() + } + } + + // ──────── One-off query edge cases ──────── + + @Test + fun oneOffQueryInvalidSql() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000.milliseconds) { connected.await() } + val result = withTimeout(5000.milliseconds) { c.oneOffQuery("INVALID SQL QUERY!!!") } + assertNotNull(result.error, "Should return an error for invalid SQL") + assertNull(result.rows) + println("PASS: Invalid SQL one-off query returned error: ${result.error.take(80)}") + conn.disconnect() + } + } + + @Test + fun oneOffQueryEmptyResult() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000.milliseconds) { connected.await() } + // Query with impossible WHERE clause + val result = withTimeout(5000.milliseconds) { c.oneOffQuery("SELECT * FROM player WHERE id = 999999999") } + if (result.error != null) { + println("PASS: Empty result query returned error: ${result.error}") + } else { + val rows = result.rows?.tables?.flatMap { it.rows.decodeRows() } ?: emptyList() + println("PASS: Empty result query returned ${rows.size} rows") + } + conn.disconnect() + } + } + + // ──────── Token reuse ──────── + + @Test + fun reconnectWithSavedToken() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + // First connection: get identity and token + val firstConnect = CompletableDeferred>() + val conn1 = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, id, token -> firstConnect.complete(Pair(id, token)) } + .onConnectError { e -> firstConnect.completeExceptionally(e) } + .build() + + val (firstIdentity, token) = withTimeout(5000.milliseconds) { firstConnect.await() } + conn1.disconnect() + + // Second connection: reuse the token + val secondConnect = CompletableDeferred>() + val conn2 = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withToken(token) + .onConnect { _, id, newToken -> secondConnect.complete(Pair(id, newToken)) } + .onConnectError { e -> secondConnect.completeExceptionally(e) } + .build() + + val (secondIdentity, _) = withTimeout(5000.milliseconds) { secondConnect.await() } + assertEquals(firstIdentity, secondIdentity, "Same token should yield same identity") + println("PASS: Token reuse preserved identity: ${firstIdentity.toHex().take(16)}...") + conn2.disconnect() + } + } + + // ──────── Rapid fire operations ──────── + + @Test + fun rapidReducerCallsWithCallbacks() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val targetCount = 20 + val results = java.util.concurrent.ConcurrentHashMap() + val allDone = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + + val elapsed = measureTime { + repeat(targetCount) { i -> + val w = BsatnWriter(64) + w.writeString("Rapid_${System.currentTimeMillis()}_$i") + conn.callReducer("add_player", w.toByteArray()) { result -> + results[result.requestId] = result + if (results.size >= targetCount && !allDone.isCompleted) { + allDone.complete(Unit) + } + } + } + withTimeout(15000.milliseconds) { allDone.await() } + } + + assertEquals(targetCount, results.size, "All $targetCount callbacks should fire") + // Verify all got unique requestIds + assertEquals(targetCount, results.keys.size) + println("PASS: $targetCount rapid reducer calls all received callbacks in ${elapsed.inWholeMilliseconds}ms") + conn.disconnect() + } + } + + // ──────── Connection state transitions ──────── + + @Test + fun connectionStateTransitions() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val disconnected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onDisconnect { _, _ -> disconnected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + // Before connect completes, state should be CONNECTING or CONNECTED + val earlyState = conn.connectionState.value + assertTrue( + earlyState == ConnectionState.CONNECTING || earlyState == ConnectionState.CONNECTED, + "Early state should be CONNECTING or CONNECTED, got $earlyState" + ) + + withTimeout(5000.milliseconds) { connected.await() } + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + assertTrue(conn.isActive) + + conn.disconnect() + assertEquals(ConnectionState.DISCONNECTED, conn.connectionState.value) + assertFalse(conn.isActive) + + // Identity should still be available after disconnect + assertNotNull(conn.identity, "Identity should persist after disconnect") + + println("PASS: State transitions: CONNECTING -> CONNECTED -> DISCONNECTED") + } + } + + // ──────── Identity null before connect ──────── + + @Test + fun identityNullBeforeConnect() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + // Identity/connectionId/token should be null before InitialConnection arrives + // (This is a best-effort check — the connect could be very fast) + // We mainly verify they're non-null after connect + withTimeout(5000.milliseconds) { connected.await() } + + assertNotNull(conn.identity) + assertNotNull(conn.connectionId) + assertNotNull(conn.savedToken) + println("PASS: Identity, connectionId, and token all non-null after connect") + conn.disconnect() + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt new file mode 100644 index 00000000000..73cd9f48ef3 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt @@ -0,0 +1,291 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.measureTime + +/** + * Live integration tests against a local SpacetimeDB server. + * + * Prerequisites: + * 1. `spacetime start` running on localhost:3000 + * 2. Test module published: `spacetime publish --server local -p kotlin-sdk-test` + * + * Set `SPACETIMEDB_TEST=1` to enable. Skipped by default in CI. + */ +class LiveIntegrationTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "kotlin-sdk-test" + + private fun skipIfNoServer() { + if (System.getenv("SPACETIMEDB_TEST") != "1") { + println("SKIP: Set SPACETIMEDB_TEST=1 to run live integration tests") + return + } + } + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + @Test + fun connectAndReceiveIdentity() { + if (!shouldRun()) { + println("SKIP: Set SPACETIMEDB_TEST=1"); return + } + + runBlocking { + val connected = CompletableDeferred>() + val disconnected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, id, token -> + connected.complete(Triple(c, id, token)) + } + .onDisconnect { _, err -> disconnected.complete(err) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val (_, identity, token) = withTimeout(5000.milliseconds) { connected.await() } + + assertNotNull(identity, "Should receive an identity") + assertTrue(identity.bytes.size == 32, "Identity should be 32 bytes") + assertTrue(token.isNotEmpty(), "Should receive an auth token") + assertNotNull(conn.connectionId, "Should have a connectionId") + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + + println("PASS: Connected as ${identity.toHex().take(16)}...") + println(" Token: ${token.take(20)}...") + println(" ConnectionId: ${conn.connectionId}") + + conn.disconnect() + } + } + + @Test + fun subscribeAndReceiveRows() { + if (!shouldRun()) { + println("SKIP: Set SPACETIMEDB_TEST=1"); return + } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + + println("PASS: Subscription to 'SELECT * FROM player' applied successfully") + + conn.disconnect() + } + } + + @Test + fun callReducerAndObserveInsert() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val insertReceived = CompletableDeferred() + val reducerResult = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.table("person").onInsert { row -> + insertReceived.complete(row) + } + + c.subscriptionBuilder() + .onApplied { + subApplied.complete(Unit) + } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM person") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + + // Call the add reducer with name and age arguments + val name = "KotlinSDK_${System.currentTimeMillis()}" + val argsWriter = BsatnWriter(64) + argsWriter.writeString(name) + argsWriter.writeU8(25u) + + conn.callReducer("add", argsWriter.toByteArray()) { result -> + reducerResult.complete(result) + } + + val row = withTimeout(5000.milliseconds) { insertReceived.await() } + assertTrue(row.isNotEmpty(), "Should receive inserted row bytes") + + val result = withTimeout(5000.milliseconds) { reducerResult.await() } + assertNotNull(result, "Should receive reducer result") + + println("PASS: Called add('$name', 25)") + println(" Received insert: ${row.size} bytes") + println(" Reducer result: ${result.outcome}") + + conn.disconnect() + } + } + + @Test + fun multipleReducerCallsPerformance() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val insertCount = java.util.concurrent.atomic.AtomicInteger(0) + val targetCount = 50 + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.table("person").onInsert { insertCount.incrementAndGet() } + + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM person") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { connected.await() } + withTimeout(5000.milliseconds) { subApplied.await() } + + // Fire N reducer calls and measure round-trip time + val elapsed = measureTime { + repeat(targetCount) { i -> + val w = BsatnWriter(64) + w.writeString("Batch_${System.currentTimeMillis()}_$i") + w.writeU8(25u) + conn.callReducer("add", w.toByteArray()) + } + + // Wait for all inserts to arrive + withTimeout(15000.milliseconds) { + while (insertCount.get() < targetCount) { + delay(50.milliseconds) + } + } + } + + assertTrue(insertCount.get() >= targetCount, "Should receive all $targetCount inserts") + val avgMs = elapsed.inWholeMilliseconds.toDouble() / targetCount + println("PASS: $targetCount reducer calls + round-trip in ${elapsed.inWholeMilliseconds}ms") + println(" Avg round-trip: ${"%.1f".format(avgMs)}ms per call") + + conn.disconnect() + } + } + + @Test + fun oneOffQueryExecution() { + if (!shouldRun()) { + println("SKIP: Set SPACETIMEDB_TEST=1"); return + } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000.milliseconds) { connected.await() } + + val elapsed = measureTime { + val result = withTimeout(5000.milliseconds) { + c.oneOffQuery("SELECT * FROM player") + } + if (result.error != null) { + println(" Query returned error: ${result.error}") + } else { + val rows = result.rows?.tables?.flatMap { it.rows.decodeRows() } ?: emptyList() + println("PASS: One-off query returned ${rows.size} player rows") + } + } + println(" Query time: ${elapsed.inWholeMilliseconds}ms") + + conn.disconnect() + } + } + + @Test + fun reconnectionAfterDisconnect() { + if (!shouldRun()) { + println("SKIP: Set SPACETIMEDB_TEST=1"); return + } + + runBlocking { + var connectCount = 0 + val firstConnect = CompletableDeferred() + val secondConnect = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withReconnectPolicy(ReconnectPolicy(maxRetries = 3, initialDelayMs = 500)) + .onConnect { _, _, _ -> + connectCount++ + if (connectCount == 1) firstConnect.complete(Unit) + else secondConnect.complete(Unit) + } + .onConnectError { e -> firstConnect.completeExceptionally(e) } + .build() + + withTimeout(5000.milliseconds) { firstConnect.await() } + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + println("PASS: First connection established") + + // We can't easily force a server-side disconnect from the client, + // so we just verify the reconnect policy is wired up correctly + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + println("PASS: Reconnect policy configured (maxRetries=3, initialDelay=500ms)") + + conn.disconnect() + assertEquals(ConnectionState.DISCONNECTED, conn.connectionState.value) + println("PASS: Clean disconnect") + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt new file mode 100644 index 00000000000..7555f95c5e2 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt @@ -0,0 +1,461 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.measureTime + +/** + * Performance benchmarks for the core SDK machinery. + * + * These validate throughput and latency of: + * - BSATN serialization/deserialization + * - ClientCache insert/delete/update operations + * - Full ServerMessage decode pipeline + * - Gzip decompression throughput + * + * All tests run offline — no server required. + */ +class PerformanceBenchmarkTest { + + // ───────────────────────────── BSATN ───────────────────────────── + + @Test + fun bsatnWriteThroughput() { + val iterations = 100_000 + // Simulate writing a "player row": u64 id, string name, i32 x, i32 y, f64 health + val elapsed = measureTime { + repeat(iterations) { + val w = BsatnWriter(64) + w.writeU64(it.toULong()) + w.writeString("Player_$it") + w.writeI32(it * 10) + w.writeI32(it * -5) + w.writeF64(100.0 - (it % 100)) + w.toByteArray() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("BSATN write: ${iterations} rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + // Sanity: should do at least 100k rows/sec on any modern machine + assertTrue(elapsed.inWholeMilliseconds < 5000, "BSATN write too slow: ${elapsed.inWholeMilliseconds}ms") + } + + @Test + fun bsatnReadThroughput() { + val iterations = 100_000 + // Pre-encode rows + val rows = Array(iterations) { i -> + val w = BsatnWriter(64) + w.writeU64(i.toULong()) + w.writeString("Player_$i") + w.writeI32(i * 10) + w.writeI32(i * -5) + w.writeF64(100.0 - (i % 100)) + w.toByteArray() + } + + val elapsed = measureTime { + for (data in rows) { + val r = BsatnReader(data) + r.readU64() + r.readString() + r.readI32() + r.readI32() + r.readF64() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("BSATN read: ${iterations} rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "BSATN read too slow: ${elapsed.inWholeMilliseconds}ms") + } + + @Test + fun bsatnRoundTripIntegrity() { + // Verify data survives write → read for every primitive type + val w = BsatnWriter(256) + w.writeBool(true) + w.writeBool(false) + w.writeU8(255u) + w.writeI8(-128) + w.writeU16(65535u) + w.writeI16(-32768) + w.writeU32(UInt.MAX_VALUE) + w.writeI32(Int.MIN_VALUE) + w.writeU64(ULong.MAX_VALUE) + w.writeI64(Long.MIN_VALUE) + w.writeF32(3.14f) + w.writeF64(2.718281828459045) + w.writeString("Hello, SpacetimeDB! 🚀") + w.writeByteArray(byteArrayOf(0xCA.toByte(), 0xFE.toByte())) + + val r = BsatnReader(w.toByteArray()) + assertEquals(true, r.readBool()) + assertEquals(false, r.readBool()) + assertEquals(255.toUByte(), r.readU8()) + assertEquals((-128).toByte(), r.readI8()) + assertEquals(65535.toUShort(), r.readU16()) + assertEquals((-32768).toShort(), r.readI16()) + assertEquals(UInt.MAX_VALUE, r.readU32()) + assertEquals(Int.MIN_VALUE, r.readI32()) + assertEquals(ULong.MAX_VALUE, r.readU64()) + assertEquals(Long.MIN_VALUE, r.readI64()) + assertEquals(3.14f, r.readF32()) + assertEquals(2.718281828459045, r.readF64()) + assertEquals("Hello, SpacetimeDB! 🚀", r.readString()) + val bytes = r.readByteArray() + assertEquals(0xCA.toByte(), bytes[0]) + assertEquals(0xFE.toByte(), bytes[1]) + assertTrue(r.isExhausted, "Reader should be fully consumed") + } + + // ───────────────────────── Client Cache ────────────────────────── + + @Test + fun cacheInsertThroughput() { + val cache = ClientCache() + val table = cache.getOrCreateTable("players") + val rowCount = 50_000 + // Pre-generate unique rows + val rows = Array(rowCount) { i -> + val w = BsatnWriter(32) + w.writeU64(i.toULong()) + w.writeString("P$i") + w.toByteArray() + } + + val elapsed = measureTime { + for (row in rows) { + table.insertRow(row) + } + } + assertEquals(rowCount, table.count) + val opsPerSec = rowCount / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Cache insert: $rowCount rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Cache insert too slow") + } + + @Test + fun cacheDeleteThroughput() { + val cache = ClientCache() + val table = cache.getOrCreateTable("players") + val rowCount = 50_000 + val rows = Array(rowCount) { i -> + val w = BsatnWriter(32) + w.writeU64(i.toULong()) + w.writeString("P$i") + w.toByteArray() + } + for (row in rows) table.insertRow(row) + + val elapsed = measureTime { + for (row in rows) { + table.deleteRow(row) + } + } + assertEquals(0, table.count) + val opsPerSec = rowCount / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Cache delete: $rowCount rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Cache delete too slow") + } + + @Test + fun cacheRefCountingCorrectness() { + // Overlapping subscriptions: same row inserted twice, deleted once → still present + val table = TableCache("test") + val row = byteArrayOf(1, 2, 3) + table.insertRow(row) + table.insertRow(row) // refCount = 2 + assertEquals(1, table.count, "Same row should not duplicate") + table.deleteRow(row) // refCount = 1 + assertEquals(1, table.count, "Row should remain with refCount > 0") + assertTrue(table.containsRow(row)) + table.deleteRow(row) // refCount = 0 + assertEquals(0, table.count, "Row should be removed at refCount 0") + } + + @Test + fun cacheTransactionUpdatePerformance() { + val cache = ClientCache() + // Pre-populate with 10k rows + val table = cache.getOrCreateTable("entities") + val existingRows = Array(10_000) { i -> + val w = BsatnWriter(16) + w.writeU64(i.toULong()) + w.writeI32(i) + w.toByteArray() + } + for (row in existingRows) table.insertRow(row) + + // Simulate a transaction: delete 1000 rows, insert 1000 new, update 500 + val deleteRows = existingRows.take(1500) // 1000 pure deletes + 500 updates + val updateNewRows = Array(500) { i -> + val w = BsatnWriter(16) + w.writeU64(i.toULong()) // same key as deleted + w.writeI32(i + 999_999) // different value + w.toByteArray() + } + val insertRows = Array(1000) { i -> + val w = BsatnWriter(16) + w.writeU64((20_000 + i).toULong()) + w.writeI32(i) + w.toByteArray() + } + + // Build the BsatnRowList payloads + val deletePayload = buildRowListPayload(deleteRows.toList()) + val insertPayload = buildRowListPayload(updateNewRows.toList() + insertRows.toList()) + + val qsUpdate = buildQuerySetUpdate("entities", insertPayload, deletePayload) + val elapsed = measureTime { + cache.applyTransactionUpdate(listOf(qsUpdate)) + } + + // Expected: 10000 - 1000 pure deletes + 1000 new inserts = 10000 (500 updates are in-place) + println("Transaction update: 2500 ops in ${elapsed.inWholeMilliseconds}ms") + assertTrue(elapsed.inWholeMilliseconds < 2000, "Transaction update too slow") + } + + // ──────────────────── Protocol Decode Pipeline ─────────────────── + + @Test + fun initialConnectionDecodePerformance() { + // Build a valid InitialConnection message + val w = BsatnWriter(256) + w.writeTag(0u) // InitialConnection tag + w.writeBytes(ByteArray(32) { it.toByte() }) // identity + w.writeBytes(ByteArray(16) { it.toByte() }) // connectionId + w.writeString("test-token-abc123") + val payload = w.toByteArray() + + val iterations = 50_000 + val elapsed = measureTime { + repeat(iterations) { + val msg = ServerMessage.decode(payload) + assertTrue(msg is ServerMessage.InitialConnection) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("InitialConnection decode: $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Decode too slow") + } + + @Test + fun subscribeAppliedDecodeWithRows() { + // Build a SubscribeApplied with 100 rows across 2 tables + val w = BsatnWriter(4096) + w.writeTag(1u) // SubscribeApplied + w.writeU32(42u) // requestId + w.writeU32(7u) // querySetId + + // QueryRows: array of SingleTableRows + w.writeU32(1u) // 1 table + w.writeString("players") // table name + // BsatnRowList: RowSizeHint (tag + data) + length-prefixed row bytes + val rowSize = 12 // u64 + i32 + val rowCount = 100 + w.writeTag(0u) // RowSizeHint::FixedSize + w.writeU16(rowSize.toUShort()) + // Row data as a length-prefixed byte array + w.writeU32((rowSize * rowCount).toUInt()) + repeat(rowCount) { i -> + // Each row: u64 id, i32 score + for (b in 0 until 8) w.writeI8(((i shr (b * 8)) and 0xFF).toByte()) + w.writeI32(i * 100) + } + + val payload = w.toByteArray() + + val iterations = 10_000 + val elapsed = measureTime { + repeat(iterations) { + val msg = ServerMessage.decode(payload) + assertTrue(msg is ServerMessage.SubscribeApplied) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("SubscribeApplied decode (100 rows): $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "SubscribeApplied decode too slow") + } + + @Test + fun clientMessageEncodeThroughput() { + val iterations = 100_000 + val elapsed = measureTime { + repeat(iterations) { i -> + val msg = ClientMessage.CallReducer( + requestId = i.toUInt(), + reducer = "set_position", + args = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8), + ) + msg.encode() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("CallReducer encode: $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Encode too slow") + } + + // ──────────────────────── Gzip Decompression ───────────────────── + + @Test + fun gzipDecompressionThroughput() { + // Compress a realistic payload (1KB of row data) then benchmark decompression + val payload = ByteArray(1024) { (it % 256).toByte() } + val compressed = compressGzip(payload) + println("Gzip: ${payload.size} bytes → ${compressed.size} bytes (${compressed.size * 100 / payload.size}%)") + + val iterations = 50_000 + val elapsed = measureTime { + repeat(iterations) { + val decompressed = decompressGzip(compressed) + assertEquals(payload.size, decompressed.size) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Gzip decompress: $iterations x ${compressed.size}B in ${elapsed.inWholeMilliseconds}ms ($opsPerSec ops/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "Gzip decompression too slow") + } + + @Test + fun gzipLargePayloadDecompression() { + // Simulate a large SubscribeApplied (100KB) + val payload = ByteArray(100_000) { (it % 256).toByte() } + val compressed = compressGzip(payload) + println("Gzip large: ${payload.size} bytes → ${compressed.size} bytes") + + val iterations = 1_000 + val elapsed = measureTime { + repeat(iterations) { + val result = decompressGzip(compressed) + assertEquals(payload.size, result.size) + } + } + val mbPerSec = (payload.size.toLong() * iterations / 1024 / 1024) / + elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Gzip large decompress: $iterations x ${payload.size / 1024}KB in ${elapsed.inWholeMilliseconds}ms ($mbPerSec MB/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "Large gzip decompression too slow") + } + + // ──────────────────── Callback System ──────────────────────────── + + @Test + fun tableHandleCallbackPerformance() { + val handle = TableHandle("test") + var insertCount = 0 + var deleteCount = 0 + var updateCount = 0 + + // Register multiple callbacks + repeat(10) { + handle.onInsert { insertCount++ } + handle.onDelete { deleteCount++ } + handle.onUpdate { _, _ -> updateCount++ } + } + + val row = byteArrayOf(1, 2, 3, 4) + val iterations = 100_000 + val elapsed = measureTime { + repeat(iterations) { + handle.fireInsert(row) + handle.fireDelete(row) + handle.fireUpdate(row, row) + } + } + assertEquals(iterations * 10, insertCount) + assertEquals(iterations * 10, deleteCount) + assertEquals(iterations * 10, updateCount) + println("Callbacks: ${iterations * 3} fires (10 listeners each) in ${elapsed.inWholeMilliseconds}ms") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Callbacks too slow") + } + + @Test + fun callbackRegistrationAndRemoval() { + val handle = TableHandle("test") + var count = 0 + val ids = mutableListOf() + + // Register 100 callbacks that all increment count + repeat(100) { + ids.add(handle.onInsert { count++ }) + } + + // Remove every other one (50 removed, 50 remain) + for (i in ids.indices step 2) { + handle.removeOnInsert(ids[i]) + } + + handle.fireInsert(byteArrayOf(1)) + assertEquals(50, count, "Should have 50 callbacks remaining") + } + + // ──────────────────── End-to-End Message Flow ──────────────────── + + @Test + fun fullMessageRoundTrip() { + // Encode a Subscribe message, verify it round-trips through binary + val subscribe = ClientMessage.Subscribe( + requestId = 1u, + querySetId = QuerySetId(42u), + queryStrings = listOf("SELECT * FROM players", "SELECT * FROM items WHERE owner_id = 7"), + ) + val encoded = subscribe.encode() + assertTrue(encoded.isNotEmpty()) + + // Decode it back manually + val reader = BsatnReader(encoded) + assertEquals(0, reader.readTag().toInt()) // Subscribe tag + assertEquals(1u, reader.readU32()) // requestId + assertEquals(42u, reader.readU32()) // querySetId + val queryCount = reader.readU32().toInt() + assertEquals(2, queryCount) + assertEquals("SELECT * FROM players", reader.readString()) + assertEquals("SELECT * FROM items WHERE owner_id = 7", reader.readString()) + assertTrue(reader.isExhausted) + } + + // ──────────────────── Helpers ──────────────────────────────────── + + private fun compressGzip(data: ByteArray): ByteArray { + val bos = java.io.ByteArrayOutputStream() + java.util.zip.GZIPOutputStream(bos).use { it.write(data) } + return bos.toByteArray() + } + + private fun buildRowListPayload(rows: List): ByteArray { + val w = BsatnWriter(256) + w.writeTag(0u) // RowSizeHint::FixedSize + if (rows.isEmpty()) { + w.writeU16(0u) + w.writeU32(0u) // empty data + return w.toByteArray() + } + val rowSize = rows.first().size + w.writeU16(rowSize.toUShort()) + w.writeU32((rowSize * rows.size).toUInt()) // length-prefixed data + for (row in rows) w.writeBytes(row) + return w.toByteArray() + } + + private fun buildQuerySetUpdate( + tableName: String, + insertPayload: ByteArray, + deletePayload: ByteArray, + ): QuerySetUpdate { + // Encode to BSATN and decode — ensures we go through the real codec + val w = BsatnWriter(insertPayload.size + deletePayload.size + 256) + w.writeU32(1u) // querySetId + w.writeU32(1u) // 1 table + w.writeString(tableName) + w.writeU32(1u) // 1 row update block + w.writeTag(0u) // TableUpdateRows::PersistentTable + // PersistentTableRows: inserts then deletes (each is a full BsatnRowList) + w.writeBytes(insertPayload) + w.writeBytes(deletePayload) + + return QuerySetUpdate.read(BsatnReader(w.toByteArray())) + } +}