From 47991aa3773ecca2f05a7ad050039a1b503de092 Mon Sep 17 00:00:00 2001 From: Martin Hughes Date: Sat, 6 Jun 2026 21:28:39 +0100 Subject: [PATCH] Make WrappedObject Send + Sync using RwLock This changes the WrappedObject interface, so will need a version bump. The aim is to remove the unsafe Send and Sync implementations, to: - Make sure WrappedObject really is Send / Sync - Make it easier for users to use WrappedObject in a thread-safe way. --- Cargo.toml | 1 + src/aml/mod.rs | 658 ++++++++++++++++++++------------ src/aml/namespace.rs | 7 +- src/aml/object.rs | 113 +++--- src/aml/pci_routing.rs | 34 +- src/aml/resource.rs | 5 +- tests/static_assertions.rs | 6 + tools/aml_test_tools/src/lib.rs | 5 +- 8 files changed, 498 insertions(+), 331 deletions(-) create mode 100644 tests/static_assertions.rs diff --git a/Cargo.toml b/Cargo.toml index e429a7fe..466698a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ pci_types = { version = "0.10.0", public = true } byteorder = { version = "1.5.0", default-features = false } [dev-dependencies] +static_assertions = "1.1.0" aml_test_tools = { path = "tools/aml_test_tools" } pretty_env_logger = "0.5.0" diff --git a/src/aml/mod.rs b/src/aml/mod.rs index 65707db8..0486ba0c 100644 --- a/src/aml/mod.rs +++ b/src/aml/mod.rs @@ -57,7 +57,6 @@ use object::{ FieldUpdateRule, MethodFlags, Object, - ObjectToken, ObjectType, ReferenceKind, WrappedObject, @@ -104,7 +103,6 @@ where { handler: H, pub namespace: Spinlock, - pub object_token: Spinlock, integer_size: IntegerSize, region_handlers: Spinlock>>, @@ -139,7 +137,6 @@ where Interpreter { handler, namespace: Spinlock::new(Namespace::new(global_lock_mutex)), - object_token: Spinlock::new(unsafe { ObjectToken::create_interpreter_token() }), integer_size: IntegerSize::from_revision(dsdt_revision), region_handlers: Spinlock::new(BTreeMap::new()), global_lock_mutex, @@ -204,14 +201,19 @@ where trace!("Invoking AML method: {}", path); let object = self.namespace.lock().get(path.clone())?.clone(); - match &*object { - Object::Method { .. } => { - self.namespace.lock().add_level(path.clone(), NamespaceLevelKind::MethodLocals)?; - let context = MethodContext::new_from_method(object, args, path)?; - self.do_execute_method(context) - } - Object::NativeMethod { f, .. } => f(&args), - _ => Ok(object), + let read = object.read(); + if let Object::Method { .. } = &*read { + drop(read); + self.namespace.lock().add_level(path.clone(), NamespaceLevelKind::MethodLocals)?; + let context = MethodContext::new_from_method(object, args, path)?; + self.do_execute_method(context) + } else if let Object::NativeMethod { f, .. } = &*read { + let f = f.clone(); + drop(read); + f(&args) + } else { + drop(read); + Ok(object) } } @@ -285,8 +287,11 @@ where .evaluate_if_present(AmlName::from_str("_STA").unwrap().resolve(path)?, vec![]) { Ok(Some(result)) => { - let Object::Integer(result) = *result else { panic!() }; - let status = DeviceStatus(result); + let result_int = match &*result.read() { + Object::Integer(v) => *v, + _ => panic!(), + }; + let status = DeviceStatus(result_int); status.present() && status.functioning() } Ok(None) => true, @@ -452,22 +457,24 @@ where Opcode::Increment | Opcode::Decrement => { let [Argument::Object(operand)] = &op.arguments[..] else { panic!() }; let operand = operand.clone().unwrap_transparent_reference(); - let token = self.object_token.lock(); - - let Object::Integer(operand) = (unsafe { operand.gain_mut(&token) }) else { - Err(AmlError::ObjectNotOfExpectedType { + let mut operand_write = operand.write(); + + let new_value = match &mut *operand_write { + Object::Integer(val) => { + let new = match op.op { + Opcode::Increment => val.wrapping_add(1), + Opcode::Decrement => val.wrapping_sub(1), + _ => unreachable!(), + }; + *val = new; + new + } + other => Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::Integer, - got: operand.typ(), - })? + got: other.typ(), + })?, }; - let new_value = match op.op { - Opcode::Increment => operand.wrapping_add(1), - Opcode::Decrement => operand.wrapping_sub(1), - _ => unreachable!(), - }; - - *operand = new_value; context.contribute_arg(Argument::Object(Object::Integer(new_value).wrap())); context.retire_op(op); } @@ -494,11 +501,17 @@ where Argument::Object(source2), Argument::Object(target) ]); - let source1 = source1.as_buffer()?; - let source2 = source2.as_buffer()?; + let source1_buf = { + let source1_read = source1.read(); + source1_read.as_buffer()?.to_vec() + }; + let source2_buf = { + let source2_read = source2.read(); + source2_read.as_buffer()?.to_vec() + }; let result = { - let mut buffer = Vec::from(source1); - buffer.extend_from_slice(source2); + let mut buffer = source1_buf; + buffer.extend_from_slice(&source2_buf); // Add a new end-tag buffer.push(0x78); // Don't calculate the new real checksum - just use 0 @@ -513,35 +526,47 @@ where Opcode::Reset => { extract_args!(op => [Argument::Object(sync_object)]); let sync_object = sync_object.clone().unwrap_reference(); + let sync_read = sync_object.read(); - if let Object::Event(ref counter) = *sync_object { + if let Object::Event(ref counter) = *sync_read { counter.store(0, Ordering::Release); } else { return Err(AmlError::InvalidOperationOnObject { op: Operation::ResetEvent, - typ: sync_object.typ(), + typ: sync_read.typ(), }); } } Opcode::Signal => { extract_args!(op => [Argument::Object(sync_object)]); let sync_object = sync_object.clone().unwrap_reference(); + let sync_read = sync_object.read(); - if let Object::Event(ref counter) = *sync_object { + if let Object::Event(ref counter) = *sync_read { counter.fetch_add(1, Ordering::AcqRel); } else { return Err(AmlError::InvalidOperationOnObject { op: Operation::SignalEvent, - typ: sync_object.typ(), + typ: sync_read.typ(), }); } } Opcode::Wait => { extract_args!(op => [Argument::Object(sync_object), Argument::Object(timeout)]); let sync_object = sync_object.clone().unwrap_reference(); - let timeout = u64::min(timeout.as_integer()?, 0xffff); + let timeout = u64::min(timeout.read().as_integer()?, 0xffff); + + let sync_read = sync_object.read(); + let counter = if let Object::Event(counter) = &*sync_read { + counter.clone() + } else { + return Err(AmlError::InvalidOperationOnObject { + op: Operation::WaitEvent, + typ: sync_read.typ(), + }); + }; - if let Object::Event(ref counter) = *sync_object { + { /* * `Wait` returns a non-zero value if a timeout occurs and the event * was not signaled, and zero if it was. Timeout is specified in @@ -582,17 +607,12 @@ where context.contribute_arg(Argument::Object( Object::Integer(if timed_out { u64::MAX } else { 0 }).wrap(), )); - } else { - return Err(AmlError::InvalidOperationOnObject { - op: Operation::WaitEvent, - typ: sync_object.typ(), - }); } } Opcode::Notify => { // TODO: may need special handling on the node to get path? extract_args!(op => [Argument::Namestring(name), Argument::Object(value)]); - let value = value.as_integer()?; + let value = value.read().as_integer()?; info!("Notify {:?} with value {}", name, value); // TODO: support @@ -608,7 +628,7 @@ where } Opcode::Fatal => { extract_args!(op => [Argument::ByteData(typ), Argument::DWordData(code), Argument::Object(arg)]); - let arg = arg.as_integer()?; + let arg = arg.read().as_integer()?; self.handler.handle_fatal_error(*typ, *code, arg); context.retire_op(op); return Err(AmlError::FatalErrorEncountered); @@ -625,8 +645,8 @@ where let region = Object::OpRegion(OpRegion { space: RegionSpace::from(*region_space), - base: region_offset.as_integer()?, - length: region_length.as_integer()?, + base: region_offset.read().as_integer()?, + length: region_length.read().as_integer()?, parent_device_path: context.current_scope.clone(), }); self.namespace.lock().insert(name.resolve(&context.current_scope)?, region.wrap())?; @@ -639,9 +659,9 @@ where Argument::Object(oem_id), Argument::Object(oem_table_id), ]); - let _signature = signature.as_string()?; - let _oem_id = oem_id.as_string()?; - let _oem_table_id = oem_table_id.as_string()?; + signature.read().as_string()?; + oem_id.read().as_string()?; + oem_table_id.read().as_string()?; // TODO: once this is integrated into the rest of the crate, load the table log::warn!( @@ -663,7 +683,8 @@ where Argument::PkgLength(pkg_length), Argument::Object(buffer_size), ]); - let buffer_size = buffer_size.clone().unwrap_transparent_reference().as_integer()?; + let buffer_size = + buffer_size.clone().unwrap_transparent_reference().read().as_integer()?; let buffer_len = pkg_length - (context.current_block.pc - start_pc); let mut buffer = vec![0; buffer_size as usize]; @@ -707,7 +728,7 @@ where Opcode::VarPackage => { extract_args!(op[0..1] => [Argument::Object(total_elements)]); let total_elements = - total_elements.clone().unwrap_transparent_reference().as_integer()? as usize; + total_elements.clone().unwrap_transparent_reference().read().as_integer()? as usize; let mut elements = Vec::with_capacity(total_elements); for arg in &op.arguments[1..] { @@ -735,7 +756,7 @@ where Argument::PkgLength(then_length), Argument::Object(predicate), ]); - let predicate = predicate.as_integer()?; + let predicate = predicate.read().as_integer()?; let remaining_then_length = then_length - (context.current_block.pc - start_pc); if predicate > 0 { @@ -768,7 +789,7 @@ where | opcode @ Opcode::CreateQWordField => { extract_args!(op => [Argument::Object(buffer), Argument::Object(index)]); let name = context.namestring()?; - let index = index.as_integer()?; + let index = index.read().as_integer()?; let (offset, length) = match opcode { Opcode::CreateBitField => (index, 1), Opcode::CreateByteField => (index * 8, 8), @@ -786,8 +807,8 @@ where Opcode::CreateField => { extract_args!(op => [Argument::Object(buffer), Argument::Object(bit_index), Argument::Object(num_bits)]); let name = context.namestring()?; - let bit_index = bit_index.as_integer()?; - let num_bits = num_bits.as_integer()?; + let bit_index = bit_index.read().as_integer()?; + let num_bits = num_bits.read().as_integer()?; self.namespace.lock().insert( name.resolve(&context.current_scope)?, @@ -819,30 +840,36 @@ where } Opcode::CondRefOf => { extract_args!(op => [Argument::Object(object), Argument::Object(target)]); - let result = if let Object::Reference { kind: ReferenceKind::Unresolved, .. } = **object { - Object::Integer(0) - } else { - let reference = - Object::Reference { kind: ReferenceKind::RefOf, inner: object.clone() }.wrap(); - self.do_store(target.clone(), reference)?; - Object::Integer(u64::MAX) - }; + let result = + if let Object::Reference { kind: ReferenceKind::Unresolved, .. } = &*object.read() { + Object::Integer(0) + } else { + let reference = + Object::Reference { kind: ReferenceKind::RefOf, inner: object.clone() }.wrap(); + self.do_store(target.clone(), reference)?; + Object::Integer(u64::MAX) + }; context.contribute_arg(Argument::Object(result.wrap())); context.retire_op(op); } Opcode::DerefOf => { extract_args!(op => [Argument::Object(object)]); - let result = match **object { - Object::Reference { kind: _, inner: _ } => object.clone().unwrap_reference(), + let obj_read = object.read(); + let result = match &*obj_read { + Object::Reference { .. } => { + drop(obj_read); + object.clone().unwrap_reference() + } Object::String(_) => { - let path = AmlName::from_str(&object.as_string().unwrap())?; + let path = AmlName::from_str(&obj_read.as_string().unwrap())?; + drop(obj_read); let (_, object) = self.namespace.lock().search(&path, &context.current_scope)?; object.clone() } _ => { return Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::Reference, - got: object.typ(), + got: obj_read.typ(), }); } }; @@ -875,48 +902,62 @@ where } Opcode::Sleep => { extract_args!(op => [Argument::Object(msec)]); - self.handler.sleep(msec.as_integer()?); + self.handler.sleep(msec.read().as_integer()?); context.retire_op(op); } Opcode::Stall => { extract_args!(op => [Argument::Object(usec)]); - self.handler.stall(usec.as_integer()?); + self.handler.stall(usec.read().as_integer()?); context.retire_op(op); } Opcode::Acquire => { extract_args!(op => [Argument::Object(mutex)]); - let Object::Mutex { mutex, sync_level: _ } = **mutex else { - Err(AmlError::InvalidOperationOnObject { op: Operation::Acquire, typ: mutex.typ() })? + let mutex_handle = { + let mutex_read = mutex.read(); + let Object::Mutex { mutex: handle, sync_level: _ } = &*mutex_read else { + return Err(AmlError::InvalidOperationOnObject { + op: Operation::Acquire, + typ: mutex_read.typ(), + }); + }; + *handle }; let timeout = context.next_u16()?; // TODO: should we do something with the sync level?? - if mutex == self.global_lock_mutex { + if mutex_handle == self.global_lock_mutex { self.acquire_global_lock(timeout)?; } else { - self.handler.acquire(mutex, timeout)?; + self.handler.acquire(mutex_handle, timeout)?; } context.retire_op(op); } Opcode::Release => { extract_args!(op => [Argument::Object(mutex)]); - let Object::Mutex { mutex, sync_level: _ } = **mutex else { - Err(AmlError::InvalidOperationOnObject { op: Operation::Release, typ: mutex.typ() })? + let mutex_handle = { + let mutex_read = mutex.read(); + let Object::Mutex { mutex: handle, sync_level: _ } = &*mutex_read else { + return Err(AmlError::InvalidOperationOnObject { + op: Operation::Release, + typ: mutex_read.typ(), + }); + }; + *handle }; // TODO: should we do something with the sync level?? - if mutex == self.global_lock_mutex { + if mutex_handle == self.global_lock_mutex { self.release_global_lock()?; } else { - self.handler.release(mutex); + self.handler.release(mutex_handle); } context.retire_op(op); } Opcode::InternalMethodCall => { extract_args!(op[0..2] => [Argument::Object(method), Argument::Namestring(method_scope)]); - let args = op.arguments[2..] + let args: Vec = op.arguments[2..] .iter() .map(|arg| { if let Argument::Object(arg) = arg { @@ -927,21 +968,34 @@ where }) .collect(); - if let Object::Method { .. } = **method { - self.namespace - .lock() - .add_level(method_scope.clone(), NamespaceLevelKind::MethodLocals)?; - - let new_context = - MethodContext::new_from_method(method.clone(), args, method_scope.clone())?; - let old_context = mem::replace(&mut context, new_context); - context_stack.push(old_context); - context.retire_op(op); - } else if let Object::NativeMethod { ref f, .. } = **method { + let native_f = { + let method_read = method.read(); + if let Object::NativeMethod { ref f, .. } = *method_read { + Some(f.clone()) + } else { + None + } + }; + if let Some(f) = native_f { let result = f(&args)?; context.contribute_arg(Argument::Object(result)); } else { - panic!(); + // Must be a regular Method + let method_read = method.read(); + if let Object::Method { .. } = *method_read { + drop(method_read); + self.namespace + .lock() + .add_level(method_scope.clone(), NamespaceLevelKind::MethodLocals)?; + + let new_context = + MethodContext::new_from_method(method.clone(), args, method_scope.clone())?; + let old_context = mem::replace(&mut context, new_context); + context_stack.push(old_context); + context.retire_op(op); + } else { + panic!(); + } } } Opcode::Return => { @@ -967,7 +1021,8 @@ where // (they should return `0`) fn object_type(object: &Object) -> u64 { if let Object::Reference { kind: _, inner } = object { - object_type(&inner) + let inner_read = inner.read(); + object_type(&*inner_read) } else { match object.typ() { ObjectType::Uninitialized => 0, @@ -993,7 +1048,11 @@ where } } - context.contribute_arg(Argument::Object(Object::Integer(object_type(&object)).wrap())); + let result_type = { + let obj_read = object.read(); + object_type(&*obj_read) + }; + context.contribute_arg(Argument::Object(Object::Integer(result_type).wrap())); context.retire_op(op); } Opcode::SizeOf => self.do_size_of(&mut context, op)?, @@ -1006,7 +1065,7 @@ where Argument::Namestring(bank_name), Argument::Object(bank_value), ]); - let bank_value = bank_value.as_integer()?; + let bank_value = bank_value.read().as_integer()?; let field_flags = context.next()?; let (region, bank) = { @@ -1026,7 +1085,7 @@ where * false, skip over the rest of the loop, otherwise carry on. */ extract_args!(op => [Argument::Object(predicate)]); - let predicate = predicate.clone().unwrap_transparent_reference().as_integer()?; + let predicate = predicate.clone().unwrap_transparent_reference().read().as_integer()?; if predicate == 0 { // Exit from the while loop by skipping out of the current block @@ -1109,9 +1168,12 @@ where let Argument::Object(total_elements) = &package_op.arguments[0] else { panic!() }; - let total_elements = - total_elements.clone().unwrap_transparent_reference().as_integer()? - as usize; + let total_elements = total_elements + .clone() + .unwrap_transparent_reference() + .read() + .as_integer()? + as usize; // Update the expected number of arguments to terminate the in-flight op package_op.expected_arguments = package_op.arguments.len(); @@ -1432,7 +1494,9 @@ where (index, data) }; - if let Object::FieldUnit(ref data_fu) = *data { + let data_cloned = data.clone(); + let data_read = data_cloned.read(); + if let Object::FieldUnit(ref data_fu) = *data_read { if data_fu.flags.access_type_bytes()? < FieldFlags(field_flags).access_type_bytes()? { // On ACPICA this causes reads/writes to be truncated to the width of // the data register. @@ -1591,9 +1655,12 @@ where let object = self.namespace.lock().search(&name, &context.current_scope); match object { Ok((resolved_name, object)) => { + let obj_read = object.read(); if let Object::Method { flags, .. } | Object::NativeMethod { flags, .. } = - *object + &*obj_read { + let flags = *flags; + drop(obj_read); context.start(OpInFlight::new_with_dynamic( Opcode::InternalMethodCall, vec![Argument::Object(object), Argument::Namestring(resolved_name)], @@ -1611,13 +1678,17 @@ where ResolveBehaviour::TermArg, ], )) - } else if let Object::FieldUnit(ref field) = *object { - let value = self.do_field_read(field)?; + } else if let Object::FieldUnit(field) = &*obj_read { + let field = field.clone(); + drop(obj_read); + let value = self.do_field_read(&field)?; context.contribute_arg(Argument::Object(value)); - } else if let Object::BufferField { .. } = *object { - let value = object.read_buffer_field(self.integer_size)?; + } else if let Object::BufferField { .. } = &*obj_read { + let value = obj_read.read_buffer_field(self.integer_size)?; + drop(obj_read); context.contribute_arg(Argument::Object(value.wrap())); } else { + drop(obj_read); context.contribute_arg(Argument::Object(object)); } } @@ -1862,8 +1933,8 @@ where extract_args!(op[0..3] => [Argument::Object(left), Argument::Object(right), Argument::Object(target)]); let target2 = if op.op == Opcode::Divide { Some(&op.arguments[3]) } else { None }; - let left = left.clone().unwrap_transparent_reference().to_integer(self.integer_size)?; - let right = right.clone().unwrap_transparent_reference().to_integer(self.integer_size)?; + let left = left.clone().unwrap_transparent_reference().read().to_integer(self.integer_size)?; + let right = right.clone().unwrap_transparent_reference().read().to_integer(self.integer_size)?; let result = match op.op { Opcode::Add => left.wrapping_add(right), @@ -1895,7 +1966,7 @@ where fn do_unary_maths(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(operand)]); - let operand = operand.clone().unwrap_transparent_reference().as_integer()?; + let operand = operand.clone().unwrap_transparent_reference().read().as_integer()?; let result = match op.op { Opcode::FindSetLeftBit => { @@ -1937,7 +2008,7 @@ where fn do_logical_op(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { if op.op == Opcode::LNot { extract_args!(op => [Argument::Object(operand)]); - let operand = operand.clone().unwrap_transparent_reference().as_integer()?; + let operand = operand.clone().unwrap_transparent_reference().read().as_integer()?; let result = if operand == 0 { u64::MAX } else { 0 }; context.contribute_arg(Argument::Object(Object::Integer(result).wrap())); @@ -1952,25 +2023,29 @@ where let mut int_size = self.integer_size; // Make sure both sides are the same type. - let right = match *left { - Object::Integer(_) => &Object::Integer(right.as_integer()?), - Object::String(_) => &Object::String(right.as_string()?.parse().unwrap()), + let left_read = left.read(); + let right_obj: Object = match &*left_read { + Object::Integer(_) => Object::Integer(right.read().as_integer()?), + Object::String(_) => Object::String(right.read().as_string()?.into_owned()), Object::Buffer(_) => { // When doing && or ||, uACPI and NT only compare the first 4 bytes of a buffer. int_size = IntegerSize::FourBytes; - if right.typ() == ObjectType::Buffer { - &*right + let right_read = right.read(); + if right_read.typ() == ObjectType::Buffer { + right_read.clone() } else { - &Object::Buffer(right.to_buffer(self.integer_size)?) + Object::Buffer(right_read.to_buffer(self.integer_size)?) } } - _ => Err(AmlError::InvalidOperationOnObject { op: Operation::LogicalOp, typ: left.typ() })?, + _ => { + return Err(AmlError::InvalidOperationOnObject { op: Operation::LogicalOp, typ: left_read.typ() }); + } }; - let ordering = left.aml_cmp(right); + let ordering = left_read.aml_cmp(&right_obj); let result = match op.op { - Opcode::LAnd => (left.to_integer(int_size)? > 0) && (right.to_integer(int_size)? > 0), - Opcode::LOr => (left.to_integer(int_size)? > 0) || (right.to_integer(int_size)? > 0), + Opcode::LAnd => (left_read.to_integer(int_size)? > 0) && (right_obj.to_integer(int_size)? > 0), + Opcode::LOr => (left_read.to_integer(int_size)? > 0) || (right_obj.to_integer(int_size)? > 0), Opcode::LNotEqual => ordering?.is_ne(), Opcode::LLessEqual => ordering?.is_le(), Opcode::LGreaterEqual => ordering?.is_ge(), @@ -1989,17 +2064,18 @@ where fn do_to_buffer(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(operand), Argument::Object(target)]); let operand = operand.clone().unwrap_transparent_reference(); + let operand_read = operand.read(); - let result = match *operand { - Object::Buffer(ref bytes) => Object::Buffer(bytes.clone()), + let result = match &*operand_read { + Object::Buffer(bytes) => Object::Buffer(bytes.clone()), Object::Integer(value) => { if self.integer_size == IntegerSize::EightBytes { Object::Buffer(value.to_le_bytes().to_vec()) } else { - Object::Buffer((value as u32).to_le_bytes().to_vec()) + Object::Buffer((*value as u32).to_le_bytes().to_vec()) } } - Object::String(ref value) => { + Object::String(value) => { // XXX: an empty string is converted to an empty buffer, *without* the null-terminator if value.is_empty() { Object::Buffer(vec![]) @@ -2009,7 +2085,12 @@ where Object::Buffer(bytes) } } - _ => Err(AmlError::InvalidOperationOnObject { op: Operation::ToBuffer, typ: operand.typ() })?, + _ => { + return Err(AmlError::InvalidOperationOnObject { + op: Operation::ToBuffer, + typ: operand_read.typ(), + }); + } } .wrap(); @@ -2023,7 +2104,7 @@ where extract_args!(op => [Argument::Object(operand), Argument::Object(target)]); let operand = operand.clone().unwrap_transparent_reference(); - let result = Object::Integer(operand.to_integer(self.integer_size)?).wrap(); + let result = Object::Integer(operand.read().to_integer(self.integer_size)?).wrap(); let result = self.do_store(target.clone(), result)?; context.contribute_arg(Argument::Object(result)); context.retire_op(op); @@ -2033,8 +2114,9 @@ where fn do_to_string(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(source), Argument::Object(length), Argument::Object(target)]); let source = source.clone().unwrap_transparent_reference(); - let source = source.as_buffer()?; - let length = length.clone().unwrap_transparent_reference().as_integer()? as usize; + let source_read = source.read(); + let source = source_read.as_buffer()?; + let length = length.clone().unwrap_transparent_reference().read().as_integer()? as usize; let result = if source.is_empty() { Object::String(String::new()) @@ -2061,15 +2143,16 @@ where fn do_to_dec_hex_string(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(operand), Argument::Object(target)]); let operand = operand.clone().unwrap_transparent_reference(); + let operand_read = operand.read(); - let result = match *operand { - Object::String(ref value) => Object::String(value.clone()), + let result = match &*operand_read { + Object::String(value) => Object::String(value.clone()), Object::Integer(value) => match op.op { Opcode::ToDecimalString => Object::String(value.to_string()), Opcode::ToHexString => Object::String(alloc::format!("{value:#X}")), _ => panic!(), }, - Object::Buffer(ref bytes) => { + Object::Buffer(bytes) => { if bytes.is_empty() { Object::String(String::new()) } else { @@ -2089,7 +2172,12 @@ where Object::String(string) } } - _ => Err(AmlError::InvalidOperationOnObject { op: Operation::ToDecOrHexString, typ: operand.typ() })?, + _ => { + return Err(AmlError::InvalidOperationOnObject { + op: Operation::ToDecOrHexString, + typ: operand_read.typ(), + }); + } } .wrap(); @@ -2101,29 +2189,35 @@ where fn do_mid(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(source), Argument::Object(index), Argument::Object(length), Argument::Object(target)]); - let index = index.clone().unwrap_transparent_reference().as_integer()? as usize; - let length = length.clone().unwrap_transparent_reference().as_integer()? as usize; - - let result = match **source { - Object::String(ref string) => { - if index >= string.len() { - Object::String(String::new()) - } else { - let upper = usize::min(index + length, index + string.len()); - let chars = &string[index..upper]; - Object::String(String::from(chars)) + let index = index.clone().unwrap_transparent_reference().read().as_integer()? as usize; + let length = length.clone().unwrap_transparent_reference().read().as_integer()? as usize; + + let result = { + let source_unwrapped = source.clone().unwrap_transparent_reference(); + let source_read = source_unwrapped.read(); + match &*source_read { + Object::String(string) => { + if index >= string.len() { + Object::String(String::new()) + } else { + let upper = usize::min(index + length, index + string.len()); + let chars = &string[index..upper]; + Object::String(String::from(chars)) + } } - } - Object::Buffer(ref buffer) => { - if index >= buffer.len() { - Object::Buffer(vec![]) - } else { - let upper = usize::min(index + length, index + buffer.len()); - let bytes = &buffer[index..upper]; - Object::Buffer(bytes.to_vec()) + Object::Buffer(buffer) => { + if index >= buffer.len() { + Object::Buffer(vec![]) + } else { + let upper = usize::min(index + length, index + buffer.len()); + let bytes = &buffer[index..upper]; + Object::Buffer(bytes.to_vec()) + } + } + _ => { + return Err(AmlError::InvalidOperationOnObject { op: Operation::Mid, typ: source_read.typ() }); } } - _ => Err(AmlError::InvalidOperationOnObject { op: Operation::Mid, typ: source.typ() })?, } .wrap(); @@ -2149,7 +2243,11 @@ where Object::Integer(value) => value.to_string(), Object::Method { .. } | Object::NativeMethod { .. } => "[Control Method]".to_string(), Object::Mutex { .. } => "[Mutex]".to_string(), - Object::Reference { inner, .. } => resolve_as_string(&(inner.clone().unwrap_reference())), + Object::Reference { inner, .. } => { + let unwrapped = inner.clone().unwrap_reference(); + let r = unwrapped.read(); + resolve_as_string(&*r) + } Object::OpRegion(_) => "[Operation Region]".to_string(), Object::Package(_) => "[Package]".to_string(), Object::PowerResource { .. } => "[Power Resource]".to_string(), @@ -2161,29 +2259,31 @@ where } } - let result = match source1.typ() { + let result = match source1.read().typ() { ObjectType::Integer => { - let source1 = source1.as_integer()?; - let source2 = source2.to_integer(self.integer_size)?; + let s1 = source1.read().as_integer()?; + let s2 = source2.read().to_integer(self.integer_size)?; let mut buffer = Vec::new(); if self.integer_size == IntegerSize::EightBytes { - buffer.extend_from_slice(&source1.to_le_bytes()); - buffer.extend_from_slice(&source2.to_le_bytes()); + buffer.extend_from_slice(&s1.to_le_bytes()); + buffer.extend_from_slice(&s2.to_le_bytes()); } else { - buffer.extend_from_slice(&(source1 as u32).to_le_bytes()); - buffer.extend_from_slice(&(source2 as u32).to_le_bytes()); + buffer.extend_from_slice(&(s1 as u32).to_le_bytes()); + buffer.extend_from_slice(&(s2 as u32).to_le_bytes()); } Object::Buffer(buffer).wrap() } ObjectType::Buffer => { - let mut buffer = source1.as_buffer()?.to_vec(); - buffer.extend(source2.to_buffer(self.integer_size)?); + let mut buffer = source1.read().as_buffer()?.to_vec(); + buffer.extend(source2.read().to_buffer(self.integer_size)?); Object::Buffer(buffer).wrap() } _ => { - let source1 = resolve_as_string(&source1); - let source2 = resolve_as_string(&source2); - Object::String(source1 + &source2).wrap() + let s1_read = source1.read(); + let s2_read = source2.read(); + let s1 = resolve_as_string(&*s1_read); + let s2 = resolve_as_string(&*s2_read); + Object::String(s1 + &s2).wrap() } }; @@ -2195,7 +2295,7 @@ where fn do_from_bcd(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(value)]); - let mut value = value.clone().unwrap_transparent_reference().as_integer()?; + let mut value = value.clone().unwrap_transparent_reference().read().as_integer()?; let mut result = 0; let mut i = 1; @@ -2212,7 +2312,7 @@ where fn do_to_bcd(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(value)]); - let mut value = value.clone().unwrap_transparent_reference().as_integer()?; + let mut value = value.clone().unwrap_transparent_reference().read().as_integer()?; let mut result = 0; let mut i = 0; @@ -2230,12 +2330,13 @@ where fn do_size_of(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(object)]); let object = object.clone().unwrap_reference(); + let obj_read = object.read(); - let result = match *object { - Object::Buffer(ref buffer) => buffer.len(), - Object::String(ref str) => str.len(), - Object::Package(ref package) => package.len(), - _ => Err(AmlError::InvalidOperationOnObject { op: Operation::SizeOf, typ: object.typ() })?, + let result = match &*obj_read { + Object::Buffer(buffer) => buffer.len(), + Object::String(str) => str.len(), + Object::Package(package) => package.len(), + _ => Err(AmlError::InvalidOperationOnObject { op: Operation::SizeOf, typ: obj_read.typ() })?, }; context.contribute_arg(Argument::Object(Object::Integer(result as u64).wrap())); @@ -2246,10 +2347,11 @@ where fn do_index(&self, context: &mut MethodContext, op: OpInFlight) -> Result<(), AmlError> { extract_args!(op => [Argument::Object(object), Argument::Object(index_value), Argument::Object(target)]); let object = object.clone().unwrap_transparent_reference(); - let index_value = index_value.clone().unwrap_transparent_reference().as_integer()?; + let index_value = index_value.clone().unwrap_transparent_reference().read().as_integer()?; + let obj_read = object.read(); - let result = match *object { - Object::Buffer(ref buffer) => { + let result = match &*obj_read { + Object::Buffer(buffer) => { if index_value as usize >= buffer.len() { Err(AmlError::IndexOutOfBounds)? } @@ -2264,7 +2366,7 @@ where .wrap(), } } - Object::String(ref string) => { + Object::String(string) => { if index_value as usize >= string.len() { Err(AmlError::IndexOutOfBounds)? } @@ -2279,7 +2381,7 @@ where .wrap(), } } - Object::Package(ref package) => { + Object::Package(package) => { let Some(element) = package.get(index_value as usize) else { Err(AmlError::IndexOutOfBounds)? }; Object::Reference { kind: ReferenceKind::RefOf, inner: element.clone() } } @@ -2303,23 +2405,28 @@ where /// - Named objects are stored into, with implicit casting fn do_store(&self, target: WrappedObject, object: WrappedObject) -> Result { let object = object.unwrap_transparent_reference(); - let token = self.object_token.lock(); - match unsafe { target.gain_mut(&token) } { + // Extract reference kind and inner from target, releasing the read lock before mutating. + let target_read = target.read(); + match &*target_read { Object::Reference { kind, inner } => { let (target_object, overwrite) = match kind { ReferenceKind::Named => (inner.clone().unwrap_reference(), false), ReferenceKind::Local | ReferenceKind::Index => { - if let Object::Reference { kind: _, inner: ref inner_inner } = **inner { + let inner_read = inner.read(); + if let Object::Reference { inner: inner_inner, .. } = &*inner_read { (inner_inner.clone(), false) } else { + drop(inner_read); (inner.clone().unwrap_transparent_reference(), true) } } ReferenceKind::Arg => { - if let Object::Reference { kind: _, inner: ref inner_inner } = **inner { + let inner_read = inner.read(); + if let Object::Reference { inner: inner_inner, .. } = &*inner_read { (inner_inner.clone(), true) } else { + drop(inner_read); (inner.clone().unwrap_transparent_reference(), true) } } @@ -2327,28 +2434,30 @@ where return Err(AmlError::StoreToInvalidReferenceType); } }; + drop(target_read); if overwrite { - unsafe { - *target_object.gain_mut(&token) = (*object).clone(); - } + *target_object.write() = object.read().clone(); } else { - match &*target_object { + let target_read2 = target_object.read(); + match &*target_read2 { Object::Integer(_) | Object::String(_) | Object::Buffer(_) => { - let target_object = unsafe { target_object.gain_mut(&token) }; - target_object.replace_with_implicit_casting((*object).clone())?; + let new_val = object.read().clone(); + drop(target_read2); + target_object.write().replace_with_implicit_casting(new_val)?; } Object::BufferField { .. } => { - let target_object = unsafe { target_object.gain_mut(&token) }; - match unsafe { object.gain_mut(&token) } { - Object::Integer(value) => { - target_object.write_buffer_field(&value.to_le_bytes(), &token)? - } - Object::Buffer(value) => { - target_object.write_buffer_field(value.as_slice(), &token)? + // Read the value to write from `object` before acquiring write lock. + let write_data: alloc::vec::Vec = { + let obj_read = object.read(); + match &*obj_read { + Object::Integer(value) => value.to_le_bytes().to_vec(), + Object::Buffer(value) => value.clone(), + _ => panic!(), } - _ => panic!(), - } + }; + drop(target_read2); + target_object.write().write_buffer_field(&write_data)?; } /* @@ -2357,20 +2466,28 @@ where * write-then-read pattern. We should perform a read in those cases and * return that instead. */ - Object::FieldUnit(field_unit) => self.do_field_write(field_unit, object.clone())?, + Object::FieldUnit(field_unit) => { + let field_unit = field_unit.clone(); + drop(target_read2); + self.do_field_write(&field_unit, object.clone())?; + } _ => { return Err(AmlError::InvalidOperationOnObject { op: Operation::Store, - typ: target_object.typ(), + typ: target_read2.typ(), }); } } } } - Object::Debug => self.handler.handle_debug(&object), + Object::Debug => { + drop(target_read); + let object_read = object.read(); + self.handler.handle_debug(&*object_read); + } Object::Integer(0) => {} // Store to NullName - _ => return Err(AmlError::InvalidOperationOnObject { op: Operation::Store, typ: target.typ() }), + _ => return Err(AmlError::InvalidOperationOnObject { op: Operation::Store, typ: target_read.typ() }), } Ok(object) @@ -2384,16 +2501,21 @@ where /// - Index references cause the object at the index to be overwritten /// - Other reference operations are not allowed fn do_copy_object(&self, target: WrappedObject, object: WrappedObject) -> Result<(), AmlError> { - let Object::Reference { kind, ref inner } = *target else { - return Err(AmlError::InternalError("Target of CopyObject must be a reference".to_string())); - }; let object = object.clone().unwrap_transparent_reference(); - let token = self.object_token.lock(); + + let (kind, inner) = { + let target_read = target.read(); + let Object::Reference { kind, inner } = &*target_read else { + return Err(AmlError::InternalError("Target of CopyObject must be a reference".to_string())); + }; + (*kind, inner.clone()) + }; let dst = match kind { ReferenceKind::Named | ReferenceKind::Local => inner.clone().unwrap_transparent_reference(), ReferenceKind::Arg => { - if let Object::Reference { kind: _, inner: ref inner_inner } = **inner { + let inner_read = inner.read(); + if let Object::Reference { inner: inner_inner, .. } = &*inner_read { inner_inner.clone() } else { inner.clone().unwrap_transparent_reference() @@ -2403,10 +2525,7 @@ where ReferenceKind::RefOf | ReferenceKind::Unresolved => return Err(AmlError::StoreToInvalidReferenceType), }; - unsafe { - *dst.gain_mut(&token) = (*object).clone(); - } - + *dst.write() = object.read().clone(); Ok(()) } @@ -2437,21 +2556,35 @@ where }; let (read_region, index_field_idx) = match field.kind { - FieldUnitKind::Normal { ref region } => (region, 0), + FieldUnitKind::Normal { ref region } => { + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), 0) + } FieldUnitKind::Bank { ref region, ref bank, bank_value } => { - let Object::FieldUnit(ref bank) = **bank else { panic!() }; - assert!(matches!(bank.kind, FieldUnitKind::Normal { .. })); - self.do_field_write(bank, Object::Integer(bank_value).wrap())?; - (region, 0) + let bank_fu = { + let bank_read = bank.read(); + let Object::FieldUnit(ref fu) = *bank_read else { panic!() }; + assert!(matches!(fu.kind, FieldUnitKind::Normal { .. })); + fu.clone() + }; + self.do_field_write(&bank_fu, Object::Integer(bank_value).wrap())?; + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), 0) } FieldUnitKind::Index { index: _, ref data } => { - let Object::FieldUnit(ref data) = **data else { panic!() }; - let FieldUnitKind::Normal { region } = &data.kind else { panic!() }; - let reg_idx = field.bit_index / 8; - (region, reg_idx) + let (region_clone, reg_idx) = { + let data_read = data.read(); + let Object::FieldUnit(ref data_fu) = *data_read else { panic!() }; + let FieldUnitKind::Normal { ref region } = data_fu.kind else { panic!() }; + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), field.bit_index / 8) + }; + (region_clone, reg_idx) } }; - let Object::OpRegion(ref read_region) = **read_region else { panic!() }; /* * TODO: it might be worth having a fast path here for reads that don't do weird @@ -2476,20 +2609,25 @@ where } FieldUnitKind::Index { ref index, ref data } => { // Update index register - let Object::FieldUnit(ref index) = **index else { panic!() }; - let Object::FieldUnit(ref data) = **data else { panic!() }; + let (index_fu, data_bit_index) = { + let index_read = index.read(); + let Object::FieldUnit(ref fu) = *index_read else { panic!() }; + let data_read = data.read(); + let Object::FieldUnit(ref data_fu) = *data_read else { panic!() }; + (fu.clone(), data_fu.bit_index) + }; self.do_field_write( - index, + &index_fu, Object::Integer((index_field_idx + i * (access_width_bits / 8)) as u64).wrap(), )?; // The offset is always that of the data register, as we always read from the // base of the data register. - data.bit_index + data_bit_index } }; - let raw = self.do_native_region_read(read_region, aligned_offset / 8, access_width_bits / 8)?; + let raw = self.do_native_region_read(&read_region, aligned_offset / 8, access_width_bits / 8)?; let src_index = if i == 0 { field.bit_index % access_width_bits } else { 0 }; let remaining_length = field.bit_length - read_so_far; let length = if i == 0 { @@ -2511,11 +2649,19 @@ where fn do_field_write(&self, field: &FieldUnit, value: WrappedObject) -> Result<(), AmlError> { trace!("AML field write. Field = {:?}. Value = {}", field, value); - let value_bytes = match &*value { - Object::Integer(value) => &value.to_le_bytes() as &[u8], + let value_read = value.read(); + let value_bytes: &[u8] = match &*value_read { + Object::Integer(v) => &v.to_le_bytes() as &[u8], Object::Buffer(bytes) => bytes, - _ => Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::Integer, got: value.typ() })?, + _ => { + return Err(AmlError::ObjectNotOfExpectedType { + expected: ObjectType::Integer, + got: value_read.typ(), + }); + } }; + let value_bytes: alloc::vec::Vec = value_bytes.to_vec(); + let access_width_bits = field.flags.access_type_bytes()? * 8; // In this tuple: @@ -2523,21 +2669,35 @@ where // - index_field_idx is the initial index to write into the Index register of an index // field. For all other field types it is unused and set to zero. let (write_region, index_field_idx) = match field.kind { - FieldUnitKind::Normal { ref region } => (region, 0), + FieldUnitKind::Normal { ref region } => { + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), 0) + } FieldUnitKind::Bank { ref region, ref bank, bank_value } => { - let Object::FieldUnit(ref bank) = **bank else { panic!() }; - assert!(matches!(bank.kind, FieldUnitKind::Normal { .. })); - self.do_field_write(bank, Object::Integer(bank_value).wrap())?; - (region, 0) + let bank_fu = { + let bank_read = bank.read(); + let Object::FieldUnit(ref fu) = *bank_read else { panic!() }; + assert!(matches!(fu.kind, FieldUnitKind::Normal { .. })); + fu.clone() + }; + self.do_field_write(&bank_fu, Object::Integer(bank_value).wrap())?; + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), 0) } FieldUnitKind::Index { index: _, ref data } => { - let Object::FieldUnit(ref data) = **data else { panic!() }; - let FieldUnitKind::Normal { region: data_region } = &data.kind else { panic!() }; - let reg_idx = field.bit_index / 8; - (data_region, reg_idx) + let (region_clone, reg_idx) = { + let data_read = data.read(); + let Object::FieldUnit(ref data_fu) = *data_read else { panic!() }; + let FieldUnitKind::Normal { ref region } = data_fu.kind else { panic!() }; + let region_read = region.read(); + let Object::OpRegion(op_region) = &*region_read else { panic!() }; + (op_region.clone(), field.bit_index / 8) + }; + (region_clone, reg_idx) } }; - let Object::OpRegion(ref write_region) = **write_region else { panic!() }; // TODO: if the region wants locking, do that @@ -2557,16 +2717,21 @@ where } FieldUnitKind::Index { ref index, ref data } => { // Update index register - let Object::FieldUnit(ref index) = **index else { panic!() }; - let Object::FieldUnit(ref data) = **data else { panic!() }; + let (index_fu, data_bit_index) = { + let index_read = index.read(); + let Object::FieldUnit(ref fu) = *index_read else { panic!() }; + let data_read = data.read(); + let Object::FieldUnit(ref data_fu) = *data_read else { panic!() }; + (fu.clone(), data_fu.bit_index) + }; self.do_field_write( - index, + &index_fu, Object::Integer((index_field_idx + i * (access_width_bits / 8)) as u64).wrap(), )?; // The offset is always that of the data register, as we always read from the // base of the data register. - data.bit_index + data_bit_index } }; let dst_index = if i == 0 { field.bit_index % access_width_bits } else { 0 }; @@ -2579,7 +2744,7 @@ where let mut bytes = if dst_index > 0 || (field.bit_length - written_so_far) < access_width_bits { match field.flags.update_rule() { FieldUpdateRule::Preserve => self - .do_native_region_read(write_region, aligned_offset / 8, access_width_bits / 8)? + .do_native_region_read(&write_region, aligned_offset / 8, access_width_bits / 8)? .to_le_bytes(), FieldUpdateRule::WriteAsOnes => [0xff; 8], FieldUpdateRule::WriteAsZeros => [0; 8], @@ -2595,9 +2760,9 @@ where usize::min(remaining_length, access_width_bits) }; - object::copy_bits(value_bytes, written_so_far, &mut bytes, dst_index, length); + object::copy_bits(&value_bytes, written_so_far, &mut bytes, dst_index, length); self.do_native_region_write( - write_region, + &write_region, aligned_offset / 8, access_width_bits / 8, u64::from_le_bytes(bytes), @@ -2738,17 +2903,17 @@ where * cache them somewhere? */ let seg = match self.evaluate_if_present(AmlName::from_str("_SEG").unwrap().resolve(path)?, vec![])? { - Some(value) => value.as_integer()?, + Some(value) => value.read().as_integer()?, None => 0, }; let bus = match self.evaluate_if_present(AmlName::from_str("_BBN").unwrap().resolve(path)?, vec![])? { - Some(value) => value.as_integer()?, + Some(value) => value.read().as_integer()?, None => 0, }; let (device, function) = { let adr = self.evaluate_if_present(AmlName::from_str("_ADR").unwrap().resolve(path)?, vec![])?; let adr = match adr { - Some(adr) => adr.as_integer()?, + Some(adr) => adr.read().as_integer()?, None => 0, }; (adr.get_bits(16..32), adr.get_bits(0..16)) @@ -2928,12 +3093,13 @@ impl MethodContext { args: Vec, scope: AmlName, ) -> Result { - if let Object::Method { code, flags } = &*method { + let method_read = method.read(); + if let Object::Method { code, flags } = &*method_read { if args.len() != flags.arg_count() { return Err(AmlError::MethodArgCountIncorrect); } let block = Block { - stream: code as &[u8] as *const [u8], + stream: code.as_slice() as *const [u8], pc: 0, kind: BlockKind::Method { method_scope: scope.clone() }, }; @@ -2951,7 +3117,7 @@ impl MethodContext { }; Ok(context) } else { - Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::Method, got: method.typ() }) + Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::Method, got: method_read.typ() }) } } diff --git a/src/aml/namespace.rs b/src/aml/namespace.rs index c887bcac..a2761521 100644 --- a/src/aml/namespace.rs +++ b/src/aml/namespace.rs @@ -69,10 +69,11 @@ impl Namespace { if args.len() != 1 { return Err(AmlError::MethodArgCountIncorrect); } - let Object::String(ref feature) = *args[0] else { + let args0_read = args[0].read(); + let Object::String(ref feature) = *args0_read else { return Err(AmlError::ObjectNotOfExpectedType { expected: ObjectType::String, - got: args[0].typ(), + got: args0_read.typ(), }); }; @@ -383,7 +384,7 @@ impl fmt::Display for Namespace { if end { END } else { BRANCH }, name.as_str(), if flags.is_alias() { "[A] " } else { "" }, - **object + *object.read() )?; // If the object has a corresponding scope, print it here diff --git a/src/aml/object.rs b/src/aml/object.rs index 7d78447a..c42bcae4 100644 --- a/src/aml/object.rs +++ b/src/aml/object.rs @@ -6,9 +6,10 @@ use alloc::{ vec::Vec, }; use bit_field::BitField; -use core::{cell::UnsafeCell, cmp::Ordering, fmt, ops, sync::atomic::AtomicU64}; +use core::{cmp::Ordering, fmt, sync::atomic::AtomicU64}; +use spinning_top::{RwSpinlock, guard::{RwSpinlockReadGuard, RwSpinlockWriteGuard}}; -type NativeMethod = dyn Fn(&[WrappedObject]) -> Result; +type NativeMethod = dyn Fn(&[WrappedObject]) -> Result + Send + Sync; #[derive(Clone)] pub enum Object { @@ -36,7 +37,7 @@ pub enum Object { impl Object { pub fn native_method(num_args: u8, f: F) -> Object where - F: Fn(&[WrappedObject]) -> Result + 'static, + F: Fn(&[WrappedObject]) -> Result + Send + Sync + 'static, { let mut flags = 0; flags.set_bits(0..3, num_args); @@ -61,15 +62,15 @@ impl fmt::Display for Object { Object::Method { .. } => write!(f, "Method"), Object::NativeMethod { .. } => write!(f, "NativeMethod"), Object::Mutex { .. } => write!(f, "Mutex"), - Object::Reference { kind, inner } => write!(f, "Reference({:?} -> {})", kind, **inner), + Object::Reference { kind, inner } => write!(f, "Reference({:?} -> {})", kind, *inner.read()), Object::OpRegion(region) => write!(f, "{region:?}"), Object::Package(elements) => { write!(f, "Package {{ ")?; for (i, element) in elements.iter().enumerate() { if i == elements.len() - 1 { - write!(f, "{}", **element)?; + write!(f, "{}", *element.read())?; } else { - write!(f, "{}, ", **element)?; + write!(f, "{}, ", *element.read())?; } } write!(f, " }}")?; @@ -87,53 +88,38 @@ impl fmt::Display for Object { } } -/// `ObjectToken` is used to mediate mutable access to objects from a [`WrappedObject`]. It must be -/// acquired by locking the single token provided by [`super::Interpreter`]. -#[non_exhaustive] -pub struct ObjectToken { - _dont_construct_me: (), -} +#[derive(Clone)] +pub struct WrappedObject(Arc>); -impl ObjectToken { - /// Create an [`ObjectToken`]. This should **only** be done **once** by the main interpreter, - /// as contructing your own token allows invalid mutable access to objects. - pub(super) unsafe fn create_interpreter_token() -> ObjectToken { - ObjectToken { _dont_construct_me: () } +impl fmt::Debug for WrappedObject { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "WrappedObject({})", *self.0.read()) } } -#[derive(Clone, Debug)] -pub struct WrappedObject(Arc>); - impl WrappedObject { pub fn new(object: Object) -> WrappedObject { - #[allow(clippy::arc_with_non_send_sync)] - WrappedObject(Arc::new(UnsafeCell::new(object))) + WrappedObject(Arc::new(RwSpinlock::new(object))) } - /// Gain a mutable reference to an [`Object`] from this [`WrappedObject`]. - /// - /// # Safety - /// This requires an [`ObjectToken`] which is protected by a lock on [`super::Interpreter`], - /// which prevents mutable access to objects from multiple contexts. It does not, however, - /// prevent the same object, referenced from multiple [`WrappedObject`]s, having multiple - /// mutable (and therefore aliasing) references being made to it, and therefore care must be - /// taken in the interpreter to prevent this. - pub unsafe fn gain_mut<'r, 'a, 't>(&'a self, _token: &'t ObjectToken) -> &'r mut Object - where - 't: 'r, - 'a: 'r, - { - unsafe { &mut *(self.0.get()) } + pub fn read(&self) -> RwSpinlockReadGuard<'_, Object> { + self.0.read() + } + + pub fn write(&self) -> RwSpinlockWriteGuard<'_, Object> { + self.0.write() } pub fn unwrap_reference(self) -> WrappedObject { let mut object = self; loop { - if let Object::Reference { ref inner, .. } = *object { - object = inner.clone(); - } else { - return object.clone(); + let inner = { + let read = object.read(); + if let Object::Reference { inner, .. } = &*read { Some(inner.clone()) } else { None } + }; + match inner { + Some(inner) => object = inner, + None => return object, } } } @@ -143,34 +129,27 @@ impl WrappedObject { pub fn unwrap_transparent_reference(self) -> WrappedObject { let mut object = self; loop { - if let Object::Reference { kind, ref inner } = *object - && (kind == ReferenceKind::Local || kind == ReferenceKind::Arg || kind == ReferenceKind::Named) - { - object = inner.clone(); - } else { - return object.clone(); + let inner = { + let read = object.read(); + if let Object::Reference { kind, inner } = &*read + && (*kind == ReferenceKind::Local || *kind == ReferenceKind::Arg || *kind == ReferenceKind::Named) + { + Some(inner.clone()) + } else { + None + } + }; + match inner { + Some(inner) => object = inner, + None => return object, } } } } -impl ops::Deref for WrappedObject { - type Target = Object; - - fn deref(&self) -> &Self::Target { - /* - * SAFETY: elided lifetime ensures reference cannot outlive at least one reference-counted - * instance of the object. `WrappedObject::gain_mut` is unsafe, and so it is the user's - * responsibility to ensure shared references from `Deref` do not co-exist with an - * exclusive reference. - */ - unsafe { &*self.0.get() } - } -} - impl fmt::Display for WrappedObject { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Wrapped({})", unsafe { &*self.0.get() }) + write!(f, "Wrapped({})", *self.0.read()) } } @@ -261,9 +240,10 @@ impl Object { pub fn read_buffer_field(&self, integer_size: IntegerSize) -> Result { if let Self::BufferField { buffer, offset, length } = self { - let buffer = match **buffer { - Object::Buffer(ref buffer) => buffer.as_slice(), - Object::String(ref string) => string.as_bytes(), + let buffer_read = buffer.read(); + let buffer = match &*buffer_read { + Object::Buffer(buffer) => buffer.as_slice(), + Object::String(string) => string.as_bytes(), _ => panic!(), }; if *length <= integer_size as usize { @@ -280,10 +260,11 @@ impl Object { } } - pub fn write_buffer_field(&mut self, value: &[u8], token: &ObjectToken) -> Result<(), AmlError> { + pub fn write_buffer_field(&mut self, value: &[u8]) -> Result<(), AmlError> { // TODO: bounds check the buffer first to avoid panicking if let Self::BufferField { buffer, offset, length } = self { - let buffer = match unsafe { buffer.gain_mut(token) } { + let mut buffer_write = buffer.write(); + let buffer = match &mut *buffer_write { Object::Buffer(buffer) => buffer.as_mut_slice(), // XXX: this unfortunately requires us to trust AML to keep the string as valid // UTF8... maybe there is a better way? diff --git a/src/aml/pci_routing.rs b/src/aml/pci_routing.rs index 475ee9d6..c8c99ac7 100644 --- a/src/aml/pci_routing.rs +++ b/src/aml/pci_routing.rs @@ -66,9 +66,11 @@ impl PciRoutingTable { let prt = interpreter.evaluate(prt_path.clone(), vec![])?; - if let Object::Package(ref inner_values) = *prt { + let prt_read = prt.read(); + if let Object::Package(inner_values) = &*prt_read { for value in inner_values { - if let Object::Package(ref pin_package) = **value { + let value_read = value.read(); + if let Object::Package(pin_package) = &*value_read { /* * Each inner package has the following structure: * | Field | Type | Description | @@ -92,12 +94,13 @@ impl PciRoutingTable { * | | | pin is connected. | * | -----------|-----------|-----------------------------------------------------------| */ - let Object::Integer(address) = *pin_package[0] else { - return Err(AmlError::PrtInvalidAddress); + let address = match &*pin_package[0].read() { + Object::Integer(addr) => *addr, + _ => return Err(AmlError::PrtInvalidAddress), }; let device = address.get_bits(16..32).try_into().map_err(|_| AmlError::PrtInvalidAddress)?; let function = address.get_bits(0..16).try_into().map_err(|_| AmlError::PrtInvalidAddress)?; - let pin = match *pin_package[1] { + let pin = match &*pin_package[1].read() { Object::Integer(0) => Pin::IntA, Object::Integer(1) => Pin::IntB, Object::Integer(2) => Pin::IntC, @@ -105,14 +108,16 @@ impl PciRoutingTable { _ => return Err(AmlError::PrtInvalidPin), }; - match *pin_package[2] { + let p2_read = pin_package[2].read(); + match &*p2_read { Object::Integer(0) => { /* * The Source Index field contains the GSI number that this interrupt is attached * to. */ - let Object::Integer(gsi) = *pin_package[3] else { - return Err(AmlError::PrtInvalidGsi); + let gsi = match &*pin_package[3].read() { + Object::Integer(gsi) => *gsi, + _ => return Err(AmlError::PrtInvalidGsi), }; entries.push(PciRoute { device, @@ -121,11 +126,13 @@ impl PciRoutingTable { route_type: PciRouteType::Gsi(gsi as u32), }); } - Object::String(ref name) => { + Object::String(name) => { + let name = name.clone(); + drop(p2_read); let link_object_name = interpreter .namespace .lock() - .search_for_level(&AmlName::from_str(name)?, &prt_path)?; + .search_for_level(&AmlName::from_str(&name)?, &prt_path)?; entries.push(PciRoute { device, function, @@ -136,13 +143,16 @@ impl PciRoutingTable { _ => return Err(AmlError::PrtInvalidSource), } } else { - return Err(AmlError::InvalidOperationOnObject { op: Operation::DecodePrt, typ: value.typ() }); + return Err(AmlError::InvalidOperationOnObject { + op: Operation::DecodePrt, + typ: value_read.typ(), + }); } } Ok(PciRoutingTable { entries }) } else { - Err(AmlError::InvalidOperationOnObject { op: Operation::DecodePrt, typ: prt.typ() }) + Err(AmlError::InvalidOperationOnObject { op: Operation::DecodePrt, typ: prt_read.typ() }) } } diff --git a/src/aml/resource.rs b/src/aml/resource.rs index 5a85651e..03fdb116 100644 --- a/src/aml/resource.rs +++ b/src/aml/resource.rs @@ -16,7 +16,8 @@ pub enum Resource { /// Parse a `ResourceDescriptor` buffer into a list of resources. pub fn resource_descriptor_list(descriptor: WrappedObject) -> Result, AmlError> { - if let Object::Buffer(ref bytes) = *descriptor { + let desc_read = descriptor.read(); + if let Object::Buffer(ref bytes) = *desc_read { let mut descriptors = Vec::new(); let mut bytes = bytes.as_slice(); @@ -33,7 +34,7 @@ pub fn resource_descriptor_list(descriptor: WrappedObject) -> Result: Send, Sync); +assert_impl_all!(WrappedObject: Send, Sync); diff --git a/tools/aml_test_tools/src/lib.rs b/tools/aml_test_tools/src/lib.rs index 0407a9cd..0b8b4d45 100644 --- a/tools/aml_test_tools/src/lib.rs +++ b/tools/aml_test_tools/src/lib.rs @@ -448,10 +448,11 @@ where if let Some(result) = interpreter.evaluate_if_present(AmlName::from_str("\\MAIN").unwrap(), vec![])? { let expected_result = expected_result.as_ref().unwrap_or(&ExpectedResult::Integer(0)); - if result_matches(expected_result, &result) { + let result_read = result.read(); + if result_matches(expected_result, &*result_read) { Ok(()) } else { - let e = format!("Unexpected MAIN result: {}, expected: {:?}", *result, expected_result); + let e = format!("Unexpected MAIN result: {}, expected: {:?}", *result_read, expected_result); error!("{}", e); Err(AmlError::HostError(e)) }